在项目中遇到了需要在实现HTTP反向代理的同时作为一个网关进行鉴权。我选择了 Smiley's HTTP Proxy Servlet 来实现。

原始ProxyServlet类

package org.mitre.dsmiley.httpproxy;

import javax.servlet.http.HttpServlet;

/**
 * An HTTP reverse proxy/gateway servlet. It is designed to be extended for customization
 * if desired. Most of the work is handled by
 * <a href="http://hc.apache.org/httpcomponents-client-ga/">Apache HttpClient</a>.
 * <p>
 *   There are alternatives to a servlet based proxy such as Apache mod_proxy if that is available to you. However
 *   this servlet is easily customizable by Java, secure-able by your web application's security (e.g. spring-security),
 *   portable across servlet engines, and is embeddable into another web application.
 * </p>
 * <p>
 *   Inspiration: http://httpd.apache.org/docs/2.0/mod/mod_proxy.html
 * </p>
 *
 * @author David Smiley dsmiley@apache.org
 */
@SuppressWarnings({"deprecation", "serial", "WeakerAccess"})
public class ProxyServlet extends HttpServlet {
    
  @Override
  protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
      throws ServletException, IOException {
    //...
  }
    // ...
}

对这个类进行一些重写

自定义异常处理

/**
 * 重写service方法,处理内部的异常
 */
override fun service(servletRequest: HttpServletRequest?, servletResponse: HttpServletResponse?) {
    logger.info("========= 重写请求 =========")
    try {
        super.service(servletRequest, servletResponse)
    } catch (e: ZZException) {
        logger.error("ZZException", e)
        handleZZException(e, servletRequest!!, servletResponse!!)
    }
}

/**
 * 处理Servlet的service异常部分,
 */
private fun handleZZException(ex: ZZException, request: HttpServletRequest, response: HttpServletResponse) {
    logger.info("RequestURI:{}", request.requestURI)
    // 设置HTTP状态码为403 Forbidden
    response.status = HttpStatus.FORBIDDEN.value()

    // 设置响应类型为JSON
    response.contentType = "application/json;charset=UTF-8"

    // 创建一个包含错误信息的map
    val errorDetails = mapOf(
        "code" to ex.code,
        "msg" to ex.message
    )
    // 将错误信息map转换为JSON字符串
    val json = jacksonObjectMapper().writeValueAsString(errorDetails)

    // 写入JSON字符串到响应体
    response.writer.write(json)

    // 记录日志或其他处理...
}

重写请求的query参数

/**
 * 重写请求的query参数
 */
override fun rewriteQueryStringFromRequest(servletRequest: HttpServletRequest, queryString: String?): String {
    if (queryString.isNullOrBlank()) {
        return ""
    }
    logger.debug("请求参数: $queryString")

    val queryStrBuffer = StringBuffer(queryString.length)
    queryString.split("&".toRegex()).filter { it.isNotBlank() }.forEach {
        processQueryParam(it, queryStrBuffer)
    }
    // 去掉最后一个&
    if (queryStrBuffer.endsWith("&")) {
        val removeRange = queryStrBuffer.removeRange(queryStrBuffer.length - 1, queryStrBuffer.length)
        return removeRange.toString()
    }
    return queryStrBuffer.toString()
}

/**
 * 处理请求参数
 */
protected open fun processQueryParam(param: String, queryStrBuffer: StringBuffer) {
    // tk
    if (param.lowercase().contains("tk=")) {
        val tkStr = param.substring(3)
        if (StrUtil.isBlank(tkStr)) {
            throw ZZException("403", "令牌不能为空")
        }
        tkTL.set(tkStr)
        return
    }
    queryStrBuffer.append("$param&")
}

重写请求的url

protected fun superRewriteUrlFromRequest(servletRequest: HttpServletRequest): String =
    super.rewriteUrlFromRequest(servletRequest)

/**
 * 重写请求的url
 */
override fun rewriteUrlFromRequest(servletRequest: HttpServletRequest): String {

    logger.debug("========= 重写请求url =========")
    // 制作请求 URI
    val uri = StringBuilder(500)

    // 目标uri
    uri.append(getTargetUri(servletRequest))

    // 处理给servlet的路径
    val pathInfo = rewritePathInfoFromRequest(servletRequest)
    uri.append(URLUtil.encode(pathInfo))

    // 处理 query string & fragment
    var queryString = URLUtil.decode(servletRequest.queryString)
    var fragment: String? = null
    // 从 queryString 中分离出 fragment,如果找到则更新 queryString
    val fragIdx = queryString?.indexOf('#')
    if (fragIdx != null) {
        if (fragIdx >= 0) {
            fragment = queryString.substring(fragIdx + 1)
            queryString = queryString.substring(0, fragIdx)
        }
    }

    queryString = rewriteQueryStringFromRequest(servletRequest, queryString)

    // 获取请求的ip
    val ip = getIpAddr(servletRequest)

    // 处理令牌
    val (appInfo, layerInfo) = checkAuth(tkTL.get(), ip, layerIdTL.get())
    // 记录日志
    processLogInfo(appInfo, layerInfo)

    uri.append('?')
    uri.append(queryString)
    if (doSendUrlFragment && fragment != null) {
        uri.append('#')
        uri.append(encodeUriQuery(fragment, false))
    }
    logger.info("重写后的url: $uri")
    return uri.toString()
}

鉴权处理令牌

protected open fun checkAuth(
    tk: String,
    ip: String,
    layerId: String
): Pair<Map<String, Any>, Map<String, Any?>> {
    logger.debug("tk:$tk\nlayerId:$layerId")
    val tokenMap = redisTool.getMap<String, Map<String, Any?>>("$platformName:platform:app:token")
    val appInfo = tokenMap.getMap(tk)
    if (appInfo.isEmpty()) {
        throw ZZException("403", "令牌无效")
    }
    //校验ip是否在白名单内
    checkIp(appInfo.getString("appIpWhitelist"), ip)

    val layers: List<Map<String, Any?>> = appInfo.getList("layers")
    val layerInfo = layers.find { StrUtil.equals(it.getString("lRealLayerId"), layerId) }
        ?: throw ZZException("403", "令牌与图层不匹配")
    logger.info("图层信息: $layerInfo")
    return Pair(appInfo, layerInfo)
}

注入Servlet

/**
 * OGC标准服务代理
 * WMTS
 */
@Bean
fun ogcServletRegistrationBean(proxyServlet: OGCServerProxyServlet): ServletRegistrationBean<OGCServerProxyServlet> {
    val servletRegistrationBean = ServletRegistrationBean(proxyServlet, "/geoserver/*")
    servletRegistrationBean.addInitParameter(ProxyServlet.P_TARGET_URI, "http://127.0.0.1:8080/geowebcache/service")
    servletRegistrationBean.addInitParameter(ProxyServlet.P_LOG, "true")
    return servletRegistrationBean
}