在项目中遇到了需要在实现HTTP反向代理的同时作为一个网关进行鉴权。我选择了 Smiley's HTTP Proxy Servlet 来实现。
原始ProxyServlet类
package org.mitre.dsmiley.httpproxy;
import javax.servlet.http.HttpServlet;
/**
* An HTTP reverse proxy/gateway servlet. It is designed to be extended for customization
* if desired. Most of the work is handled by
* <a href="http://hc.apache.org/httpcomponents-client-ga/">Apache HttpClient</a>.
* <p>
* There are alternatives to a servlet based proxy such as Apache mod_proxy if that is available to you. However
* this servlet is easily customizable by Java, secure-able by your web application's security (e.g. spring-security),
* portable across servlet engines, and is embeddable into another web application.
* </p>
* <p>
* Inspiration: http://httpd.apache.org/docs/2.0/mod/mod_proxy.html
* </p>
*
* @author David Smiley dsmiley@apache.org
*/
@SuppressWarnings({"deprecation", "serial", "WeakerAccess"})
public class ProxyServlet extends HttpServlet {
@Override
protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
throws ServletException, IOException {
//...
}
// ...
}
对这个类进行一些重写
自定义异常处理
/**
* 重写service方法,处理内部的异常
*/
override fun service(servletRequest: HttpServletRequest?, servletResponse: HttpServletResponse?) {
logger.info("========= 重写请求 =========")
try {
super.service(servletRequest, servletResponse)
} catch (e: ZZException) {
logger.error("ZZException", e)
handleZZException(e, servletRequest!!, servletResponse!!)
}
}
/**
* 处理Servlet的service异常部分,
*/
private fun handleZZException(ex: ZZException, request: HttpServletRequest, response: HttpServletResponse) {
logger.info("RequestURI:{}", request.requestURI)
// 设置HTTP状态码为403 Forbidden
response.status = HttpStatus.FORBIDDEN.value()
// 设置响应类型为JSON
response.contentType = "application/json;charset=UTF-8"
// 创建一个包含错误信息的map
val errorDetails = mapOf(
"code" to ex.code,
"msg" to ex.message
)
// 将错误信息map转换为JSON字符串
val json = jacksonObjectMapper().writeValueAsString(errorDetails)
// 写入JSON字符串到响应体
response.writer.write(json)
// 记录日志或其他处理...
}
重写请求的query参数
/**
* 重写请求的query参数
*/
override fun rewriteQueryStringFromRequest(servletRequest: HttpServletRequest, queryString: String?): String {
if (queryString.isNullOrBlank()) {
return ""
}
logger.debug("请求参数: $queryString")
val queryStrBuffer = StringBuffer(queryString.length)
queryString.split("&".toRegex()).filter { it.isNotBlank() }.forEach {
processQueryParam(it, queryStrBuffer)
}
// 去掉最后一个&
if (queryStrBuffer.endsWith("&")) {
val removeRange = queryStrBuffer.removeRange(queryStrBuffer.length - 1, queryStrBuffer.length)
return removeRange.toString()
}
return queryStrBuffer.toString()
}
/**
* 处理请求参数
*/
protected open fun processQueryParam(param: String, queryStrBuffer: StringBuffer) {
// tk
if (param.lowercase().contains("tk=")) {
val tkStr = param.substring(3)
if (StrUtil.isBlank(tkStr)) {
throw ZZException("403", "令牌不能为空")
}
tkTL.set(tkStr)
return
}
queryStrBuffer.append("$param&")
}
重写请求的url
protected fun superRewriteUrlFromRequest(servletRequest: HttpServletRequest): String =
super.rewriteUrlFromRequest(servletRequest)
/**
* 重写请求的url
*/
override fun rewriteUrlFromRequest(servletRequest: HttpServletRequest): String {
logger.debug("========= 重写请求url =========")
// 制作请求 URI
val uri = StringBuilder(500)
// 目标uri
uri.append(getTargetUri(servletRequest))
// 处理给servlet的路径
val pathInfo = rewritePathInfoFromRequest(servletRequest)
uri.append(URLUtil.encode(pathInfo))
// 处理 query string & fragment
var queryString = URLUtil.decode(servletRequest.queryString)
var fragment: String? = null
// 从 queryString 中分离出 fragment,如果找到则更新 queryString
val fragIdx = queryString?.indexOf('#')
if (fragIdx != null) {
if (fragIdx >= 0) {
fragment = queryString.substring(fragIdx + 1)
queryString = queryString.substring(0, fragIdx)
}
}
queryString = rewriteQueryStringFromRequest(servletRequest, queryString)
// 获取请求的ip
val ip = getIpAddr(servletRequest)
// 处理令牌
val (appInfo, layerInfo) = checkAuth(tkTL.get(), ip, layerIdTL.get())
// 记录日志
processLogInfo(appInfo, layerInfo)
uri.append('?')
uri.append(queryString)
if (doSendUrlFragment && fragment != null) {
uri.append('#')
uri.append(encodeUriQuery(fragment, false))
}
logger.info("重写后的url: $uri")
return uri.toString()
}
鉴权处理令牌
protected open fun checkAuth(
tk: String,
ip: String,
layerId: String
): Pair<Map<String, Any>, Map<String, Any?>> {
logger.debug("tk:$tk\nlayerId:$layerId")
val tokenMap = redisTool.getMap<String, Map<String, Any?>>("$platformName:platform:app:token")
val appInfo = tokenMap.getMap(tk)
if (appInfo.isEmpty()) {
throw ZZException("403", "令牌无效")
}
//校验ip是否在白名单内
checkIp(appInfo.getString("appIpWhitelist"), ip)
val layers: List<Map<String, Any?>> = appInfo.getList("layers")
val layerInfo = layers.find { StrUtil.equals(it.getString("lRealLayerId"), layerId) }
?: throw ZZException("403", "令牌与图层不匹配")
logger.info("图层信息: $layerInfo")
return Pair(appInfo, layerInfo)
}
注入Servlet
/**
* OGC标准服务代理
* WMTS
*/
@Bean
fun ogcServletRegistrationBean(proxyServlet: OGCServerProxyServlet): ServletRegistrationBean<OGCServerProxyServlet> {
val servletRegistrationBean = ServletRegistrationBean(proxyServlet, "/geoserver/*")
servletRegistrationBean.addInitParameter(ProxyServlet.P_TARGET_URI, "http://127.0.0.1:8080/geowebcache/service")
servletRegistrationBean.addInitParameter(ProxyServlet.P_LOG, "true")
return servletRegistrationBean
}