235 lines
5.6 KiB
Go
235 lines
5.6 KiB
Go
package server
|
||
|
||
import (
|
||
"container/list"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"path"
|
||
"reflect"
|
||
"strings"
|
||
|
||
"git.pyer.club/kingecg/gohttpd/model"
|
||
"git.pyer.club/kingecg/gologger"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
)
|
||
|
||
type Middleware func(w http.ResponseWriter, r *http.Request, next http.Handler)
|
||
|
||
type MiddlewareLink struct {
|
||
*list.List
|
||
}
|
||
|
||
func IsEqualMiddleware(a, b Middleware) bool {
|
||
if a == nil && b == nil {
|
||
return true
|
||
}
|
||
if a == nil || b == nil {
|
||
return false
|
||
}
|
||
return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer()
|
||
}
|
||
func (ml *MiddlewareLink) Add(m Middleware) {
|
||
if m == nil {
|
||
return
|
||
}
|
||
|
||
if ml.List.Len() == 0 {
|
||
ml.PushBack(m)
|
||
} else {
|
||
if IsEqualMiddleware(m, Done) {
|
||
return
|
||
}
|
||
el := ml.Back()
|
||
ml.InsertBefore(m, el)
|
||
}
|
||
}
|
||
|
||
func (ml *MiddlewareLink) wrap(m Middleware, next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
m(w, r, next)
|
||
})
|
||
}
|
||
func (ml *MiddlewareLink) WrapHandler(next http.Handler) http.Handler {
|
||
if ml.Back() == nil {
|
||
return next
|
||
}
|
||
|
||
var handler http.Handler = next
|
||
|
||
for e := ml.Back(); e != nil; e = e.Prev() {
|
||
|
||
middleware, ok := e.Value.(Middleware)
|
||
if !ok {
|
||
break
|
||
}
|
||
|
||
handler = ml.wrap(middleware, handler)
|
||
}
|
||
return handler
|
||
}
|
||
|
||
func (ml *MiddlewareLink) Clone() *MiddlewareLink {
|
||
ret := NewMiddlewareLink()
|
||
for e := ml.Back(); e != nil; e = e.Prev() {
|
||
middleware, ok := e.Value.(Middleware)
|
||
if !ok {
|
||
break
|
||
}
|
||
ret.Add(middleware)
|
||
}
|
||
return ret
|
||
}
|
||
func NewMiddlewareLink() *MiddlewareLink {
|
||
ml := &MiddlewareLink{list.New()}
|
||
ml.Add(Done)
|
||
return ml
|
||
}
|
||
|
||
func BasicAuth(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||
config := model.GetConfig()
|
||
|
||
if config.Admin.Username == "" || config.Admin.Password == "" {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
user, pass, ok := r.BasicAuth()
|
||
if ok && user == config.Admin.Username && pass == config.Admin.Password {
|
||
next.ServeHTTP(w, r)
|
||
} else {
|
||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
|
||
}
|
||
}
|
||
func JwtAuth(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||
l := gologger.GetLogger("JwtAuth")
|
||
config := getServerConfig(r)
|
||
if config == nil || config.Jwt == nil {
|
||
http.Error(w, "Jwt config error", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
jwtConfig := config.Jwt
|
||
if jwtConfig.Secret == "" || path.Base(r.URL.Path) == "login" {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
// 从cookie中获取token
|
||
tokenCookie, err := r.Cookie("auth_token")
|
||
if err != nil {
|
||
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
tokenString := tokenCookie.Value
|
||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||
// 确保签名方法是正确的
|
||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||
}
|
||
return []byte(jwtConfig.Secret), nil
|
||
})
|
||
if err != nil {
|
||
l.Error("Failed to parse JWT: %v", err)
|
||
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid {
|
||
// 验证通过,将用户信息存储在请求上下文中
|
||
ctx := context.WithValue(r.Context(), "user", claims)
|
||
next.ServeHTTP(w, r.WithContext(ctx))
|
||
return
|
||
}
|
||
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
|
||
}
|
||
|
||
func RecordAccess(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||
model.Incr(r.Host)
|
||
next.ServeHTTP(w, r)
|
||
}
|
||
func Parse[T any](w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||
|
||
if r.Method == "POST" || r.Method == "PUT" {
|
||
//判断r的content-type是否是application/x-www-form-urlencoded
|
||
if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
|
||
r.ParseForm()
|
||
|
||
} else if r.Header.Get("Content-Type") == "multipart/form-data" {
|
||
r.ParseMultipartForm(r.ContentLength)
|
||
|
||
} else {
|
||
// 判断r的content-type是否是application/json
|
||
contentType := r.Header.Get("Content-Type")
|
||
if strings.Contains(contentType, "application/json") {
|
||
var t T
|
||
// 读取json数据
|
||
if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
|
||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||
return
|
||
}
|
||
ctx := r.Context()
|
||
m := ctx.Value(RequestCtxKey("data")).(map[string]interface{})
|
||
if m != nil {
|
||
m["data"] = t
|
||
}
|
||
|
||
}
|
||
}
|
||
|
||
}
|
||
next.ServeHTTP(w, r)
|
||
}
|
||
|
||
func getServerConfig(r *http.Request) *model.HttpServerConfig {
|
||
ctx := r.Context()
|
||
serverName, ok := ctx.Value(RequestCtxKey("serverName")).(string)
|
||
if !ok {
|
||
serverName = ""
|
||
return nil
|
||
}
|
||
config := model.GetServerConfig(serverName)
|
||
return config
|
||
}
|
||
|
||
// IPAccessControl 中间件实现IP访问控制
|
||
func IPAccessControl(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||
|
||
// get serverName from request context
|
||
config := getServerConfig(r)
|
||
if config != nil {
|
||
allowedIPs := config.AllowIPs
|
||
deniedIPs := config.DenyIPs
|
||
clientIP := strings.Split(r.RemoteAddr, ":")[0] // 获取客户端IP
|
||
|
||
// 首先检查是否被禁止(Deny规则优先级高于Allow)
|
||
for _, ip := range deniedIPs {
|
||
if ip == clientIP {
|
||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 然后检查是否被允许
|
||
if len(allowedIPs) > 0 {
|
||
allowed := false
|
||
for _, ip := range allowedIPs {
|
||
if ip == clientIP {
|
||
allowed = true
|
||
break
|
||
}
|
||
}
|
||
if !allowed {
|
||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||
return
|
||
}
|
||
}
|
||
} else {
|
||
http.Error(w, "Server Config Error", http.StatusInternalServerError)
|
||
}
|
||
|
||
next.ServeHTTP(w, r)
|
||
}
|
||
|
||
func Done(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||
next.ServeHTTP(w, r)
|
||
}
|