diff --git a/server/directive.go b/server/directive.go index 34f7f7a..cbe4ff2 100644 --- a/server/directive.go +++ b/server/directive.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net/http" "strings" @@ -47,10 +48,15 @@ var Gzip_Response Directive = func(args ...string) Middleware { } } var DRecordAccess Directive = func(args ...string) Middleware { + serverName := args[0] return func(w http.ResponseWriter, r *http.Request, next http.Handler) { l := gologger.GetLogger("Directive") l.Debug("Record-Access") model.Incr(r.URL.Host) + // put serverName to request context + ctx := r.Context() + ctx = context.WithValue(ctx, RequestCtxKey("serverName"), serverName) + r = r.WithContext(ctx) next.ServeHTTP(w, r) } } @@ -60,6 +66,8 @@ var JWTDirective Directive = func(args ...string) Middleware { var BasicAuthDirective Directive = func(args ...string) Middleware { return BasicAuth } + +// 在DirectiveMap中注册新指令 var DirectiveMap = map[string]Directive{ "Set-Header": Set_Header, "Add-Header": Add_Header, diff --git a/server/middleware.go b/server/middleware.go index c36d14a..5895e5c 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -178,6 +178,48 @@ func Parse[T any](w http.ResponseWriter, r *http.Request, next http.Handler) { next.ServeHTTP(w, r) } +// IPAccessControl 中间件实现IP访问控制 +func IPAccessControl(w http.ResponseWriter, r *http.Request, next http.Handler) { + + // get serverName from request context + ctx := r.Context() + serverName, ok := ctx.Value(RequestCtxKey("serverName")).(string) + if !ok { + serverName = "" + } + config := model.GetServerConfig(serverName) + if config != nil { + allowedIPs := config.AllowIPs + deniedIPs := config.DenyIPs + clientIP := strings.Split(r.RemoteAddr, ":")[0] // 获取客户端IP + + // 首先检查是否被禁止(Deny规则优先级高于Allow) + for _, ip := range deniedIPs { + if ip == clientIP { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + } + + // 然后检查是否被允许 + if len(allowedIPs) > 0 { + allowed := false + for _, ip := range allowedIPs { + if ip == clientIP { + allowed = true + break + } + } + if !allowed { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + } + } + + next.ServeHTTP(w, r) +} + func Done(w http.ResponseWriter, r *http.Request, next http.Handler) { next.ServeHTTP(w, r) } diff --git a/server/server.go b/server/server.go index e47899c..b6d30cf 100644 --- a/server/server.go +++ b/server/server.go @@ -308,18 +308,22 @@ func NewServeMux(c *model.HttpServerConfig) *ServerMux { wrappedHandler: make(map[string]http.Handler), } - s.AddDirective("Record-Access") - if c.AuthType == "jwt" { - s.AddDirective("Jwt-Auth") - } - if c.AuthType == "basic" { - s.AddDirective("Basic-Auth") + s.AddDirective("Record-Access " + c.Name) + // 添加IP访问控制中间件 + if len(c.AllowIPs) > 0 || len(c.DenyIPs) > 0 { + s.directiveHandlers.Add(IPAccessControl) } // 遍历配置中的所有指令 for _, directive := range c.Directives { // 将指令添加到 ServerMux 的指令处理中间件链中 s.AddDirective(string(directive)) } + if c.AuthType == "jwt" { + s.AddDirective("Jwt-Auth") + } + if c.AuthType == "basic" { + s.AddDirective("Basic-Auth") + } // 遍历配置中的所有 HTTP 路径 for _, httpPath := range c.Paths { // 检查路径配置中是否指定了根目录