添加ip控制

This commit is contained in:
程广 2025-05-29 17:39:08 +08:00
parent be64000bff
commit ea38f85fb7
3 changed files with 60 additions and 6 deletions

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"net/http"
"strings"
@ -47,10 +48,15 @@ var Gzip_Response Directive = func(args ...string) Middleware {
}
}
var DRecordAccess Directive = func(args ...string) Middleware {
serverName := args[0]
return func(w http.ResponseWriter, r *http.Request, next http.Handler) {
l := gologger.GetLogger("Directive")
l.Debug("Record-Access")
model.Incr(r.URL.Host)
// put serverName to request context
ctx := r.Context()
ctx = context.WithValue(ctx, RequestCtxKey("serverName"), serverName)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
}
@ -60,6 +66,8 @@ var JWTDirective Directive = func(args ...string) Middleware {
var BasicAuthDirective Directive = func(args ...string) Middleware {
return BasicAuth
}
// 在DirectiveMap中注册新指令
var DirectiveMap = map[string]Directive{
"Set-Header": Set_Header,
"Add-Header": Add_Header,

View File

@ -178,6 +178,48 @@ func Parse[T any](w http.ResponseWriter, r *http.Request, next http.Handler) {
next.ServeHTTP(w, r)
}
// IPAccessControl 中间件实现IP访问控制
func IPAccessControl(w http.ResponseWriter, r *http.Request, next http.Handler) {
// get serverName from request context
ctx := r.Context()
serverName, ok := ctx.Value(RequestCtxKey("serverName")).(string)
if !ok {
serverName = ""
}
config := model.GetServerConfig(serverName)
if config != nil {
allowedIPs := config.AllowIPs
deniedIPs := config.DenyIPs
clientIP := strings.Split(r.RemoteAddr, ":")[0] // 获取客户端IP
// 首先检查是否被禁止Deny规则优先级高于Allow
for _, ip := range deniedIPs {
if ip == clientIP {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
}
// 然后检查是否被允许
if len(allowedIPs) > 0 {
allowed := false
for _, ip := range allowedIPs {
if ip == clientIP {
allowed = true
break
}
}
if !allowed {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
}
}
next.ServeHTTP(w, r)
}
func Done(w http.ResponseWriter, r *http.Request, next http.Handler) {
next.ServeHTTP(w, r)
}

View File

@ -308,18 +308,22 @@ func NewServeMux(c *model.HttpServerConfig) *ServerMux {
wrappedHandler: make(map[string]http.Handler),
}
s.AddDirective("Record-Access")
if c.AuthType == "jwt" {
s.AddDirective("Jwt-Auth")
}
if c.AuthType == "basic" {
s.AddDirective("Basic-Auth")
s.AddDirective("Record-Access " + c.Name)
// 添加IP访问控制中间件
if len(c.AllowIPs) > 0 || len(c.DenyIPs) > 0 {
s.directiveHandlers.Add(IPAccessControl)
}
// 遍历配置中的所有指令
for _, directive := range c.Directives {
// 将指令添加到 ServerMux 的指令处理中间件链中
s.AddDirective(string(directive))
}
if c.AuthType == "jwt" {
s.AddDirective("Jwt-Auth")
}
if c.AuthType == "basic" {
s.AddDirective("Basic-Auth")
}
// 遍历配置中的所有 HTTP 路径
for _, httpPath := range c.Paths {
// 检查路径配置中是否指定了根目录