This commit is contained in:
commit
cd761e8145
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
63
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
63
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: "Bug report"
|
||||
description: "[反馈 Bug 必须用此模板] 程序不能运行?或者没有按照期望的那样工作?"
|
||||
title: "[Bug] "
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: "感谢热心的反馈 bug。描述越详细越有助于定位和解决 Bug。不提供有效信息的反馈可能会被直接关闭。"
|
||||
|
||||
- type: checkboxes
|
||||
id: pre-check
|
||||
attributes:
|
||||
label: "在提交之前,请确认"
|
||||
options:
|
||||
- label: "我已经尝试搜索过 Issue ,但没有找到相关问题。"
|
||||
required: true
|
||||
- label: "我正在使用最新的 mosdns 版本(或者最新的 commit),问题依旧存在。"
|
||||
required: true
|
||||
- label: "我仔细看过 wiki 后仍然无法自行解决该问题。"
|
||||
required: true
|
||||
- label: "我非常确定这是 mosdns 核心的问题。(如果是通过第三方衍生软件使用 mosdns
|
||||
核心,不确定问题源头时,请先向衍生软件开发者提交问题。)"
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: mosdns 版本
|
||||
description: "不清楚可用 `mosdns version` 查看。"
|
||||
placeholder: v9.9.9
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: system
|
||||
attributes:
|
||||
label: 操作系统
|
||||
placeholder: ubuntu
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
attributes:
|
||||
label: Bug 描述和复现步骤
|
||||
description: "描述越详细越有助于定位和解决 Bug。"
|
||||
placeholder: "示例: Bug: mosdns 的 qname 匹配器无法匹配 xxxx 域名。复现方式: 使用如下配置,请求 xxxx.xxxx ,观察 log 发现匹配器没有匹配到。"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: config
|
||||
attributes:
|
||||
label: 使用的配置文件
|
||||
render: yaml
|
||||
description: "必须是完整的配置文件。不要只复制某个插件的配置片段。请尽可能的提供最小配置文件(能复现 bug,但没有其他功能)方便开发者定位问题。"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: log
|
||||
attributes:
|
||||
label: mosdns 的 log 记录
|
||||
render: txt
|
||||
16
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
16
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
name: "Feature request"
|
||||
description: "希望添加新功能"
|
||||
title: "[Feature request] "
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: "感谢提出建议。提交的 issue 不一定会马上得到开发者的回复。开发者会根据社区的反应以及功能的重要性选择是否实现这个新功能。"
|
||||
|
||||
- type: textarea
|
||||
id: new-feature
|
||||
attributes:
|
||||
label: 希望添加的功能
|
||||
description: "请详细描述一下该功能作用,使用场景,是否有类似实现,文档等。"
|
||||
placeholder: "希望支持 ... 可以更快的 ..."
|
||||
validations:
|
||||
required: true
|
||||
10
.github/ISSUE_TEMPLATE/other-questions.md
vendored
Normal file
10
.github/ISSUE_TEMPLATE/other-questions.md
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
---
|
||||
name: Other questions
|
||||
about: 不要在 Issue 里提问。有问题请进入 Discussions 讨论。
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
不要在 issue 里提问。有问题请进入 Discussions 讨论。
|
||||
38
.github/workflows/release.yml
vendored
Normal file
38
.github/workflows/release.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
name: Release mosdns
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
|
||||
build-release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
cache: true
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Build
|
||||
run: python ./release.py
|
||||
env:
|
||||
CGO_ENABLED: '0'
|
||||
|
||||
- name: Publish
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: './release/mosdns*.zip'
|
||||
prerelease: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
24
.github/workflows/test.yml
vendored
Normal file
24
.github/workflows/test.yml
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
name: Test mosdns
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
cache: true
|
||||
|
||||
- name: Build
|
||||
run: go build -v ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -race -v ./...
|
||||
25
.gitignore
vendored
Normal file
25
.gitignore
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
vendor/
|
||||
|
||||
# release dir
|
||||
release/
|
||||
|
||||
# ide
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# test utils
|
||||
testutils/
|
||||
12
Dockerfile
Normal file
12
Dockerfile
Normal file
@ -0,0 +1,12 @@
|
||||
FROM golang:latest as builder
|
||||
ARG CGO_ENABLED=0
|
||||
|
||||
COPY ./ /root/src/
|
||||
WORKDIR /root/src/
|
||||
RUN go build -ldflags "-s -w -X main.version=$(git describe --tags --long --always)" -trimpath -o mosdns
|
||||
|
||||
FROM alpine:latest
|
||||
|
||||
COPY --from=builder /root/src/mosdns /usr/bin/
|
||||
|
||||
RUN apk add --no-cache ca-certificates
|
||||
14
Dockerfile_buildx
Normal file
14
Dockerfile_buildx
Normal file
@ -0,0 +1,14 @@
|
||||
FROM --platform=${TARGETPLATFORM} golang:latest as builder
|
||||
|
||||
ARG CGO_ENABLED=0
|
||||
|
||||
COPY ./ /root/src/
|
||||
WORKDIR /root/src/
|
||||
RUN go build -ldflags "-s -w -X main.version=$(git describe --tags --long --always)" -trimpath -o mosdns
|
||||
|
||||
|
||||
FROM --platform=${TARGETPLATFORM} alpine:latest
|
||||
|
||||
COPY --from=builder /root/src/mosdns /usr/bin/
|
||||
|
||||
RUN apk add --no-cache ca-certificates
|
||||
674
LICENSE
Normal file
674
LICENSE
Normal file
@ -0,0 +1,674 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
7
README.md
Normal file
7
README.md
Normal file
@ -0,0 +1,7 @@
|
||||
# mosdns
|
||||
|
||||
功能概述、配置方式、教程等,详见: [wiki](https://irine-sistiana.gitbook.io/mosdns-wiki/)
|
||||
|
||||
下载预编译文件、更新日志,详见: [release](https://github.com/IrineSistiana/mosdns/releases)
|
||||
|
||||
docker 镜像: [docker hub](https://hub.docker.com/r/irinesistiana/mosdns)
|
||||
192
config.yaml
Normal file
192
config.yaml
Normal file
@ -0,0 +1,192 @@
|
||||
# ============================================
|
||||
# MosDNS v5 完整配置:中英文注释版
|
||||
# ============================================
|
||||
|
||||
log:
|
||||
level: info # 可选: debug/info/warn/error
|
||||
|
||||
# 管理 API(可用于调试、监控)
|
||||
api:
|
||||
http: "0.0.0.0:5535"
|
||||
|
||||
# 引入上游 DNS 配置文件(在 dns.yaml 中)
|
||||
include: ['/opt/mosdns/dns.yaml']
|
||||
|
||||
plugins:
|
||||
# ======================
|
||||
# 域名/IP 匹配规则
|
||||
# ======================
|
||||
|
||||
# GFW 域名列表(如 google.com)
|
||||
- tag: GFW_domains
|
||||
type: domain_set
|
||||
args:
|
||||
files:
|
||||
- "/opt/mosdns/config/geosite_tiktok.txt"
|
||||
- "/opt/mosdns/config/gfwlist.out.txt"
|
||||
|
||||
# Amazon 域名列表(如 amazon.com)
|
||||
- tag: amazon_domains
|
||||
type: domain_set
|
||||
args:
|
||||
files:
|
||||
- "/opt/mosdns/config/geosite_amazon.txt"
|
||||
- "/opt/mosdns/config/geosite_amazon-ads.txt"
|
||||
- "/opt/mosdns/config/geosite_amazontrust.txt"
|
||||
- "/opt/mosdns/config/amazon.txt"
|
||||
|
||||
# 中国大陆常用域名(如 .cn / baidu.com)
|
||||
- tag: CN_domains
|
||||
type: domain_set
|
||||
args:
|
||||
files:
|
||||
- "/opt/mosdns/config/domains.txt"
|
||||
|
||||
# 中国大陆 IP 列表
|
||||
- tag: geoip_cn
|
||||
type: ip_set
|
||||
args:
|
||||
files:
|
||||
- "/opt/mosdns/config/cn.txt"
|
||||
|
||||
# ======================
|
||||
# 缓存模块
|
||||
# ======================
|
||||
- tag: cache
|
||||
type: cache
|
||||
args:
|
||||
size: 32768 # 最大缓存条目数
|
||||
lazy_cache_ttl: 43200 # 默认缓存 TTL(秒)
|
||||
|
||||
# ======================
|
||||
# 上游 DNS 定义
|
||||
# ======================
|
||||
|
||||
# 国内 DNS fallback 模式
|
||||
- tag: forward_local
|
||||
type: fallback
|
||||
args:
|
||||
primary: cn-dns
|
||||
secondary: cn-dns
|
||||
threshold: 500
|
||||
always_standby: true
|
||||
|
||||
# 国外 DNS fallback 模式
|
||||
- tag: forward_remote
|
||||
type: fallback
|
||||
args:
|
||||
primary: jp-dns
|
||||
secondary: jp-dns
|
||||
threshold: 500
|
||||
always_standby: true
|
||||
|
||||
# 封装调用国内 DNS
|
||||
- tag: forward_local_upstream
|
||||
type: sequence
|
||||
args:
|
||||
- exec: prefer_ipv4
|
||||
- exec: query_summary forward_local
|
||||
- exec: $forward_local
|
||||
|
||||
# 封装调用国外 DNS
|
||||
- tag: forward_remote_upstream
|
||||
type: sequence
|
||||
args:
|
||||
- exec: prefer_ipv4
|
||||
- exec: query_summary forward_remote
|
||||
- exec: $forward_remote
|
||||
|
||||
# 如果已有响应,直接返回
|
||||
- tag: has_resp_sequence
|
||||
type: sequence
|
||||
args:
|
||||
- matches: has_resp
|
||||
exec: accept
|
||||
|
||||
# ======================
|
||||
# 查询逻辑
|
||||
# ======================
|
||||
|
||||
# 拒绝无效查询(如 HTTPS 记录)
|
||||
- tag: query_is_reject_domain
|
||||
type: sequence
|
||||
args:
|
||||
- matches: qtype 65
|
||||
exec: reject 3
|
||||
|
||||
# GFW 域名:强制走国外 DNS
|
||||
- tag: query_is_foreign_domain
|
||||
type: sequence
|
||||
args:
|
||||
- matches: qname $GFW_domains
|
||||
exec: $forward_remote_upstream
|
||||
- exec: query_summary gfw_domain
|
||||
|
||||
# 国内域名:强制走国内 DNS
|
||||
- tag: query_is_cn_domain
|
||||
type: sequence
|
||||
args:
|
||||
- matches: qname $CN_domains
|
||||
exec: $forward_local_upstream
|
||||
- exec: query_summary cn_domain
|
||||
|
||||
# Amazon 域名:走国外 DNS 并添加到 MikroTik
|
||||
- tag: query_is_amazon_domain
|
||||
type: sequence
|
||||
args:
|
||||
- matches: qname $amazon_domains
|
||||
exec: $forward_remote_upstream
|
||||
- exec: $mikrotik_amazon
|
||||
- exec: query_summary amazon_domain
|
||||
|
||||
# 未知域名处理逻辑:
|
||||
# 先查国内 DNS → 如果返回非 CN IP ⇒ fallback 到国外
|
||||
- tag: query_unknown_fallback
|
||||
type: sequence
|
||||
args:
|
||||
- exec: prefer_ipv4
|
||||
- exec: $forward_local
|
||||
- matches: resp_ip $geoip_cn
|
||||
exec: accept
|
||||
- exec: $forward_remote_upstream
|
||||
- exec: query_summary fallback_to_remote
|
||||
|
||||
# ======================
|
||||
# 主查询处理流程
|
||||
# ======================
|
||||
- tag: main_sequence
|
||||
type: sequence
|
||||
args:
|
||||
- exec: $cache # 首先查缓存
|
||||
- exec: $query_is_reject_domain
|
||||
- exec: jump has_resp_sequence
|
||||
|
||||
- exec: $query_is_foreign_domain # gfwlist
|
||||
- exec: jump has_resp_sequence
|
||||
|
||||
- exec: $query_is_cn_domain # 国内域名
|
||||
- exec: jump has_resp_sequence
|
||||
|
||||
- exec: $query_is_amazon_domain # Amazon 域名(走国外 DNS + 添加到 MikroTik)
|
||||
- exec: jump has_resp_sequence
|
||||
|
||||
- exec: $query_unknown_fallback # 其他未知域名 fallback 流程
|
||||
- exec: jump has_resp_sequence
|
||||
|
||||
# ======================
|
||||
# 服务监听
|
||||
# ======================
|
||||
|
||||
# UDP 监听
|
||||
- tag: udp_server
|
||||
type: udp_server
|
||||
args:
|
||||
entry: main_sequence
|
||||
listen: ":5300"
|
||||
|
||||
# TCP 监听
|
||||
- tag: tcp_server
|
||||
type: tcp_server
|
||||
args:
|
||||
entry: main_sequence
|
||||
listen: ":5300"
|
||||
50
coremain/config.go
Normal file
50
coremain/config.go
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package coremain
|
||||
|
||||
import (
|
||||
"github.com/IrineSistiana/mosdns/v5/mlog"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Log mlog.LogConfig `yaml:"log"`
|
||||
Include []string `yaml:"include"`
|
||||
Plugins []PluginConfig `yaml:"plugins"`
|
||||
API APIConfig `yaml:"api"`
|
||||
}
|
||||
|
||||
// PluginConfig represents a plugin config
|
||||
type PluginConfig struct {
|
||||
// Tag for this plugin. Optional. If omitted, this plugin will
|
||||
// be registered with a random tag.
|
||||
Tag string `yaml:"tag"`
|
||||
|
||||
// Type, required.
|
||||
Type string `yaml:"type"`
|
||||
|
||||
// Args, might be required by some plugins.
|
||||
// The type of Args is depended on RegNewPluginFunc.
|
||||
// If it's a map[string]any, it will be converted by mapstruct.
|
||||
Args any `yaml:"args"`
|
||||
}
|
||||
|
||||
type APIConfig struct {
|
||||
HTTP string `yaml:"http"`
|
||||
}
|
||||
244
coremain/mosdns.go
Normal file
244
coremain/mosdns.go
Normal file
@ -0,0 +1,244 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package coremain
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/mlog"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/safe_close"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
)
|
||||
|
||||
type Mosdns struct {
|
||||
logger *zap.Logger // non-nil logger.
|
||||
|
||||
// Plugins
|
||||
plugins map[string]any
|
||||
|
||||
httpMux *chi.Mux
|
||||
metricsReg *prometheus.Registry
|
||||
sc *safe_close.SafeClose
|
||||
}
|
||||
|
||||
// NewMosdns initializes a mosdns instance and its plugins.
|
||||
func NewMosdns(cfg *Config) (*Mosdns, error) {
|
||||
// Init logger.
|
||||
lg, err := mlog.NewLogger(cfg.Log)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init logger: %w", err)
|
||||
}
|
||||
|
||||
m := &Mosdns{
|
||||
logger: lg,
|
||||
plugins: make(map[string]any),
|
||||
httpMux: chi.NewRouter(),
|
||||
metricsReg: newMetricsReg(),
|
||||
sc: safe_close.NewSafeClose(),
|
||||
}
|
||||
// This must be called after m.httpMux and m.metricsReg been set.
|
||||
m.initHttpMux()
|
||||
|
||||
// Start http api server
|
||||
if httpAddr := cfg.API.HTTP; len(httpAddr) > 0 {
|
||||
httpServer := &http.Server{
|
||||
Addr: httpAddr,
|
||||
Handler: m.httpMux,
|
||||
}
|
||||
m.sc.Attach(func(done func(), closeSignal <-chan struct{}) {
|
||||
defer done()
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
m.logger.Info("starting api http server", zap.String("addr", httpAddr))
|
||||
errChan <- httpServer.ListenAndServe()
|
||||
}()
|
||||
select {
|
||||
case err := <-errChan:
|
||||
m.sc.SendCloseSignal(err)
|
||||
case <-closeSignal:
|
||||
_ = httpServer.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Load plugins.
|
||||
|
||||
// Close all plugins on signal.
|
||||
// From here, call m.sc.SendCloseSignal() if any plugin failed to load.
|
||||
m.sc.Attach(func(done func(), closeSignal <-chan struct{}) {
|
||||
go func() {
|
||||
defer done()
|
||||
<-closeSignal
|
||||
m.logger.Info("starting shutdown sequences")
|
||||
for tag, p := range m.plugins {
|
||||
if closer, _ := p.(io.Closer); closer != nil {
|
||||
m.logger.Info("closing plugin", zap.String("tag", tag))
|
||||
_ = closer.Close()
|
||||
}
|
||||
}
|
||||
m.logger.Info("all plugins were closed")
|
||||
}()
|
||||
})
|
||||
|
||||
// Preset plugins
|
||||
if err := m.loadPresetPlugins(); err != nil {
|
||||
m.sc.SendCloseSignal(err)
|
||||
_ = m.sc.WaitClosed()
|
||||
return nil, err
|
||||
}
|
||||
// Plugins from config.
|
||||
if err := m.loadPluginsFromCfg(cfg, 0); err != nil {
|
||||
m.sc.SendCloseSignal(err)
|
||||
_ = m.sc.WaitClosed()
|
||||
return nil, err
|
||||
}
|
||||
m.logger.Info("all plugins are loaded")
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// NewTestMosdnsWithPlugins returns a mosdns instance for testing.
|
||||
func NewTestMosdnsWithPlugins(p map[string]any) *Mosdns {
|
||||
return &Mosdns{
|
||||
logger: mlog.Nop(),
|
||||
httpMux: chi.NewRouter(),
|
||||
plugins: p,
|
||||
metricsReg: newMetricsReg(),
|
||||
sc: safe_close.NewSafeClose(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mosdns) GetSafeClose() *safe_close.SafeClose {
|
||||
return m.sc
|
||||
}
|
||||
|
||||
// CloseWithErr is a shortcut for m.sc.SendCloseSignal
|
||||
func (m *Mosdns) CloseWithErr(err error) {
|
||||
m.sc.SendCloseSignal(err)
|
||||
}
|
||||
|
||||
// Logger returns a non-nil logger.
|
||||
func (m *Mosdns) Logger() *zap.Logger {
|
||||
return m.logger
|
||||
}
|
||||
|
||||
// GetPlugin returns a plugin.
|
||||
func (m *Mosdns) GetPlugin(tag string) any {
|
||||
return m.plugins[tag]
|
||||
}
|
||||
|
||||
// GetMetricsReg returns a prometheus.Registerer with a prefix of "mosdns_"
|
||||
func (m *Mosdns) GetMetricsReg() prometheus.Registerer {
|
||||
return prometheus.WrapRegistererWithPrefix("mosdns_", m.metricsReg)
|
||||
}
|
||||
|
||||
func (m *Mosdns) GetAPIRouter() *chi.Mux {
|
||||
return m.httpMux
|
||||
}
|
||||
|
||||
func (m *Mosdns) RegPluginAPI(tag string, mux *chi.Mux) {
|
||||
m.httpMux.Mount("/plugins/"+tag, mux)
|
||||
}
|
||||
|
||||
func newMetricsReg() *prometheus.Registry {
|
||||
reg := prometheus.NewRegistry()
|
||||
reg.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
||||
reg.MustRegister(collectors.NewGoCollector())
|
||||
return reg
|
||||
}
|
||||
|
||||
// initHttpMux initializes api entries. It MUST be called after m.metricsReg being initialized.
|
||||
func (m *Mosdns) initHttpMux() {
|
||||
// Register metrics.
|
||||
m.httpMux.Method(http.MethodGet, "/metrics", promhttp.HandlerFor(m.metricsReg, promhttp.HandlerOpts{}))
|
||||
|
||||
// Register pprof.
|
||||
m.httpMux.Route("/debug/pprof", func(r chi.Router) {
|
||||
r.Get("/*", pprof.Index)
|
||||
r.Get("/cmdline", pprof.Cmdline)
|
||||
r.Get("/profile", pprof.Profile)
|
||||
r.Get("/symbol", pprof.Symbol)
|
||||
r.Get("/trace", pprof.Trace)
|
||||
})
|
||||
|
||||
// A helper page for invalid request.
|
||||
invalidApiReqHelper := func(w http.ResponseWriter, req *http.Request) {
|
||||
b := new(bytes.Buffer)
|
||||
_, _ = fmt.Fprintf(b, "Invalid request %s %s\n\n", req.Method, req.RequestURI)
|
||||
b.WriteString("Available api urls:\n")
|
||||
_ = chi.Walk(m.httpMux, func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
|
||||
b.WriteString(method)
|
||||
b.WriteByte(' ')
|
||||
b.WriteString(route)
|
||||
b.WriteByte('\n')
|
||||
return nil
|
||||
})
|
||||
_, _ = w.Write(b.Bytes())
|
||||
}
|
||||
m.httpMux.NotFound(invalidApiReqHelper)
|
||||
m.httpMux.MethodNotAllowed(invalidApiReqHelper)
|
||||
}
|
||||
|
||||
func (m *Mosdns) loadPresetPlugins() error {
|
||||
for tag, f := range LoadNewPersetPluginFuncs() {
|
||||
p, err := f(NewBP(tag, m))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init preset plugin %s, %w", tag, err)
|
||||
}
|
||||
m.plugins[tag] = p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadPluginsFromCfg loads plugins from this config. It follows include first.
|
||||
func (m *Mosdns) loadPluginsFromCfg(cfg *Config, includeDepth int) error {
|
||||
const maxIncludeDepth = 8
|
||||
if includeDepth > maxIncludeDepth {
|
||||
return errors.New("maximum include depth reached")
|
||||
}
|
||||
includeDepth++
|
||||
|
||||
// Follow include first.
|
||||
for _, s := range cfg.Include {
|
||||
subCfg, path, err := loadConfig(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config from %s, %w", s, err)
|
||||
}
|
||||
m.logger.Info("load config", zap.String("file", path))
|
||||
if err := m.loadPluginsFromCfg(subCfg, includeDepth); err != nil {
|
||||
return fmt.Errorf("failed to load config from %s, %w", s, err)
|
||||
}
|
||||
}
|
||||
|
||||
for i, pc := range cfg.Plugins {
|
||||
if err := m.newPlugin(pc); err != nil {
|
||||
return fmt.Errorf("failed to init plugin #%d %s, %w", i, pc.Tag, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
201
coremain/plugin.go
Normal file
201
coremain/plugin.go
Normal file
@ -0,0 +1,201 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package coremain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"go.uber.org/zap"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// NewPluginArgsFunc represents a func that creates a new args object.
|
||||
type NewPluginArgsFunc func() any
|
||||
|
||||
// NewPluginFunc represents a func that can init a Plugin.
|
||||
// args is the object created by NewPluginArgsFunc.
|
||||
type NewPluginFunc func(bp *BP, args any) (p any, err error)
|
||||
|
||||
type PluginTypeInfo struct {
|
||||
NewPlugin NewPluginFunc
|
||||
NewArgs NewPluginArgsFunc
|
||||
}
|
||||
|
||||
var (
|
||||
// pluginTypeRegister stores init funcs for certain plugin types
|
||||
pluginTypeRegister struct {
|
||||
sync.RWMutex
|
||||
m map[string]PluginTypeInfo
|
||||
}
|
||||
)
|
||||
|
||||
// RegNewPluginFunc registers the type.
|
||||
// If the type has been registered. RegNewPluginFunc will panic.
|
||||
func RegNewPluginFunc(typ string, initFunc NewPluginFunc, argsType NewPluginArgsFunc) {
|
||||
pluginTypeRegister.Lock()
|
||||
defer pluginTypeRegister.Unlock()
|
||||
|
||||
if pluginTypeRegister.m == nil {
|
||||
pluginTypeRegister.m = make(map[string]PluginTypeInfo)
|
||||
}
|
||||
|
||||
_, ok := pluginTypeRegister.m[typ]
|
||||
if ok {
|
||||
panic(fmt.Sprintf("duplicate plugin type [%s]", typ))
|
||||
}
|
||||
|
||||
pluginTypeRegister.m[typ] = PluginTypeInfo{
|
||||
NewPlugin: initFunc,
|
||||
NewArgs: argsType,
|
||||
}
|
||||
}
|
||||
|
||||
// DelPluginType deletes the init func for this plugin type.
|
||||
// It is a noop if pluginType is not registered.
|
||||
func DelPluginType(typ string) {
|
||||
pluginTypeRegister.Lock()
|
||||
defer pluginTypeRegister.Unlock()
|
||||
delete(pluginTypeRegister.m, typ)
|
||||
}
|
||||
|
||||
// GetPluginType gets the registered type init func.
|
||||
func GetPluginType(typ string) (PluginTypeInfo, bool) {
|
||||
pluginTypeRegister.RLock()
|
||||
defer pluginTypeRegister.RUnlock()
|
||||
|
||||
info, ok := pluginTypeRegister.m[typ]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// newPlugin initializes a Plugin from c and adds it to mosdns.
|
||||
func (m *Mosdns) newPlugin(c PluginConfig) error {
|
||||
if len(c.Tag) == 0 {
|
||||
c.Tag = fmt.Sprintf("anonymouse_%s_%d", c.Type, len(m.plugins))
|
||||
}
|
||||
|
||||
if _, dup := m.plugins[c.Tag]; dup {
|
||||
return fmt.Errorf("duplicated plugin tag %s", c.Tag)
|
||||
}
|
||||
|
||||
typeInfo, ok := GetPluginType(c.Type)
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin type %s not defined", c.Type)
|
||||
}
|
||||
|
||||
args := typeInfo.NewArgs()
|
||||
if reflect.TypeOf(c.Args) == reflect.TypeOf(args) { // Same type, no need to parse.
|
||||
args = c.Args
|
||||
} else {
|
||||
if err := utils.WeakDecode(c.Args, args); err != nil {
|
||||
return fmt.Errorf("unable to decode plugin args: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Info("loading plugin", zap.String("tag", c.Tag), zap.String("type", c.Type))
|
||||
p, err := typeInfo.NewPlugin(NewBP(c.Tag, m), args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init plugin: %w", err)
|
||||
}
|
||||
m.plugins[c.Tag] = p
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllPluginTypes returns all plugin types which are configurable.
|
||||
func GetAllPluginTypes() []string {
|
||||
pluginTypeRegister.RLock()
|
||||
defer pluginTypeRegister.RUnlock()
|
||||
|
||||
var t []string
|
||||
for typ := range pluginTypeRegister.m {
|
||||
t = append(t, typ)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type NewPersetPluginFunc func(bp *BP) (any, error)
|
||||
|
||||
var presetPluginFuncReg struct {
|
||||
sync.Mutex
|
||||
m map[string]NewPersetPluginFunc
|
||||
}
|
||||
|
||||
func RegNewPersetPluginFunc(tag string, f NewPersetPluginFunc) {
|
||||
presetPluginFuncReg.Lock()
|
||||
defer presetPluginFuncReg.Unlock()
|
||||
|
||||
if presetPluginFuncReg.m == nil {
|
||||
presetPluginFuncReg.m = make(map[string]NewPersetPluginFunc)
|
||||
}
|
||||
if _, ok := presetPluginFuncReg.m[tag]; ok {
|
||||
panic(fmt.Sprintf("preset plugin %s has already been registered", tag))
|
||||
}
|
||||
|
||||
presetPluginFuncReg.m[tag] = f
|
||||
}
|
||||
|
||||
func LoadNewPersetPluginFuncs() map[string]NewPersetPluginFunc {
|
||||
presetPluginFuncReg.Lock()
|
||||
defer presetPluginFuncReg.Unlock()
|
||||
m := make(map[string]NewPersetPluginFunc)
|
||||
for tag, pluginFunc := range presetPluginFuncReg.m {
|
||||
m[tag] = pluginFunc
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
type BP struct {
|
||||
tag string
|
||||
m *Mosdns
|
||||
l *zap.Logger
|
||||
}
|
||||
|
||||
// NewBP creates a new BP. m MUST NOT nil.
|
||||
func NewBP(tag string, m *Mosdns) *BP {
|
||||
return &BP{
|
||||
tag: tag,
|
||||
l: m.Logger().Named(tag),
|
||||
m: m,
|
||||
}
|
||||
}
|
||||
|
||||
// L returns a non-nil logger.
|
||||
func (p *BP) L() *zap.Logger {
|
||||
return p.l
|
||||
}
|
||||
|
||||
// M returns a non-nil Mosdns.
|
||||
func (p *BP) M() *Mosdns {
|
||||
return p.m
|
||||
}
|
||||
|
||||
// Tag returns the plugin tag.
|
||||
// This tag should be unique globally unless it's in
|
||||
// a test environment.
|
||||
func (p *BP) Tag() string {
|
||||
return p.tag
|
||||
}
|
||||
|
||||
// RegAPI mounts mux to mosdns api. Note: Plugins MUST NOT call RegAPI twice.
|
||||
// Since mounting same path to root chi.Mux causes runtime panic.
|
||||
func (p *BP) RegAPI(mux *chi.Mux) {
|
||||
p.m.RegPluginAPI(p.tag, mux)
|
||||
}
|
||||
159
coremain/run.go
Normal file
159
coremain/run.go
Normal file
@ -0,0 +1,159 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package coremain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/mlog"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type serverFlags struct {
|
||||
c string
|
||||
dir string
|
||||
cpu int
|
||||
asService bool
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "mosdns",
|
||||
}
|
||||
|
||||
func init() {
|
||||
sf := new(serverFlags)
|
||||
startCmd := &cobra.Command{
|
||||
Use: "start [-c config_file] [-d working_dir]",
|
||||
Short: "Start mosdns main program.",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if sf.asService {
|
||||
svc, err := service.New(&serverService{f: sf}, svcCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init service, %w", err)
|
||||
}
|
||||
return svc.Run()
|
||||
}
|
||||
|
||||
m, err := NewServer(sf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
|
||||
sig := <-c
|
||||
m.logger.Warn("signal received", zap.Stringer("signal", sig))
|
||||
m.sc.SendCloseSignal(nil)
|
||||
}()
|
||||
return m.GetSafeClose().WaitClosed()
|
||||
},
|
||||
DisableFlagsInUseLine: true,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
rootCmd.AddCommand(startCmd)
|
||||
fs := startCmd.Flags()
|
||||
fs.StringVarP(&sf.c, "config", "c", "", "config file")
|
||||
fs.StringVarP(&sf.dir, "dir", "d", "", "working dir")
|
||||
fs.IntVar(&sf.cpu, "cpu", 0, "set runtime.GOMAXPROCS")
|
||||
fs.BoolVar(&sf.asService, "as-service", false, "start as a service")
|
||||
_ = fs.MarkHidden("as-service")
|
||||
|
||||
serviceCmd := &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "Manage mosdns as a system service.",
|
||||
}
|
||||
serviceCmd.PersistentPreRunE = initService
|
||||
serviceCmd.AddCommand(
|
||||
newSvcInstallCmd(),
|
||||
newSvcUninstallCmd(),
|
||||
newSvcStartCmd(),
|
||||
newSvcStopCmd(),
|
||||
newSvcRestartCmd(),
|
||||
newSvcStatusCmd(),
|
||||
)
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
}
|
||||
|
||||
func AddSubCmd(c *cobra.Command) {
|
||||
rootCmd.AddCommand(c)
|
||||
}
|
||||
|
||||
func Run() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func NewServer(sf *serverFlags) (*Mosdns, error) {
|
||||
if sf.cpu > 0 {
|
||||
runtime.GOMAXPROCS(sf.cpu)
|
||||
}
|
||||
|
||||
if len(sf.dir) > 0 {
|
||||
err := os.Chdir(sf.dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to change the current working directory, %w", err)
|
||||
}
|
||||
mlog.L().Info("working directory changed", zap.String("path", sf.dir))
|
||||
}
|
||||
|
||||
cfg, fileUsed, err := loadConfig(sf.c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to load config, %w", err)
|
||||
}
|
||||
mlog.L().Info("main config loaded", zap.String("file", fileUsed))
|
||||
|
||||
return NewMosdns(cfg)
|
||||
}
|
||||
|
||||
// loadConfig load a config from a file. If filePath is empty, it will
|
||||
// automatically search and load a file which name start with "config".
|
||||
func loadConfig(filePath string) (*Config, string, error) {
|
||||
v := viper.New()
|
||||
|
||||
if len(filePath) > 0 {
|
||||
v.SetConfigFile(filePath)
|
||||
} else {
|
||||
v.SetConfigName("config")
|
||||
v.AddConfigPath(".")
|
||||
}
|
||||
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
|
||||
decoderOpt := func(cfg *mapstructure.DecoderConfig) {
|
||||
cfg.ErrorUnused = true
|
||||
cfg.TagName = "yaml"
|
||||
cfg.WeaklyTypedInput = true
|
||||
}
|
||||
|
||||
cfg := new(Config)
|
||||
if err := v.Unmarshal(cfg, decoderOpt); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
return cfg, v.ConfigFileUsed(), nil
|
||||
}
|
||||
220
coremain/service.go
Normal file
220
coremain/service.go
Normal file
@ -0,0 +1,220 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package coremain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/mlog"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// initialized by "service" sub command
|
||||
svc service.Service
|
||||
svcCfg = &service.Config{
|
||||
Name: "mosdns",
|
||||
DisplayName: "mosdns",
|
||||
Description: "A DNS forwarder",
|
||||
}
|
||||
)
|
||||
|
||||
type serverService struct {
|
||||
f *serverFlags
|
||||
m *Mosdns
|
||||
}
|
||||
|
||||
func (ss *serverService) Start(s service.Service) error {
|
||||
mlog.L().Info("starting service", zap.String("platform", s.Platform()))
|
||||
m, err := NewServer(ss.f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ss.m = m
|
||||
go func() {
|
||||
err := m.GetSafeClose().WaitClosed()
|
||||
if err != nil {
|
||||
m.Logger().Fatal("server exited", zap.Error(err))
|
||||
} else {
|
||||
m.Logger().Info("server exited")
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ss *serverService) Stop(_ service.Service) error {
|
||||
ss.m.Logger().Info("service is shutting down")
|
||||
ss.m.GetSafeClose().SendCloseSignal(nil)
|
||||
return ss.m.GetSafeClose().WaitClosed()
|
||||
}
|
||||
|
||||
// initService will init svc for sub command "service"
|
||||
func initService(_ *cobra.Command, _ []string) error {
|
||||
s, err := service.New(&serverService{}, svcCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot init service, %w", err)
|
||||
}
|
||||
svc = s
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSvcInstallCmd() *cobra.Command {
|
||||
sf := new(serverFlags)
|
||||
c := &cobra.Command{
|
||||
Use: "install [-d working_dir] [-c config_file]",
|
||||
Short: "Install mosdns as a system service.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(sf.dir) > 0 {
|
||||
absWd, err := filepath.Abs(sf.dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot solve absolute working dir path, %w", err)
|
||||
} else {
|
||||
sf.dir = absWd
|
||||
}
|
||||
} else {
|
||||
ep, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot solve current executable path, %w", err)
|
||||
}
|
||||
sf.dir = filepath.Dir(ep)
|
||||
}
|
||||
mlog.S().Infof("set service working dir as %s", sf.dir)
|
||||
svcCfg.Arguments = []string{"start", "--as-service", "-d", sf.dir}
|
||||
if len(sf.c) > 0 {
|
||||
svcCfg.Arguments = append(svcCfg.Arguments, "-c", sf.c)
|
||||
}
|
||||
return svc.Install()
|
||||
},
|
||||
DisableFlagsInUseLine: true,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
c.Flags().StringVarP(&sf.dir, "dir", "d", "", "working dir")
|
||||
c.Flags().StringVarP(&sf.c, "config", "c", "", "config path")
|
||||
return c
|
||||
}
|
||||
|
||||
func newSvcUninstallCmd() *cobra.Command {
|
||||
c := &cobra.Command{
|
||||
Use: "uninstall",
|
||||
Short: "Uninstall mosdns from system service.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return svc.Uninstall()
|
||||
},
|
||||
DisableFlagsInUseLine: true,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func newSvcStartCmd() *cobra.Command {
|
||||
c := &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "Start mosdns system service.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := svc.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mlog.S().Info("service is starting")
|
||||
time.Sleep(time.Second)
|
||||
s, err := svc.Status()
|
||||
if err != nil {
|
||||
mlog.S().Warn("cannot get service status, %w", err)
|
||||
} else {
|
||||
switch s {
|
||||
case service.StatusRunning:
|
||||
mlog.S().Info("service is running")
|
||||
case service.StatusStopped:
|
||||
mlog.S().Error("service is stopped, check mosdns and system service log for more info")
|
||||
default:
|
||||
mlog.S().Warn("cannot get service status, system may not support this operation")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
DisableFlagsInUseLine: true,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func newSvcStopCmd() *cobra.Command {
|
||||
c := &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "Stop mosdns system service.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return svc.Stop()
|
||||
},
|
||||
DisableFlagsInUseLine: true,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func newSvcRestartCmd() *cobra.Command {
|
||||
c := &cobra.Command{
|
||||
Use: "restart",
|
||||
Short: "Restart mosdns system service.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return svc.Restart()
|
||||
},
|
||||
DisableFlagsInUseLine: true,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func newSvcStatusCmd() *cobra.Command {
|
||||
c := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Status of mosdns system service.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
s, err := svc.Status()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get service status, %w", err)
|
||||
}
|
||||
var out string
|
||||
switch s {
|
||||
case service.StatusRunning:
|
||||
out = "running"
|
||||
case service.StatusStopped:
|
||||
out = "stopped"
|
||||
case service.StatusUnknown:
|
||||
out = "unknown"
|
||||
}
|
||||
println(out)
|
||||
return nil
|
||||
},
|
||||
DisableFlagsInUseLine: true,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
return c
|
||||
}
|
||||
192
deploy-mikrotik-amazon-updated.md
Normal file
192
deploy-mikrotik-amazon-updated.md
Normal file
@ -0,0 +1,192 @@
|
||||
# MosDNS + MikroTik Amazon 域名处理部署指南(更新版)
|
||||
|
||||
## 功能说明
|
||||
|
||||
这个配置会在解析 Amazon 相关域名时,自动将解析到的 IP 地址添加到 MikroTik 路由器的 address list 中,用于防火墙规则控制。
|
||||
|
||||
## 部署步骤
|
||||
|
||||
### 1. 上传文件到 Debian 12 服务器
|
||||
|
||||
```bash
|
||||
# 上传编译好的 mosdns 可执行文件
|
||||
scp mosdns-linux-amd64 user@your-server:/usr/local/bin/mosdns
|
||||
|
||||
# 上传配置文件
|
||||
scp config.yaml user@your-server:/opt/mosdns/
|
||||
scp dns.yaml user@your-server:/opt/mosdns/
|
||||
|
||||
# 设置执行权限
|
||||
ssh user@your-server "chmod +x /usr/local/bin/mosdns"
|
||||
```
|
||||
|
||||
### 2. 创建必要的目录和文件
|
||||
|
||||
```bash
|
||||
# 创建配置目录
|
||||
sudo mkdir -p /opt/mosdns/config
|
||||
|
||||
# 下载 Amazon 域名列表
|
||||
sudo wget -O /opt/mosdns/config/geosite_amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon
|
||||
sudo wget -O /opt/mosdns/config/geosite_amazon-ads.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon-ads
|
||||
sudo wget -O /opt/mosdns/config/geosite_amazontrust.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazontrust
|
||||
sudo wget -O /opt/mosdns/config/amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon
|
||||
|
||||
# 下载其他必要的域名和 IP 文件
|
||||
sudo wget -O /opt/mosdns/config/geosite_tiktok.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/tiktok
|
||||
sudo wget -O /opt/mosdns/config/gfwlist.out.txt https://raw.githubusercontent.com/gfwlist/gfwlist/master/gfwlist.txt
|
||||
sudo wget -O /opt/mosdns/config/domains.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/category-games
|
||||
sudo wget -O /opt/mosdns/config/cn.txt https://raw.githubusercontent.com/Loyalsoldier/v2ray-rules-dat/release/geoip.dat
|
||||
```
|
||||
|
||||
### 3. 在 MikroTik 中创建 address list
|
||||
|
||||
```bash
|
||||
# 通过 SSH 连接到 MikroTik 路由器
|
||||
ssh admin@10.248.0.1
|
||||
|
||||
# 创建 IPv4 和 IPv6 address list
|
||||
/ip firewall address-list add list=AmazonIP
|
||||
/ip firewall address-list add list=AmazonIP6
|
||||
|
||||
# 创建防火墙规则(可选)
|
||||
/ip firewall filter add chain=forward src-address-list=AmazonIP action=drop comment="Block Amazon IPs"
|
||||
/ip firewall filter add chain=forward src-address-list=AmazonIP6 action=drop comment="Block Amazon IPv6 IPs"
|
||||
```
|
||||
|
||||
### 4. 修改配置文件中的 MikroTik 连接信息
|
||||
|
||||
编辑 `/opt/mosdns/dns.yaml` 文件,确认 mikrotik_amazon 插件的配置:
|
||||
|
||||
```yaml
|
||||
# 当前配置(根据你的实际情况修改)
|
||||
args: "10.248.0.1:9728:admin:szn0s!nw@pwd():false:10:AmazonIP:AmazonIP6:24:32:AmazonIP:86400"
|
||||
```
|
||||
|
||||
参数说明:
|
||||
- `10.248.0.1`: MikroTik 路由器 IP
|
||||
- `9728`: API 端口
|
||||
- `admin`: 用户名
|
||||
- `szn0s!nw@pwd()`: 密码
|
||||
- `false`: 不使用 TLS
|
||||
- `10`: 连接超时时间
|
||||
- `AmazonIP`: IPv4 address list 名称
|
||||
- `AmazonIP6`: IPv6 address list 名称
|
||||
- `24`: IPv4 掩码
|
||||
- `32`: IPv6 掩码
|
||||
- `AmazonIP`: 注释
|
||||
- `86400`: 地址超时时间(24小时)
|
||||
|
||||
### 5. 创建 systemd 服务
|
||||
|
||||
```bash
|
||||
sudo tee /etc/systemd/system/mosdns.service > /dev/null <<EOF
|
||||
[Unit]
|
||||
Description=MosDNS DNS Server
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=root
|
||||
ExecStart=/usr/local/bin/mosdns -c /opt/mosdns/config.yaml
|
||||
Restart=always
|
||||
RestartSec=3
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# 重新加载 systemd 配置
|
||||
sudo systemctl daemon-reload
|
||||
|
||||
# 启用并启动服务
|
||||
sudo systemctl enable mosdns
|
||||
sudo systemctl start mosdns
|
||||
|
||||
# 检查服务状态
|
||||
sudo systemctl status mosdns
|
||||
```
|
||||
|
||||
### 6. 配置 DNS 转发
|
||||
|
||||
修改 `/etc/systemd/resolved.conf`:
|
||||
|
||||
```ini
|
||||
[Resolve]
|
||||
DNS=127.0.0.1:5300
|
||||
#FallbackDNS=8.8.8.8 8.8.4.4
|
||||
#Domains=
|
||||
#DNSSEC=no
|
||||
#DNSOverTLS=no
|
||||
#MulticastDNS=yes
|
||||
#LLMNR=yes
|
||||
#Cache=yes
|
||||
#DNSStubListener=no
|
||||
```
|
||||
|
||||
重启 systemd-resolved:
|
||||
|
||||
```bash
|
||||
sudo systemctl restart systemd-resolved
|
||||
```
|
||||
|
||||
### 7. 测试配置
|
||||
|
||||
```bash
|
||||
# 测试 Amazon 域名解析
|
||||
nslookup amazon.com 127.0.0.1:5300
|
||||
nslookup aws.amazon.com 127.0.0.1:5300
|
||||
|
||||
# 检查 MikroTik address list 是否更新
|
||||
ssh admin@10.248.0.1 "/ip firewall address-list print where list=AmazonIP"
|
||||
|
||||
# 查看 mosdns 日志
|
||||
sudo journalctl -u mosdns -f
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 工作流程
|
||||
|
||||
1. **域名匹配**:当查询 Amazon 相关域名时,匹配 `amazon_domains` 集合
|
||||
2. **DNS 解析**:使用国外 DNS 服务器解析域名
|
||||
3. **IP 提取**:从 DNS 响应中提取 A 和 AAAA 记录
|
||||
4. **地址添加**:通过 MikroTik API 将 IP 添加到 address list
|
||||
5. **超时管理**:IP 地址会在 24 小时后自动过期
|
||||
|
||||
### 监控和调试
|
||||
|
||||
```bash
|
||||
# 查看实时日志
|
||||
sudo journalctl -u mosdns -f
|
||||
|
||||
# 查看服务状态
|
||||
sudo systemctl status mosdns
|
||||
|
||||
# 测试 MikroTik 连接
|
||||
curl -k https://10.248.0.1:9729/api/rest/ip/firewall/address-list
|
||||
|
||||
# 查看 API 状态
|
||||
curl http://localhost:5535/metrics
|
||||
```
|
||||
|
||||
### 故障排除
|
||||
|
||||
1. **连接失败**:检查 MikroTik IP、端口和认证信息
|
||||
2. **权限不足**:确保 MikroTik 用户具有管理 address list 的权限
|
||||
3. **域名文件缺失**:确保所有域名列表文件都已下载
|
||||
4. **DNS 解析失败**:检查上游 DNS 服务器配置
|
||||
|
||||
## 安全注意事项
|
||||
|
||||
1. 不要在配置文件中使用明文密码,考虑使用环境变量
|
||||
2. 限制对 MikroTik API 端口的访问
|
||||
3. 定期更新域名列表文件
|
||||
4. 监控 address list 大小,避免过多条目影响性能
|
||||
|
||||
## 更新日志
|
||||
|
||||
- 修复了插件注册问题,现在支持 YAML 配置和快速配置
|
||||
- 更新了路径配置为 `/opt/mosdns/`
|
||||
- 更新了端口配置为 `5300`
|
||||
- 更新了 API 端口为 `5535`
|
||||
182
deploy-mikrotik-amazon.md
Normal file
182
deploy-mikrotik-amazon.md
Normal file
@ -0,0 +1,182 @@
|
||||
# MosDNS + MikroTik Amazon 域名处理部署指南
|
||||
|
||||
## 功能说明
|
||||
|
||||
这个配置会在解析 Amazon 相关域名时,自动将解析到的 IP 地址添加到 MikroTik 路由器的 address list 中,用于防火墙规则控制。
|
||||
|
||||
## 部署步骤
|
||||
|
||||
### 1. 上传文件到 Debian 12 服务器
|
||||
|
||||
```bash
|
||||
# 上传编译好的 mosdns 可执行文件
|
||||
scp mosdns-linux-amd64 user@your-server:/usr/local/bin/mosdns
|
||||
|
||||
# 上传配置文件
|
||||
scp config.yaml user@your-server:/usr/local/mosdns/
|
||||
scp dns.yaml user@your-server:/usr/local/mosdns/
|
||||
|
||||
# 设置执行权限
|
||||
ssh user@your-server "chmod +x /usr/local/bin/mosdns"
|
||||
```
|
||||
|
||||
### 2. 创建必要的目录和文件
|
||||
|
||||
```bash
|
||||
# 创建配置目录
|
||||
sudo mkdir -p /usr/local/mosdns/config
|
||||
|
||||
# 下载 Amazon 域名列表
|
||||
sudo wget -O /usr/local/mosdns/config/geosite_amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon
|
||||
sudo wget -O /usr/local/mosdns/config/geosite_amazon-ads.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon-ads
|
||||
sudo wget -O /usr/local/mosdns/config/geosite_amazontrust.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazontrust
|
||||
sudo wget -O /usr/local/mosdns/config/amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon
|
||||
|
||||
# 下载其他必要的域名和 IP 文件
|
||||
sudo wget -O /usr/local/mosdns/config/geosite_tiktok.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/tiktok
|
||||
sudo wget -O /usr/local/mosdns/config/gfwlist.out.txt https://raw.githubusercontent.com/gfwlist/gfwlist/master/gfwlist.txt
|
||||
sudo wget -O /usr/local/mosdns/config/domains.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/category-games
|
||||
sudo wget -O /usr/local/mosdns/config/cn.txt https://raw.githubusercontent.com/Loyalsoldier/v2ray-rules-dat/release/geoip.dat
|
||||
```
|
||||
|
||||
### 3. 在 MikroTik 中创建 address list
|
||||
|
||||
```bash
|
||||
# 通过 SSH 连接到 MikroTik 路由器
|
||||
ssh admin@192.168.1.1
|
||||
|
||||
# 创建 IPv4 和 IPv6 address list
|
||||
/ip firewall address-list add list=amazon_ips
|
||||
/ip firewall address-list add list=amazon_ips6
|
||||
|
||||
# 创建防火墙规则(可选)
|
||||
/ip firewall filter add chain=forward src-address-list=amazon_ips action=drop comment="Block Amazon IPs"
|
||||
/ip firewall filter add chain=forward src-address-list=amazon_ips6 action=drop comment="Block Amazon IPv6 IPs"
|
||||
```
|
||||
|
||||
### 4. 修改配置文件中的 MikroTik 连接信息
|
||||
|
||||
编辑 `/usr/local/mosdns/dns.yaml` 文件,修改 mikrotik_amazon 插件的配置:
|
||||
|
||||
```yaml
|
||||
# 修改为你的 MikroTik 实际信息
|
||||
args: "192.168.1.1:8728:admin:your-password:false:10:amazon_ips:amazon_ips6:24:32:amazon_domain:86400"
|
||||
```
|
||||
|
||||
参数说明:
|
||||
- `192.168.1.1`: MikroTik 路由器 IP
|
||||
- `8728`: API 端口
|
||||
- `admin`: 用户名
|
||||
- `your-password`: 密码
|
||||
- `false`: 不使用 TLS
|
||||
- `10`: 连接超时时间
|
||||
- `amazon_ips`: IPv4 address list 名称
|
||||
- `amazon_ips6`: IPv6 address list 名称
|
||||
- `24`: IPv4 掩码
|
||||
- `32`: IPv6 掩码
|
||||
- `amazon_domain`: 注释
|
||||
- `86400`: 地址超时时间(24小时)
|
||||
|
||||
### 5. 创建 systemd 服务
|
||||
|
||||
```bash
|
||||
sudo tee /etc/systemd/system/mosdns.service > /dev/null <<EOF
|
||||
[Unit]
|
||||
Description=MosDNS DNS Server
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=root
|
||||
ExecStart=/usr/local/bin/mosdns -c /usr/local/mosdns/config.yaml
|
||||
Restart=always
|
||||
RestartSec=3
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# 重新加载 systemd 配置
|
||||
sudo systemctl daemon-reload
|
||||
|
||||
# 启用并启动服务
|
||||
sudo systemctl enable mosdns
|
||||
sudo systemctl start mosdns
|
||||
|
||||
# 检查服务状态
|
||||
sudo systemctl status mosdns
|
||||
```
|
||||
|
||||
### 6. 配置 DNS 转发
|
||||
|
||||
修改 `/etc/systemd/resolved.conf`:
|
||||
|
||||
```ini
|
||||
[Resolve]
|
||||
DNS=127.0.0.1
|
||||
#FallbackDNS=8.8.8.8 8.8.4.4
|
||||
#Domains=
|
||||
#DNSSEC=no
|
||||
#DNSOverTLS=no
|
||||
#MulticastDNS=yes
|
||||
#LLMNR=yes
|
||||
#Cache=yes
|
||||
#DNSStubListener=no
|
||||
```
|
||||
|
||||
重启 systemd-resolved:
|
||||
|
||||
```bash
|
||||
sudo systemctl restart systemd-resolved
|
||||
```
|
||||
|
||||
### 7. 测试配置
|
||||
|
||||
```bash
|
||||
# 测试 Amazon 域名解析
|
||||
nslookup amazon.com 127.0.0.1
|
||||
nslookup aws.amazon.com 127.0.0.1
|
||||
|
||||
# 检查 MikroTik address list 是否更新
|
||||
ssh admin@192.168.1.1 "/ip firewall address-list print where list=amazon_ips"
|
||||
|
||||
# 查看 mosdns 日志
|
||||
sudo journalctl -u mosdns -f
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 工作流程
|
||||
|
||||
1. **域名匹配**:当查询 Amazon 相关域名时,匹配 `amazon_domains` 集合
|
||||
2. **DNS 解析**:使用国外 DNS 服务器解析域名
|
||||
3. **IP 提取**:从 DNS 响应中提取 A 和 AAAA 记录
|
||||
4. **地址添加**:通过 MikroTik API 将 IP 添加到 address list
|
||||
5. **超时管理**:IP 地址会在 24 小时后自动过期
|
||||
|
||||
### 监控和调试
|
||||
|
||||
```bash
|
||||
# 查看实时日志
|
||||
sudo journalctl -u mosdns -f
|
||||
|
||||
# 查看服务状态
|
||||
sudo systemctl status mosdns
|
||||
|
||||
# 测试 MikroTik 连接
|
||||
curl -k https://192.168.1.1:8729/api/rest/ip/firewall/address-list
|
||||
```
|
||||
|
||||
### 故障排除
|
||||
|
||||
1. **连接失败**:检查 MikroTik IP、端口和认证信息
|
||||
2. **权限不足**:确保 MikroTik 用户具有管理 address list 的权限
|
||||
3. **域名文件缺失**:确保所有域名列表文件都已下载
|
||||
4. **DNS 解析失败**:检查上游 DNS 服务器配置
|
||||
|
||||
## 安全注意事项
|
||||
|
||||
1. 不要在配置文件中使用明文密码,考虑使用环境变量
|
||||
2. 限制对 MikroTik API 端口的访问
|
||||
3. 定期更新域名列表文件
|
||||
4. 监控 address list 大小,避免过多条目影响性能
|
||||
55
dns-example-gfw.yaml
Normal file
55
dns-example-gfw.yaml
Normal file
@ -0,0 +1,55 @@
|
||||
################ DNS Plugins #################
|
||||
plugins:
|
||||
|
||||
- tag: mikrotik-one
|
||||
type: forward
|
||||
args:
|
||||
concurrent: 1
|
||||
upstreams:
|
||||
- addr: "udp://10.248.0.1"
|
||||
|
||||
- tag: cn-dns
|
||||
type: forward
|
||||
args:
|
||||
concurrent: 6
|
||||
upstreams:
|
||||
- addr: "udp://202.96.128.86"
|
||||
- addr: "udp://202.96.128.166"
|
||||
- addr: "udp://119.29.29.29"
|
||||
- addr: "udp://223.5.5.5"
|
||||
- addr: "udp://114.114.114.114"
|
||||
- addr: "udp://180.76.76.76"
|
||||
|
||||
- tag: jp-dns
|
||||
type: forward
|
||||
args:
|
||||
concurrent: 4
|
||||
upstreams:
|
||||
- addr: "tls://1dot1dot1dot1.cloudflare-dns.com"
|
||||
dial_addr: "1.1.1.1"
|
||||
enable_pipeline: true
|
||||
- addr: "tls://1dot1dot1dot1.cloudflare-dns.com"
|
||||
dial_addr: "1.0.0.1"
|
||||
enable_pipeline: true
|
||||
- addr: "tls://dns.google"
|
||||
dial_addr: "8.8.8.8"
|
||||
enable_pipeline: true
|
||||
- addr: "tls://dns.google"
|
||||
dial_addr: "8.8.4.4"
|
||||
enable_pipeline: true
|
||||
|
||||
# MikroTik Address List 插件 - 处理 Amazon 相关域名
|
||||
# 示例:将地址列表改为 gfw
|
||||
- tag: mikrotik_amazon
|
||||
type: mikrotik_addresslist
|
||||
args:
|
||||
host: "10.248.0.1"
|
||||
port: 9728
|
||||
username: "admin"
|
||||
password: "szn0s!nw@pwd()"
|
||||
use_tls: false
|
||||
timeout: 10
|
||||
address_list4: "gfw" # 改为 gfw,插件会自动创建这个地址列表
|
||||
mask4: 24
|
||||
comment: "amazon_domain"
|
||||
timeout_addr: 86400
|
||||
57
dns.yaml
Normal file
57
dns.yaml
Normal file
@ -0,0 +1,57 @@
|
||||
################ DNS Plugins #################
|
||||
plugins:
|
||||
|
||||
- tag: mikrotik-one
|
||||
type: forward
|
||||
args:
|
||||
concurrent: 1
|
||||
upstreams:
|
||||
- addr: "udp://10.248.0.1"
|
||||
|
||||
|
||||
- tag: cn-dns
|
||||
type: forward
|
||||
args:
|
||||
concurrent: 6
|
||||
upstreams:
|
||||
- addr: "udp://202.96.128.86"
|
||||
- addr: "udp://202.96.128.166"
|
||||
- addr: "udp://119.29.29.29"
|
||||
- addr: "udp://223.5.5.5"
|
||||
- addr: "udp://114.114.114.114"
|
||||
- addr: "udp://180.76.76.76"
|
||||
|
||||
|
||||
- tag: jp-dns
|
||||
type: forward
|
||||
args:
|
||||
concurrent: 4 # 同步向 3 条上游并发查询
|
||||
upstreams:
|
||||
- addr: "tls://1dot1dot1dot1.cloudflare-dns.com"
|
||||
dial_addr: "1.1.1.1"
|
||||
enable_pipeline: true
|
||||
- addr: "tls://1dot1dot1dot1.cloudflare-dns.com"
|
||||
dial_addr: "1.0.0.1"
|
||||
enable_pipeline: true
|
||||
- addr: "tls://dns.google"
|
||||
dial_addr: "8.8.8.8"
|
||||
enable_pipeline: true
|
||||
- addr: "tls://dns.google"
|
||||
dial_addr: "8.8.4.4"
|
||||
enable_pipeline: true
|
||||
|
||||
# MikroTik Address List 插件 - 处理 Amazon 相关域名
|
||||
# 示例:将地址列表改为 gfw
|
||||
- tag: mikrotik_amazon
|
||||
type: mikrotik_addresslist
|
||||
args:
|
||||
host: "10.248.0.1"
|
||||
port: 9728
|
||||
username: "admin"
|
||||
password: "szn0s!nw@pwd()"
|
||||
use_tls: false
|
||||
timeout: 10
|
||||
address_list4: "gfw" # 改为 gfw,插件会自动创建这个地址列表
|
||||
mask4: 24
|
||||
comment: "amazon_domain"
|
||||
timeout_addr: 86400
|
||||
71
go.mod
Normal file
71
go.mod
Normal file
@ -0,0 +1,71 @@
|
||||
module github.com/IrineSistiana/mosdns/v5
|
||||
|
||||
go 1.22
|
||||
|
||||
require (
|
||||
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230918115058-c72bd9761c57
|
||||
github.com/go-chi/chi/v5 v5.1.0
|
||||
github.com/go-routeros/routeros/v3 v3.0.1
|
||||
github.com/google/nftables v0.2.0
|
||||
github.com/kardianos/service v1.2.2
|
||||
github.com/klauspost/compress v1.17.9
|
||||
github.com/miekg/dns v1.1.62
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/quic-go/quic-go v0.46.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/spf13/viper v1.19.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2.0.20221107222636-d3c0a2caa559
|
||||
go.uber.org/zap v1.27.0
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
|
||||
golang.org/x/net v0.28.0
|
||||
golang.org/x/sync v0.8.0
|
||||
golang.org/x/sys v0.24.0
|
||||
golang.org/x/time v0.6.0
|
||||
google.golang.org/protobuf v1.34.2
|
||||
)
|
||||
|
||||
replace github.com/nadoo/ipset v0.5.0 => github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a
|
||||
|
||||
require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.20.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.55.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.6.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.11.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.26.0 // indirect
|
||||
golang.org/x/mod v0.20.0 // indirect
|
||||
golang.org/x/text v0.17.0 // indirect
|
||||
golang.org/x/tools v0.24.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
154
go.sum
Normal file
154
go.sum
Normal file
@ -0,0 +1,154 @@
|
||||
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230918115058-c72bd9761c57 h1:nfurUSSmVY9sY/mYyoReOA1w2cR2fp2eicL9ojicZhQ=
|
||||
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230918115058-c72bd9761c57/go.mod h1:pQ/FSsWSNYmNdgIKmulKlmVC/R2PEpq2vIEi3J9IijI=
|
||||
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a h1:GQdh/h0q0ni3L//CXusyk+7QdhBL289vdNaes1WKkHI=
|
||||
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
|
||||
github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-routeros/routeros/v3 v3.0.1 h1:FdNKlF6Hst8nkHr0dIvD54pQ+dZ8sHOJfQSVRKz0BFg=
|
||||
github.com/go-routeros/routeros/v3 v3.0.1/go.mod h1:j4mq65czXfKtHsdLkgVv8w7sNzyhLZy1TKi2zQDMpiQ=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
|
||||
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
|
||||
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k=
|
||||
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
|
||||
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/onsi/ginkgo/v2 v2.20.0 h1:PE84V2mHqoT1sglvHc8ZdQtPcwmvvt29WLEEO3xmdZw=
|
||||
github.com/onsi/ginkgo/v2 v2.20.0/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI=
|
||||
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
||||
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
|
||||
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
|
||||
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/quic-go v0.46.0 h1:uuwLClEEyk1DNvchH8uCByQVjo3yKL9opKulExNDs7Y=
|
||||
github.com/quic-go/quic-go v0.46.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3N51bwOk=
|
||||
github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
|
||||
github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2.0.20221107222636-d3c0a2caa559 h1:NwQroOyW+fpfiUroBzAMqFc6NRwBmvJevoVtEK6gsFE=
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2.0.20221107222636-d3c0a2caa559/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik=
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI=
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
|
||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
|
||||
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220804214406-8e32c043e418/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
|
||||
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
|
||||
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
|
||||
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
50
main.go
Normal file
50
main.go
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/coremain"
|
||||
"github.com/IrineSistiana/mosdns/v5/mlog"
|
||||
_ "github.com/IrineSistiana/mosdns/v5/plugin"
|
||||
_ "github.com/IrineSistiana/mosdns/v5/tools"
|
||||
"github.com/spf13/cobra"
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev/unknown"
|
||||
)
|
||||
|
||||
func init() {
|
||||
coremain.AddSubCmd(&cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print out version info and exit.",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Println(version)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := coremain.Run(); err != nil {
|
||||
mlog.S().Fatal(err)
|
||||
}
|
||||
}
|
||||
91
mlog/logger.go
Normal file
91
mlog/logger.go
Normal file
@ -0,0 +1,91 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package mlog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"os"
|
||||
)
|
||||
|
||||
type LogConfig struct {
|
||||
// Level, See also zapcore.ParseLevel.
|
||||
Level string `yaml:"level"`
|
||||
|
||||
// File that logger will be writen into.
|
||||
// Default is stderr.
|
||||
File string `yaml:"file"`
|
||||
|
||||
// Production enables json output.
|
||||
Production bool `yaml:"production"`
|
||||
}
|
||||
|
||||
var (
|
||||
stderr = zapcore.Lock(os.Stderr)
|
||||
lvl = zap.NewAtomicLevelAt(zap.InfoLevel)
|
||||
l = zap.New(zapcore.NewCore(zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), stderr, lvl))
|
||||
s = l.Sugar()
|
||||
|
||||
nop = zap.NewNop()
|
||||
)
|
||||
|
||||
func NewLogger(lc LogConfig) (*zap.Logger, error) {
|
||||
lvl, err := zapcore.ParseLevel(lc.Level)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid log level: %w", err)
|
||||
}
|
||||
|
||||
var out zapcore.WriteSyncer
|
||||
if lf := lc.File; len(lf) > 0 {
|
||||
f, _, err := zap.Open(lf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open log file: %w", err)
|
||||
}
|
||||
out = zapcore.Lock(f)
|
||||
} else {
|
||||
out = stderr
|
||||
}
|
||||
|
||||
if lc.Production {
|
||||
return zap.New(zapcore.NewCore(zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()), out, lvl)), nil
|
||||
}
|
||||
return zap.New(zapcore.NewCore(zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), out, lvl)), nil
|
||||
}
|
||||
|
||||
// L is a global logger.
|
||||
func L() *zap.Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
// SetLevel sets the log level for the global logger.
|
||||
func SetLevel(l zapcore.Level) {
|
||||
lvl.SetLevel(l)
|
||||
}
|
||||
|
||||
// S is a global logger.
|
||||
func S() *zap.SugaredLogger {
|
||||
return s
|
||||
}
|
||||
|
||||
// Nop is a logger that never writes out logs.
|
||||
func Nop() *zap.Logger {
|
||||
return nop
|
||||
}
|
||||
BIN
mosdns-linux-amd64
Normal file
BIN
mosdns-linux-amd64
Normal file
Binary file not shown.
157
pkg/cache/cache.go
vendored
Normal file
157
pkg/cache/cache.go
vendored
Normal file
@ -0,0 +1,157 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/concurrent_lru"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/concurrent_map"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCleanerInterval = time.Second * 10
|
||||
)
|
||||
|
||||
type Key interface {
|
||||
concurrent_lru.Hashable
|
||||
}
|
||||
|
||||
type Value interface {
|
||||
any
|
||||
}
|
||||
|
||||
// Cache is a simple map cache that stores values in memory.
|
||||
// It is safe for concurrent use.
|
||||
type Cache[K Key, V Value] struct {
|
||||
opts Opts
|
||||
|
||||
closed atomic.Bool
|
||||
closeNotify chan struct{}
|
||||
m *concurrent_map.Map[K, *elem[V]]
|
||||
}
|
||||
|
||||
type Opts struct {
|
||||
Size int
|
||||
CleanerInterval time.Duration
|
||||
}
|
||||
|
||||
func (opts *Opts) init() {
|
||||
utils.SetDefaultNum(&opts.Size, 1024)
|
||||
utils.SetDefaultNum(&opts.CleanerInterval, defaultCleanerInterval)
|
||||
}
|
||||
|
||||
type elem[V Value] struct {
|
||||
v V
|
||||
expirationTime time.Time
|
||||
}
|
||||
|
||||
// New initializes a Cache.
|
||||
// The minimum size is 1024.
|
||||
// cleanerInterval specifies the interval that Cache scans
|
||||
// and discards expired values. If cleanerInterval <= 0, a default
|
||||
// interval will be used.
|
||||
func New[K Key, V Value](opts Opts) *Cache[K, V] {
|
||||
opts.init()
|
||||
c := &Cache[K, V]{
|
||||
closeNotify: make(chan struct{}),
|
||||
m: concurrent_map.NewMapCache[K, *elem[V]](opts.Size),
|
||||
}
|
||||
go c.gcLoop(opts.CleanerInterval)
|
||||
return c
|
||||
}
|
||||
|
||||
// Close closes the inner cleaner of this cache.
|
||||
func (c *Cache[K, V]) Close() error {
|
||||
if ok := c.closed.CompareAndSwap(false, true); ok {
|
||||
close(c.closeNotify)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) Get(key K) (v V, expirationTime time.Time, ok bool) {
|
||||
if e, hasEntry := c.m.Get(key); hasEntry {
|
||||
if e.expirationTime.Before(time.Now()) {
|
||||
c.m.Del(key)
|
||||
return
|
||||
}
|
||||
return e.v, e.expirationTime, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Range calls f through all entries. If f returns an error, the same error will be returned
|
||||
// by Range.
|
||||
func (c *Cache[K, V]) Range(f func(key K, v V, expirationTime time.Time) error) error {
|
||||
cf := func(key K, v *elem[V]) (newV *elem[V], setV bool, delV bool, err error) {
|
||||
return nil, false, false, f(key, v.v, v.expirationTime)
|
||||
}
|
||||
return c.m.RangeDo(cf)
|
||||
}
|
||||
|
||||
// Store stores this kv in cache. If expirationTime is before time.Now(),
|
||||
// Store is an noop.
|
||||
func (c *Cache[K, V]) Store(key K, v V, expirationTime time.Time) {
|
||||
now := time.Now()
|
||||
if now.After(expirationTime) {
|
||||
return
|
||||
}
|
||||
|
||||
e := &elem[V]{
|
||||
v: v,
|
||||
expirationTime: expirationTime,
|
||||
}
|
||||
c.m.Set(key, e)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) gcLoop(interval time.Duration) {
|
||||
if interval <= 0 {
|
||||
interval = defaultCleanerInterval
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.closeNotify:
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
c.gc(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) gc(now time.Time) {
|
||||
f := func(key K, v *elem[V]) (newV *elem[V], setV, delV bool, err error) {
|
||||
return nil, false, now.After(v.expirationTime), nil
|
||||
}
|
||||
_ = c.m.RangeDo(f)
|
||||
}
|
||||
|
||||
// Len returns the current size of this cache.
|
||||
func (c *Cache[K, V]) Len() int {
|
||||
return c.m.Len()
|
||||
}
|
||||
|
||||
// Flush removes all stored entries from this cache.
|
||||
func (c *Cache[K, V]) Flush() {
|
||||
c.m.Flush()
|
||||
}
|
||||
98
pkg/cache/cache_test.go
vendored
Normal file
98
pkg/cache/cache_test.go
vendored
Normal file
@ -0,0 +1,98 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testKey int
|
||||
|
||||
func (t testKey) Sum() uint64 {
|
||||
return uint64(t)
|
||||
}
|
||||
|
||||
func Test_Cache(t *testing.T) {
|
||||
c := New[testKey, int](Opts{
|
||||
Size: 1024,
|
||||
})
|
||||
for i := 0; i < 128; i++ {
|
||||
key := testKey(i)
|
||||
c.Store(key, i, time.Now().Add(time.Millisecond*200))
|
||||
v, _, ok := c.Get(key)
|
||||
|
||||
if v != i {
|
||||
t.Fatal("cache kv mismatched")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 1024*4; i++ {
|
||||
key := testKey(i)
|
||||
c.Store(key, i, time.Now().Add(time.Millisecond*200))
|
||||
}
|
||||
|
||||
if l := c.Len(); l > 1024 {
|
||||
t.Fatal("cache overflow")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_memCache_cleaner(t *testing.T) {
|
||||
c := New[testKey, int](Opts{
|
||||
Size: 1024,
|
||||
CleanerInterval: time.Millisecond * 10,
|
||||
})
|
||||
defer c.Close()
|
||||
for i := 0; i < 64; i++ {
|
||||
key := testKey(i)
|
||||
c.Store(key, i, time.Now().Add(time.Millisecond*10))
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
if c.Len() != 0 {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_memCache_race(t *testing.T) {
|
||||
c := New[testKey, int](Opts{
|
||||
Size: 1024,
|
||||
})
|
||||
defer c.Close()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 32; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 256; i++ {
|
||||
key := testKey(i)
|
||||
c.Store(key, i, time.Now().Add(time.Minute))
|
||||
_, _, _ = c.Get(key)
|
||||
c.gc(time.Now())
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
149
pkg/concurrent_lru/concurrent_lru.go
Normal file
149
pkg/concurrent_lru/concurrent_lru.go
Normal file
@ -0,0 +1,149 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package concurrent_lru
|
||||
|
||||
import (
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/lru"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Hashable interface {
|
||||
comparable
|
||||
Sum() uint64
|
||||
}
|
||||
|
||||
type ShardedLRU[K Hashable, V any] struct {
|
||||
l []*ConcurrentLRU[K, V]
|
||||
}
|
||||
|
||||
func NewShardedLRU[K Hashable, V any](
|
||||
shardNum, maxSizePerShard int,
|
||||
onEvict func(key K, v V),
|
||||
) *ShardedLRU[K, V] {
|
||||
cl := &ShardedLRU[K, V]{
|
||||
l: make([]*ConcurrentLRU[K, V], 0, shardNum),
|
||||
}
|
||||
|
||||
for i := 0; i < shardNum; i++ {
|
||||
cl.l = append(cl.l, NewConecurrentLRU[K, V](maxSizePerShard, onEvict))
|
||||
}
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
func (c *ShardedLRU[K, V]) Add(key K, v V) {
|
||||
sl := c.getShard(key)
|
||||
sl.Add(key, v)
|
||||
}
|
||||
|
||||
func (c *ShardedLRU[K, V]) Del(key K) {
|
||||
sl := c.getShard(key)
|
||||
sl.Del(key)
|
||||
}
|
||||
|
||||
func (c *ShardedLRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) {
|
||||
for _, l := range c.l {
|
||||
removed += l.Clean(f)
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (c *ShardedLRU[K, V]) Flush() {
|
||||
for _, l := range c.l {
|
||||
l.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ShardedLRU[K, V]) Get(key K) (v V, ok bool) {
|
||||
sl := c.getShard(key)
|
||||
v, ok = sl.Get(key)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ShardedLRU[K, V]) Len() int {
|
||||
sum := 0
|
||||
for _, l := range c.l {
|
||||
sum += l.Len()
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func (c *ShardedLRU[K, V]) shardNum() int {
|
||||
return len(c.l)
|
||||
}
|
||||
|
||||
func (c *ShardedLRU[K, V]) getShard(key K) *ConcurrentLRU[K, V] {
|
||||
return c.l[key.Sum()%uint64(c.shardNum())]
|
||||
}
|
||||
|
||||
// ConcurrentLRU is a lru.LRU with a lock.
|
||||
// It is concurrent safe.
|
||||
type ConcurrentLRU[K comparable, V any] struct {
|
||||
sync.Mutex
|
||||
lru *lru.LRU[K, V]
|
||||
}
|
||||
|
||||
func NewConecurrentLRU[K comparable, V any](maxSize int, onEvict func(key K, v V)) *ConcurrentLRU[K, V] {
|
||||
return &ConcurrentLRU[K, V]{
|
||||
lru: lru.NewLRU[K, V](maxSize, onEvict),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConcurrentLRU[K, V]) Add(key K, v V) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
c.lru.Add(key, v)
|
||||
}
|
||||
|
||||
func (c *ConcurrentLRU[K, V]) Del(key K) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
c.lru.Del(key)
|
||||
}
|
||||
|
||||
func (c *ConcurrentLRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
return c.lru.Clean(f)
|
||||
}
|
||||
|
||||
func (c *ConcurrentLRU[K, V]) Flush() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.lru.Flush()
|
||||
}
|
||||
|
||||
func (c *ConcurrentLRU[K, V]) Get(key K) (v V, ok bool) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
v, ok = c.lru.Get(key)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ConcurrentLRU[K, V]) Len() int {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
return c.lru.Len()
|
||||
}
|
||||
111
pkg/concurrent_lru/concurrent_lru_test.go
Normal file
111
pkg/concurrent_lru/concurrent_lru_test.go
Normal file
@ -0,0 +1,111 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package concurrent_lru
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testKey int
|
||||
|
||||
func (k testKey) Sum() uint64 {
|
||||
return uint64(k)
|
||||
}
|
||||
|
||||
func TestConcurrentLRU(t *testing.T) {
|
||||
onEvict := func(k testKey, v int) {}
|
||||
|
||||
var cache *ShardedLRU[testKey, int]
|
||||
reset := func(shardNum, maxShardSize int) {
|
||||
cache = NewShardedLRU[testKey, int](shardNum, maxShardSize, onEvict)
|
||||
}
|
||||
|
||||
add := func(keys ...int) {
|
||||
for _, k := range keys {
|
||||
cache.Add(testKey(k), k)
|
||||
}
|
||||
}
|
||||
|
||||
mustGet := func(keys ...int) {
|
||||
for _, k := range keys {
|
||||
gotV, ok := cache.Get(testKey(k))
|
||||
if !ok || !reflect.DeepEqual(gotV, k) {
|
||||
t.Fatalf("want %v, got %v", k, gotV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
emptyGet := func(keys ...int) {
|
||||
for _, k := range keys {
|
||||
gotV, ok := cache.Get(testKey(k))
|
||||
if ok {
|
||||
t.Fatalf("want empty, got %v", gotV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
checkLen := func(want int) {
|
||||
if want != cache.Len() {
|
||||
t.Fatalf("want %v, got %v", want, cache.Len())
|
||||
}
|
||||
}
|
||||
|
||||
// test add
|
||||
reset(4, 16)
|
||||
add(1, 1, 1, 1, 2, 2, 3, 3, 4)
|
||||
checkLen(4)
|
||||
mustGet(1, 2, 3, 4)
|
||||
emptyGet(5, 6, 7, 9999)
|
||||
|
||||
// test add overflow
|
||||
reset(4, 16) // max size is 64
|
||||
for i := 0; i < 1024; i++ {
|
||||
add(i)
|
||||
}
|
||||
if cache.Len() > 64 {
|
||||
t.Fatalf("lru overflowed: want len = %d, got = %d", 64, cache.Len())
|
||||
}
|
||||
|
||||
// test del
|
||||
reset(4, 16)
|
||||
add(1, 2, 3, 4)
|
||||
cache.Del(2)
|
||||
cache.Del(4)
|
||||
cache.Del(9999)
|
||||
mustGet(1, 3)
|
||||
emptyGet(2, 4)
|
||||
|
||||
// test clean
|
||||
reset(4, 16)
|
||||
add(1, 2, 3, 4)
|
||||
cleanFunc := func(k testKey, v int) (remove bool) {
|
||||
switch k {
|
||||
case 1, 3:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
if cleaned := cache.Clean(cleanFunc); cleaned != 2 {
|
||||
t.Fatalf("q.Clean want cleaned = 2, got %v", cleaned)
|
||||
}
|
||||
mustGet(2, 4)
|
||||
emptyGet(1, 3)
|
||||
}
|
||||
186
pkg/concurrent_map/map.go
Normal file
186
pkg/concurrent_map/map.go
Normal file
@ -0,0 +1,186 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package concurrent_map
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
MapShardSize = 64
|
||||
)
|
||||
|
||||
type Hashable interface {
|
||||
comparable
|
||||
Sum() uint64
|
||||
}
|
||||
|
||||
type TestAndSetFunc[K comparable, V any] func(key K, v V, ok bool) (newV V, setV, deleteV bool)
|
||||
|
||||
type Map[K Hashable, V any] struct {
|
||||
shards [MapShardSize]shard[K, V]
|
||||
}
|
||||
|
||||
func NewMap[K Hashable, V any]() *Map[K, V] {
|
||||
m := new(Map[K, V])
|
||||
for i := range m.shards {
|
||||
m.shards[i] = newShard[K, V](0)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// NewMapCache returns a cache with a maximum size.
|
||||
// Note that, because this it has multiple (MapShardSize) shards,
|
||||
// the actual maximum size is MapShardSize*(size / MapShardSize).
|
||||
// If size <=0, it's equal to NewMap().
|
||||
func NewMapCache[K Hashable, V any](size int) *Map[K, V] {
|
||||
sizePreShard := size / MapShardSize
|
||||
m := new(Map[K, V])
|
||||
for i := range m.shards {
|
||||
m.shards[i] = newShard[K, V](sizePreShard)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) getShard(key K) *shard[K, V] {
|
||||
return &m.shards[key.Sum()%MapShardSize]
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||
return m.getShard(key).get(key)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Set(key K, v V) {
|
||||
m.getShard(key).set(key, v)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Del(key K) {
|
||||
m.getShard(key).del(key)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) TestAndSet(key K, f func(v V, ok bool) (newV V, setV, delV bool)) {
|
||||
m.getShard(key).testAndSet(key, f)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) RangeDo(f func(k K, v V) (newV V, setV, delV bool, err error)) error {
|
||||
for i := range m.shards {
|
||||
if err := m.shards[i].rangeDo(f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Len() int {
|
||||
l := 0
|
||||
for i := range m.shards {
|
||||
l += m.shards[i].len()
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Flush() {
|
||||
for i := range m.shards {
|
||||
m.shards[i].flush()
|
||||
}
|
||||
}
|
||||
|
||||
type shard[K comparable, V any] struct {
|
||||
l sync.RWMutex
|
||||
max int // Negative or zero max means no limit.
|
||||
m map[K]V
|
||||
}
|
||||
|
||||
func newShard[K comparable, V any](max int) shard[K, V] {
|
||||
return shard[K, V]{
|
||||
max: max,
|
||||
m: make(map[K]V),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *shard[K, V]) get(key K) (V, bool) {
|
||||
m.l.RLock()
|
||||
defer m.l.RUnlock()
|
||||
v, ok := m.m[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *shard[K, V]) set(key K, v V) {
|
||||
m.l.Lock()
|
||||
defer m.l.Unlock()
|
||||
if m.max > 0 && len(m.m)+1 > m.max {
|
||||
for k := range m.m {
|
||||
delete(m.m, k)
|
||||
if len(m.m)+1 <= m.max {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.m[key] = v
|
||||
}
|
||||
|
||||
func (m *shard[K, V]) del(key K) {
|
||||
m.l.Lock()
|
||||
defer m.l.Unlock()
|
||||
delete(m.m, key)
|
||||
}
|
||||
|
||||
func (m *shard[K, V]) testAndSet(key K, f func(v V, ok bool) (newV V, setV, delV bool)) {
|
||||
m.l.Lock()
|
||||
defer m.l.Unlock()
|
||||
v, ok := m.m[key]
|
||||
newV, setV, deleteV := f(v, ok)
|
||||
switch {
|
||||
case setV:
|
||||
m.m[key] = newV
|
||||
case deleteV && ok:
|
||||
delete(m.m, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *shard[K, V]) len() int {
|
||||
m.l.RLock()
|
||||
defer m.l.RUnlock()
|
||||
return len(m.m)
|
||||
}
|
||||
|
||||
func (m *shard[K, V]) flush() {
|
||||
m.l.RLock()
|
||||
defer m.l.RUnlock()
|
||||
m.m = make(map[K]V)
|
||||
}
|
||||
|
||||
func (m *shard[K, V]) rangeDo(f func(k K, v V) (newV V, setV, delV bool, err error)) error {
|
||||
m.l.Lock()
|
||||
defer m.l.Unlock()
|
||||
for k, v := range m.m {
|
||||
newV, setV, deleteV, err := f(k, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch {
|
||||
case setV:
|
||||
m.m[k] = newV
|
||||
case deleteV:
|
||||
delete(m.m, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
190
pkg/concurrent_map/map_test.go
Normal file
190
pkg/concurrent_map/map_test.go
Normal file
@ -0,0 +1,190 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package concurrent_map
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testMapHashable uint64
|
||||
|
||||
func (h testMapHashable) Sum() uint64 {
|
||||
return uint64(h)
|
||||
}
|
||||
|
||||
func Test_Map(t *testing.T) {
|
||||
m := NewMap[testMapHashable, int]()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
// test add
|
||||
for i := 0; i < 512; i++ {
|
||||
i := i
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.Set(testMapHashable(i), i)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// test range
|
||||
wantErr := errors.New("")
|
||||
f := func(key testMapHashable, v int) (newV int, setV bool, deleteV bool, err error) {
|
||||
return 0, false, false, wantErr
|
||||
}
|
||||
if wantErr != m.RangeDo(f) {
|
||||
t.Fatal("range should return a error")
|
||||
}
|
||||
|
||||
cc := make([]bool, 512)
|
||||
f = func(key testMapHashable, v int) (newV int, setV bool, deleteV bool, err error) {
|
||||
cc[key] = true
|
||||
return 0, false, false, nil
|
||||
}
|
||||
_ = m.RangeDo(f)
|
||||
for _, ok := range cc {
|
||||
if !ok {
|
||||
t.Fatal("test or range failed")
|
||||
}
|
||||
}
|
||||
|
||||
// test get
|
||||
for i := 0; i < 512; i++ {
|
||||
i := i
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
v, ok := m.Get(testMapHashable(i))
|
||||
if !ok {
|
||||
t.Error()
|
||||
return
|
||||
}
|
||||
if v != i {
|
||||
t.Error()
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// test len
|
||||
if m.Len() != 512 {
|
||||
t.Fatal()
|
||||
}
|
||||
|
||||
// test del
|
||||
for i := 0; i < 512; i++ {
|
||||
i := i
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.Del(testMapHashable(i))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if m.Len() != 0 {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentMap_TestAndSet(t *testing.T) {
|
||||
cm := NewMap[testMapHashable, int]()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
f := func(v int, ok bool) (newV int, setV bool, deleteV bool) {
|
||||
return 1, true, false
|
||||
}
|
||||
|
||||
for i := 0; i < 512; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cm.TestAndSet(1, f)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
v, _ := cm.Get(1)
|
||||
if v != 1 {
|
||||
t.Fatal()
|
||||
}
|
||||
|
||||
// test delete
|
||||
f = func(v int, ok bool) (newV int, setV bool, deleteV bool) {
|
||||
return 1, false, true
|
||||
}
|
||||
cm.TestAndSet(1, f)
|
||||
_, ok := cm.Get(1)
|
||||
if ok {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConcurrentMap_Get_And_Set(b *testing.B) {
|
||||
keys := make([]testMapHashable, 2048)
|
||||
m := NewMap[testMapHashable, int]()
|
||||
for i := 0; i < 2048; i++ {
|
||||
key := testMapHashable(i)
|
||||
keys[i] = key
|
||||
m.Set(key, i)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
i++
|
||||
key := keys[i%2048]
|
||||
m.Set(key, i)
|
||||
m.Get(key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Benchmark_RWMutexMap_Get_And_Set(b *testing.B) {
|
||||
keys := make([]int, 2048)
|
||||
rwm := new(sync.RWMutex)
|
||||
m := make(map[int]int, 2048)
|
||||
for i := 0; i < 2048; i++ {
|
||||
keys[i] = i
|
||||
m[i] = i
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
i++
|
||||
key := keys[i%2048]
|
||||
|
||||
rwm.Lock()
|
||||
m[key] = i
|
||||
rwm.Unlock()
|
||||
|
||||
rwm.RLock()
|
||||
_ = m[key]
|
||||
rwm.RUnlock()
|
||||
}
|
||||
})
|
||||
}
|
||||
160
pkg/dnsutils/msg.go
Normal file
160
pkg/dnsutils/msg.go
Normal file
@ -0,0 +1,160 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package dnsutils
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// GetMinimalTTL returns the minimal ttl of this msg.
|
||||
// If msg m has no record, it returns 0.
|
||||
func GetMinimalTTL(m *dns.Msg) uint32 {
|
||||
minTTL := ^uint32(0)
|
||||
hasRecord := false
|
||||
for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||||
for _, rr := range section {
|
||||
hdr := rr.Header()
|
||||
if hdr.Rrtype == dns.TypeOPT {
|
||||
continue // opt record ttl is not ttl.
|
||||
}
|
||||
hasRecord = true
|
||||
ttl := hdr.Ttl
|
||||
if ttl < minTTL {
|
||||
minTTL = ttl
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRecord { // no ttl applied
|
||||
return 0
|
||||
}
|
||||
return minTTL
|
||||
}
|
||||
|
||||
// SetTTL updates all records' ttl to ttl, except opt record.
|
||||
func SetTTL(m *dns.Msg, ttl uint32) {
|
||||
for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||||
for _, rr := range section {
|
||||
hdr := rr.Header()
|
||||
if hdr.Rrtype == dns.TypeOPT {
|
||||
continue // opt record ttl is not ttl.
|
||||
}
|
||||
hdr.Ttl = ttl
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ApplyMaximumTTL(m *dns.Msg, ttl uint32) {
|
||||
applyTTL(m, ttl, true)
|
||||
}
|
||||
|
||||
func ApplyMinimalTTL(m *dns.Msg, ttl uint32) {
|
||||
applyTTL(m, ttl, false)
|
||||
}
|
||||
|
||||
// SubtractTTL subtract delta from every m's RR.
|
||||
// If RR's TTL is smaller than delta, SubtractTTL
|
||||
// will return overflowed = true.
|
||||
func SubtractTTL(m *dns.Msg, delta uint32) (overflowed bool) {
|
||||
for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||||
for _, rr := range section {
|
||||
hdr := rr.Header()
|
||||
if hdr.Rrtype == dns.TypeOPT {
|
||||
continue // opt record ttl is not ttl.
|
||||
}
|
||||
if ttl := hdr.Ttl; ttl > delta {
|
||||
hdr.Ttl = ttl - delta
|
||||
} else {
|
||||
hdr.Ttl = 1
|
||||
overflowed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func applyTTL(m *dns.Msg, ttl uint32, maximum bool) {
|
||||
for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||||
for _, rr := range section {
|
||||
hdr := rr.Header()
|
||||
if hdr.Rrtype == dns.TypeOPT {
|
||||
continue // opt record ttl is not ttl.
|
||||
}
|
||||
if maximum {
|
||||
if hdr.Ttl > ttl {
|
||||
hdr.Ttl = ttl
|
||||
}
|
||||
} else {
|
||||
if hdr.Ttl < ttl {
|
||||
hdr.Ttl = ttl
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func uint16Conv(u uint16, m map[uint16]string) string {
|
||||
if s, ok := m[u]; ok {
|
||||
return s
|
||||
}
|
||||
return strconv.Itoa(int(u))
|
||||
}
|
||||
|
||||
func QclassToString(u uint16) string {
|
||||
return uint16Conv(u, dns.ClassToString)
|
||||
}
|
||||
|
||||
func QtypeToString(u uint16) string {
|
||||
return uint16Conv(u, dns.TypeToString)
|
||||
}
|
||||
|
||||
func GenEmptyReply(q *dns.Msg, rcode int) *dns.Msg {
|
||||
r := new(dns.Msg)
|
||||
r.SetRcode(q, rcode)
|
||||
|
||||
var name string
|
||||
if len(q.Question) > 1 {
|
||||
name = q.Question[0].Name
|
||||
} else {
|
||||
name = "."
|
||||
}
|
||||
|
||||
r.Ns = []dns.RR{FakeSOA(name)}
|
||||
return r
|
||||
}
|
||||
|
||||
func FakeSOA(name string) *dns.SOA {
|
||||
return &dns.SOA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeSOA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
Ns: "fake-ns.mosdns.fake.root.",
|
||||
Mbox: "fake-mbox.mosdns.fake.root.",
|
||||
Serial: 2021110400,
|
||||
Refresh: 1800,
|
||||
Retry: 900,
|
||||
Expire: 604800,
|
||||
Minttl: 86400,
|
||||
}
|
||||
}
|
||||
138
pkg/dnsutils/net_io.go
Normal file
138
pkg/dnsutils/net_io.go
Normal file
@ -0,0 +1,138 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package dnsutils
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
DnsHeaderLen = 12 // minimum dns msg size
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPayloadTooSmall = errors.New("payload is to small for a valid dns msg")
|
||||
)
|
||||
|
||||
// ReadRawMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
|
||||
// with a two byte length field).
|
||||
// n represents how many bytes are read from c.
|
||||
// The returned the *[]byte should be released by pool.ReleaseBuf.
|
||||
func ReadRawMsgFromTCP(c io.Reader) (*[]byte, error) {
|
||||
h := pool.GetBuf(2)
|
||||
defer pool.ReleaseBuf(h)
|
||||
_, err := io.ReadFull(c, *h)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// dns length
|
||||
length := binary.BigEndian.Uint16(*h)
|
||||
if length <= DnsHeaderLen {
|
||||
return nil, ErrPayloadTooSmall
|
||||
}
|
||||
|
||||
b := pool.GetBuf(int(length))
|
||||
_, err = io.ReadFull(c, *b)
|
||||
if err != nil {
|
||||
pool.ReleaseBuf(b)
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// ReadMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
|
||||
// with a two byte length field).
|
||||
// n represents how many bytes are read from c.
|
||||
func ReadMsgFromTCP(c io.Reader) (*dns.Msg, int, error) {
|
||||
b, err := ReadRawMsgFromTCP(c)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer pool.ReleaseBuf(b)
|
||||
|
||||
m, err := unpackMsgWithDetailedErr(*b)
|
||||
return m, len(*b) + 2, err
|
||||
}
|
||||
|
||||
// WriteMsgToTCP packs and writes m to c in RFC 1035 format.
|
||||
// n represents how many bytes are written to c.
|
||||
func WriteMsgToTCP(c io.Writer, m *dns.Msg) (n int, err error) {
|
||||
buf, err := pool.PackTCPBuffer(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer pool.ReleaseBuf(buf)
|
||||
return c.Write(*buf)
|
||||
}
|
||||
|
||||
// WriteRawMsgToTCP See WriteMsgToTCP
|
||||
func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) {
|
||||
if len(b) > dns.MaxMsgSize {
|
||||
return 0, fmt.Errorf("payload length %d is greater than dns max msg size", len(b))
|
||||
}
|
||||
|
||||
buf := pool.GetBuf(len(b) + 2)
|
||||
defer pool.ReleaseBuf(buf)
|
||||
|
||||
binary.BigEndian.PutUint16((*buf)[:2], uint16(len(b)))
|
||||
copy((*buf)[2:], b)
|
||||
return c.Write((*buf))
|
||||
}
|
||||
|
||||
func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) {
|
||||
b, err := pool.PackBuffer(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer pool.ReleaseBuf(b)
|
||||
return c.Write(*b)
|
||||
}
|
||||
|
||||
func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) {
|
||||
if bufSize < dns.MinMsgSize {
|
||||
bufSize = dns.MinMsgSize
|
||||
}
|
||||
|
||||
b := pool.GetBuf(bufSize)
|
||||
defer pool.ReleaseBuf(b)
|
||||
n, err := c.Read(*b)
|
||||
if err != nil {
|
||||
return nil, n, err
|
||||
}
|
||||
|
||||
m, err := unpackMsgWithDetailedErr((*b)[:n])
|
||||
return m, n, err
|
||||
}
|
||||
|
||||
func unpackMsgWithDetailedErr(b []byte) (*dns.Msg, error) {
|
||||
m := new(dns.Msg)
|
||||
if err := m.Unpack(b); err != nil {
|
||||
return nil, fmt.Errorf("failed to unpack msg [%x], %w", b, err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
125
pkg/dnsutils/ptr_parser.go
Normal file
125
pkg/dnsutils/ptr_parser.go
Normal file
@ -0,0 +1,125 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package dnsutils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errNotPTRDomain = errors.New("domain does not has a ptr suffix")
|
||||
|
||||
const (
|
||||
IP4arpa = ".in-addr.arpa."
|
||||
IP6arpa = ".ip6.arpa."
|
||||
)
|
||||
|
||||
// ParsePTRQName returns the ip that a PTR query name contains.
|
||||
func ParsePTRQName(fqdn string) (netip.Addr, error) {
|
||||
switch {
|
||||
case strings.HasSuffix(fqdn, IP4arpa):
|
||||
return reverse4(fqdn[:len(fqdn)-len(IP4arpa)])
|
||||
case strings.HasSuffix(fqdn, IP6arpa):
|
||||
return reverse6(fqdn[:len(fqdn)-len(IP6arpa)])
|
||||
default:
|
||||
return netip.Addr{}, errNotPTRDomain
|
||||
}
|
||||
}
|
||||
|
||||
func reverse4(s string) (netip.Addr, error) {
|
||||
var buf [4]byte
|
||||
l := 0
|
||||
for offset := len(s); offset > 0 && l < len(buf); l++ {
|
||||
var label string
|
||||
label, offset = prevLabel(s, offset)
|
||||
n, err := strconv.ParseUint(label, 10, 8)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invaild bit, %w", err)
|
||||
}
|
||||
buf[l] = byte(n)
|
||||
}
|
||||
if l < len(buf) {
|
||||
return netip.Addr{}, fmt.Errorf("expact at least 3 labels, got %d", l)
|
||||
}
|
||||
return netip.AddrFrom4(buf), nil
|
||||
}
|
||||
|
||||
func reverse6(s string) (netip.Addr, error) {
|
||||
var buf [16]byte
|
||||
var val byte
|
||||
var tail bool
|
||||
var l int
|
||||
for offset := len(s); offset > 0 && l < len(buf); {
|
||||
var label string
|
||||
label, offset = prevLabel(s, offset)
|
||||
if len(label) != 1 {
|
||||
return netip.Addr{}, fmt.Errorf("invalid label %s", label)
|
||||
}
|
||||
b := label[0]
|
||||
n, ok := hex2byte(b)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("invaild bit %d", b)
|
||||
}
|
||||
if tail {
|
||||
buf[l] = val<<4 + n
|
||||
l++
|
||||
tail = false
|
||||
} else {
|
||||
val = n
|
||||
tail = true
|
||||
}
|
||||
}
|
||||
if l < len(buf) {
|
||||
return netip.Addr{}, fmt.Errorf("expact at least 16 bytes, got %d", l)
|
||||
}
|
||||
return netip.AddrFrom16(buf), nil
|
||||
}
|
||||
|
||||
func hex2byte(c byte) (byte, bool) {
|
||||
lower := func(c byte) byte {
|
||||
return c | ('x' - 'X')
|
||||
}
|
||||
var b byte
|
||||
switch {
|
||||
case '0' <= c && c <= '9':
|
||||
b = c - '0'
|
||||
case 'a' <= lower(c) && lower(c) <= 'z':
|
||||
b = lower(c) - 'a' + 10
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
return b, true
|
||||
}
|
||||
|
||||
func prevLabel(s string, offset int) (string, int) {
|
||||
for {
|
||||
s = s[:offset]
|
||||
n := strings.LastIndexByte(s, '.')
|
||||
label := s[n+1 : offset]
|
||||
if n != -1 && len(label) == 0 {
|
||||
offset = n
|
||||
continue
|
||||
}
|
||||
return label, n
|
||||
}
|
||||
}
|
||||
81
pkg/dnsutils/ptr_parser_test.go
Normal file
81
pkg/dnsutils/ptr_parser_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package dnsutils
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_reverse4(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want netip.Addr
|
||||
wantErr bool
|
||||
}{
|
||||
{"v4", "4.4.8.8", netip.MustParseAddr("8.8.4.4"), false},
|
||||
{"v4_with_prefix", "prefix.4.4.8.8", netip.MustParseAddr("8.8.4.4"), false},
|
||||
{"invalid_format", "123114123", netip.Addr{}, true},
|
||||
{"invalid_format", "12..311..4123..", netip.Addr{}, true},
|
||||
{"invalid_format", "...", netip.Addr{}, true},
|
||||
{"short_length", "4.8.8", netip.Addr{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := reverse4(tt.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("reverse4() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("reverse4() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_reverse6(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want netip.Addr
|
||||
wantErr bool
|
||||
}{
|
||||
{"v6", "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.MustParseAddr("2001:db8::567:89ab"), false},
|
||||
{"v6_with_prefix", "prefix.b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.MustParseAddr("2001:db8::567:89ab"), false},
|
||||
{"invalid_format", "123114123", netip.Addr{}, true},
|
||||
{"invalid_format", "..123...", netip.Addr{}, true},
|
||||
{"short_length", "0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.Addr{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := reverse6(tt.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("reverse6() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("reverse4() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
134
pkg/hosts/hosts.go
Normal file
134
pkg/hosts/hosts.go
Normal file
@ -0,0 +1,134 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain"
|
||||
"github.com/miekg/dns"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Hosts struct {
|
||||
matcher domain.Matcher[*IPs]
|
||||
}
|
||||
|
||||
// NewHosts creates a hosts using m.
|
||||
func NewHosts(m domain.Matcher[*IPs]) *Hosts {
|
||||
return &Hosts{
|
||||
matcher: m,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hosts) Lookup(fqdn string) (ipv4, ipv6 []netip.Addr) {
|
||||
ips, ok := h.matcher.Match(fqdn)
|
||||
if !ok {
|
||||
return nil, nil // no such host
|
||||
}
|
||||
return ips.IPv4, ips.IPv6
|
||||
}
|
||||
|
||||
func (h *Hosts) LookupMsg(m *dns.Msg) *dns.Msg {
|
||||
if len(m.Question) != 1 {
|
||||
return nil
|
||||
}
|
||||
q := m.Question[0]
|
||||
typ := q.Qtype
|
||||
fqdn := q.Name
|
||||
if q.Qclass != dns.ClassINET || (typ != dns.TypeA && typ != dns.TypeAAAA) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipv4, ipv6 := h.Lookup(fqdn)
|
||||
if len(ipv4)+len(ipv6) == 0 {
|
||||
return nil // no such host
|
||||
}
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetReply(m)
|
||||
switch {
|
||||
case typ == dns.TypeA && len(ipv4) > 0:
|
||||
for _, ip := range ipv4 {
|
||||
rr := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: fqdn,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 10,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
r.Answer = append(r.Answer, rr)
|
||||
}
|
||||
case typ == dns.TypeAAAA && len(ipv6) > 0:
|
||||
for _, ip := range ipv6 {
|
||||
rr := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: fqdn,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 10,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}
|
||||
r.Answer = append(r.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
||||
// Append fake SOA record for empty reply.
|
||||
if len(r.Answer) == 0 {
|
||||
r.Ns = []dns.RR{dnsutils.FakeSOA(fqdn)}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
type IPs struct {
|
||||
IPv4 []netip.Addr
|
||||
IPv6 []netip.Addr
|
||||
}
|
||||
|
||||
var _ domain.ParseStringFunc[*IPs] = ParseIPs
|
||||
|
||||
func ParseIPs(s string) (string, *IPs, error) {
|
||||
f := strings.Fields(s)
|
||||
if len(f) == 0 {
|
||||
return "", nil, errors.New("empty string")
|
||||
}
|
||||
|
||||
pattern := f[0]
|
||||
v := new(IPs)
|
||||
for _, ipStr := range f[1:] {
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid ip addr %s, %w", ipStr, err)
|
||||
}
|
||||
|
||||
if ip.Is4() { // is ipv4
|
||||
v.IPv4 = append(v.IPv4, ip)
|
||||
} else { // is ipv6
|
||||
v.IPv6 = append(v.IPv6, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return pattern, v, nil
|
||||
}
|
||||
105
pkg/hosts/hosts_test.go
Normal file
105
pkg/hosts/hosts_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain"
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var test_hosts = `
|
||||
# comment
|
||||
# empty line
|
||||
dns.google 8.8.8.8 8.8.4.4 2001:4860:4860::8844 2001:4860:4860::8888
|
||||
regexp:^123456789 192.168.1.1
|
||||
test.com 1.2.3.4 # will be replaced
|
||||
test.com 2.3.4.5
|
||||
# nxdomain.com 1.2.3.4
|
||||
`
|
||||
|
||||
func Test_hostsContainer_Match(t *testing.T) {
|
||||
m := domain.NewMixMatcher[*IPs]()
|
||||
m.SetDefaultMatcher(domain.MatcherDomain)
|
||||
err := domain.LoadFromTextReader[*IPs](m, bytes.NewBuffer([]byte(test_hosts)), ParseIPs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h := NewHosts(m)
|
||||
|
||||
type args struct {
|
||||
name string
|
||||
typ uint16
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantMatched bool
|
||||
wantAddr []string
|
||||
}{
|
||||
{"matched A", args{name: "dns.google.", typ: dns.TypeA}, true, []string{"8.8.8.8", "8.8.4.4"}},
|
||||
{"matched AAAA", args{name: "dns.google.", typ: dns.TypeAAAA}, true, []string{"2001:4860:4860::8844", "2001:4860:4860::8888"}},
|
||||
{"not matched A", args{name: "nxdomain.com.", typ: dns.TypeA}, false, nil},
|
||||
{"not matched A", args{name: "sub.dns.google.", typ: dns.TypeA}, false, nil},
|
||||
{"matched regexp A", args{name: "123456789.test.", typ: dns.TypeA}, true, []string{"192.168.1.1"}},
|
||||
{"not matched regexp A", args{name: "0123456789.test.", typ: dns.TypeA}, false, nil},
|
||||
{"test replacement", args{name: "test.com.", typ: dns.TypeA}, true, []string{"2.3.4.5"}},
|
||||
{"test matched domain with mismatched type", args{name: "test.com.", typ: dns.TypeAAAA}, true, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion(tt.args.name, tt.args.typ)
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := h.LookupMsg(q)
|
||||
if tt.wantMatched && r == nil {
|
||||
t.Fatal("Lookup() should not return a nil result")
|
||||
}
|
||||
|
||||
for _, s := range tt.wantAddr {
|
||||
wantIP := net.ParseIP(s)
|
||||
if wantIP == nil {
|
||||
t.Fatal("invalid test case addr")
|
||||
}
|
||||
found := false
|
||||
for _, rr := range r.Answer {
|
||||
var ip net.IP
|
||||
switch rr := rr.(type) {
|
||||
case *dns.A:
|
||||
ip = rr.A
|
||||
case *dns.AAAA:
|
||||
ip = rr.AAAA
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if ip.Equal(wantIP) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("wanted ip is not found in response")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
41
pkg/list/elem.go
Normal file
41
pkg/list/elem.go
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package list
|
||||
|
||||
type Elem[V any] struct {
|
||||
prev, next *Elem[V]
|
||||
list *List[V]
|
||||
|
||||
Value V
|
||||
}
|
||||
|
||||
func NewElem[V any](v V) *Elem[V] {
|
||||
return &Elem[V]{
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Elem[V]) Prev() *Elem[V] {
|
||||
return e.prev
|
||||
}
|
||||
|
||||
func (e *Elem[V]) Next() *Elem[V] {
|
||||
return e.next
|
||||
}
|
||||
99
pkg/list/list.go
Normal file
99
pkg/list/list.go
Normal file
@ -0,0 +1,99 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package list
|
||||
|
||||
type List[V any] struct {
|
||||
front, back *Elem[V]
|
||||
length int
|
||||
}
|
||||
|
||||
func New[V any]() *List[V] {
|
||||
return &List[V]{}
|
||||
}
|
||||
|
||||
func mustBeFreeElem[V any](e *Elem[V]) {
|
||||
if e.prev != nil || e.next != nil || e.list != nil {
|
||||
panic("element is in use")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *List[V]) Front() *Elem[V] {
|
||||
return l.front
|
||||
}
|
||||
|
||||
func (l *List[V]) Back() *Elem[V] {
|
||||
return l.back
|
||||
}
|
||||
|
||||
func (l *List[V]) Len() int {
|
||||
return l.length
|
||||
}
|
||||
|
||||
func (l *List[V]) PushFront(e *Elem[V]) *Elem[V] {
|
||||
mustBeFreeElem(e)
|
||||
l.length++
|
||||
e.list = l
|
||||
if l.front == nil {
|
||||
l.front = e
|
||||
l.back = e
|
||||
} else {
|
||||
e.next = l.front
|
||||
l.front.prev = e
|
||||
l.front = e
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func (l *List[V]) PushBack(e *Elem[V]) *Elem[V] {
|
||||
mustBeFreeElem(e)
|
||||
l.length++
|
||||
e.list = l
|
||||
if l.back == nil {
|
||||
l.front = e
|
||||
l.back = e
|
||||
} else {
|
||||
e.prev = l.back
|
||||
l.back.next = e
|
||||
l.back = e
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func (l *List[V]) PopElem(e *Elem[V]) *Elem[V] {
|
||||
if e.list != l {
|
||||
panic("elem is not belong to this list")
|
||||
}
|
||||
|
||||
l.length--
|
||||
if p := e.prev; p != nil {
|
||||
p.next = e.next
|
||||
}
|
||||
if n := e.next; n != nil {
|
||||
n.prev = e.prev
|
||||
}
|
||||
if e == l.front {
|
||||
l.front = e.next
|
||||
}
|
||||
if e == l.back {
|
||||
l.back = e.prev
|
||||
}
|
||||
e.prev, e.next, e.list = nil, nil, nil
|
||||
return e
|
||||
}
|
||||
104
pkg/list/list_test.go
Normal file
104
pkg/list/list_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package list
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func checkLinkPointers[V any](t *testing.T, l *List[V]) {
|
||||
t.Helper()
|
||||
|
||||
e := l.front
|
||||
for e != nil {
|
||||
if (e.next != nil && e.next.prev != e) || (e.prev != nil && e.prev.next != e) {
|
||||
t.Fatal("broken list")
|
||||
}
|
||||
e = e.next
|
||||
}
|
||||
}
|
||||
|
||||
func makeElems(n []int) []*Elem[int] {
|
||||
s := make([]*Elem[int], 0, len(n))
|
||||
for _, i := range n {
|
||||
s = append(s, NewElem(i))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func allValue[V any](l *List[V]) []V {
|
||||
s := make([]V, 0)
|
||||
node := l.front
|
||||
for node != nil {
|
||||
s = append(s, node.Value)
|
||||
node = node.next
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestList_Push(t *testing.T) {
|
||||
l := new(List[int])
|
||||
l.PushBack(NewElem(1))
|
||||
l.PushBack(NewElem(2))
|
||||
assert.Equal(t, []int{1, 2}, allValue(l))
|
||||
checkLinkPointers(t, l)
|
||||
|
||||
l = new(List[int])
|
||||
l.PushFront(NewElem(1))
|
||||
l.PushFront(NewElem(2))
|
||||
assert.Equal(t, []int{2, 1}, allValue(l))
|
||||
checkLinkPointers(t, l)
|
||||
}
|
||||
|
||||
func TestList_PopElem(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in []int
|
||||
pop int
|
||||
want int
|
||||
wantList []int
|
||||
}{
|
||||
{"pop front", []int{0, 1, 2}, 0, 0, []int{1, 2}},
|
||||
{"pop mid", []int{0, 1, 2}, 1, 1, []int{0, 2}},
|
||||
{"pop back", []int{0, 1, 2}, 2, 2, []int{0, 1}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := new(List[int])
|
||||
le := makeElems(tt.in)
|
||||
for _, e := range le {
|
||||
l.PushBack(e)
|
||||
}
|
||||
checkLinkPointers(t, l)
|
||||
|
||||
got := l.PopElem(le[tt.pop])
|
||||
checkLinkPointers(t, l)
|
||||
|
||||
if got.Value != tt.want {
|
||||
t.Errorf("PopElem() = %v, want %v", got.Value, tt.want)
|
||||
}
|
||||
if !reflect.DeepEqual(allValue(l), tt.wantList) {
|
||||
t.Errorf("allValue() = %v, want %v", allValue(l), tt.wantList)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
135
pkg/lru/lru.go
Normal file
135
pkg/lru/lru.go
Normal file
@ -0,0 +1,135 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package lru
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/list"
|
||||
)
|
||||
|
||||
type LRU[K comparable, V any] struct {
|
||||
maxSize int
|
||||
onEvict func(key K, v V)
|
||||
|
||||
l *list.List[KV[K, V]]
|
||||
m map[K]*list.Elem[KV[K, V]]
|
||||
}
|
||||
|
||||
type KV[K comparable, V any] struct {
|
||||
key K
|
||||
v V
|
||||
}
|
||||
|
||||
func NewLRU[K comparable, V any](maxSize int, onEvict func(key K, v V)) *LRU[K, V] {
|
||||
if maxSize <= 0 {
|
||||
panic(fmt.Sprintf("LRU: invalid max size: %d", maxSize))
|
||||
}
|
||||
|
||||
return &LRU[K, V]{
|
||||
maxSize: maxSize,
|
||||
onEvict: onEvict,
|
||||
l: list.New[KV[K, V]](),
|
||||
m: make(map[K]*list.Elem[KV[K, V]]),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *LRU[K, V]) Add(key K, v V) {
|
||||
if e, ok := q.m[key]; ok { // update existed key
|
||||
e.Value.v = v
|
||||
q.l.PushBack(q.l.PopElem(e))
|
||||
return
|
||||
}
|
||||
|
||||
o := q.Len() - q.maxSize + 1
|
||||
for o > 0 {
|
||||
key, v, _ := q.PopOldest()
|
||||
if q.onEvict != nil {
|
||||
q.onEvict(key, v)
|
||||
}
|
||||
o--
|
||||
}
|
||||
|
||||
e := list.NewElem(KV[K, V]{
|
||||
key: key,
|
||||
v: v,
|
||||
})
|
||||
q.m[key] = e
|
||||
q.l.PushBack(e)
|
||||
}
|
||||
|
||||
func (q *LRU[K, V]) Del(key K) {
|
||||
e := q.m[key]
|
||||
if e != nil {
|
||||
q.delElem(e)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *LRU[K, V]) delElem(e *list.Elem[KV[K, V]]) {
|
||||
key, v := e.Value.key, e.Value.v
|
||||
q.l.PopElem(e)
|
||||
delete(q.m, key)
|
||||
if q.onEvict != nil {
|
||||
q.onEvict(key, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *LRU[K, V]) PopOldest() (key K, v V, ok bool) {
|
||||
e := q.l.Front()
|
||||
if e != nil {
|
||||
q.l.PopElem(e)
|
||||
key, v = e.Value.key, e.Value.v
|
||||
delete(q.m, key)
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (q *LRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) {
|
||||
e := q.l.Front()
|
||||
for e != nil {
|
||||
next := e.Next() // Delete e will clean its pointers. Save it first.
|
||||
key, v := e.Value.key, e.Value.v
|
||||
if remove := f(key, v); remove {
|
||||
q.delElem(e)
|
||||
removed++
|
||||
}
|
||||
e = next
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (q *LRU[K, V]) Flush() {
|
||||
q.l = list.New[KV[K, V]]()
|
||||
q.m = make(map[K]*list.Elem[KV[K, V]])
|
||||
}
|
||||
|
||||
func (q *LRU[K, V]) Get(key K) (v V, ok bool) {
|
||||
e, ok := q.m[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
q.l.PushBack(q.l.PopElem(e))
|
||||
return e.Value.v, true
|
||||
}
|
||||
|
||||
func (q *LRU[K, V]) Len() int {
|
||||
return q.l.Len()
|
||||
}
|
||||
138
pkg/lru/lru_test.go
Normal file
138
pkg/lru/lru_test.go
Normal file
@ -0,0 +1,138 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package lru
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_lru(t *testing.T) {
|
||||
var q *LRU[int, int]
|
||||
reset := func(maxSize int) {
|
||||
t.Helper()
|
||||
q = NewLRU[int, int](maxSize, nil)
|
||||
}
|
||||
|
||||
add := func(keys ...int) {
|
||||
t.Helper()
|
||||
for _, key := range keys {
|
||||
q.Add(key, key)
|
||||
}
|
||||
}
|
||||
|
||||
mustGet := func(keys ...int) {
|
||||
t.Helper()
|
||||
for _, key := range keys {
|
||||
gotV, ok := q.Get(key)
|
||||
if !ok || gotV != key {
|
||||
t.Fatalf("want %v, got %v", key, gotV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
emptyGet := func(keys ...int) {
|
||||
t.Helper()
|
||||
for _, key := range keys {
|
||||
gotV, ok := q.Get(key)
|
||||
if ok {
|
||||
t.Fatalf("want empty, got %v", gotV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mustPopOldest := func(keys ...int) {
|
||||
t.Helper()
|
||||
for _, key := range keys {
|
||||
gotKey, gotV, ok := q.PopOldest()
|
||||
if !ok {
|
||||
t.Fatal()
|
||||
}
|
||||
if gotKey != key || gotV != gotKey {
|
||||
t.Fatalf("want key: %v, v: %v, got key: %v, v:%v", key, key, gotKey, gotV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
emptyPop := func() {
|
||||
t.Helper()
|
||||
gotKey, gotV, ok := q.PopOldest()
|
||||
if ok {
|
||||
t.Fatalf("want empty result, got key: %v, v:%v", gotKey, gotV)
|
||||
}
|
||||
}
|
||||
|
||||
checkLen := func(want int) {
|
||||
t.Helper()
|
||||
if q.l.Len() != len(q.m) {
|
||||
t.Fatalf("possible mem leak: q.l.Len() %v != len(q.m){ %v", q.l.Len(), len(q.m))
|
||||
}
|
||||
if want != q.Len() {
|
||||
t.Fatalf("want %v, got %v", want, q.Len())
|
||||
}
|
||||
}
|
||||
|
||||
// test add
|
||||
reset(4)
|
||||
add(1, 1, 1, 1, 1, 1, 2, 3)
|
||||
checkLen(3)
|
||||
mustGet(1, 2, 3)
|
||||
|
||||
// test add overflow
|
||||
reset(2)
|
||||
add(1, 2, 3, 4, 5)
|
||||
checkLen(2)
|
||||
mustGet(4, 5)
|
||||
emptyGet(1, 2, 3)
|
||||
|
||||
// test pop
|
||||
reset(3)
|
||||
add(1, 2, 3)
|
||||
mustPopOldest(1, 2, 3)
|
||||
checkLen(0)
|
||||
emptyPop()
|
||||
|
||||
// test del
|
||||
reset(3)
|
||||
add(1, 2, 3)
|
||||
q.Del(2)
|
||||
q.Del(9999)
|
||||
mustPopOldest(1, 3)
|
||||
|
||||
// test clean
|
||||
reset(4)
|
||||
add(1, 2, 3, 4)
|
||||
cleanFunc := func(key int, v int) (remove bool) {
|
||||
switch key {
|
||||
case 1, 3:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
if cleaned := q.Clean(cleanFunc); cleaned != 2 {
|
||||
t.Fatalf("q.Clean want cleaned = 2, got %v", cleaned)
|
||||
}
|
||||
mustPopOldest(2, 4)
|
||||
|
||||
// test lru
|
||||
reset(4)
|
||||
add(1, 2, 3, 4) // 1 2 3 4
|
||||
mustGet(2, 3) // 1 4 2 3
|
||||
mustPopOldest(1, 4, 2, 3)
|
||||
}
|
||||
36
pkg/matcher/domain/interface.go
Normal file
36
pkg/matcher/domain/interface.go
Normal file
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package domain
|
||||
|
||||
// "fqdn-insensitive" means the domain in Add() and Match() call
|
||||
// is fqdn-insensitive. "google.com" and "google.com." will get
|
||||
// the same outcome.
|
||||
// The logic for case-insensitive is the same as above.
|
||||
|
||||
type Matcher[T any] interface {
|
||||
// Match matches the domain s.
|
||||
// s could be a fqdn or not, and should be case-insensitive.
|
||||
Match(s string) (v T, ok bool)
|
||||
}
|
||||
|
||||
type WriteableMatcher[T any] interface {
|
||||
Matcher[T]
|
||||
Add(pattern string, v T) error
|
||||
}
|
||||
80
pkg/matcher/domain/load_helper.go
Normal file
80
pkg/matcher/domain/load_helper.go
Normal file
@ -0,0 +1,80 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package domain
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
"io"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// ParseStringFunc parse data string to matcher pattern and additional attributions.
|
||||
type ParseStringFunc[T any] func(s string) (pattern string, v T, err error)
|
||||
|
||||
// check if s only contains a domain pattern (no other section, no space).
|
||||
func patternOnly[T any](s string) (pattern string, v T, err error) {
|
||||
if strings.IndexFunc(s, unicode.IsSpace) != -1 {
|
||||
return "", v, errors.New("rule string has more than one section")
|
||||
}
|
||||
return s, v, nil
|
||||
}
|
||||
|
||||
// Load loads data from a string, parsing it with parseString function.
|
||||
func Load[T any](m WriteableMatcher[T], s string, parseString ParseStringFunc[T]) error {
|
||||
if parseString == nil {
|
||||
parseString = patternOnly[T]
|
||||
}
|
||||
pattern, v, err := parseString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.Add(pattern, v)
|
||||
}
|
||||
|
||||
// LoadFromTextReader loads multiple lines from reader r. r
|
||||
func LoadFromTextReader[T any](m WriteableMatcher[T], r io.Reader, parseString ParseStringFunc[T]) error {
|
||||
lineCounter := 0
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
lineCounter++
|
||||
s := scanner.Text()
|
||||
s = utils.RemoveComment(s, "#")
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err := Load(m, s, parseString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("line %d: %v", lineCounter, err)
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func NewDomainMixMatcher() *MixMatcher[struct{}] {
|
||||
mixMatcher := NewMixMatcher[struct{}]()
|
||||
mixMatcher.SetDefaultMatcher(MatcherDomain)
|
||||
return mixMatcher
|
||||
}
|
||||
275
pkg/matcher/domain/matcher.go
Normal file
275
pkg/matcher/domain/matcher.go
Normal file
@ -0,0 +1,275 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
)
|
||||
|
||||
var _ WriteableMatcher[any] = (*MixMatcher[any])(nil)
|
||||
var _ WriteableMatcher[any] = (*SubDomainMatcher[any])(nil)
|
||||
var _ WriteableMatcher[any] = (*FullMatcher[any])(nil)
|
||||
var _ WriteableMatcher[any] = (*KeywordMatcher[any])(nil)
|
||||
var _ WriteableMatcher[any] = (*RegexMatcher[any])(nil)
|
||||
|
||||
type SubDomainMatcher[T any] struct {
|
||||
root *labelNode[T]
|
||||
}
|
||||
|
||||
func NewSubDomainMatcher[T any]() *SubDomainMatcher[T] {
|
||||
return &SubDomainMatcher[T]{root: new(labelNode[T])}
|
||||
}
|
||||
|
||||
func (m *SubDomainMatcher[T]) Match(s string) (T, bool) {
|
||||
s = NormalizeDomain(s)
|
||||
ds := NewReverseDomainScanner(s)
|
||||
currentNode := m.root
|
||||
v, ok := currentNode.getValue()
|
||||
for ds.Scan() {
|
||||
label := ds.NextLabel()
|
||||
if nextNode := currentNode.getChild(label); nextNode != nil {
|
||||
if nextNode.hasValue() {
|
||||
v, ok = nextNode.getValue()
|
||||
}
|
||||
currentNode = nextNode
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *SubDomainMatcher[T]) Len() int {
|
||||
return m.root.len()
|
||||
}
|
||||
|
||||
func (m *SubDomainMatcher[T]) Add(s string, v T) error {
|
||||
s = NormalizeDomain(s)
|
||||
ds := NewReverseDomainScanner(s)
|
||||
currentNode := m.root
|
||||
for ds.Scan() {
|
||||
label := ds.NextLabel()
|
||||
if child := currentNode.getChild(label); child != nil {
|
||||
currentNode = child
|
||||
} else {
|
||||
currentNode = currentNode.newChild(label)
|
||||
}
|
||||
}
|
||||
currentNode.storeValue(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
type FullMatcher[T any] struct {
|
||||
m map[string]T // string in is map must be a normalized domain (See NormalizeDomain).
|
||||
}
|
||||
|
||||
func NewFullMatcher[T any]() *FullMatcher[T] {
|
||||
return &FullMatcher[T]{
|
||||
m: make(map[string]T),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds domain s to this matcher, s can be a fqdn or not.
|
||||
func (m *FullMatcher[T]) Add(s string, v T) error {
|
||||
s = NormalizeDomain(s)
|
||||
m.m[s] = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *FullMatcher[T]) Match(s string) (v T, ok bool) {
|
||||
s = NormalizeDomain(s)
|
||||
v, ok = m.m[s]
|
||||
return
|
||||
}
|
||||
|
||||
func (m *FullMatcher[T]) Len() int {
|
||||
return len(m.m)
|
||||
}
|
||||
|
||||
type KeywordMatcher[T any] struct {
|
||||
kws map[string]T
|
||||
}
|
||||
|
||||
func NewKeywordMatcher[T any]() *KeywordMatcher[T] {
|
||||
return &KeywordMatcher[T]{
|
||||
kws: make(map[string]T),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *KeywordMatcher[T]) Add(keyword string, v T) error {
|
||||
keyword = NormalizeDomain(keyword) // fqdn-insensitive and case-insensitive
|
||||
m.kws[keyword] = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *KeywordMatcher[T]) Match(s string) (v T, ok bool) {
|
||||
s = NormalizeDomain(s)
|
||||
for k, v := range m.kws {
|
||||
if strings.Contains(s, k) {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
return v, false
|
||||
}
|
||||
|
||||
func (m *KeywordMatcher[T]) Len() int {
|
||||
return len(m.kws)
|
||||
}
|
||||
|
||||
// RegexMatcher contains regexp rules.
|
||||
// Note: the regexp rule is expect to match a lower-case non fqdn.
|
||||
type RegexMatcher[T any] struct {
|
||||
regs map[string]*regElem[T]
|
||||
}
|
||||
|
||||
type regElem[T any] struct {
|
||||
reg *regexp.Regexp
|
||||
v T
|
||||
}
|
||||
|
||||
func NewRegexMatcher[T any]() *RegexMatcher[T] {
|
||||
return &RegexMatcher[T]{regs: make(map[string]*regElem[T])}
|
||||
}
|
||||
|
||||
func (m *RegexMatcher[T]) Add(expr string, v T) error {
|
||||
e := m.regs[expr]
|
||||
if e == nil {
|
||||
reg, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.regs[expr] = ®Elem[T]{
|
||||
reg: reg,
|
||||
v: v,
|
||||
}
|
||||
} else {
|
||||
e.v = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *RegexMatcher[T]) Match(s string) (v T, ok bool) {
|
||||
s = NormalizeDomain(s)
|
||||
for _, e := range m.regs {
|
||||
if e.reg.MatchString(s) {
|
||||
return e.v, true
|
||||
}
|
||||
}
|
||||
var zeroT T
|
||||
return zeroT, false
|
||||
}
|
||||
|
||||
func (m *RegexMatcher[T]) Len() int {
|
||||
return len(m.regs)
|
||||
}
|
||||
|
||||
const (
|
||||
MatcherFull = "full"
|
||||
MatcherDomain = "domain"
|
||||
MatcherRegexp = "regexp"
|
||||
MatcherKeyword = "keyword"
|
||||
)
|
||||
|
||||
type MixMatcher[T any] struct {
|
||||
defaultMatcher string
|
||||
|
||||
full *FullMatcher[T]
|
||||
domain *SubDomainMatcher[T]
|
||||
regex *RegexMatcher[T]
|
||||
keyword *KeywordMatcher[T]
|
||||
}
|
||||
|
||||
func NewMixMatcher[T any]() *MixMatcher[T] {
|
||||
return &MixMatcher[T]{
|
||||
full: NewFullMatcher[T](),
|
||||
domain: NewSubDomainMatcher[T](),
|
||||
regex: NewRegexMatcher[T](),
|
||||
keyword: NewKeywordMatcher[T](),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MixMatcher[T]) SetDefaultMatcher(s string) {
|
||||
m.defaultMatcher = s
|
||||
}
|
||||
|
||||
func (m *MixMatcher[T]) GetSubMatcher(typ string) WriteableMatcher[T] {
|
||||
switch typ {
|
||||
case MatcherFull:
|
||||
return m.full
|
||||
case MatcherDomain:
|
||||
return m.domain
|
||||
case MatcherRegexp:
|
||||
return m.regex
|
||||
case MatcherKeyword:
|
||||
return m.keyword
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrNodefaultMatcher = errors.New("default matcher is not set")
|
||||
|
||||
func (m *MixMatcher[T]) Add(s string, v T) error {
|
||||
typ, pattern := m.splitTypeAndPattern(s)
|
||||
if len(typ) == 0 {
|
||||
if len(m.defaultMatcher) != 0 {
|
||||
typ = m.defaultMatcher
|
||||
} else {
|
||||
return ErrNodefaultMatcher
|
||||
}
|
||||
}
|
||||
sm := m.GetSubMatcher(typ)
|
||||
if sm == nil {
|
||||
return fmt.Errorf("unsupported match type [%s]", typ)
|
||||
}
|
||||
return sm.Add(pattern, v)
|
||||
}
|
||||
|
||||
func (m *MixMatcher[T]) Match(s string) (v T, ok bool) {
|
||||
for _, matcher := range [...]Matcher[T]{m.full, m.domain, m.regex, m.keyword} {
|
||||
if v, ok = matcher.Match(s); ok {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MixMatcher[T]) Len() int {
|
||||
sum := 0
|
||||
for _, matcher := range [...]interface{ Len() int }{m.full, m.domain, m.regex, m.keyword} {
|
||||
if matcher == nil {
|
||||
continue
|
||||
}
|
||||
sum += matcher.Len()
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func (m *MixMatcher[T]) splitTypeAndPattern(s string) (string, string) {
|
||||
typ, pattern, ok := utils.SplitString2(s, ":")
|
||||
if !ok {
|
||||
pattern = s
|
||||
}
|
||||
return typ, pattern
|
||||
}
|
||||
200
pkg/matcher/domain/matcher_test.go
Normal file
200
pkg/matcher/domain/matcher_test.go
Normal file
@ -0,0 +1,200 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package domain
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func assertFunc[T any](t *testing.T, m Matcher[T]) func(domain string, wantBool bool, wantV any) {
|
||||
return func(domain string, wantBool bool, wantV any) {
|
||||
t.Helper()
|
||||
v, ok := m.Match(domain)
|
||||
if ok != wantBool {
|
||||
t.Fatalf("%s, wantBool = %v, got = %v", domain, wantBool, ok)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(v, wantV) {
|
||||
t.Fatalf("%s, wantV = %v, got = %v", domain, wantV, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type aStr struct {
|
||||
s string
|
||||
}
|
||||
|
||||
func s(str string) *aStr {
|
||||
return &aStr{s: str}
|
||||
}
|
||||
|
||||
func (a *aStr) Append(v any) {
|
||||
a.s = a.s + v.(*aStr).s
|
||||
}
|
||||
|
||||
func TestDomainMatcher(t *testing.T) {
|
||||
m := NewSubDomainMatcher[any]()
|
||||
add := func(domain string, v any) {
|
||||
m.Add(domain, v)
|
||||
}
|
||||
assert := assertFunc[any](t, m)
|
||||
|
||||
add("cn", nil)
|
||||
assertInt(t, 1, m.Len())
|
||||
assert("cn", true, nil)
|
||||
assert("a.cn.", true, nil)
|
||||
assert("a.com", false, nil)
|
||||
add("a.b.com", nil)
|
||||
assertInt(t, 2, m.Len())
|
||||
assert("a.b.com.", true, nil)
|
||||
assert("q.w.e.r.a.b.com.", true, nil)
|
||||
assert("b.com.", false, nil)
|
||||
|
||||
// test replace
|
||||
add("append", 0)
|
||||
assertInt(t, 3, m.Len())
|
||||
assert("append.", true, 0)
|
||||
add("append.", 1)
|
||||
assert("append.", true, 1)
|
||||
add("append", nil)
|
||||
assert("append.", true, nil)
|
||||
|
||||
// test sub domain
|
||||
add("sub", 1)
|
||||
assertInt(t, 4, m.Len())
|
||||
add("a.sub", 2)
|
||||
assertInt(t, 5, m.Len())
|
||||
assert("sub", true, 1)
|
||||
assert("b.sub", true, 1)
|
||||
assert("a.sub", true, 2)
|
||||
assert("a.a.sub", true, 2)
|
||||
|
||||
// test case-insensitive
|
||||
add("UPpER", 1)
|
||||
assert("LowER.Upper", true, 1)
|
||||
|
||||
// root match
|
||||
add(".", 9)
|
||||
assert("any.domain", true, 9)
|
||||
}
|
||||
|
||||
func assertInt(t testing.TB, want, got int) {
|
||||
t.Helper()
|
||||
if want != got {
|
||||
t.Errorf("assertion failed: want %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FullMatcher(t *testing.T) {
|
||||
m := NewFullMatcher[any]()
|
||||
assert := assertFunc[any](t, m)
|
||||
add := func(domain string, v any) {
|
||||
m.Add(domain, v)
|
||||
}
|
||||
|
||||
add("cn", nil)
|
||||
assert("cn", true, nil)
|
||||
assert("a.cn", false, nil)
|
||||
add("test.test", nil)
|
||||
assert("test.test", true, nil)
|
||||
assert("test.a.test", false, nil)
|
||||
|
||||
// test replace
|
||||
add("append", 0)
|
||||
assert("append", true, 0)
|
||||
add("append", 1)
|
||||
assert("append", true, 1)
|
||||
add("append", nil)
|
||||
assert("append", true, nil)
|
||||
|
||||
assertInt(t, m.Len(), 3)
|
||||
|
||||
// test case-insensitive
|
||||
add("UPpER", 1)
|
||||
assert("Upper", true, 1)
|
||||
}
|
||||
|
||||
func Test_KeywordMatcher(t *testing.T) {
|
||||
m := NewKeywordMatcher[any]()
|
||||
add := func(domain string, v any) {
|
||||
m.Add(domain, v)
|
||||
}
|
||||
|
||||
assert := assertFunc[any](t, m)
|
||||
|
||||
add("123", s("a"))
|
||||
assert("123456.cn", true, s("a"))
|
||||
assert("111123.com", true, s("a"))
|
||||
assert("111111.cn", false, nil)
|
||||
add("example.com", nil)
|
||||
assert("sub.example.com", true, nil)
|
||||
assert("example_sub.com", false, nil)
|
||||
|
||||
// test replace
|
||||
add("append", 0)
|
||||
assert("append", true, 0)
|
||||
add("append", 1)
|
||||
assert("append", true, 1)
|
||||
add("append", nil)
|
||||
assert("append", true, nil)
|
||||
|
||||
assertInt(t, m.Len(), 3)
|
||||
|
||||
// test case-insensitive
|
||||
add("UPpER", 1)
|
||||
assert("L.Upper.U", true, 1)
|
||||
}
|
||||
|
||||
func Test_RegexMatcher(t *testing.T) {
|
||||
m := NewRegexMatcher[any]()
|
||||
add := func(expr string, v any, wantErr bool) {
|
||||
err := m.Add(expr, v)
|
||||
if (err != nil) != wantErr {
|
||||
t.Fatalf("%s: want err %v, got %v", expr, wantErr, err != nil)
|
||||
}
|
||||
}
|
||||
|
||||
assert := assertFunc[any](t, m)
|
||||
|
||||
expr := "^github-production-release-asset-[0-9a-za-z]{6}\\.s3\\.amazonaws\\.com$"
|
||||
add(expr, nil, false)
|
||||
assert("github-production-release-asset-000000.s3.amazonaws.com", true, nil)
|
||||
assert("github-production-release-asset-aaaaaa.s3.amazonaws.com", true, nil)
|
||||
assert("github-production-release-asset-aa.s3.amazonaws.com", false, nil)
|
||||
assert("prefix_github-production-release-asset-000000.s3.amazonaws.com", false, nil)
|
||||
assert("github-production-release-asset-000000.s3.amazonaws.com.suffix", false, nil)
|
||||
|
||||
expr = "^example"
|
||||
add(expr, nil, false)
|
||||
assert("example.com", true, nil)
|
||||
assert("sub.example.com", false, nil)
|
||||
|
||||
// test replace
|
||||
add("append", 0, false)
|
||||
assert("append", true, 0)
|
||||
add("append", 1, false)
|
||||
assert("append", true, 1)
|
||||
add("append", nil, false)
|
||||
assert("append", true, nil)
|
||||
|
||||
expr = "*"
|
||||
add(expr, nil, true)
|
||||
}
|
||||
116
pkg/matcher/domain/utils.go
Normal file
116
pkg/matcher/domain/utils.go
Normal file
@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package domain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ReverseDomainScanner struct {
|
||||
s string // not fqdn
|
||||
p int
|
||||
t int
|
||||
}
|
||||
|
||||
func NewReverseDomainScanner(s string) *ReverseDomainScanner {
|
||||
s = TrimDot(s)
|
||||
return &ReverseDomainScanner{
|
||||
s: s,
|
||||
p: len(s),
|
||||
t: len(s),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ReverseDomainScanner) Scan() bool {
|
||||
if s.p <= 0 {
|
||||
return false
|
||||
}
|
||||
s.t = s.p
|
||||
s.p = strings.LastIndexByte(s.s[:s.p], '.')
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *ReverseDomainScanner) NextLabelOffset() int {
|
||||
return s.p + 1
|
||||
}
|
||||
|
||||
func (s *ReverseDomainScanner) NextLabel() (label string) {
|
||||
return s.s[s.p+1 : s.t]
|
||||
}
|
||||
|
||||
// NormalizeDomain normalize domain string s.
|
||||
// It removes the suffix "." and make sure the domain is in lower case.
|
||||
// e.g. a fqdn "GOOGLE.com." will become "google.com"
|
||||
func NormalizeDomain(s string) string {
|
||||
return strings.ToLower(TrimDot(s))
|
||||
}
|
||||
|
||||
// TrimDot trims suffix '.'
|
||||
func TrimDot(s string) string {
|
||||
if len(s) >= 1 && s[len(s)-1] == '.' {
|
||||
s = s[:len(s)-1]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// labelNode can store dns labels.
|
||||
type labelNode[T any] struct {
|
||||
children map[string]*labelNode[T] // lazy init
|
||||
|
||||
v T
|
||||
hasV bool
|
||||
}
|
||||
|
||||
func (n *labelNode[T]) storeValue(v T) {
|
||||
n.v = v
|
||||
n.hasV = true
|
||||
}
|
||||
|
||||
func (n *labelNode[T]) getValue() (T, bool) {
|
||||
return n.v, n.hasV
|
||||
}
|
||||
|
||||
func (n *labelNode[T]) hasValue() bool {
|
||||
return n.hasV
|
||||
}
|
||||
|
||||
func (n *labelNode[T]) newChild(key string) *labelNode[T] {
|
||||
if n.children == nil {
|
||||
n.children = make(map[string]*labelNode[T])
|
||||
}
|
||||
node := new(labelNode[T])
|
||||
n.children[key] = node
|
||||
return node
|
||||
}
|
||||
|
||||
func (n *labelNode[T]) getChild(key string) *labelNode[T] {
|
||||
return n.children[key]
|
||||
}
|
||||
|
||||
func (n *labelNode[T]) len() int {
|
||||
l := 0
|
||||
for _, node := range n.children {
|
||||
l += node.len()
|
||||
if node.hasValue() {
|
||||
l++
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
61
pkg/matcher/domain/utils_test.go
Normal file
61
pkg/matcher/domain/utils_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package domain
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDomainScanner(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fqdn string
|
||||
wantOffsets []int
|
||||
wantLabels []string
|
||||
}{
|
||||
{"empty", "", []int{}, []string{}},
|
||||
{"root", ".", []int{}, []string{}},
|
||||
{"non fqdn", "a.2", []int{2, 0}, []string{"2", "a"}},
|
||||
{"domain", "1.2.3.", []int{4, 2, 0}, []string{"3", "2", "1"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := NewReverseDomainScanner(tt.fqdn)
|
||||
gotOffsets := make([]int, 0)
|
||||
for s.Scan() {
|
||||
gotOffsets = append(gotOffsets, s.NextLabelOffset())
|
||||
}
|
||||
if !reflect.DeepEqual(gotOffsets, tt.wantOffsets) {
|
||||
t.Errorf("PrevLabelOffset() = %v, want %v", gotOffsets, tt.wantOffsets)
|
||||
}
|
||||
|
||||
s = NewReverseDomainScanner(tt.fqdn)
|
||||
gotLabels := make([]string, 0)
|
||||
for s.Scan() {
|
||||
pl := s.NextLabel()
|
||||
gotLabels = append(gotLabels, pl)
|
||||
}
|
||||
if !reflect.DeepEqual(gotLabels, tt.wantLabels) {
|
||||
t.Errorf("PrevLabel() = %v, want %v", gotLabels, tt.wantLabels)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
28
pkg/matcher/netlist/interface.go
Normal file
28
pkg/matcher/netlist/interface.go
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package netlist
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Matcher interface {
|
||||
Match(addr netip.Addr) bool
|
||||
}
|
||||
150
pkg/matcher/netlist/list.go
Normal file
150
pkg/matcher/netlist/list.go
Normal file
@ -0,0 +1,150 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package netlist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// List is a list of netip.Prefix. It stores all netip.Prefix in one single slice
|
||||
// and use binary search.
|
||||
// It is suitable for large static cidr search.
|
||||
type List struct {
|
||||
// stores valid and masked netip.Prefix(s)
|
||||
e []netip.Prefix
|
||||
sorted bool
|
||||
}
|
||||
|
||||
// NewList returns a *List.
|
||||
func NewList() *List {
|
||||
return &List{
|
||||
e: make([]netip.Prefix, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func mustValid(l []netip.Prefix) {
|
||||
for i, prefix := range l {
|
||||
if !prefix.IsValid() {
|
||||
panic(fmt.Sprintf("invalid prefix at #%d", i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Append appends new netip.Prefix(s) to the list.
|
||||
// This modified the list. Caller must call List.Sort() before calling List.Contains()
|
||||
func (list *List) Append(newNet ...netip.Prefix) {
|
||||
for i, n := range newNet {
|
||||
addr := to6(n.Addr())
|
||||
bits := n.Bits()
|
||||
if n.Addr().Is4() {
|
||||
bits += 96
|
||||
}
|
||||
newNet[i] = netip.PrefixFrom(addr, bits).Masked()
|
||||
}
|
||||
mustValid(newNet)
|
||||
list.e = append(list.e, newNet...)
|
||||
list.sorted = false
|
||||
}
|
||||
|
||||
// Sort sorts the list, this must be called after
|
||||
// list being modified and before calling List.Contains().
|
||||
func (list *List) Sort() {
|
||||
if list.sorted {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Sort(list)
|
||||
out := make([]netip.Prefix, 0)
|
||||
for i, n := range list.e {
|
||||
if i == 0 {
|
||||
out = append(out, n)
|
||||
} else {
|
||||
lv := &out[len(out)-1]
|
||||
switch {
|
||||
case n.Addr() == lv.Addr():
|
||||
if n.Bits() < lv.Bits() {
|
||||
*lv = n
|
||||
}
|
||||
case !lv.Contains(n.Addr()):
|
||||
out = append(out, n)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
list.e = out
|
||||
list.sorted = true
|
||||
}
|
||||
|
||||
// Len implements sort Interface.
|
||||
func (list *List) Len() int {
|
||||
return len(list.e)
|
||||
}
|
||||
|
||||
// Less implements sort Interface.
|
||||
func (list *List) Less(i, j int) bool {
|
||||
return list.e[i].Addr().Less(list.e[j].Addr())
|
||||
}
|
||||
|
||||
// Swap implements sort Interface.
|
||||
func (list *List) Swap(i, j int) {
|
||||
list.e[i], list.e[j] = list.e[j], list.e[i]
|
||||
}
|
||||
|
||||
func (list *List) Match(addr netip.Addr) bool {
|
||||
return list.Contains(addr)
|
||||
}
|
||||
|
||||
// Contains reports whether the list includes the given netip.Addr.
|
||||
func (list *List) Contains(addr netip.Addr) bool {
|
||||
if !list.sorted {
|
||||
panic("list is not sorted")
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
addr = to6(addr)
|
||||
|
||||
i, j := 0, len(list.e)
|
||||
for i < j {
|
||||
h := int(uint(i+j) >> 1) // avoid overflow when computing h
|
||||
if list.e[h].Addr().Compare(addr) <= 0 {
|
||||
i = h + 1
|
||||
} else {
|
||||
j = h
|
||||
}
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return list.e[i-1].Contains(addr)
|
||||
}
|
||||
|
||||
func to6(addr netip.Addr) netip.Addr {
|
||||
if addr.Is6() {
|
||||
return addr
|
||||
}
|
||||
return netip.AddrFrom16(addr.As16())
|
||||
}
|
||||
77
pkg/matcher/netlist/load_helper.go
Normal file
77
pkg/matcher/netlist/load_helper.go
Normal file
@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package netlist
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// LoadFromReader loads IP list from a reader.
|
||||
// It might modify the List and causes List unsorted.
|
||||
func LoadFromReader(l *List, reader io.Reader) error {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
|
||||
// count how many lines we have read.
|
||||
lineCounter := 0
|
||||
for scanner.Scan() {
|
||||
lineCounter++
|
||||
s := scanner.Text()
|
||||
s = strings.TrimSpace(s)
|
||||
s = utils.RemoveComment(s, "#")
|
||||
s = utils.RemoveComment(s, " ")
|
||||
if len(s) == 0 {
|
||||
continue
|
||||
}
|
||||
err := LoadFromText(l, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid data at line #%d: %w", lineCounter, err)
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// LoadFromText loads an IP from s.
|
||||
// It might modify the List and causes List unsorted.
|
||||
func LoadFromText(l *List, s string) error {
|
||||
if strings.ContainsRune(s, '/') {
|
||||
ipNet, err := netip.ParsePrefix(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.Append(ipNet)
|
||||
return nil
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bits := 32
|
||||
if addr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
l.Append(netip.PrefixFrom(addr, bits))
|
||||
return nil
|
||||
}
|
||||
118
pkg/matcher/netlist/netlist_test.go
Normal file
118
pkg/matcher/netlist/netlist_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package netlist
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPNetList_Sort_And_Merge(t *testing.T) {
|
||||
raw := `
|
||||
192.168.0.0/32 # merged
|
||||
192.168.0.0/24 # merged
|
||||
192.168.0.0/16
|
||||
192.168.1.1/24 # merged
|
||||
192.168.9.24/24 # merged
|
||||
192.168.3.0/24 # merged
|
||||
192.169.0.0/16
|
||||
104.16.0.0/12
|
||||
`
|
||||
ipNetList := NewList()
|
||||
err := LoadFromReader(ipNetList, bytes.NewBufferString(raw))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ipNetList.Sort()
|
||||
|
||||
if ipNetList.Len() != 3 {
|
||||
t.Fatalf("unexpected length %d", ipNetList.Len())
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testIP netip.Addr
|
||||
want bool
|
||||
}{
|
||||
{"0", netip.MustParseAddr("192.167.255.255"), false},
|
||||
{"1", netip.MustParseAddr("192.168.0.0"), true},
|
||||
{"2", netip.MustParseAddr("192.168.1.1"), true},
|
||||
{"3", netip.MustParseAddr("192.168.9.255"), true},
|
||||
{"4", netip.MustParseAddr("192.168.255.255"), true},
|
||||
{"5", netip.MustParseAddr("192.169.1.1"), true},
|
||||
{"6", netip.MustParseAddr("192.170.1.1"), false},
|
||||
{"7", netip.MustParseAddr("1.1.1.1"), false},
|
||||
{"8", netip.MustParseAddr("104.16.67.38"), true},
|
||||
{"9", netip.MustParseAddr("104.32.67.38"), false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ipNetList.Match(tt.testIP); got != tt.want {
|
||||
t.Errorf("IPNetList.Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPNetList_New_And_Contains(t *testing.T) {
|
||||
raw := `
|
||||
# comment line
|
||||
1.0.0.0/24 additional strings should be ignored
|
||||
2.0.0.0/23 # comment
|
||||
3.0.0.0
|
||||
|
||||
2000:0000::/32
|
||||
2000:2000::1
|
||||
`
|
||||
|
||||
ipNetList := NewList()
|
||||
err := LoadFromReader(ipNetList, bytes.NewBufferString(raw))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ipNetList.Sort()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testIP netip.Addr
|
||||
want bool
|
||||
}{
|
||||
{"", netip.MustParseAddr("1.0.0.0"), true},
|
||||
{"", netip.MustParseAddr("1.0.0.1"), true},
|
||||
{"", netip.MustParseAddr("1.0.1.0"), false},
|
||||
{"", netip.MustParseAddr("2.0.0.0"), true},
|
||||
{"", netip.MustParseAddr("2.0.1.255"), true},
|
||||
{"", netip.MustParseAddr("2.0.2.0"), false},
|
||||
{"", netip.MustParseAddr("3.0.0.0"), true},
|
||||
{"", netip.MustParseAddr("2000:0000::"), true},
|
||||
{"", netip.MustParseAddr("2000:0000::1"), true},
|
||||
{"", netip.MustParseAddr("2000:0000:1::"), true},
|
||||
{"", netip.MustParseAddr("2000:0001::"), false},
|
||||
{"", netip.MustParseAddr("2000:2000::1"), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ipNetList.Match(tt.testIP); got != tt.want {
|
||||
t.Errorf("IPNetList.Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
154
pkg/nftset_utils/handler.go
Normal file
154
pkg/nftset_utils/handler.go
Normal file
@ -0,0 +1,154 @@
|
||||
//go:build linux
|
||||
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package nftset_utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosed = errors.New("closed handler")
|
||||
)
|
||||
|
||||
// NftSetHandler can add netip.Prefix to the corresponding set.
|
||||
// The table that contains this set must be an inet family table.
|
||||
// If the set has a 'interval' flag, the prefix from netip.Prefix will be
|
||||
// applied.
|
||||
type NftSetHandler struct {
|
||||
opts HandlerOpts
|
||||
|
||||
m sync.Mutex
|
||||
closed bool
|
||||
lastUpdate time.Time
|
||||
set *nftables.Set
|
||||
lastingConn *nftables.Conn // Note: lasting conn is not concurrent safe so m is required.
|
||||
|
||||
disableSetCache bool // for test only
|
||||
}
|
||||
|
||||
type HandlerOpts struct {
|
||||
TableFamily nftables.TableFamily
|
||||
TableName string
|
||||
SetName string
|
||||
}
|
||||
|
||||
// NewNtSetHandler inits NftSetHandler.
|
||||
func NewNtSetHandler(opts HandlerOpts) *NftSetHandler {
|
||||
return &NftSetHandler{
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// getSetLocked get set info from kernel. It has an internal cache and won't
|
||||
// invoke a syscall every time.
|
||||
func (h *NftSetHandler) getSetLocked() (*nftables.Set, error) {
|
||||
const refreshInterval = time.Second
|
||||
|
||||
now := time.Now()
|
||||
if !h.disableSetCache && h.set != nil && now.Sub(h.lastUpdate) < refreshInterval {
|
||||
return h.set, nil
|
||||
}
|
||||
|
||||
// Note: GetSetByName is not concurrent safe.
|
||||
set, err := h.lastingConn.GetSetByName(&nftables.Table{Name: h.opts.TableName, Family: h.opts.TableFamily}, h.opts.SetName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.set = set
|
||||
h.lastUpdate = now
|
||||
return set, nil
|
||||
}
|
||||
|
||||
// AddElems adds netip.Prefix(s) to set in a single batch.
|
||||
func (h *NftSetHandler) AddElems(es ...netip.Prefix) error {
|
||||
h.m.Lock()
|
||||
defer h.m.Unlock()
|
||||
|
||||
if h.closed {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
if h.lastingConn == nil {
|
||||
c, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open netlink, %w", err)
|
||||
}
|
||||
h.lastingConn = c
|
||||
}
|
||||
|
||||
set, err := h.getSetLocked()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get set, %w", err)
|
||||
}
|
||||
|
||||
var elems []nftables.SetElement
|
||||
if set.Interval {
|
||||
elems = make([]nftables.SetElement, 0, 2*len(es))
|
||||
} else {
|
||||
elems = make([]nftables.SetElement, 0, len(es))
|
||||
}
|
||||
|
||||
for i, e := range es {
|
||||
if !e.IsValid() {
|
||||
return fmt.Errorf("invalid prefix at index %d", i)
|
||||
}
|
||||
if set.Interval {
|
||||
start := e.Masked().Addr()
|
||||
elems = append(elems, nftables.SetElement{Key: start.AsSlice(), IntervalEnd: false})
|
||||
|
||||
end := netipx.PrefixLastIP(e).Next() // may be invalid if end is overflowed
|
||||
if end.IsValid() {
|
||||
elems = append(elems, nftables.SetElement{Key: end.AsSlice(), IntervalEnd: true})
|
||||
}
|
||||
} else {
|
||||
elems = append(elems, nftables.SetElement{Key: e.Addr().AsSlice()})
|
||||
}
|
||||
}
|
||||
|
||||
err = h.lastingConn.SetAddElements(set, elems)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.lastingConn.Flush()
|
||||
}
|
||||
|
||||
func (h *NftSetHandler) Close() error {
|
||||
h.m.Lock()
|
||||
defer h.m.Unlock()
|
||||
|
||||
if h.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.closed = true
|
||||
if h.lastingConn != nil {
|
||||
return h.lastingConn.CloseLasting()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
96
pkg/nftset_utils/handler_test.go
Normal file
96
pkg/nftset_utils/handler_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
//go:build linux
|
||||
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package nftset_utils
|
||||
|
||||
import (
|
||||
"github.com/google/nftables"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func skipCI(t *testing.T) {
|
||||
if os.Getenv("TEST_NFTSET") == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
}
|
||||
|
||||
func prepareSet(t testing.TB, tableName, setName string, interval bool) {
|
||||
t.Helper()
|
||||
nc, err := nftables.New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
table := &nftables.Table{Name: tableName, Family: nftables.TableFamilyINet}
|
||||
nc.AddTable(table)
|
||||
if err := nc.AddSet(&nftables.Set{Name: setName, Table: table, KeyType: nftables.TypeIPAddr, Interval: interval}, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := nc.Flush(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AddElems(t *testing.T) {
|
||||
skipCI(t)
|
||||
n := "test"
|
||||
prepareSet(t, n, n, false)
|
||||
|
||||
h := NewNtSetHandler(HandlerOpts{
|
||||
TableFamily: nftables.TableFamilyINet,
|
||||
TableName: n,
|
||||
SetName: n,
|
||||
})
|
||||
h.disableSetCache = true
|
||||
|
||||
if err := h.AddElems(netip.MustParsePrefix("127.0.0.1/24")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nc, err := nftables.New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
elems, err := nc.GetSetElements(h.set)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(elems) == 0 {
|
||||
t.Fatal("set is empty")
|
||||
}
|
||||
|
||||
// test concurrent safe.
|
||||
wg := new(sync.WaitGroup)
|
||||
for i := 0; i < 512; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := h.AddElems(netip.MustParsePrefix("127.0.0.1/24")); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
30
pkg/pool/allocator.go
Normal file
30
pkg/pool/allocator.go
Normal file
@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pool
|
||||
|
||||
import (
|
||||
bytesPool "github.com/IrineSistiana/go-bytes-pool"
|
||||
)
|
||||
|
||||
var (
|
||||
_pool = bytesPool.NewPool(20) // 1Mb pool, should be enough.
|
||||
GetBuf = _pool.Get
|
||||
ReleaseBuf = _pool.Release
|
||||
)
|
||||
53
pkg/pool/bytes_buf.go
Normal file
53
pkg/pool/bytes_buf.go
Normal file
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type BytesBufPool struct {
|
||||
p sync.Pool
|
||||
}
|
||||
|
||||
func NewBytesBufPool(initSize int) *BytesBufPool {
|
||||
if initSize < 0 {
|
||||
panic(fmt.Sprintf("utils.NewBytesBufPool: negative init size %d", initSize))
|
||||
}
|
||||
|
||||
return &BytesBufPool{
|
||||
p: sync.Pool{New: func() any {
|
||||
b := new(bytes.Buffer)
|
||||
b.Grow(initSize)
|
||||
return b
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BytesBufPool) Get() *bytes.Buffer {
|
||||
return p.p.Get().(*bytes.Buffer)
|
||||
}
|
||||
|
||||
func (p *BytesBufPool) Release(b *bytes.Buffer) {
|
||||
b.Reset()
|
||||
p.p.Put(b)
|
||||
}
|
||||
69
pkg/pool/msg_buf.go
Normal file
69
pkg/pool/msg_buf.go
Normal file
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pool
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// There is no such way to give dns.Msg.PackBuffer() a buffer
|
||||
// with a proper size.
|
||||
// Just give it a big buf and hope the buf will be reused in most scenes.
|
||||
const packBufferSize = 8191
|
||||
|
||||
// PackBuffer packs the dns msg m to wire format.
|
||||
// Callers should release the buf by calling ReleaseBuf after they have done
|
||||
// with the wire []byte.
|
||||
func PackBuffer(m *dns.Msg) (*[]byte, error) {
|
||||
packBuf := GetBuf(packBufferSize)
|
||||
defer ReleaseBuf(packBuf)
|
||||
wire, err := m.PackBuffer(*packBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgBuf := GetBuf(len(wire))
|
||||
copy(*msgBuf, wire)
|
||||
return msgBuf, nil
|
||||
}
|
||||
|
||||
// PackBuffer packs the dns msg m to wire format, with to bytes length header.
|
||||
// Callers should release the buf by calling ReleaseBuf.
|
||||
func PackTCPBuffer(m *dns.Msg) (*[]byte, error) {
|
||||
packBuf := GetBuf(packBufferSize)
|
||||
defer ReleaseBuf(packBuf)
|
||||
wire, err := m.PackBuffer((*packBuf)[2:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := len(wire)
|
||||
if l > dns.MaxMsgSize {
|
||||
return nil, fmt.Errorf("dns payload size %d is too large", l)
|
||||
}
|
||||
|
||||
msgBuf := GetBuf(2 + len(wire))
|
||||
binary.BigEndian.PutUint16(*msgBuf, uint16(l))
|
||||
copy((*msgBuf)[2:], wire)
|
||||
return msgBuf, nil
|
||||
}
|
||||
60
pkg/pool/timer.go
Normal file
60
pkg/pool/timer.go
Normal file
@ -0,0 +1,60 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
timerPool = sync.Pool{}
|
||||
)
|
||||
|
||||
func GetTimer(t time.Duration) *time.Timer {
|
||||
timer, ok := timerPool.Get().(*time.Timer)
|
||||
if !ok {
|
||||
return time.NewTimer(t)
|
||||
}
|
||||
if timer.Reset(t) {
|
||||
panic("dispatcher.go getTimer: active timer trapped in timerPool")
|
||||
}
|
||||
return timer
|
||||
}
|
||||
|
||||
func ReleaseTimer(timer *time.Timer) {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timerPool.Put(timer)
|
||||
}
|
||||
|
||||
func ResetAndDrainTimer(timer *time.Timer, d time.Duration) {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(d)
|
||||
}
|
||||
312
pkg/query_context/context.go
Normal file
312
pkg/query_context/context.go
Normal file
@ -0,0 +1,312 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package query_context
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/server"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
const (
|
||||
edns0Size = 1200
|
||||
)
|
||||
|
||||
// Context is a query context that pass through plugins.
|
||||
// All Context funcs are not safe for concurrent use.
|
||||
type Context struct {
|
||||
id uint32
|
||||
startTime time.Time
|
||||
|
||||
// ServerMeta contains some meta info from the server.
|
||||
// It is read-only.
|
||||
ServerMeta ServerMeta
|
||||
query *dns.Msg // always has one question.
|
||||
clientOpt *dns.OPT // may be nil
|
||||
|
||||
resp *dns.Msg
|
||||
respOpt *dns.OPT // nil if clientOpt == nil
|
||||
upstreamOpt *dns.OPT // may be nil
|
||||
|
||||
// lazy init.
|
||||
kv map[uint32]any
|
||||
marks map[uint32]struct{}
|
||||
}
|
||||
|
||||
var contextUid atomic.Uint32
|
||||
|
||||
type ServerMeta = server.QueryMeta
|
||||
|
||||
// NewContext creates a new query Context.
|
||||
// q must have one question.
|
||||
// NewContext takes the ownership of q.
|
||||
func NewContext(q *dns.Msg) *Context {
|
||||
ctx := &Context{
|
||||
id: contextUid.Add(1),
|
||||
startTime: time.Now(),
|
||||
query: q,
|
||||
clientOpt: addNewAndSwapOldOpt(q),
|
||||
}
|
||||
if ctx.clientOpt != nil {
|
||||
ctx.respOpt = newOpt()
|
||||
|
||||
// RFC 3225 3
|
||||
// The DO bit of the query MUST be copied in the response.
|
||||
if ctx.clientOpt.Do() {
|
||||
setDo(ctx.respOpt, true)
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Id returns the Context id.
|
||||
// Note: This id is not the dns msg id.
|
||||
// It's a unique uint32 growing with the number of query.
|
||||
func (ctx *Context) Id() uint32 {
|
||||
return ctx.id
|
||||
}
|
||||
|
||||
// StartTime returns the time when the Context was created.
|
||||
func (ctx *Context) StartTime() time.Time {
|
||||
return ctx.startTime
|
||||
}
|
||||
|
||||
// Q returns the query msg that will be forward to upstream.
|
||||
// It always returns a non-nil msg with one question and EDNS0 OPT.
|
||||
// If Caller want to modify the msg, be sure not to break those conditions.
|
||||
func (ctx *Context) Q() *dns.Msg {
|
||||
return ctx.query
|
||||
}
|
||||
|
||||
// QQuestion returns the query question.
|
||||
func (ctx *Context) QQuestion() dns.Question {
|
||||
return ctx.query.Question[0]
|
||||
}
|
||||
|
||||
// QOpt returns the query opt. It always returns a non-nil opt.
|
||||
// It's a helper func for searching opt in Q() manually.
|
||||
func (ctx *Context) QOpt() *dns.OPT {
|
||||
opt := findOpt(ctx.query)
|
||||
ctx.query.IsEdns0()
|
||||
if opt == nil {
|
||||
panic("query opt is missing")
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
// ClientOpt returns the OPT rr from client. Maybe nil, if client does not send it.
|
||||
// Plugins that responsible for handling EDNS0 option should
|
||||
// check ClientOpt and pick/add options into Q() on demand.
|
||||
// The OPT is read-only.
|
||||
func (ctx *Context) ClientOpt() *dns.OPT {
|
||||
return ctx.clientOpt
|
||||
}
|
||||
|
||||
// SetResponse sets m as response. It takes the ownership of m.
|
||||
// If m is nil. It removes existing response.
|
||||
func (ctx *Context) SetResponse(m *dns.Msg) {
|
||||
ctx.resp = m
|
||||
if m == nil {
|
||||
ctx.upstreamOpt = nil
|
||||
} else {
|
||||
ctx.upstreamOpt = popOpt(m)
|
||||
}
|
||||
}
|
||||
|
||||
// R returns the response that will be sent to client. It might be nil.
|
||||
// Note: R does not have EDNS0. Caller MUST NOT add a dns.OPT into R.
|
||||
// Use RespOpt() instead.
|
||||
func (ctx *Context) R() *dns.Msg {
|
||||
return ctx.resp
|
||||
}
|
||||
|
||||
// RespOpt returns the OPT that will be sent to client.
|
||||
// If client support EDNS0, then RespOpt always returns a non-nil OPT.
|
||||
// No matter what R() returns.
|
||||
// Otherwise, RespOpt returns nil.
|
||||
func (ctx *Context) RespOpt() *dns.OPT {
|
||||
return ctx.respOpt
|
||||
}
|
||||
|
||||
// UpstreamOpt returns the OPT from upstream. May be nil.
|
||||
// Plugins that responsible for handling EDNS0 option should
|
||||
// check UpstreamOpt and pick/add options into RespOpt on demand.
|
||||
// The OPT is read-only.
|
||||
func (ctx *Context) UpstreamOpt() *dns.OPT {
|
||||
return ctx.upstreamOpt
|
||||
}
|
||||
|
||||
// InfoField returns a zap.Field contains a brief summary of this Context.
|
||||
// Useful in log.
|
||||
func (ctx *Context) InfoField() zap.Field {
|
||||
return zap.Object("query", ctx)
|
||||
}
|
||||
|
||||
// Copy deep copies this Context.
|
||||
// See CopyTo.
|
||||
func (ctx *Context) Copy() *Context {
|
||||
newCtx := new(Context)
|
||||
ctx.CopyTo(newCtx)
|
||||
return newCtx
|
||||
}
|
||||
|
||||
// CopyTo deep copies this Context to d.
|
||||
// Note that values that stored by StoreValue is not deep-copied.
|
||||
func (ctx *Context) CopyTo(d *Context) *Context {
|
||||
d.id = ctx.id
|
||||
d.startTime = ctx.startTime
|
||||
|
||||
d.ServerMeta = ctx.ServerMeta
|
||||
d.query = ctx.query.Copy()
|
||||
d.clientOpt = ctx.clientOpt
|
||||
|
||||
if ctx.resp != nil {
|
||||
d.resp = ctx.resp.Copy()
|
||||
}
|
||||
if ctx.respOpt != nil {
|
||||
d.respOpt = dns.Copy(ctx.respOpt).(*dns.OPT)
|
||||
}
|
||||
d.upstreamOpt = ctx.upstreamOpt
|
||||
|
||||
d.kv = copyMap(ctx.kv)
|
||||
d.marks = copyMap(ctx.marks)
|
||||
return d
|
||||
}
|
||||
|
||||
// StoreValue stores any v in to this Context
|
||||
// k MUST from RegKey.
|
||||
func (ctx *Context) StoreValue(k uint32, v any) {
|
||||
if ctx.kv == nil {
|
||||
ctx.kv = make(map[uint32]any)
|
||||
}
|
||||
ctx.kv[k] = v
|
||||
}
|
||||
|
||||
// GetValue returns the value stored by StoreValue.
|
||||
func (ctx *Context) GetValue(k uint32) (any, bool) {
|
||||
v, ok := ctx.kv[k]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// DeleteValue deletes value k from Context
|
||||
func (ctx *Context) DeleteValue(k uint32) {
|
||||
delete(ctx.kv, k)
|
||||
}
|
||||
|
||||
// SetMark marks this Context with given mark.
|
||||
func (ctx *Context) SetMark(m uint32) {
|
||||
if ctx.marks == nil {
|
||||
ctx.marks = make(map[uint32]struct{})
|
||||
}
|
||||
ctx.marks[m] = struct{}{}
|
||||
}
|
||||
|
||||
// HasMark reports whether this mark m was marked by SetMark.
|
||||
func (ctx *Context) HasMark(m uint32) bool {
|
||||
_, ok := ctx.marks[m]
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeleteMark deletes mark m from this Context.
|
||||
func (ctx *Context) DeleteMark(m uint32) {
|
||||
delete(ctx.marks, m)
|
||||
}
|
||||
|
||||
// MarshalLogObject implements zapcore.ObjectMarshaler.
|
||||
func (ctx *Context) MarshalLogObject(encoder zapcore.ObjectEncoder) error {
|
||||
encoder.AddUint32("uqid", ctx.id)
|
||||
|
||||
if clientAddr := ctx.ServerMeta.ClientAddr; clientAddr.IsValid() {
|
||||
zap.Stringer("client", clientAddr).AddTo(encoder)
|
||||
}
|
||||
|
||||
question := ctx.query.Question[0]
|
||||
encoder.AddString("qname", question.Name)
|
||||
encoder.AddUint16("qtype", question.Qtype)
|
||||
encoder.AddUint16("qclass", question.Qclass)
|
||||
|
||||
if r := ctx.resp; r != nil {
|
||||
encoder.AddInt("rcode", r.Rcode)
|
||||
}
|
||||
encoder.AddDuration("elapsed", time.Since(ctx.startTime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyMap[K comparable, V any](m map[K]V) map[K]V {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
cm := make(map[K]V, len(m))
|
||||
for k, v := range m {
|
||||
cm[k] = v
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
func addNewAndSwapOldOpt(m *dns.Msg) *dns.OPT {
|
||||
for i := len(m.Extra) - 1; i >= 0; i-- {
|
||||
// If m has oldOpt
|
||||
if oldOpt, ok := m.Extra[i].(*dns.OPT); ok {
|
||||
// replace it directly
|
||||
m.Extra[i] = newOpt()
|
||||
return oldOpt
|
||||
}
|
||||
}
|
||||
m.Extra = append(m.Extra, newOpt())
|
||||
return nil
|
||||
}
|
||||
|
||||
func popOpt(m *dns.Msg) *dns.OPT {
|
||||
for i := len(m.Extra) - 1; i >= 0; i-- {
|
||||
if opt, ok := m.Extra[i].(*dns.OPT); ok {
|
||||
m.Extra = append(m.Extra[:i], m.Extra[i+1:]...)
|
||||
return opt
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findOpt(m *dns.Msg) *dns.OPT {
|
||||
for i := len(m.Extra) - 1; i >= 0; i-- {
|
||||
if opt, ok := m.Extra[i].(*dns.OPT); ok {
|
||||
return opt
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newOpt() *dns.OPT {
|
||||
opt := new(dns.OPT)
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
opt.SetUDPSize(edns0Size)
|
||||
return opt
|
||||
}
|
||||
|
||||
func setDo(opt *dns.OPT, do bool) {
|
||||
const doBit = 1 << 15 // DNSSEC OK
|
||||
if do {
|
||||
opt.Hdr.Ttl |= doBit
|
||||
}
|
||||
}
|
||||
35
pkg/query_context/kv.go
Normal file
35
pkg/query_context/kv.go
Normal file
@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package query_context
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
var kId atomic.Uint32
|
||||
|
||||
// RegKey returns a unique uint32 for the key used in
|
||||
// Context.StoreValue, Context.GetValue.
|
||||
// It should only be called during initialization.
|
||||
func RegKey() uint32 {
|
||||
i := kId.Add(1)
|
||||
if i == 0 {
|
||||
panic("key id overflowed")
|
||||
}
|
||||
return i
|
||||
}
|
||||
153
pkg/rate_limiter/rate_limiter.go
Normal file
153
pkg/rate_limiter/rate_limiter.go
Normal file
@ -0,0 +1,153 @@
|
||||
package rate_limiter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
tableShards = 32
|
||||
gcInterval = time.Minute
|
||||
)
|
||||
|
||||
type Limiter struct {
|
||||
// Limit and Burst are read-only.
|
||||
Limit rate.Limit
|
||||
Burst int
|
||||
|
||||
closeOnce sync.Once
|
||||
closeNotify chan struct{}
|
||||
tables [tableShards]*tableShard
|
||||
}
|
||||
|
||||
type tableShard struct {
|
||||
m sync.Mutex
|
||||
table map[netip.Addr]*limiterEntry
|
||||
}
|
||||
|
||||
type limiterEntry struct {
|
||||
l *rate.Limiter
|
||||
lastSeen time.Time
|
||||
sync.Once
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new client rate limiter.
|
||||
// limit and burst should be greater than zero. See rate.Limiter for more
|
||||
// details.
|
||||
// Limiter has a internal gc which will run and remove old client entries every 1m.
|
||||
// If the token refill time (burst/limit) is greater than 1m,
|
||||
// the actual average qps limit may be higher than expected because the client status
|
||||
// may be deleted and re-initialized.
|
||||
func NewRateLimiter(limit rate.Limit, burst int) *Limiter {
|
||||
l := &Limiter{
|
||||
Limit: limit,
|
||||
Burst: burst,
|
||||
closeNotify: make(chan struct{}),
|
||||
}
|
||||
|
||||
for i := range l.tables {
|
||||
l.tables[i] = &tableShard{table: make(map[netip.Addr]*limiterEntry)}
|
||||
}
|
||||
|
||||
go l.gcLoop(gcInterval)
|
||||
return l
|
||||
}
|
||||
|
||||
// maskedUnmappedP must be a masked prefix and contain a unmapped addr.
|
||||
func (l *Limiter) Allow(unmappedAddr netip.Addr) bool {
|
||||
now := time.Now()
|
||||
shard := l.getTableShard(unmappedAddr)
|
||||
shard.m.Lock()
|
||||
e, ok := shard.table[unmappedAddr]
|
||||
if !ok {
|
||||
e = &limiterEntry{
|
||||
l: rate.NewLimiter(l.Limit, l.Burst),
|
||||
lastSeen: now,
|
||||
}
|
||||
shard.table[unmappedAddr] = e
|
||||
}
|
||||
e.lastSeen = now
|
||||
shard.m.Unlock()
|
||||
clientLimiter := e.l
|
||||
return clientLimiter.AllowN(now, 1)
|
||||
}
|
||||
|
||||
func (l *Limiter) Close() error {
|
||||
l.closeOnce.Do(func() {
|
||||
close(l.closeNotify)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Limiter) gcLoop(gcInterval time.Duration) {
|
||||
ticker := time.NewTicker(gcInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.closeNotify:
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
l.doGc(now, gcInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) doGc(now time.Time, gcInterval time.Duration) {
|
||||
for _, shard := range l.tables {
|
||||
shard.m.Lock()
|
||||
for a, e := range shard.table {
|
||||
if now.Sub(e.lastSeen) > gcInterval {
|
||||
delete(shard.table, a)
|
||||
}
|
||||
}
|
||||
shard.m.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) getTableShard(unmappedAddr netip.Addr) *tableShard {
|
||||
return l.tables[getTableShardIdx(unmappedAddr)]
|
||||
}
|
||||
|
||||
func (l *Limiter) ForEach(doFunc func(unmappedAddr netip.Addr, r *rate.Limiter) (doBreak bool)) (doBreak bool) {
|
||||
for _, shard := range l.tables {
|
||||
shard.m.Lock()
|
||||
for a, e := range shard.table {
|
||||
doBreak = doFunc(a, e.l)
|
||||
if doBreak {
|
||||
shard.m.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
shard.m.Unlock()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Len returns current number of entries in the Limiter.
|
||||
func (l *Limiter) Len() int {
|
||||
n := 0
|
||||
for _, shard := range l.tables {
|
||||
shard.m.Lock()
|
||||
n += len(shard.table)
|
||||
shard.m.Unlock()
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func getTableShardIdx(unmappedAddr netip.Addr) int {
|
||||
var i byte
|
||||
if unmappedAddr.Is4() {
|
||||
for _, b := range unmappedAddr.As4() {
|
||||
i ^= b
|
||||
}
|
||||
} else {
|
||||
for _, b := range unmappedAddr.As16() {
|
||||
i ^= b
|
||||
}
|
||||
}
|
||||
return int(i % tableShards)
|
||||
}
|
||||
20
pkg/rate_limiter/rate_limiter_test.go
Normal file
20
pkg/rate_limiter/rate_limiter_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package rate_limiter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func BenchmarkXxx(b *testing.B) {
|
||||
now := time.Now()
|
||||
var l *limiterEntry
|
||||
for i := 0; i < b.N; i++ {
|
||||
l = &limiterEntry{
|
||||
l: rate.NewLimiter(0, 0),
|
||||
lastSeen: now,
|
||||
}
|
||||
}
|
||||
_ = l
|
||||
}
|
||||
84
pkg/safe_close/safe_close.go
Normal file
84
pkg/safe_close/safe_close.go
Normal file
@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package safe_close
|
||||
|
||||
import "sync"
|
||||
|
||||
// SafeClose can achieve safe close where WaitClosed returns only after
|
||||
// all sub goroutines exited.
|
||||
//
|
||||
// 1. Main service goroutine starts and wait on ReceiveCloseSignal.
|
||||
// 2. Any service's sub goroutine should be started by Attach and wait on ReceiveCloseSignal.
|
||||
// 3. If any fatal err occurs, any service goroutine can call SendCloseSignal to close the service.
|
||||
// 4. Any third party caller can call SendCloseSignal to close the service.
|
||||
type SafeClose struct {
|
||||
m sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
closeSignal chan struct{}
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func NewSafeClose() *SafeClose {
|
||||
return &SafeClose{
|
||||
closeSignal: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// WaitClosed waits until all SendCloseSignal is called and all
|
||||
// attached funcs in SafeClose are done.
|
||||
func (s *SafeClose) WaitClosed() error {
|
||||
<-s.closeSignal
|
||||
s.wg.Wait()
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
// SendCloseSignal sends a close signal. Unblock WaitClosed.
|
||||
// The given error will be read by WaitClosed.
|
||||
// Once SendCloseSignal is called, following calls are noop.
|
||||
func (s *SafeClose) SendCloseSignal(err error) {
|
||||
s.m.Lock()
|
||||
select {
|
||||
case <-s.closeSignal:
|
||||
default:
|
||||
s.closeErr = err
|
||||
close(s.closeSignal)
|
||||
}
|
||||
s.m.Unlock()
|
||||
}
|
||||
|
||||
func (s *SafeClose) ReceiveCloseSignal() <-chan struct{} {
|
||||
return s.closeSignal
|
||||
}
|
||||
|
||||
// Attach add this goroutine to s.wg WaitClosed.
|
||||
// f must receive closeSignal and call done when it is done.
|
||||
// If s was closed, f will not run.
|
||||
func (s *SafeClose) Attach(f func(done func(), closeSignal <-chan struct{})) {
|
||||
s.m.Lock()
|
||||
select {
|
||||
case <-s.closeSignal:
|
||||
default:
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
f(s.wg.Done, s.closeSignal)
|
||||
}()
|
||||
}
|
||||
s.m.Unlock()
|
||||
}
|
||||
123
pkg/server/doq.go
Normal file
123
pkg/server/doq.go
Normal file
@ -0,0 +1,123 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/quic-go/quic-go"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQuicIdleTimeout = time.Second * 30
|
||||
streamReadTimeout = time.Second * 2
|
||||
quicFirstReadTimeout = time.Second * 2
|
||||
)
|
||||
|
||||
type DoQServerOpts struct {
|
||||
Logger *zap.Logger
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// ServeDoQ starts a server at l. It returns if l had an Accept() error.
|
||||
// It always returns a non-nil error.
|
||||
func ServeDoQ(l *quic.Listener, h Handler, opts DoQServerOpts) error {
|
||||
logger := opts.Logger
|
||||
if logger == nil {
|
||||
logger = nopLogger
|
||||
}
|
||||
idleTimeout := opts.IdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = defaultQuicIdleTimeout
|
||||
}
|
||||
|
||||
listenerCtx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(errListenerCtxCanceled)
|
||||
for {
|
||||
c, err := l.Accept(listenerCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unexpected listener err: %w", err)
|
||||
}
|
||||
|
||||
// handle connection
|
||||
connCtx, cancelConn := context.WithCancelCause(listenerCtx)
|
||||
go func() {
|
||||
defer c.CloseWithError(0, "")
|
||||
defer cancelConn(errConnectionCtxCanceled)
|
||||
|
||||
var clientAddr netip.Addr
|
||||
ta, ok := c.RemoteAddr().(*net.UDPAddr)
|
||||
if ok {
|
||||
clientAddr = ta.AddrPort().Addr()
|
||||
}
|
||||
|
||||
firstRead := true
|
||||
for {
|
||||
var streamAcceptTimeout time.Duration
|
||||
if firstRead {
|
||||
firstRead = false
|
||||
streamAcceptTimeout = quicFirstReadTimeout
|
||||
} else {
|
||||
streamAcceptTimeout = idleTimeout
|
||||
}
|
||||
streamAcceptCtx, cancelStreamAccept := context.WithTimeout(connCtx, streamAcceptTimeout)
|
||||
stream, err := c.AcceptStream(streamAcceptCtx)
|
||||
cancelStreamAccept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle stream.
|
||||
// For doq, one stream, one query.
|
||||
go func() {
|
||||
defer func() {
|
||||
stream.Close()
|
||||
stream.CancelRead(0) // TODO: Needs a proper error code.
|
||||
}()
|
||||
// Avoid fragmentation attack.
|
||||
stream.SetReadDeadline(time.Now().Add(streamReadTimeout))
|
||||
req, _, err := dnsutils.ReadMsgFromTCP(stream)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
queryMeta := QueryMeta{
|
||||
ClientAddr: clientAddr,
|
||||
ServerName: c.ConnectionState().TLS.ServerName,
|
||||
}
|
||||
|
||||
resp := h.Handle(connCtx, req, queryMeta, pool.PackTCPBuffer)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
if _, err := stream.Write(*resp); err != nil {
|
||||
logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
179
pkg/server/http_handler.go
Normal file
179
pkg/server/http_handler.go
Normal file
@ -0,0 +1,179 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type HttpHandlerOpts struct {
|
||||
// GetSrcIPFromHeader specifies the header that contain client source address.
|
||||
// e.g. "X-Forwarded-For".
|
||||
GetSrcIPFromHeader string
|
||||
|
||||
// Logger specifies the logger which Handler writes its log to.
|
||||
// Default is a nop logger.
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
type HttpHandler struct {
|
||||
dnsHandler Handler
|
||||
logger *zap.Logger
|
||||
srcIPHeader string
|
||||
}
|
||||
|
||||
var _ http.Handler = (*HttpHandler)(nil)
|
||||
|
||||
func NewHttpHandler(h Handler, opts HttpHandlerOpts) *HttpHandler {
|
||||
hh := new(HttpHandler)
|
||||
hh.dnsHandler = h
|
||||
hh.srcIPHeader = opts.GetSrcIPFromHeader
|
||||
hh.logger = opts.Logger
|
||||
if hh.logger == nil {
|
||||
hh.logger = nopLogger
|
||||
}
|
||||
return hh
|
||||
}
|
||||
|
||||
func (h *HttpHandler) warnErr(req *http.Request, msg string, err error) {
|
||||
h.logger.Warn(msg, zap.String("from", req.RemoteAddr), zap.String("method", req.Method), zap.String("url", req.RequestURI), zap.Error(err))
|
||||
}
|
||||
|
||||
func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
addrPort, err := netip.ParseAddrPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to parse request remote addr", zap.String("addr", req.RemoteAddr), zap.Error(err))
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
clientAddr := addrPort.Addr()
|
||||
|
||||
// read remote addr from header
|
||||
if header := h.srcIPHeader; len(header) != 0 {
|
||||
if xff := req.Header.Get(header); len(xff) != 0 {
|
||||
addr, err := readClientAddrFromXFF(xff)
|
||||
if err != nil {
|
||||
h.warnErr(req, "failed to get client ip from header", fmt.Errorf("failed to prase header %s: %s, %s", header, xff, err))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientAddr = addr
|
||||
}
|
||||
}
|
||||
|
||||
// read msg
|
||||
q, err := ReadMsgFromReq(req)
|
||||
if err != nil {
|
||||
h.warnErr(req, "invalid request", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
queryMeta := QueryMeta{
|
||||
ClientAddr: clientAddr,
|
||||
}
|
||||
if u := req.URL; u != nil {
|
||||
queryMeta.UrlPath = u.Path
|
||||
}
|
||||
if tlsStat := req.TLS; tlsStat != nil {
|
||||
queryMeta.ServerName = tlsStat.ServerName
|
||||
}
|
||||
resp := h.dnsHandler.Handle(req.Context(), q, queryMeta, pool.PackBuffer)
|
||||
if resp == nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer pool.ReleaseBuf(resp)
|
||||
w.Header().Set("Content-Type", "application/dns-message")
|
||||
if _, err := w.Write(*resp); err != nil {
|
||||
h.warnErr(req, "failed to write response", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func readClientAddrFromXFF(s string) (netip.Addr, error) {
|
||||
if i := strings.IndexRune(s, ','); i > 0 {
|
||||
return netip.ParseAddr(s[:i])
|
||||
}
|
||||
return netip.ParseAddr(s)
|
||||
}
|
||||
|
||||
var errInvalidMediaType = errors.New("missing or invalid media type header")
|
||||
|
||||
var bufPool = pool.NewBytesBufPool(512)
|
||||
|
||||
func ReadMsgFromReq(req *http.Request) (*dns.Msg, error) {
|
||||
var b []byte
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
// Check accept header
|
||||
if req.Header.Get("Accept") != "application/dns-message" {
|
||||
return nil, errInvalidMediaType
|
||||
}
|
||||
|
||||
s := req.URL.Query().Get("dns")
|
||||
if len(s) == 0 {
|
||||
return nil, errors.New("no dns parameter")
|
||||
}
|
||||
msgSize := base64.RawURLEncoding.DecodedLen(len(s))
|
||||
if msgSize > dns.MaxMsgSize {
|
||||
return nil, fmt.Errorf("msg length %d is too big", msgSize)
|
||||
}
|
||||
|
||||
var err error
|
||||
b, err = base64.RawURLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 query: %w", err)
|
||||
}
|
||||
|
||||
case http.MethodPost:
|
||||
// Check Content-Type header
|
||||
if req.Header.Get("Content-Type") != "application/dns-message" {
|
||||
return nil, errInvalidMediaType
|
||||
}
|
||||
|
||||
buf := bufPool.Get()
|
||||
defer bufPool.Release(buf)
|
||||
_, err := buf.ReadFrom(io.LimitReader(req.Body, dns.MaxMsgSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
b = buf.Bytes()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported method: %s", req.Method)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
if err := m.Unpack(b); err != nil {
|
||||
return nil, fmt.Errorf("failed to unpack msg [%x], %w", b, err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
29
pkg/server/iface.go
Normal file
29
pkg/server/iface.go
Normal file
@ -0,0 +1,29 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Handler handles incoming request q and MUST ALWAYS return a response.
|
||||
// Handler MUST handle dns errors by itself and return a proper error responses.
|
||||
// e.g. Return a SERVFAIL if something goes wrong.
|
||||
// If Handle() returns a nil resp, caller will
|
||||
// udp: do nothing.
|
||||
// tcp/dot: close the connection immediately.
|
||||
// doh: send a 500 response.
|
||||
// doq: close the stream immediately.
|
||||
type Handler interface {
|
||||
Handle(ctx context.Context, q *dns.Msg, meta QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) (respPayload *[]byte)
|
||||
}
|
||||
|
||||
type QueryMeta struct {
|
||||
FromUDP bool
|
||||
|
||||
// Optional
|
||||
ClientAddr netip.Addr
|
||||
ServerName string
|
||||
UrlPath string
|
||||
}
|
||||
119
pkg/server/tcp.go
Normal file
119
pkg/server/tcp.go
Normal file
@ -0,0 +1,119 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTCPIdleTimeout = time.Second * 10
|
||||
tcpFirstReadTimeout = time.Second * 2
|
||||
)
|
||||
|
||||
type TCPServerOpts struct {
|
||||
// Nil logger == nop
|
||||
Logger *zap.Logger
|
||||
|
||||
// Default is defaultTCPIdleTimeout.
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// ServeTCP starts a server at l. It returns if l had an Accept() error.
|
||||
// It always returns a non-nil error.
|
||||
func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error {
|
||||
logger := opts.Logger
|
||||
if logger == nil {
|
||||
logger = nopLogger
|
||||
}
|
||||
idleTimeout := opts.IdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = defaultTCPIdleTimeout
|
||||
}
|
||||
firstReadTimeout := tcpFirstReadTimeout
|
||||
if idleTimeout < firstReadTimeout {
|
||||
firstReadTimeout = idleTimeout
|
||||
}
|
||||
|
||||
listenerCtx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(errListenerCtxCanceled)
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unexpected listener err: %w", err)
|
||||
}
|
||||
|
||||
// handle connection
|
||||
tcpConnCtx, cancelConn := context.WithCancelCause(listenerCtx)
|
||||
go func() {
|
||||
defer c.Close()
|
||||
defer cancelConn(errConnectionCtxCanceled)
|
||||
|
||||
firstRead := true
|
||||
for {
|
||||
if firstRead {
|
||||
firstRead = false
|
||||
c.SetReadDeadline(time.Now().Add(firstReadTimeout))
|
||||
} else {
|
||||
c.SetReadDeadline(time.Now().Add(idleTimeout))
|
||||
}
|
||||
req, _, err := dnsutils.ReadMsgFromTCP(c)
|
||||
if err != nil {
|
||||
return // read err, close the connection
|
||||
}
|
||||
|
||||
// Try to get server name from tls conn.
|
||||
var serverName string
|
||||
if tlsConn, ok := c.(*tls.Conn); ok {
|
||||
serverName = tlsConn.ConnectionState().ServerName
|
||||
}
|
||||
|
||||
// handle query
|
||||
go func() {
|
||||
var clientAddr netip.Addr
|
||||
ta, ok := c.RemoteAddr().(*net.TCPAddr)
|
||||
if ok {
|
||||
clientAddr = ta.AddrPort().Addr()
|
||||
}
|
||||
r := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr, ServerName: serverName}, pool.PackTCPBuffer)
|
||||
if r == nil {
|
||||
c.Close() // abort the connection
|
||||
return
|
||||
}
|
||||
defer pool.ReleaseBuf(r)
|
||||
|
||||
if _, err := c.Write(*r); err != nil {
|
||||
logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err))
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
33
pkg/server/tls.go
Normal file
33
pkg/server/tls.go
Normal file
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
func LoadCert(tlsCfg *tls.Config, cert, key string) error {
|
||||
c, err := tls.LoadX509KeyPair(cert, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tlsCfg.Certificates = []tls.Certificate{c}
|
||||
return nil
|
||||
}
|
||||
109
pkg/server/udp.go
Normal file
109
pkg/server/udp.go
Normal file
@ -0,0 +1,109 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type UDPServerOpts struct {
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
// ServeUDP starts a server at c. It returns if c had a read error.
|
||||
// It always returns a non-nil error.
|
||||
// h is required. logger is optional.
|
||||
func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error {
|
||||
logger := opts.Logger
|
||||
if logger == nil {
|
||||
logger = nopLogger
|
||||
}
|
||||
|
||||
listenerCtx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(errListenerCtxCanceled)
|
||||
|
||||
rb := pool.GetBuf(dns.MaxMsgSize)
|
||||
defer pool.ReleaseBuf(rb)
|
||||
|
||||
oobReader, oobWriter, err := initOobHandler(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init oob handler, %w", err)
|
||||
}
|
||||
var ob []byte
|
||||
if oobReader != nil {
|
||||
obp := pool.GetBuf(1024)
|
||||
defer pool.ReleaseBuf(obp)
|
||||
ob = *obp
|
||||
}
|
||||
|
||||
for {
|
||||
n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob)
|
||||
if err != nil {
|
||||
if n == 0 {
|
||||
// Err with zero read. Most likely because c was closed.
|
||||
return fmt.Errorf("unexpected read err: %w", err)
|
||||
}
|
||||
// Temporary err.
|
||||
logger.Warn("read err", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
q := new(dns.Msg)
|
||||
if err := q.Unpack((*rb)[:n]); err != nil {
|
||||
logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr))
|
||||
continue
|
||||
}
|
||||
|
||||
var dstIpFromCm net.IP
|
||||
if oobReader != nil {
|
||||
var err error
|
||||
dstIpFromCm, err = oobReader(ob[:oobn])
|
||||
if err != nil {
|
||||
logger.Error("failed to get dst address from oob", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// handle query
|
||||
go func() {
|
||||
payload := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}, pool.PackBuffer)
|
||||
if payload == nil {
|
||||
return
|
||||
}
|
||||
defer pool.ReleaseBuf(payload)
|
||||
|
||||
var oob []byte
|
||||
if oobWriter != nil && dstIpFromCm != nil {
|
||||
oob = oobWriter(dstIpFromCm)
|
||||
}
|
||||
if _, _, err := c.WriteMsgUDPAddrPort(*payload, oob, remoteAddr); err != nil {
|
||||
logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
type getSrcAddrFromOOB func(oob []byte) (net.IP, error)
|
||||
type writeSrcAddrToOOB func(a net.IP) []byte
|
||||
125
pkg/server/udp_linux.go
Normal file
125
pkg/server/udp_linux.go
Normal file
@ -0,0 +1,125 @@
|
||||
//go:build linux
|
||||
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
errCmNoDstAddr = errors.New("control msg does not have dst address")
|
||||
)
|
||||
|
||||
func getOOBFromCM4(oob []byte) (net.IP, error) {
|
||||
var cm ipv4.ControlMessage
|
||||
if err := cm.Parse(oob); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cm.Dst == nil {
|
||||
return nil, errCmNoDstAddr
|
||||
}
|
||||
return cm.Dst, nil
|
||||
}
|
||||
|
||||
func getOOBFromCM6(oob []byte) (net.IP, error) {
|
||||
var cm ipv6.ControlMessage
|
||||
if err := cm.Parse(oob); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cm.Dst == nil {
|
||||
return nil, errCmNoDstAddr
|
||||
}
|
||||
return cm.Dst, nil
|
||||
}
|
||||
|
||||
func srcIP2Cm(ip net.IP) []byte {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return (&ipv4.ControlMessage{
|
||||
Src: ip,
|
||||
}).Marshal()
|
||||
}
|
||||
|
||||
if ip6 := ip.To16(); ip6 != nil {
|
||||
return (&ipv6.ControlMessage{
|
||||
Src: ip,
|
||||
}).Marshal()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) {
|
||||
if !c.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
sc, err := c.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var getter getSrcAddrFromOOB
|
||||
var setter writeSrcAddrToOOB
|
||||
var controlErr error
|
||||
if err := sc.Control(func(fd uintptr) {
|
||||
v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN)
|
||||
if err != nil {
|
||||
controlErr = os.NewSyscallError("failed to get SO_PROTOCOL", err)
|
||||
return
|
||||
}
|
||||
switch v {
|
||||
case unix.AF_INET:
|
||||
c4 := ipv4.NewPacketConn(c)
|
||||
if err := c4.SetControlMessage(ipv4.FlagDst, true); err != nil {
|
||||
controlErr = fmt.Errorf("failed to set ipv4 cmsg flags, %w", err)
|
||||
}
|
||||
|
||||
getter = getOOBFromCM4
|
||||
setter = srcIP2Cm
|
||||
return
|
||||
case unix.AF_INET6:
|
||||
c6 := ipv6.NewPacketConn(c)
|
||||
if err := c6.SetControlMessage(ipv6.FlagDst, true); err != nil {
|
||||
controlErr = fmt.Errorf("failed to set ipv6 cmsg flags, %w", err)
|
||||
}
|
||||
getter = getOOBFromCM6
|
||||
setter = srcIP2Cm
|
||||
return
|
||||
default:
|
||||
controlErr = fmt.Errorf("socket protocol %d is not supported", v)
|
||||
}
|
||||
}); err != nil {
|
||||
return nil, nil, fmt.Errorf("control fd err, %w", controlErr)
|
||||
}
|
||||
|
||||
if controlErr != nil {
|
||||
return nil, nil, fmt.Errorf("failed to set up socket, %w", controlErr)
|
||||
}
|
||||
return getter, setter, nil
|
||||
}
|
||||
28
pkg/server/udp_others.go
Normal file
28
pkg/server/udp_others.go
Normal file
@ -0,0 +1,28 @@
|
||||
//go:build !linux
|
||||
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import "net"
|
||||
|
||||
func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
16
pkg/server/utils.go
Normal file
16
pkg/server/utils.go
Normal file
@ -0,0 +1,16 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
errListenerCtxCanceled = errors.New("listener ctx canceled")
|
||||
errConnectionCtxCanceled = errors.New("connection ctx canceled")
|
||||
)
|
||||
|
||||
var (
|
||||
nopLogger = zap.NewNop()
|
||||
)
|
||||
154
pkg/server_handler/entry_handler.go
Normal file
154
pkg/server_handler/entry_handler.go
Normal file
@ -0,0 +1,154 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server_handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/mlog"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/server"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
"github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueryTimeout = time.Second * 5
|
||||
)
|
||||
|
||||
var (
|
||||
nopLogger = mlog.Nop()
|
||||
|
||||
// options that can forward to upstream
|
||||
queryForwardEDNS0Option = map[uint16]struct{}{
|
||||
dns.EDNS0SUBNET: {},
|
||||
}
|
||||
|
||||
// options that useless for downstream
|
||||
respRemoveEDNS0Option = map[uint16]struct{}{
|
||||
dns.EDNS0PADDING: {},
|
||||
}
|
||||
)
|
||||
|
||||
type EntryHandlerOpts struct {
|
||||
// Logger is used for logging. Default is a noop logger.
|
||||
Logger *zap.Logger
|
||||
|
||||
// Required.
|
||||
Entry sequence.Executable
|
||||
|
||||
// QueryTimeout limits the timeout value of each query.
|
||||
// Default is defaultQueryTimeout.
|
||||
QueryTimeout time.Duration
|
||||
}
|
||||
|
||||
func (opts *EntryHandlerOpts) init() {
|
||||
if opts.Logger == nil {
|
||||
opts.Logger = nopLogger
|
||||
}
|
||||
utils.SetDefaultNum(&opts.QueryTimeout, defaultQueryTimeout)
|
||||
}
|
||||
|
||||
type EntryHandler struct {
|
||||
opts EntryHandlerOpts
|
||||
}
|
||||
|
||||
var _ server.Handler = (*EntryHandler)(nil)
|
||||
|
||||
func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler {
|
||||
opts.init()
|
||||
return &EntryHandler{opts: opts}
|
||||
}
|
||||
|
||||
// ServeDNS implements server.Handler.
|
||||
// If entry returns an error, a SERVFAIL response will be returned.
|
||||
// If entry returns without a response, a REFUSED response will be returned.
|
||||
func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, serverMeta server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte {
|
||||
// basic query check.
|
||||
if q.Response || len(q.Question) != 1 || len(q.Answer)+len(q.Ns) > 0 || len(q.Extra) > 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ddl := time.Now().Add(h.opts.QueryTimeout)
|
||||
ctx, cancel := context.WithDeadline(ctx, ddl)
|
||||
defer cancel()
|
||||
|
||||
qCtx := query_context.NewContext(q)
|
||||
qCtx.ServerMeta = serverMeta
|
||||
|
||||
// exec entry
|
||||
err := h.opts.Entry.Exec(ctx, qCtx)
|
||||
var resp *dns.Msg
|
||||
if err != nil {
|
||||
h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err))
|
||||
resp = new(dns.Msg)
|
||||
resp.SetReply(q)
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
} else {
|
||||
resp = qCtx.R()
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
resp = new(dns.Msg)
|
||||
resp.SetReply(q)
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
}
|
||||
// We assume that our server is a forwarder.
|
||||
resp.RecursionAvailable = true
|
||||
|
||||
// add respOpt back to resp
|
||||
if respOpt := qCtx.RespOpt(); respOpt != nil {
|
||||
resp.Extra = append(resp.Extra, respOpt)
|
||||
}
|
||||
|
||||
if serverMeta.FromUDP {
|
||||
udpSize := getValidUDPSize(qCtx.ClientOpt())
|
||||
resp.Truncate(udpSize)
|
||||
}
|
||||
|
||||
payload, err := packMsgPayload(resp)
|
||||
if err != nil {
|
||||
h.opts.Logger.Error("internal err: failed to pack resp msg", qCtx.InfoField(), zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// opt can be nil.
|
||||
func getValidUDPSize(opt *dns.OPT) int {
|
||||
var s uint16
|
||||
if opt != nil {
|
||||
s = opt.UDPSize()
|
||||
}
|
||||
if s < dns.MinMsgSize {
|
||||
s = dns.MinMsgSize
|
||||
}
|
||||
return int(s)
|
||||
}
|
||||
|
||||
func newOpt() *dns.OPT {
|
||||
opt := new(dns.OPT)
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
return opt
|
||||
}
|
||||
248
pkg/upstream/bootstrap/bootstrap.go
Normal file
248
pkg/upstream/bootstrap/bootstrap.go
Normal file
@ -0,0 +1,248 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
minimumUpdateInterval = time.Minute * 5
|
||||
retryInterval = time.Second * 2
|
||||
queryTimeout = time.Second * 5
|
||||
)
|
||||
|
||||
var (
|
||||
errNoAddrInResp = errors.New("resp does not have ip address")
|
||||
)
|
||||
|
||||
func New(
|
||||
host string,
|
||||
port uint16,
|
||||
bootstrapServer netip.AddrPort,
|
||||
bootstrapVer int, // 0,4,6
|
||||
logger *zap.Logger, // not nil
|
||||
) (*Bootstrap, error) {
|
||||
dp := new(Bootstrap)
|
||||
dp.fqdn = dns.Fqdn(host)
|
||||
dp.port = port
|
||||
if !bootstrapServer.IsValid() {
|
||||
return nil, errors.New("invalid bootstrap server address")
|
||||
}
|
||||
dp.bootstrap = net.UDPAddrFromAddrPort(bootstrapServer)
|
||||
qt, ok := bootstrapVer2Qt(bootstrapVer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid bootstrap version %d", bootstrapVer)
|
||||
}
|
||||
dp.qt = qt
|
||||
dp.logger = logger
|
||||
|
||||
dp.readyNotify = make(chan struct{})
|
||||
return dp, nil
|
||||
}
|
||||
|
||||
type Bootstrap struct {
|
||||
fqdn string
|
||||
port uint16
|
||||
bootstrap *net.UDPAddr
|
||||
qt uint16 // dns.TypeA or dns.TypeAAAA
|
||||
logger *zap.Logger // not nil
|
||||
|
||||
updating atomic.Bool
|
||||
nextUpdate time.Time
|
||||
|
||||
readyNotify chan struct{}
|
||||
m sync.Mutex
|
||||
ready bool
|
||||
addrStr string
|
||||
}
|
||||
|
||||
func (sp *Bootstrap) GetAddrPortStr(ctx context.Context) (string, error) {
|
||||
sp.tryUpdate()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", context.Cause(ctx)
|
||||
case <-sp.readyNotify:
|
||||
}
|
||||
|
||||
sp.m.Lock()
|
||||
addr := sp.addrStr
|
||||
sp.m.Unlock()
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
func (sp *Bootstrap) tryUpdate() {
|
||||
if sp.updating.CompareAndSwap(false, true) {
|
||||
if time.Now().After(sp.nextUpdate) {
|
||||
go func() {
|
||||
defer sp.updating.Store(false)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), queryTimeout)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
addr, ttl, err := sp.updateAddr(ctx)
|
||||
if err != nil {
|
||||
sp.logger.Check(zap.WarnLevel, "failed to update bootstrap addr").Write(
|
||||
zap.String("fqdn", sp.fqdn),
|
||||
zap.Error(err),
|
||||
)
|
||||
sp.nextUpdate = time.Now().Add(retryInterval)
|
||||
} else {
|
||||
updateInterval := time.Second * time.Duration(ttl)
|
||||
if updateInterval < minimumUpdateInterval {
|
||||
updateInterval = minimumUpdateInterval
|
||||
}
|
||||
sp.logger.Check(zap.DebugLevel, "bootstrap addr updated").Write(
|
||||
zap.String("fqdn", sp.fqdn),
|
||||
zap.Stringer("addr", addr),
|
||||
zap.Duration("ttl", updateInterval),
|
||||
zap.Duration("elapse", time.Since(start)),
|
||||
)
|
||||
sp.nextUpdate = time.Now().Add(updateInterval)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
sp.updating.Store(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *Bootstrap) updateAddr(ctx context.Context) (netip.Addr, uint32, error) {
|
||||
addr, ttl, err := sp.resolve(ctx, sp.qt)
|
||||
if err != nil {
|
||||
return netip.Addr{}, 0, err
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(addr, sp.port).String()
|
||||
sp.m.Lock()
|
||||
sp.addrStr = addrPort
|
||||
if !sp.ready {
|
||||
sp.ready = true
|
||||
close(sp.readyNotify)
|
||||
}
|
||||
sp.m.Unlock()
|
||||
return addr, ttl, nil
|
||||
}
|
||||
|
||||
func (sp *Bootstrap) resolve(ctx context.Context, qt uint16) (netip.Addr, uint32, error) {
|
||||
const edns0UdpSize = 1200
|
||||
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion(sp.fqdn, qt)
|
||||
q.SetEdns0(edns0UdpSize, false)
|
||||
|
||||
c, err := net.DialUDP("udp", nil, sp.bootstrap)
|
||||
if err != nil {
|
||||
return netip.Addr{}, 0, err
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
writeErrC := make(chan error, 1)
|
||||
type res struct {
|
||||
resp *dns.Msg
|
||||
err error
|
||||
}
|
||||
readResC := make(chan res, 1)
|
||||
|
||||
cancelWrite := make(chan struct{})
|
||||
defer close(cancelWrite)
|
||||
go func() {
|
||||
if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil {
|
||||
writeErrC <- err
|
||||
return
|
||||
}
|
||||
|
||||
retryTicker := time.NewTicker(time.Second)
|
||||
defer retryTicker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-cancelWrite:
|
||||
return
|
||||
case <-retryTicker.C:
|
||||
if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil {
|
||||
writeErrC <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
m, _, err := dnsutils.ReadMsgFromUDP(c, edns0UdpSize)
|
||||
readResC <- res{resp: m, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return netip.Addr{}, 0, context.Cause(ctx)
|
||||
case err := <-writeErrC:
|
||||
return netip.Addr{}, 0, fmt.Errorf("failed to write query, %w", err)
|
||||
case r := <-readResC:
|
||||
resp := r.resp
|
||||
err := r.err
|
||||
if err != nil {
|
||||
return netip.Addr{}, 0, fmt.Errorf("failed to read resp, %w", err)
|
||||
}
|
||||
|
||||
for _, v := range resp.Answer {
|
||||
var ip net.IP
|
||||
var ttl uint32
|
||||
switch rr := v.(type) {
|
||||
case *dns.A:
|
||||
ip = rr.A
|
||||
ttl = rr.Hdr.Ttl
|
||||
case *dns.AAAA:
|
||||
ip = rr.AAAA
|
||||
ttl = rr.Hdr.Ttl
|
||||
default:
|
||||
continue
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if ok {
|
||||
return addr, ttl, nil
|
||||
}
|
||||
}
|
||||
|
||||
// No ip addr in resp.
|
||||
return netip.Addr{}, 0, errNoAddrInResp
|
||||
}
|
||||
}
|
||||
|
||||
func bootstrapVer2Qt(ver int) (uint16, bool) {
|
||||
switch ver {
|
||||
case 0, 4:
|
||||
return dns.TypeA, true
|
||||
case 6:
|
||||
return dns.TypeAAAA, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
166
pkg/upstream/doh/upstream.go
Normal file
166
pkg/upstream/doh/upstream.go
Normal file
@ -0,0 +1,166 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package doh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
urlpkg "net/url"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDoHTimeout = time.Second * 6
|
||||
)
|
||||
|
||||
var nopLogger = zap.NewNop()
|
||||
|
||||
// Upstream is a DNS-over-HTTPS (RFC 8484) upstream.
|
||||
type Upstream struct {
|
||||
rt http.RoundTripper
|
||||
logger *zap.Logger // non-nil
|
||||
urlTemplate *urlpkg.URL
|
||||
reqTemplate *http.Request
|
||||
}
|
||||
|
||||
func NewUpstream(endPoint string, rt http.RoundTripper, logger *zap.Logger) (*Upstream, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, endPoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse http request, %w", err)
|
||||
}
|
||||
|
||||
req.Header["Accept"] = []string{"application/dns-message"}
|
||||
req.Header["User-Agent"] = nil // Don't let go http send a default user agent header.
|
||||
|
||||
if logger == nil {
|
||||
logger = nopLogger
|
||||
}
|
||||
return &Upstream{
|
||||
rt: rt,
|
||||
logger: logger,
|
||||
urlTemplate: req.URL,
|
||||
reqTemplate: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var (
|
||||
bufPool4k = pool.NewBytesBufPool(4096)
|
||||
)
|
||||
|
||||
func (u *Upstream) ExchangeContext(ctx context.Context, q []byte) (*[]byte, error) {
|
||||
bp := pool.GetBuf(len(q))
|
||||
defer pool.ReleaseBuf(bp)
|
||||
wire := *bp
|
||||
copy(wire, q)
|
||||
|
||||
// In order to maximize HTTP cache friendliness, DoH clients using media
|
||||
// formats that include the ID field from the DNS message header, such
|
||||
// as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS
|
||||
// request.
|
||||
// https://tools.ietf.org/html/rfc8484#section-4.1
|
||||
wire[0] = 0
|
||||
wire[1] = 0
|
||||
|
||||
queryLen := 4 + base64.RawURLEncoding.EncodedLen(len(wire))
|
||||
queryBuf := make([]byte, queryLen)
|
||||
|
||||
p := 0
|
||||
p += copy(queryBuf, "dns=")
|
||||
|
||||
// Padding characters for base64url MUST NOT be included.
|
||||
// See: https://tools.ietf.org/html/rfc8484#section-6.
|
||||
base64.RawURLEncoding.Encode(queryBuf[p:], wire)
|
||||
|
||||
type res struct {
|
||||
r *[]byte
|
||||
err error
|
||||
}
|
||||
|
||||
resChan := make(chan res, 1)
|
||||
go func() {
|
||||
// We overwrite the ctx with a fixed timeout context here.
|
||||
// Because the http package may close the underlay connection
|
||||
// if the context is done before the query is completed. This
|
||||
// reduces the connection reuse efficiency.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDoHTimeout)
|
||||
defer cancel()
|
||||
r, err := u.exchange(ctx, utils.BytesToStringUnsafe(queryBuf))
|
||||
if err != nil {
|
||||
u.logger.Check(zap.WarnLevel, "exchange failed").Write(zap.Error(err))
|
||||
}
|
||||
resChan <- res{r: r, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, context.Cause(ctx)
|
||||
case res := <-resChan:
|
||||
r := res.r
|
||||
err := res.err
|
||||
if r != nil {
|
||||
binary.BigEndian.PutUint16(*r, binary.BigEndian.Uint16(q))
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Upstream) exchange(ctx context.Context, dnsQuery string) (*[]byte, error) {
|
||||
req := u.reqTemplate.WithContext(ctx)
|
||||
req.URL = new(urlpkg.URL)
|
||||
*req.URL = *u.urlTemplate
|
||||
req.URL.RawQuery = dnsQuery
|
||||
resp, err := u.rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// check status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body1k, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
if body1k != nil {
|
||||
return nil, fmt.Errorf("bad http status codes %d with body [%s]", resp.StatusCode, body1k)
|
||||
}
|
||||
return nil, fmt.Errorf("bad http status codes %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
bb := bufPool4k.Get()
|
||||
defer bufPool4k.Release(bb)
|
||||
_, err = bb.ReadFrom(io.LimitReader(resp.Body, dns.MaxMsgSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read http body: %w", err)
|
||||
}
|
||||
if bb.Len() < dnsutils.DnsHeaderLen {
|
||||
return nil, dnsutils.ErrPayloadTooSmall
|
||||
}
|
||||
payload := pool.GetBuf(bb.Len())
|
||||
copy(*payload, bb.Bytes())
|
||||
return payload, nil
|
||||
}
|
||||
70
pkg/upstream/event_stat.go
Normal file
70
pkg/upstream/event_stat.go
Normal file
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type Event int
|
||||
|
||||
const (
|
||||
EventConnOpen Event = iota
|
||||
EventConnClose
|
||||
)
|
||||
|
||||
type EventObserver interface {
|
||||
OnEvent(typ Event)
|
||||
}
|
||||
|
||||
type nopEO struct{}
|
||||
|
||||
func (n nopEO) OnEvent(_ Event) {}
|
||||
|
||||
type connWrapper struct {
|
||||
net.Conn
|
||||
closed atomic.Bool
|
||||
ob EventObserver
|
||||
}
|
||||
|
||||
// wrapConn wraps c into a connWrapper so that we can observe the connection close.
|
||||
// For convenient, if c is nil, wrapConn returns nil as well. If ob is nopEO, wrapConn
|
||||
// returns c.
|
||||
func wrapConn(c net.Conn, ob EventObserver) net.Conn {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if _, ok := ob.(nopEO); ok {
|
||||
return c
|
||||
}
|
||||
ob.OnEvent(EventConnOpen)
|
||||
return &connWrapper{
|
||||
Conn: c,
|
||||
ob: ob,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connWrapper) Close() error {
|
||||
if c.closed.CompareAndSwap(false, true) {
|
||||
c.ob.OnEvent(EventConnClose)
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
165
pkg/upstream/transport/conn_lazy_dial.go
Normal file
165
pkg/upstream/transport/conn_lazy_dial.go
Normal file
@ -0,0 +1,165 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type lazyDnsConn struct {
|
||||
maxConcurrentQuery int
|
||||
cancelDial context.CancelFunc
|
||||
|
||||
mu sync.Mutex
|
||||
earlyReserveCallWg sync.WaitGroup
|
||||
closed bool
|
||||
reservedQuery int
|
||||
dialFinished chan struct{}
|
||||
c DnsConn
|
||||
dialErr error
|
||||
|
||||
// 1: dial completed and all early reserve call finished.
|
||||
// 2: dial failed.
|
||||
fastPath atomic.Uint32
|
||||
}
|
||||
|
||||
var _ DnsConn = (*lazyDnsConn)(nil)
|
||||
|
||||
var (
|
||||
errLazyConnDialCanceled = errors.New("lazy dial canceled")
|
||||
)
|
||||
|
||||
func newLazyDnsConn(
|
||||
dial func(ctx context.Context) (DnsConn, error),
|
||||
dialTimeout time.Duration,
|
||||
maxConcurrentQueryWhileDialing int, // must be valid, no default value
|
||||
logger *zap.Logger, // must non-nil
|
||||
) *lazyDnsConn {
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = defaultDialTimeout
|
||||
}
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), defaultDialTimeout)
|
||||
lc := &lazyDnsConn{
|
||||
maxConcurrentQuery: maxConcurrentQueryWhileDialing,
|
||||
cancelDial: cancelDial,
|
||||
dialFinished: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func() {
|
||||
dc, err := dial(dialCtx)
|
||||
cancelDial()
|
||||
if err != nil {
|
||||
logger.Check(zap.WarnLevel, "failed to dial dns conn").Write(zap.Error(err))
|
||||
}
|
||||
lc.mu.Lock()
|
||||
if lc.closed { // lc was closed and dial was canceled
|
||||
lc.mu.Unlock()
|
||||
if dc != nil {
|
||||
dc.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
lc.c = dc
|
||||
lc.dialErr = err
|
||||
close(lc.dialFinished)
|
||||
lc.mu.Unlock()
|
||||
}()
|
||||
return lc
|
||||
}
|
||||
|
||||
func (lc *lazyDnsConn) Close() error {
|
||||
lc.mu.Lock()
|
||||
defer lc.mu.Unlock()
|
||||
|
||||
if lc.closed {
|
||||
return nil
|
||||
}
|
||||
lc.closed = true
|
||||
|
||||
if lc.c == nil && lc.dialErr == nil { // still dialing
|
||||
lc.cancelDial()
|
||||
lc.dialErr = errLazyConnDialCanceled
|
||||
close(lc.dialFinished)
|
||||
} else {
|
||||
// close connection
|
||||
if lc.c != nil {
|
||||
lc.c.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lc *lazyDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
|
||||
switch lc.fastPath.Load() {
|
||||
case 1:
|
||||
return lc.c.ReserveNewQuery()
|
||||
case 2:
|
||||
return nil, true
|
||||
}
|
||||
|
||||
lc.mu.Lock()
|
||||
defer lc.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-lc.dialFinished:
|
||||
// Note: race condition here and lazyDnsConnEarlyReservedExchanger.ExchangeReserved().
|
||||
// Not a big problem. May cause at most all early exchange failed.
|
||||
// earlyExchangeWg makes sure that early exchange calls ReserveNewQuery first.
|
||||
dc, err := lc.c, lc.dialErr
|
||||
if err != nil {
|
||||
lc.fastPath.Store(2)
|
||||
return nil, true
|
||||
}
|
||||
lc.earlyReserveCallWg.Wait()
|
||||
lc.fastPath.Store(1)
|
||||
return dc.ReserveNewQuery()
|
||||
default:
|
||||
if lc.reservedQuery >= lc.maxConcurrentQuery {
|
||||
return nil, false
|
||||
}
|
||||
lc.reservedQuery++
|
||||
lc.earlyReserveCallWg.Add(1)
|
||||
return (*lazyDnsConnEarlyReservedExchanger)(lc), false
|
||||
}
|
||||
}
|
||||
|
||||
type lazyDnsConnEarlyReservedExchanger lazyDnsConn
|
||||
|
||||
var _ ReservedExchanger = (*lazyDnsConnEarlyReservedExchanger)(nil)
|
||||
|
||||
func (ote *lazyDnsConnEarlyReservedExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) {
|
||||
defer func() {
|
||||
ote.mu.Lock()
|
||||
ote.reservedQuery--
|
||||
ote.mu.Unlock()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ote.earlyReserveCallWg.Done()
|
||||
return nil, context.Cause(ctx)
|
||||
case <-ote.dialFinished:
|
||||
dc, err := ote.c, ote.dialErr
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rec, _ := dc.ReserveNewQuery()
|
||||
ote.earlyReserveCallWg.Done()
|
||||
if rec == nil {
|
||||
return nil, ErrLazyConnCannotReserveQueryExchanger
|
||||
}
|
||||
return rec.ExchangeReserved(ctx, q)
|
||||
}
|
||||
}
|
||||
|
||||
func (ote *lazyDnsConnEarlyReservedExchanger) WithdrawReserved() {
|
||||
ote.earlyReserveCallWg.Done()
|
||||
ote.mu.Lock()
|
||||
ote.reservedQuery--
|
||||
ote.mu.Unlock()
|
||||
}
|
||||
123
pkg/upstream/transport/conn_quic.go
Normal file
123
pkg/upstream/transport/conn_quic.go
Normal file
@ -0,0 +1,123 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
const (
|
||||
quicQueryTimeout = time.Second * 6
|
||||
)
|
||||
|
||||
const (
|
||||
// RFC 9250 4.3. DoQ Error Codes
|
||||
_DOQ_NO_ERROR = quic.StreamErrorCode(0x0)
|
||||
_DOQ_INTERNAL_ERROR = quic.StreamErrorCode(0x1)
|
||||
_DOQ_REQUEST_CANCELLED = quic.StreamErrorCode(0x3)
|
||||
)
|
||||
|
||||
var _ DnsConn = (*QuicDnsConn)(nil)
|
||||
|
||||
type QuicDnsConn struct {
|
||||
c quic.Connection
|
||||
}
|
||||
|
||||
func NewQuicDnsConn(c quic.Connection) *QuicDnsConn {
|
||||
return &QuicDnsConn{c: c}
|
||||
}
|
||||
|
||||
func (c *QuicDnsConn) Close() error {
|
||||
return c.c.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
func (c *QuicDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
|
||||
select {
|
||||
case <-c.c.Context().Done():
|
||||
return nil, true
|
||||
default:
|
||||
}
|
||||
s, err := c.c.OpenStream()
|
||||
// We just checked the connection is alive. So we are assuming the error
|
||||
// is caused by reaching the peer's stream limit.
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return &quicReservedExchanger{stream: s}, false
|
||||
}
|
||||
|
||||
type quicReservedExchanger struct {
|
||||
stream quic.Stream
|
||||
}
|
||||
|
||||
var _ ReservedExchanger = (*quicReservedExchanger)(nil)
|
||||
|
||||
func (ote *quicReservedExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) {
|
||||
stream := ote.stream
|
||||
|
||||
payload, err := copyMsgWithLenHdr(q)
|
||||
if err != nil {
|
||||
stream.CancelWrite(_DOQ_REQUEST_CANCELLED)
|
||||
stream.CancelRead(_DOQ_REQUEST_CANCELLED)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4.2.1. DNS Message IDs
|
||||
// When sending queries over a QUIC connection, the DNS Message ID MUST
|
||||
// be set to 0. The stream mapping for DoQ allows for unambiguous
|
||||
// correlation of queries and responses, so the Message ID field is not
|
||||
// required.
|
||||
orgQid := binary.BigEndian.Uint16((*payload)[2:])
|
||||
binary.BigEndian.PutUint16((*payload)[2:], 0)
|
||||
|
||||
stream.SetDeadline(time.Now().Add(quicQueryTimeout))
|
||||
_, err = stream.Write(*payload)
|
||||
pool.ReleaseBuf(payload)
|
||||
if err != nil {
|
||||
stream.CancelRead(_DOQ_REQUEST_CANCELLED)
|
||||
stream.CancelWrite(_DOQ_REQUEST_CANCELLED)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// RFC 9250 4.2
|
||||
// The client MUST send the DNS query over the selected stream and MUST
|
||||
// indicate through the STREAM FIN mechanism that no further data will
|
||||
// be sent on that stream.
|
||||
//
|
||||
// Call Close() here will send the STREAM FIN. It won't close Read.
|
||||
stream.Close()
|
||||
|
||||
type res struct {
|
||||
resp *[]byte
|
||||
err error
|
||||
}
|
||||
rc := make(chan res, 1)
|
||||
go func() {
|
||||
r, err := dnsutils.ReadRawMsgFromTCP(stream)
|
||||
rc <- res{resp: r, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.CancelRead(_DOQ_REQUEST_CANCELLED)
|
||||
return nil, context.Cause(ctx)
|
||||
case r := <-rc:
|
||||
resp := r.resp
|
||||
err := r.err
|
||||
if resp != nil {
|
||||
binary.BigEndian.PutUint16((*resp), orgQid)
|
||||
}
|
||||
stream.CancelRead(_DOQ_NO_ERROR)
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func (ote *quicReservedExchanger) WithdrawReserved() {
|
||||
s := ote.stream
|
||||
s.CancelRead(_DOQ_REQUEST_CANCELLED)
|
||||
s.CancelWrite(_DOQ_REQUEST_CANCELLED)
|
||||
}
|
||||
283
pkg/upstream/transport/conn_traditional.go
Normal file
283
pkg/upstream/transport/conn_traditional.go
Normal file
@ -0,0 +1,283 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTDCTooManyQueries = errors.New("too many queries") // Connection has too many ongoing queries.
|
||||
ErrTDCClosed = errors.New("dns connection closed")
|
||||
)
|
||||
|
||||
var _ DnsConn = (*TraditionalDnsConn)(nil)
|
||||
|
||||
// TraditionalDnsConn is a low-level connection for traditional dns protocol, where
|
||||
// dns frames transport in a single and simple connection. (e.g. udp, tcp, tls)
|
||||
type TraditionalDnsConn struct {
|
||||
c NetConn
|
||||
isTcp bool
|
||||
idleTimeout time.Duration
|
||||
maxCq int
|
||||
|
||||
closeOnce sync.Once
|
||||
closeNotify chan struct{}
|
||||
closed atomic.Bool // atomic, for fast check
|
||||
closeErr error // closeErr is ready (not nil) when closeNotify is closed.
|
||||
|
||||
queueMu sync.RWMutex
|
||||
reservedQuery int
|
||||
nextQid uint16
|
||||
queue map[uint32]chan *[]byte // uint32 has fast path
|
||||
|
||||
// waitingResp indicates connection is waiting a reply from the peer.
|
||||
// It can identify c is dead or buggy in some circumstances. e.g. Network is dropped
|
||||
// and the sockets were still open because no fin or rst was received.
|
||||
waitingResp atomic.Bool
|
||||
}
|
||||
|
||||
type TraditionalDnsConnOpts struct {
|
||||
// Set to true if underlayer connection require a length header.
|
||||
// e.g. TCP and DoT.
|
||||
WithLengthHeader bool
|
||||
|
||||
// IdleTimeout controls the maximum idle time for each connection.
|
||||
// Default is defaultIdleTimeout.
|
||||
IdleTimeout time.Duration
|
||||
|
||||
// MaxConcurrentQuery limits the number of maximum concurrent queries
|
||||
// in the connection. Default is defaultTdcMaxConcurrentQuery.
|
||||
MaxConcurrentQuery int
|
||||
}
|
||||
|
||||
func NewDnsConn(opt TraditionalDnsConnOpts, conn NetConn) *TraditionalDnsConn {
|
||||
dc := &TraditionalDnsConn{
|
||||
c: conn,
|
||||
isTcp: opt.WithLengthHeader,
|
||||
closeNotify: make(chan struct{}),
|
||||
queue: make(map[uint32]chan *[]byte),
|
||||
}
|
||||
setDefaultGZ(&dc.idleTimeout, opt.IdleTimeout, defaultIdleTimeout)
|
||||
setDefaultGZ(&dc.maxCq, opt.MaxConcurrentQuery, defaultTdcMaxConcurrentQuery)
|
||||
|
||||
go dc.readLoop()
|
||||
return dc
|
||||
}
|
||||
|
||||
// exchange sends q out and waits for its reply.
|
||||
func (dc *TraditionalDnsConn) exchange(ctx context.Context, q []byte) (*[]byte, error) {
|
||||
select {
|
||||
case <-dc.closeNotify:
|
||||
return nil, ErrTDCClosed
|
||||
default:
|
||||
}
|
||||
|
||||
assignedQid, respChan := dc.addQueueC()
|
||||
if respChan == nil {
|
||||
return nil, ErrTDCTooManyQueries
|
||||
}
|
||||
defer dc.deleteQueueC(assignedQid)
|
||||
|
||||
// Reminder: Set write deadline here is not very useful to avoid dead connections.
|
||||
// Typically, a write operation will time out only if its socket buffer is full.
|
||||
// Ser read deadline is enough.
|
||||
err := dc.writeQuery(q, assignedQid)
|
||||
if err != nil {
|
||||
// Write error usually is fatal. Abort and close this connection.
|
||||
dc.CloseWithErr(fmt.Errorf("write err, %w", err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If a query was sent, server should have a reply (even not for this query) in a short time.
|
||||
// This indicates the connection is healthy. Otherwise, this connection might be dead.
|
||||
// The Read deadline will be refreshed in DnsConn.readLoop() after every successful read.
|
||||
// Note: There has a race condition in this SetReadDeadline() call and the one in
|
||||
// readLoop(). It's not a big problem.
|
||||
if dc.waitingResp.CompareAndSwap(false, true) {
|
||||
dc.c.SetReadDeadline(time.Now().Add(waitingReplyTimeout))
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, context.Cause(ctx)
|
||||
case r := <-respChan:
|
||||
orgId := binary.BigEndian.Uint16(q)
|
||||
binary.BigEndian.PutUint16(*r, orgId)
|
||||
return r, nil
|
||||
case <-dc.closeNotify:
|
||||
return nil, dc.closeErr
|
||||
}
|
||||
}
|
||||
|
||||
func (dc *TraditionalDnsConn) writeQuery(q []byte, assignedQid uint16) error {
|
||||
var payload *[]byte
|
||||
if dc.isTcp {
|
||||
var err error
|
||||
payload, err = copyMsgWithLenHdr(q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
binary.BigEndian.PutUint16((*payload)[2:], assignedQid)
|
||||
} else {
|
||||
payload = copyMsg(q)
|
||||
binary.BigEndian.PutUint16(*payload, assignedQid)
|
||||
}
|
||||
_, err := dc.c.Write(*payload)
|
||||
pool.ReleaseBuf(payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func (dc *TraditionalDnsConn) readResp() (payload *[]byte, err error) {
|
||||
if dc.isTcp {
|
||||
return dnsutils.ReadRawMsgFromTCP(dc.c)
|
||||
}
|
||||
return readMsgUdp(dc.c)
|
||||
}
|
||||
|
||||
// readLoop reads DnsConn until there was a read error.
|
||||
func (dc *TraditionalDnsConn) readLoop() {
|
||||
|
||||
for {
|
||||
dc.c.SetReadDeadline(time.Now().Add(dc.idleTimeout))
|
||||
r, err := dc.readResp()
|
||||
if err != nil {
|
||||
dc.CloseWithErr(fmt.Errorf("read err, %w", err)) // abort this connection.
|
||||
return
|
||||
}
|
||||
dc.waitingResp.Store(false)
|
||||
|
||||
rid := binary.BigEndian.Uint16(*r)
|
||||
resChan := dc.getQueueC(rid)
|
||||
if resChan != nil {
|
||||
select {
|
||||
case resChan <- r: // resChan has buffer
|
||||
default:
|
||||
pool.ReleaseBuf(r)
|
||||
}
|
||||
} else {
|
||||
pool.ReleaseBuf(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dc *TraditionalDnsConn) IsClosed() bool {
|
||||
return dc.closed.Load()
|
||||
}
|
||||
|
||||
func (dc *TraditionalDnsConn) Close() error {
|
||||
dc.CloseWithErr(ErrTDCClosed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWithErr closes DnsConn with an error. The error will be sent
|
||||
// to the waiting Exchange calls.
|
||||
// Subsequent calls are noop.
|
||||
// Default err is ErrTDCClosed.
|
||||
func (dc *TraditionalDnsConn) CloseWithErr(err error) {
|
||||
if err == nil {
|
||||
err = ErrTDCClosed
|
||||
}
|
||||
|
||||
dc.closeOnce.Do(func() {
|
||||
dc.closed.Store(true)
|
||||
dc.closeErr = err
|
||||
close(dc.closeNotify)
|
||||
dc.c.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func (dc *TraditionalDnsConn) getQueueC(qid uint16) chan<- *[]byte {
|
||||
dc.queueMu.RLock()
|
||||
defer dc.queueMu.RUnlock()
|
||||
return dc.queue[uint32(qid)]
|
||||
}
|
||||
|
||||
func (dc *TraditionalDnsConn) queueLen() int {
|
||||
dc.queueMu.RLock()
|
||||
defer dc.queueMu.RUnlock()
|
||||
return len(dc.queue) + dc.reservedQuery
|
||||
}
|
||||
|
||||
// addQueueC assigns a qid and add it to the queue.
|
||||
// It returns a nil c if queue has too many queries.
|
||||
// Caller must call deleteQueueC to release the qid in queue.
|
||||
func (dc *TraditionalDnsConn) addQueueC() (qid uint16, c chan *[]byte) {
|
||||
c = make(chan *[]byte)
|
||||
dc.queueMu.Lock()
|
||||
for i := 0; i < 100; i++ {
|
||||
qid = dc.nextQid
|
||||
dc.nextQid++
|
||||
if _, dup := dc.queue[uint32(qid)]; dup {
|
||||
continue
|
||||
}
|
||||
dc.queue[uint32(qid)] = c
|
||||
dc.queueMu.Unlock()
|
||||
return qid, c
|
||||
}
|
||||
dc.queueMu.Unlock()
|
||||
|
||||
// Too many queries in queue. Can't assign qid.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (dc *TraditionalDnsConn) deleteQueueC(qid uint16) {
|
||||
dc.queueMu.Lock()
|
||||
delete(dc.queue, uint32(qid))
|
||||
dc.queueMu.Unlock()
|
||||
}
|
||||
|
||||
func (dc *TraditionalDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
|
||||
if dc.closed.Load() {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
dc.queueMu.Lock()
|
||||
defer dc.queueMu.Unlock()
|
||||
if len(dc.queue)+dc.reservedQuery >= dc.maxCq {
|
||||
return nil, false
|
||||
}
|
||||
dc.reservedQuery++
|
||||
return (*tdcOneTimeExchanger)(dc), false
|
||||
}
|
||||
|
||||
type tdcOneTimeExchanger TraditionalDnsConn
|
||||
|
||||
var _ ReservedExchanger = (*tdcOneTimeExchanger)(nil)
|
||||
|
||||
func (ote *tdcOneTimeExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) {
|
||||
defer ote.WithdrawReserved()
|
||||
return (*TraditionalDnsConn)(ote).exchange(ctx, q)
|
||||
}
|
||||
|
||||
func (ote *tdcOneTimeExchanger) WithdrawReserved() {
|
||||
ote.queueMu.Lock()
|
||||
ote.reservedQuery--
|
||||
ote.queueMu.Unlock()
|
||||
}
|
||||
250
pkg/upstream/transport/conn_traditional_test.go
Normal file
250
pkg/upstream/transport/conn_traditional_test.go
Normal file
@ -0,0 +1,250 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dummyEchoNetConn struct {
|
||||
net.Conn
|
||||
rErrProb float64
|
||||
rLatency time.Duration
|
||||
wErrProb float64
|
||||
|
||||
closeOnce sync.Once
|
||||
closeNotify chan struct{}
|
||||
}
|
||||
|
||||
func newDummyEchoNetConn(rErrProb float64, rLatency time.Duration, wErrProb float64) NetConn {
|
||||
c1, c2 := net.Pipe()
|
||||
go func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
for {
|
||||
m, readErr := dnsutils.ReadRawMsgFromTCP(c2)
|
||||
if m != nil {
|
||||
go func() {
|
||||
defer pool.ReleaseBuf(m)
|
||||
if rLatency > 0 {
|
||||
t := time.NewTimer(rLatency)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-t.C:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
latency := time.Millisecond * time.Duration(rand.Intn(20))
|
||||
time.Sleep(latency)
|
||||
_, _ = dnsutils.WriteRawMsgToTCP(c2, *m)
|
||||
}()
|
||||
}
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &dummyEchoNetConn{
|
||||
Conn: c1,
|
||||
rErrProb: rErrProb,
|
||||
rLatency: rLatency,
|
||||
wErrProb: wErrProb,
|
||||
closeNotify: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func probTrue(p float64) bool {
|
||||
return rand.Float64() < p
|
||||
}
|
||||
|
||||
func (d *dummyEchoNetConn) Read(p []byte) (n int, err error) {
|
||||
if probTrue(d.rErrProb) {
|
||||
return 0, errors.New("read err")
|
||||
}
|
||||
return d.Conn.Read(p)
|
||||
}
|
||||
|
||||
func (d *dummyEchoNetConn) Write(p []byte) (n int, err error) {
|
||||
if probTrue(d.wErrProb) {
|
||||
return 0, errors.New("write err")
|
||||
}
|
||||
return d.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (d *dummyEchoNetConn) Close() error {
|
||||
d.closeOnce.Do(func() {
|
||||
close(d.closeNotify)
|
||||
})
|
||||
return d.Conn.Close()
|
||||
}
|
||||
|
||||
func Test_dnsConn_exchange(t *testing.T) {
|
||||
idleTimeout := time.Millisecond * 100
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rErrProb float64
|
||||
rLatency time.Duration
|
||||
wErrProb float64
|
||||
connClosed bool // connection is closed before calling exchange()
|
||||
wantMsg bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
rErrProb: 0,
|
||||
rLatency: 0,
|
||||
wErrProb: 0,
|
||||
wantMsg: true, wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "write err",
|
||||
rErrProb: 0,
|
||||
rLatency: 0,
|
||||
wErrProb: 1,
|
||||
wantMsg: false, wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "read err",
|
||||
rErrProb: 1,
|
||||
rLatency: 0,
|
||||
wErrProb: 0,
|
||||
wantMsg: false, wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "read timeout",
|
||||
rErrProb: 0,
|
||||
rLatency: idleTimeout * 3,
|
||||
wErrProb: 0,
|
||||
wantMsg: false, wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "connection closed",
|
||||
rErrProb: 0,
|
||||
rLatency: 0,
|
||||
wErrProb: 0,
|
||||
connClosed: true,
|
||||
wantMsg: false, wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := require.New(t)
|
||||
c := newDummyEchoNetConn(tt.rErrProb, tt.rLatency, tt.wErrProb)
|
||||
defer c.Close()
|
||||
ioOpts := TraditionalDnsConnOpts{
|
||||
WithLengthHeader: true, // TODO: Test false as well
|
||||
IdleTimeout: idleTimeout,
|
||||
}
|
||||
dc := NewDnsConn(ioOpts, c)
|
||||
|
||||
if tt.connClosed {
|
||||
dc.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||
defer cancel()
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("test.", dns.TypeA)
|
||||
queryPayload, err := q.Pack()
|
||||
r.NoError(err)
|
||||
|
||||
rec, closed := dc.ReserveNewQuery()
|
||||
r.Equal(tt.connClosed, closed)
|
||||
if rec == nil {
|
||||
return
|
||||
}
|
||||
|
||||
respPayload, err := rec.ExchangeReserved(ctx, queryPayload)
|
||||
if tt.wantErr {
|
||||
r.Error(err)
|
||||
} else {
|
||||
r.NoError(err)
|
||||
}
|
||||
|
||||
if tt.wantMsg {
|
||||
r.NotNil(respPayload)
|
||||
r.True(bytes.Equal(queryPayload, *respPayload))
|
||||
|
||||
// test idle timeout
|
||||
time.Sleep(idleTimeout + time.Millisecond*20)
|
||||
runtime.Gosched()
|
||||
r.True(dc.IsClosed(), "connection should be closed due to idle timeout")
|
||||
} else {
|
||||
r.Nil(respPayload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 测试 maxconcurrentquery。
|
||||
|
||||
func Test_dnsConn_exchange_race(t *testing.T) {
|
||||
r := require.New(t)
|
||||
wg := new(sync.WaitGroup)
|
||||
for i := 0; i < 1024; i++ {
|
||||
c := newDummyEchoNetConn(0.5, time.Millisecond*20, 0.5)
|
||||
ioOpts := TraditionalDnsConnOpts{
|
||||
WithLengthHeader: true, // TODO: Test false as well
|
||||
IdleTimeout: time.Millisecond * 50,
|
||||
}
|
||||
dc := NewDnsConn(ioOpts, c)
|
||||
for j := 0; j < 24; j++ {
|
||||
if dc.IsClosed() {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||
defer cancel()
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("test.", dns.TypeA)
|
||||
queryPayload, err := q.Pack()
|
||||
r.NoError(err)
|
||||
|
||||
rec, closed := dc.ReserveNewQuery()
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
if rec != nil {
|
||||
_, _ = rec.ExchangeReserved(ctx, queryPayload)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
151
pkg/upstream/transport/pipeline.go
Normal file
151
pkg/upstream/transport/pipeline.go
Normal file
@ -0,0 +1,151 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PipelineTransport will pipeline queries as RFC 7766 6.2.1.1 suggested.
|
||||
// It also can reuse udp socket. Since dns over udp is some kind of "pipeline".
|
||||
type PipelineTransport struct {
|
||||
m sync.Mutex // protect following fields
|
||||
closed bool
|
||||
conns map[*lazyDnsConn]struct{}
|
||||
|
||||
dialFunc func(ctx context.Context) (DnsConn, error)
|
||||
dialTimeout time.Duration
|
||||
maxLazyConnQueue int
|
||||
logger *zap.Logger // not nil
|
||||
}
|
||||
|
||||
type PipelineOpts struct {
|
||||
// DialContext specifies the method to dial a connection to the server.
|
||||
// DialContext MUST NOT be nil.
|
||||
DialContext func(ctx context.Context) (DnsConn, error)
|
||||
|
||||
// DialTimeout specifies the timeout for DialFunc.
|
||||
// Default is defaultDialTimeout.
|
||||
DialTimeout time.Duration
|
||||
|
||||
// When connection is dialing, how many queries can be queued up in that
|
||||
// connection. Default is defaultLazyConnMaxConcurrentQuery.
|
||||
// Note: If the connection turns out having a smaller limit, part of queued up
|
||||
// queries will fail.
|
||||
MaxConcurrentQueryWhileDialing int
|
||||
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewPipelineTransport(opt PipelineOpts) *PipelineTransport {
|
||||
t := &PipelineTransport{
|
||||
conns: make(map[*lazyDnsConn]struct{}),
|
||||
}
|
||||
t.dialFunc = opt.DialContext
|
||||
setDefaultGZ(&t.dialTimeout, opt.DialTimeout, defaultDialTimeout)
|
||||
setDefaultGZ(&t.maxLazyConnQueue, opt.MaxConcurrentQueryWhileDialing, defaultMaxLazyConnQueue)
|
||||
setNonNilLogger(&t.logger, opt.Logger)
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *PipelineTransport) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) {
|
||||
const maxRetry = 2
|
||||
retry := 0
|
||||
for {
|
||||
dc, isNewConn, err := t.getReservedExchanger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := dc.ExchangeReserved(ctx, m)
|
||||
if err != nil {
|
||||
// Reused connection may not stable.
|
||||
// Try to re-send this query if it failed on a reused connection.
|
||||
if !isNewConn && retry < maxRetry && ctx.Err() == nil {
|
||||
retry++
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes PipelineTransport and all its connections.
|
||||
// It always returns a nil error.
|
||||
func (t *PipelineTransport) Close() error {
|
||||
t.m.Lock()
|
||||
defer t.m.Unlock()
|
||||
if t.closed {
|
||||
return nil
|
||||
}
|
||||
t.closed = true
|
||||
for conn := range t.conns {
|
||||
conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *PipelineTransport) getReservedExchanger() (_ ReservedExchanger, isNewConn bool, err error) {
|
||||
t.m.Lock()
|
||||
if t.closed {
|
||||
err = ErrClosedTransport
|
||||
t.m.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
var rxc ReservedExchanger
|
||||
const maxReserveAttempt = 16
|
||||
reserveAttempt := 0
|
||||
for c := range t.conns {
|
||||
var closed bool
|
||||
rxc, closed = c.ReserveNewQuery()
|
||||
if closed {
|
||||
delete(t.conns, c)
|
||||
}
|
||||
if rxc != nil {
|
||||
break
|
||||
} else {
|
||||
reserveAttempt++
|
||||
if reserveAttempt > maxReserveAttempt {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dial a new connection
|
||||
if rxc == nil {
|
||||
c := newLazyDnsConn(t.dialFunc, t.dialTimeout, t.maxLazyConnQueue, t.logger)
|
||||
rxc, _ = c.ReserveNewQuery() // ignore the closed error for new lazy connection
|
||||
isNewConn = true
|
||||
t.conns[c] = struct{}{}
|
||||
}
|
||||
t.m.Unlock()
|
||||
|
||||
if rxc == nil {
|
||||
isNewConn = false
|
||||
err = ErrNewConnCannotReserveQueryExchanger
|
||||
}
|
||||
return rxc, isNewConn, err
|
||||
}
|
||||
152
pkg/upstream/transport/pipeline_test.go
Normal file
152
pkg/upstream/transport/pipeline_test.go
Normal file
@ -0,0 +1,152 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dummyEchoDnsConnOpt struct {
|
||||
exchangeErr error
|
||||
mcq int
|
||||
|
||||
closed atomic.Bool
|
||||
|
||||
wantConcurrentExchangeCall int
|
||||
waitingExchangeCall atomic.Int32
|
||||
unblockOnce sync.Once
|
||||
unblockExchange chan struct{}
|
||||
}
|
||||
|
||||
type dummyEchoDnsConn struct {
|
||||
opt *dummyEchoDnsConnOpt
|
||||
|
||||
m sync.Mutex
|
||||
reserved int
|
||||
}
|
||||
|
||||
func (dc *dummyEchoDnsConn) ExchangeReserved(ctx context.Context, q []byte) (*[]byte, error) {
|
||||
defer dc.WithdrawReserved()
|
||||
|
||||
if dc.opt.waitingExchangeCall.Add(1) == int32(dc.opt.wantConcurrentExchangeCall) {
|
||||
dc.opt.unblockOnce.Do(func() { close(dc.opt.unblockExchange) })
|
||||
}
|
||||
defer dc.opt.waitingExchangeCall.Add(-1)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, context.Cause(ctx)
|
||||
case <-dc.opt.unblockExchange:
|
||||
if dc.opt.exchangeErr != nil {
|
||||
return nil, dc.opt.exchangeErr
|
||||
}
|
||||
return copyMsg(q), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (dc *dummyEchoDnsConn) WithdrawReserved() {
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
dc.reserved--
|
||||
if dc.reserved < 0 {
|
||||
panic("negative reserved counter")
|
||||
}
|
||||
}
|
||||
|
||||
func (dc *dummyEchoDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
|
||||
if dc.opt.closed.Load() {
|
||||
return nil, true
|
||||
}
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
if dc.reserved >= dc.opt.mcq {
|
||||
return nil, false
|
||||
}
|
||||
dc.reserved++
|
||||
return dc, false
|
||||
}
|
||||
|
||||
func (dc *dummyEchoDnsConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_PipelineTransport(t *testing.T) {
|
||||
const (
|
||||
mcq = 100
|
||||
wantConn = 10
|
||||
wantMaxConcurrentExchangeCall = mcq * wantConn
|
||||
)
|
||||
|
||||
r := require.New(t)
|
||||
dcControl := &dummyEchoDnsConnOpt{
|
||||
mcq: mcq,
|
||||
unblockExchange: make(chan struct{}),
|
||||
wantConcurrentExchangeCall: wantMaxConcurrentExchangeCall,
|
||||
}
|
||||
po := PipelineOpts{
|
||||
DialContext: func(ctx context.Context) (DnsConn, error) { return &dummyEchoDnsConn{opt: dcControl}, nil },
|
||||
MaxConcurrentQueryWhileDialing: mcq,
|
||||
}
|
||||
pt := NewPipelineTransport(po)
|
||||
defer pt.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("test.", dns.TypeA)
|
||||
queryPayload, err := q.Pack()
|
||||
r.NoError(err)
|
||||
wg := new(sync.WaitGroup)
|
||||
for i := 0; i < wantMaxConcurrentExchangeCall; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := pt.ExchangeContext(ctx, queryPayload)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if t.Failed() {
|
||||
break
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
pt.m.Lock()
|
||||
pl := len(pt.conns)
|
||||
pt.m.Unlock()
|
||||
|
||||
r.Equal(wantConn, pl)
|
||||
|
||||
dcControl.closed.Store(true)
|
||||
_, _ = pt.ExchangeContext(ctx, queryPayload) // remove all closed conn
|
||||
|
||||
pt.m.Lock()
|
||||
pl = len(pt.conns)
|
||||
pt.m.Unlock()
|
||||
r.Equal(1, pl, "all connection should be remove then one will be opened")
|
||||
}
|
||||
341
pkg/upstream/transport/reuse.go
Normal file
341
pkg/upstream/transport/reuse.go
Normal file
@ -0,0 +1,341 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
// Most servers will send SERVFAIL after 3~5s. If no resp, connection may be dead.
|
||||
reuseConnQueryTimeout = time.Second * 6
|
||||
)
|
||||
|
||||
// ReuseConnTransport is for old tcp protocol. (no pipelining)
|
||||
type ReuseConnTransport struct {
|
||||
dialFunc func(ctx context.Context) (NetConn, error)
|
||||
dialTimeout time.Duration
|
||||
idleTimeout time.Duration
|
||||
logger *zap.Logger // non-nil
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelCauseFunc
|
||||
|
||||
m sync.Mutex // protect following fields
|
||||
closed bool
|
||||
idleConns map[*reusableConn]struct{}
|
||||
conns map[*reusableConn]struct{}
|
||||
|
||||
// for testing
|
||||
testWaitRespTimeout time.Duration
|
||||
}
|
||||
|
||||
type ReuseConnOpts struct {
|
||||
// DialContext specifies the method to dial a connection to the server.
|
||||
// DialContext MUST NOT be nil.
|
||||
DialContext func(ctx context.Context) (NetConn, error)
|
||||
|
||||
// DialTimeout specifies the timeout for DialFunc.
|
||||
// Default is defaultDialTimeout.
|
||||
DialTimeout time.Duration
|
||||
|
||||
// Default is defaultIdleTimeout
|
||||
IdleTimeout time.Duration
|
||||
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewReuseConnTransport(opt ReuseConnOpts) *ReuseConnTransport {
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
t := &ReuseConnTransport{
|
||||
ctx: ctx,
|
||||
ctxCancel: cancel,
|
||||
idleConns: make(map[*reusableConn]struct{}),
|
||||
conns: make(map[*reusableConn]struct{}),
|
||||
}
|
||||
t.dialFunc = opt.DialContext
|
||||
setDefaultGZ(&t.dialTimeout, opt.DialTimeout, defaultDialTimeout)
|
||||
setDefaultGZ(&t.idleTimeout, opt.IdleTimeout, defaultIdleTimeout)
|
||||
setNonNilLogger(&t.logger, opt.Logger)
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *ReuseConnTransport) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) {
|
||||
const maxRetry = 2
|
||||
|
||||
retry := 0
|
||||
for {
|
||||
var isNewConn bool
|
||||
c, err := t.getIdleConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c == nil {
|
||||
isNewConn = true
|
||||
c, err = t.getNewConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
queryPayload, err := copyMsgWithLenHdr(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.exchange(ctx, queryPayload)
|
||||
if err != nil {
|
||||
if !isNewConn && retry <= maxRetry {
|
||||
retry++
|
||||
continue // retry if c is a reused connection.
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
// getNewConn dial a *reusableConn.
|
||||
// The caller must call releaseReusableConn to release the reusableConn.
|
||||
func (t *ReuseConnTransport) getNewConn(ctx context.Context) (*reusableConn, error) {
|
||||
callCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
type dialRes struct {
|
||||
c *reusableConn
|
||||
err error
|
||||
}
|
||||
dialChan := make(chan dialRes)
|
||||
go func() {
|
||||
dialCtx, cancelDial := context.WithTimeout(t.ctx, t.dialTimeout)
|
||||
defer cancelDial()
|
||||
|
||||
var rc *reusableConn
|
||||
c, err := t.dialFunc(dialCtx)
|
||||
if err != nil {
|
||||
t.logger.Check(zap.WarnLevel, "fail to dial reusable conn").Write(zap.Error(err))
|
||||
}
|
||||
if c != nil {
|
||||
rc = t.newReusableConn(c)
|
||||
if rc == nil { // transport closed
|
||||
c.Close()
|
||||
rc = nil
|
||||
err = ErrClosedTransport
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case dialChan <- dialRes{c: rc, err: err}:
|
||||
case <-callCtx.Done(): // caller canceled getNewConn() call
|
||||
if rc != nil { // put this conn to pool
|
||||
t.setIdle(rc)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-callCtx.Done():
|
||||
return nil, context.Cause(ctx)
|
||||
case <-t.ctx.Done():
|
||||
return nil, context.Cause(t.ctx)
|
||||
case res := <-dialChan:
|
||||
return res.c, res.err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReuseConnTransport) setIdle(c *reusableConn) {
|
||||
t.m.Lock()
|
||||
defer t.m.Unlock()
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
if _, ok := t.conns[c]; ok {
|
||||
t.idleConns[c] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// getIdleConn returns a *reusableConn from conn pool, or nil if no conn
|
||||
// is idle.
|
||||
// The caller must call releaseReusableConn to release the reusableConn.
|
||||
func (t *ReuseConnTransport) getIdleConn() (*reusableConn, error) {
|
||||
t.m.Lock()
|
||||
defer t.m.Unlock()
|
||||
if t.closed {
|
||||
return nil, ErrClosedTransport
|
||||
}
|
||||
|
||||
for c := range t.idleConns {
|
||||
delete(t.idleConns, c)
|
||||
return c, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Close closes ReuseConnTransport and all its connections.
|
||||
// It always returns a nil error.
|
||||
func (t *ReuseConnTransport) Close() error {
|
||||
t.m.Lock()
|
||||
defer t.m.Unlock()
|
||||
if t.closed {
|
||||
return nil
|
||||
}
|
||||
t.closed = true
|
||||
for c := range t.conns {
|
||||
delete(t.conns, c)
|
||||
delete(t.idleConns, c)
|
||||
c.closeWithErrByTransport(ErrClosedTransport)
|
||||
}
|
||||
t.ctxCancel(ErrClosedTransport)
|
||||
return nil
|
||||
}
|
||||
|
||||
type reusableConn struct {
|
||||
c NetConn
|
||||
t *ReuseConnTransport
|
||||
|
||||
m sync.Mutex
|
||||
waitingResp chan *[]byte
|
||||
|
||||
closeOnce sync.Once
|
||||
closeNotify chan struct{}
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// return nil if transport was closed
|
||||
func (t *ReuseConnTransport) newReusableConn(c NetConn) *reusableConn {
|
||||
rc := &reusableConn{
|
||||
c: c,
|
||||
t: t,
|
||||
closeNotify: make(chan struct{}),
|
||||
}
|
||||
|
||||
t.m.Lock()
|
||||
if t.closed { // t was closed.
|
||||
t.m.Unlock()
|
||||
return nil
|
||||
}
|
||||
t.conns[rc] = struct{}{}
|
||||
t.m.Unlock()
|
||||
go rc.readLoop()
|
||||
return rc
|
||||
}
|
||||
|
||||
var (
|
||||
errUnexpectedResp = errors.New("server misbehaving: unexpected response")
|
||||
)
|
||||
|
||||
func (c *reusableConn) readLoop() {
|
||||
for {
|
||||
resp, err := dnsutils.ReadRawMsgFromTCP(c.c)
|
||||
if err != nil {
|
||||
c.closeWithErr(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.m.Lock()
|
||||
respChan := c.waitingResp
|
||||
c.waitingResp = nil
|
||||
c.m.Unlock()
|
||||
|
||||
if respChan == nil {
|
||||
pool.ReleaseBuf(resp)
|
||||
c.closeWithErr(errUnexpectedResp)
|
||||
return
|
||||
}
|
||||
|
||||
// This connection is idled again.
|
||||
c.c.SetReadDeadline(time.Now().Add(c.t.idleTimeout))
|
||||
// Note: calling setIdle before sending resp back to make sure this connection is idle
|
||||
// before Exchange call returning. Otherwise, Test_ReuseConnTransport may fail.
|
||||
c.t.setIdle(c)
|
||||
|
||||
select {
|
||||
case respChan <- resp:
|
||||
default:
|
||||
panic("bug: respChan has buffer, we shouldn't reach here")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *reusableConn) closeWithErr(err error) {
|
||||
if err == nil {
|
||||
err = net.ErrClosed
|
||||
}
|
||||
c.closeOnce.Do(func() {
|
||||
c.t.m.Lock()
|
||||
delete(c.t.conns, c)
|
||||
delete(c.t.idleConns, c)
|
||||
c.t.m.Unlock()
|
||||
|
||||
c.closeErr = err
|
||||
c.c.Close()
|
||||
close(c.closeNotify)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *reusableConn) closeWithErrByTransport(err error) {
|
||||
if err == nil {
|
||||
err = net.ErrClosed
|
||||
}
|
||||
c.closeOnce.Do(func() {
|
||||
c.closeErr = err
|
||||
c.c.Close()
|
||||
close(c.closeNotify)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *reusableConn) exchange(ctx context.Context, q *[]byte) (*[]byte, error) {
|
||||
respChan := make(chan *[]byte, 1)
|
||||
c.m.Lock()
|
||||
if c.waitingResp != nil {
|
||||
c.m.Unlock()
|
||||
panic("bug: reusableConn: concurrent exchange calls")
|
||||
}
|
||||
c.waitingResp = respChan
|
||||
c.m.Unlock()
|
||||
|
||||
waitRespTimeout := reuseConnQueryTimeout
|
||||
if c.t.testWaitRespTimeout > 0 {
|
||||
waitRespTimeout = c.t.testWaitRespTimeout
|
||||
}
|
||||
c.c.SetDeadline(time.Now().Add(waitRespTimeout))
|
||||
_, err := c.c.Write(*q)
|
||||
if err != nil {
|
||||
c.closeWithErr(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case resp := <-respChan:
|
||||
return resp, nil
|
||||
case <-c.closeNotify:
|
||||
return nil, c.closeErr
|
||||
case <-ctx.Done():
|
||||
return nil, context.Cause(ctx)
|
||||
}
|
||||
}
|
||||
154
pkg/upstream/transport/reuse_test.go
Normal file
154
pkg/upstream/transport/reuse_test.go
Normal file
@ -0,0 +1,154 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ReuseConnTransport(t *testing.T) {
|
||||
const idleTimeout = time.Second * 5
|
||||
r := require.New(t)
|
||||
|
||||
po := ReuseConnOpts{
|
||||
DialContext: func(ctx context.Context) (NetConn, error) {
|
||||
return newDummyEchoNetConn(0, 0, 0), nil
|
||||
},
|
||||
IdleTimeout: idleTimeout,
|
||||
}
|
||||
rt := NewReuseConnTransport(po)
|
||||
defer rt.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("test.", dns.TypeA)
|
||||
queryPayload, err := q.Pack()
|
||||
r.NoError(err)
|
||||
concurrentQueryNum := 10
|
||||
for l := 0; l < 4; l++ {
|
||||
wg := new(sync.WaitGroup)
|
||||
for i := 0; i < concurrentQueryNum; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := rt.ExchangeContext(ctx, queryPayload)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if t.Failed() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
rt.m.Lock()
|
||||
connNum := len(rt.conns)
|
||||
idledConnNum := len(rt.idleConns)
|
||||
rt.m.Unlock()
|
||||
|
||||
r.Equal(0, connNum-idledConnNum, "there should be no active conn")
|
||||
r.Equal(concurrentQueryNum, connNum)
|
||||
r.Equal(concurrentQueryNum, idledConnNum, "all conn should be in idle status")
|
||||
}
|
||||
|
||||
func Test_ReuseConnTransport_Read_err_and_close(t *testing.T) {
|
||||
const idleTimeout = time.Second * 5
|
||||
r := require.New(t)
|
||||
|
||||
po := ReuseConnOpts{
|
||||
DialContext: func(ctx context.Context) (NetConn, error) {
|
||||
return newDummyEchoNetConn(1, 0, 0), nil // 100% read err
|
||||
},
|
||||
IdleTimeout: idleTimeout,
|
||||
}
|
||||
rt := NewReuseConnTransport(po)
|
||||
defer rt.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("test.", dns.TypeA)
|
||||
queryPayload, err := q.Pack()
|
||||
r.NoError(err)
|
||||
|
||||
wg := new(sync.WaitGroup)
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := rt.ExchangeContext(ctx, queryPayload)
|
||||
r.Error(err)
|
||||
}()
|
||||
if t.Failed() {
|
||||
return
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
rt.m.Lock()
|
||||
connNum := len(rt.conns)
|
||||
idledConnNum := len(rt.idleConns)
|
||||
rt.m.Unlock()
|
||||
|
||||
r.Equal(0, connNum)
|
||||
r.Equal(0, idledConnNum)
|
||||
}
|
||||
|
||||
func Test_ReuseConnTransport_conn_lose_and_close(t *testing.T) {
|
||||
r := require.New(t)
|
||||
po := ReuseConnOpts{
|
||||
DialContext: func(ctx context.Context) (NetConn, error) {
|
||||
return newDummyEchoNetConn(0, time.Second, 0), nil // 100% read timeout
|
||||
},
|
||||
}
|
||||
rt := NewReuseConnTransport(po)
|
||||
defer rt.Close()
|
||||
rt.testWaitRespTimeout = time.Millisecond * 1
|
||||
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("test.", dns.TypeA)
|
||||
queryPayload, err := q.Pack()
|
||||
r.NoError(err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
|
||||
defer cancel()
|
||||
_, err = rt.ExchangeContext(ctx, queryPayload) // canceled ctx
|
||||
r.Error(err)
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
rt.m.Lock()
|
||||
connNum := len(rt.conns)
|
||||
idledConnNum := len(rt.idleConns)
|
||||
rt.m.Unlock()
|
||||
|
||||
// connection should be closed and removed
|
||||
r.Equal(0, connNum)
|
||||
r.Equal(0, idledConnNum)
|
||||
}
|
||||
74
pkg/upstream/transport/transport.go
Normal file
74
pkg/upstream/transport/transport.go
Normal file
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosedTransport = errors.New("transport has been closed")
|
||||
ErrPayloadOverFlow = errors.New("payload is too large")
|
||||
ErrNewConnCannotReserveQueryExchanger = errors.New("new connection failed to reserve query exchanger")
|
||||
ErrLazyConnCannotReserveQueryExchanger = errors.New("lazy connection failed to reserve query exchanger")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultIdleTimeout = time.Second * 10
|
||||
defaultDialTimeout = time.Second * 5
|
||||
|
||||
// If a pipeline connection sent a query but did not see any reply (include replies that
|
||||
// for other queries) from the server after waitingReplyTimeout. It assumes that
|
||||
// something goes wrong with the connection or the server. The connection will be closed.
|
||||
waitingReplyTimeout = time.Second * 10
|
||||
|
||||
defaultTdcMaxConcurrentQuery = 32
|
||||
defaultMaxLazyConnQueue = 16
|
||||
)
|
||||
|
||||
// One method MUST be called in ReservedExchanger.
|
||||
type ReservedExchanger interface {
|
||||
// ExchangeReserved sends q to the server and returns it's response.
|
||||
// ExchangeReserved MUST not modify nor keep the q.
|
||||
// q MUST be a valid dns message.
|
||||
// resp (if no err) should be released by ReleaseResp().
|
||||
ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error)
|
||||
|
||||
// WithdrawReserved aborts the query.
|
||||
WithdrawReserved()
|
||||
}
|
||||
|
||||
type DnsConn interface {
|
||||
// ReserveNewQuery reserves a query. It MUST be fast and non-block. If DnsConn
|
||||
// cannot serve more query due to its capacity, ReserveNewQuery returns nil.
|
||||
// If DnsConn is closed and can no longer serve more query, returns closed = true.
|
||||
ReserveNewQuery() (_ ReservedExchanger, closed bool)
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type NetConn interface {
|
||||
io.ReadWriteCloser
|
||||
SetDeadline(t time.Time) error
|
||||
SetReadDeadline(t time.Time) error
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
90
pkg/upstream/transport/utils.go
Normal file
90
pkg/upstream/transport/utils.go
Normal file
@ -0,0 +1,90 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsHeaderLen = 12 // minimum dns msg size
|
||||
)
|
||||
|
||||
func copyMsgWithLenHdr(m []byte) (*[]byte, error) {
|
||||
l := len(m)
|
||||
if l > dns.MaxMsgSize {
|
||||
return nil, ErrPayloadOverFlow
|
||||
}
|
||||
bp := pool.GetBuf(l + 2)
|
||||
binary.BigEndian.PutUint16(*bp, uint16(l))
|
||||
copy((*bp)[2:], m)
|
||||
return bp, nil
|
||||
}
|
||||
|
||||
func copyMsg(m []byte) *[]byte {
|
||||
bp := pool.GetBuf(len(m))
|
||||
copy((*bp), m)
|
||||
return bp
|
||||
}
|
||||
|
||||
// readMsgUdp reads dns frame from r. r typically should be a udp connection.
|
||||
// It uses a 4kb rx buffer and ignores any payload that is too small for a dns msg.
|
||||
// If no error, the length of payload always >= 12 bytes.
|
||||
func readMsgUdp(r io.Reader) (*[]byte, error) {
|
||||
// TODO: Make this configurable?
|
||||
// 4kb should be enough.
|
||||
payload := pool.GetBuf(4095)
|
||||
|
||||
readAgain:
|
||||
n, err := r.Read(*payload)
|
||||
if err != nil {
|
||||
pool.ReleaseBuf(payload)
|
||||
return nil, err
|
||||
}
|
||||
if n < dnsHeaderLen {
|
||||
goto readAgain
|
||||
}
|
||||
*payload = (*payload)[:n]
|
||||
return payload, err
|
||||
}
|
||||
|
||||
func setDefaultGZ[T constraints.Float | constraints.Integer](i *T, s, d T) {
|
||||
if s > 0 {
|
||||
*i = s
|
||||
} else {
|
||||
*i = d
|
||||
}
|
||||
}
|
||||
|
||||
var nopLogger = zap.NewNop()
|
||||
|
||||
func setNonNilLogger(i **zap.Logger, s *zap.Logger) {
|
||||
if s != nil {
|
||||
*i = s
|
||||
} else {
|
||||
*i = nopLogger
|
||||
}
|
||||
}
|
||||
609
pkg/upstream/upstream.go
Normal file
609
pkg/upstream/upstream.go
Normal file
@ -0,0 +1,609 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/mlog"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/bootstrap"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/doh"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/transport"
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
const (
|
||||
tlsHandshakeTimeout = time.Second * 3
|
||||
|
||||
// Maximum number of concurrent queries in one pipeline connection.
|
||||
// See RFC 7766 7. Response Reordering.
|
||||
// TODO: Make this configurable?
|
||||
pipelineConcurrentLimit = 64
|
||||
)
|
||||
|
||||
// Upstream represents a DNS upstream.
|
||||
type Upstream interface {
|
||||
// ExchangeContext exchanges query message m to the upstream, and returns
|
||||
// response. It MUST NOT keep or modify m.
|
||||
// m MUST be a valid dns msg frame. It MUST be at least 12 bytes
|
||||
// (contain a valid dns header).
|
||||
ExchangeContext(ctx context.Context, m []byte) (*[]byte, error)
|
||||
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type Opt struct {
|
||||
// DialAddr specifies the address the upstream will
|
||||
// actually dial to in the network layer by overwriting
|
||||
// the address inferred from upstream url.
|
||||
// It won't affect high level layers. (e.g. SNI, HTTP HOST header won't be changed).
|
||||
// Can be an IP or a domain. Port is optional.
|
||||
// Tips: If the upstream url host is a domain, specific an IP address
|
||||
// here can skip resolving ip of this domain.
|
||||
DialAddr string
|
||||
|
||||
// Socks5 specifies the socks5 proxy server that the upstream
|
||||
// will connect though.
|
||||
// Not implemented for udp based protocols (aka. dns over udp, http3, quic).
|
||||
Socks5 string
|
||||
|
||||
// SoMark sets the socket SO_MARK option in unix system.
|
||||
SoMark int
|
||||
|
||||
// BindToDevice sets the socket SO_BINDTODEVICE option in unix system.
|
||||
BindToDevice string
|
||||
|
||||
// IdleTimeout specifies the idle timeout for long-connections.
|
||||
// Default: TCP, DoT: 10s , DoH, DoH3, Quic: 30s.
|
||||
IdleTimeout time.Duration
|
||||
|
||||
// EnablePipeline enables query pipelining support as RFC 7766 6.2.1.1 suggested.
|
||||
// Available for TCP, DoT upstream.
|
||||
// Note: There is no fallback. Make sure the server supports it.
|
||||
EnablePipeline bool
|
||||
|
||||
// EnableHTTP3 will use HTTP/3 protocol to connect a DoH upstream. (aka DoH3).
|
||||
// Note: There is no fallback. Make sure the server supports it.
|
||||
EnableHTTP3 bool
|
||||
|
||||
// Bootstrap specifies a plain dns server to solve the
|
||||
// upstream server domain address.
|
||||
// It must be an IP address. Port is optional.
|
||||
Bootstrap string
|
||||
|
||||
// Bootstrap version. One of 0 (default equals 4), 4, 6.
|
||||
// TODO: Support dual-stack.
|
||||
BootstrapVer int
|
||||
|
||||
// TLSConfig specifies the tls.Config that the TLS client will use.
|
||||
// Available for DoT, DoH, DoQ upstream.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// Logger specifies the logger that the upstream will use.
|
||||
Logger *zap.Logger
|
||||
|
||||
// EventObserver can observe connection events.
|
||||
// Not implemented for quic based protocol (DoH3, DoQ).
|
||||
EventObserver EventObserver
|
||||
}
|
||||
|
||||
// NewUpstream creates a upstream.
|
||||
// addr has the format of: [protocol://]host[:port][/path].
|
||||
// Supported protocol: udp/tcp/tls/https/quic. Default protocol is udp.
|
||||
//
|
||||
// Helper protocol:
|
||||
// - tcp+pipeline/tls+pipeline: Automatically set opt.EnablePipeline to true.
|
||||
// - h3: Automatically set opt.EnableHTTP3 to true.
|
||||
func NewUpstream(addr string, opt Opt) (_ Upstream, err error) {
|
||||
if opt.Logger == nil {
|
||||
opt.Logger = mlog.Nop()
|
||||
}
|
||||
if opt.EventObserver == nil {
|
||||
opt.EventObserver = nopEO{}
|
||||
}
|
||||
|
||||
// parse protocol and server addr
|
||||
if !strings.Contains(addr, "://") {
|
||||
addr = "udp://" + addr
|
||||
}
|
||||
addrURL, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid server address, %w", err)
|
||||
}
|
||||
|
||||
// Apply helper protocol
|
||||
switch addrURL.Scheme {
|
||||
case "tcp+pipeline", "tls+pipeline":
|
||||
addrURL.Scheme = addrURL.Scheme[:3]
|
||||
opt.EnablePipeline = true
|
||||
case "h3":
|
||||
addrURL.Scheme = "https"
|
||||
opt.EnableHTTP3 = true
|
||||
}
|
||||
|
||||
// If host is a ipv6 without port, it will be in []. This will cause err when
|
||||
// split and join address and port. Try to remove brackets now.
|
||||
addrUrlHost := tryTrimIpv6Brackets(addrURL.Host)
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Control: getSocketControlFunc(socketOpts{
|
||||
so_mark: opt.SoMark,
|
||||
bind_to_device: opt.BindToDevice,
|
||||
}),
|
||||
}
|
||||
|
||||
var bootstrapAp netip.AddrPort
|
||||
if s := opt.Bootstrap; len(s) > 0 {
|
||||
bootstrapAp, err = parseBootstrapAp(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid bootstrap, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
newUdpAddrResolveFunc := func(defaultPort uint16) (func(ctx context.Context) (*net.UDPAddr, error), error) {
|
||||
host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if addr, err := netip.ParseAddr(host); err == nil { // host is an ip.
|
||||
ua := net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
|
||||
return func(ctx context.Context) (*net.UDPAddr, error) {
|
||||
return ua, nil
|
||||
}, nil
|
||||
} else { // Not an ip, assuming it's a domain name.
|
||||
if bootstrapAp.IsValid() {
|
||||
// Bootstrap enabled.
|
||||
bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(ctx context.Context) (*net.UDPAddr, error) {
|
||||
s, err := bs.GetAddrPortStr(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bootstrap failed, %w", err)
|
||||
}
|
||||
return net.ResolveUDPAddr("udp", s)
|
||||
}, nil
|
||||
} else {
|
||||
// Bootstrap disabled.
|
||||
dialAddr := joinPort(host, port)
|
||||
return func(ctx context.Context) (*net.UDPAddr, error) {
|
||||
return net.ResolveUDPAddr("udp", dialAddr)
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newTcpDialer := func(dialAddrMustBeIp bool, defaultPort uint16) (func(ctx context.Context) (net.Conn, error), error) {
|
||||
host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Socks5 enabled.
|
||||
if s5Addr := opt.Socks5; len(s5Addr) > 0 {
|
||||
socks5Dialer, err := proxy.SOCKS5("tcp", s5Addr, nil, dialer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init socks5 dialer: %w", err)
|
||||
}
|
||||
|
||||
contextDialer := socks5Dialer.(proxy.ContextDialer)
|
||||
dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
return contextDialer.DialContext(ctx, "tcp", dialAddr)
|
||||
}, nil
|
||||
}
|
||||
|
||||
if _, err := netip.ParseAddr(host); err == nil {
|
||||
// Host is an ip addr. No need to resolve it.
|
||||
dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", dialAddr)
|
||||
}, nil
|
||||
} else {
|
||||
if dialAddrMustBeIp {
|
||||
return nil, errors.New("addr must be an ip address")
|
||||
}
|
||||
// Host is not an ip addr, assuming it is a domain.
|
||||
if bootstrapAp.IsValid() {
|
||||
// Bootstrap enabled.
|
||||
bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
dialAddr, err := bs.GetAddrPortStr(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bootstrap failed, %w", err)
|
||||
}
|
||||
return dialer.DialContext(ctx, "tcp", dialAddr)
|
||||
}, nil
|
||||
} else {
|
||||
// Bootstrap disabled.
|
||||
dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", dialAddr)
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
closeIfFuncErr := func(c io.Closer) {
|
||||
if err != nil {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
switch addrURL.Scheme {
|
||||
case "", "udp":
|
||||
const defaultPort = 53
|
||||
const maxConcurrentQueryPreConn = 4096 // Protocol limit is 65535.
|
||||
host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := netip.ParseAddr(host); err != nil {
|
||||
return nil, fmt.Errorf("addr must be an ip address, %w", err)
|
||||
}
|
||||
dialAddr := joinPort(host, port)
|
||||
|
||||
dialUdpPipeline := func(ctx context.Context) (transport.DnsConn, error) {
|
||||
c, err := dialer.DialContext(ctx, "udp", dialAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
to := transport.TraditionalDnsConnOpts{
|
||||
WithLengthHeader: false,
|
||||
IdleTimeout: time.Minute * 5,
|
||||
MaxConcurrentQuery: maxConcurrentQueryPreConn,
|
||||
}
|
||||
return transport.NewDnsConn(to, wrapConn(c, opt.EventObserver)), nil
|
||||
}
|
||||
dialTcpNetConn := func(ctx context.Context) (transport.NetConn, error) {
|
||||
c, err := dialer.DialContext(ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wrapConn(c, opt.EventObserver), nil
|
||||
}
|
||||
|
||||
return &udpWithFallback{
|
||||
u: transport.NewPipelineTransport(transport.PipelineOpts{
|
||||
DialContext: dialUdpPipeline,
|
||||
MaxConcurrentQueryWhileDialing: maxConcurrentQueryPreConn,
|
||||
Logger: opt.Logger,
|
||||
}),
|
||||
t: transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialTcpNetConn}),
|
||||
}, nil
|
||||
case "tcp":
|
||||
const defaultPort = 53
|
||||
tcpDialer, err := newTcpDialer(true, defaultPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
|
||||
}
|
||||
idleTimeout := opt.IdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = time.Second * 10
|
||||
}
|
||||
|
||||
dialNetConn := func(ctx context.Context) (transport.NetConn, error) {
|
||||
c, err := tcpDialer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wrapConn(c, opt.EventObserver), nil
|
||||
}
|
||||
if opt.EnablePipeline {
|
||||
to := transport.TraditionalDnsConnOpts{
|
||||
WithLengthHeader: true,
|
||||
IdleTimeout: idleTimeout,
|
||||
MaxConcurrentQuery: pipelineConcurrentLimit,
|
||||
}
|
||||
dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) {
|
||||
c, err := dialNetConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return transport.NewDnsConn(to, c), nil
|
||||
}
|
||||
return transport.NewPipelineTransport(transport.PipelineOpts{
|
||||
DialContext: dialDnsConn,
|
||||
MaxConcurrentQueryWhileDialing: pipelineConcurrentLimit,
|
||||
Logger: opt.Logger,
|
||||
}), nil
|
||||
}
|
||||
return transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialNetConn, IdleTimeout: idleTimeout}), nil
|
||||
case "tls":
|
||||
const defaultPort = 853
|
||||
tlsConfig := opt.TLSConfig.Clone()
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = new(tls.Config)
|
||||
}
|
||||
if len(tlsConfig.ServerName) == 0 {
|
||||
tlsConfig.ServerName = tryRemovePort(addrUrlHost)
|
||||
}
|
||||
|
||||
tcpDialer, err := newTcpDialer(false, defaultPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
|
||||
}
|
||||
|
||||
dialNetConn := func(ctx context.Context) (transport.NetConn, error) {
|
||||
conn, err := tcpDialer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn = wrapConn(conn, opt.EventObserver)
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return wrapConn(tlsConn, opt.EventObserver), nil
|
||||
}
|
||||
|
||||
if opt.EnablePipeline {
|
||||
to := transport.TraditionalDnsConnOpts{
|
||||
WithLengthHeader: true,
|
||||
IdleTimeout: opt.IdleTimeout,
|
||||
MaxConcurrentQuery: pipelineConcurrentLimit,
|
||||
}
|
||||
dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) {
|
||||
c, err := dialNetConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return transport.NewDnsConn(to, c), nil
|
||||
}
|
||||
return transport.NewPipelineTransport(transport.PipelineOpts{
|
||||
DialContext: dialDnsConn,
|
||||
MaxConcurrentQueryWhileDialing: pipelineConcurrentLimit,
|
||||
Logger: opt.Logger,
|
||||
}), nil
|
||||
}
|
||||
return transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialNetConn}), nil
|
||||
case "https":
|
||||
const defaultPort = 443
|
||||
|
||||
idleConnTimeout := time.Second * 30
|
||||
if opt.IdleTimeout > 0 {
|
||||
idleConnTimeout = opt.IdleTimeout
|
||||
}
|
||||
|
||||
var t http.RoundTripper
|
||||
var addonCloser io.Closer
|
||||
if opt.EnableHTTP3 {
|
||||
udpBootstrap, err := newUdpAddrResolveFunc(defaultPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err)
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})}
|
||||
conn, err := lc.ListenPacket(context.Background(), "udp", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init udp socket for quic, %w", err)
|
||||
}
|
||||
quicTransport := &quic.Transport{
|
||||
Conn: conn,
|
||||
}
|
||||
quicConfig := newDefaultClientQuicConfig()
|
||||
quicConfig.MaxIdleTimeout = idleConnTimeout
|
||||
|
||||
defer closeIfFuncErr(quicTransport)
|
||||
addonCloser = quicTransport
|
||||
t = &http3.RoundTripper{
|
||||
TLSClientConfig: opt.TLSConfig,
|
||||
QUICConfig: quicConfig,
|
||||
Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
ua, err := udpBootstrap(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return quicTransport.DialEarly(ctx, ua, tlsCfg, cfg)
|
||||
},
|
||||
MaxResponseHeaderBytes: 4 * 1024,
|
||||
}
|
||||
} else {
|
||||
tcpDialer, err := newTcpDialer(false, defaultPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
|
||||
}
|
||||
t1 := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { // overwrite server addr
|
||||
c, err := tcpDialer(ctx)
|
||||
c = wrapConn(c, opt.EventObserver)
|
||||
return c, err
|
||||
},
|
||||
TLSClientConfig: opt.TLSConfig,
|
||||
TLSHandshakeTimeout: tlsHandshakeTimeout,
|
||||
IdleConnTimeout: idleConnTimeout,
|
||||
|
||||
// Following opts are for http/1 only.
|
||||
// MaxConnsPerHost: 2,
|
||||
// MaxIdleConnsPerHost: 2,
|
||||
}
|
||||
|
||||
t2, err := http2.ConfigureTransports(t1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to upgrade http2 support, %w", err)
|
||||
}
|
||||
t2.MaxHeaderListSize = 4 * 1024
|
||||
t2.MaxReadFrameSize = 16 * 1024
|
||||
t2.ReadIdleTimeout = time.Second * 30
|
||||
t2.PingTimeout = time.Second * 5
|
||||
t = t1
|
||||
}
|
||||
|
||||
u, err := doh.NewUpstream(addrURL.String(), t, opt.Logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create doh upstream, %w", err)
|
||||
}
|
||||
|
||||
return &dohWithClose{
|
||||
u: u,
|
||||
closer: addonCloser,
|
||||
}, nil
|
||||
case "quic", "doq":
|
||||
const defaultPort = 853
|
||||
tlsConfig := opt.TLSConfig.Clone()
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = new(tls.Config)
|
||||
}
|
||||
if len(tlsConfig.ServerName) == 0 {
|
||||
tlsConfig.ServerName = tryRemovePort(addrUrlHost)
|
||||
}
|
||||
tlsConfig.NextProtos = []string{"doq"}
|
||||
|
||||
quicConfig := newDefaultClientQuicConfig()
|
||||
if opt.IdleTimeout > 0 {
|
||||
quicConfig.MaxIdleTimeout = opt.IdleTimeout
|
||||
}
|
||||
// Don't accept stream.
|
||||
quicConfig.MaxIncomingStreams = -1
|
||||
quicConfig.MaxIncomingUniStreams = -1
|
||||
|
||||
udpBootstrap, err := newUdpAddrResolveFunc(defaultPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err)
|
||||
}
|
||||
|
||||
srk, _, err := utils.InitQUICSrkFromIfaceMac()
|
||||
if err != nil {
|
||||
opt.Logger.Warn("failed to init quic stateless reset key, it will be disabled", zap.Error(err))
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})}
|
||||
uc, err := lc.ListenPacket(context.Background(), "udp", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init udp socket for quic, %w", err)
|
||||
}
|
||||
|
||||
t := &quic.Transport{
|
||||
Conn: uc,
|
||||
StatelessResetKey: (*quic.StatelessResetKey)(srk),
|
||||
}
|
||||
|
||||
dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) {
|
||||
ua, err := udpBootstrap(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bootstrap failed, %w", err)
|
||||
}
|
||||
|
||||
// This is a workaround to
|
||||
// 1. recover from strange 0rtt rejected err.
|
||||
// 2. avoid NextConnection might block forever.
|
||||
// TODO: Remove this workaround.
|
||||
var c quic.Connection
|
||||
ec, err := t.DialEarly(ctx, ua, tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err = ec.NextConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return transport.NewQuicDnsConn(c), nil
|
||||
}
|
||||
|
||||
return transport.NewPipelineTransport(transport.PipelineOpts{
|
||||
DialContext: dialDnsConn,
|
||||
// Quic rfc recommendation is 100. Some implications use 65535.
|
||||
MaxConcurrentQueryWhileDialing: 90,
|
||||
Logger: opt.Logger,
|
||||
}), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
type udpWithFallback struct {
|
||||
u *transport.PipelineTransport
|
||||
t *transport.ReuseConnTransport
|
||||
}
|
||||
|
||||
func (u *udpWithFallback) ExchangeContext(ctx context.Context, q []byte) (*[]byte, error) {
|
||||
r, err := u.u.ExchangeContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msgTruncated(*r) {
|
||||
pool.ReleaseBuf(r)
|
||||
return u.t.ExchangeContext(ctx, q)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (u *udpWithFallback) Close() error {
|
||||
u.u.Close()
|
||||
u.t.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type dohWithClose struct {
|
||||
u *doh.Upstream
|
||||
closer io.Closer // maybe nil
|
||||
}
|
||||
|
||||
func (u *dohWithClose) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) {
|
||||
return u.u.ExchangeContext(ctx, m)
|
||||
}
|
||||
|
||||
func (u *dohWithClose) Close() error {
|
||||
if u.closer != nil {
|
||||
return u.closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDefaultClientQuicConfig() *quic.Config {
|
||||
return &quic.Config{
|
||||
TokenStore: quic.NewLRUTokenStore(4, 8),
|
||||
|
||||
// Dns does not need large amount of io, so the rx/tx windows are small.
|
||||
InitialStreamReceiveWindow: 4 * 1024,
|
||||
MaxStreamReceiveWindow: 4 * 1024,
|
||||
InitialConnectionReceiveWindow: 8 * 1024,
|
||||
MaxConnectionReceiveWindow: 64 * 1024,
|
||||
|
||||
MaxIdleTimeout: time.Second * 30,
|
||||
KeepAlivePeriod: time.Second * 25,
|
||||
HandshakeIdleTimeout: tlsHandshakeTimeout,
|
||||
}
|
||||
}
|
||||
233
pkg/upstream/upstream_test.go
Normal file
233
pkg/upstream/upstream_test.go
Normal file
@ -0,0 +1,233 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func newUDPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
|
||||
udpConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
udpAddr := udpConn.LocalAddr().String()
|
||||
udpServer := dns.Server{
|
||||
PacketConn: udpConn,
|
||||
Handler: handler,
|
||||
}
|
||||
go udpServer.ActivateAndServe()
|
||||
return udpAddr, func() {
|
||||
udpServer.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
func newTCPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tcpAddr := l.Addr().String()
|
||||
tcpServer := dns.Server{
|
||||
Listener: l,
|
||||
Handler: handler,
|
||||
MaxTCPQueries: -1,
|
||||
}
|
||||
go tcpServer.ActivateAndServe()
|
||||
return tcpAddr, func() {
|
||||
tcpServer.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
func newDoTTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
|
||||
serverName := "test"
|
||||
cert, err := utils.GenerateCertificate(serverName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tlsConfig := new(tls.Config)
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
doTAddr := tlsListener.Addr().String()
|
||||
doTServer := dns.Server{
|
||||
Net: "tcp-tls",
|
||||
Listener: tlsListener,
|
||||
TLSConfig: tlsConfig,
|
||||
Handler: handler,
|
||||
MaxTCPQueries: -1,
|
||||
}
|
||||
go doTServer.ActivateAndServe()
|
||||
return doTAddr, func() {
|
||||
doTServer.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
type newTestServerFunc func(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func())
|
||||
|
||||
var m = map[string]newTestServerFunc{
|
||||
"udp": newUDPTestServer,
|
||||
"tcp": newTCPTestServer,
|
||||
"tls": newDoTTestServer,
|
||||
}
|
||||
|
||||
func Test_fastUpstream(t *testing.T) {
|
||||
|
||||
// TODO: add test for doh
|
||||
// TODO: add test for socks5
|
||||
|
||||
// server config
|
||||
for scheme, f := range m {
|
||||
for _, bigMsg := range [...]bool{true, false} {
|
||||
for _, latency := range [...]time.Duration{0, time.Millisecond * 10} {
|
||||
|
||||
// client specific
|
||||
for _, idleTimeout := range [...]time.Duration{0, time.Second} {
|
||||
|
||||
testName := fmt.Sprintf(
|
||||
"test: protocol: %s, bigMsg: %v, latency: %s, getIdleTimeout: %s",
|
||||
scheme,
|
||||
bigMsg,
|
||||
latency,
|
||||
idleTimeout,
|
||||
)
|
||||
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
addr, shutdownServer := f(t, &vServer{
|
||||
latency: latency,
|
||||
bigMsg: bigMsg,
|
||||
})
|
||||
defer shutdownServer()
|
||||
u, err := NewUpstream(
|
||||
scheme+"://"+addr,
|
||||
Opt{
|
||||
IdleTimeout: time.Second,
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := testUpstream(u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func testUpstream(u Upstream) error {
|
||||
wg := sync.WaitGroup{}
|
||||
errs := make([]error, 0)
|
||||
errsLock := sync.Mutex{}
|
||||
logErr := func(err error) {
|
||||
errsLock.Lock()
|
||||
errs = append(errs, err)
|
||||
errsLock.Unlock()
|
||||
}
|
||||
errsToString := func() string {
|
||||
s := fmt.Sprintf("%d err(s) occured during the test: ", len(errs))
|
||||
for i := range errs {
|
||||
s = s + errs[i].Error() + "|"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
for i := uint16(0); i < 10; i++ {
|
||||
wg.Add(1)
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("example.com.", dns.TypeA)
|
||||
q.Id = i
|
||||
queryPayload, err := q.Pack()
|
||||
if err != nil {
|
||||
logErr(err)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := u.ExchangeContext(ctx, queryPayload)
|
||||
if err != nil {
|
||||
logErr(err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := new(dns.Msg)
|
||||
err = resp.Unpack(*r)
|
||||
if err != nil {
|
||||
logErr(err)
|
||||
return
|
||||
}
|
||||
if q.Id != resp.Id {
|
||||
logErr(dns.ErrId)
|
||||
return
|
||||
}
|
||||
if !resp.Response {
|
||||
logErr(fmt.Errorf("resp is not a resp bit"))
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if len(errs) != 0 {
|
||||
return errors.New(errsToString())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type vServer struct {
|
||||
latency time.Duration
|
||||
bigMsg bool // with 1kb padding
|
||||
}
|
||||
|
||||
var padding = make([]byte, 1024)
|
||||
|
||||
func (s *vServer) ServeDNS(w dns.ResponseWriter, q *dns.Msg) {
|
||||
r := new(dns.Msg)
|
||||
r.SetReply(q)
|
||||
if s.bigMsg {
|
||||
r.SetEdns0(dns.MaxMsgSize, false)
|
||||
opt := r.IsEdns0()
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_PADDING{Padding: padding})
|
||||
}
|
||||
|
||||
time.Sleep(s.latency)
|
||||
w.WriteMsg(r)
|
||||
}
|
||||
104
pkg/upstream/utils.go
Normal file
104
pkg/upstream/utils.go
Normal file
@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type socketOpts struct {
|
||||
so_mark int
|
||||
bind_to_device string
|
||||
}
|
||||
|
||||
func parseDialAddr(urlHost, dialAddr string, defaultPort uint16) (string, uint16, error) {
|
||||
addr := urlHost
|
||||
if len(dialAddr) > 0 {
|
||||
addr = dialAddr
|
||||
}
|
||||
host, port, err := trySplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if port == 0 {
|
||||
port = defaultPort
|
||||
}
|
||||
return host, port, nil
|
||||
}
|
||||
|
||||
func joinPort(host string, port uint16) string {
|
||||
return net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
}
|
||||
|
||||
func tryRemovePort(s string) string {
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return s
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// trySplitHostPort splits host and port.
|
||||
// If s has no port, it returns s,0,nil
|
||||
func trySplitHostPort(s string) (string, uint16, error) {
|
||||
var port uint16
|
||||
host, portS, err := net.SplitHostPort(s)
|
||||
if err == nil {
|
||||
n, err := strconv.ParseUint(portS, 10, 16)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("invalid port, %w", err)
|
||||
}
|
||||
port = uint16(n)
|
||||
return host, port, nil
|
||||
}
|
||||
return s, 0, nil
|
||||
}
|
||||
|
||||
func parseBootstrapAp(s string) (netip.AddrPort, error) {
|
||||
host, port, err := trySplitHostPort(s)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
if port == 0 {
|
||||
port = 53
|
||||
}
|
||||
addr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
return netip.AddrPortFrom(addr, port), nil
|
||||
}
|
||||
|
||||
func tryTrimIpv6Brackets(s string) string {
|
||||
if len(s) < 2 {
|
||||
return s
|
||||
}
|
||||
if s[0] == '[' && s[len(s)-1] == ']' {
|
||||
return s[1 : len(s)-2]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func msgTruncated(b []byte) bool {
|
||||
return b[2]&(1<<1) != 0
|
||||
}
|
||||
28
pkg/upstream/utils_others.go
Normal file
28
pkg/upstream/utils_others.go
Normal file
@ -0,0 +1,28 @@
|
||||
//go:build !linux
|
||||
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package upstream
|
||||
|
||||
import "syscall"
|
||||
|
||||
func getSocketControlFunc(_ socketOpts) func(string, string, syscall.RawConn) error {
|
||||
return nil
|
||||
}
|
||||
58
pkg/upstream/utils_unix.go
Normal file
58
pkg/upstream/utils_unix.go
Normal file
@ -0,0 +1,58 @@
|
||||
//go:build linux
|
||||
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func getSocketControlFunc(opts socketOpts) func(string, string, syscall.RawConn) error {
|
||||
return func(_, _ string, c syscall.RawConn) error {
|
||||
var sysCallErr error
|
||||
if err := c.Control(func(fd uintptr) {
|
||||
// SO_MARK
|
||||
if opts.so_mark > 0 {
|
||||
sysCallErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.so_mark)
|
||||
if sysCallErr != nil {
|
||||
sysCallErr = os.NewSyscallError("failed to set SO_MARK", sysCallErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// SO_BINDTODEVICE
|
||||
if len(opts.bind_to_device) > 0 {
|
||||
sysCallErr = unix.SetsockoptString(int(fd), unix.SOL_SOCKET, unix.SO_BINDTODEVICE, opts.bind_to_device)
|
||||
if sysCallErr != nil {
|
||||
sysCallErr = os.NewSyscallError("failed to set SO_BINDTODEVICE", sysCallErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return sysCallErr
|
||||
}
|
||||
}
|
||||
77
pkg/utils/config_helper.go
Normal file
77
pkg/utils/config_helper.go
Normal file
@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"golang.org/x/exp/constraints"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func SetDefaultNum[K constraints.Integer | constraints.Float](p *K, d K) {
|
||||
if *p == 0 {
|
||||
*p = d
|
||||
}
|
||||
}
|
||||
|
||||
func SetDefaultUnsignNum[K constraints.Integer | constraints.Float](p *K, d K) {
|
||||
if *p <= 0 {
|
||||
*p = d
|
||||
}
|
||||
}
|
||||
|
||||
func SetDefaultString(p *string, d string) {
|
||||
if len(*p) == 0 {
|
||||
*p = d
|
||||
}
|
||||
}
|
||||
|
||||
func CheckNumRange[K constraints.Integer | constraints.Float](v, min, max K) bool {
|
||||
if v < min || v > max {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// WeakDecode decodes args from config to output.
|
||||
func WeakDecode(in any, output any) error {
|
||||
config := &mapstructure.DecoderConfig{
|
||||
ErrorUnused: true,
|
||||
Result: output,
|
||||
WeaklyTypedInput: true,
|
||||
TagName: "yaml",
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return decoder.Decode(in)
|
||||
}
|
||||
|
||||
func ParseNameOrNum[T constraints.Integer](s string, m map[string]T) (T, bool) {
|
||||
i, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
v, ok := m[s]
|
||||
return v, ok
|
||||
}
|
||||
return T(i), true
|
||||
}
|
||||
58
pkg/utils/net.go
Normal file
58
pkg/utils/net.go
Normal file
@ -0,0 +1,58 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// GetIPFromAddr returns a net.IP from the given net.Addr.
|
||||
// addr can be *net.TCPAddr, *net.UDPAddr, *net.IPNet, *net.IPAddr
|
||||
// Will return nil otherwise.
|
||||
func GetIPFromAddr(addr net.Addr) (ip net.IP) {
|
||||
switch v := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return v.IP
|
||||
case *net.UDPAddr:
|
||||
return v.IP
|
||||
case *net.IPNet:
|
||||
return v.IP
|
||||
case *net.IPAddr:
|
||||
return v.IP
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAddrFromAddr returns netip.Addr from net.Addr.
|
||||
// See also: GetIPFromAddr.
|
||||
func GetAddrFromAddr(addr net.Addr) netip.Addr {
|
||||
a, _ := netip.AddrFromSlice(GetIPFromAddr(addr))
|
||||
return a
|
||||
}
|
||||
|
||||
// SplitSchemeAndHost splits addr to protocol and host.
|
||||
func SplitSchemeAndHost(addr string) (protocol, host string) {
|
||||
if protocol, host, ok := SplitString2(addr, "://"); ok {
|
||||
return protocol, host
|
||||
} else {
|
||||
return "", addr
|
||||
}
|
||||
}
|
||||
52
pkg/utils/quic.go
Normal file
52
pkg/utils/quic.go
Normal file
@ -0,0 +1,52 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var quicSrkSalt = []byte{115, 189, 156, 229, 145, 216, 251, 127, 220, 89,
|
||||
243, 234, 211, 79, 190, 166, 135, 253, 183, 36, 245, 174, 78, 200, 54, 213,
|
||||
85, 255, 104, 240, 103, 27}
|
||||
|
||||
var (
|
||||
quicSrkInitOnce sync.Once
|
||||
quicSrk *[32]byte
|
||||
quicSrkFromIface net.Interface
|
||||
quicSrkInitErr error
|
||||
)
|
||||
|
||||
func initQUICSrkFromIfaceMac() {
|
||||
nonZero := func(b []byte) bool {
|
||||
for _, i := range b {
|
||||
if i != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
quicSrkInitErr = err
|
||||
return
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
if nonZero(iface.HardwareAddr) {
|
||||
k := sha256.Sum256(append(iface.HardwareAddr, quicSrkSalt...))
|
||||
quicSrk = &k
|
||||
quicSrkFromIface = iface
|
||||
return
|
||||
}
|
||||
}
|
||||
quicSrkInitErr = errors.New("cannot find non-zero mac interface")
|
||||
}
|
||||
|
||||
// A helper func to init quic stateless reset key.
|
||||
// It use the first non-zero interface mac + sha256 hash.
|
||||
func InitQUICSrkFromIfaceMac() (*[32]byte, net.Interface, error) {
|
||||
quicSrkInitOnce.Do(initQUICSrkFromIfaceMac)
|
||||
return quicSrk, quicSrkFromIface, quicSrkInitErr
|
||||
}
|
||||
20
pkg/utils/server.go
Normal file
20
pkg/utils/server.go
Normal file
@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package utils
|
||||
49
pkg/utils/strings.go
Normal file
49
pkg/utils/strings.go
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// BytesToStringUnsafe converts bytes to string.
|
||||
func BytesToStringUnsafe(b []byte) string {
|
||||
return unsafe.String(unsafe.SliceData(b), len(b))
|
||||
}
|
||||
|
||||
// RemoveComment removes comment after "symbol".
|
||||
func RemoveComment(s, symbol string) string {
|
||||
if i := strings.Index(s, symbol); i >= 0 {
|
||||
return s[:i]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SplitString2 split s to two parts by given symbol
|
||||
func SplitString2(s, symbol string) (s1 string, s2 string, ok bool) {
|
||||
if len(symbol) == 0 {
|
||||
return "", s, true
|
||||
}
|
||||
if i := strings.Index(s, symbol); i >= 0 {
|
||||
return s[:i], s[i+len(symbol):], true
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
107
pkg/utils/utils.go
Normal file
107
pkg/utils/utils.go
Normal file
@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LoadCertPool reads and loads certificates in certs.
|
||||
func LoadCertPool(certs []string) (*x509.CertPool, error) {
|
||||
rootCAs := x509.NewCertPool()
|
||||
for _, cert := range certs {
|
||||
b, err := os.ReadFile(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok := rootCAs.AppendCertsFromPEM(b); !ok {
|
||||
return nil, fmt.Errorf("no certificate was successfully parsed in %s", cert)
|
||||
}
|
||||
}
|
||||
return rootCAs, nil
|
||||
}
|
||||
|
||||
// GenerateCertificate generates an ecdsa certificate with given dnsName.
|
||||
// This should only use in test.
|
||||
func GenerateCertificate(dnsName string) (cert tls.Certificate, err error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
//serial number
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("generate serial number: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{CommonName: dnsName},
|
||||
DNSNames: []string{dnsName},
|
||||
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
return tls.X509KeyPair(certPEM, keyPEM)
|
||||
}
|
||||
|
||||
// ClosedChan returns true if c is closed.
|
||||
// c must not use for sending data and must be used in close() only.
|
||||
// If ClosedChan receives something from c, it panics.
|
||||
func ClosedChan(c chan struct{}) bool {
|
||||
select {
|
||||
case _, ok := <-c:
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
panic("received from the chan")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
123
pkg/utils/utils_test.go
Normal file
123
pkg/utils/utils_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitString2(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
symbol string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantS1 string
|
||||
wantS2 string
|
||||
wantOk bool
|
||||
}{
|
||||
{"blank", args{"", ""}, "", "", true},
|
||||
{"blank", args{"///", ""}, "", "///", true},
|
||||
{"split", args{"///", "/"}, "", "//", true},
|
||||
{"split", args{"--/", "/"}, "--", "", true},
|
||||
{"split", args{"https://***.***.***", "://"}, "https", "***.***.***", true},
|
||||
{"split", args{"://***.***.***", "://"}, "", "***.***.***", true},
|
||||
{"split", args{"https://", "://"}, "https", "", true},
|
||||
{"split", args{"--/", "*"}, "", "", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotS1, gotS2, gotOk := SplitString2(tt.args.s, tt.args.symbol)
|
||||
if gotS1 != tt.wantS1 {
|
||||
t.Errorf("SplitString2() gotS1 = %v, want %v", gotS1, tt.wantS1)
|
||||
}
|
||||
if gotS2 != tt.wantS2 {
|
||||
t.Errorf("SplitString2() gotS2 = %v, want %v", gotS2, tt.wantS2)
|
||||
}
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("SplitString2() gotOk = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveComment(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
symbol string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{name: "empty", args: args{s: "", symbol: ""}, want: ""},
|
||||
{name: "empty symbol", args: args{s: "12345", symbol: ""}, want: ""},
|
||||
{name: "empty string", args: args{s: "", symbol: "#"}, want: ""},
|
||||
{name: "remove 1", args: args{s: "123/456", symbol: "/"}, want: "123"},
|
||||
{name: "remove 2", args: args{s: "123//456", symbol: "//"}, want: "123"},
|
||||
{name: "remove 3", args: args{s: "123/*/456", symbol: "//"}, want: "123/*/456"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := RemoveComment(tt.args.s, tt.args.symbol); got != tt.want {
|
||||
t.Errorf("RemoveComment() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type TestArgsStruct struct {
|
||||
A string `yaml:"1"`
|
||||
B []int `yaml:"2"`
|
||||
}
|
||||
|
||||
func Test_WeakDecode(t *testing.T) {
|
||||
testObj := new(TestArgsStruct)
|
||||
testArgs := map[string]any{
|
||||
"1": "test",
|
||||
"2": []int{1, 2, 3},
|
||||
}
|
||||
wantObj := &TestArgsStruct{
|
||||
A: "test",
|
||||
B: []int{1, 2, 3},
|
||||
}
|
||||
|
||||
err := WeakDecode(testArgs, testObj)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(testObj, wantObj) {
|
||||
t.Fatalf("args decode failed, want %v, got %v", wantObj, testObj)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_WeakDecode2(t *testing.T) {
|
||||
testObj := new([]byte)
|
||||
args := []any{"1", 2, 3}
|
||||
err := WeakDecode(args, testObj)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
85
pkg/zone_file/zone_file.go
Normal file
85
pkg/zone_file/zone_file.go
Normal file
@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Copyright (C) 2020-2022, IrineSistiana
|
||||
*
|
||||
* This file is part of mosdns.
|
||||
*
|
||||
* mosdns is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* mosdns is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package zone_file
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Matcher struct {
|
||||
m map[dns.Question][]dns.RR
|
||||
}
|
||||
|
||||
func (m *Matcher) LoadFile(s string) error {
|
||||
f, err := os.Open(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return m.Load(f)
|
||||
}
|
||||
|
||||
func (m *Matcher) Load(r io.Reader) error {
|
||||
if m.m == nil {
|
||||
m.m = make(map[dns.Question][]dns.RR)
|
||||
}
|
||||
|
||||
parser := dns.NewZoneParser(r, "", "")
|
||||
parser.SetDefaultTTL(3600)
|
||||
for {
|
||||
rr, ok := parser.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
h := rr.Header()
|
||||
q := dns.Question{
|
||||
Name: strings.ToLower(h.Name),
|
||||
Qtype: h.Rrtype,
|
||||
Qclass: h.Class,
|
||||
}
|
||||
m.m[q] = append(m.m[q], rr)
|
||||
}
|
||||
return parser.Err()
|
||||
}
|
||||
|
||||
func (m *Matcher) Search(q dns.Question) []dns.RR {
|
||||
q.Name = strings.ToLower(q.Name)
|
||||
return m.m[q]
|
||||
}
|
||||
|
||||
func (m *Matcher) Reply(q *dns.Msg) *dns.Msg {
|
||||
var r *dns.Msg
|
||||
for _, question := range q.Question {
|
||||
rr := m.Search(question)
|
||||
if rr != nil {
|
||||
if r == nil {
|
||||
r = new(dns.Msg)
|
||||
r.SetReply(q)
|
||||
}
|
||||
r.Answer = append(r.Answer, rr...)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user