234 lines
5.1 KiB
Go
234 lines
5.1 KiB
Go
/*
|
|
* Copyright (C) 2020-2022, IrineSistiana
|
|
*
|
|
* This file is part of mosdns.
|
|
*
|
|
* mosdns is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* mosdns is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package upstream
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
func newUDPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
|
|
udpConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
udpAddr := udpConn.LocalAddr().String()
|
|
udpServer := dns.Server{
|
|
PacketConn: udpConn,
|
|
Handler: handler,
|
|
}
|
|
go udpServer.ActivateAndServe()
|
|
return udpAddr, func() {
|
|
udpServer.Shutdown()
|
|
}
|
|
}
|
|
|
|
func newTCPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tcpAddr := l.Addr().String()
|
|
tcpServer := dns.Server{
|
|
Listener: l,
|
|
Handler: handler,
|
|
MaxTCPQueries: -1,
|
|
}
|
|
go tcpServer.ActivateAndServe()
|
|
return tcpAddr, func() {
|
|
tcpServer.Shutdown()
|
|
}
|
|
}
|
|
|
|
func newDoTTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
|
|
serverName := "test"
|
|
cert, err := utils.GenerateCertificate(serverName)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tlsConfig := new(tls.Config)
|
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
|
tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
doTAddr := tlsListener.Addr().String()
|
|
doTServer := dns.Server{
|
|
Net: "tcp-tls",
|
|
Listener: tlsListener,
|
|
TLSConfig: tlsConfig,
|
|
Handler: handler,
|
|
MaxTCPQueries: -1,
|
|
}
|
|
go doTServer.ActivateAndServe()
|
|
return doTAddr, func() {
|
|
doTServer.Shutdown()
|
|
}
|
|
}
|
|
|
|
type newTestServerFunc func(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func())
|
|
|
|
var m = map[string]newTestServerFunc{
|
|
"udp": newUDPTestServer,
|
|
"tcp": newTCPTestServer,
|
|
"tls": newDoTTestServer,
|
|
}
|
|
|
|
func Test_fastUpstream(t *testing.T) {
|
|
|
|
// TODO: add test for doh
|
|
// TODO: add test for socks5
|
|
|
|
// server config
|
|
for scheme, f := range m {
|
|
for _, bigMsg := range [...]bool{true, false} {
|
|
for _, latency := range [...]time.Duration{0, time.Millisecond * 10} {
|
|
|
|
// client specific
|
|
for _, idleTimeout := range [...]time.Duration{0, time.Second} {
|
|
|
|
testName := fmt.Sprintf(
|
|
"test: protocol: %s, bigMsg: %v, latency: %s, getIdleTimeout: %s",
|
|
scheme,
|
|
bigMsg,
|
|
latency,
|
|
idleTimeout,
|
|
)
|
|
|
|
t.Run(testName, func(t *testing.T) {
|
|
addr, shutdownServer := f(t, &vServer{
|
|
latency: latency,
|
|
bigMsg: bigMsg,
|
|
})
|
|
defer shutdownServer()
|
|
u, err := NewUpstream(
|
|
scheme+"://"+addr,
|
|
Opt{
|
|
IdleTimeout: time.Second,
|
|
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
|
},
|
|
)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := testUpstream(u); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
func testUpstream(u Upstream) error {
|
|
wg := sync.WaitGroup{}
|
|
errs := make([]error, 0)
|
|
errsLock := sync.Mutex{}
|
|
logErr := func(err error) {
|
|
errsLock.Lock()
|
|
errs = append(errs, err)
|
|
errsLock.Unlock()
|
|
}
|
|
errsToString := func() string {
|
|
s := fmt.Sprintf("%d err(s) occured during the test: ", len(errs))
|
|
for i := range errs {
|
|
s = s + errs[i].Error() + "|"
|
|
}
|
|
return s
|
|
}
|
|
|
|
for i := uint16(0); i < 10; i++ {
|
|
wg.Add(1)
|
|
i := i
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
q := new(dns.Msg)
|
|
q.SetQuestion("example.com.", dns.TypeA)
|
|
q.Id = i
|
|
queryPayload, err := q.Pack()
|
|
if err != nil {
|
|
logErr(err)
|
|
return
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
r, err := u.ExchangeContext(ctx, queryPayload)
|
|
if err != nil {
|
|
logErr(err)
|
|
return
|
|
}
|
|
|
|
resp := new(dns.Msg)
|
|
err = resp.Unpack(*r)
|
|
if err != nil {
|
|
logErr(err)
|
|
return
|
|
}
|
|
if q.Id != resp.Id {
|
|
logErr(dns.ErrId)
|
|
return
|
|
}
|
|
if !resp.Response {
|
|
logErr(fmt.Errorf("resp is not a resp bit"))
|
|
return
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
if len(errs) != 0 {
|
|
return errors.New(errsToString())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type vServer struct {
|
|
latency time.Duration
|
|
bigMsg bool // with 1kb padding
|
|
}
|
|
|
|
var padding = make([]byte, 1024)
|
|
|
|
func (s *vServer) ServeDNS(w dns.ResponseWriter, q *dns.Msg) {
|
|
r := new(dns.Msg)
|
|
r.SetReply(q)
|
|
if s.bigMsg {
|
|
r.SetEdns0(dns.MaxMsgSize, false)
|
|
opt := r.IsEdns0()
|
|
opt.Option = append(opt.Option, &dns.EDNS0_PADDING{Padding: padding})
|
|
}
|
|
|
|
time.Sleep(s.latency)
|
|
w.WriteMsg(r)
|
|
}
|