You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

127 lines
2.5 KiB
Go

4 years ago
package proxy
import (
"bytes"
"io"
"net"
2 years ago
"os"
4 years ago
"strings"
2 years ago
"sync"
4 years ago
2 years ago
log "github.com/sirupsen/logrus"
4 years ago
)
2 years ago
var normalErrMsgs []string = []string{
4 years ago
"read: connection reset by peer",
"write: broken pipe",
"i/o timeout",
"net/http: TLS handshake timeout",
"io: read/write on closed pipe",
"connect: connection refused",
"connect: connection reset by peer",
"use of closed network connection",
4 years ago
}
// 仅打印预料之外的错误信息
2 years ago
func logErr(log *log.Entry, err error) (loged bool) {
4 years ago
msg := err.Error()
2 years ago
for _, str := range normalErrMsgs {
4 years ago
if strings.Contains(msg, str) {
log.Debug(err)
return
}
}
log.Error(err)
loged = true
return
}
// 转发流量
func transfer(log *log.Entry, server, client io.ReadWriteCloser) {
4 years ago
done := make(chan struct{})
defer close(done)
errChan := make(chan error)
go func() {
_, err := io.Copy(server, client)
log.Debugln("client copy end", err)
client.Close()
select {
case <-done:
return
case errChan <- err:
return
}
}()
go func() {
_, err := io.Copy(client, server)
log.Debugln("server copy end", err)
server.Close()
if clientConn, ok := client.(*wrapClientConn); ok {
err := clientConn.Conn.(*net.TCPConn).CloseRead()
log.Debugln("clientConn.Conn.(*net.TCPConn).CloseRead()", err)
}
4 years ago
select {
case <-done:
return
case errChan <- err:
4 years ago
return
4 years ago
}
}()
4 years ago
for i := 0; i < 2; i++ {
if err := <-errChan; err != nil {
2 years ago
logErr(log, err)
4 years ago
return // 如果有错误,直接返回
}
}
}
4 years ago
// 尝试将 Reader 读取至 buffer 中
4 years ago
// 如果未达到 limit则成功读取进入 buffer
// 否则 buffer 返回 nil且返回新 Reader状态为未读取前
2 years ago
func readerToBuffer(r io.Reader, limit int64) ([]byte, io.Reader, error) {
4 years ago
buf := bytes.NewBuffer(make([]byte, 0))
lr := io.LimitReader(r, limit)
_, err := io.Copy(buf, lr)
if err != nil {
return nil, nil, err
}
// 达到上限
if int64(buf.Len()) == limit {
// 返回新的 Reader
return nil, io.MultiReader(bytes.NewBuffer(buf.Bytes()), r), nil
}
// 返回 buffer
return buf.Bytes(), nil, nil
}
2 years ago
// Wireshark 解析 https 设置
var tlsKeyLogWriter io.Writer
var tlsKeyLogOnce sync.Once
func getTlsKeyLogWriter() io.Writer {
tlsKeyLogOnce.Do(func() {
logfile := os.Getenv("SSLKEYLOGFILE")
if logfile == "" {
return
}
writer, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
2 years ago
log.Debugf("getTlsKeyLogWriter OpenFile error: %v", err)
2 years ago
return
}
tlsKeyLogWriter = writer
})
return tlsKeyLogWriter
}