https add server hook

addon-dailer
lqqyt2423 2 years ago
parent 46f165105d
commit 2e0c62a08b

@ -23,7 +23,7 @@ func NewConnContext(c net.Conn) *ConnContext {
} }
} }
func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenServerConnected func()) { func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenConnected func()) {
if connCtx.Server != nil { if connCtx.Server != nil {
return return
} }
@ -48,7 +48,7 @@ func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.C
cw := connWrap(c) cw := connWrap(c)
server.Conn = cw server.Conn = cw
defer whenServerConnected() defer whenConnected()
return cw, nil return cw, nil
}, },
ForceAttemptHTTP2: false, // disable http2 ForceAttemptHTTP2: false, // disable http2
@ -67,7 +67,7 @@ func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.C
connCtx.Server = server connCtx.Server = server
} }
func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool) { func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenConnected func()) {
if connCtx.Server != nil { if connCtx.Server != nil {
return return
} }
@ -80,18 +80,33 @@ func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool) {
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
// todo: change here DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
DialContext: (&net.Dialer{ log.Debugln("in https DialTLSContext")
// Timeout: 30 * time.Second,
// KeepAlive: 30 * time.Second, plainConn, err := (&net.Dialer{}).DialContext(ctx, network, addr)
}).DialContext, if err != nil {
return nil, err
}
cw := connWrap(plainConn)
server.Conn = cw
whenConnected()
firstTLSHost, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
cfg := &tls.Config{
InsecureSkipVerify: sslInsecure,
KeyLogWriter: GetTlsKeyLogWriter(),
ServerName: firstTLSHost,
}
tlsConn := tls.Client(cw, cfg)
return tlsConn, nil
},
ForceAttemptHTTP2: false, // disable http2 ForceAttemptHTTP2: false, // disable http2
DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true. DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true.
TLSClientConfig: &tls.Config{
InsecureSkipVerify: sslInsecure,
KeyLogWriter: GetTlsKeyLogWriter(),
},
}, },
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
// 禁止自动重定向 // 禁止自动重定向

@ -32,14 +32,14 @@ func (pipeAddr) Network() string { return "pipe" }
func (a *pipeAddr) String() string { return a.remoteAddr } func (a *pipeAddr) String() string { return a.remoteAddr }
// 建立客户端和服务端通信的通道 // 建立客户端和服务端通信的通道
func newPipes(req *http.Request) (net.Conn, *connBuf) { func newPipes(req *http.Request) (net.Conn, *pipeConn) {
client, srv := net.Pipe() client, srv := net.Pipe()
server := newConnBuf(srv, req) server := newPipeConn(srv, req)
return client, server return client, server
} }
// add Peek method for conn // add Peek method for conn
type connBuf struct { type pipeConn struct {
net.Conn net.Conn
r *bufio.Reader r *bufio.Reader
host string host string
@ -47,8 +47,8 @@ type connBuf struct {
connContext *flow.ConnContext connContext *flow.ConnContext
} }
func newConnBuf(c net.Conn, req *http.Request) *connBuf { func newPipeConn(c net.Conn, req *http.Request) *pipeConn {
return &connBuf{ return &pipeConn{
Conn: c, Conn: c,
r: bufio.NewReader(c), r: bufio.NewReader(c),
host: req.Host, host: req.Host,
@ -57,16 +57,16 @@ func newConnBuf(c net.Conn, req *http.Request) *connBuf {
} }
} }
func (b *connBuf) Peek(n int) ([]byte, error) { func (c *pipeConn) Peek(n int) ([]byte, error) {
return b.r.Peek(n) return c.r.Peek(n)
} }
func (b *connBuf) Read(data []byte) (int, error) { func (c *pipeConn) Read(data []byte) (int, error) {
return b.r.Read(data) return c.r.Read(data)
} }
func (b *connBuf) RemoteAddr() net.Addr { func (c *pipeConn) RemoteAddr() net.Addr {
return &pipeAddr{remoteAddr: b.remoteAddr} return &pipeAddr{remoteAddr: c.remoteAddr}
} }
// Middle: man-in-the-middle // Middle: man-in-the-middle
@ -93,7 +93,7 @@ func NewMiddle(proxy *Proxy, caPath string) (Interceptor, error) {
IdleTimeout: 5 * time.Second, IdleTimeout: 5 * time.Second,
ConnContext: func(ctx context.Context, c net.Conn) context.Context { ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, flow.ConnContextKey, c.(*tls.Conn).NetConn().(*connBuf).connContext) return context.WithValue(ctx, flow.ConnContextKey, c.(*tls.Conn).NetConn().(*pipeConn).connContext)
}, },
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
@ -116,9 +116,9 @@ func (m *Middle) Start() error {
} }
func (m *Middle) Dial(req *http.Request) (net.Conn, error) { func (m *Middle) Dial(req *http.Request) (net.Conn, error) {
clientConn, serverConn := newPipes(req) pipeClientConn, pipeServerConn := newPipes(req)
go m.intercept(serverConn) go m.intercept(pipeServerConn)
return clientConn, nil return pipeClientConn, nil
} }
func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) { func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) {
@ -140,24 +140,38 @@ func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) {
// 解析 connect 流量 // 解析 connect 流量
// 如果是 tls 流量,则进入 listener.Accept => Middle.ServeHTTP // 如果是 tls 流量,则进入 listener.Accept => Middle.ServeHTTP
// 否则很可能是 ws 流量 // 否则很可能是 ws 流量
func (m *Middle) intercept(serverConn *connBuf) { func (m *Middle) intercept(pipeServerConn *pipeConn) {
log := log.WithField("in", "Middle.intercept").WithField("host", serverConn.host) log := log.WithField("in", "Middle.intercept").WithField("host", pipeServerConn.host)
buf, err := serverConn.Peek(3) buf, err := pipeServerConn.Peek(3)
if err != nil { if err != nil {
log.Errorf("Peek error: %v\n", err) log.Errorf("Peek error: %v\n", err)
serverConn.Close() pipeServerConn.Close()
return return
} }
// https://github.com/mitmproxy/mitmproxy/blob/main/mitmproxy/net/tls.py is_tls_record_magic // https://github.com/mitmproxy/mitmproxy/blob/main/mitmproxy/net/tls.py is_tls_record_magic
if buf[0] == 0x16 && buf[1] == 0x03 && buf[2] <= 0x03 { if buf[0] == 0x16 && buf[1] == 0x03 && buf[2] <= 0x03 {
// tls // tls
serverConn.connContext.Client.Tls = true pipeServerConn.connContext.Client.Tls = true
serverConn.connContext.InitHttpsServer(m.Proxy.Opts.SslInsecure) pipeServerConn.connContext.InitHttpsServer(
m.Listener.(*middleListener).connChan <- serverConn m.Proxy.Opts.SslInsecure,
func(c net.Conn) net.Conn {
return &serverConn{
Conn: c,
proxy: m.Proxy,
connCtx: pipeServerConn.connContext,
}
},
func() {
for _, addon := range m.Proxy.Addons {
addon.ServerConnected(pipeServerConn.connContext)
}
},
)
m.Listener.(*middleListener).connChan <- pipeServerConn
} else { } else {
// ws // ws
DefaultWebSocket.WS(serverConn, serverConn.host) DefaultWebSocket.WS(pipeServerConn, pipeServerConn.host)
} }
} }

Loading…
Cancel
Save