addon-dailer
lqqyt2423 4 years ago
parent 573723b61e
commit 730bd208bb

@ -0,0 +1,54 @@
package addon
import (
"time"
"github.com/lqqyt2423/go-mitmproxy/flow"
_log "github.com/sirupsen/logrus"
)
var log = _log.WithField("at", "addon")
type Addon interface {
// HTTP request headers were successfully read. At this point, the body is empty.
Requestheaders(*flow.Flow)
// The full HTTP request has been read.
Request(*flow.Flow)
// HTTP response headers were successfully read. At this point, the body is empty.
Responseheaders(*flow.Flow)
// The full HTTP response has been read.
Response(*flow.Flow)
}
// Base do nothing
type Base struct{}
func (addon *Base) Requestheaders(*flow.Flow) {}
func (addon *Base) Request(*flow.Flow) {}
func (addon *Base) Responseheaders(*flow.Flow) {}
func (addon *Base) Response(*flow.Flow) {}
// Log log http record
type Log struct {
Base
}
func (addon *Log) Requestheaders(f *flow.Flow) {
log := log.WithField("in", "Log")
start := time.Now()
go func() {
<-f.Done()
var StatusCode int
if f.Response != nil {
StatusCode = f.Response.StatusCode
}
var contentLen int
if f.Response != nil && f.Response.Body != nil {
contentLen = len(f.Response.Body)
}
log.Infof("%v %v %v %v - %v ms\n", f.Request.Method, f.Request.URL.String(), StatusCode, contentLen, time.Since(start).Milliseconds())
}()
}

@ -3,7 +3,6 @@ package flow
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"time"
_log "github.com/sirupsen/logrus" _log "github.com/sirupsen/logrus"
) )
@ -45,47 +44,3 @@ func (f *Flow) Done() <-chan struct{} {
func (f *Flow) Finish() { func (f *Flow) Finish() {
close(f.done) close(f.done)
} }
type Addon interface {
// HTTP request headers were successfully read. At this point, the body is empty.
Requestheaders(*Flow)
// The full HTTP request has been read.
Request(*Flow)
// HTTP response headers were successfully read. At this point, the body is empty.
Responseheaders(*Flow)
// The full HTTP response has been read.
Response(*Flow)
}
// BaseAddon do nothing
type BaseAddon struct{}
func (addon *BaseAddon) Requestheaders(*Flow) {}
func (addon *BaseAddon) Request(*Flow) {}
func (addon *BaseAddon) Responseheaders(*Flow) {}
func (addon *BaseAddon) Response(*Flow) {}
// LogAddon log http record
type LogAddon struct {
BaseAddon
}
func (addon *LogAddon) Requestheaders(flo *Flow) {
log := log.WithField("in", "LogAddon")
start := time.Now()
go func() {
<-flo.Done()
var StatusCode int
if flo.Response != nil {
StatusCode = flo.Response.StatusCode
}
var contentLen int
if flo.Response != nil && flo.Response.Body != nil {
contentLen = len(flo.Response.Body)
}
log.Infof("%v %v %v %v - %v ms\n", flo.Request.Method, flo.Request.URL.String(), StatusCode, contentLen, time.Since(start).Milliseconds())
}()
}

@ -3,9 +3,76 @@ package proxy
import ( import (
"bytes" "bytes"
"io" "io"
"os"
"strings"
"sync"
_log "github.com/sirupsen/logrus"
) )
var NormalErrMsgs []string = []string{
"read: connection reset by peer",
"write: broken pipe",
"i/o timeout",
"net/http: TLS handshake timeout",
"io: read/write on closed pipe",
"connect: connection refused",
"connect: connection reset by peer",
}
// 仅打印预料之外的错误信息
func LogErr(log *_log.Entry, err error) (loged bool) {
msg := err.Error()
for _, str := range NormalErrMsgs {
if strings.Contains(msg, str) {
log.Debug(err)
return
}
}
log.Error(err)
loged = true
return
}
// 转发流量
// Read a => Write b
// Read b => Write a
func Transfer(log *_log.Entry, a, b io.ReadWriter) {
done := make(chan struct{})
defer close(done)
forward := func(dst io.Writer, src io.Reader, ec chan<- error) {
_, err := io.Copy(dst, src)
if v, ok := dst.(*conn); ok {
// 避免内存泄漏
_ = v.Writer.CloseWithError(nil)
}
select {
case <-done:
return
case ec <- err:
}
}
errChan := make(chan error)
go forward(a, b, errChan)
go forward(b, a, errChan)
for i := 0; i < 2; i++ {
if err := <-errChan; err != nil {
LogErr(log, err)
return // 如果有错误,直接返回
}
}
}
// 尝试将 Reader 读取至 buffer 中 // 尝试将 Reader 读取至 buffer 中
// 如果未达到 limit则成功读取进入 buffer
// 否则 buffer 返回 nil且返回新 Reader状态为未读取前
func ReaderToBuffer(r io.Reader, limit int64) ([]byte, io.Reader, error) { func ReaderToBuffer(r io.Reader, limit int64) ([]byte, io.Reader, error) {
buf := bytes.NewBuffer(make([]byte, 0)) buf := bytes.NewBuffer(make([]byte, 0))
lr := io.LimitReader(r, limit) lr := io.LimitReader(r, limit)
@ -24,3 +91,25 @@ func ReaderToBuffer(r io.Reader, limit int64) ([]byte, io.Reader, error) {
// 返回 buffer // 返回 buffer
return buf.Bytes(), nil, nil return buf.Bytes(), nil, nil
} }
// Wireshark 解析 https 设置
var tlsKeyLogWriter io.Writer
var tlsKeyLogOnce sync.Once
func GetTlsKeyLogWriter() io.Writer {
tlsKeyLogOnce.Do(func() {
logfile := os.Getenv("SSLKEYLOGFILE")
if logfile == "" {
return
}
writer, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
log.WithField("in", "GetTlsKeyLogWriter").Debug(err)
return
}
tlsKeyLogWriter = writer
})
return tlsKeyLogWriter
}

@ -0,0 +1,24 @@
package proxy
import (
"net"
)
// 拦截 https 流量通用接口
type Interceptor interface {
// 初始化
Start() error
// 针对每个 host 连接
Dial(host string) (net.Conn, error)
}
// 直接转发 https 流量
type Forward struct{}
func (i *Forward) Start() error {
return nil
}
func (i *Forward) Dial(host string) (net.Conn, error) {
return net.Dial("tcp", host)
}

@ -0,0 +1,213 @@
package proxy
import (
"bufio"
"crypto/tls"
"net"
"net/http"
"os"
"strings"
"time"
mock_conn "github.com/jordwest/mock-conn"
"github.com/lqqyt2423/go-mitmproxy/cert"
)
// 模拟了标准库中 server 运行,目的是仅通过当前进程内存转发 socket 数据,不需要经过 tcp 或 unix socket
// mock net.Listener
type listener struct {
connChan chan net.Conn
}
func (l *listener) Accept() (net.Conn, error) { return <-l.connChan, nil }
func (l *listener) Close() error { return nil }
func (l *listener) Addr() net.Addr { return nil }
type ioRes struct {
n int
err error
}
// mock net.Conn
type conn struct {
mock_conn.End
host string // remote host
readErrChan chan error // Read 方法提前返回时的错误
}
// 建立客户端和服务端通信的通道
func newPipes(host string) (client *conn, server *connBuf) {
pipes := mock_conn.NewConn()
client = &conn{*pipes.Client, host, nil}
serverConn := &conn{*pipes.Server, host, make(chan error)}
server = newConnBuf(serverConn)
return client, server
}
// 当接收到 readErrChan 时,可提前返回
func (c *conn) Read(data []byte) (int, error) {
select {
case err := <-c.readErrChan:
return 0, err
default:
}
resChan := make(chan *ioRes)
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
return
default:
}
n, err := c.End.Read(data)
select {
case resChan <- &ioRes{n, err}:
return
case <-done:
close(resChan)
}
}()
select {
case res := <-resChan:
return res.n, res.err
case err := <-c.readErrChan:
return 0, err
}
}
func (c *conn) SetDeadline(t time.Time) error {
if !t.Equal(time.Time{}) {
log.WithField("host", c.host).Warnf("SetDeadline %v\n", t)
}
return nil
}
// http server 会在连接快结束时调用此方法
func (c *conn) SetReadDeadline(t time.Time) error {
if !t.Equal(time.Time{}) {
if !t.After(time.Now()) {
// 使当前 Read 尽快返回
c.readErrChan <- os.ErrDeadlineExceeded
} else {
log.WithField("host", c.host).Warnf("SetReadDeadline %v\n", t)
}
}
return nil
}
func (c *conn) SetWriteDeadline(t time.Time) error {
log.WithField("host", c.host).Warnf("SetWriteDeadline %v\n", t)
return nil
}
// add Peek method for conn
type connBuf struct {
*conn
r *bufio.Reader
}
func newConnBuf(c *conn) *connBuf {
return &connBuf{c, bufio.NewReader(c)}
}
func (b *connBuf) Peek(n int) ([]byte, error) {
return b.r.Peek(n)
}
func (b *connBuf) Read(data []byte) (int, error) {
return b.r.Read(data)
}
// Middle: man-in-the-middle
type Middle struct {
Proxy *Proxy
CA *cert.CA
Listener net.Listener
Server *http.Server
}
func NewMiddle(proxy *Proxy) (Interceptor, error) {
ca, err := cert.NewCA("")
if err != nil {
return nil, err
}
m := &Middle{
Proxy: proxy,
CA: ca,
}
server := &http.Server{
Handler: m,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
TLSConfig: &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Debugf("Middle GetCertificate ServerName: %v\n", chi.ServerName)
return ca.GetCert(chi.ServerName)
},
},
}
// 每次连接尽快结束,因为连接并无开销
server.SetKeepAlivesEnabled(false)
m.Server = server
return m, nil
}
func (m *Middle) Start() error {
m.Listener = &listener{make(chan net.Conn)}
return m.Server.ServeTLS(m.Listener, "", "")
}
func (m *Middle) Dial(host string) (net.Conn, error) {
clientConn, serverConn := newPipes(host)
go m.intercept(serverConn)
return clientConn, nil
}
func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) {
if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
// wss
DefaultWebSocket.WSS(res, req)
return
}
if req.URL.Scheme == "" {
req.URL.Scheme = "https"
}
if req.URL.Host == "" {
req.URL.Host = req.Host
}
m.Proxy.ServeHTTP(res, req)
}
// 解析 connect 流量
// 如果是 tls 流量,则进入 listener.Accept => Middle.ServeHTTP
// 否则很可能是 ws 流量
func (m *Middle) intercept(serverConn *connBuf) {
log := log.WithField("in", "Middle.intercept").WithField("host", serverConn.host)
buf, err := serverConn.Peek(3)
if err != nil {
log.Errorf("Peek error: %v\n", err)
serverConn.Close()
return
}
if buf[0] == 0x16 && buf[1] == 0x03 && (buf[2] >= 0x0 || buf[2] <= 0x03) {
// tls
m.Listener.(*listener).connChan <- serverConn
} else {
// ws
DefaultWebSocket.WS(serverConn, serverConn.host)
}
}

@ -1,93 +0,0 @@
package proxy
import (
"crypto/tls"
"net"
"net/http"
"github.com/lqqyt2423/go-mitmproxy/cert"
)
type Mitm interface {
Start() error
Dial(host string) (net.Conn, error)
}
// 直接转发 https 流量
type MitmForward struct{}
func (m *MitmForward) Start() error {
return nil
}
func (m *MitmForward) Dial(host string) (net.Conn, error) {
return net.Dial("tcp", host)
}
// 内部解析 https 流量
// 每个连接都会消耗掉两个文件描述符,可能会达到打开文件上限
type MitmServer struct {
Proxy *Proxy
CA *cert.CA
Listener net.Listener
Server *http.Server
}
func NewMitmServer(proxy *Proxy) (Mitm, error) {
ca, err := cert.NewCA("")
if err != nil {
return nil, err
}
m := &MitmServer{
Proxy: proxy,
CA: ca,
}
server := &http.Server{
Handler: m,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
TLSConfig: &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Debugf("MitmServer GetCertificate ServerName: %v\n", chi.ServerName)
return ca.GetCert(chi.ServerName)
},
},
}
// 尽快关闭内部的连接,释放文件描述符
server.SetKeepAlivesEnabled(false)
m.Server = server
return m, nil
}
func (m *MitmServer) Start() error {
ln, err := net.Listen("tcp", "127.0.0.1:") // port number is automatically chosen
if err != nil {
return err
}
m.Listener = ln
m.Server.Addr = ln.Addr().String()
log.Infof("MitmServer Server Addr is %v\n", m.Server.Addr)
defer ln.Close()
return m.Server.ServeTLS(ln, "", "")
}
func (m *MitmServer) Dial(host string) (net.Conn, error) {
return net.Dial("tcp", m.Server.Addr)
}
func (m *MitmServer) ServeHTTP(res http.ResponseWriter, req *http.Request) {
if req.URL.Scheme == "" {
req.URL.Scheme = "https"
}
if req.URL.Host == "" {
req.URL.Host = req.Host
}
m.Proxy.ServeHTTP(res, req)
}

@ -1,269 +0,0 @@
package proxy
import (
"bufio"
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
"os"
"strings"
"time"
mock_conn "github.com/jordwest/mock-conn"
"github.com/lqqyt2423/go-mitmproxy/cert"
)
// 模拟实现 net
type listener struct {
connChan chan net.Conn
}
func (l *listener) Accept() (net.Conn, error) {
return <-l.connChan, nil
}
func (l *listener) Close() error {
return nil
}
func (l *listener) Addr() net.Addr {
return nil
}
type ioRes struct {
n int
err error
}
type conn struct {
*mock_conn.End
Host string // remote host
readErrChan chan error // Read 方法提前返回时的错误
}
func newConn(end *mock_conn.End, host string) *conn {
return &conn{
End: end,
Host: host,
readErrChan: make(chan error),
}
}
// 当接收到 readErrChan 时,可提前返回
func (c *conn) Read(data []byte) (int, error) {
select {
case err := <-c.readErrChan:
return 0, err
default:
}
resChan := make(chan *ioRes)
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
return
default:
}
n, err := c.End.Read(data)
select {
case resChan <- &ioRes{n, err}:
return
case <-done:
close(resChan)
}
}()
select {
case res := <-resChan:
return res.n, res.err
case err := <-c.readErrChan:
return 0, err
}
}
func (c *conn) SetDeadline(t time.Time) error {
if !t.Equal(time.Time{}) {
log.WithField("host", c.Host).Warnf("SetDeadline %v\n", t)
}
return nil
}
// http server 会在连接快结束时调用此方法
func (c *conn) SetReadDeadline(t time.Time) error {
if !t.Equal(time.Time{}) {
if !t.After(time.Now()) {
// 使当前 Read 尽快返回
c.readErrChan <- os.ErrDeadlineExceeded
} else {
log.Warnf("SetReadDeadline %v\n", t)
}
}
return nil
}
func (c *conn) SetWriteDeadline(t time.Time) error {
log.WithField("host", c.Host).Warnf("SetWriteDeadline %v\n", t)
return nil
}
// wrap conn for peek
type connBuf struct {
*conn
r *bufio.Reader
}
func newConnBuf(c *conn) *connBuf {
return &connBuf{c, bufio.NewReader(c)}
}
func (b *connBuf) Peek(n int) ([]byte, error) {
return b.r.Peek(n)
}
func (b *connBuf) Read(data []byte) (int, error) {
return b.r.Read(data)
}
type MitmMemory struct {
Proxy *Proxy
CA *cert.CA
Listener net.Listener
Server *http.Server
}
func NewMitmMemory(proxy *Proxy) (Mitm, error) {
ca, err := cert.NewCA("")
if err != nil {
return nil, err
}
m := &MitmMemory{
Proxy: proxy,
CA: ca,
}
server := &http.Server{
Handler: m,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
TLSConfig: &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Debugf("MitmMemory GetCertificate ServerName: %v\n", chi.ServerName)
return ca.GetCert(chi.ServerName)
},
},
}
// 每次连接尽快结束,因为连接并无开销
server.SetKeepAlivesEnabled(false)
m.Server = server
return m, nil
}
func (m *MitmMemory) Start() error {
ln := &listener{
connChan: make(chan net.Conn),
}
m.Listener = ln
return m.Server.ServeTLS(ln, "", "")
}
func (m *MitmMemory) Dial(host string) (net.Conn, error) {
log := log.WithField("in", "MitmMemory.Dial").WithField("host", host)
pipes := mock_conn.NewConn()
// 如果是 tls 流量,则进入 listener.Accept => MitmMemory.ServeHTTP
// 否则很可能是 ws 流量,直接转发流量
go func() {
conn := newConn(pipes.Server, host)
connb := newConnBuf(conn)
buf, err := connb.Peek(3)
if err != nil {
log.Errorf("Peek error: %v\n", err)
connb.Close()
return
}
// tls
if buf[0] == 0x16 && buf[1] == 0x03 && (buf[2] >= 0x0 || buf[2] <= 0x03) {
m.Listener.(*listener).connChan <- connb
} else {
// websocket ws://
log.Debug("begin websocket ws://")
defer connb.Close()
remoteConn, err := net.Dial("tcp", host)
if err != nil {
if !ignoreErr(log, err) {
log.Error(err)
}
return
}
defer remoteConn.Close()
transfer(log, connb, remoteConn)
}
}()
return newConn(pipes.Client, host), nil
}
func (m *MitmMemory) ServeHTTP(res http.ResponseWriter, req *http.Request) {
log := log.WithField("in", "MitmMemory.ServeHTTP").WithField("host", req.Host)
// websocket wss://
if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
log.Debug("begin websocket wss://")
upgradeBuf, err := httputil.DumpRequest(req, false)
if err != nil {
log.Errorf("DumpRequest: %v\n", err)
res.WriteHeader(502)
return
}
cconn, _, err := res.(http.Hijacker).Hijack()
if err != nil {
log.Errorf("Hijack: %v\n", err)
res.WriteHeader(502)
return
}
defer cconn.Close()
host := req.Host
if !strings.Contains(host, ":") {
host = host + ":443"
}
conn, err := tls.Dial("tcp", host, nil)
if err != nil {
log.Errorf("tls.Dial: %v\n", err)
return
}
defer conn.Close()
_, err = conn.Write(upgradeBuf)
if err != nil {
log.Errorf("wss upgrade: %v\n", err)
return
}
transfer(log, conn, cconn)
return
}
if req.URL.Scheme == "" {
req.URL.Scheme = "https"
}
if req.URL.Host == "" {
req.URL.Host = req.Host
}
m.Proxy.ServeHTTP(res, req)
}

@ -6,86 +6,80 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
"strings"
"sync"
"time" "time"
"github.com/lqqyt2423/go-mitmproxy/addon"
"github.com/lqqyt2423/go-mitmproxy/flow" "github.com/lqqyt2423/go-mitmproxy/flow"
_log "github.com/sirupsen/logrus" _log "github.com/sirupsen/logrus"
) )
var log = _log.WithField("at", "proxy") var log = _log.WithField("at", "proxy")
var ignoreErr = func(log *_log.Entry, err error) bool { type Options struct {
errs := err.Error() Addr string
strs := []string{ StreamLargeBodies int64
"read: connection reset by peer",
"write: broken pipe",
"i/o timeout",
"net/http: TLS handshake timeout",
"io: read/write on closed pipe",
"connect: connection refused",
"connect: connection reset by peer",
}
for _, str := range strs {
if strings.Contains(errs, str) {
log.Debug(err)
return true
}
}
return false
} }
func transfer(log *_log.Entry, a, b io.ReadWriter) { type Proxy struct {
done := make(chan struct{}) Server *http.Server
defer close(done) Client *http.Client
Interceptor Interceptor
forward := func(dst io.Writer, src io.Reader, ec chan<- error) { StreamLargeBodies int64 // 当请求或响应体大于此字节时,转为 stream 模式
_, err := io.Copy(dst, src) Addons []addon.Addon
}
if v, ok := dst.(*conn); ok { func NewProxy(opts *Options) (*Proxy, error) {
// 避免内存泄漏的关键 proxy := new(Proxy)
_ = v.Writer.CloseWithError(nil)
}
select { proxy.Server = &http.Server{
case <-done: Addr: opts.Addr,
return Handler: proxy,
case ec <- err:
}
} }
errChan := make(chan error) proxy.Client = &http.Client{
go forward(a, b, errChan) Transport: &http.Transport{
go forward(b, a, errChan) Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
for i := 0; i < 2; i++ { ForceAttemptHTTP2: false, // disable http2
if err := <-errChan; err != nil { DisableCompression: true,
if !ignoreErr(log, err) { TLSClientConfig: &tls.Config{
log.Error(err) KeyLogWriter: GetTlsKeyLogWriter(),
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// 禁止自动重定向
return http.ErrUseLastResponse
},
} }
return // 如果有错误,直接返回
interceptor, err := NewMiddle(proxy)
if err != nil {
return nil, err
} }
proxy.Interceptor = interceptor
if opts.StreamLargeBodies > 0 {
proxy.StreamLargeBodies = opts.StreamLargeBodies
} else {
proxy.StreamLargeBodies = 1024 * 1024 * 5 // default: 5mb
} }
}
type Options struct { proxy.Addons = make([]addon.Addon, 0)
Addr string proxy.AddAddon(&addon.Log{})
StreamLargeBodies int64
}
type Proxy struct { return proxy, nil
Server *http.Server
Client *http.Client
Mitm Mitm
StreamLargeBodies int64 // 当请求或响应体大于此字节时,转为 stream 模式
Addons []flow.Addon
} }
func (proxy *Proxy) AddAddon(addon flow.Addon) { func (proxy *Proxy) AddAddon(addon addon.Addon) {
proxy.Addons = append(proxy.Addons, addon) proxy.Addons = append(proxy.Addons, addon)
} }
@ -99,7 +93,7 @@ func (proxy *Proxy) Start() error {
}() }()
go func() { go func() {
err := proxy.Mitm.Start() err := proxy.Interceptor.Start()
errChan <- err errChan <- err
}() }()
@ -130,7 +124,7 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
return return
} }
endRes := func(response *flow.Response, body io.Reader) { reply := func(response *flow.Response, body io.Reader) {
if response.Header != nil { if response.Header != nil {
for key, value := range response.Header { for key, value := range response.Header {
for _, v := range value { for _, v := range value {
@ -142,13 +136,13 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
if body != nil { if body != nil {
_, err := io.Copy(res, body) _, err := io.Copy(res, body)
if err != nil && !ignoreErr(log, err) { if err != nil {
log.Error(err) LogErr(log, err)
} }
} else if response.Body != nil && len(response.Body) > 0 { } else if response.Body != nil && len(response.Body) > 0 {
_, err := res.Write(response.Body) _, err := res.Write(response.Body)
if err != nil && !ignoreErr(log, err) { if err != nil {
log.Error(err) LogErr(log, err)
} }
} }
} }
@ -160,27 +154,27 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
} }
}() }()
flo := flow.NewFlow() f := flow.NewFlow()
flo.Request = &flow.Request{ f.Request = &flow.Request{
Method: req.Method, Method: req.Method,
URL: req.URL, URL: req.URL,
Proto: req.Proto, Proto: req.Proto,
Header: req.Header, Header: req.Header,
} }
defer flo.Finish() defer f.Finish()
// trigger addon event Requestheaders // trigger addon event Requestheaders
for _, addon := range proxy.Addons { for _, addon := range proxy.Addons {
addon.Requestheaders(flo) addon.Requestheaders(f)
if flo.Response != nil { if f.Response != nil {
endRes(flo.Response, nil) reply(f.Response, nil)
return return
} }
} }
// request body // Read request body
var reqBody io.Reader = req.Body var reqBody io.Reader = req.Body
if !flo.Stream { if !f.Stream {
reqBuf, r, err := ReaderToBuffer(req.Body, proxy.StreamLargeBodies) reqBuf, r, err := ReaderToBuffer(req.Body, proxy.StreamLargeBodies)
reqBody = r reqBody = r
if err != nil { if err != nil {
@ -188,65 +182,62 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
res.WriteHeader(502) res.WriteHeader(502)
return return
} }
if reqBuf == nil { if reqBuf == nil {
log.Warnf("request body size >= %v\n", proxy.StreamLargeBodies) log.Warnf("request body size >= %v\n", proxy.StreamLargeBodies)
flo.Stream = true f.Stream = true
} else { } else {
flo.Request.Body = reqBuf f.Request.Body = reqBuf
}
// trigger addon event Request // trigger addon event Request
if !flo.Stream {
for _, addon := range proxy.Addons { for _, addon := range proxy.Addons {
addon.Request(flo) addon.Request(f)
if flo.Response != nil { if f.Response != nil {
endRes(flo.Response, nil) reply(f.Response, nil)
return return
} }
} }
reqBody = bytes.NewReader(flo.Request.Body) reqBody = bytes.NewReader(f.Request.Body)
} }
} }
proxyReq, err := http.NewRequest(flo.Request.Method, flo.Request.URL.String(), reqBody) proxyReq, err := http.NewRequest(f.Request.Method, f.Request.URL.String(), reqBody)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
res.WriteHeader(502) res.WriteHeader(502)
return return
} }
for key, value := range flo.Request.Header { for key, value := range f.Request.Header {
for _, v := range value { for _, v := range value {
proxyReq.Header.Add(key, v) proxyReq.Header.Add(key, v)
} }
} }
proxyRes, err := proxy.Client.Do(proxyReq) proxyRes, err := proxy.Client.Do(proxyReq)
if err != nil { if err != nil {
if !ignoreErr(log, err) { LogErr(log, err)
log.Error(err)
}
res.WriteHeader(502) res.WriteHeader(502)
return return
} }
defer proxyRes.Body.Close() defer proxyRes.Body.Close()
flo.Response = &flow.Response{ f.Response = &flow.Response{
StatusCode: proxyRes.StatusCode, StatusCode: proxyRes.StatusCode,
Header: proxyRes.Header, Header: proxyRes.Header,
} }
// trigger addon event Responseheaders // trigger addon event Responseheaders
for _, addon := range proxy.Addons { for _, addon := range proxy.Addons {
addon.Responseheaders(flo) addon.Responseheaders(f)
if flo.Response.Body != nil { if f.Response.Body != nil {
endRes(flo.Response, nil) reply(f.Response, nil)
return return
} }
} }
// response body // Read response body
var resBody io.Reader = proxyRes.Body var resBody io.Reader = proxyRes.Body
if !flo.Stream { if !f.Stream {
resBuf, r, err := ReaderToBuffer(proxyRes.Body, proxy.StreamLargeBodies) resBuf, r, err := ReaderToBuffer(proxyRes.Body, proxy.StreamLargeBodies)
resBody = r resBody = r
if err != nil { if err != nil {
@ -256,20 +247,18 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
} }
if resBuf == nil { if resBuf == nil {
log.Warnf("response body size >= %v\n", proxy.StreamLargeBodies) log.Warnf("response body size >= %v\n", proxy.StreamLargeBodies)
flo.Stream = true f.Stream = true
} else { } else {
flo.Response.Body = resBuf f.Response.Body = resBuf
}
// trigger addon event Response // trigger addon event Response
if !flo.Stream {
for _, addon := range proxy.Addons { for _, addon := range proxy.Addons {
addon.Response(flo) addon.Response(f)
} }
} }
} }
endRes(flo.Response, resBody) reply(f.Response, resBody)
} }
func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) {
@ -280,8 +269,7 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) {
log.Debug("receive connect") log.Debug("receive connect")
conn, err := proxy.Mitm.Dial(req.Host) conn, err := proxy.Interceptor.Dial(req.Host)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
res.WriteHeader(502) res.WriteHeader(502)
@ -303,77 +291,5 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) {
return return
} }
transfer(log, conn, cconn) Transfer(log, conn, cconn)
}
func NewProxy(opts *Options) (*Proxy, error) {
proxy := new(Proxy)
proxy.Server = &http.Server{
Addr: opts.Addr,
Handler: proxy,
}
proxy.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: false, // disable http2
DisableCompression: true,
TLSClientConfig: &tls.Config{
KeyLogWriter: GetTlsKeyLogWriter(),
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// 禁止自动重定向
return http.ErrUseLastResponse
},
}
mitm, err := NewMitmMemory(proxy)
if err != nil {
return nil, err
}
proxy.Mitm = mitm
if opts.StreamLargeBodies > 0 {
proxy.StreamLargeBodies = opts.StreamLargeBodies
} else {
proxy.StreamLargeBodies = 1024 * 1024 * 5 // default: 5mb
}
proxy.Addons = make([]flow.Addon, 0)
proxy.AddAddon(&flow.LogAddon{})
return proxy, nil
}
var tlsKeyLogWriter io.Writer
var tlsKeyLogOnce sync.Once
// Wireshark 解析 https 设置
func GetTlsKeyLogWriter() io.Writer {
tlsKeyLogOnce.Do(func() {
logfile := os.Getenv("SSLKEYLOGFILE")
if logfile == "" {
return
}
writer, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
log.WithField("in", "GetTlsKeyLogWriter").Debug(err)
return
}
tlsKeyLogWriter = writer
})
return tlsKeyLogWriter
} }

@ -0,0 +1,65 @@
package proxy
import (
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
"strings"
)
// 当前仅做了转发 websocket 流量
type WebSocket struct{}
var DefaultWebSocket WebSocket
func (s *WebSocket) WS(conn net.Conn, host string) {
log := log.WithField("in", "WebSocket.WS").WithField("host", host)
defer conn.Close()
remoteConn, err := net.Dial("tcp", host)
if err != nil {
LogErr(log, err)
return
}
defer remoteConn.Close()
Transfer(log, conn, remoteConn)
}
func (s *WebSocket) WSS(res http.ResponseWriter, req *http.Request) {
log := log.WithField("in", "WebSocket.WSS").WithField("host", req.Host)
upgradeBuf, err := httputil.DumpRequest(req, false)
if err != nil {
log.Errorf("DumpRequest: %v\n", err)
res.WriteHeader(502)
return
}
cconn, _, err := res.(http.Hijacker).Hijack()
if err != nil {
log.Errorf("Hijack: %v\n", err)
res.WriteHeader(502)
return
}
defer cconn.Close()
host := req.Host
if !strings.Contains(host, ":") {
host = host + ":443"
}
conn, err := tls.Dial("tcp", host, nil)
if err != nil {
log.Errorf("tls.Dial: %v\n", err)
return
}
defer conn.Close()
_, err = conn.Write(upgradeBuf)
if err != nil {
log.Errorf("wss upgrade: %v\n", err)
return
}
Transfer(log, conn, cconn)
}
Loading…
Cancel
Save