gohttp/server/middleware.go

226 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"container/list"
"context"
"encoding/json"
"fmt"
"net/http"
"path"
"reflect"
"strings"
"git.pyer.club/kingecg/gohttpd/model"
"git.pyer.club/kingecg/gologger"
"github.com/golang-jwt/jwt/v5"
)
type Middleware func(w http.ResponseWriter, r *http.Request, next http.Handler)
type MiddlewareLink struct {
*list.List
}
func IsEqualMiddleware(a, b Middleware) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer()
}
func (ml *MiddlewareLink) Add(m Middleware) {
if m == nil {
return
}
if ml.List.Len() == 0 {
ml.PushBack(m)
} else {
if IsEqualMiddleware(m, Done) {
return
}
el := ml.Back()
ml.InsertBefore(m, el)
}
}
// func (ml *MiddlewareLink) ServeHTTP(w http.ResponseWriter, r *http.Request) bool {
// canContinue := true
// next := func() {
// canContinue = true
// }
// for e := ml.Front(); e != nil && canContinue; e = e.Next() {
// canContinue = false
// e.Value.(Middleware)(w, r, next)
// if !canContinue {
// break
// }
// }
// return canContinue
// }
func (ml *MiddlewareLink) wrap(m Middleware, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m(w, r, next)
})
}
func (ml *MiddlewareLink) WrapHandler(next http.Handler) http.Handler {
if ml.Back() == nil {
return next
}
var handler http.Handler = next
for e := ml.Back(); e != nil; e = e.Prev() {
middleware, ok := e.Value.(Middleware)
if !ok {
break
}
handler = ml.wrap(middleware, handler)
}
return handler
}
func NewMiddlewareLink() *MiddlewareLink {
ml := &MiddlewareLink{list.New()}
ml.Add(Done)
return ml
}
func BasicAuth(w http.ResponseWriter, r *http.Request, next http.Handler) {
config := model.GetConfig()
if config.Admin.Username == "" || config.Admin.Password == "" {
next.ServeHTTP(w, r)
return
}
user, pass, ok := r.BasicAuth()
if ok && user == config.Admin.Username && pass == config.Admin.Password {
next.ServeHTTP(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
}
}
func JwtAuth(w http.ResponseWriter, r *http.Request, next http.Handler) {
l := gologger.GetLogger("JwtAuth")
config := model.GetConfig()
jwtConfig := config.Admin.Jwt
if jwtConfig.Secret == "" || path.Base(r.URL.Path) == "login" {
next.ServeHTTP(w, r)
return
}
// 从cookie中获取token
tokenCookie, err := r.Cookie("auth_token")
if err != nil {
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
return
}
tokenString := tokenCookie.Value
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
// 确保签名方法是正确的
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(jwtConfig.Secret), nil
})
if err != nil {
l.Error("Failed to parse JWT: %v", err)
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
return
}
if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid {
// 验证通过,将用户信息存储在请求上下文中
ctx := context.WithValue(r.Context(), "user", claims)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
}
func RecordAccess(w http.ResponseWriter, r *http.Request, next http.Handler) {
model.Incr(r.Host)
next.ServeHTTP(w, r)
}
func Parse[T any](w http.ResponseWriter, r *http.Request, next http.Handler) {
if r.Method == "POST" || r.Method == "PUT" {
//判断r的content-type是否是application/x-www-form-urlencoded
if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
r.ParseForm()
} else if r.Header.Get("Content-Type") == "multipart/form-data" {
r.ParseMultipartForm(r.ContentLength)
} else {
// 判断r的content-type是否是application/json
contentType := r.Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
var t T
// 读取json数据
if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ctx := r.Context()
m := ctx.Value(RequestCtxKey("data")).(map[string]interface{})
if m != nil {
m["data"] = t
}
}
}
}
next.ServeHTTP(w, r)
}
// IPAccessControl 中间件实现IP访问控制
func IPAccessControl(w http.ResponseWriter, r *http.Request, next http.Handler) {
// get serverName from request context
ctx := r.Context()
serverName, ok := ctx.Value(RequestCtxKey("serverName")).(string)
if !ok {
serverName = ""
}
config := model.GetServerConfig(serverName)
if config != nil {
allowedIPs := config.AllowIPs
deniedIPs := config.DenyIPs
clientIP := strings.Split(r.RemoteAddr, ":")[0] // 获取客户端IP
// 首先检查是否被禁止Deny规则优先级高于Allow
for _, ip := range deniedIPs {
if ip == clientIP {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
}
// 然后检查是否被允许
if len(allowedIPs) > 0 {
allowed := false
for _, ip := range allowedIPs {
if ip == clientIP {
allowed = true
break
}
}
if !allowed {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
}
}
next.ServeHTTP(w, r)
}
func Done(w http.ResponseWriter, r *http.Request, next http.Handler) {
next.ServeHTTP(w, r)
}