tranfer websocket

addon-dailer
lqqyt2423 4 years ago
parent 3cc022c365
commit 5b9e246780

@ -1,10 +1,13 @@
package proxy package proxy
import ( import (
"bufio"
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
"os" "os"
"strings"
"time" "time"
mock_conn "github.com/jordwest/mock-conn" mock_conn "github.com/jordwest/mock-conn"
@ -37,12 +40,14 @@ type ioRes struct {
type conn struct { type conn struct {
*mock_conn.End *mock_conn.End
Host string // remote host
readErrChan chan error // Read 方法提前返回时的错误 readErrChan chan error // Read 方法提前返回时的错误
} }
func newConn(end *mock_conn.End) *conn { func newConn(end *mock_conn.End, host string) *conn {
return &conn{ return &conn{
End: end, End: end,
Host: host,
readErrChan: make(chan error), readErrChan: make(chan error),
} }
} }
@ -84,7 +89,9 @@ func (c *conn) Read(data []byte) (int, error) {
} }
func (c *conn) SetDeadline(t time.Time) error { func (c *conn) SetDeadline(t time.Time) error {
log.Warnf("SetDeadline %v\n", t) if !t.Equal(time.Time{}) {
log.WithField("host", c.Host).Warnf("SetDeadline %v\n", t)
}
return nil return nil
} }
@ -103,10 +110,28 @@ func (c *conn) SetReadDeadline(t time.Time) error {
} }
func (c *conn) SetWriteDeadline(t time.Time) error { func (c *conn) SetWriteDeadline(t time.Time) error {
log.Warnf("SetWriteDeadline %v\n", t) log.WithField("host", c.Host).Warnf("SetWriteDeadline %v\n", t)
return nil return nil
} }
// wrap conn for peek
type connBuf struct {
*conn
r *bufio.Reader
}
func newConnBuf(c *conn) *connBuf {
return &connBuf{c, bufio.NewReader(c)}
}
func (b *connBuf) Peek(n int) ([]byte, error) {
return b.r.Peek(n)
}
func (b *connBuf) Read(data []byte) (int, error) {
return b.r.Read(data)
}
type MitmMemory struct { type MitmMemory struct {
Proxy *Proxy Proxy *Proxy
CA *cert.CA CA *cert.CA
@ -153,12 +178,85 @@ func (m *MitmMemory) Start() error {
} }
func (m *MitmMemory) Dial(host string) (net.Conn, error) { func (m *MitmMemory) Dial(host string) (net.Conn, error) {
log := log.WithField("in", "MitmMemory.Dial").WithField("host", host)
pipes := mock_conn.NewConn() pipes := mock_conn.NewConn()
m.Listener.(*listener).connChan <- newConn(pipes.Server)
return newConn(pipes.Client), nil // 如果是 tls 流量,则进入 listener.Accept => MitmMemory.ServeHTTP
// 否则很可能是 ws 流量,直接转发流量
go func() {
conn := newConn(pipes.Server, host)
connb := newConnBuf(conn)
buf, err := connb.Peek(3)
if err != nil {
log.Errorf("Peek error: %v\n", err)
connb.Close()
return
}
// tls
if buf[0] == 0x16 && buf[1] == 0x03 && (buf[2] >= 0x0 || buf[2] <= 0x03) {
m.Listener.(*listener).connChan <- connb
} else {
// websocket ws://
log.Debug("begin websocket ws://")
defer connb.Close()
remoteConn, err := net.Dial("tcp", host)
if err != nil {
if !ignoreErr(log, err) {
log.Error(err)
}
return
}
defer remoteConn.Close()
transfer(log, connb, remoteConn)
}
}()
return newConn(pipes.Client, host), nil
} }
func (m *MitmMemory) ServeHTTP(res http.ResponseWriter, req *http.Request) { func (m *MitmMemory) ServeHTTP(res http.ResponseWriter, req *http.Request) {
log := log.WithField("in", "MitmMemory.ServeHTTP").WithField("host", req.Host)
// websocket wss://
if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
log.Debug("begin websocket wss://")
upgradeBuf, err := httputil.DumpRequest(req, false)
if err != nil {
log.Errorf("DumpRequest: %v\n", err)
res.WriteHeader(502)
return
}
cconn, _, err := res.(http.Hijacker).Hijack()
if err != nil {
log.Errorf("Hijack: %v\n", err)
res.WriteHeader(502)
return
}
defer cconn.Close()
host := req.Host
if !strings.Contains(host, ":") {
host = host + ":443"
}
conn, err := tls.Dial("tcp", host, nil)
if err != nil {
log.Errorf("tls.Dial: %v\n", err)
return
}
defer conn.Close()
_, err = conn.Write(upgradeBuf)
if err != nil {
log.Errorf("wss upgrade: %v\n", err)
return
}
transfer(log, conn, cconn)
return
}
if req.URL.Scheme == "" { if req.URL.Scheme == "" {
req.URL.Scheme = "https" req.URL.Scheme = "https"
} }

@ -27,7 +27,7 @@ var ignoreErr = func(log *_log.Entry, err error) bool {
for _, str := range strs { for _, str := range strs {
if strings.Contains(errs, str) { if strings.Contains(errs, str) {
log.Debug(str) log.Debug(err)
return true return true
} }
} }
@ -35,6 +35,23 @@ var ignoreErr = func(log *_log.Entry, err error) bool {
return false return false
} }
func transfer(log *_log.Entry, a, b io.ReadWriter) {
done := make(chan struct{})
go func() {
_, err := io.Copy(a, b)
if err != nil && !ignoreErr(log, err) {
log.Error(err)
}
close(done)
}()
_, err := io.Copy(b, a)
if err != nil && !ignoreErr(log, err) {
log.Error(err)
}
<-done
}
type Options struct { type Options struct {
Addr string Addr string
} }
@ -70,11 +87,13 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
} }
log := log.WithFields(_log.Fields{ log := log.WithFields(_log.Fields{
"in": "ServeHTTP", "in": "Proxy.ServeHTTP",
"url": req.URL, "url": req.URL,
"method": req.Method, "method": req.Method,
}) })
log.Debug("receive request")
if !req.URL.IsAbs() || req.URL.Host == "" { if !req.URL.IsAbs() || req.URL.Host == "" {
res.WriteHeader(400) res.WriteHeader(400)
_, err := io.WriteString(res, "此为代理服务器,不能直接发起请求") _, err := io.WriteString(res, "此为代理服务器,不能直接发起请求")
@ -125,11 +144,11 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) {
log := log.WithFields(_log.Fields{ log := log.WithFields(_log.Fields{
"in": "handleConnect", "in": "Proxy.handleConnect",
"host": req.Host, "host": req.Host,
}) })
log.Debug("CONNECT") log.Debug("receive connect")
conn, err := proxy.Mitm.Dial(req.Host) conn, err := proxy.Mitm.Dial(req.Host)
@ -154,21 +173,7 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) {
return return
} }
done := make(chan struct{}) transfer(log, conn, cconn)
go func() {
_, err := io.Copy(conn, cconn)
if err != nil && !ignoreErr(log, err) {
log.Error(err)
}
close(done)
}()
_, err = io.Copy(cconn, conn)
if err != nil && !ignoreErr(log, err) {
log.Error(err)
}
<-done
} }
func NewProxy(opts *Options) (*Proxy, error) { func NewProxy(opts *Options) (*Proxy, error) {

Loading…
Cancel
Save