optimize middle.Dial

addon-dailer
lqqyt2423 2 years ago
parent 7f55903797
commit 637f752a3d

@ -49,6 +49,7 @@ type ConnContext struct {
ServerConn *ServerConn ServerConn *ServerConn
proxy *Proxy proxy *Proxy
pipeConn *pipeConn
} }
func newConnContext(c net.Conn, proxy *Proxy) *ConnContext { func newConnContext(c net.Conn, proxy *Proxy) *ConnContext {
@ -105,36 +106,38 @@ func (connCtx *ConnContext) initHttpServerConn() {
connCtx.ServerConn = serverConn connCtx.ServerConn = serverConn
} }
func (connCtx *ConnContext) initHttpsServerConn() { func (connCtx *ConnContext) initServerTcpConn() error {
if connCtx.ServerConn != nil { log.Debugln("in initServerTcpConn")
return
}
if !connCtx.ClientConn.Tls {
return
}
ServerConn := newServerConn() ServerConn := newServerConn()
ServerConn.client = &http.Client{ connCtx.ServerConn = ServerConn
Transport: &http.Transport{ ServerConn.Address = connCtx.pipeConn.host
Proxy: http.ProxyFromEnvironment,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { plainConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", ServerConn.Address)
log.Debugln("in https DialTLSContext")
plainConn, err := (&net.Dialer{}).DialContext(ctx, network, addr)
if err != nil { if err != nil {
return nil, err return err
} }
cw := &wrapServerConn{ ServerConn.Conn = &wrapServerConn{
Conn: plainConn, Conn: plainConn,
proxy: connCtx.proxy, proxy: connCtx.proxy,
connCtx: connCtx, connCtx: connCtx,
} }
ServerConn.Conn = cw
ServerConn.Address = addr
for _, addon := range connCtx.proxy.Addons { for _, addon := range connCtx.proxy.Addons {
addon.ServerConnected(connCtx) addon.ServerConnected(connCtx)
} }
return nil
}
func (connCtx *ConnContext) initHttpsServerConn() {
if !connCtx.ClientConn.Tls {
return
}
connCtx.ServerConn.client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Debugln("in https DialTLSContext")
firstTLSHost, _, err := net.SplitHostPort(addr) firstTLSHost, _, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -144,7 +147,7 @@ func (connCtx *ConnContext) initHttpsServerConn() {
KeyLogWriter: getTlsKeyLogWriter(), KeyLogWriter: getTlsKeyLogWriter(),
ServerName: firstTLSHost, ServerName: firstTLSHost,
} }
tlsConn := tls.Client(cw, cfg) tlsConn := tls.Client(connCtx.ServerConn.Conn, cfg)
return tlsConn, nil return tlsConn, nil
}, },
ForceAttemptHTTP2: false, // disable http2 ForceAttemptHTTP2: false, // disable http2
@ -155,7 +158,6 @@ func (connCtx *ConnContext) initHttpsServerConn() {
return http.ErrUseLastResponse return http.ErrUseLastResponse
}, },
} }
connCtx.ServerConn = ServerConn
} }
// wrap tcpConn for remote client // wrap tcpConn for remote client

@ -44,19 +44,22 @@ func (a *pipeAddr) String() string { return a.remoteAddr }
type pipeConn struct { type pipeConn struct {
net.Conn net.Conn
r *bufio.Reader r *bufio.Reader
host string host string // server host:port
remoteAddr string remoteAddr string // client ip:port
connContext *ConnContext connContext *ConnContext
} }
func newPipeConn(c net.Conn, req *http.Request) *pipeConn { func newPipeConn(c net.Conn, req *http.Request) *pipeConn {
return &pipeConn{ connContext := req.Context().Value(connContextKey).(*ConnContext)
pipeConn := &pipeConn{
Conn: c, Conn: c,
r: bufio.NewReader(c), r: bufio.NewReader(c),
host: req.Host, host: req.Host,
remoteAddr: req.RemoteAddr, remoteAddr: req.RemoteAddr,
connContext: req.Context().Value(connContextKey).(*ConnContext), connContext: connContext,
} }
connContext.pipeConn = pipeConn
return pipeConn
} }
func (c *pipeConn) Peek(n int) ([]byte, error) { func (c *pipeConn) Peek(n int) ([]byte, error) {
@ -116,9 +119,9 @@ func newMiddle(proxy *Proxy) (interceptor, error) {
}, },
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Debugf("middle GetCertificate ServerName: %v\n", chi.ServerName) log.Debugf("middle GetCertificate ServerName: %v\n", clientHello.ServerName)
return ca.GetCert(chi.ServerName) return ca.GetCert(clientHello.ServerName)
}, },
}, },
} }
@ -130,9 +133,14 @@ func (m *middle) Start() error {
return m.server.ServeTLS(m.listener, "", "") return m.server.ServeTLS(m.listener, "", "")
} }
// todo: should block until ServerConnected
func (m *middle) Dial(req *http.Request) (net.Conn, error) { func (m *middle) Dial(req *http.Request) (net.Conn, error) {
pipeClientConn, pipeServerConn := newPipes(req) pipeClientConn, pipeServerConn := newPipes(req)
err := pipeServerConn.connContext.initServerTcpConn()
if err != nil {
pipeClientConn.Close()
pipeServerConn.Close()
return nil, err
}
go m.intercept(pipeServerConn) go m.intercept(pipeServerConn)
return pipeClientConn, nil return pipeClientConn, nil
} }

Loading…
Cancel
Save