添加ip控制
This commit is contained in:
parent
be64000bff
commit
ea38f85fb7
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -47,10 +48,15 @@ var Gzip_Response Directive = func(args ...string) Middleware {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var DRecordAccess Directive = func(args ...string) Middleware {
|
var DRecordAccess Directive = func(args ...string) Middleware {
|
||||||
|
serverName := args[0]
|
||||||
return func(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
return func(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||||||
l := gologger.GetLogger("Directive")
|
l := gologger.GetLogger("Directive")
|
||||||
l.Debug("Record-Access")
|
l.Debug("Record-Access")
|
||||||
model.Incr(r.URL.Host)
|
model.Incr(r.URL.Host)
|
||||||
|
// put serverName to request context
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, RequestCtxKey("serverName"), serverName)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,6 +66,8 @@ var JWTDirective Directive = func(args ...string) Middleware {
|
||||||
var BasicAuthDirective Directive = func(args ...string) Middleware {
|
var BasicAuthDirective Directive = func(args ...string) Middleware {
|
||||||
return BasicAuth
|
return BasicAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 在DirectiveMap中注册新指令
|
||||||
var DirectiveMap = map[string]Directive{
|
var DirectiveMap = map[string]Directive{
|
||||||
"Set-Header": Set_Header,
|
"Set-Header": Set_Header,
|
||||||
"Add-Header": Add_Header,
|
"Add-Header": Add_Header,
|
||||||
|
|
|
@ -178,6 +178,48 @@ func Parse[T any](w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IPAccessControl 中间件实现IP访问控制
|
||||||
|
func IPAccessControl(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||||||
|
|
||||||
|
// get serverName from request context
|
||||||
|
ctx := r.Context()
|
||||||
|
serverName, ok := ctx.Value(RequestCtxKey("serverName")).(string)
|
||||||
|
if !ok {
|
||||||
|
serverName = ""
|
||||||
|
}
|
||||||
|
config := model.GetServerConfig(serverName)
|
||||||
|
if config != nil {
|
||||||
|
allowedIPs := config.AllowIPs
|
||||||
|
deniedIPs := config.DenyIPs
|
||||||
|
clientIP := strings.Split(r.RemoteAddr, ":")[0] // 获取客户端IP
|
||||||
|
|
||||||
|
// 首先检查是否被禁止(Deny规则优先级高于Allow)
|
||||||
|
for _, ip := range deniedIPs {
|
||||||
|
if ip == clientIP {
|
||||||
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 然后检查是否被允许
|
||||||
|
if len(allowedIPs) > 0 {
|
||||||
|
allowed := false
|
||||||
|
for _, ip := range allowedIPs {
|
||||||
|
if ip == clientIP {
|
||||||
|
allowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
func Done(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
func Done(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
|
@ -308,18 +308,22 @@ func NewServeMux(c *model.HttpServerConfig) *ServerMux {
|
||||||
wrappedHandler: make(map[string]http.Handler),
|
wrappedHandler: make(map[string]http.Handler),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.AddDirective("Record-Access")
|
s.AddDirective("Record-Access " + c.Name)
|
||||||
if c.AuthType == "jwt" {
|
// 添加IP访问控制中间件
|
||||||
s.AddDirective("Jwt-Auth")
|
if len(c.AllowIPs) > 0 || len(c.DenyIPs) > 0 {
|
||||||
}
|
s.directiveHandlers.Add(IPAccessControl)
|
||||||
if c.AuthType == "basic" {
|
|
||||||
s.AddDirective("Basic-Auth")
|
|
||||||
}
|
}
|
||||||
// 遍历配置中的所有指令
|
// 遍历配置中的所有指令
|
||||||
for _, directive := range c.Directives {
|
for _, directive := range c.Directives {
|
||||||
// 将指令添加到 ServerMux 的指令处理中间件链中
|
// 将指令添加到 ServerMux 的指令处理中间件链中
|
||||||
s.AddDirective(string(directive))
|
s.AddDirective(string(directive))
|
||||||
}
|
}
|
||||||
|
if c.AuthType == "jwt" {
|
||||||
|
s.AddDirective("Jwt-Auth")
|
||||||
|
}
|
||||||
|
if c.AuthType == "basic" {
|
||||||
|
s.AddDirective("Basic-Auth")
|
||||||
|
}
|
||||||
// 遍历配置中的所有 HTTP 路径
|
// 遍历配置中的所有 HTTP 路径
|
||||||
for _, httpPath := range c.Paths {
|
for _, httpPath := range c.Paths {
|
||||||
// 检查路径配置中是否指定了根目录
|
// 检查路径配置中是否指定了根目录
|
||||||
|
|
Loading…
Reference in New Issue