gohttp/handler/proxy.go

152 lines
4.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"fmt"
"net/http"
"net/http/httputil"
"strconv"
"strings"
"time"
"git.pyer.club/kingecg/gohttpd/healthcheck"
"git.pyer.club/kingecg/gohttpd/model"
"git.pyer.club/kingecg/gologger"
)
type ProxyHandler struct {
proxy []*httputil.ReverseProxy
Upstreams []string
count int
checker *healthcheck.HealthChecker // 健康检查器
}
func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
l := gologger.GetLogger("Proxy")
originalUrl := r.Host + r.URL.String()
s, err := r.Cookie("s")
var proxyIndex int
if err != nil {
proxyIndex = p.count
p.count++
if p.count >= len(p.proxy) {
p.count = 0
}
} else {
proxyIndex, _ = strconv.Atoi(s.Value)
}
// 如果选中的上游服务器不健康,则进行重试
maxRetries := 3
for i := 0; i < maxRetries; i++ {
if p.checker == nil || p.checker.CheckHealth(p.Upstreams[proxyIndex]) {
l.Info(fmt.Sprintf("proxy %s to %s", originalUrl, p.Upstreams[proxyIndex]))
p.proxy[proxyIndex].ServeHTTP(w, r)
return
}
proxyIndex = (proxyIndex + 1) % len(p.proxy) // 选择下一个上游服务器
}
l.Error(fmt.Sprintf("All upstream servers are unhealthy"))
http.Error(w, "Service Unavailable", http.StatusServiceUnavailable)
}
// makeProxy 初始化httputil.ReverseProxy实例并添加路径重写和会话粘滞cookie到响应
// 参数:
// upstream 上游服务器URL
// path HTTP路径配置
// index 上游服务器在列表中的索引
// 返回值:
// httputil.ReverseProxy实例
func makeProxy(upstream string, path *model.HttpPath, index int) *httputil.ReverseProxy {
p := &httputil.ReverseProxy{}
directiveHandlers := []func(r *http.Request){}
if len(path.Directives) > 0 {
for _, directive := range path.Directives {
ndirective := strings.TrimPrefix(directive, "Proxy_")
d := strings.Replace(string(ndirective), "$target", upstream, 1)
dh := GetUpdaterFn(d)
if dh != nil {
directiveHandlers = append(directiveHandlers, dh)
}
}
}
p.Director = func(req *http.Request) {
for _, handler := range directiveHandlers {
handler(req)
}
}
p.ModifyResponse = func(resp *http.Response) error {
hasSticky := false
for _, cookie := range resp.Cookies() {
if cookie.Name == "s" {
hasSticky = true
break
}
}
if !hasSticky {
c := http.Cookie{
Name: "s",
Value: strconv.Itoa(index),
}
resp.Header.Add("Set-Cookie", c.String())
}
return nil
}
return p
}
// NewProxyHandler 创建一个新的代理处理器
func NewProxyHandler(p *model.HttpPath) *ProxyHandler {
upstreamCount := len(p.Upstreams)
if upstreamCount == 0 {
panic("no upstream defined")
}
ph := &ProxyHandler{
Upstreams: p.Upstreams,
}
ph.proxy = make([]*httputil.ReverseProxy, upstreamCount)
for index, upstream := range p.Upstreams {
ph.proxy[index] = makeProxy(upstream, p, index)
}
if len(p.Upstreams) > 1 && p.HealthCheck != nil {
// 只有上游服务器数目大于1时才需要进行健康检查
// 从配置中获取健康检查参数,如果不存在则使用默认值
var interval time.Duration = 10 * time.Second
var timeout time.Duration = 5 * time.Second
var retries int = 3
if p.HealthCheck.Interval != "" {
var err error
interval, err = time.ParseDuration(p.HealthCheck.Interval)
if err != nil {
interval = 10 * time.Second // 默认值
}
}
if p.HealthCheck.Timeout != "" {
var err error
timeout, err = time.ParseDuration(p.HealthCheck.Timeout)
if err != nil {
timeout = 5 * time.Second // 默认值
}
}
if p.HealthCheck.Retries > 0 {
retries = p.HealthCheck.Retries
}
// 使用配置参数创建健康检查器
ph.checker = healthcheck.NewHealthChecker(interval, timeout, retries)
ph.checker.StartHealthCheck(ph.Upstreams, func(upstream string, healthy bool) {
// 当上游服务器状态变化时的回调函数
logger := gologger.GetLogger("Proxy")
if !healthy {
logger.Warn(fmt.Sprintf("Upstream %s is now unhealthy", upstream))
} else {
logger.Info(fmt.Sprintf("Upstream %s is now healthy", upstream))
}
})
}
return ph
}