You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-mitmproxy/proxy/mitm_memory.go

270 lines
5.2 KiB
Go

package proxy
import (
"bufio"
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
"os"
"strings"
"time"
mock_conn "github.com/jordwest/mock-conn"
"github.com/lqqyt2423/go-mitmproxy/cert"
)
// 模拟实现 net
type listener struct {
connChan chan net.Conn
}
func (l *listener) Accept() (net.Conn, error) {
return <-l.connChan, nil
}
func (l *listener) Close() error {
return nil
}
func (l *listener) Addr() net.Addr {
return nil
}
type ioRes struct {
n int
err error
}
type conn struct {
*mock_conn.End
Host string // remote host
readErrChan chan error // Read 方法提前返回时的错误
}
func newConn(end *mock_conn.End, host string) *conn {
return &conn{
End: end,
Host: host,
readErrChan: make(chan error),
}
}
// 当接收到 readErrChan 时,可提前返回
func (c *conn) Read(data []byte) (int, error) {
select {
case err := <-c.readErrChan:
return 0, err
default:
}
resChan := make(chan *ioRes)
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
return
default:
}
n, err := c.End.Read(data)
select {
case resChan <- &ioRes{n, err}:
return
case <-done:
close(resChan)
}
}()
select {
case res := <-resChan:
return res.n, res.err
case err := <-c.readErrChan:
return 0, err
}
}
func (c *conn) SetDeadline(t time.Time) error {
if !t.Equal(time.Time{}) {
log.WithField("host", c.Host).Warnf("SetDeadline %v\n", t)
}
return nil
}
// http server 会在连接快结束时调用此方法
func (c *conn) SetReadDeadline(t time.Time) error {
if !t.Equal(time.Time{}) {
if !t.After(time.Now()) {
// 使当前 Read 尽快返回
c.readErrChan <- os.ErrDeadlineExceeded
} else {
log.Warnf("SetReadDeadline %v\n", t)
}
}
return nil
}
func (c *conn) SetWriteDeadline(t time.Time) error {
log.WithField("host", c.Host).Warnf("SetWriteDeadline %v\n", t)
return nil
}
// wrap conn for peek
type connBuf struct {
*conn
r *bufio.Reader
}
func newConnBuf(c *conn) *connBuf {
return &connBuf{c, bufio.NewReader(c)}
}
func (b *connBuf) Peek(n int) ([]byte, error) {
return b.r.Peek(n)
}
func (b *connBuf) Read(data []byte) (int, error) {
return b.r.Read(data)
}
type MitmMemory struct {
Proxy *Proxy
CA *cert.CA
Listener net.Listener
Server *http.Server
}
func NewMitmMemory(proxy *Proxy) (Mitm, error) {
ca, err := cert.NewCA("")
if err != nil {
return nil, err
}
m := &MitmMemory{
Proxy: proxy,
CA: ca,
}
server := &http.Server{
Handler: m,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
TLSConfig: &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Debugf("MitmMemory GetCertificate ServerName: %v\n", chi.ServerName)
return ca.GetCert(chi.ServerName)
},
},
}
// 每次连接尽快结束,因为连接并无开销
server.SetKeepAlivesEnabled(false)
m.Server = server
return m, nil
}
func (m *MitmMemory) Start() error {
ln := &listener{
connChan: make(chan net.Conn),
}
m.Listener = ln
return m.Server.ServeTLS(ln, "", "")
}
func (m *MitmMemory) Dial(host string) (net.Conn, error) {
log := log.WithField("in", "MitmMemory.Dial").WithField("host", host)
pipes := mock_conn.NewConn()
// 如果是 tls 流量,则进入 listener.Accept => MitmMemory.ServeHTTP
// 否则很可能是 ws 流量,直接转发流量
go func() {
conn := newConn(pipes.Server, host)
connb := newConnBuf(conn)
buf, err := connb.Peek(3)
if err != nil {
log.Errorf("Peek error: %v\n", err)
connb.Close()
return
}
// tls
if buf[0] == 0x16 && buf[1] == 0x03 && (buf[2] >= 0x0 || buf[2] <= 0x03) {
m.Listener.(*listener).connChan <- connb
} else {
// websocket ws://
log.Debug("begin websocket ws://")
defer connb.Close()
remoteConn, err := net.Dial("tcp", host)
if err != nil {
if !ignoreErr(log, err) {
log.Error(err)
}
return
}
defer remoteConn.Close()
transfer(log, connb, remoteConn)
}
}()
return newConn(pipes.Client, host), nil
}
func (m *MitmMemory) ServeHTTP(res http.ResponseWriter, req *http.Request) {
log := log.WithField("in", "MitmMemory.ServeHTTP").WithField("host", req.Host)
// websocket wss://
if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
log.Debug("begin websocket wss://")
upgradeBuf, err := httputil.DumpRequest(req, false)
if err != nil {
log.Errorf("DumpRequest: %v\n", err)
res.WriteHeader(502)
return
}
cconn, _, err := res.(http.Hijacker).Hijack()
if err != nil {
log.Errorf("Hijack: %v\n", err)
res.WriteHeader(502)
return
}
defer cconn.Close()
host := req.Host
if !strings.Contains(host, ":") {
host = host + ":443"
}
conn, err := tls.Dial("tcp", host, nil)
if err != nil {
log.Errorf("tls.Dial: %v\n", err)
return
}
defer conn.Close()
_, err = conn.Write(upgradeBuf)
if err != nil {
log.Errorf("wss upgrade: %v\n", err)
return
}
transfer(log, conn, cconn)
return
}
if req.URL.Scheme == "" {
req.URL.Scheme = "https"
}
if req.URL.Host == "" {
req.URL.Host = req.Host
}
m.Proxy.ServeHTTP(res, req)
}