新增Mikrotik API 插入解析ip
Some checks are pending
Test mosdns / build (push) Waiting to run

This commit is contained in:
dengxiongjian 2025-07-31 11:28:55 +08:00
commit cd761e8145
186 changed files with 20932 additions and 0 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto

63
.github/ISSUE_TEMPLATE/bug-report.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: "Bug report"
description: "[反馈 Bug 必须用此模板] 程序不能运行?或者没有按照期望的那样工作?"
title: "[Bug] "
body:
- type: markdown
attributes:
value: "感谢热心的反馈 bug。描述越详细越有助于定位和解决 Bug。不提供有效信息的反馈可能会被直接关闭。"
- type: checkboxes
id: pre-check
attributes:
label: "在提交之前,请确认"
options:
- label: "我已经尝试搜索过 Issue ,但没有找到相关问题。"
required: true
- label: "我正在使用最新的 mosdns 版本(或者最新的 commit),问题依旧存在。"
required: true
- label: "我仔细看过 wiki 后仍然无法自行解决该问题。"
required: true
- label: "我非常确定这是 mosdns 核心的问题。(如果是通过第三方衍生软件使用 mosdns
核心,不确定问题源头时,请先向衍生软件开发者提交问题。)"
required: true
- type: input
id: version
attributes:
label: mosdns 版本
description: "不清楚可用 `mosdns version` 查看。"
placeholder: v9.9.9
validations:
required: true
- type: input
id: system
attributes:
label: 操作系统
placeholder: ubuntu
validations:
required: true
- type: textarea
id: what-happened
attributes:
label: Bug 描述和复现步骤
description: "描述越详细越有助于定位和解决 Bug。"
placeholder: "示例: Bug: mosdns 的 qname 匹配器无法匹配 xxxx 域名。复现方式: 使用如下配置,请求 xxxx.xxxx ,观察 log 发现匹配器没有匹配到。"
validations:
required: true
- type: textarea
id: config
attributes:
label: 使用的配置文件
render: yaml
description: "必须是完整的配置文件。不要只复制某个插件的配置片段。请尽可能的提供最小配置文件(能复现 bug但没有其他功能)方便开发者定位问题。"
validations:
required: true
- type: textarea
id: log
attributes:
label: mosdns 的 log 记录
render: txt

View File

@ -0,0 +1,16 @@
name: "Feature request"
description: "希望添加新功能"
title: "[Feature request] "
body:
- type: markdown
attributes:
value: "感谢提出建议。提交的 issue 不一定会马上得到开发者的回复。开发者会根据社区的反应以及功能的重要性选择是否实现这个新功能。"
- type: textarea
id: new-feature
attributes:
label: 希望添加的功能
description: "请详细描述一下该功能作用,使用场景,是否有类似实现,文档等。"
placeholder: "希望支持 ... 可以更快的 ..."
validations:
required: true

View File

@ -0,0 +1,10 @@
---
name: Other questions
about: 不要在 Issue 里提问。有问题请进入 Discussions 讨论。
title: ''
labels: ''
assignees: ''
---
不要在 issue 里提问。有问题请进入 Discussions 讨论。

38
.github/workflows/release.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: Release mosdns
on:
push:
tags:
- 'v*'
jobs:
build-release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
check-latest: true
cache: true
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Build
run: python ./release.py
env:
CGO_ENABLED: '0'
- name: Publish
uses: softprops/action-gh-release@v2
with:
files: './release/mosdns*.zip'
prerelease: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

24
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,24 @@
name: Test mosdns
on:
push:
pull_request:
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23'
check-latest: true
cache: true
- name: Build
run: go build -v ./...
- name: Test
run: go test -race -v ./...

25
.gitignore vendored Normal file
View File

@ -0,0 +1,25 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
vendor/
# release dir
release/
# ide
.vscode/
.idea/
# test utils
testutils/

12
Dockerfile Normal file
View File

@ -0,0 +1,12 @@
FROM golang:latest as builder
ARG CGO_ENABLED=0
COPY ./ /root/src/
WORKDIR /root/src/
RUN go build -ldflags "-s -w -X main.version=$(git describe --tags --long --always)" -trimpath -o mosdns
FROM alpine:latest
COPY --from=builder /root/src/mosdns /usr/bin/
RUN apk add --no-cache ca-certificates

14
Dockerfile_buildx Normal file
View File

@ -0,0 +1,14 @@
FROM --platform=${TARGETPLATFORM} golang:latest as builder
ARG CGO_ENABLED=0
COPY ./ /root/src/
WORKDIR /root/src/
RUN go build -ldflags "-s -w -X main.version=$(git describe --tags --long --always)" -trimpath -o mosdns
FROM --platform=${TARGETPLATFORM} alpine:latest
COPY --from=builder /root/src/mosdns /usr/bin/
RUN apk add --no-cache ca-certificates

674
LICENSE Normal file
View File

@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

7
README.md Normal file
View File

@ -0,0 +1,7 @@
# mosdns
功能概述、配置方式、教程等,详见: [wiki](https://irine-sistiana.gitbook.io/mosdns-wiki/)
下载预编译文件、更新日志,详见: [release](https://github.com/IrineSistiana/mosdns/releases)
docker 镜像: [docker hub](https://hub.docker.com/r/irinesistiana/mosdns)

192
config.yaml Normal file
View File

@ -0,0 +1,192 @@
# ============================================
# MosDNS v5 完整配置:中英文注释版
# ============================================
log:
level: info # 可选: debug/info/warn/error
# 管理 API可用于调试、监控
api:
http: "0.0.0.0:5535"
# 引入上游 DNS 配置文件(在 dns.yaml 中)
include: ['/opt/mosdns/dns.yaml']
plugins:
# ======================
# 域名/IP 匹配规则
# ======================
# GFW 域名列表(如 google.com
- tag: GFW_domains
type: domain_set
args:
files:
- "/opt/mosdns/config/geosite_tiktok.txt"
- "/opt/mosdns/config/gfwlist.out.txt"
# Amazon 域名列表(如 amazon.com
- tag: amazon_domains
type: domain_set
args:
files:
- "/opt/mosdns/config/geosite_amazon.txt"
- "/opt/mosdns/config/geosite_amazon-ads.txt"
- "/opt/mosdns/config/geosite_amazontrust.txt"
- "/opt/mosdns/config/amazon.txt"
# 中国大陆常用域名(如 .cn / baidu.com
- tag: CN_domains
type: domain_set
args:
files:
- "/opt/mosdns/config/domains.txt"
# 中国大陆 IP 列表
- tag: geoip_cn
type: ip_set
args:
files:
- "/opt/mosdns/config/cn.txt"
# ======================
# 缓存模块
# ======================
- tag: cache
type: cache
args:
size: 32768 # 最大缓存条目数
lazy_cache_ttl: 43200 # 默认缓存 TTL
# ======================
# 上游 DNS 定义
# ======================
# 国内 DNS fallback 模式
- tag: forward_local
type: fallback
args:
primary: cn-dns
secondary: cn-dns
threshold: 500
always_standby: true
# 国外 DNS fallback 模式
- tag: forward_remote
type: fallback
args:
primary: jp-dns
secondary: jp-dns
threshold: 500
always_standby: true
# 封装调用国内 DNS
- tag: forward_local_upstream
type: sequence
args:
- exec: prefer_ipv4
- exec: query_summary forward_local
- exec: $forward_local
# 封装调用国外 DNS
- tag: forward_remote_upstream
type: sequence
args:
- exec: prefer_ipv4
- exec: query_summary forward_remote
- exec: $forward_remote
# 如果已有响应,直接返回
- tag: has_resp_sequence
type: sequence
args:
- matches: has_resp
exec: accept
# ======================
# 查询逻辑
# ======================
# 拒绝无效查询(如 HTTPS 记录)
- tag: query_is_reject_domain
type: sequence
args:
- matches: qtype 65
exec: reject 3
# GFW 域名:强制走国外 DNS
- tag: query_is_foreign_domain
type: sequence
args:
- matches: qname $GFW_domains
exec: $forward_remote_upstream
- exec: query_summary gfw_domain
# 国内域名:强制走国内 DNS
- tag: query_is_cn_domain
type: sequence
args:
- matches: qname $CN_domains
exec: $forward_local_upstream
- exec: query_summary cn_domain
# Amazon 域名:走国外 DNS 并添加到 MikroTik
- tag: query_is_amazon_domain
type: sequence
args:
- matches: qname $amazon_domains
exec: $forward_remote_upstream
- exec: $mikrotik_amazon
- exec: query_summary amazon_domain
# 未知域名处理逻辑:
# 先查国内 DNS → 如果返回非 CN IP ⇒ fallback 到国外
- tag: query_unknown_fallback
type: sequence
args:
- exec: prefer_ipv4
- exec: $forward_local
- matches: resp_ip $geoip_cn
exec: accept
- exec: $forward_remote_upstream
- exec: query_summary fallback_to_remote
# ======================
# 主查询处理流程
# ======================
- tag: main_sequence
type: sequence
args:
- exec: $cache # 首先查缓存
- exec: $query_is_reject_domain
- exec: jump has_resp_sequence
- exec: $query_is_foreign_domain # gfwlist
- exec: jump has_resp_sequence
- exec: $query_is_cn_domain # 国内域名
- exec: jump has_resp_sequence
- exec: $query_is_amazon_domain # Amazon 域名(走国外 DNS + 添加到 MikroTik
- exec: jump has_resp_sequence
- exec: $query_unknown_fallback # 其他未知域名 fallback 流程
- exec: jump has_resp_sequence
# ======================
# 服务监听
# ======================
# UDP 监听
- tag: udp_server
type: udp_server
args:
entry: main_sequence
listen: ":5300"
# TCP 监听
- tag: tcp_server
type: tcp_server
args:
entry: main_sequence
listen: ":5300"

50
coremain/config.go Normal file
View File

@ -0,0 +1,50 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package coremain
import (
"github.com/IrineSistiana/mosdns/v5/mlog"
)
type Config struct {
Log mlog.LogConfig `yaml:"log"`
Include []string `yaml:"include"`
Plugins []PluginConfig `yaml:"plugins"`
API APIConfig `yaml:"api"`
}
// PluginConfig represents a plugin config
type PluginConfig struct {
// Tag for this plugin. Optional. If omitted, this plugin will
// be registered with a random tag.
Tag string `yaml:"tag"`
// Type, required.
Type string `yaml:"type"`
// Args, might be required by some plugins.
// The type of Args is depended on RegNewPluginFunc.
// If it's a map[string]any, it will be converted by mapstruct.
Args any `yaml:"args"`
}
type APIConfig struct {
HTTP string `yaml:"http"`
}

244
coremain/mosdns.go Normal file
View File

@ -0,0 +1,244 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package coremain
import (
"bytes"
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/v5/mlog"
"github.com/IrineSistiana/mosdns/v5/pkg/safe_close"
"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
"io"
"net/http"
"net/http/pprof"
)
type Mosdns struct {
logger *zap.Logger // non-nil logger.
// Plugins
plugins map[string]any
httpMux *chi.Mux
metricsReg *prometheus.Registry
sc *safe_close.SafeClose
}
// NewMosdns initializes a mosdns instance and its plugins.
func NewMosdns(cfg *Config) (*Mosdns, error) {
// Init logger.
lg, err := mlog.NewLogger(cfg.Log)
if err != nil {
return nil, fmt.Errorf("failed to init logger: %w", err)
}
m := &Mosdns{
logger: lg,
plugins: make(map[string]any),
httpMux: chi.NewRouter(),
metricsReg: newMetricsReg(),
sc: safe_close.NewSafeClose(),
}
// This must be called after m.httpMux and m.metricsReg been set.
m.initHttpMux()
// Start http api server
if httpAddr := cfg.API.HTTP; len(httpAddr) > 0 {
httpServer := &http.Server{
Addr: httpAddr,
Handler: m.httpMux,
}
m.sc.Attach(func(done func(), closeSignal <-chan struct{}) {
defer done()
errChan := make(chan error, 1)
go func() {
m.logger.Info("starting api http server", zap.String("addr", httpAddr))
errChan <- httpServer.ListenAndServe()
}()
select {
case err := <-errChan:
m.sc.SendCloseSignal(err)
case <-closeSignal:
_ = httpServer.Close()
}
})
}
// Load plugins.
// Close all plugins on signal.
// From here, call m.sc.SendCloseSignal() if any plugin failed to load.
m.sc.Attach(func(done func(), closeSignal <-chan struct{}) {
go func() {
defer done()
<-closeSignal
m.logger.Info("starting shutdown sequences")
for tag, p := range m.plugins {
if closer, _ := p.(io.Closer); closer != nil {
m.logger.Info("closing plugin", zap.String("tag", tag))
_ = closer.Close()
}
}
m.logger.Info("all plugins were closed")
}()
})
// Preset plugins
if err := m.loadPresetPlugins(); err != nil {
m.sc.SendCloseSignal(err)
_ = m.sc.WaitClosed()
return nil, err
}
// Plugins from config.
if err := m.loadPluginsFromCfg(cfg, 0); err != nil {
m.sc.SendCloseSignal(err)
_ = m.sc.WaitClosed()
return nil, err
}
m.logger.Info("all plugins are loaded")
return m, nil
}
// NewTestMosdnsWithPlugins returns a mosdns instance for testing.
func NewTestMosdnsWithPlugins(p map[string]any) *Mosdns {
return &Mosdns{
logger: mlog.Nop(),
httpMux: chi.NewRouter(),
plugins: p,
metricsReg: newMetricsReg(),
sc: safe_close.NewSafeClose(),
}
}
func (m *Mosdns) GetSafeClose() *safe_close.SafeClose {
return m.sc
}
// CloseWithErr is a shortcut for m.sc.SendCloseSignal
func (m *Mosdns) CloseWithErr(err error) {
m.sc.SendCloseSignal(err)
}
// Logger returns a non-nil logger.
func (m *Mosdns) Logger() *zap.Logger {
return m.logger
}
// GetPlugin returns a plugin.
func (m *Mosdns) GetPlugin(tag string) any {
return m.plugins[tag]
}
// GetMetricsReg returns a prometheus.Registerer with a prefix of "mosdns_"
func (m *Mosdns) GetMetricsReg() prometheus.Registerer {
return prometheus.WrapRegistererWithPrefix("mosdns_", m.metricsReg)
}
func (m *Mosdns) GetAPIRouter() *chi.Mux {
return m.httpMux
}
func (m *Mosdns) RegPluginAPI(tag string, mux *chi.Mux) {
m.httpMux.Mount("/plugins/"+tag, mux)
}
func newMetricsReg() *prometheus.Registry {
reg := prometheus.NewRegistry()
reg.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
reg.MustRegister(collectors.NewGoCollector())
return reg
}
// initHttpMux initializes api entries. It MUST be called after m.metricsReg being initialized.
func (m *Mosdns) initHttpMux() {
// Register metrics.
m.httpMux.Method(http.MethodGet, "/metrics", promhttp.HandlerFor(m.metricsReg, promhttp.HandlerOpts{}))
// Register pprof.
m.httpMux.Route("/debug/pprof", func(r chi.Router) {
r.Get("/*", pprof.Index)
r.Get("/cmdline", pprof.Cmdline)
r.Get("/profile", pprof.Profile)
r.Get("/symbol", pprof.Symbol)
r.Get("/trace", pprof.Trace)
})
// A helper page for invalid request.
invalidApiReqHelper := func(w http.ResponseWriter, req *http.Request) {
b := new(bytes.Buffer)
_, _ = fmt.Fprintf(b, "Invalid request %s %s\n\n", req.Method, req.RequestURI)
b.WriteString("Available api urls:\n")
_ = chi.Walk(m.httpMux, func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
b.WriteString(method)
b.WriteByte(' ')
b.WriteString(route)
b.WriteByte('\n')
return nil
})
_, _ = w.Write(b.Bytes())
}
m.httpMux.NotFound(invalidApiReqHelper)
m.httpMux.MethodNotAllowed(invalidApiReqHelper)
}
func (m *Mosdns) loadPresetPlugins() error {
for tag, f := range LoadNewPersetPluginFuncs() {
p, err := f(NewBP(tag, m))
if err != nil {
return fmt.Errorf("failed to init preset plugin %s, %w", tag, err)
}
m.plugins[tag] = p
}
return nil
}
// loadPluginsFromCfg loads plugins from this config. It follows include first.
func (m *Mosdns) loadPluginsFromCfg(cfg *Config, includeDepth int) error {
const maxIncludeDepth = 8
if includeDepth > maxIncludeDepth {
return errors.New("maximum include depth reached")
}
includeDepth++
// Follow include first.
for _, s := range cfg.Include {
subCfg, path, err := loadConfig(s)
if err != nil {
return fmt.Errorf("failed to read config from %s, %w", s, err)
}
m.logger.Info("load config", zap.String("file", path))
if err := m.loadPluginsFromCfg(subCfg, includeDepth); err != nil {
return fmt.Errorf("failed to load config from %s, %w", s, err)
}
}
for i, pc := range cfg.Plugins {
if err := m.newPlugin(pc); err != nil {
return fmt.Errorf("failed to init plugin #%d %s, %w", i, pc.Tag, err)
}
}
return nil
}

201
coremain/plugin.go Normal file
View File

@ -0,0 +1,201 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package coremain
import (
"fmt"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/go-chi/chi/v5"
"go.uber.org/zap"
"reflect"
"sync"
)
// NewPluginArgsFunc represents a func that creates a new args object.
type NewPluginArgsFunc func() any
// NewPluginFunc represents a func that can init a Plugin.
// args is the object created by NewPluginArgsFunc.
type NewPluginFunc func(bp *BP, args any) (p any, err error)
type PluginTypeInfo struct {
NewPlugin NewPluginFunc
NewArgs NewPluginArgsFunc
}
var (
// pluginTypeRegister stores init funcs for certain plugin types
pluginTypeRegister struct {
sync.RWMutex
m map[string]PluginTypeInfo
}
)
// RegNewPluginFunc registers the type.
// If the type has been registered. RegNewPluginFunc will panic.
func RegNewPluginFunc(typ string, initFunc NewPluginFunc, argsType NewPluginArgsFunc) {
pluginTypeRegister.Lock()
defer pluginTypeRegister.Unlock()
if pluginTypeRegister.m == nil {
pluginTypeRegister.m = make(map[string]PluginTypeInfo)
}
_, ok := pluginTypeRegister.m[typ]
if ok {
panic(fmt.Sprintf("duplicate plugin type [%s]", typ))
}
pluginTypeRegister.m[typ] = PluginTypeInfo{
NewPlugin: initFunc,
NewArgs: argsType,
}
}
// DelPluginType deletes the init func for this plugin type.
// It is a noop if pluginType is not registered.
func DelPluginType(typ string) {
pluginTypeRegister.Lock()
defer pluginTypeRegister.Unlock()
delete(pluginTypeRegister.m, typ)
}
// GetPluginType gets the registered type init func.
func GetPluginType(typ string) (PluginTypeInfo, bool) {
pluginTypeRegister.RLock()
defer pluginTypeRegister.RUnlock()
info, ok := pluginTypeRegister.m[typ]
return info, ok
}
// newPlugin initializes a Plugin from c and adds it to mosdns.
func (m *Mosdns) newPlugin(c PluginConfig) error {
if len(c.Tag) == 0 {
c.Tag = fmt.Sprintf("anonymouse_%s_%d", c.Type, len(m.plugins))
}
if _, dup := m.plugins[c.Tag]; dup {
return fmt.Errorf("duplicated plugin tag %s", c.Tag)
}
typeInfo, ok := GetPluginType(c.Type)
if !ok {
return fmt.Errorf("plugin type %s not defined", c.Type)
}
args := typeInfo.NewArgs()
if reflect.TypeOf(c.Args) == reflect.TypeOf(args) { // Same type, no need to parse.
args = c.Args
} else {
if err := utils.WeakDecode(c.Args, args); err != nil {
return fmt.Errorf("unable to decode plugin args: %w", err)
}
}
m.logger.Info("loading plugin", zap.String("tag", c.Tag), zap.String("type", c.Type))
p, err := typeInfo.NewPlugin(NewBP(c.Tag, m), args)
if err != nil {
return fmt.Errorf("failed to init plugin: %w", err)
}
m.plugins[c.Tag] = p
return nil
}
// GetAllPluginTypes returns all plugin types which are configurable.
func GetAllPluginTypes() []string {
pluginTypeRegister.RLock()
defer pluginTypeRegister.RUnlock()
var t []string
for typ := range pluginTypeRegister.m {
t = append(t, typ)
}
return t
}
type NewPersetPluginFunc func(bp *BP) (any, error)
var presetPluginFuncReg struct {
sync.Mutex
m map[string]NewPersetPluginFunc
}
func RegNewPersetPluginFunc(tag string, f NewPersetPluginFunc) {
presetPluginFuncReg.Lock()
defer presetPluginFuncReg.Unlock()
if presetPluginFuncReg.m == nil {
presetPluginFuncReg.m = make(map[string]NewPersetPluginFunc)
}
if _, ok := presetPluginFuncReg.m[tag]; ok {
panic(fmt.Sprintf("preset plugin %s has already been registered", tag))
}
presetPluginFuncReg.m[tag] = f
}
func LoadNewPersetPluginFuncs() map[string]NewPersetPluginFunc {
presetPluginFuncReg.Lock()
defer presetPluginFuncReg.Unlock()
m := make(map[string]NewPersetPluginFunc)
for tag, pluginFunc := range presetPluginFuncReg.m {
m[tag] = pluginFunc
}
return m
}
type BP struct {
tag string
m *Mosdns
l *zap.Logger
}
// NewBP creates a new BP. m MUST NOT nil.
func NewBP(tag string, m *Mosdns) *BP {
return &BP{
tag: tag,
l: m.Logger().Named(tag),
m: m,
}
}
// L returns a non-nil logger.
func (p *BP) L() *zap.Logger {
return p.l
}
// M returns a non-nil Mosdns.
func (p *BP) M() *Mosdns {
return p.m
}
// Tag returns the plugin tag.
// This tag should be unique globally unless it's in
// a test environment.
func (p *BP) Tag() string {
return p.tag
}
// RegAPI mounts mux to mosdns api. Note: Plugins MUST NOT call RegAPI twice.
// Since mounting same path to root chi.Mux causes runtime panic.
func (p *BP) RegAPI(mux *chi.Mux) {
p.m.RegPluginAPI(p.tag, mux)
}

159
coremain/run.go Normal file
View File

@ -0,0 +1,159 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package coremain
import (
"fmt"
"github.com/IrineSistiana/mosdns/v5/mlog"
"github.com/kardianos/service"
"github.com/mitchellh/mapstructure"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.uber.org/zap"
"os"
"os/signal"
"runtime"
"syscall"
)
type serverFlags struct {
c string
dir string
cpu int
asService bool
}
var rootCmd = &cobra.Command{
Use: "mosdns",
}
func init() {
sf := new(serverFlags)
startCmd := &cobra.Command{
Use: "start [-c config_file] [-d working_dir]",
Short: "Start mosdns main program.",
RunE: func(cmd *cobra.Command, args []string) error {
if sf.asService {
svc, err := service.New(&serverService{f: sf}, svcCfg)
if err != nil {
return fmt.Errorf("failed to init service, %w", err)
}
return svc.Run()
}
m, err := NewServer(sf)
if err != nil {
return err
}
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
sig := <-c
m.logger.Warn("signal received", zap.Stringer("signal", sig))
m.sc.SendCloseSignal(nil)
}()
return m.GetSafeClose().WaitClosed()
},
DisableFlagsInUseLine: true,
SilenceUsage: true,
}
rootCmd.AddCommand(startCmd)
fs := startCmd.Flags()
fs.StringVarP(&sf.c, "config", "c", "", "config file")
fs.StringVarP(&sf.dir, "dir", "d", "", "working dir")
fs.IntVar(&sf.cpu, "cpu", 0, "set runtime.GOMAXPROCS")
fs.BoolVar(&sf.asService, "as-service", false, "start as a service")
_ = fs.MarkHidden("as-service")
serviceCmd := &cobra.Command{
Use: "service",
Short: "Manage mosdns as a system service.",
}
serviceCmd.PersistentPreRunE = initService
serviceCmd.AddCommand(
newSvcInstallCmd(),
newSvcUninstallCmd(),
newSvcStartCmd(),
newSvcStopCmd(),
newSvcRestartCmd(),
newSvcStatusCmd(),
)
rootCmd.AddCommand(serviceCmd)
}
func AddSubCmd(c *cobra.Command) {
rootCmd.AddCommand(c)
}
func Run() error {
return rootCmd.Execute()
}
func NewServer(sf *serverFlags) (*Mosdns, error) {
if sf.cpu > 0 {
runtime.GOMAXPROCS(sf.cpu)
}
if len(sf.dir) > 0 {
err := os.Chdir(sf.dir)
if err != nil {
return nil, fmt.Errorf("failed to change the current working directory, %w", err)
}
mlog.L().Info("working directory changed", zap.String("path", sf.dir))
}
cfg, fileUsed, err := loadConfig(sf.c)
if err != nil {
return nil, fmt.Errorf("fail to load config, %w", err)
}
mlog.L().Info("main config loaded", zap.String("file", fileUsed))
return NewMosdns(cfg)
}
// loadConfig load a config from a file. If filePath is empty, it will
// automatically search and load a file which name start with "config".
func loadConfig(filePath string) (*Config, string, error) {
v := viper.New()
if len(filePath) > 0 {
v.SetConfigFile(filePath)
} else {
v.SetConfigName("config")
v.AddConfigPath(".")
}
if err := v.ReadInConfig(); err != nil {
return nil, "", fmt.Errorf("failed to read config: %w", err)
}
decoderOpt := func(cfg *mapstructure.DecoderConfig) {
cfg.ErrorUnused = true
cfg.TagName = "yaml"
cfg.WeaklyTypedInput = true
}
cfg := new(Config)
if err := v.Unmarshal(cfg, decoderOpt); err != nil {
return nil, "", fmt.Errorf("failed to unmarshal config: %w", err)
}
return cfg, v.ConfigFileUsed(), nil
}

220
coremain/service.go Normal file
View File

@ -0,0 +1,220 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package coremain
import (
"fmt"
"github.com/IrineSistiana/mosdns/v5/mlog"
"github.com/kardianos/service"
"github.com/spf13/cobra"
"go.uber.org/zap"
"os"
"path/filepath"
"time"
)
var (
// initialized by "service" sub command
svc service.Service
svcCfg = &service.Config{
Name: "mosdns",
DisplayName: "mosdns",
Description: "A DNS forwarder",
}
)
type serverService struct {
f *serverFlags
m *Mosdns
}
func (ss *serverService) Start(s service.Service) error {
mlog.L().Info("starting service", zap.String("platform", s.Platform()))
m, err := NewServer(ss.f)
if err != nil {
return err
}
ss.m = m
go func() {
err := m.GetSafeClose().WaitClosed()
if err != nil {
m.Logger().Fatal("server exited", zap.Error(err))
} else {
m.Logger().Info("server exited")
}
}()
return nil
}
func (ss *serverService) Stop(_ service.Service) error {
ss.m.Logger().Info("service is shutting down")
ss.m.GetSafeClose().SendCloseSignal(nil)
return ss.m.GetSafeClose().WaitClosed()
}
// initService will init svc for sub command "service"
func initService(_ *cobra.Command, _ []string) error {
s, err := service.New(&serverService{}, svcCfg)
if err != nil {
return fmt.Errorf("cannot init service, %w", err)
}
svc = s
return nil
}
func newSvcInstallCmd() *cobra.Command {
sf := new(serverFlags)
c := &cobra.Command{
Use: "install [-d working_dir] [-c config_file]",
Short: "Install mosdns as a system service.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
if len(sf.dir) > 0 {
absWd, err := filepath.Abs(sf.dir)
if err != nil {
return fmt.Errorf("cannot solve absolute working dir path, %w", err)
} else {
sf.dir = absWd
}
} else {
ep, err := os.Executable()
if err != nil {
return fmt.Errorf("cannot solve current executable path, %w", err)
}
sf.dir = filepath.Dir(ep)
}
mlog.S().Infof("set service working dir as %s", sf.dir)
svcCfg.Arguments = []string{"start", "--as-service", "-d", sf.dir}
if len(sf.c) > 0 {
svcCfg.Arguments = append(svcCfg.Arguments, "-c", sf.c)
}
return svc.Install()
},
DisableFlagsInUseLine: true,
SilenceUsage: true,
}
c.Flags().StringVarP(&sf.dir, "dir", "d", "", "working dir")
c.Flags().StringVarP(&sf.c, "config", "c", "", "config path")
return c
}
func newSvcUninstallCmd() *cobra.Command {
c := &cobra.Command{
Use: "uninstall",
Short: "Uninstall mosdns from system service.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
return svc.Uninstall()
},
DisableFlagsInUseLine: true,
SilenceUsage: true,
}
return c
}
func newSvcStartCmd() *cobra.Command {
c := &cobra.Command{
Use: "start",
Short: "Start mosdns system service.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
if err := svc.Start(); err != nil {
return err
}
mlog.S().Info("service is starting")
time.Sleep(time.Second)
s, err := svc.Status()
if err != nil {
mlog.S().Warn("cannot get service status, %w", err)
} else {
switch s {
case service.StatusRunning:
mlog.S().Info("service is running")
case service.StatusStopped:
mlog.S().Error("service is stopped, check mosdns and system service log for more info")
default:
mlog.S().Warn("cannot get service status, system may not support this operation")
}
}
return nil
},
DisableFlagsInUseLine: true,
SilenceUsage: true,
}
return c
}
func newSvcStopCmd() *cobra.Command {
c := &cobra.Command{
Use: "stop",
Short: "Stop mosdns system service.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
return svc.Stop()
},
DisableFlagsInUseLine: true,
SilenceUsage: true,
}
return c
}
func newSvcRestartCmd() *cobra.Command {
c := &cobra.Command{
Use: "restart",
Short: "Restart mosdns system service.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
return svc.Restart()
},
DisableFlagsInUseLine: true,
SilenceUsage: true,
}
return c
}
func newSvcStatusCmd() *cobra.Command {
c := &cobra.Command{
Use: "status",
Short: "Status of mosdns system service.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
s, err := svc.Status()
if err != nil {
return fmt.Errorf("cannot get service status, %w", err)
}
var out string
switch s {
case service.StatusRunning:
out = "running"
case service.StatusStopped:
out = "stopped"
case service.StatusUnknown:
out = "unknown"
}
println(out)
return nil
},
DisableFlagsInUseLine: true,
SilenceUsage: true,
}
return c
}

View File

@ -0,0 +1,192 @@
# MosDNS + MikroTik Amazon 域名处理部署指南(更新版)
## 功能说明
这个配置会在解析 Amazon 相关域名时,自动将解析到的 IP 地址添加到 MikroTik 路由器的 address list 中,用于防火墙规则控制。
## 部署步骤
### 1. 上传文件到 Debian 12 服务器
```bash
# 上传编译好的 mosdns 可执行文件
scp mosdns-linux-amd64 user@your-server:/usr/local/bin/mosdns
# 上传配置文件
scp config.yaml user@your-server:/opt/mosdns/
scp dns.yaml user@your-server:/opt/mosdns/
# 设置执行权限
ssh user@your-server "chmod +x /usr/local/bin/mosdns"
```
### 2. 创建必要的目录和文件
```bash
# 创建配置目录
sudo mkdir -p /opt/mosdns/config
# 下载 Amazon 域名列表
sudo wget -O /opt/mosdns/config/geosite_amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon
sudo wget -O /opt/mosdns/config/geosite_amazon-ads.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon-ads
sudo wget -O /opt/mosdns/config/geosite_amazontrust.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazontrust
sudo wget -O /opt/mosdns/config/amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon
# 下载其他必要的域名和 IP 文件
sudo wget -O /opt/mosdns/config/geosite_tiktok.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/tiktok
sudo wget -O /opt/mosdns/config/gfwlist.out.txt https://raw.githubusercontent.com/gfwlist/gfwlist/master/gfwlist.txt
sudo wget -O /opt/mosdns/config/domains.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/category-games
sudo wget -O /opt/mosdns/config/cn.txt https://raw.githubusercontent.com/Loyalsoldier/v2ray-rules-dat/release/geoip.dat
```
### 3. 在 MikroTik 中创建 address list
```bash
# 通过 SSH 连接到 MikroTik 路由器
ssh admin@10.248.0.1
# 创建 IPv4 和 IPv6 address list
/ip firewall address-list add list=AmazonIP
/ip firewall address-list add list=AmazonIP6
# 创建防火墙规则(可选)
/ip firewall filter add chain=forward src-address-list=AmazonIP action=drop comment="Block Amazon IPs"
/ip firewall filter add chain=forward src-address-list=AmazonIP6 action=drop comment="Block Amazon IPv6 IPs"
```
### 4. 修改配置文件中的 MikroTik 连接信息
编辑 `/opt/mosdns/dns.yaml` 文件,确认 mikrotik_amazon 插件的配置:
```yaml
# 当前配置(根据你的实际情况修改)
args: "10.248.0.1:9728:admin:szn0s!nw@pwd():false:10:AmazonIP:AmazonIP6:24:32:AmazonIP:86400"
```
参数说明:
- `10.248.0.1`: MikroTik 路由器 IP
- `9728`: API 端口
- `admin`: 用户名
- `szn0s!nw@pwd()`: 密码
- `false`: 不使用 TLS
- `10`: 连接超时时间
- `AmazonIP`: IPv4 address list 名称
- `AmazonIP6`: IPv6 address list 名称
- `24`: IPv4 掩码
- `32`: IPv6 掩码
- `AmazonIP`: 注释
- `86400`: 地址超时时间24小时
### 5. 创建 systemd 服务
```bash
sudo tee /etc/systemd/system/mosdns.service > /dev/null <<EOF
[Unit]
Description=MosDNS DNS Server
After=network.target
[Service]
Type=simple
User=root
ExecStart=/usr/local/bin/mosdns -c /opt/mosdns/config.yaml
Restart=always
RestartSec=3
[Install]
WantedBy=multi-user.target
EOF
# 重新加载 systemd 配置
sudo systemctl daemon-reload
# 启用并启动服务
sudo systemctl enable mosdns
sudo systemctl start mosdns
# 检查服务状态
sudo systemctl status mosdns
```
### 6. 配置 DNS 转发
修改 `/etc/systemd/resolved.conf`
```ini
[Resolve]
DNS=127.0.0.1:5300
#FallbackDNS=8.8.8.8 8.8.4.4
#Domains=
#DNSSEC=no
#DNSOverTLS=no
#MulticastDNS=yes
#LLMNR=yes
#Cache=yes
#DNSStubListener=no
```
重启 systemd-resolved
```bash
sudo systemctl restart systemd-resolved
```
### 7. 测试配置
```bash
# 测试 Amazon 域名解析
nslookup amazon.com 127.0.0.1:5300
nslookup aws.amazon.com 127.0.0.1:5300
# 检查 MikroTik address list 是否更新
ssh admin@10.248.0.1 "/ip firewall address-list print where list=AmazonIP"
# 查看 mosdns 日志
sudo journalctl -u mosdns -f
```
## 配置说明
### 工作流程
1. **域名匹配**:当查询 Amazon 相关域名时,匹配 `amazon_domains` 集合
2. **DNS 解析**:使用国外 DNS 服务器解析域名
3. **IP 提取**:从 DNS 响应中提取 A 和 AAAA 记录
4. **地址添加**:通过 MikroTik API 将 IP 添加到 address list
5. **超时管理**IP 地址会在 24 小时后自动过期
### 监控和调试
```bash
# 查看实时日志
sudo journalctl -u mosdns -f
# 查看服务状态
sudo systemctl status mosdns
# 测试 MikroTik 连接
curl -k https://10.248.0.1:9729/api/rest/ip/firewall/address-list
# 查看 API 状态
curl http://localhost:5535/metrics
```
### 故障排除
1. **连接失败**:检查 MikroTik IP、端口和认证信息
2. **权限不足**:确保 MikroTik 用户具有管理 address list 的权限
3. **域名文件缺失**:确保所有域名列表文件都已下载
4. **DNS 解析失败**:检查上游 DNS 服务器配置
## 安全注意事项
1. 不要在配置文件中使用明文密码,考虑使用环境变量
2. 限制对 MikroTik API 端口的访问
3. 定期更新域名列表文件
4. 监控 address list 大小,避免过多条目影响性能
## 更新日志
- 修复了插件注册问题,现在支持 YAML 配置和快速配置
- 更新了路径配置为 `/opt/mosdns/`
- 更新了端口配置为 `5300`
- 更新了 API 端口为 `5535`

182
deploy-mikrotik-amazon.md Normal file
View File

@ -0,0 +1,182 @@
# MosDNS + MikroTik Amazon 域名处理部署指南
## 功能说明
这个配置会在解析 Amazon 相关域名时,自动将解析到的 IP 地址添加到 MikroTik 路由器的 address list 中,用于防火墙规则控制。
## 部署步骤
### 1. 上传文件到 Debian 12 服务器
```bash
# 上传编译好的 mosdns 可执行文件
scp mosdns-linux-amd64 user@your-server:/usr/local/bin/mosdns
# 上传配置文件
scp config.yaml user@your-server:/usr/local/mosdns/
scp dns.yaml user@your-server:/usr/local/mosdns/
# 设置执行权限
ssh user@your-server "chmod +x /usr/local/bin/mosdns"
```
### 2. 创建必要的目录和文件
```bash
# 创建配置目录
sudo mkdir -p /usr/local/mosdns/config
# 下载 Amazon 域名列表
sudo wget -O /usr/local/mosdns/config/geosite_amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon
sudo wget -O /usr/local/mosdns/config/geosite_amazon-ads.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon-ads
sudo wget -O /usr/local/mosdns/config/geosite_amazontrust.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazontrust
sudo wget -O /usr/local/mosdns/config/amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon
# 下载其他必要的域名和 IP 文件
sudo wget -O /usr/local/mosdns/config/geosite_tiktok.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/tiktok
sudo wget -O /usr/local/mosdns/config/gfwlist.out.txt https://raw.githubusercontent.com/gfwlist/gfwlist/master/gfwlist.txt
sudo wget -O /usr/local/mosdns/config/domains.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/category-games
sudo wget -O /usr/local/mosdns/config/cn.txt https://raw.githubusercontent.com/Loyalsoldier/v2ray-rules-dat/release/geoip.dat
```
### 3. 在 MikroTik 中创建 address list
```bash
# 通过 SSH 连接到 MikroTik 路由器
ssh admin@192.168.1.1
# 创建 IPv4 和 IPv6 address list
/ip firewall address-list add list=amazon_ips
/ip firewall address-list add list=amazon_ips6
# 创建防火墙规则(可选)
/ip firewall filter add chain=forward src-address-list=amazon_ips action=drop comment="Block Amazon IPs"
/ip firewall filter add chain=forward src-address-list=amazon_ips6 action=drop comment="Block Amazon IPv6 IPs"
```
### 4. 修改配置文件中的 MikroTik 连接信息
编辑 `/usr/local/mosdns/dns.yaml` 文件,修改 mikrotik_amazon 插件的配置:
```yaml
# 修改为你的 MikroTik 实际信息
args: "192.168.1.1:8728:admin:your-password:false:10:amazon_ips:amazon_ips6:24:32:amazon_domain:86400"
```
参数说明:
- `192.168.1.1`: MikroTik 路由器 IP
- `8728`: API 端口
- `admin`: 用户名
- `your-password`: 密码
- `false`: 不使用 TLS
- `10`: 连接超时时间
- `amazon_ips`: IPv4 address list 名称
- `amazon_ips6`: IPv6 address list 名称
- `24`: IPv4 掩码
- `32`: IPv6 掩码
- `amazon_domain`: 注释
- `86400`: 地址超时时间24小时
### 5. 创建 systemd 服务
```bash
sudo tee /etc/systemd/system/mosdns.service > /dev/null <<EOF
[Unit]
Description=MosDNS DNS Server
After=network.target
[Service]
Type=simple
User=root
ExecStart=/usr/local/bin/mosdns -c /usr/local/mosdns/config.yaml
Restart=always
RestartSec=3
[Install]
WantedBy=multi-user.target
EOF
# 重新加载 systemd 配置
sudo systemctl daemon-reload
# 启用并启动服务
sudo systemctl enable mosdns
sudo systemctl start mosdns
# 检查服务状态
sudo systemctl status mosdns
```
### 6. 配置 DNS 转发
修改 `/etc/systemd/resolved.conf`
```ini
[Resolve]
DNS=127.0.0.1
#FallbackDNS=8.8.8.8 8.8.4.4
#Domains=
#DNSSEC=no
#DNSOverTLS=no
#MulticastDNS=yes
#LLMNR=yes
#Cache=yes
#DNSStubListener=no
```
重启 systemd-resolved
```bash
sudo systemctl restart systemd-resolved
```
### 7. 测试配置
```bash
# 测试 Amazon 域名解析
nslookup amazon.com 127.0.0.1
nslookup aws.amazon.com 127.0.0.1
# 检查 MikroTik address list 是否更新
ssh admin@192.168.1.1 "/ip firewall address-list print where list=amazon_ips"
# 查看 mosdns 日志
sudo journalctl -u mosdns -f
```
## 配置说明
### 工作流程
1. **域名匹配**:当查询 Amazon 相关域名时,匹配 `amazon_domains` 集合
2. **DNS 解析**:使用国外 DNS 服务器解析域名
3. **IP 提取**:从 DNS 响应中提取 A 和 AAAA 记录
4. **地址添加**:通过 MikroTik API 将 IP 添加到 address list
5. **超时管理**IP 地址会在 24 小时后自动过期
### 监控和调试
```bash
# 查看实时日志
sudo journalctl -u mosdns -f
# 查看服务状态
sudo systemctl status mosdns
# 测试 MikroTik 连接
curl -k https://192.168.1.1:8729/api/rest/ip/firewall/address-list
```
### 故障排除
1. **连接失败**:检查 MikroTik IP、端口和认证信息
2. **权限不足**:确保 MikroTik 用户具有管理 address list 的权限
3. **域名文件缺失**:确保所有域名列表文件都已下载
4. **DNS 解析失败**:检查上游 DNS 服务器配置
## 安全注意事项
1. 不要在配置文件中使用明文密码,考虑使用环境变量
2. 限制对 MikroTik API 端口的访问
3. 定期更新域名列表文件
4. 监控 address list 大小,避免过多条目影响性能

55
dns-example-gfw.yaml Normal file
View File

@ -0,0 +1,55 @@
################ DNS Plugins #################
plugins:
- tag: mikrotik-one
type: forward
args:
concurrent: 1
upstreams:
- addr: "udp://10.248.0.1"
- tag: cn-dns
type: forward
args:
concurrent: 6
upstreams:
- addr: "udp://202.96.128.86"
- addr: "udp://202.96.128.166"
- addr: "udp://119.29.29.29"
- addr: "udp://223.5.5.5"
- addr: "udp://114.114.114.114"
- addr: "udp://180.76.76.76"
- tag: jp-dns
type: forward
args:
concurrent: 4
upstreams:
- addr: "tls://1dot1dot1dot1.cloudflare-dns.com"
dial_addr: "1.1.1.1"
enable_pipeline: true
- addr: "tls://1dot1dot1dot1.cloudflare-dns.com"
dial_addr: "1.0.0.1"
enable_pipeline: true
- addr: "tls://dns.google"
dial_addr: "8.8.8.8"
enable_pipeline: true
- addr: "tls://dns.google"
dial_addr: "8.8.4.4"
enable_pipeline: true
# MikroTik Address List 插件 - 处理 Amazon 相关域名
# 示例:将地址列表改为 gfw
- tag: mikrotik_amazon
type: mikrotik_addresslist
args:
host: "10.248.0.1"
port: 9728
username: "admin"
password: "szn0s!nw@pwd()"
use_tls: false
timeout: 10
address_list4: "gfw" # 改为 gfw插件会自动创建这个地址列表
mask4: 24
comment: "amazon_domain"
timeout_addr: 86400

57
dns.yaml Normal file
View File

@ -0,0 +1,57 @@
################ DNS Plugins #################
plugins:
- tag: mikrotik-one
type: forward
args:
concurrent: 1
upstreams:
- addr: "udp://10.248.0.1"
- tag: cn-dns
type: forward
args:
concurrent: 6
upstreams:
- addr: "udp://202.96.128.86"
- addr: "udp://202.96.128.166"
- addr: "udp://119.29.29.29"
- addr: "udp://223.5.5.5"
- addr: "udp://114.114.114.114"
- addr: "udp://180.76.76.76"
- tag: jp-dns
type: forward
args:
concurrent: 4 # 同步向 3 条上游并发查询
upstreams:
- addr: "tls://1dot1dot1dot1.cloudflare-dns.com"
dial_addr: "1.1.1.1"
enable_pipeline: true
- addr: "tls://1dot1dot1dot1.cloudflare-dns.com"
dial_addr: "1.0.0.1"
enable_pipeline: true
- addr: "tls://dns.google"
dial_addr: "8.8.8.8"
enable_pipeline: true
- addr: "tls://dns.google"
dial_addr: "8.8.4.4"
enable_pipeline: true
# MikroTik Address List 插件 - 处理 Amazon 相关域名
# 示例:将地址列表改为 gfw
- tag: mikrotik_amazon
type: mikrotik_addresslist
args:
host: "10.248.0.1"
port: 9728
username: "admin"
password: "szn0s!nw@pwd()"
use_tls: false
timeout: 10
address_list4: "gfw" # 改为 gfw插件会自动创建这个地址列表
mask4: 24
comment: "amazon_domain"
timeout_addr: 86400

71
go.mod Normal file
View File

@ -0,0 +1,71 @@
module github.com/IrineSistiana/mosdns/v5
go 1.22
require (
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230918115058-c72bd9761c57
github.com/go-chi/chi/v5 v5.1.0
github.com/go-routeros/routeros/v3 v3.0.1
github.com/google/nftables v0.2.0
github.com/kardianos/service v1.2.2
github.com/klauspost/compress v1.17.9
github.com/miekg/dns v1.1.62
github.com/mitchellh/mapstructure v1.5.0
github.com/nadoo/ipset v0.5.0
github.com/prometheus/client_golang v1.19.1
github.com/quic-go/quic-go v0.46.0
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.2.1-beta.2.0.20221107222636-d3c0a2caa559
go.uber.org/zap v1.27.0
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
golang.org/x/net v0.28.0
golang.org/x/sync v0.8.0
golang.org/x/sys v0.24.0
golang.org/x/time v0.6.0
google.golang.org/protobuf v1.34.2
)
replace github.com/nadoo/ipset v0.5.0 => github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo/v2 v2.20.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.55.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/quic-go/qpack v0.4.0 // indirect
github.com/sagikazarmark/locafero v0.6.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/text v0.17.0 // indirect
golang.org/x/tools v0.24.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

154
go.sum Normal file
View File

@ -0,0 +1,154 @@
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230918115058-c72bd9761c57 h1:nfurUSSmVY9sY/mYyoReOA1w2cR2fp2eicL9ojicZhQ=
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230918115058-c72bd9761c57/go.mod h1:pQ/FSsWSNYmNdgIKmulKlmVC/R2PEpq2vIEi3J9IijI=
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a h1:GQdh/h0q0ni3L//CXusyk+7QdhBL289vdNaes1WKkHI=
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-routeros/routeros/v3 v3.0.1 h1:FdNKlF6Hst8nkHr0dIvD54pQ+dZ8sHOJfQSVRKz0BFg=
github.com/go-routeros/routeros/v3 v3.0.1/go.mod h1:j4mq65czXfKtHsdLkgVv8w7sNzyhLZy1TKi2zQDMpiQ=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k=
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/onsi/ginkgo/v2 v2.20.0 h1:PE84V2mHqoT1sglvHc8ZdQtPcwmvvt29WLEEO3xmdZw=
github.com/onsi/ginkgo/v2 v2.20.0/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI=
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/quic-go v0.46.0 h1:uuwLClEEyk1DNvchH8uCByQVjo3yKL9opKulExNDs7Y=
github.com/quic-go/quic-go v0.46.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3N51bwOk=
github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/vishvananda/netlink v1.2.1-beta.2.0.20221107222636-d3c0a2caa559 h1:NwQroOyW+fpfiUroBzAMqFc6NRwBmvJevoVtEK6gsFE=
github.com/vishvananda/netlink v1.2.1-beta.2.0.20221107222636-d3c0a2caa559/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI=
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220804214406-8e32c043e418/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

50
main.go Normal file
View File

@ -0,0 +1,50 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package main
import (
"fmt"
"github.com/IrineSistiana/mosdns/v5/coremain"
"github.com/IrineSistiana/mosdns/v5/mlog"
_ "github.com/IrineSistiana/mosdns/v5/plugin"
_ "github.com/IrineSistiana/mosdns/v5/tools"
"github.com/spf13/cobra"
_ "net/http/pprof"
)
var (
version = "dev/unknown"
)
func init() {
coremain.AddSubCmd(&cobra.Command{
Use: "version",
Short: "Print out version info and exit.",
Run: func(cmd *cobra.Command, args []string) {
fmt.Println(version)
},
})
}
func main() {
if err := coremain.Run(); err != nil {
mlog.S().Fatal(err)
}
}

91
mlog/logger.go Normal file
View File

@ -0,0 +1,91 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package mlog
import (
"fmt"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"os"
)
type LogConfig struct {
// Level, See also zapcore.ParseLevel.
Level string `yaml:"level"`
// File that logger will be writen into.
// Default is stderr.
File string `yaml:"file"`
// Production enables json output.
Production bool `yaml:"production"`
}
var (
stderr = zapcore.Lock(os.Stderr)
lvl = zap.NewAtomicLevelAt(zap.InfoLevel)
l = zap.New(zapcore.NewCore(zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), stderr, lvl))
s = l.Sugar()
nop = zap.NewNop()
)
func NewLogger(lc LogConfig) (*zap.Logger, error) {
lvl, err := zapcore.ParseLevel(lc.Level)
if err != nil {
return nil, fmt.Errorf("invalid log level: %w", err)
}
var out zapcore.WriteSyncer
if lf := lc.File; len(lf) > 0 {
f, _, err := zap.Open(lf)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
out = zapcore.Lock(f)
} else {
out = stderr
}
if lc.Production {
return zap.New(zapcore.NewCore(zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()), out, lvl)), nil
}
return zap.New(zapcore.NewCore(zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), out, lvl)), nil
}
// L is a global logger.
func L() *zap.Logger {
return l
}
// SetLevel sets the log level for the global logger.
func SetLevel(l zapcore.Level) {
lvl.SetLevel(l)
}
// S is a global logger.
func S() *zap.SugaredLogger {
return s
}
// Nop is a logger that never writes out logs.
func Nop() *zap.Logger {
return nop
}

BIN
mosdns-linux-amd64 Normal file

Binary file not shown.

157
pkg/cache/cache.go vendored Normal file
View File

@ -0,0 +1,157 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package cache
import (
"github.com/IrineSistiana/mosdns/v5/pkg/concurrent_lru"
"github.com/IrineSistiana/mosdns/v5/pkg/concurrent_map"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"sync/atomic"
"time"
)
const (
defaultCleanerInterval = time.Second * 10
)
type Key interface {
concurrent_lru.Hashable
}
type Value interface {
any
}
// Cache is a simple map cache that stores values in memory.
// It is safe for concurrent use.
type Cache[K Key, V Value] struct {
opts Opts
closed atomic.Bool
closeNotify chan struct{}
m *concurrent_map.Map[K, *elem[V]]
}
type Opts struct {
Size int
CleanerInterval time.Duration
}
func (opts *Opts) init() {
utils.SetDefaultNum(&opts.Size, 1024)
utils.SetDefaultNum(&opts.CleanerInterval, defaultCleanerInterval)
}
type elem[V Value] struct {
v V
expirationTime time.Time
}
// New initializes a Cache.
// The minimum size is 1024.
// cleanerInterval specifies the interval that Cache scans
// and discards expired values. If cleanerInterval <= 0, a default
// interval will be used.
func New[K Key, V Value](opts Opts) *Cache[K, V] {
opts.init()
c := &Cache[K, V]{
closeNotify: make(chan struct{}),
m: concurrent_map.NewMapCache[K, *elem[V]](opts.Size),
}
go c.gcLoop(opts.CleanerInterval)
return c
}
// Close closes the inner cleaner of this cache.
func (c *Cache[K, V]) Close() error {
if ok := c.closed.CompareAndSwap(false, true); ok {
close(c.closeNotify)
}
return nil
}
func (c *Cache[K, V]) Get(key K) (v V, expirationTime time.Time, ok bool) {
if e, hasEntry := c.m.Get(key); hasEntry {
if e.expirationTime.Before(time.Now()) {
c.m.Del(key)
return
}
return e.v, e.expirationTime, true
}
return
}
// Range calls f through all entries. If f returns an error, the same error will be returned
// by Range.
func (c *Cache[K, V]) Range(f func(key K, v V, expirationTime time.Time) error) error {
cf := func(key K, v *elem[V]) (newV *elem[V], setV bool, delV bool, err error) {
return nil, false, false, f(key, v.v, v.expirationTime)
}
return c.m.RangeDo(cf)
}
// Store stores this kv in cache. If expirationTime is before time.Now(),
// Store is an noop.
func (c *Cache[K, V]) Store(key K, v V, expirationTime time.Time) {
now := time.Now()
if now.After(expirationTime) {
return
}
e := &elem[V]{
v: v,
expirationTime: expirationTime,
}
c.m.Set(key, e)
return
}
func (c *Cache[K, V]) gcLoop(interval time.Duration) {
if interval <= 0 {
interval = defaultCleanerInterval
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-c.closeNotify:
return
case now := <-ticker.C:
c.gc(now)
}
}
}
func (c *Cache[K, V]) gc(now time.Time) {
f := func(key K, v *elem[V]) (newV *elem[V], setV, delV bool, err error) {
return nil, false, now.After(v.expirationTime), nil
}
_ = c.m.RangeDo(f)
}
// Len returns the current size of this cache.
func (c *Cache[K, V]) Len() int {
return c.m.Len()
}
// Flush removes all stored entries from this cache.
func (c *Cache[K, V]) Flush() {
c.m.Flush()
}

98
pkg/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,98 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package cache
import (
"sync"
"testing"
"time"
)
type testKey int
func (t testKey) Sum() uint64 {
return uint64(t)
}
func Test_Cache(t *testing.T) {
c := New[testKey, int](Opts{
Size: 1024,
})
for i := 0; i < 128; i++ {
key := testKey(i)
c.Store(key, i, time.Now().Add(time.Millisecond*200))
v, _, ok := c.Get(key)
if v != i {
t.Fatal("cache kv mismatched")
}
if !ok {
t.Fatal()
}
}
for i := 0; i < 1024*4; i++ {
key := testKey(i)
c.Store(key, i, time.Now().Add(time.Millisecond*200))
}
if l := c.Len(); l > 1024 {
t.Fatal("cache overflow")
}
}
func Test_memCache_cleaner(t *testing.T) {
c := New[testKey, int](Opts{
Size: 1024,
CleanerInterval: time.Millisecond * 10,
})
defer c.Close()
for i := 0; i < 64; i++ {
key := testKey(i)
c.Store(key, i, time.Now().Add(time.Millisecond*10))
}
time.Sleep(time.Millisecond * 100)
if c.Len() != 0 {
t.Fatal()
}
}
func Test_memCache_race(t *testing.T) {
c := New[testKey, int](Opts{
Size: 1024,
})
defer c.Close()
wg := sync.WaitGroup{}
for i := 0; i < 32; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 256; i++ {
key := testKey(i)
c.Store(key, i, time.Now().Add(time.Minute))
_, _, _ = c.Get(key)
c.gc(time.Now())
}
}()
}
wg.Wait()
}

View File

@ -0,0 +1,149 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package concurrent_lru
import (
"github.com/IrineSistiana/mosdns/v5/pkg/lru"
"sync"
)
type Hashable interface {
comparable
Sum() uint64
}
type ShardedLRU[K Hashable, V any] struct {
l []*ConcurrentLRU[K, V]
}
func NewShardedLRU[K Hashable, V any](
shardNum, maxSizePerShard int,
onEvict func(key K, v V),
) *ShardedLRU[K, V] {
cl := &ShardedLRU[K, V]{
l: make([]*ConcurrentLRU[K, V], 0, shardNum),
}
for i := 0; i < shardNum; i++ {
cl.l = append(cl.l, NewConecurrentLRU[K, V](maxSizePerShard, onEvict))
}
return cl
}
func (c *ShardedLRU[K, V]) Add(key K, v V) {
sl := c.getShard(key)
sl.Add(key, v)
}
func (c *ShardedLRU[K, V]) Del(key K) {
sl := c.getShard(key)
sl.Del(key)
}
func (c *ShardedLRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) {
for _, l := range c.l {
removed += l.Clean(f)
}
return removed
}
func (c *ShardedLRU[K, V]) Flush() {
for _, l := range c.l {
l.Flush()
}
}
func (c *ShardedLRU[K, V]) Get(key K) (v V, ok bool) {
sl := c.getShard(key)
v, ok = sl.Get(key)
return
}
func (c *ShardedLRU[K, V]) Len() int {
sum := 0
for _, l := range c.l {
sum += l.Len()
}
return sum
}
func (c *ShardedLRU[K, V]) shardNum() int {
return len(c.l)
}
func (c *ShardedLRU[K, V]) getShard(key K) *ConcurrentLRU[K, V] {
return c.l[key.Sum()%uint64(c.shardNum())]
}
// ConcurrentLRU is a lru.LRU with a lock.
// It is concurrent safe.
type ConcurrentLRU[K comparable, V any] struct {
sync.Mutex
lru *lru.LRU[K, V]
}
func NewConecurrentLRU[K comparable, V any](maxSize int, onEvict func(key K, v V)) *ConcurrentLRU[K, V] {
return &ConcurrentLRU[K, V]{
lru: lru.NewLRU[K, V](maxSize, onEvict),
}
}
func (c *ConcurrentLRU[K, V]) Add(key K, v V) {
c.Lock()
defer c.Unlock()
c.lru.Add(key, v)
}
func (c *ConcurrentLRU[K, V]) Del(key K) {
c.Lock()
defer c.Unlock()
c.lru.Del(key)
}
func (c *ConcurrentLRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) {
c.Lock()
defer c.Unlock()
return c.lru.Clean(f)
}
func (c *ConcurrentLRU[K, V]) Flush() {
c.Lock()
defer c.Unlock()
c.lru.Flush()
}
func (c *ConcurrentLRU[K, V]) Get(key K) (v V, ok bool) {
c.Lock()
defer c.Unlock()
v, ok = c.lru.Get(key)
return
}
func (c *ConcurrentLRU[K, V]) Len() int {
c.Lock()
defer c.Unlock()
return c.lru.Len()
}

View File

@ -0,0 +1,111 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package concurrent_lru
import (
"reflect"
"testing"
)
type testKey int
func (k testKey) Sum() uint64 {
return uint64(k)
}
func TestConcurrentLRU(t *testing.T) {
onEvict := func(k testKey, v int) {}
var cache *ShardedLRU[testKey, int]
reset := func(shardNum, maxShardSize int) {
cache = NewShardedLRU[testKey, int](shardNum, maxShardSize, onEvict)
}
add := func(keys ...int) {
for _, k := range keys {
cache.Add(testKey(k), k)
}
}
mustGet := func(keys ...int) {
for _, k := range keys {
gotV, ok := cache.Get(testKey(k))
if !ok || !reflect.DeepEqual(gotV, k) {
t.Fatalf("want %v, got %v", k, gotV)
}
}
}
emptyGet := func(keys ...int) {
for _, k := range keys {
gotV, ok := cache.Get(testKey(k))
if ok {
t.Fatalf("want empty, got %v", gotV)
}
}
}
checkLen := func(want int) {
if want != cache.Len() {
t.Fatalf("want %v, got %v", want, cache.Len())
}
}
// test add
reset(4, 16)
add(1, 1, 1, 1, 2, 2, 3, 3, 4)
checkLen(4)
mustGet(1, 2, 3, 4)
emptyGet(5, 6, 7, 9999)
// test add overflow
reset(4, 16) // max size is 64
for i := 0; i < 1024; i++ {
add(i)
}
if cache.Len() > 64 {
t.Fatalf("lru overflowed: want len = %d, got = %d", 64, cache.Len())
}
// test del
reset(4, 16)
add(1, 2, 3, 4)
cache.Del(2)
cache.Del(4)
cache.Del(9999)
mustGet(1, 3)
emptyGet(2, 4)
// test clean
reset(4, 16)
add(1, 2, 3, 4)
cleanFunc := func(k testKey, v int) (remove bool) {
switch k {
case 1, 3:
return true
}
return false
}
if cleaned := cache.Clean(cleanFunc); cleaned != 2 {
t.Fatalf("q.Clean want cleaned = 2, got %v", cleaned)
}
mustGet(2, 4)
emptyGet(1, 3)
}

186
pkg/concurrent_map/map.go Normal file
View File

@ -0,0 +1,186 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package concurrent_map
import (
"sync"
)
const (
MapShardSize = 64
)
type Hashable interface {
comparable
Sum() uint64
}
type TestAndSetFunc[K comparable, V any] func(key K, v V, ok bool) (newV V, setV, deleteV bool)
type Map[K Hashable, V any] struct {
shards [MapShardSize]shard[K, V]
}
func NewMap[K Hashable, V any]() *Map[K, V] {
m := new(Map[K, V])
for i := range m.shards {
m.shards[i] = newShard[K, V](0)
}
return m
}
// NewMapCache returns a cache with a maximum size.
// Note that, because this it has multiple (MapShardSize) shards,
// the actual maximum size is MapShardSize*(size / MapShardSize).
// If size <=0, it's equal to NewMap().
func NewMapCache[K Hashable, V any](size int) *Map[K, V] {
sizePreShard := size / MapShardSize
m := new(Map[K, V])
for i := range m.shards {
m.shards[i] = newShard[K, V](sizePreShard)
}
return m
}
func (m *Map[K, V]) getShard(key K) *shard[K, V] {
return &m.shards[key.Sum()%MapShardSize]
}
func (m *Map[K, V]) Get(key K) (V, bool) {
return m.getShard(key).get(key)
}
func (m *Map[K, V]) Set(key K, v V) {
m.getShard(key).set(key, v)
}
func (m *Map[K, V]) Del(key K) {
m.getShard(key).del(key)
}
func (m *Map[K, V]) TestAndSet(key K, f func(v V, ok bool) (newV V, setV, delV bool)) {
m.getShard(key).testAndSet(key, f)
}
func (m *Map[K, V]) RangeDo(f func(k K, v V) (newV V, setV, delV bool, err error)) error {
for i := range m.shards {
if err := m.shards[i].rangeDo(f); err != nil {
return err
}
}
return nil
}
func (m *Map[K, V]) Len() int {
l := 0
for i := range m.shards {
l += m.shards[i].len()
}
return l
}
func (m *Map[K, V]) Flush() {
for i := range m.shards {
m.shards[i].flush()
}
}
type shard[K comparable, V any] struct {
l sync.RWMutex
max int // Negative or zero max means no limit.
m map[K]V
}
func newShard[K comparable, V any](max int) shard[K, V] {
return shard[K, V]{
max: max,
m: make(map[K]V),
}
}
func (m *shard[K, V]) get(key K) (V, bool) {
m.l.RLock()
defer m.l.RUnlock()
v, ok := m.m[key]
return v, ok
}
func (m *shard[K, V]) set(key K, v V) {
m.l.Lock()
defer m.l.Unlock()
if m.max > 0 && len(m.m)+1 > m.max {
for k := range m.m {
delete(m.m, k)
if len(m.m)+1 <= m.max {
break
}
}
}
m.m[key] = v
}
func (m *shard[K, V]) del(key K) {
m.l.Lock()
defer m.l.Unlock()
delete(m.m, key)
}
func (m *shard[K, V]) testAndSet(key K, f func(v V, ok bool) (newV V, setV, delV bool)) {
m.l.Lock()
defer m.l.Unlock()
v, ok := m.m[key]
newV, setV, deleteV := f(v, ok)
switch {
case setV:
m.m[key] = newV
case deleteV && ok:
delete(m.m, key)
}
}
func (m *shard[K, V]) len() int {
m.l.RLock()
defer m.l.RUnlock()
return len(m.m)
}
func (m *shard[K, V]) flush() {
m.l.RLock()
defer m.l.RUnlock()
m.m = make(map[K]V)
}
func (m *shard[K, V]) rangeDo(f func(k K, v V) (newV V, setV, delV bool, err error)) error {
m.l.Lock()
defer m.l.Unlock()
for k, v := range m.m {
newV, setV, deleteV, err := f(k, v)
if err != nil {
return err
}
switch {
case setV:
m.m[k] = newV
case deleteV:
delete(m.m, k)
}
}
return nil
}

View File

@ -0,0 +1,190 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package concurrent_map
import (
"errors"
"sync"
"testing"
)
type testMapHashable uint64
func (h testMapHashable) Sum() uint64 {
return uint64(h)
}
func Test_Map(t *testing.T) {
m := NewMap[testMapHashable, int]()
wg := sync.WaitGroup{}
// test add
for i := 0; i < 512; i++ {
i := i
wg.Add(1)
go func() {
defer wg.Done()
m.Set(testMapHashable(i), i)
}()
}
wg.Wait()
// test range
wantErr := errors.New("")
f := func(key testMapHashable, v int) (newV int, setV bool, deleteV bool, err error) {
return 0, false, false, wantErr
}
if wantErr != m.RangeDo(f) {
t.Fatal("range should return a error")
}
cc := make([]bool, 512)
f = func(key testMapHashable, v int) (newV int, setV bool, deleteV bool, err error) {
cc[key] = true
return 0, false, false, nil
}
_ = m.RangeDo(f)
for _, ok := range cc {
if !ok {
t.Fatal("test or range failed")
}
}
// test get
for i := 0; i < 512; i++ {
i := i
wg.Add(1)
go func() {
defer wg.Done()
v, ok := m.Get(testMapHashable(i))
if !ok {
t.Error()
return
}
if v != i {
t.Error()
return
}
}()
}
wg.Wait()
// test len
if m.Len() != 512 {
t.Fatal()
}
// test del
for i := 0; i < 512; i++ {
i := i
wg.Add(1)
go func() {
defer wg.Done()
m.Del(testMapHashable(i))
}()
}
wg.Wait()
if m.Len() != 0 {
t.Fatal()
}
}
func TestConcurrentMap_TestAndSet(t *testing.T) {
cm := NewMap[testMapHashable, int]()
wg := sync.WaitGroup{}
f := func(v int, ok bool) (newV int, setV bool, deleteV bool) {
return 1, true, false
}
for i := 0; i < 512; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cm.TestAndSet(1, f)
}()
}
wg.Wait()
v, _ := cm.Get(1)
if v != 1 {
t.Fatal()
}
// test delete
f = func(v int, ok bool) (newV int, setV bool, deleteV bool) {
return 1, false, true
}
cm.TestAndSet(1, f)
_, ok := cm.Get(1)
if ok {
t.Fatal()
}
}
func BenchmarkConcurrentMap_Get_And_Set(b *testing.B) {
keys := make([]testMapHashable, 2048)
m := NewMap[testMapHashable, int]()
for i := 0; i < 2048; i++ {
key := testMapHashable(i)
keys[i] = key
m.Set(key, i)
}
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
i++
key := keys[i%2048]
m.Set(key, i)
m.Get(key)
}
})
}
func Benchmark_RWMutexMap_Get_And_Set(b *testing.B) {
keys := make([]int, 2048)
rwm := new(sync.RWMutex)
m := make(map[int]int, 2048)
for i := 0; i < 2048; i++ {
keys[i] = i
m[i] = i
}
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
i++
key := keys[i%2048]
rwm.Lock()
m[key] = i
rwm.Unlock()
rwm.RLock()
_ = m[key]
rwm.RUnlock()
}
})
}

160
pkg/dnsutils/msg.go Normal file
View File

@ -0,0 +1,160 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package dnsutils
import (
"github.com/miekg/dns"
"strconv"
)
// GetMinimalTTL returns the minimal ttl of this msg.
// If msg m has no record, it returns 0.
func GetMinimalTTL(m *dns.Msg) uint32 {
minTTL := ^uint32(0)
hasRecord := false
for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range section {
hdr := rr.Header()
if hdr.Rrtype == dns.TypeOPT {
continue // opt record ttl is not ttl.
}
hasRecord = true
ttl := hdr.Ttl
if ttl < minTTL {
minTTL = ttl
}
}
}
if !hasRecord { // no ttl applied
return 0
}
return minTTL
}
// SetTTL updates all records' ttl to ttl, except opt record.
func SetTTL(m *dns.Msg, ttl uint32) {
for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range section {
hdr := rr.Header()
if hdr.Rrtype == dns.TypeOPT {
continue // opt record ttl is not ttl.
}
hdr.Ttl = ttl
}
}
}
func ApplyMaximumTTL(m *dns.Msg, ttl uint32) {
applyTTL(m, ttl, true)
}
func ApplyMinimalTTL(m *dns.Msg, ttl uint32) {
applyTTL(m, ttl, false)
}
// SubtractTTL subtract delta from every m's RR.
// If RR's TTL is smaller than delta, SubtractTTL
// will return overflowed = true.
func SubtractTTL(m *dns.Msg, delta uint32) (overflowed bool) {
for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range section {
hdr := rr.Header()
if hdr.Rrtype == dns.TypeOPT {
continue // opt record ttl is not ttl.
}
if ttl := hdr.Ttl; ttl > delta {
hdr.Ttl = ttl - delta
} else {
hdr.Ttl = 1
overflowed = true
}
}
}
return
}
func applyTTL(m *dns.Msg, ttl uint32, maximum bool) {
for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range section {
hdr := rr.Header()
if hdr.Rrtype == dns.TypeOPT {
continue // opt record ttl is not ttl.
}
if maximum {
if hdr.Ttl > ttl {
hdr.Ttl = ttl
}
} else {
if hdr.Ttl < ttl {
hdr.Ttl = ttl
}
}
}
}
}
func uint16Conv(u uint16, m map[uint16]string) string {
if s, ok := m[u]; ok {
return s
}
return strconv.Itoa(int(u))
}
func QclassToString(u uint16) string {
return uint16Conv(u, dns.ClassToString)
}
func QtypeToString(u uint16) string {
return uint16Conv(u, dns.TypeToString)
}
func GenEmptyReply(q *dns.Msg, rcode int) *dns.Msg {
r := new(dns.Msg)
r.SetRcode(q, rcode)
var name string
if len(q.Question) > 1 {
name = q.Question[0].Name
} else {
name = "."
}
r.Ns = []dns.RR{FakeSOA(name)}
return r
}
func FakeSOA(name string) *dns.SOA {
return &dns.SOA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 300,
},
Ns: "fake-ns.mosdns.fake.root.",
Mbox: "fake-mbox.mosdns.fake.root.",
Serial: 2021110400,
Refresh: 1800,
Retry: 900,
Expire: 604800,
Minttl: 86400,
}
}

138
pkg/dnsutils/net_io.go Normal file
View File

@ -0,0 +1,138 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package dnsutils
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
)
const (
DnsHeaderLen = 12 // minimum dns msg size
)
var (
ErrPayloadTooSmall = errors.New("payload is to small for a valid dns msg")
)
// ReadRawMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
// with a two byte length field).
// n represents how many bytes are read from c.
// The returned the *[]byte should be released by pool.ReleaseBuf.
func ReadRawMsgFromTCP(c io.Reader) (*[]byte, error) {
h := pool.GetBuf(2)
defer pool.ReleaseBuf(h)
_, err := io.ReadFull(c, *h)
if err != nil {
return nil, err
}
// dns length
length := binary.BigEndian.Uint16(*h)
if length <= DnsHeaderLen {
return nil, ErrPayloadTooSmall
}
b := pool.GetBuf(int(length))
_, err = io.ReadFull(c, *b)
if err != nil {
pool.ReleaseBuf(b)
return nil, err
}
return b, nil
}
// ReadMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
// with a two byte length field).
// n represents how many bytes are read from c.
func ReadMsgFromTCP(c io.Reader) (*dns.Msg, int, error) {
b, err := ReadRawMsgFromTCP(c)
if err != nil {
return nil, 0, err
}
defer pool.ReleaseBuf(b)
m, err := unpackMsgWithDetailedErr(*b)
return m, len(*b) + 2, err
}
// WriteMsgToTCP packs and writes m to c in RFC 1035 format.
// n represents how many bytes are written to c.
func WriteMsgToTCP(c io.Writer, m *dns.Msg) (n int, err error) {
buf, err := pool.PackTCPBuffer(m)
if err != nil {
return 0, err
}
defer pool.ReleaseBuf(buf)
return c.Write(*buf)
}
// WriteRawMsgToTCP See WriteMsgToTCP
func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) {
if len(b) > dns.MaxMsgSize {
return 0, fmt.Errorf("payload length %d is greater than dns max msg size", len(b))
}
buf := pool.GetBuf(len(b) + 2)
defer pool.ReleaseBuf(buf)
binary.BigEndian.PutUint16((*buf)[:2], uint16(len(b)))
copy((*buf)[2:], b)
return c.Write((*buf))
}
func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) {
b, err := pool.PackBuffer(m)
if err != nil {
return 0, err
}
defer pool.ReleaseBuf(b)
return c.Write(*b)
}
func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) {
if bufSize < dns.MinMsgSize {
bufSize = dns.MinMsgSize
}
b := pool.GetBuf(bufSize)
defer pool.ReleaseBuf(b)
n, err := c.Read(*b)
if err != nil {
return nil, n, err
}
m, err := unpackMsgWithDetailedErr((*b)[:n])
return m, n, err
}
func unpackMsgWithDetailedErr(b []byte) (*dns.Msg, error) {
m := new(dns.Msg)
if err := m.Unpack(b); err != nil {
return nil, fmt.Errorf("failed to unpack msg [%x], %w", b, err)
}
return m, nil
}

125
pkg/dnsutils/ptr_parser.go Normal file
View File

@ -0,0 +1,125 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package dnsutils
import (
"errors"
"fmt"
"net/netip"
"strconv"
"strings"
)
var errNotPTRDomain = errors.New("domain does not has a ptr suffix")
const (
IP4arpa = ".in-addr.arpa."
IP6arpa = ".ip6.arpa."
)
// ParsePTRQName returns the ip that a PTR query name contains.
func ParsePTRQName(fqdn string) (netip.Addr, error) {
switch {
case strings.HasSuffix(fqdn, IP4arpa):
return reverse4(fqdn[:len(fqdn)-len(IP4arpa)])
case strings.HasSuffix(fqdn, IP6arpa):
return reverse6(fqdn[:len(fqdn)-len(IP6arpa)])
default:
return netip.Addr{}, errNotPTRDomain
}
}
func reverse4(s string) (netip.Addr, error) {
var buf [4]byte
l := 0
for offset := len(s); offset > 0 && l < len(buf); l++ {
var label string
label, offset = prevLabel(s, offset)
n, err := strconv.ParseUint(label, 10, 8)
if err != nil {
return netip.Addr{}, fmt.Errorf("invaild bit, %w", err)
}
buf[l] = byte(n)
}
if l < len(buf) {
return netip.Addr{}, fmt.Errorf("expact at least 3 labels, got %d", l)
}
return netip.AddrFrom4(buf), nil
}
func reverse6(s string) (netip.Addr, error) {
var buf [16]byte
var val byte
var tail bool
var l int
for offset := len(s); offset > 0 && l < len(buf); {
var label string
label, offset = prevLabel(s, offset)
if len(label) != 1 {
return netip.Addr{}, fmt.Errorf("invalid label %s", label)
}
b := label[0]
n, ok := hex2byte(b)
if !ok {
return netip.Addr{}, fmt.Errorf("invaild bit %d", b)
}
if tail {
buf[l] = val<<4 + n
l++
tail = false
} else {
val = n
tail = true
}
}
if l < len(buf) {
return netip.Addr{}, fmt.Errorf("expact at least 16 bytes, got %d", l)
}
return netip.AddrFrom16(buf), nil
}
func hex2byte(c byte) (byte, bool) {
lower := func(c byte) byte {
return c | ('x' - 'X')
}
var b byte
switch {
case '0' <= c && c <= '9':
b = c - '0'
case 'a' <= lower(c) && lower(c) <= 'z':
b = lower(c) - 'a' + 10
default:
return 0, false
}
return b, true
}
func prevLabel(s string, offset int) (string, int) {
for {
s = s[:offset]
n := strings.LastIndexByte(s, '.')
label := s[n+1 : offset]
if n != -1 && len(label) == 0 {
offset = n
continue
}
return label, n
}
}

View File

@ -0,0 +1,81 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package dnsutils
import (
"net/netip"
"reflect"
"testing"
)
func Test_reverse4(t *testing.T) {
tests := []struct {
name string
s string
want netip.Addr
wantErr bool
}{
{"v4", "4.4.8.8", netip.MustParseAddr("8.8.4.4"), false},
{"v4_with_prefix", "prefix.4.4.8.8", netip.MustParseAddr("8.8.4.4"), false},
{"invalid_format", "123114123", netip.Addr{}, true},
{"invalid_format", "12..311..4123..", netip.Addr{}, true},
{"invalid_format", "...", netip.Addr{}, true},
{"short_length", "4.8.8", netip.Addr{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := reverse4(tt.s)
if (err != nil) != tt.wantErr {
t.Errorf("reverse4() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("reverse4() got = %v, want %v", got, tt.want)
}
})
}
}
func Test_reverse6(t *testing.T) {
tests := []struct {
name string
s string
want netip.Addr
wantErr bool
}{
{"v6", "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.MustParseAddr("2001:db8::567:89ab"), false},
{"v6_with_prefix", "prefix.b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.MustParseAddr("2001:db8::567:89ab"), false},
{"invalid_format", "123114123", netip.Addr{}, true},
{"invalid_format", "..123...", netip.Addr{}, true},
{"short_length", "0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.Addr{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := reverse6(tt.s)
if (err != nil) != tt.wantErr {
t.Errorf("reverse6() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("reverse4() got = %v, want %v", got, tt.want)
}
})
}
}

134
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,134 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package hosts
import (
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain"
"github.com/miekg/dns"
"net/netip"
"strings"
)
type Hosts struct {
matcher domain.Matcher[*IPs]
}
// NewHosts creates a hosts using m.
func NewHosts(m domain.Matcher[*IPs]) *Hosts {
return &Hosts{
matcher: m,
}
}
func (h *Hosts) Lookup(fqdn string) (ipv4, ipv6 []netip.Addr) {
ips, ok := h.matcher.Match(fqdn)
if !ok {
return nil, nil // no such host
}
return ips.IPv4, ips.IPv6
}
func (h *Hosts) LookupMsg(m *dns.Msg) *dns.Msg {
if len(m.Question) != 1 {
return nil
}
q := m.Question[0]
typ := q.Qtype
fqdn := q.Name
if q.Qclass != dns.ClassINET || (typ != dns.TypeA && typ != dns.TypeAAAA) {
return nil
}
ipv4, ipv6 := h.Lookup(fqdn)
if len(ipv4)+len(ipv6) == 0 {
return nil // no such host
}
r := new(dns.Msg)
r.SetReply(m)
switch {
case typ == dns.TypeA && len(ipv4) > 0:
for _, ip := range ipv4 {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: fqdn,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 10,
},
A: ip.AsSlice(),
}
r.Answer = append(r.Answer, rr)
}
case typ == dns.TypeAAAA && len(ipv6) > 0:
for _, ip := range ipv6 {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: fqdn,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 10,
},
AAAA: ip.AsSlice(),
}
r.Answer = append(r.Answer, rr)
}
}
// Append fake SOA record for empty reply.
if len(r.Answer) == 0 {
r.Ns = []dns.RR{dnsutils.FakeSOA(fqdn)}
}
return r
}
type IPs struct {
IPv4 []netip.Addr
IPv6 []netip.Addr
}
var _ domain.ParseStringFunc[*IPs] = ParseIPs
func ParseIPs(s string) (string, *IPs, error) {
f := strings.Fields(s)
if len(f) == 0 {
return "", nil, errors.New("empty string")
}
pattern := f[0]
v := new(IPs)
for _, ipStr := range f[1:] {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return "", nil, fmt.Errorf("invalid ip addr %s, %w", ipStr, err)
}
if ip.Is4() { // is ipv4
v.IPv4 = append(v.IPv4, ip)
} else { // is ipv6
v.IPv6 = append(v.IPv6, ip)
}
}
return pattern, v, nil
}

105
pkg/hosts/hosts_test.go Normal file
View File

@ -0,0 +1,105 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package hosts
import (
"bytes"
"github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain"
"github.com/miekg/dns"
"net"
"testing"
)
var test_hosts = `
# comment
# empty line
dns.google 8.8.8.8 8.8.4.4 2001:4860:4860::8844 2001:4860:4860::8888
regexp:^123456789 192.168.1.1
test.com 1.2.3.4 # will be replaced
test.com 2.3.4.5
# nxdomain.com 1.2.3.4
`
func Test_hostsContainer_Match(t *testing.T) {
m := domain.NewMixMatcher[*IPs]()
m.SetDefaultMatcher(domain.MatcherDomain)
err := domain.LoadFromTextReader[*IPs](m, bytes.NewBuffer([]byte(test_hosts)), ParseIPs)
if err != nil {
t.Fatal(err)
}
h := NewHosts(m)
type args struct {
name string
typ uint16
}
tests := []struct {
name string
args args
wantMatched bool
wantAddr []string
}{
{"matched A", args{name: "dns.google.", typ: dns.TypeA}, true, []string{"8.8.8.8", "8.8.4.4"}},
{"matched AAAA", args{name: "dns.google.", typ: dns.TypeAAAA}, true, []string{"2001:4860:4860::8844", "2001:4860:4860::8888"}},
{"not matched A", args{name: "nxdomain.com.", typ: dns.TypeA}, false, nil},
{"not matched A", args{name: "sub.dns.google.", typ: dns.TypeA}, false, nil},
{"matched regexp A", args{name: "123456789.test.", typ: dns.TypeA}, true, []string{"192.168.1.1"}},
{"not matched regexp A", args{name: "0123456789.test.", typ: dns.TypeA}, false, nil},
{"test replacement", args{name: "test.com.", typ: dns.TypeA}, true, []string{"2.3.4.5"}},
{"test matched domain with mismatched type", args{name: "test.com.", typ: dns.TypeAAAA}, true, nil},
}
for _, tt := range tests {
q := new(dns.Msg)
q.SetQuestion(tt.args.name, tt.args.typ)
t.Run(tt.name, func(t *testing.T) {
r := h.LookupMsg(q)
if tt.wantMatched && r == nil {
t.Fatal("Lookup() should not return a nil result")
}
for _, s := range tt.wantAddr {
wantIP := net.ParseIP(s)
if wantIP == nil {
t.Fatal("invalid test case addr")
}
found := false
for _, rr := range r.Answer {
var ip net.IP
switch rr := rr.(type) {
case *dns.A:
ip = rr.A
case *dns.AAAA:
ip = rr.AAAA
default:
continue
}
if ip.Equal(wantIP) {
found = true
break
}
}
if !found {
t.Fatal("wanted ip is not found in response")
}
}
})
}
}

41
pkg/list/elem.go Normal file
View File

@ -0,0 +1,41 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package list
type Elem[V any] struct {
prev, next *Elem[V]
list *List[V]
Value V
}
func NewElem[V any](v V) *Elem[V] {
return &Elem[V]{
Value: v,
}
}
func (e *Elem[V]) Prev() *Elem[V] {
return e.prev
}
func (e *Elem[V]) Next() *Elem[V] {
return e.next
}

99
pkg/list/list.go Normal file
View File

@ -0,0 +1,99 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package list
type List[V any] struct {
front, back *Elem[V]
length int
}
func New[V any]() *List[V] {
return &List[V]{}
}
func mustBeFreeElem[V any](e *Elem[V]) {
if e.prev != nil || e.next != nil || e.list != nil {
panic("element is in use")
}
}
func (l *List[V]) Front() *Elem[V] {
return l.front
}
func (l *List[V]) Back() *Elem[V] {
return l.back
}
func (l *List[V]) Len() int {
return l.length
}
func (l *List[V]) PushFront(e *Elem[V]) *Elem[V] {
mustBeFreeElem(e)
l.length++
e.list = l
if l.front == nil {
l.front = e
l.back = e
} else {
e.next = l.front
l.front.prev = e
l.front = e
}
return e
}
func (l *List[V]) PushBack(e *Elem[V]) *Elem[V] {
mustBeFreeElem(e)
l.length++
e.list = l
if l.back == nil {
l.front = e
l.back = e
} else {
e.prev = l.back
l.back.next = e
l.back = e
}
return e
}
func (l *List[V]) PopElem(e *Elem[V]) *Elem[V] {
if e.list != l {
panic("elem is not belong to this list")
}
l.length--
if p := e.prev; p != nil {
p.next = e.next
}
if n := e.next; n != nil {
n.prev = e.prev
}
if e == l.front {
l.front = e.next
}
if e == l.back {
l.back = e.prev
}
e.prev, e.next, e.list = nil, nil, nil
return e
}

104
pkg/list/list_test.go Normal file
View File

@ -0,0 +1,104 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package list
import (
"github.com/stretchr/testify/assert"
"reflect"
"testing"
)
func checkLinkPointers[V any](t *testing.T, l *List[V]) {
t.Helper()
e := l.front
for e != nil {
if (e.next != nil && e.next.prev != e) || (e.prev != nil && e.prev.next != e) {
t.Fatal("broken list")
}
e = e.next
}
}
func makeElems(n []int) []*Elem[int] {
s := make([]*Elem[int], 0, len(n))
for _, i := range n {
s = append(s, NewElem(i))
}
return s
}
func allValue[V any](l *List[V]) []V {
s := make([]V, 0)
node := l.front
for node != nil {
s = append(s, node.Value)
node = node.next
}
return s
}
func TestList_Push(t *testing.T) {
l := new(List[int])
l.PushBack(NewElem(1))
l.PushBack(NewElem(2))
assert.Equal(t, []int{1, 2}, allValue(l))
checkLinkPointers(t, l)
l = new(List[int])
l.PushFront(NewElem(1))
l.PushFront(NewElem(2))
assert.Equal(t, []int{2, 1}, allValue(l))
checkLinkPointers(t, l)
}
func TestList_PopElem(t *testing.T) {
tests := []struct {
name string
in []int
pop int
want int
wantList []int
}{
{"pop front", []int{0, 1, 2}, 0, 0, []int{1, 2}},
{"pop mid", []int{0, 1, 2}, 1, 1, []int{0, 2}},
{"pop back", []int{0, 1, 2}, 2, 2, []int{0, 1}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := new(List[int])
le := makeElems(tt.in)
for _, e := range le {
l.PushBack(e)
}
checkLinkPointers(t, l)
got := l.PopElem(le[tt.pop])
checkLinkPointers(t, l)
if got.Value != tt.want {
t.Errorf("PopElem() = %v, want %v", got.Value, tt.want)
}
if !reflect.DeepEqual(allValue(l), tt.wantList) {
t.Errorf("allValue() = %v, want %v", allValue(l), tt.wantList)
}
})
}
}

135
pkg/lru/lru.go Normal file
View File

@ -0,0 +1,135 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package lru
import (
"fmt"
"github.com/IrineSistiana/mosdns/v5/pkg/list"
)
type LRU[K comparable, V any] struct {
maxSize int
onEvict func(key K, v V)
l *list.List[KV[K, V]]
m map[K]*list.Elem[KV[K, V]]
}
type KV[K comparable, V any] struct {
key K
v V
}
func NewLRU[K comparable, V any](maxSize int, onEvict func(key K, v V)) *LRU[K, V] {
if maxSize <= 0 {
panic(fmt.Sprintf("LRU: invalid max size: %d", maxSize))
}
return &LRU[K, V]{
maxSize: maxSize,
onEvict: onEvict,
l: list.New[KV[K, V]](),
m: make(map[K]*list.Elem[KV[K, V]]),
}
}
func (q *LRU[K, V]) Add(key K, v V) {
if e, ok := q.m[key]; ok { // update existed key
e.Value.v = v
q.l.PushBack(q.l.PopElem(e))
return
}
o := q.Len() - q.maxSize + 1
for o > 0 {
key, v, _ := q.PopOldest()
if q.onEvict != nil {
q.onEvict(key, v)
}
o--
}
e := list.NewElem(KV[K, V]{
key: key,
v: v,
})
q.m[key] = e
q.l.PushBack(e)
}
func (q *LRU[K, V]) Del(key K) {
e := q.m[key]
if e != nil {
q.delElem(e)
}
}
func (q *LRU[K, V]) delElem(e *list.Elem[KV[K, V]]) {
key, v := e.Value.key, e.Value.v
q.l.PopElem(e)
delete(q.m, key)
if q.onEvict != nil {
q.onEvict(key, v)
}
}
func (q *LRU[K, V]) PopOldest() (key K, v V, ok bool) {
e := q.l.Front()
if e != nil {
q.l.PopElem(e)
key, v = e.Value.key, e.Value.v
delete(q.m, key)
ok = true
return
}
return
}
func (q *LRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) {
e := q.l.Front()
for e != nil {
next := e.Next() // Delete e will clean its pointers. Save it first.
key, v := e.Value.key, e.Value.v
if remove := f(key, v); remove {
q.delElem(e)
removed++
}
e = next
}
return removed
}
func (q *LRU[K, V]) Flush() {
q.l = list.New[KV[K, V]]()
q.m = make(map[K]*list.Elem[KV[K, V]])
}
func (q *LRU[K, V]) Get(key K) (v V, ok bool) {
e, ok := q.m[key]
if !ok {
return
}
q.l.PushBack(q.l.PopElem(e))
return e.Value.v, true
}
func (q *LRU[K, V]) Len() int {
return q.l.Len()
}

138
pkg/lru/lru_test.go Normal file
View File

@ -0,0 +1,138 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package lru
import (
"testing"
)
func Test_lru(t *testing.T) {
var q *LRU[int, int]
reset := func(maxSize int) {
t.Helper()
q = NewLRU[int, int](maxSize, nil)
}
add := func(keys ...int) {
t.Helper()
for _, key := range keys {
q.Add(key, key)
}
}
mustGet := func(keys ...int) {
t.Helper()
for _, key := range keys {
gotV, ok := q.Get(key)
if !ok || gotV != key {
t.Fatalf("want %v, got %v", key, gotV)
}
}
}
emptyGet := func(keys ...int) {
t.Helper()
for _, key := range keys {
gotV, ok := q.Get(key)
if ok {
t.Fatalf("want empty, got %v", gotV)
}
}
}
mustPopOldest := func(keys ...int) {
t.Helper()
for _, key := range keys {
gotKey, gotV, ok := q.PopOldest()
if !ok {
t.Fatal()
}
if gotKey != key || gotV != gotKey {
t.Fatalf("want key: %v, v: %v, got key: %v, v:%v", key, key, gotKey, gotV)
}
}
}
emptyPop := func() {
t.Helper()
gotKey, gotV, ok := q.PopOldest()
if ok {
t.Fatalf("want empty result, got key: %v, v:%v", gotKey, gotV)
}
}
checkLen := func(want int) {
t.Helper()
if q.l.Len() != len(q.m) {
t.Fatalf("possible mem leak: q.l.Len() %v != len(q.m){ %v", q.l.Len(), len(q.m))
}
if want != q.Len() {
t.Fatalf("want %v, got %v", want, q.Len())
}
}
// test add
reset(4)
add(1, 1, 1, 1, 1, 1, 2, 3)
checkLen(3)
mustGet(1, 2, 3)
// test add overflow
reset(2)
add(1, 2, 3, 4, 5)
checkLen(2)
mustGet(4, 5)
emptyGet(1, 2, 3)
// test pop
reset(3)
add(1, 2, 3)
mustPopOldest(1, 2, 3)
checkLen(0)
emptyPop()
// test del
reset(3)
add(1, 2, 3)
q.Del(2)
q.Del(9999)
mustPopOldest(1, 3)
// test clean
reset(4)
add(1, 2, 3, 4)
cleanFunc := func(key int, v int) (remove bool) {
switch key {
case 1, 3:
return true
}
return false
}
if cleaned := q.Clean(cleanFunc); cleaned != 2 {
t.Fatalf("q.Clean want cleaned = 2, got %v", cleaned)
}
mustPopOldest(2, 4)
// test lru
reset(4)
add(1, 2, 3, 4) // 1 2 3 4
mustGet(2, 3) // 1 4 2 3
mustPopOldest(1, 4, 2, 3)
}

View File

@ -0,0 +1,36 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package domain
// "fqdn-insensitive" means the domain in Add() and Match() call
// is fqdn-insensitive. "google.com" and "google.com." will get
// the same outcome.
// The logic for case-insensitive is the same as above.
type Matcher[T any] interface {
// Match matches the domain s.
// s could be a fqdn or not, and should be case-insensitive.
Match(s string) (v T, ok bool)
}
type WriteableMatcher[T any] interface {
Matcher[T]
Add(pattern string, v T) error
}

View File

@ -0,0 +1,80 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package domain
import (
"bufio"
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"io"
"strings"
"unicode"
)
// ParseStringFunc parse data string to matcher pattern and additional attributions.
type ParseStringFunc[T any] func(s string) (pattern string, v T, err error)
// check if s only contains a domain pattern (no other section, no space).
func patternOnly[T any](s string) (pattern string, v T, err error) {
if strings.IndexFunc(s, unicode.IsSpace) != -1 {
return "", v, errors.New("rule string has more than one section")
}
return s, v, nil
}
// Load loads data from a string, parsing it with parseString function.
func Load[T any](m WriteableMatcher[T], s string, parseString ParseStringFunc[T]) error {
if parseString == nil {
parseString = patternOnly[T]
}
pattern, v, err := parseString(s)
if err != nil {
return err
}
return m.Add(pattern, v)
}
// LoadFromTextReader loads multiple lines from reader r. r
func LoadFromTextReader[T any](m WriteableMatcher[T], r io.Reader, parseString ParseStringFunc[T]) error {
lineCounter := 0
scanner := bufio.NewScanner(r)
for scanner.Scan() {
lineCounter++
s := scanner.Text()
s = utils.RemoveComment(s, "#")
s = strings.TrimSpace(s)
if len(s) == 0 {
continue
}
err := Load(m, s, parseString)
if err != nil {
return fmt.Errorf("line %d: %v", lineCounter, err)
}
}
return scanner.Err()
}
func NewDomainMixMatcher() *MixMatcher[struct{}] {
mixMatcher := NewMixMatcher[struct{}]()
mixMatcher.SetDefaultMatcher(MatcherDomain)
return mixMatcher
}

View File

@ -0,0 +1,275 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package domain
import (
"errors"
"fmt"
"regexp"
"strings"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
)
var _ WriteableMatcher[any] = (*MixMatcher[any])(nil)
var _ WriteableMatcher[any] = (*SubDomainMatcher[any])(nil)
var _ WriteableMatcher[any] = (*FullMatcher[any])(nil)
var _ WriteableMatcher[any] = (*KeywordMatcher[any])(nil)
var _ WriteableMatcher[any] = (*RegexMatcher[any])(nil)
type SubDomainMatcher[T any] struct {
root *labelNode[T]
}
func NewSubDomainMatcher[T any]() *SubDomainMatcher[T] {
return &SubDomainMatcher[T]{root: new(labelNode[T])}
}
func (m *SubDomainMatcher[T]) Match(s string) (T, bool) {
s = NormalizeDomain(s)
ds := NewReverseDomainScanner(s)
currentNode := m.root
v, ok := currentNode.getValue()
for ds.Scan() {
label := ds.NextLabel()
if nextNode := currentNode.getChild(label); nextNode != nil {
if nextNode.hasValue() {
v, ok = nextNode.getValue()
}
currentNode = nextNode
} else {
break
}
}
return v, ok
}
func (m *SubDomainMatcher[T]) Len() int {
return m.root.len()
}
func (m *SubDomainMatcher[T]) Add(s string, v T) error {
s = NormalizeDomain(s)
ds := NewReverseDomainScanner(s)
currentNode := m.root
for ds.Scan() {
label := ds.NextLabel()
if child := currentNode.getChild(label); child != nil {
currentNode = child
} else {
currentNode = currentNode.newChild(label)
}
}
currentNode.storeValue(v)
return nil
}
type FullMatcher[T any] struct {
m map[string]T // string in is map must be a normalized domain (See NormalizeDomain).
}
func NewFullMatcher[T any]() *FullMatcher[T] {
return &FullMatcher[T]{
m: make(map[string]T),
}
}
// Add adds domain s to this matcher, s can be a fqdn or not.
func (m *FullMatcher[T]) Add(s string, v T) error {
s = NormalizeDomain(s)
m.m[s] = v
return nil
}
func (m *FullMatcher[T]) Match(s string) (v T, ok bool) {
s = NormalizeDomain(s)
v, ok = m.m[s]
return
}
func (m *FullMatcher[T]) Len() int {
return len(m.m)
}
type KeywordMatcher[T any] struct {
kws map[string]T
}
func NewKeywordMatcher[T any]() *KeywordMatcher[T] {
return &KeywordMatcher[T]{
kws: make(map[string]T),
}
}
func (m *KeywordMatcher[T]) Add(keyword string, v T) error {
keyword = NormalizeDomain(keyword) // fqdn-insensitive and case-insensitive
m.kws[keyword] = v
return nil
}
func (m *KeywordMatcher[T]) Match(s string) (v T, ok bool) {
s = NormalizeDomain(s)
for k, v := range m.kws {
if strings.Contains(s, k) {
return v, true
}
}
return v, false
}
func (m *KeywordMatcher[T]) Len() int {
return len(m.kws)
}
// RegexMatcher contains regexp rules.
// Note: the regexp rule is expect to match a lower-case non fqdn.
type RegexMatcher[T any] struct {
regs map[string]*regElem[T]
}
type regElem[T any] struct {
reg *regexp.Regexp
v T
}
func NewRegexMatcher[T any]() *RegexMatcher[T] {
return &RegexMatcher[T]{regs: make(map[string]*regElem[T])}
}
func (m *RegexMatcher[T]) Add(expr string, v T) error {
e := m.regs[expr]
if e == nil {
reg, err := regexp.Compile(expr)
if err != nil {
return err
}
m.regs[expr] = &regElem[T]{
reg: reg,
v: v,
}
} else {
e.v = v
}
return nil
}
func (m *RegexMatcher[T]) Match(s string) (v T, ok bool) {
s = NormalizeDomain(s)
for _, e := range m.regs {
if e.reg.MatchString(s) {
return e.v, true
}
}
var zeroT T
return zeroT, false
}
func (m *RegexMatcher[T]) Len() int {
return len(m.regs)
}
const (
MatcherFull = "full"
MatcherDomain = "domain"
MatcherRegexp = "regexp"
MatcherKeyword = "keyword"
)
type MixMatcher[T any] struct {
defaultMatcher string
full *FullMatcher[T]
domain *SubDomainMatcher[T]
regex *RegexMatcher[T]
keyword *KeywordMatcher[T]
}
func NewMixMatcher[T any]() *MixMatcher[T] {
return &MixMatcher[T]{
full: NewFullMatcher[T](),
domain: NewSubDomainMatcher[T](),
regex: NewRegexMatcher[T](),
keyword: NewKeywordMatcher[T](),
}
}
func (m *MixMatcher[T]) SetDefaultMatcher(s string) {
m.defaultMatcher = s
}
func (m *MixMatcher[T]) GetSubMatcher(typ string) WriteableMatcher[T] {
switch typ {
case MatcherFull:
return m.full
case MatcherDomain:
return m.domain
case MatcherRegexp:
return m.regex
case MatcherKeyword:
return m.keyword
}
return nil
}
var ErrNodefaultMatcher = errors.New("default matcher is not set")
func (m *MixMatcher[T]) Add(s string, v T) error {
typ, pattern := m.splitTypeAndPattern(s)
if len(typ) == 0 {
if len(m.defaultMatcher) != 0 {
typ = m.defaultMatcher
} else {
return ErrNodefaultMatcher
}
}
sm := m.GetSubMatcher(typ)
if sm == nil {
return fmt.Errorf("unsupported match type [%s]", typ)
}
return sm.Add(pattern, v)
}
func (m *MixMatcher[T]) Match(s string) (v T, ok bool) {
for _, matcher := range [...]Matcher[T]{m.full, m.domain, m.regex, m.keyword} {
if v, ok = matcher.Match(s); ok {
return v, true
}
}
return
}
func (m *MixMatcher[T]) Len() int {
sum := 0
for _, matcher := range [...]interface{ Len() int }{m.full, m.domain, m.regex, m.keyword} {
if matcher == nil {
continue
}
sum += matcher.Len()
}
return sum
}
func (m *MixMatcher[T]) splitTypeAndPattern(s string) (string, string) {
typ, pattern, ok := utils.SplitString2(s, ":")
if !ok {
pattern = s
}
return typ, pattern
}

View File

@ -0,0 +1,200 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package domain
import (
"reflect"
"testing"
)
func assertFunc[T any](t *testing.T, m Matcher[T]) func(domain string, wantBool bool, wantV any) {
return func(domain string, wantBool bool, wantV any) {
t.Helper()
v, ok := m.Match(domain)
if ok != wantBool {
t.Fatalf("%s, wantBool = %v, got = %v", domain, wantBool, ok)
}
if !reflect.DeepEqual(v, wantV) {
t.Fatalf("%s, wantV = %v, got = %v", domain, wantV, v)
}
}
}
type aStr struct {
s string
}
func s(str string) *aStr {
return &aStr{s: str}
}
func (a *aStr) Append(v any) {
a.s = a.s + v.(*aStr).s
}
func TestDomainMatcher(t *testing.T) {
m := NewSubDomainMatcher[any]()
add := func(domain string, v any) {
m.Add(domain, v)
}
assert := assertFunc[any](t, m)
add("cn", nil)
assertInt(t, 1, m.Len())
assert("cn", true, nil)
assert("a.cn.", true, nil)
assert("a.com", false, nil)
add("a.b.com", nil)
assertInt(t, 2, m.Len())
assert("a.b.com.", true, nil)
assert("q.w.e.r.a.b.com.", true, nil)
assert("b.com.", false, nil)
// test replace
add("append", 0)
assertInt(t, 3, m.Len())
assert("append.", true, 0)
add("append.", 1)
assert("append.", true, 1)
add("append", nil)
assert("append.", true, nil)
// test sub domain
add("sub", 1)
assertInt(t, 4, m.Len())
add("a.sub", 2)
assertInt(t, 5, m.Len())
assert("sub", true, 1)
assert("b.sub", true, 1)
assert("a.sub", true, 2)
assert("a.a.sub", true, 2)
// test case-insensitive
add("UPpER", 1)
assert("LowER.Upper", true, 1)
// root match
add(".", 9)
assert("any.domain", true, 9)
}
func assertInt(t testing.TB, want, got int) {
t.Helper()
if want != got {
t.Errorf("assertion failed: want %d, got %d", want, got)
}
}
func Test_FullMatcher(t *testing.T) {
m := NewFullMatcher[any]()
assert := assertFunc[any](t, m)
add := func(domain string, v any) {
m.Add(domain, v)
}
add("cn", nil)
assert("cn", true, nil)
assert("a.cn", false, nil)
add("test.test", nil)
assert("test.test", true, nil)
assert("test.a.test", false, nil)
// test replace
add("append", 0)
assert("append", true, 0)
add("append", 1)
assert("append", true, 1)
add("append", nil)
assert("append", true, nil)
assertInt(t, m.Len(), 3)
// test case-insensitive
add("UPpER", 1)
assert("Upper", true, 1)
}
func Test_KeywordMatcher(t *testing.T) {
m := NewKeywordMatcher[any]()
add := func(domain string, v any) {
m.Add(domain, v)
}
assert := assertFunc[any](t, m)
add("123", s("a"))
assert("123456.cn", true, s("a"))
assert("111123.com", true, s("a"))
assert("111111.cn", false, nil)
add("example.com", nil)
assert("sub.example.com", true, nil)
assert("example_sub.com", false, nil)
// test replace
add("append", 0)
assert("append", true, 0)
add("append", 1)
assert("append", true, 1)
add("append", nil)
assert("append", true, nil)
assertInt(t, m.Len(), 3)
// test case-insensitive
add("UPpER", 1)
assert("L.Upper.U", true, 1)
}
func Test_RegexMatcher(t *testing.T) {
m := NewRegexMatcher[any]()
add := func(expr string, v any, wantErr bool) {
err := m.Add(expr, v)
if (err != nil) != wantErr {
t.Fatalf("%s: want err %v, got %v", expr, wantErr, err != nil)
}
}
assert := assertFunc[any](t, m)
expr := "^github-production-release-asset-[0-9a-za-z]{6}\\.s3\\.amazonaws\\.com$"
add(expr, nil, false)
assert("github-production-release-asset-000000.s3.amazonaws.com", true, nil)
assert("github-production-release-asset-aaaaaa.s3.amazonaws.com", true, nil)
assert("github-production-release-asset-aa.s3.amazonaws.com", false, nil)
assert("prefix_github-production-release-asset-000000.s3.amazonaws.com", false, nil)
assert("github-production-release-asset-000000.s3.amazonaws.com.suffix", false, nil)
expr = "^example"
add(expr, nil, false)
assert("example.com", true, nil)
assert("sub.example.com", false, nil)
// test replace
add("append", 0, false)
assert("append", true, 0)
add("append", 1, false)
assert("append", true, 1)
add("append", nil, false)
assert("append", true, nil)
expr = "*"
add(expr, nil, true)
}

116
pkg/matcher/domain/utils.go Normal file
View File

@ -0,0 +1,116 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package domain
import (
"strings"
)
type ReverseDomainScanner struct {
s string // not fqdn
p int
t int
}
func NewReverseDomainScanner(s string) *ReverseDomainScanner {
s = TrimDot(s)
return &ReverseDomainScanner{
s: s,
p: len(s),
t: len(s),
}
}
func (s *ReverseDomainScanner) Scan() bool {
if s.p <= 0 {
return false
}
s.t = s.p
s.p = strings.LastIndexByte(s.s[:s.p], '.')
return true
}
func (s *ReverseDomainScanner) NextLabelOffset() int {
return s.p + 1
}
func (s *ReverseDomainScanner) NextLabel() (label string) {
return s.s[s.p+1 : s.t]
}
// NormalizeDomain normalize domain string s.
// It removes the suffix "." and make sure the domain is in lower case.
// e.g. a fqdn "GOOGLE.com." will become "google.com"
func NormalizeDomain(s string) string {
return strings.ToLower(TrimDot(s))
}
// TrimDot trims suffix '.'
func TrimDot(s string) string {
if len(s) >= 1 && s[len(s)-1] == '.' {
s = s[:len(s)-1]
}
return s
}
// labelNode can store dns labels.
type labelNode[T any] struct {
children map[string]*labelNode[T] // lazy init
v T
hasV bool
}
func (n *labelNode[T]) storeValue(v T) {
n.v = v
n.hasV = true
}
func (n *labelNode[T]) getValue() (T, bool) {
return n.v, n.hasV
}
func (n *labelNode[T]) hasValue() bool {
return n.hasV
}
func (n *labelNode[T]) newChild(key string) *labelNode[T] {
if n.children == nil {
n.children = make(map[string]*labelNode[T])
}
node := new(labelNode[T])
n.children[key] = node
return node
}
func (n *labelNode[T]) getChild(key string) *labelNode[T] {
return n.children[key]
}
func (n *labelNode[T]) len() int {
l := 0
for _, node := range n.children {
l += node.len()
if node.hasValue() {
l++
}
}
return l
}

View File

@ -0,0 +1,61 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package domain
import (
"reflect"
"testing"
)
func TestDomainScanner(t *testing.T) {
tests := []struct {
name string
fqdn string
wantOffsets []int
wantLabels []string
}{
{"empty", "", []int{}, []string{}},
{"root", ".", []int{}, []string{}},
{"non fqdn", "a.2", []int{2, 0}, []string{"2", "a"}},
{"domain", "1.2.3.", []int{4, 2, 0}, []string{"3", "2", "1"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := NewReverseDomainScanner(tt.fqdn)
gotOffsets := make([]int, 0)
for s.Scan() {
gotOffsets = append(gotOffsets, s.NextLabelOffset())
}
if !reflect.DeepEqual(gotOffsets, tt.wantOffsets) {
t.Errorf("PrevLabelOffset() = %v, want %v", gotOffsets, tt.wantOffsets)
}
s = NewReverseDomainScanner(tt.fqdn)
gotLabels := make([]string, 0)
for s.Scan() {
pl := s.NextLabel()
gotLabels = append(gotLabels, pl)
}
if !reflect.DeepEqual(gotLabels, tt.wantLabels) {
t.Errorf("PrevLabel() = %v, want %v", gotLabels, tt.wantLabels)
}
})
}
}

View File

@ -0,0 +1,28 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package netlist
import (
"net/netip"
)
type Matcher interface {
Match(addr netip.Addr) bool
}

150
pkg/matcher/netlist/list.go Normal file
View File

@ -0,0 +1,150 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package netlist
import (
"fmt"
"net/netip"
"sort"
)
// List is a list of netip.Prefix. It stores all netip.Prefix in one single slice
// and use binary search.
// It is suitable for large static cidr search.
type List struct {
// stores valid and masked netip.Prefix(s)
e []netip.Prefix
sorted bool
}
// NewList returns a *List.
func NewList() *List {
return &List{
e: make([]netip.Prefix, 0),
}
}
func mustValid(l []netip.Prefix) {
for i, prefix := range l {
if !prefix.IsValid() {
panic(fmt.Sprintf("invalid prefix at #%d", i))
}
}
}
// Append appends new netip.Prefix(s) to the list.
// This modified the list. Caller must call List.Sort() before calling List.Contains()
func (list *List) Append(newNet ...netip.Prefix) {
for i, n := range newNet {
addr := to6(n.Addr())
bits := n.Bits()
if n.Addr().Is4() {
bits += 96
}
newNet[i] = netip.PrefixFrom(addr, bits).Masked()
}
mustValid(newNet)
list.e = append(list.e, newNet...)
list.sorted = false
}
// Sort sorts the list, this must be called after
// list being modified and before calling List.Contains().
func (list *List) Sort() {
if list.sorted {
return
}
sort.Sort(list)
out := make([]netip.Prefix, 0)
for i, n := range list.e {
if i == 0 {
out = append(out, n)
} else {
lv := &out[len(out)-1]
switch {
case n.Addr() == lv.Addr():
if n.Bits() < lv.Bits() {
*lv = n
}
case !lv.Contains(n.Addr()):
out = append(out, n)
}
}
}
list.e = out
list.sorted = true
}
// Len implements sort Interface.
func (list *List) Len() int {
return len(list.e)
}
// Less implements sort Interface.
func (list *List) Less(i, j int) bool {
return list.e[i].Addr().Less(list.e[j].Addr())
}
// Swap implements sort Interface.
func (list *List) Swap(i, j int) {
list.e[i], list.e[j] = list.e[j], list.e[i]
}
func (list *List) Match(addr netip.Addr) bool {
return list.Contains(addr)
}
// Contains reports whether the list includes the given netip.Addr.
func (list *List) Contains(addr netip.Addr) bool {
if !list.sorted {
panic("list is not sorted")
}
if !addr.IsValid() {
return false
}
addr = to6(addr)
i, j := 0, len(list.e)
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
if list.e[h].Addr().Compare(addr) <= 0 {
i = h + 1
} else {
j = h
}
}
if i == 0 {
return false
}
return list.e[i-1].Contains(addr)
}
func to6(addr netip.Addr) netip.Addr {
if addr.Is6() {
return addr
}
return netip.AddrFrom16(addr.As16())
}

View File

@ -0,0 +1,77 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package netlist
import (
"bufio"
"fmt"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"io"
"net/netip"
"strings"
)
// LoadFromReader loads IP list from a reader.
// It might modify the List and causes List unsorted.
func LoadFromReader(l *List, reader io.Reader) error {
scanner := bufio.NewScanner(reader)
// count how many lines we have read.
lineCounter := 0
for scanner.Scan() {
lineCounter++
s := scanner.Text()
s = strings.TrimSpace(s)
s = utils.RemoveComment(s, "#")
s = utils.RemoveComment(s, " ")
if len(s) == 0 {
continue
}
err := LoadFromText(l, s)
if err != nil {
return fmt.Errorf("invalid data at line #%d: %w", lineCounter, err)
}
}
return scanner.Err()
}
// LoadFromText loads an IP from s.
// It might modify the List and causes List unsorted.
func LoadFromText(l *List, s string) error {
if strings.ContainsRune(s, '/') {
ipNet, err := netip.ParsePrefix(s)
if err != nil {
return err
}
l.Append(ipNet)
return nil
}
addr, err := netip.ParseAddr(s)
if err != nil {
return err
}
bits := 32
if addr.Is6() {
bits = 128
}
l.Append(netip.PrefixFrom(addr, bits))
return nil
}

View File

@ -0,0 +1,118 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package netlist
import (
"bytes"
"net/netip"
"testing"
)
func TestIPNetList_Sort_And_Merge(t *testing.T) {
raw := `
192.168.0.0/32 # merged
192.168.0.0/24 # merged
192.168.0.0/16
192.168.1.1/24 # merged
192.168.9.24/24 # merged
192.168.3.0/24 # merged
192.169.0.0/16
104.16.0.0/12
`
ipNetList := NewList()
err := LoadFromReader(ipNetList, bytes.NewBufferString(raw))
if err != nil {
t.Fatal(err)
}
ipNetList.Sort()
if ipNetList.Len() != 3 {
t.Fatalf("unexpected length %d", ipNetList.Len())
}
tests := []struct {
name string
testIP netip.Addr
want bool
}{
{"0", netip.MustParseAddr("192.167.255.255"), false},
{"1", netip.MustParseAddr("192.168.0.0"), true},
{"2", netip.MustParseAddr("192.168.1.1"), true},
{"3", netip.MustParseAddr("192.168.9.255"), true},
{"4", netip.MustParseAddr("192.168.255.255"), true},
{"5", netip.MustParseAddr("192.169.1.1"), true},
{"6", netip.MustParseAddr("192.170.1.1"), false},
{"7", netip.MustParseAddr("1.1.1.1"), false},
{"8", netip.MustParseAddr("104.16.67.38"), true},
{"9", netip.MustParseAddr("104.32.67.38"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ipNetList.Match(tt.testIP); got != tt.want {
t.Errorf("IPNetList.Match() = %v, want %v", got, tt.want)
}
})
}
}
func TestIPNetList_New_And_Contains(t *testing.T) {
raw := `
# comment line
1.0.0.0/24 additional strings should be ignored
2.0.0.0/23 # comment
3.0.0.0
2000:0000::/32
2000:2000::1
`
ipNetList := NewList()
err := LoadFromReader(ipNetList, bytes.NewBufferString(raw))
if err != nil {
t.Fatal(err)
}
ipNetList.Sort()
tests := []struct {
name string
testIP netip.Addr
want bool
}{
{"", netip.MustParseAddr("1.0.0.0"), true},
{"", netip.MustParseAddr("1.0.0.1"), true},
{"", netip.MustParseAddr("1.0.1.0"), false},
{"", netip.MustParseAddr("2.0.0.0"), true},
{"", netip.MustParseAddr("2.0.1.255"), true},
{"", netip.MustParseAddr("2.0.2.0"), false},
{"", netip.MustParseAddr("3.0.0.0"), true},
{"", netip.MustParseAddr("2000:0000::"), true},
{"", netip.MustParseAddr("2000:0000::1"), true},
{"", netip.MustParseAddr("2000:0000:1::"), true},
{"", netip.MustParseAddr("2000:0001::"), false},
{"", netip.MustParseAddr("2000:2000::1"), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ipNetList.Match(tt.testIP); got != tt.want {
t.Errorf("IPNetList.Match() = %v, want %v", got, tt.want)
}
})
}
}

154
pkg/nftset_utils/handler.go Normal file
View File

@ -0,0 +1,154 @@
//go:build linux
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package nftset_utils
import (
"errors"
"fmt"
"net/netip"
"sync"
"time"
"github.com/google/nftables"
"go4.org/netipx"
)
var (
ErrClosed = errors.New("closed handler")
)
// NftSetHandler can add netip.Prefix to the corresponding set.
// The table that contains this set must be an inet family table.
// If the set has a 'interval' flag, the prefix from netip.Prefix will be
// applied.
type NftSetHandler struct {
opts HandlerOpts
m sync.Mutex
closed bool
lastUpdate time.Time
set *nftables.Set
lastingConn *nftables.Conn // Note: lasting conn is not concurrent safe so m is required.
disableSetCache bool // for test only
}
type HandlerOpts struct {
TableFamily nftables.TableFamily
TableName string
SetName string
}
// NewNtSetHandler inits NftSetHandler.
func NewNtSetHandler(opts HandlerOpts) *NftSetHandler {
return &NftSetHandler{
opts: opts,
}
}
// getSetLocked get set info from kernel. It has an internal cache and won't
// invoke a syscall every time.
func (h *NftSetHandler) getSetLocked() (*nftables.Set, error) {
const refreshInterval = time.Second
now := time.Now()
if !h.disableSetCache && h.set != nil && now.Sub(h.lastUpdate) < refreshInterval {
return h.set, nil
}
// Note: GetSetByName is not concurrent safe.
set, err := h.lastingConn.GetSetByName(&nftables.Table{Name: h.opts.TableName, Family: h.opts.TableFamily}, h.opts.SetName)
if err != nil {
return nil, err
}
h.set = set
h.lastUpdate = now
return set, nil
}
// AddElems adds netip.Prefix(s) to set in a single batch.
func (h *NftSetHandler) AddElems(es ...netip.Prefix) error {
h.m.Lock()
defer h.m.Unlock()
if h.closed {
return ErrClosed
}
if h.lastingConn == nil {
c, err := nftables.New(nftables.AsLasting())
if err != nil {
return fmt.Errorf("failed to open netlink, %w", err)
}
h.lastingConn = c
}
set, err := h.getSetLocked()
if err != nil {
return fmt.Errorf("failed to get set, %w", err)
}
var elems []nftables.SetElement
if set.Interval {
elems = make([]nftables.SetElement, 0, 2*len(es))
} else {
elems = make([]nftables.SetElement, 0, len(es))
}
for i, e := range es {
if !e.IsValid() {
return fmt.Errorf("invalid prefix at index %d", i)
}
if set.Interval {
start := e.Masked().Addr()
elems = append(elems, nftables.SetElement{Key: start.AsSlice(), IntervalEnd: false})
end := netipx.PrefixLastIP(e).Next() // may be invalid if end is overflowed
if end.IsValid() {
elems = append(elems, nftables.SetElement{Key: end.AsSlice(), IntervalEnd: true})
}
} else {
elems = append(elems, nftables.SetElement{Key: e.Addr().AsSlice()})
}
}
err = h.lastingConn.SetAddElements(set, elems)
if err != nil {
return err
}
return h.lastingConn.Flush()
}
func (h *NftSetHandler) Close() error {
h.m.Lock()
defer h.m.Unlock()
if h.closed {
return nil
}
h.closed = true
if h.lastingConn != nil {
return h.lastingConn.CloseLasting()
}
return nil
}

View File

@ -0,0 +1,96 @@
//go:build linux
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package nftset_utils
import (
"github.com/google/nftables"
"net/netip"
"os"
"sync"
"testing"
)
func skipCI(t *testing.T) {
if os.Getenv("TEST_NFTSET") == "" {
t.SkipNow()
}
}
func prepareSet(t testing.TB, tableName, setName string, interval bool) {
t.Helper()
nc, err := nftables.New()
if err != nil {
t.Fatal(err)
}
table := &nftables.Table{Name: tableName, Family: nftables.TableFamilyINet}
nc.AddTable(table)
if err := nc.AddSet(&nftables.Set{Name: setName, Table: table, KeyType: nftables.TypeIPAddr, Interval: interval}, nil); err != nil {
t.Fatal(err)
}
if err := nc.Flush(); err != nil {
t.Fatal(err)
}
}
func Test_AddElems(t *testing.T) {
skipCI(t)
n := "test"
prepareSet(t, n, n, false)
h := NewNtSetHandler(HandlerOpts{
TableFamily: nftables.TableFamilyINet,
TableName: n,
SetName: n,
})
h.disableSetCache = true
if err := h.AddElems(netip.MustParsePrefix("127.0.0.1/24")); err != nil {
t.Fatal(err)
}
nc, err := nftables.New()
if err != nil {
t.Fatal(err)
}
elems, err := nc.GetSetElements(h.set)
if err != nil {
t.Fatal(err)
}
if len(elems) == 0 {
t.Fatal("set is empty")
}
// test concurrent safe.
wg := new(sync.WaitGroup)
for i := 0; i < 512; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := h.AddElems(netip.MustParsePrefix("127.0.0.1/24")); err != nil {
t.Error(err)
return
}
}()
}
wg.Wait()
}

30
pkg/pool/allocator.go Normal file
View File

@ -0,0 +1,30 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package pool
import (
bytesPool "github.com/IrineSistiana/go-bytes-pool"
)
var (
_pool = bytesPool.NewPool(20) // 1Mb pool, should be enough.
GetBuf = _pool.Get
ReleaseBuf = _pool.Release
)

53
pkg/pool/bytes_buf.go Normal file
View File

@ -0,0 +1,53 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package pool
import (
"bytes"
"fmt"
"sync"
)
type BytesBufPool struct {
p sync.Pool
}
func NewBytesBufPool(initSize int) *BytesBufPool {
if initSize < 0 {
panic(fmt.Sprintf("utils.NewBytesBufPool: negative init size %d", initSize))
}
return &BytesBufPool{
p: sync.Pool{New: func() any {
b := new(bytes.Buffer)
b.Grow(initSize)
return b
}},
}
}
func (p *BytesBufPool) Get() *bytes.Buffer {
return p.p.Get().(*bytes.Buffer)
}
func (p *BytesBufPool) Release(b *bytes.Buffer) {
b.Reset()
p.p.Put(b)
}

69
pkg/pool/msg_buf.go Normal file
View File

@ -0,0 +1,69 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package pool
import (
"encoding/binary"
"fmt"
"github.com/miekg/dns"
)
// There is no such way to give dns.Msg.PackBuffer() a buffer
// with a proper size.
// Just give it a big buf and hope the buf will be reused in most scenes.
const packBufferSize = 8191
// PackBuffer packs the dns msg m to wire format.
// Callers should release the buf by calling ReleaseBuf after they have done
// with the wire []byte.
func PackBuffer(m *dns.Msg) (*[]byte, error) {
packBuf := GetBuf(packBufferSize)
defer ReleaseBuf(packBuf)
wire, err := m.PackBuffer(*packBuf)
if err != nil {
return nil, err
}
msgBuf := GetBuf(len(wire))
copy(*msgBuf, wire)
return msgBuf, nil
}
// PackBuffer packs the dns msg m to wire format, with to bytes length header.
// Callers should release the buf by calling ReleaseBuf.
func PackTCPBuffer(m *dns.Msg) (*[]byte, error) {
packBuf := GetBuf(packBufferSize)
defer ReleaseBuf(packBuf)
wire, err := m.PackBuffer((*packBuf)[2:])
if err != nil {
return nil, err
}
l := len(wire)
if l > dns.MaxMsgSize {
return nil, fmt.Errorf("dns payload size %d is too large", l)
}
msgBuf := GetBuf(2 + len(wire))
binary.BigEndian.PutUint16(*msgBuf, uint16(l))
copy((*msgBuf)[2:], wire)
return msgBuf, nil
}

60
pkg/pool/timer.go Normal file
View File

@ -0,0 +1,60 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package pool
import (
"sync"
"time"
)
var (
timerPool = sync.Pool{}
)
func GetTimer(t time.Duration) *time.Timer {
timer, ok := timerPool.Get().(*time.Timer)
if !ok {
return time.NewTimer(t)
}
if timer.Reset(t) {
panic("dispatcher.go getTimer: active timer trapped in timerPool")
}
return timer
}
func ReleaseTimer(timer *time.Timer) {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timerPool.Put(timer)
}
func ResetAndDrainTimer(timer *time.Timer, d time.Duration) {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(d)
}

View File

@ -0,0 +1,312 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package query_context
import (
"sync/atomic"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/server"
"github.com/miekg/dns"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
const (
edns0Size = 1200
)
// Context is a query context that pass through plugins.
// All Context funcs are not safe for concurrent use.
type Context struct {
id uint32
startTime time.Time
// ServerMeta contains some meta info from the server.
// It is read-only.
ServerMeta ServerMeta
query *dns.Msg // always has one question.
clientOpt *dns.OPT // may be nil
resp *dns.Msg
respOpt *dns.OPT // nil if clientOpt == nil
upstreamOpt *dns.OPT // may be nil
// lazy init.
kv map[uint32]any
marks map[uint32]struct{}
}
var contextUid atomic.Uint32
type ServerMeta = server.QueryMeta
// NewContext creates a new query Context.
// q must have one question.
// NewContext takes the ownership of q.
func NewContext(q *dns.Msg) *Context {
ctx := &Context{
id: contextUid.Add(1),
startTime: time.Now(),
query: q,
clientOpt: addNewAndSwapOldOpt(q),
}
if ctx.clientOpt != nil {
ctx.respOpt = newOpt()
// RFC 3225 3
// The DO bit of the query MUST be copied in the response.
if ctx.clientOpt.Do() {
setDo(ctx.respOpt, true)
}
}
return ctx
}
// Id returns the Context id.
// Note: This id is not the dns msg id.
// It's a unique uint32 growing with the number of query.
func (ctx *Context) Id() uint32 {
return ctx.id
}
// StartTime returns the time when the Context was created.
func (ctx *Context) StartTime() time.Time {
return ctx.startTime
}
// Q returns the query msg that will be forward to upstream.
// It always returns a non-nil msg with one question and EDNS0 OPT.
// If Caller want to modify the msg, be sure not to break those conditions.
func (ctx *Context) Q() *dns.Msg {
return ctx.query
}
// QQuestion returns the query question.
func (ctx *Context) QQuestion() dns.Question {
return ctx.query.Question[0]
}
// QOpt returns the query opt. It always returns a non-nil opt.
// It's a helper func for searching opt in Q() manually.
func (ctx *Context) QOpt() *dns.OPT {
opt := findOpt(ctx.query)
ctx.query.IsEdns0()
if opt == nil {
panic("query opt is missing")
}
return opt
}
// ClientOpt returns the OPT rr from client. Maybe nil, if client does not send it.
// Plugins that responsible for handling EDNS0 option should
// check ClientOpt and pick/add options into Q() on demand.
// The OPT is read-only.
func (ctx *Context) ClientOpt() *dns.OPT {
return ctx.clientOpt
}
// SetResponse sets m as response. It takes the ownership of m.
// If m is nil. It removes existing response.
func (ctx *Context) SetResponse(m *dns.Msg) {
ctx.resp = m
if m == nil {
ctx.upstreamOpt = nil
} else {
ctx.upstreamOpt = popOpt(m)
}
}
// R returns the response that will be sent to client. It might be nil.
// Note: R does not have EDNS0. Caller MUST NOT add a dns.OPT into R.
// Use RespOpt() instead.
func (ctx *Context) R() *dns.Msg {
return ctx.resp
}
// RespOpt returns the OPT that will be sent to client.
// If client support EDNS0, then RespOpt always returns a non-nil OPT.
// No matter what R() returns.
// Otherwise, RespOpt returns nil.
func (ctx *Context) RespOpt() *dns.OPT {
return ctx.respOpt
}
// UpstreamOpt returns the OPT from upstream. May be nil.
// Plugins that responsible for handling EDNS0 option should
// check UpstreamOpt and pick/add options into RespOpt on demand.
// The OPT is read-only.
func (ctx *Context) UpstreamOpt() *dns.OPT {
return ctx.upstreamOpt
}
// InfoField returns a zap.Field contains a brief summary of this Context.
// Useful in log.
func (ctx *Context) InfoField() zap.Field {
return zap.Object("query", ctx)
}
// Copy deep copies this Context.
// See CopyTo.
func (ctx *Context) Copy() *Context {
newCtx := new(Context)
ctx.CopyTo(newCtx)
return newCtx
}
// CopyTo deep copies this Context to d.
// Note that values that stored by StoreValue is not deep-copied.
func (ctx *Context) CopyTo(d *Context) *Context {
d.id = ctx.id
d.startTime = ctx.startTime
d.ServerMeta = ctx.ServerMeta
d.query = ctx.query.Copy()
d.clientOpt = ctx.clientOpt
if ctx.resp != nil {
d.resp = ctx.resp.Copy()
}
if ctx.respOpt != nil {
d.respOpt = dns.Copy(ctx.respOpt).(*dns.OPT)
}
d.upstreamOpt = ctx.upstreamOpt
d.kv = copyMap(ctx.kv)
d.marks = copyMap(ctx.marks)
return d
}
// StoreValue stores any v in to this Context
// k MUST from RegKey.
func (ctx *Context) StoreValue(k uint32, v any) {
if ctx.kv == nil {
ctx.kv = make(map[uint32]any)
}
ctx.kv[k] = v
}
// GetValue returns the value stored by StoreValue.
func (ctx *Context) GetValue(k uint32) (any, bool) {
v, ok := ctx.kv[k]
return v, ok
}
// DeleteValue deletes value k from Context
func (ctx *Context) DeleteValue(k uint32) {
delete(ctx.kv, k)
}
// SetMark marks this Context with given mark.
func (ctx *Context) SetMark(m uint32) {
if ctx.marks == nil {
ctx.marks = make(map[uint32]struct{})
}
ctx.marks[m] = struct{}{}
}
// HasMark reports whether this mark m was marked by SetMark.
func (ctx *Context) HasMark(m uint32) bool {
_, ok := ctx.marks[m]
return ok
}
// DeleteMark deletes mark m from this Context.
func (ctx *Context) DeleteMark(m uint32) {
delete(ctx.marks, m)
}
// MarshalLogObject implements zapcore.ObjectMarshaler.
func (ctx *Context) MarshalLogObject(encoder zapcore.ObjectEncoder) error {
encoder.AddUint32("uqid", ctx.id)
if clientAddr := ctx.ServerMeta.ClientAddr; clientAddr.IsValid() {
zap.Stringer("client", clientAddr).AddTo(encoder)
}
question := ctx.query.Question[0]
encoder.AddString("qname", question.Name)
encoder.AddUint16("qtype", question.Qtype)
encoder.AddUint16("qclass", question.Qclass)
if r := ctx.resp; r != nil {
encoder.AddInt("rcode", r.Rcode)
}
encoder.AddDuration("elapsed", time.Since(ctx.startTime))
return nil
}
func copyMap[K comparable, V any](m map[K]V) map[K]V {
if m == nil {
return nil
}
cm := make(map[K]V, len(m))
for k, v := range m {
cm[k] = v
}
return cm
}
func addNewAndSwapOldOpt(m *dns.Msg) *dns.OPT {
for i := len(m.Extra) - 1; i >= 0; i-- {
// If m has oldOpt
if oldOpt, ok := m.Extra[i].(*dns.OPT); ok {
// replace it directly
m.Extra[i] = newOpt()
return oldOpt
}
}
m.Extra = append(m.Extra, newOpt())
return nil
}
func popOpt(m *dns.Msg) *dns.OPT {
for i := len(m.Extra) - 1; i >= 0; i-- {
if opt, ok := m.Extra[i].(*dns.OPT); ok {
m.Extra = append(m.Extra[:i], m.Extra[i+1:]...)
return opt
}
}
return nil
}
func findOpt(m *dns.Msg) *dns.OPT {
for i := len(m.Extra) - 1; i >= 0; i-- {
if opt, ok := m.Extra[i].(*dns.OPT); ok {
return opt
}
}
return nil
}
func newOpt() *dns.OPT {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.SetUDPSize(edns0Size)
return opt
}
func setDo(opt *dns.OPT, do bool) {
const doBit = 1 << 15 // DNSSEC OK
if do {
opt.Hdr.Ttl |= doBit
}
}

35
pkg/query_context/kv.go Normal file
View File

@ -0,0 +1,35 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package query_context
import "sync/atomic"
var kId atomic.Uint32
// RegKey returns a unique uint32 for the key used in
// Context.StoreValue, Context.GetValue.
// It should only be called during initialization.
func RegKey() uint32 {
i := kId.Add(1)
if i == 0 {
panic("key id overflowed")
}
return i
}

View File

@ -0,0 +1,153 @@
package rate_limiter
import (
"net/netip"
"sync"
"time"
"golang.org/x/time/rate"
)
const (
tableShards = 32
gcInterval = time.Minute
)
type Limiter struct {
// Limit and Burst are read-only.
Limit rate.Limit
Burst int
closeOnce sync.Once
closeNotify chan struct{}
tables [tableShards]*tableShard
}
type tableShard struct {
m sync.Mutex
table map[netip.Addr]*limiterEntry
}
type limiterEntry struct {
l *rate.Limiter
lastSeen time.Time
sync.Once
}
// NewRateLimiter creates a new client rate limiter.
// limit and burst should be greater than zero. See rate.Limiter for more
// details.
// Limiter has a internal gc which will run and remove old client entries every 1m.
// If the token refill time (burst/limit) is greater than 1m,
// the actual average qps limit may be higher than expected because the client status
// may be deleted and re-initialized.
func NewRateLimiter(limit rate.Limit, burst int) *Limiter {
l := &Limiter{
Limit: limit,
Burst: burst,
closeNotify: make(chan struct{}),
}
for i := range l.tables {
l.tables[i] = &tableShard{table: make(map[netip.Addr]*limiterEntry)}
}
go l.gcLoop(gcInterval)
return l
}
// maskedUnmappedP must be a masked prefix and contain a unmapped addr.
func (l *Limiter) Allow(unmappedAddr netip.Addr) bool {
now := time.Now()
shard := l.getTableShard(unmappedAddr)
shard.m.Lock()
e, ok := shard.table[unmappedAddr]
if !ok {
e = &limiterEntry{
l: rate.NewLimiter(l.Limit, l.Burst),
lastSeen: now,
}
shard.table[unmappedAddr] = e
}
e.lastSeen = now
shard.m.Unlock()
clientLimiter := e.l
return clientLimiter.AllowN(now, 1)
}
func (l *Limiter) Close() error {
l.closeOnce.Do(func() {
close(l.closeNotify)
})
return nil
}
func (l *Limiter) gcLoop(gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-l.closeNotify:
return
case now := <-ticker.C:
l.doGc(now, gcInterval)
}
}
}
func (l *Limiter) doGc(now time.Time, gcInterval time.Duration) {
for _, shard := range l.tables {
shard.m.Lock()
for a, e := range shard.table {
if now.Sub(e.lastSeen) > gcInterval {
delete(shard.table, a)
}
}
shard.m.Unlock()
}
}
func (l *Limiter) getTableShard(unmappedAddr netip.Addr) *tableShard {
return l.tables[getTableShardIdx(unmappedAddr)]
}
func (l *Limiter) ForEach(doFunc func(unmappedAddr netip.Addr, r *rate.Limiter) (doBreak bool)) (doBreak bool) {
for _, shard := range l.tables {
shard.m.Lock()
for a, e := range shard.table {
doBreak = doFunc(a, e.l)
if doBreak {
shard.m.Unlock()
return
}
}
shard.m.Unlock()
}
return false
}
// Len returns current number of entries in the Limiter.
func (l *Limiter) Len() int {
n := 0
for _, shard := range l.tables {
shard.m.Lock()
n += len(shard.table)
shard.m.Unlock()
}
return n
}
func getTableShardIdx(unmappedAddr netip.Addr) int {
var i byte
if unmappedAddr.Is4() {
for _, b := range unmappedAddr.As4() {
i ^= b
}
} else {
for _, b := range unmappedAddr.As16() {
i ^= b
}
}
return int(i % tableShards)
}

View File

@ -0,0 +1,20 @@
package rate_limiter
import (
"testing"
"time"
"golang.org/x/time/rate"
)
func BenchmarkXxx(b *testing.B) {
now := time.Now()
var l *limiterEntry
for i := 0; i < b.N; i++ {
l = &limiterEntry{
l: rate.NewLimiter(0, 0),
lastSeen: now,
}
}
_ = l
}

View File

@ -0,0 +1,84 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package safe_close
import "sync"
// SafeClose can achieve safe close where WaitClosed returns only after
// all sub goroutines exited.
//
// 1. Main service goroutine starts and wait on ReceiveCloseSignal.
// 2. Any service's sub goroutine should be started by Attach and wait on ReceiveCloseSignal.
// 3. If any fatal err occurs, any service goroutine can call SendCloseSignal to close the service.
// 4. Any third party caller can call SendCloseSignal to close the service.
type SafeClose struct {
m sync.Mutex
wg sync.WaitGroup
closeSignal chan struct{}
closeErr error
}
func NewSafeClose() *SafeClose {
return &SafeClose{
closeSignal: make(chan struct{}),
}
}
// WaitClosed waits until all SendCloseSignal is called and all
// attached funcs in SafeClose are done.
func (s *SafeClose) WaitClosed() error {
<-s.closeSignal
s.wg.Wait()
return s.closeErr
}
// SendCloseSignal sends a close signal. Unblock WaitClosed.
// The given error will be read by WaitClosed.
// Once SendCloseSignal is called, following calls are noop.
func (s *SafeClose) SendCloseSignal(err error) {
s.m.Lock()
select {
case <-s.closeSignal:
default:
s.closeErr = err
close(s.closeSignal)
}
s.m.Unlock()
}
func (s *SafeClose) ReceiveCloseSignal() <-chan struct{} {
return s.closeSignal
}
// Attach add this goroutine to s.wg WaitClosed.
// f must receive closeSignal and call done when it is done.
// If s was closed, f will not run.
func (s *SafeClose) Attach(f func(done func(), closeSignal <-chan struct{})) {
s.m.Lock()
select {
case <-s.closeSignal:
default:
s.wg.Add(1)
go func() {
f(s.wg.Done, s.closeSignal)
}()
}
s.m.Unlock()
}

123
pkg/server/doq.go Normal file
View File

@ -0,0 +1,123 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server
import (
"context"
"fmt"
"net"
"net/netip"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/quic-go/quic-go"
"go.uber.org/zap"
)
const (
defaultQuicIdleTimeout = time.Second * 30
streamReadTimeout = time.Second * 2
quicFirstReadTimeout = time.Second * 2
)
type DoQServerOpts struct {
Logger *zap.Logger
IdleTimeout time.Duration
}
// ServeDoQ starts a server at l. It returns if l had an Accept() error.
// It always returns a non-nil error.
func ServeDoQ(l *quic.Listener, h Handler, opts DoQServerOpts) error {
logger := opts.Logger
if logger == nil {
logger = nopLogger
}
idleTimeout := opts.IdleTimeout
if idleTimeout <= 0 {
idleTimeout = defaultQuicIdleTimeout
}
listenerCtx, cancel := context.WithCancelCause(context.Background())
defer cancel(errListenerCtxCanceled)
for {
c, err := l.Accept(listenerCtx)
if err != nil {
return fmt.Errorf("unexpected listener err: %w", err)
}
// handle connection
connCtx, cancelConn := context.WithCancelCause(listenerCtx)
go func() {
defer c.CloseWithError(0, "")
defer cancelConn(errConnectionCtxCanceled)
var clientAddr netip.Addr
ta, ok := c.RemoteAddr().(*net.UDPAddr)
if ok {
clientAddr = ta.AddrPort().Addr()
}
firstRead := true
for {
var streamAcceptTimeout time.Duration
if firstRead {
firstRead = false
streamAcceptTimeout = quicFirstReadTimeout
} else {
streamAcceptTimeout = idleTimeout
}
streamAcceptCtx, cancelStreamAccept := context.WithTimeout(connCtx, streamAcceptTimeout)
stream, err := c.AcceptStream(streamAcceptCtx)
cancelStreamAccept()
if err != nil {
return
}
// Handle stream.
// For doq, one stream, one query.
go func() {
defer func() {
stream.Close()
stream.CancelRead(0) // TODO: Needs a proper error code.
}()
// Avoid fragmentation attack.
stream.SetReadDeadline(time.Now().Add(streamReadTimeout))
req, _, err := dnsutils.ReadMsgFromTCP(stream)
if err != nil {
return
}
queryMeta := QueryMeta{
ClientAddr: clientAddr,
ServerName: c.ConnectionState().TLS.ServerName,
}
resp := h.Handle(connCtx, req, queryMeta, pool.PackTCPBuffer)
if resp == nil {
return
}
if _, err := stream.Write(*resp); err != nil {
logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err))
}
}()
}
}()
}
}

179
pkg/server/http_handler.go Normal file
View File

@ -0,0 +1,179 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server
import (
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"strings"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
"go.uber.org/zap"
)
type HttpHandlerOpts struct {
// GetSrcIPFromHeader specifies the header that contain client source address.
// e.g. "X-Forwarded-For".
GetSrcIPFromHeader string
// Logger specifies the logger which Handler writes its log to.
// Default is a nop logger.
Logger *zap.Logger
}
type HttpHandler struct {
dnsHandler Handler
logger *zap.Logger
srcIPHeader string
}
var _ http.Handler = (*HttpHandler)(nil)
func NewHttpHandler(h Handler, opts HttpHandlerOpts) *HttpHandler {
hh := new(HttpHandler)
hh.dnsHandler = h
hh.srcIPHeader = opts.GetSrcIPFromHeader
hh.logger = opts.Logger
if hh.logger == nil {
hh.logger = nopLogger
}
return hh
}
func (h *HttpHandler) warnErr(req *http.Request, msg string, err error) {
h.logger.Warn(msg, zap.String("from", req.RemoteAddr), zap.String("method", req.Method), zap.String("url", req.RequestURI), zap.Error(err))
}
func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
addrPort, err := netip.ParseAddrPort(req.RemoteAddr)
if err != nil {
h.logger.Error("failed to parse request remote addr", zap.String("addr", req.RemoteAddr), zap.Error(err))
w.WriteHeader(http.StatusInternalServerError)
return
}
clientAddr := addrPort.Addr()
// read remote addr from header
if header := h.srcIPHeader; len(header) != 0 {
if xff := req.Header.Get(header); len(xff) != 0 {
addr, err := readClientAddrFromXFF(xff)
if err != nil {
h.warnErr(req, "failed to get client ip from header", fmt.Errorf("failed to prase header %s: %s, %s", header, xff, err))
w.WriteHeader(http.StatusBadRequest)
return
}
clientAddr = addr
}
}
// read msg
q, err := ReadMsgFromReq(req)
if err != nil {
h.warnErr(req, "invalid request", err)
w.WriteHeader(http.StatusBadRequest)
return
}
queryMeta := QueryMeta{
ClientAddr: clientAddr,
}
if u := req.URL; u != nil {
queryMeta.UrlPath = u.Path
}
if tlsStat := req.TLS; tlsStat != nil {
queryMeta.ServerName = tlsStat.ServerName
}
resp := h.dnsHandler.Handle(req.Context(), q, queryMeta, pool.PackBuffer)
if resp == nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
defer pool.ReleaseBuf(resp)
w.Header().Set("Content-Type", "application/dns-message")
if _, err := w.Write(*resp); err != nil {
h.warnErr(req, "failed to write response", err)
return
}
}
func readClientAddrFromXFF(s string) (netip.Addr, error) {
if i := strings.IndexRune(s, ','); i > 0 {
return netip.ParseAddr(s[:i])
}
return netip.ParseAddr(s)
}
var errInvalidMediaType = errors.New("missing or invalid media type header")
var bufPool = pool.NewBytesBufPool(512)
func ReadMsgFromReq(req *http.Request) (*dns.Msg, error) {
var b []byte
switch req.Method {
case http.MethodGet:
// Check accept header
if req.Header.Get("Accept") != "application/dns-message" {
return nil, errInvalidMediaType
}
s := req.URL.Query().Get("dns")
if len(s) == 0 {
return nil, errors.New("no dns parameter")
}
msgSize := base64.RawURLEncoding.DecodedLen(len(s))
if msgSize > dns.MaxMsgSize {
return nil, fmt.Errorf("msg length %d is too big", msgSize)
}
var err error
b, err = base64.RawURLEncoding.DecodeString(s)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 query: %w", err)
}
case http.MethodPost:
// Check Content-Type header
if req.Header.Get("Content-Type") != "application/dns-message" {
return nil, errInvalidMediaType
}
buf := bufPool.Get()
defer bufPool.Release(buf)
_, err := buf.ReadFrom(io.LimitReader(req.Body, dns.MaxMsgSize))
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
b = buf.Bytes()
default:
return nil, fmt.Errorf("unsupported method: %s", req.Method)
}
m := new(dns.Msg)
if err := m.Unpack(b); err != nil {
return nil, fmt.Errorf("failed to unpack msg [%x], %w", b, err)
}
return m, nil
}

29
pkg/server/iface.go Normal file
View File

@ -0,0 +1,29 @@
package server
import (
"context"
"net/netip"
"github.com/miekg/dns"
)
// Handler handles incoming request q and MUST ALWAYS return a response.
// Handler MUST handle dns errors by itself and return a proper error responses.
// e.g. Return a SERVFAIL if something goes wrong.
// If Handle() returns a nil resp, caller will
// udp: do nothing.
// tcp/dot: close the connection immediately.
// doh: send a 500 response.
// doq: close the stream immediately.
type Handler interface {
Handle(ctx context.Context, q *dns.Msg, meta QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) (respPayload *[]byte)
}
type QueryMeta struct {
FromUDP bool
// Optional
ClientAddr netip.Addr
ServerName string
UrlPath string
}

119
pkg/server/tcp.go Normal file
View File

@ -0,0 +1,119 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/netip"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"go.uber.org/zap"
)
const (
defaultTCPIdleTimeout = time.Second * 10
tcpFirstReadTimeout = time.Second * 2
)
type TCPServerOpts struct {
// Nil logger == nop
Logger *zap.Logger
// Default is defaultTCPIdleTimeout.
IdleTimeout time.Duration
}
// ServeTCP starts a server at l. It returns if l had an Accept() error.
// It always returns a non-nil error.
func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error {
logger := opts.Logger
if logger == nil {
logger = nopLogger
}
idleTimeout := opts.IdleTimeout
if idleTimeout <= 0 {
idleTimeout = defaultTCPIdleTimeout
}
firstReadTimeout := tcpFirstReadTimeout
if idleTimeout < firstReadTimeout {
firstReadTimeout = idleTimeout
}
listenerCtx, cancel := context.WithCancelCause(context.Background())
defer cancel(errListenerCtxCanceled)
for {
c, err := l.Accept()
if err != nil {
return fmt.Errorf("unexpected listener err: %w", err)
}
// handle connection
tcpConnCtx, cancelConn := context.WithCancelCause(listenerCtx)
go func() {
defer c.Close()
defer cancelConn(errConnectionCtxCanceled)
firstRead := true
for {
if firstRead {
firstRead = false
c.SetReadDeadline(time.Now().Add(firstReadTimeout))
} else {
c.SetReadDeadline(time.Now().Add(idleTimeout))
}
req, _, err := dnsutils.ReadMsgFromTCP(c)
if err != nil {
return // read err, close the connection
}
// Try to get server name from tls conn.
var serverName string
if tlsConn, ok := c.(*tls.Conn); ok {
serverName = tlsConn.ConnectionState().ServerName
}
// handle query
go func() {
var clientAddr netip.Addr
ta, ok := c.RemoteAddr().(*net.TCPAddr)
if ok {
clientAddr = ta.AddrPort().Addr()
}
r := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr, ServerName: serverName}, pool.PackTCPBuffer)
if r == nil {
c.Close() // abort the connection
return
}
defer pool.ReleaseBuf(r)
if _, err := c.Write(*r); err != nil {
logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err))
return
}
}()
}
}()
}
}

33
pkg/server/tls.go Normal file
View File

@ -0,0 +1,33 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server
import (
"crypto/tls"
)
func LoadCert(tlsCfg *tls.Config, cert, key string) error {
c, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
return err
}
tlsCfg.Certificates = []tls.Certificate{c}
return nil
}

109
pkg/server/udp.go Normal file
View File

@ -0,0 +1,109 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server
import (
"context"
"fmt"
"net"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
"go.uber.org/zap"
)
type UDPServerOpts struct {
Logger *zap.Logger
}
// ServeUDP starts a server at c. It returns if c had a read error.
// It always returns a non-nil error.
// h is required. logger is optional.
func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error {
logger := opts.Logger
if logger == nil {
logger = nopLogger
}
listenerCtx, cancel := context.WithCancelCause(context.Background())
defer cancel(errListenerCtxCanceled)
rb := pool.GetBuf(dns.MaxMsgSize)
defer pool.ReleaseBuf(rb)
oobReader, oobWriter, err := initOobHandler(c)
if err != nil {
return fmt.Errorf("failed to init oob handler, %w", err)
}
var ob []byte
if oobReader != nil {
obp := pool.GetBuf(1024)
defer pool.ReleaseBuf(obp)
ob = *obp
}
for {
n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob)
if err != nil {
if n == 0 {
// Err with zero read. Most likely because c was closed.
return fmt.Errorf("unexpected read err: %w", err)
}
// Temporary err.
logger.Warn("read err", zap.Error(err))
continue
}
q := new(dns.Msg)
if err := q.Unpack((*rb)[:n]); err != nil {
logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr))
continue
}
var dstIpFromCm net.IP
if oobReader != nil {
var err error
dstIpFromCm, err = oobReader(ob[:oobn])
if err != nil {
logger.Error("failed to get dst address from oob", zap.Error(err))
}
}
// handle query
go func() {
payload := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}, pool.PackBuffer)
if payload == nil {
return
}
defer pool.ReleaseBuf(payload)
var oob []byte
if oobWriter != nil && dstIpFromCm != nil {
oob = oobWriter(dstIpFromCm)
}
if _, _, err := c.WriteMsgUDPAddrPort(*payload, oob, remoteAddr); err != nil {
logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
}
}()
}
}
type getSrcAddrFromOOB func(oob []byte) (net.IP, error)
type writeSrcAddrToOOB func(a net.IP) []byte

125
pkg/server/udp_linux.go Normal file
View File

@ -0,0 +1,125 @@
//go:build linux
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server
import (
"errors"
"fmt"
"net"
"os"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
var (
errCmNoDstAddr = errors.New("control msg does not have dst address")
)
func getOOBFromCM4(oob []byte) (net.IP, error) {
var cm ipv4.ControlMessage
if err := cm.Parse(oob); err != nil {
return nil, err
}
if cm.Dst == nil {
return nil, errCmNoDstAddr
}
return cm.Dst, nil
}
func getOOBFromCM6(oob []byte) (net.IP, error) {
var cm ipv6.ControlMessage
if err := cm.Parse(oob); err != nil {
return nil, err
}
if cm.Dst == nil {
return nil, errCmNoDstAddr
}
return cm.Dst, nil
}
func srcIP2Cm(ip net.IP) []byte {
if ip4 := ip.To4(); ip4 != nil {
return (&ipv4.ControlMessage{
Src: ip,
}).Marshal()
}
if ip6 := ip.To16(); ip6 != nil {
return (&ipv6.ControlMessage{
Src: ip,
}).Marshal()
}
return nil
}
func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) {
if !c.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() {
return nil, nil, nil
}
sc, err := c.SyscallConn()
if err != nil {
return nil, nil, err
}
var getter getSrcAddrFromOOB
var setter writeSrcAddrToOOB
var controlErr error
if err := sc.Control(func(fd uintptr) {
v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN)
if err != nil {
controlErr = os.NewSyscallError("failed to get SO_PROTOCOL", err)
return
}
switch v {
case unix.AF_INET:
c4 := ipv4.NewPacketConn(c)
if err := c4.SetControlMessage(ipv4.FlagDst, true); err != nil {
controlErr = fmt.Errorf("failed to set ipv4 cmsg flags, %w", err)
}
getter = getOOBFromCM4
setter = srcIP2Cm
return
case unix.AF_INET6:
c6 := ipv6.NewPacketConn(c)
if err := c6.SetControlMessage(ipv6.FlagDst, true); err != nil {
controlErr = fmt.Errorf("failed to set ipv6 cmsg flags, %w", err)
}
getter = getOOBFromCM6
setter = srcIP2Cm
return
default:
controlErr = fmt.Errorf("socket protocol %d is not supported", v)
}
}); err != nil {
return nil, nil, fmt.Errorf("control fd err, %w", controlErr)
}
if controlErr != nil {
return nil, nil, fmt.Errorf("failed to set up socket, %w", controlErr)
}
return getter, setter, nil
}

28
pkg/server/udp_others.go Normal file
View File

@ -0,0 +1,28 @@
//go:build !linux
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server
import "net"
func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) {
return nil, nil, nil
}

16
pkg/server/utils.go Normal file
View File

@ -0,0 +1,16 @@
package server
import (
"errors"
"go.uber.org/zap"
)
var (
errListenerCtxCanceled = errors.New("listener ctx canceled")
errConnectionCtxCanceled = errors.New("connection ctx canceled")
)
var (
nopLogger = zap.NewNop()
)

View File

@ -0,0 +1,154 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server_handler
import (
"context"
"time"
"github.com/IrineSistiana/mosdns/v5/mlog"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/IrineSistiana/mosdns/v5/pkg/server"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
"github.com/miekg/dns"
"go.uber.org/zap"
)
const (
defaultQueryTimeout = time.Second * 5
)
var (
nopLogger = mlog.Nop()
// options that can forward to upstream
queryForwardEDNS0Option = map[uint16]struct{}{
dns.EDNS0SUBNET: {},
}
// options that useless for downstream
respRemoveEDNS0Option = map[uint16]struct{}{
dns.EDNS0PADDING: {},
}
)
type EntryHandlerOpts struct {
// Logger is used for logging. Default is a noop logger.
Logger *zap.Logger
// Required.
Entry sequence.Executable
// QueryTimeout limits the timeout value of each query.
// Default is defaultQueryTimeout.
QueryTimeout time.Duration
}
func (opts *EntryHandlerOpts) init() {
if opts.Logger == nil {
opts.Logger = nopLogger
}
utils.SetDefaultNum(&opts.QueryTimeout, defaultQueryTimeout)
}
type EntryHandler struct {
opts EntryHandlerOpts
}
var _ server.Handler = (*EntryHandler)(nil)
func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler {
opts.init()
return &EntryHandler{opts: opts}
}
// ServeDNS implements server.Handler.
// If entry returns an error, a SERVFAIL response will be returned.
// If entry returns without a response, a REFUSED response will be returned.
func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, serverMeta server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte {
// basic query check.
if q.Response || len(q.Question) != 1 || len(q.Answer)+len(q.Ns) > 0 || len(q.Extra) > 1 {
return nil
}
ddl := time.Now().Add(h.opts.QueryTimeout)
ctx, cancel := context.WithDeadline(ctx, ddl)
defer cancel()
qCtx := query_context.NewContext(q)
qCtx.ServerMeta = serverMeta
// exec entry
err := h.opts.Entry.Exec(ctx, qCtx)
var resp *dns.Msg
if err != nil {
h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err))
resp = new(dns.Msg)
resp.SetReply(q)
resp.Rcode = dns.RcodeServerFailure
} else {
resp = qCtx.R()
}
if resp == nil {
resp = new(dns.Msg)
resp.SetReply(q)
resp.Rcode = dns.RcodeRefused
}
// We assume that our server is a forwarder.
resp.RecursionAvailable = true
// add respOpt back to resp
if respOpt := qCtx.RespOpt(); respOpt != nil {
resp.Extra = append(resp.Extra, respOpt)
}
if serverMeta.FromUDP {
udpSize := getValidUDPSize(qCtx.ClientOpt())
resp.Truncate(udpSize)
}
payload, err := packMsgPayload(resp)
if err != nil {
h.opts.Logger.Error("internal err: failed to pack resp msg", qCtx.InfoField(), zap.Error(err))
return nil
}
return payload
}
// opt can be nil.
func getValidUDPSize(opt *dns.OPT) int {
var s uint16
if opt != nil {
s = opt.UDPSize()
}
if s < dns.MinMsgSize {
s = dns.MinMsgSize
}
return int(s)
}
func newOpt() *dns.OPT {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
return opt
}

View File

@ -0,0 +1,248 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package bootstrap
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/miekg/dns"
"go.uber.org/zap"
)
const (
minimumUpdateInterval = time.Minute * 5
retryInterval = time.Second * 2
queryTimeout = time.Second * 5
)
var (
errNoAddrInResp = errors.New("resp does not have ip address")
)
func New(
host string,
port uint16,
bootstrapServer netip.AddrPort,
bootstrapVer int, // 0,4,6
logger *zap.Logger, // not nil
) (*Bootstrap, error) {
dp := new(Bootstrap)
dp.fqdn = dns.Fqdn(host)
dp.port = port
if !bootstrapServer.IsValid() {
return nil, errors.New("invalid bootstrap server address")
}
dp.bootstrap = net.UDPAddrFromAddrPort(bootstrapServer)
qt, ok := bootstrapVer2Qt(bootstrapVer)
if !ok {
return nil, fmt.Errorf("invalid bootstrap version %d", bootstrapVer)
}
dp.qt = qt
dp.logger = logger
dp.readyNotify = make(chan struct{})
return dp, nil
}
type Bootstrap struct {
fqdn string
port uint16
bootstrap *net.UDPAddr
qt uint16 // dns.TypeA or dns.TypeAAAA
logger *zap.Logger // not nil
updating atomic.Bool
nextUpdate time.Time
readyNotify chan struct{}
m sync.Mutex
ready bool
addrStr string
}
func (sp *Bootstrap) GetAddrPortStr(ctx context.Context) (string, error) {
sp.tryUpdate()
select {
case <-ctx.Done():
return "", context.Cause(ctx)
case <-sp.readyNotify:
}
sp.m.Lock()
addr := sp.addrStr
sp.m.Unlock()
return addr, nil
}
func (sp *Bootstrap) tryUpdate() {
if sp.updating.CompareAndSwap(false, true) {
if time.Now().After(sp.nextUpdate) {
go func() {
defer sp.updating.Store(false)
ctx, cancel := context.WithTimeout(context.Background(), queryTimeout)
defer cancel()
start := time.Now()
addr, ttl, err := sp.updateAddr(ctx)
if err != nil {
sp.logger.Check(zap.WarnLevel, "failed to update bootstrap addr").Write(
zap.String("fqdn", sp.fqdn),
zap.Error(err),
)
sp.nextUpdate = time.Now().Add(retryInterval)
} else {
updateInterval := time.Second * time.Duration(ttl)
if updateInterval < minimumUpdateInterval {
updateInterval = minimumUpdateInterval
}
sp.logger.Check(zap.DebugLevel, "bootstrap addr updated").Write(
zap.String("fqdn", sp.fqdn),
zap.Stringer("addr", addr),
zap.Duration("ttl", updateInterval),
zap.Duration("elapse", time.Since(start)),
)
sp.nextUpdate = time.Now().Add(updateInterval)
}
}()
} else {
sp.updating.Store(false)
}
}
}
func (sp *Bootstrap) updateAddr(ctx context.Context) (netip.Addr, uint32, error) {
addr, ttl, err := sp.resolve(ctx, sp.qt)
if err != nil {
return netip.Addr{}, 0, err
}
addrPort := netip.AddrPortFrom(addr, sp.port).String()
sp.m.Lock()
sp.addrStr = addrPort
if !sp.ready {
sp.ready = true
close(sp.readyNotify)
}
sp.m.Unlock()
return addr, ttl, nil
}
func (sp *Bootstrap) resolve(ctx context.Context, qt uint16) (netip.Addr, uint32, error) {
const edns0UdpSize = 1200
q := new(dns.Msg)
q.SetQuestion(sp.fqdn, qt)
q.SetEdns0(edns0UdpSize, false)
c, err := net.DialUDP("udp", nil, sp.bootstrap)
if err != nil {
return netip.Addr{}, 0, err
}
defer c.Close()
writeErrC := make(chan error, 1)
type res struct {
resp *dns.Msg
err error
}
readResC := make(chan res, 1)
cancelWrite := make(chan struct{})
defer close(cancelWrite)
go func() {
if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil {
writeErrC <- err
return
}
retryTicker := time.NewTicker(time.Second)
defer retryTicker.Stop()
for {
select {
case <-cancelWrite:
return
case <-retryTicker.C:
if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil {
writeErrC <- err
return
}
}
}
}()
go func() {
m, _, err := dnsutils.ReadMsgFromUDP(c, edns0UdpSize)
readResC <- res{resp: m, err: err}
}()
select {
case <-ctx.Done():
return netip.Addr{}, 0, context.Cause(ctx)
case err := <-writeErrC:
return netip.Addr{}, 0, fmt.Errorf("failed to write query, %w", err)
case r := <-readResC:
resp := r.resp
err := r.err
if err != nil {
return netip.Addr{}, 0, fmt.Errorf("failed to read resp, %w", err)
}
for _, v := range resp.Answer {
var ip net.IP
var ttl uint32
switch rr := v.(type) {
case *dns.A:
ip = rr.A
ttl = rr.Hdr.Ttl
case *dns.AAAA:
ip = rr.AAAA
ttl = rr.Hdr.Ttl
default:
continue
}
addr, ok := netip.AddrFromSlice(ip)
if ok {
return addr, ttl, nil
}
}
// No ip addr in resp.
return netip.Addr{}, 0, errNoAddrInResp
}
}
func bootstrapVer2Qt(ver int) (uint16, bool) {
switch ver {
case 0, 4:
return dns.TypeA, true
case 6:
return dns.TypeAAAA, true
default:
return 0, false
}
}

View File

@ -0,0 +1,166 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package doh
import (
"context"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"net/http"
urlpkg "net/url"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/miekg/dns"
"go.uber.org/zap"
)
const (
defaultDoHTimeout = time.Second * 6
)
var nopLogger = zap.NewNop()
// Upstream is a DNS-over-HTTPS (RFC 8484) upstream.
type Upstream struct {
rt http.RoundTripper
logger *zap.Logger // non-nil
urlTemplate *urlpkg.URL
reqTemplate *http.Request
}
func NewUpstream(endPoint string, rt http.RoundTripper, logger *zap.Logger) (*Upstream, error) {
req, err := http.NewRequest(http.MethodGet, endPoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to parse http request, %w", err)
}
req.Header["Accept"] = []string{"application/dns-message"}
req.Header["User-Agent"] = nil // Don't let go http send a default user agent header.
if logger == nil {
logger = nopLogger
}
return &Upstream{
rt: rt,
logger: logger,
urlTemplate: req.URL,
reqTemplate: req,
}, nil
}
var (
bufPool4k = pool.NewBytesBufPool(4096)
)
func (u *Upstream) ExchangeContext(ctx context.Context, q []byte) (*[]byte, error) {
bp := pool.GetBuf(len(q))
defer pool.ReleaseBuf(bp)
wire := *bp
copy(wire, q)
// In order to maximize HTTP cache friendliness, DoH clients using media
// formats that include the ID field from the DNS message header, such
// as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS
// request.
// https://tools.ietf.org/html/rfc8484#section-4.1
wire[0] = 0
wire[1] = 0
queryLen := 4 + base64.RawURLEncoding.EncodedLen(len(wire))
queryBuf := make([]byte, queryLen)
p := 0
p += copy(queryBuf, "dns=")
// Padding characters for base64url MUST NOT be included.
// See: https://tools.ietf.org/html/rfc8484#section-6.
base64.RawURLEncoding.Encode(queryBuf[p:], wire)
type res struct {
r *[]byte
err error
}
resChan := make(chan res, 1)
go func() {
// We overwrite the ctx with a fixed timeout context here.
// Because the http package may close the underlay connection
// if the context is done before the query is completed. This
// reduces the connection reuse efficiency.
ctx, cancel := context.WithTimeout(context.Background(), defaultDoHTimeout)
defer cancel()
r, err := u.exchange(ctx, utils.BytesToStringUnsafe(queryBuf))
if err != nil {
u.logger.Check(zap.WarnLevel, "exchange failed").Write(zap.Error(err))
}
resChan <- res{r: r, err: err}
}()
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case res := <-resChan:
r := res.r
err := res.err
if r != nil {
binary.BigEndian.PutUint16(*r, binary.BigEndian.Uint16(q))
}
return r, err
}
}
func (u *Upstream) exchange(ctx context.Context, dnsQuery string) (*[]byte, error) {
req := u.reqTemplate.WithContext(ctx)
req.URL = new(urlpkg.URL)
*req.URL = *u.urlTemplate
req.URL.RawQuery = dnsQuery
resp, err := u.rt.RoundTrip(req)
if err != nil {
return nil, fmt.Errorf("http request failed: %w", err)
}
defer resp.Body.Close()
// check status code
if resp.StatusCode != http.StatusOK {
body1k, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
if body1k != nil {
return nil, fmt.Errorf("bad http status codes %d with body [%s]", resp.StatusCode, body1k)
}
return nil, fmt.Errorf("bad http status codes %d", resp.StatusCode)
}
bb := bufPool4k.Get()
defer bufPool4k.Release(bb)
_, err = bb.ReadFrom(io.LimitReader(resp.Body, dns.MaxMsgSize))
if err != nil {
return nil, fmt.Errorf("failed to read http body: %w", err)
}
if bb.Len() < dnsutils.DnsHeaderLen {
return nil, dnsutils.ErrPayloadTooSmall
}
payload := pool.GetBuf(bb.Len())
copy(*payload, bb.Bytes())
return payload, nil
}

View File

@ -0,0 +1,70 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package upstream
import (
"net"
"sync/atomic"
)
type Event int
const (
EventConnOpen Event = iota
EventConnClose
)
type EventObserver interface {
OnEvent(typ Event)
}
type nopEO struct{}
func (n nopEO) OnEvent(_ Event) {}
type connWrapper struct {
net.Conn
closed atomic.Bool
ob EventObserver
}
// wrapConn wraps c into a connWrapper so that we can observe the connection close.
// For convenient, if c is nil, wrapConn returns nil as well. If ob is nopEO, wrapConn
// returns c.
func wrapConn(c net.Conn, ob EventObserver) net.Conn {
if c == nil {
return nil
}
if _, ok := ob.(nopEO); ok {
return c
}
ob.OnEvent(EventConnOpen)
return &connWrapper{
Conn: c,
ob: ob,
}
}
func (c *connWrapper) Close() error {
if c.closed.CompareAndSwap(false, true) {
c.ob.OnEvent(EventConnClose)
}
return c.Conn.Close()
}

View File

@ -0,0 +1,165 @@
package transport
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
)
type lazyDnsConn struct {
maxConcurrentQuery int
cancelDial context.CancelFunc
mu sync.Mutex
earlyReserveCallWg sync.WaitGroup
closed bool
reservedQuery int
dialFinished chan struct{}
c DnsConn
dialErr error
// 1: dial completed and all early reserve call finished.
// 2: dial failed.
fastPath atomic.Uint32
}
var _ DnsConn = (*lazyDnsConn)(nil)
var (
errLazyConnDialCanceled = errors.New("lazy dial canceled")
)
func newLazyDnsConn(
dial func(ctx context.Context) (DnsConn, error),
dialTimeout time.Duration,
maxConcurrentQueryWhileDialing int, // must be valid, no default value
logger *zap.Logger, // must non-nil
) *lazyDnsConn {
if dialTimeout <= 0 {
dialTimeout = defaultDialTimeout
}
dialCtx, cancelDial := context.WithTimeout(context.Background(), defaultDialTimeout)
lc := &lazyDnsConn{
maxConcurrentQuery: maxConcurrentQueryWhileDialing,
cancelDial: cancelDial,
dialFinished: make(chan struct{}),
}
go func() {
dc, err := dial(dialCtx)
cancelDial()
if err != nil {
logger.Check(zap.WarnLevel, "failed to dial dns conn").Write(zap.Error(err))
}
lc.mu.Lock()
if lc.closed { // lc was closed and dial was canceled
lc.mu.Unlock()
if dc != nil {
dc.Close()
}
return
}
lc.c = dc
lc.dialErr = err
close(lc.dialFinished)
lc.mu.Unlock()
}()
return lc
}
func (lc *lazyDnsConn) Close() error {
lc.mu.Lock()
defer lc.mu.Unlock()
if lc.closed {
return nil
}
lc.closed = true
if lc.c == nil && lc.dialErr == nil { // still dialing
lc.cancelDial()
lc.dialErr = errLazyConnDialCanceled
close(lc.dialFinished)
} else {
// close connection
if lc.c != nil {
lc.c.Close()
}
}
return nil
}
func (lc *lazyDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
switch lc.fastPath.Load() {
case 1:
return lc.c.ReserveNewQuery()
case 2:
return nil, true
}
lc.mu.Lock()
defer lc.mu.Unlock()
select {
case <-lc.dialFinished:
// Note: race condition here and lazyDnsConnEarlyReservedExchanger.ExchangeReserved().
// Not a big problem. May cause at most all early exchange failed.
// earlyExchangeWg makes sure that early exchange calls ReserveNewQuery first.
dc, err := lc.c, lc.dialErr
if err != nil {
lc.fastPath.Store(2)
return nil, true
}
lc.earlyReserveCallWg.Wait()
lc.fastPath.Store(1)
return dc.ReserveNewQuery()
default:
if lc.reservedQuery >= lc.maxConcurrentQuery {
return nil, false
}
lc.reservedQuery++
lc.earlyReserveCallWg.Add(1)
return (*lazyDnsConnEarlyReservedExchanger)(lc), false
}
}
type lazyDnsConnEarlyReservedExchanger lazyDnsConn
var _ ReservedExchanger = (*lazyDnsConnEarlyReservedExchanger)(nil)
func (ote *lazyDnsConnEarlyReservedExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) {
defer func() {
ote.mu.Lock()
ote.reservedQuery--
ote.mu.Unlock()
}()
select {
case <-ctx.Done():
ote.earlyReserveCallWg.Done()
return nil, context.Cause(ctx)
case <-ote.dialFinished:
dc, err := ote.c, ote.dialErr
if err != nil {
return nil, err
}
rec, _ := dc.ReserveNewQuery()
ote.earlyReserveCallWg.Done()
if rec == nil {
return nil, ErrLazyConnCannotReserveQueryExchanger
}
return rec.ExchangeReserved(ctx, q)
}
}
func (ote *lazyDnsConnEarlyReservedExchanger) WithdrawReserved() {
ote.earlyReserveCallWg.Done()
ote.mu.Lock()
ote.reservedQuery--
ote.mu.Unlock()
}

View File

@ -0,0 +1,123 @@
package transport
import (
"context"
"encoding/binary"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/quic-go/quic-go"
)
const (
quicQueryTimeout = time.Second * 6
)
const (
// RFC 9250 4.3. DoQ Error Codes
_DOQ_NO_ERROR = quic.StreamErrorCode(0x0)
_DOQ_INTERNAL_ERROR = quic.StreamErrorCode(0x1)
_DOQ_REQUEST_CANCELLED = quic.StreamErrorCode(0x3)
)
var _ DnsConn = (*QuicDnsConn)(nil)
type QuicDnsConn struct {
c quic.Connection
}
func NewQuicDnsConn(c quic.Connection) *QuicDnsConn {
return &QuicDnsConn{c: c}
}
func (c *QuicDnsConn) Close() error {
return c.c.CloseWithError(0, "")
}
func (c *QuicDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
select {
case <-c.c.Context().Done():
return nil, true
default:
}
s, err := c.c.OpenStream()
// We just checked the connection is alive. So we are assuming the error
// is caused by reaching the peer's stream limit.
if err != nil {
return nil, false
}
return &quicReservedExchanger{stream: s}, false
}
type quicReservedExchanger struct {
stream quic.Stream
}
var _ ReservedExchanger = (*quicReservedExchanger)(nil)
func (ote *quicReservedExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) {
stream := ote.stream
payload, err := copyMsgWithLenHdr(q)
if err != nil {
stream.CancelWrite(_DOQ_REQUEST_CANCELLED)
stream.CancelRead(_DOQ_REQUEST_CANCELLED)
return nil, err
}
// 4.2.1. DNS Message IDs
// When sending queries over a QUIC connection, the DNS Message ID MUST
// be set to 0. The stream mapping for DoQ allows for unambiguous
// correlation of queries and responses, so the Message ID field is not
// required.
orgQid := binary.BigEndian.Uint16((*payload)[2:])
binary.BigEndian.PutUint16((*payload)[2:], 0)
stream.SetDeadline(time.Now().Add(quicQueryTimeout))
_, err = stream.Write(*payload)
pool.ReleaseBuf(payload)
if err != nil {
stream.CancelRead(_DOQ_REQUEST_CANCELLED)
stream.CancelWrite(_DOQ_REQUEST_CANCELLED)
return nil, err
}
// RFC 9250 4.2
// The client MUST send the DNS query over the selected stream and MUST
// indicate through the STREAM FIN mechanism that no further data will
// be sent on that stream.
//
// Call Close() here will send the STREAM FIN. It won't close Read.
stream.Close()
type res struct {
resp *[]byte
err error
}
rc := make(chan res, 1)
go func() {
r, err := dnsutils.ReadRawMsgFromTCP(stream)
rc <- res{resp: r, err: err}
}()
select {
case <-ctx.Done():
stream.CancelRead(_DOQ_REQUEST_CANCELLED)
return nil, context.Cause(ctx)
case r := <-rc:
resp := r.resp
err := r.err
if resp != nil {
binary.BigEndian.PutUint16((*resp), orgQid)
}
stream.CancelRead(_DOQ_NO_ERROR)
return resp, err
}
}
func (ote *quicReservedExchanger) WithdrawReserved() {
s := ote.stream
s.CancelRead(_DOQ_REQUEST_CANCELLED)
s.CancelWrite(_DOQ_REQUEST_CANCELLED)
}

View File

@ -0,0 +1,283 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"context"
"encoding/binary"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
)
var (
ErrTDCTooManyQueries = errors.New("too many queries") // Connection has too many ongoing queries.
ErrTDCClosed = errors.New("dns connection closed")
)
var _ DnsConn = (*TraditionalDnsConn)(nil)
// TraditionalDnsConn is a low-level connection for traditional dns protocol, where
// dns frames transport in a single and simple connection. (e.g. udp, tcp, tls)
type TraditionalDnsConn struct {
c NetConn
isTcp bool
idleTimeout time.Duration
maxCq int
closeOnce sync.Once
closeNotify chan struct{}
closed atomic.Bool // atomic, for fast check
closeErr error // closeErr is ready (not nil) when closeNotify is closed.
queueMu sync.RWMutex
reservedQuery int
nextQid uint16
queue map[uint32]chan *[]byte // uint32 has fast path
// waitingResp indicates connection is waiting a reply from the peer.
// It can identify c is dead or buggy in some circumstances. e.g. Network is dropped
// and the sockets were still open because no fin or rst was received.
waitingResp atomic.Bool
}
type TraditionalDnsConnOpts struct {
// Set to true if underlayer connection require a length header.
// e.g. TCP and DoT.
WithLengthHeader bool
// IdleTimeout controls the maximum idle time for each connection.
// Default is defaultIdleTimeout.
IdleTimeout time.Duration
// MaxConcurrentQuery limits the number of maximum concurrent queries
// in the connection. Default is defaultTdcMaxConcurrentQuery.
MaxConcurrentQuery int
}
func NewDnsConn(opt TraditionalDnsConnOpts, conn NetConn) *TraditionalDnsConn {
dc := &TraditionalDnsConn{
c: conn,
isTcp: opt.WithLengthHeader,
closeNotify: make(chan struct{}),
queue: make(map[uint32]chan *[]byte),
}
setDefaultGZ(&dc.idleTimeout, opt.IdleTimeout, defaultIdleTimeout)
setDefaultGZ(&dc.maxCq, opt.MaxConcurrentQuery, defaultTdcMaxConcurrentQuery)
go dc.readLoop()
return dc
}
// exchange sends q out and waits for its reply.
func (dc *TraditionalDnsConn) exchange(ctx context.Context, q []byte) (*[]byte, error) {
select {
case <-dc.closeNotify:
return nil, ErrTDCClosed
default:
}
assignedQid, respChan := dc.addQueueC()
if respChan == nil {
return nil, ErrTDCTooManyQueries
}
defer dc.deleteQueueC(assignedQid)
// Reminder: Set write deadline here is not very useful to avoid dead connections.
// Typically, a write operation will time out only if its socket buffer is full.
// Ser read deadline is enough.
err := dc.writeQuery(q, assignedQid)
if err != nil {
// Write error usually is fatal. Abort and close this connection.
dc.CloseWithErr(fmt.Errorf("write err, %w", err))
return nil, err
}
// If a query was sent, server should have a reply (even not for this query) in a short time.
// This indicates the connection is healthy. Otherwise, this connection might be dead.
// The Read deadline will be refreshed in DnsConn.readLoop() after every successful read.
// Note: There has a race condition in this SetReadDeadline() call and the one in
// readLoop(). It's not a big problem.
if dc.waitingResp.CompareAndSwap(false, true) {
dc.c.SetReadDeadline(time.Now().Add(waitingReplyTimeout))
}
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case r := <-respChan:
orgId := binary.BigEndian.Uint16(q)
binary.BigEndian.PutUint16(*r, orgId)
return r, nil
case <-dc.closeNotify:
return nil, dc.closeErr
}
}
func (dc *TraditionalDnsConn) writeQuery(q []byte, assignedQid uint16) error {
var payload *[]byte
if dc.isTcp {
var err error
payload, err = copyMsgWithLenHdr(q)
if err != nil {
return err
}
binary.BigEndian.PutUint16((*payload)[2:], assignedQid)
} else {
payload = copyMsg(q)
binary.BigEndian.PutUint16(*payload, assignedQid)
}
_, err := dc.c.Write(*payload)
pool.ReleaseBuf(payload)
return err
}
func (dc *TraditionalDnsConn) readResp() (payload *[]byte, err error) {
if dc.isTcp {
return dnsutils.ReadRawMsgFromTCP(dc.c)
}
return readMsgUdp(dc.c)
}
// readLoop reads DnsConn until there was a read error.
func (dc *TraditionalDnsConn) readLoop() {
for {
dc.c.SetReadDeadline(time.Now().Add(dc.idleTimeout))
r, err := dc.readResp()
if err != nil {
dc.CloseWithErr(fmt.Errorf("read err, %w", err)) // abort this connection.
return
}
dc.waitingResp.Store(false)
rid := binary.BigEndian.Uint16(*r)
resChan := dc.getQueueC(rid)
if resChan != nil {
select {
case resChan <- r: // resChan has buffer
default:
pool.ReleaseBuf(r)
}
} else {
pool.ReleaseBuf(r)
}
}
}
func (dc *TraditionalDnsConn) IsClosed() bool {
return dc.closed.Load()
}
func (dc *TraditionalDnsConn) Close() error {
dc.CloseWithErr(ErrTDCClosed)
return nil
}
// CloseWithErr closes DnsConn with an error. The error will be sent
// to the waiting Exchange calls.
// Subsequent calls are noop.
// Default err is ErrTDCClosed.
func (dc *TraditionalDnsConn) CloseWithErr(err error) {
if err == nil {
err = ErrTDCClosed
}
dc.closeOnce.Do(func() {
dc.closed.Store(true)
dc.closeErr = err
close(dc.closeNotify)
dc.c.Close()
})
}
func (dc *TraditionalDnsConn) getQueueC(qid uint16) chan<- *[]byte {
dc.queueMu.RLock()
defer dc.queueMu.RUnlock()
return dc.queue[uint32(qid)]
}
func (dc *TraditionalDnsConn) queueLen() int {
dc.queueMu.RLock()
defer dc.queueMu.RUnlock()
return len(dc.queue) + dc.reservedQuery
}
// addQueueC assigns a qid and add it to the queue.
// It returns a nil c if queue has too many queries.
// Caller must call deleteQueueC to release the qid in queue.
func (dc *TraditionalDnsConn) addQueueC() (qid uint16, c chan *[]byte) {
c = make(chan *[]byte)
dc.queueMu.Lock()
for i := 0; i < 100; i++ {
qid = dc.nextQid
dc.nextQid++
if _, dup := dc.queue[uint32(qid)]; dup {
continue
}
dc.queue[uint32(qid)] = c
dc.queueMu.Unlock()
return qid, c
}
dc.queueMu.Unlock()
// Too many queries in queue. Can't assign qid.
return 0, nil
}
func (dc *TraditionalDnsConn) deleteQueueC(qid uint16) {
dc.queueMu.Lock()
delete(dc.queue, uint32(qid))
dc.queueMu.Unlock()
}
func (dc *TraditionalDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
if dc.closed.Load() {
return nil, true
}
dc.queueMu.Lock()
defer dc.queueMu.Unlock()
if len(dc.queue)+dc.reservedQuery >= dc.maxCq {
return nil, false
}
dc.reservedQuery++
return (*tdcOneTimeExchanger)(dc), false
}
type tdcOneTimeExchanger TraditionalDnsConn
var _ ReservedExchanger = (*tdcOneTimeExchanger)(nil)
func (ote *tdcOneTimeExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) {
defer ote.WithdrawReserved()
return (*TraditionalDnsConn)(ote).exchange(ctx, q)
}
func (ote *tdcOneTimeExchanger) WithdrawReserved() {
ote.queueMu.Lock()
ote.reservedQuery--
ote.queueMu.Unlock()
}

View File

@ -0,0 +1,250 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"bytes"
"context"
"errors"
"math/rand"
"net"
"runtime"
"sync"
"testing"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
type dummyEchoNetConn struct {
net.Conn
rErrProb float64
rLatency time.Duration
wErrProb float64
closeOnce sync.Once
closeNotify chan struct{}
}
func newDummyEchoNetConn(rErrProb float64, rLatency time.Duration, wErrProb float64) NetConn {
c1, c2 := net.Pipe()
go func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
defer c1.Close()
defer c2.Close()
for {
m, readErr := dnsutils.ReadRawMsgFromTCP(c2)
if m != nil {
go func() {
defer pool.ReleaseBuf(m)
if rLatency > 0 {
t := time.NewTimer(rLatency)
defer t.Stop()
select {
case <-t.C:
case <-ctx.Done():
return
}
}
latency := time.Millisecond * time.Duration(rand.Intn(20))
time.Sleep(latency)
_, _ = dnsutils.WriteRawMsgToTCP(c2, *m)
}()
}
if readErr != nil {
return
}
}
}()
return &dummyEchoNetConn{
Conn: c1,
rErrProb: rErrProb,
rLatency: rLatency,
wErrProb: wErrProb,
closeNotify: make(chan struct{}),
}
}
func probTrue(p float64) bool {
return rand.Float64() < p
}
func (d *dummyEchoNetConn) Read(p []byte) (n int, err error) {
if probTrue(d.rErrProb) {
return 0, errors.New("read err")
}
return d.Conn.Read(p)
}
func (d *dummyEchoNetConn) Write(p []byte) (n int, err error) {
if probTrue(d.wErrProb) {
return 0, errors.New("write err")
}
return d.Conn.Write(p)
}
func (d *dummyEchoNetConn) Close() error {
d.closeOnce.Do(func() {
close(d.closeNotify)
})
return d.Conn.Close()
}
func Test_dnsConn_exchange(t *testing.T) {
idleTimeout := time.Millisecond * 100
tests := []struct {
name string
rErrProb float64
rLatency time.Duration
wErrProb float64
connClosed bool // connection is closed before calling exchange()
wantMsg bool
wantErr bool
}{
{
name: "normal",
rErrProb: 0,
rLatency: 0,
wErrProb: 0,
wantMsg: true, wantErr: false,
},
{
name: "write err",
rErrProb: 0,
rLatency: 0,
wErrProb: 1,
wantMsg: false, wantErr: true,
},
{
name: "read err",
rErrProb: 1,
rLatency: 0,
wErrProb: 0,
wantMsg: false, wantErr: true,
},
{
name: "read timeout",
rErrProb: 0,
rLatency: idleTimeout * 3,
wErrProb: 0,
wantMsg: false, wantErr: true,
},
{
name: "connection closed",
rErrProb: 0,
rLatency: 0,
wErrProb: 0,
connClosed: true,
wantMsg: false, wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := require.New(t)
c := newDummyEchoNetConn(tt.rErrProb, tt.rLatency, tt.wErrProb)
defer c.Close()
ioOpts := TraditionalDnsConnOpts{
WithLengthHeader: true, // TODO: Test false as well
IdleTimeout: idleTimeout,
}
dc := NewDnsConn(ioOpts, c)
if tt.connClosed {
dc.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
q := new(dns.Msg)
q.SetQuestion("test.", dns.TypeA)
queryPayload, err := q.Pack()
r.NoError(err)
rec, closed := dc.ReserveNewQuery()
r.Equal(tt.connClosed, closed)
if rec == nil {
return
}
respPayload, err := rec.ExchangeReserved(ctx, queryPayload)
if tt.wantErr {
r.Error(err)
} else {
r.NoError(err)
}
if tt.wantMsg {
r.NotNil(respPayload)
r.True(bytes.Equal(queryPayload, *respPayload))
// test idle timeout
time.Sleep(idleTimeout + time.Millisecond*20)
runtime.Gosched()
r.True(dc.IsClosed(), "connection should be closed due to idle timeout")
} else {
r.Nil(respPayload)
}
})
}
}
// TODO: 测试 maxconcurrentquery。
func Test_dnsConn_exchange_race(t *testing.T) {
r := require.New(t)
wg := new(sync.WaitGroup)
for i := 0; i < 1024; i++ {
c := newDummyEchoNetConn(0.5, time.Millisecond*20, 0.5)
ioOpts := TraditionalDnsConnOpts{
WithLengthHeader: true, // TODO: Test false as well
IdleTimeout: time.Millisecond * 50,
}
dc := NewDnsConn(ioOpts, c)
for j := 0; j < 24; j++ {
if dc.IsClosed() {
break
}
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
q := new(dns.Msg)
q.SetQuestion("test.", dns.TypeA)
queryPayload, err := q.Pack()
r.NoError(err)
rec, closed := dc.ReserveNewQuery()
if closed {
return
}
if rec != nil {
_, _ = rec.ExchangeReserved(ctx, queryPayload)
}
}()
}
}
wg.Wait()
}

View File

@ -0,0 +1,151 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"context"
"sync"
"time"
"go.uber.org/zap"
)
// PipelineTransport will pipeline queries as RFC 7766 6.2.1.1 suggested.
// It also can reuse udp socket. Since dns over udp is some kind of "pipeline".
type PipelineTransport struct {
m sync.Mutex // protect following fields
closed bool
conns map[*lazyDnsConn]struct{}
dialFunc func(ctx context.Context) (DnsConn, error)
dialTimeout time.Duration
maxLazyConnQueue int
logger *zap.Logger // not nil
}
type PipelineOpts struct {
// DialContext specifies the method to dial a connection to the server.
// DialContext MUST NOT be nil.
DialContext func(ctx context.Context) (DnsConn, error)
// DialTimeout specifies the timeout for DialFunc.
// Default is defaultDialTimeout.
DialTimeout time.Duration
// When connection is dialing, how many queries can be queued up in that
// connection. Default is defaultLazyConnMaxConcurrentQuery.
// Note: If the connection turns out having a smaller limit, part of queued up
// queries will fail.
MaxConcurrentQueryWhileDialing int
Logger *zap.Logger
}
func NewPipelineTransport(opt PipelineOpts) *PipelineTransport {
t := &PipelineTransport{
conns: make(map[*lazyDnsConn]struct{}),
}
t.dialFunc = opt.DialContext
setDefaultGZ(&t.dialTimeout, opt.DialTimeout, defaultDialTimeout)
setDefaultGZ(&t.maxLazyConnQueue, opt.MaxConcurrentQueryWhileDialing, defaultMaxLazyConnQueue)
setNonNilLogger(&t.logger, opt.Logger)
return t
}
func (t *PipelineTransport) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) {
const maxRetry = 2
retry := 0
for {
dc, isNewConn, err := t.getReservedExchanger()
if err != nil {
return nil, err
}
r, err := dc.ExchangeReserved(ctx, m)
if err != nil {
// Reused connection may not stable.
// Try to re-send this query if it failed on a reused connection.
if !isNewConn && retry < maxRetry && ctx.Err() == nil {
retry++
continue
}
return nil, err
}
return r, nil
}
}
// Close closes PipelineTransport and all its connections.
// It always returns a nil error.
func (t *PipelineTransport) Close() error {
t.m.Lock()
defer t.m.Unlock()
if t.closed {
return nil
}
t.closed = true
for conn := range t.conns {
conn.Close()
}
return nil
}
func (t *PipelineTransport) getReservedExchanger() (_ ReservedExchanger, isNewConn bool, err error) {
t.m.Lock()
if t.closed {
err = ErrClosedTransport
t.m.Unlock()
return
}
var rxc ReservedExchanger
const maxReserveAttempt = 16
reserveAttempt := 0
for c := range t.conns {
var closed bool
rxc, closed = c.ReserveNewQuery()
if closed {
delete(t.conns, c)
}
if rxc != nil {
break
} else {
reserveAttempt++
if reserveAttempt > maxReserveAttempt {
break
}
}
}
// Dial a new connection
if rxc == nil {
c := newLazyDnsConn(t.dialFunc, t.dialTimeout, t.maxLazyConnQueue, t.logger)
rxc, _ = c.ReserveNewQuery() // ignore the closed error for new lazy connection
isNewConn = true
t.conns[c] = struct{}{}
}
t.m.Unlock()
if rxc == nil {
isNewConn = false
err = ErrNewConnCannotReserveQueryExchanger
}
return rxc, isNewConn, err
}

View File

@ -0,0 +1,152 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
type dummyEchoDnsConnOpt struct {
exchangeErr error
mcq int
closed atomic.Bool
wantConcurrentExchangeCall int
waitingExchangeCall atomic.Int32
unblockOnce sync.Once
unblockExchange chan struct{}
}
type dummyEchoDnsConn struct {
opt *dummyEchoDnsConnOpt
m sync.Mutex
reserved int
}
func (dc *dummyEchoDnsConn) ExchangeReserved(ctx context.Context, q []byte) (*[]byte, error) {
defer dc.WithdrawReserved()
if dc.opt.waitingExchangeCall.Add(1) == int32(dc.opt.wantConcurrentExchangeCall) {
dc.opt.unblockOnce.Do(func() { close(dc.opt.unblockExchange) })
}
defer dc.opt.waitingExchangeCall.Add(-1)
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-dc.opt.unblockExchange:
if dc.opt.exchangeErr != nil {
return nil, dc.opt.exchangeErr
}
return copyMsg(q), nil
}
}
func (dc *dummyEchoDnsConn) WithdrawReserved() {
dc.m.Lock()
defer dc.m.Unlock()
dc.reserved--
if dc.reserved < 0 {
panic("negative reserved counter")
}
}
func (dc *dummyEchoDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
if dc.opt.closed.Load() {
return nil, true
}
dc.m.Lock()
defer dc.m.Unlock()
if dc.reserved >= dc.opt.mcq {
return nil, false
}
dc.reserved++
return dc, false
}
func (dc *dummyEchoDnsConn) Close() error {
return nil
}
func Test_PipelineTransport(t *testing.T) {
const (
mcq = 100
wantConn = 10
wantMaxConcurrentExchangeCall = mcq * wantConn
)
r := require.New(t)
dcControl := &dummyEchoDnsConnOpt{
mcq: mcq,
unblockExchange: make(chan struct{}),
wantConcurrentExchangeCall: wantMaxConcurrentExchangeCall,
}
po := PipelineOpts{
DialContext: func(ctx context.Context) (DnsConn, error) { return &dummyEchoDnsConn{opt: dcControl}, nil },
MaxConcurrentQueryWhileDialing: mcq,
}
pt := NewPipelineTransport(po)
defer pt.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
q := new(dns.Msg)
q.SetQuestion("test.", dns.TypeA)
queryPayload, err := q.Pack()
r.NoError(err)
wg := new(sync.WaitGroup)
for i := 0; i < wantMaxConcurrentExchangeCall; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := pt.ExchangeContext(ctx, queryPayload)
if err != nil {
t.Error(err)
}
}()
if t.Failed() {
break
}
}
wg.Wait()
pt.m.Lock()
pl := len(pt.conns)
pt.m.Unlock()
r.Equal(wantConn, pl)
dcControl.closed.Store(true)
_, _ = pt.ExchangeContext(ctx, queryPayload) // remove all closed conn
pt.m.Lock()
pl = len(pt.conns)
pt.m.Unlock()
r.Equal(1, pl, "all connection should be remove then one will be opened")
}

View File

@ -0,0 +1,341 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"context"
"errors"
"net"
"sync"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"go.uber.org/zap"
)
const (
// Most servers will send SERVFAIL after 3~5s. If no resp, connection may be dead.
reuseConnQueryTimeout = time.Second * 6
)
// ReuseConnTransport is for old tcp protocol. (no pipelining)
type ReuseConnTransport struct {
dialFunc func(ctx context.Context) (NetConn, error)
dialTimeout time.Duration
idleTimeout time.Duration
logger *zap.Logger // non-nil
ctx context.Context
ctxCancel context.CancelCauseFunc
m sync.Mutex // protect following fields
closed bool
idleConns map[*reusableConn]struct{}
conns map[*reusableConn]struct{}
// for testing
testWaitRespTimeout time.Duration
}
type ReuseConnOpts struct {
// DialContext specifies the method to dial a connection to the server.
// DialContext MUST NOT be nil.
DialContext func(ctx context.Context) (NetConn, error)
// DialTimeout specifies the timeout for DialFunc.
// Default is defaultDialTimeout.
DialTimeout time.Duration
// Default is defaultIdleTimeout
IdleTimeout time.Duration
Logger *zap.Logger
}
func NewReuseConnTransport(opt ReuseConnOpts) *ReuseConnTransport {
ctx, cancel := context.WithCancelCause(context.Background())
t := &ReuseConnTransport{
ctx: ctx,
ctxCancel: cancel,
idleConns: make(map[*reusableConn]struct{}),
conns: make(map[*reusableConn]struct{}),
}
t.dialFunc = opt.DialContext
setDefaultGZ(&t.dialTimeout, opt.DialTimeout, defaultDialTimeout)
setDefaultGZ(&t.idleTimeout, opt.IdleTimeout, defaultIdleTimeout)
setNonNilLogger(&t.logger, opt.Logger)
return t
}
func (t *ReuseConnTransport) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) {
const maxRetry = 2
retry := 0
for {
var isNewConn bool
c, err := t.getIdleConn()
if err != nil {
return nil, err
}
if c == nil {
isNewConn = true
c, err = t.getNewConn(ctx)
if err != nil {
return nil, err
}
}
queryPayload, err := copyMsgWithLenHdr(m)
if err != nil {
return nil, err
}
resp, err := c.exchange(ctx, queryPayload)
if err != nil {
if !isNewConn && retry <= maxRetry {
retry++
continue // retry if c is a reused connection.
}
return nil, err
}
return resp, nil
}
}
// getNewConn dial a *reusableConn.
// The caller must call releaseReusableConn to release the reusableConn.
func (t *ReuseConnTransport) getNewConn(ctx context.Context) (*reusableConn, error) {
callCtx, cancel := context.WithCancel(ctx)
defer cancel()
type dialRes struct {
c *reusableConn
err error
}
dialChan := make(chan dialRes)
go func() {
dialCtx, cancelDial := context.WithTimeout(t.ctx, t.dialTimeout)
defer cancelDial()
var rc *reusableConn
c, err := t.dialFunc(dialCtx)
if err != nil {
t.logger.Check(zap.WarnLevel, "fail to dial reusable conn").Write(zap.Error(err))
}
if c != nil {
rc = t.newReusableConn(c)
if rc == nil { // transport closed
c.Close()
rc = nil
err = ErrClosedTransport
}
}
select {
case dialChan <- dialRes{c: rc, err: err}:
case <-callCtx.Done(): // caller canceled getNewConn() call
if rc != nil { // put this conn to pool
t.setIdle(rc)
}
}
}()
select {
case <-callCtx.Done():
return nil, context.Cause(ctx)
case <-t.ctx.Done():
return nil, context.Cause(t.ctx)
case res := <-dialChan:
return res.c, res.err
}
}
func (t *ReuseConnTransport) setIdle(c *reusableConn) {
t.m.Lock()
defer t.m.Unlock()
if t.closed {
return
}
if _, ok := t.conns[c]; ok {
t.idleConns[c] = struct{}{}
}
}
// getIdleConn returns a *reusableConn from conn pool, or nil if no conn
// is idle.
// The caller must call releaseReusableConn to release the reusableConn.
func (t *ReuseConnTransport) getIdleConn() (*reusableConn, error) {
t.m.Lock()
defer t.m.Unlock()
if t.closed {
return nil, ErrClosedTransport
}
for c := range t.idleConns {
delete(t.idleConns, c)
return c, nil
}
return nil, nil
}
// Close closes ReuseConnTransport and all its connections.
// It always returns a nil error.
func (t *ReuseConnTransport) Close() error {
t.m.Lock()
defer t.m.Unlock()
if t.closed {
return nil
}
t.closed = true
for c := range t.conns {
delete(t.conns, c)
delete(t.idleConns, c)
c.closeWithErrByTransport(ErrClosedTransport)
}
t.ctxCancel(ErrClosedTransport)
return nil
}
type reusableConn struct {
c NetConn
t *ReuseConnTransport
m sync.Mutex
waitingResp chan *[]byte
closeOnce sync.Once
closeNotify chan struct{}
closeErr error
}
// return nil if transport was closed
func (t *ReuseConnTransport) newReusableConn(c NetConn) *reusableConn {
rc := &reusableConn{
c: c,
t: t,
closeNotify: make(chan struct{}),
}
t.m.Lock()
if t.closed { // t was closed.
t.m.Unlock()
return nil
}
t.conns[rc] = struct{}{}
t.m.Unlock()
go rc.readLoop()
return rc
}
var (
errUnexpectedResp = errors.New("server misbehaving: unexpected response")
)
func (c *reusableConn) readLoop() {
for {
resp, err := dnsutils.ReadRawMsgFromTCP(c.c)
if err != nil {
c.closeWithErr(err)
return
}
c.m.Lock()
respChan := c.waitingResp
c.waitingResp = nil
c.m.Unlock()
if respChan == nil {
pool.ReleaseBuf(resp)
c.closeWithErr(errUnexpectedResp)
return
}
// This connection is idled again.
c.c.SetReadDeadline(time.Now().Add(c.t.idleTimeout))
// Note: calling setIdle before sending resp back to make sure this connection is idle
// before Exchange call returning. Otherwise, Test_ReuseConnTransport may fail.
c.t.setIdle(c)
select {
case respChan <- resp:
default:
panic("bug: respChan has buffer, we shouldn't reach here")
}
}
}
func (c *reusableConn) closeWithErr(err error) {
if err == nil {
err = net.ErrClosed
}
c.closeOnce.Do(func() {
c.t.m.Lock()
delete(c.t.conns, c)
delete(c.t.idleConns, c)
c.t.m.Unlock()
c.closeErr = err
c.c.Close()
close(c.closeNotify)
})
}
func (c *reusableConn) closeWithErrByTransport(err error) {
if err == nil {
err = net.ErrClosed
}
c.closeOnce.Do(func() {
c.closeErr = err
c.c.Close()
close(c.closeNotify)
})
}
func (c *reusableConn) exchange(ctx context.Context, q *[]byte) (*[]byte, error) {
respChan := make(chan *[]byte, 1)
c.m.Lock()
if c.waitingResp != nil {
c.m.Unlock()
panic("bug: reusableConn: concurrent exchange calls")
}
c.waitingResp = respChan
c.m.Unlock()
waitRespTimeout := reuseConnQueryTimeout
if c.t.testWaitRespTimeout > 0 {
waitRespTimeout = c.t.testWaitRespTimeout
}
c.c.SetDeadline(time.Now().Add(waitRespTimeout))
_, err := c.c.Write(*q)
if err != nil {
c.closeWithErr(err)
return nil, err
}
select {
case resp := <-respChan:
return resp, nil
case <-c.closeNotify:
return nil, c.closeErr
case <-ctx.Done():
return nil, context.Cause(ctx)
}
}

View File

@ -0,0 +1,154 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"context"
"sync"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func Test_ReuseConnTransport(t *testing.T) {
const idleTimeout = time.Second * 5
r := require.New(t)
po := ReuseConnOpts{
DialContext: func(ctx context.Context) (NetConn, error) {
return newDummyEchoNetConn(0, 0, 0), nil
},
IdleTimeout: idleTimeout,
}
rt := NewReuseConnTransport(po)
defer rt.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
q := new(dns.Msg)
q.SetQuestion("test.", dns.TypeA)
queryPayload, err := q.Pack()
r.NoError(err)
concurrentQueryNum := 10
for l := 0; l < 4; l++ {
wg := new(sync.WaitGroup)
for i := 0; i < concurrentQueryNum; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := rt.ExchangeContext(ctx, queryPayload)
if err != nil {
t.Error(err)
}
}()
}
wg.Wait()
if t.Failed() {
return
}
}
rt.m.Lock()
connNum := len(rt.conns)
idledConnNum := len(rt.idleConns)
rt.m.Unlock()
r.Equal(0, connNum-idledConnNum, "there should be no active conn")
r.Equal(concurrentQueryNum, connNum)
r.Equal(concurrentQueryNum, idledConnNum, "all conn should be in idle status")
}
func Test_ReuseConnTransport_Read_err_and_close(t *testing.T) {
const idleTimeout = time.Second * 5
r := require.New(t)
po := ReuseConnOpts{
DialContext: func(ctx context.Context) (NetConn, error) {
return newDummyEchoNetConn(1, 0, 0), nil // 100% read err
},
IdleTimeout: idleTimeout,
}
rt := NewReuseConnTransport(po)
defer rt.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
q := new(dns.Msg)
q.SetQuestion("test.", dns.TypeA)
queryPayload, err := q.Pack()
r.NoError(err)
wg := new(sync.WaitGroup)
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := rt.ExchangeContext(ctx, queryPayload)
r.Error(err)
}()
if t.Failed() {
return
}
}
wg.Wait()
rt.m.Lock()
connNum := len(rt.conns)
idledConnNum := len(rt.idleConns)
rt.m.Unlock()
r.Equal(0, connNum)
r.Equal(0, idledConnNum)
}
func Test_ReuseConnTransport_conn_lose_and_close(t *testing.T) {
r := require.New(t)
po := ReuseConnOpts{
DialContext: func(ctx context.Context) (NetConn, error) {
return newDummyEchoNetConn(0, time.Second, 0), nil // 100% read timeout
},
}
rt := NewReuseConnTransport(po)
defer rt.Close()
rt.testWaitRespTimeout = time.Millisecond * 1
q := new(dns.Msg)
q.SetQuestion("test.", dns.TypeA)
queryPayload, err := q.Pack()
r.NoError(err)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err = rt.ExchangeContext(ctx, queryPayload) // canceled ctx
r.Error(err)
time.Sleep(time.Millisecond * 100)
rt.m.Lock()
connNum := len(rt.conns)
idledConnNum := len(rt.idleConns)
rt.m.Unlock()
// connection should be closed and removed
r.Equal(0, connNum)
r.Equal(0, idledConnNum)
}

View File

@ -0,0 +1,74 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"context"
"errors"
"io"
"time"
)
var (
ErrClosedTransport = errors.New("transport has been closed")
ErrPayloadOverFlow = errors.New("payload is too large")
ErrNewConnCannotReserveQueryExchanger = errors.New("new connection failed to reserve query exchanger")
ErrLazyConnCannotReserveQueryExchanger = errors.New("lazy connection failed to reserve query exchanger")
)
const (
defaultIdleTimeout = time.Second * 10
defaultDialTimeout = time.Second * 5
// If a pipeline connection sent a query but did not see any reply (include replies that
// for other queries) from the server after waitingReplyTimeout. It assumes that
// something goes wrong with the connection or the server. The connection will be closed.
waitingReplyTimeout = time.Second * 10
defaultTdcMaxConcurrentQuery = 32
defaultMaxLazyConnQueue = 16
)
// One method MUST be called in ReservedExchanger.
type ReservedExchanger interface {
// ExchangeReserved sends q to the server and returns it's response.
// ExchangeReserved MUST not modify nor keep the q.
// q MUST be a valid dns message.
// resp (if no err) should be released by ReleaseResp().
ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error)
// WithdrawReserved aborts the query.
WithdrawReserved()
}
type DnsConn interface {
// ReserveNewQuery reserves a query. It MUST be fast and non-block. If DnsConn
// cannot serve more query due to its capacity, ReserveNewQuery returns nil.
// If DnsConn is closed and can no longer serve more query, returns closed = true.
ReserveNewQuery() (_ ReservedExchanger, closed bool)
io.Closer
}
type NetConn interface {
io.ReadWriteCloser
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}

View File

@ -0,0 +1,90 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"encoding/binary"
"io"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
"go.uber.org/zap"
"golang.org/x/exp/constraints"
)
const (
dnsHeaderLen = 12 // minimum dns msg size
)
func copyMsgWithLenHdr(m []byte) (*[]byte, error) {
l := len(m)
if l > dns.MaxMsgSize {
return nil, ErrPayloadOverFlow
}
bp := pool.GetBuf(l + 2)
binary.BigEndian.PutUint16(*bp, uint16(l))
copy((*bp)[2:], m)
return bp, nil
}
func copyMsg(m []byte) *[]byte {
bp := pool.GetBuf(len(m))
copy((*bp), m)
return bp
}
// readMsgUdp reads dns frame from r. r typically should be a udp connection.
// It uses a 4kb rx buffer and ignores any payload that is too small for a dns msg.
// If no error, the length of payload always >= 12 bytes.
func readMsgUdp(r io.Reader) (*[]byte, error) {
// TODO: Make this configurable?
// 4kb should be enough.
payload := pool.GetBuf(4095)
readAgain:
n, err := r.Read(*payload)
if err != nil {
pool.ReleaseBuf(payload)
return nil, err
}
if n < dnsHeaderLen {
goto readAgain
}
*payload = (*payload)[:n]
return payload, err
}
func setDefaultGZ[T constraints.Float | constraints.Integer](i *T, s, d T) {
if s > 0 {
*i = s
} else {
*i = d
}
}
var nopLogger = zap.NewNop()
func setNonNilLogger(i **zap.Logger, s *zap.Logger) {
if s != nil {
*i = s
} else {
*i = nopLogger
}
}

609
pkg/upstream/upstream.go Normal file
View File

@ -0,0 +1,609 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package upstream
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"strconv"
"strings"
"time"
"github.com/IrineSistiana/mosdns/v5/mlog"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/bootstrap"
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/doh"
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/transport"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"go.uber.org/zap"
"golang.org/x/net/http2"
"golang.org/x/net/proxy"
)
const (
tlsHandshakeTimeout = time.Second * 3
// Maximum number of concurrent queries in one pipeline connection.
// See RFC 7766 7. Response Reordering.
// TODO: Make this configurable?
pipelineConcurrentLimit = 64
)
// Upstream represents a DNS upstream.
type Upstream interface {
// ExchangeContext exchanges query message m to the upstream, and returns
// response. It MUST NOT keep or modify m.
// m MUST be a valid dns msg frame. It MUST be at least 12 bytes
// (contain a valid dns header).
ExchangeContext(ctx context.Context, m []byte) (*[]byte, error)
io.Closer
}
type Opt struct {
// DialAddr specifies the address the upstream will
// actually dial to in the network layer by overwriting
// the address inferred from upstream url.
// It won't affect high level layers. (e.g. SNI, HTTP HOST header won't be changed).
// Can be an IP or a domain. Port is optional.
// Tips: If the upstream url host is a domain, specific an IP address
// here can skip resolving ip of this domain.
DialAddr string
// Socks5 specifies the socks5 proxy server that the upstream
// will connect though.
// Not implemented for udp based protocols (aka. dns over udp, http3, quic).
Socks5 string
// SoMark sets the socket SO_MARK option in unix system.
SoMark int
// BindToDevice sets the socket SO_BINDTODEVICE option in unix system.
BindToDevice string
// IdleTimeout specifies the idle timeout for long-connections.
// Default: TCP, DoT: 10s , DoH, DoH3, Quic: 30s.
IdleTimeout time.Duration
// EnablePipeline enables query pipelining support as RFC 7766 6.2.1.1 suggested.
// Available for TCP, DoT upstream.
// Note: There is no fallback. Make sure the server supports it.
EnablePipeline bool
// EnableHTTP3 will use HTTP/3 protocol to connect a DoH upstream. (aka DoH3).
// Note: There is no fallback. Make sure the server supports it.
EnableHTTP3 bool
// Bootstrap specifies a plain dns server to solve the
// upstream server domain address.
// It must be an IP address. Port is optional.
Bootstrap string
// Bootstrap version. One of 0 (default equals 4), 4, 6.
// TODO: Support dual-stack.
BootstrapVer int
// TLSConfig specifies the tls.Config that the TLS client will use.
// Available for DoT, DoH, DoQ upstream.
TLSConfig *tls.Config
// Logger specifies the logger that the upstream will use.
Logger *zap.Logger
// EventObserver can observe connection events.
// Not implemented for quic based protocol (DoH3, DoQ).
EventObserver EventObserver
}
// NewUpstream creates a upstream.
// addr has the format of: [protocol://]host[:port][/path].
// Supported protocol: udp/tcp/tls/https/quic. Default protocol is udp.
//
// Helper protocol:
// - tcp+pipeline/tls+pipeline: Automatically set opt.EnablePipeline to true.
// - h3: Automatically set opt.EnableHTTP3 to true.
func NewUpstream(addr string, opt Opt) (_ Upstream, err error) {
if opt.Logger == nil {
opt.Logger = mlog.Nop()
}
if opt.EventObserver == nil {
opt.EventObserver = nopEO{}
}
// parse protocol and server addr
if !strings.Contains(addr, "://") {
addr = "udp://" + addr
}
addrURL, err := url.Parse(addr)
if err != nil {
return nil, fmt.Errorf("invalid server address, %w", err)
}
// Apply helper protocol
switch addrURL.Scheme {
case "tcp+pipeline", "tls+pipeline":
addrURL.Scheme = addrURL.Scheme[:3]
opt.EnablePipeline = true
case "h3":
addrURL.Scheme = "https"
opt.EnableHTTP3 = true
}
// If host is a ipv6 without port, it will be in []. This will cause err when
// split and join address and port. Try to remove brackets now.
addrUrlHost := tryTrimIpv6Brackets(addrURL.Host)
dialer := &net.Dialer{
Control: getSocketControlFunc(socketOpts{
so_mark: opt.SoMark,
bind_to_device: opt.BindToDevice,
}),
}
var bootstrapAp netip.AddrPort
if s := opt.Bootstrap; len(s) > 0 {
bootstrapAp, err = parseBootstrapAp(s)
if err != nil {
return nil, fmt.Errorf("invalid bootstrap, %w", err)
}
}
newUdpAddrResolveFunc := func(defaultPort uint16) (func(ctx context.Context) (*net.UDPAddr, error), error) {
host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
if err != nil {
return nil, err
}
if addr, err := netip.ParseAddr(host); err == nil { // host is an ip.
ua := net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
return func(ctx context.Context) (*net.UDPAddr, error) {
return ua, nil
}, nil
} else { // Not an ip, assuming it's a domain name.
if bootstrapAp.IsValid() {
// Bootstrap enabled.
bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger)
if err != nil {
return nil, err
}
return func(ctx context.Context) (*net.UDPAddr, error) {
s, err := bs.GetAddrPortStr(ctx)
if err != nil {
return nil, fmt.Errorf("bootstrap failed, %w", err)
}
return net.ResolveUDPAddr("udp", s)
}, nil
} else {
// Bootstrap disabled.
dialAddr := joinPort(host, port)
return func(ctx context.Context) (*net.UDPAddr, error) {
return net.ResolveUDPAddr("udp", dialAddr)
}, nil
}
}
}
newTcpDialer := func(dialAddrMustBeIp bool, defaultPort uint16) (func(ctx context.Context) (net.Conn, error), error) {
host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
if err != nil {
return nil, err
}
// Socks5 enabled.
if s5Addr := opt.Socks5; len(s5Addr) > 0 {
socks5Dialer, err := proxy.SOCKS5("tcp", s5Addr, nil, dialer)
if err != nil {
return nil, fmt.Errorf("failed to init socks5 dialer: %w", err)
}
contextDialer := socks5Dialer.(proxy.ContextDialer)
dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
return func(ctx context.Context) (net.Conn, error) {
return contextDialer.DialContext(ctx, "tcp", dialAddr)
}, nil
}
if _, err := netip.ParseAddr(host); err == nil {
// Host is an ip addr. No need to resolve it.
dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
return func(ctx context.Context) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", dialAddr)
}, nil
} else {
if dialAddrMustBeIp {
return nil, errors.New("addr must be an ip address")
}
// Host is not an ip addr, assuming it is a domain.
if bootstrapAp.IsValid() {
// Bootstrap enabled.
bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger)
if err != nil {
return nil, err
}
return func(ctx context.Context) (net.Conn, error) {
dialAddr, err := bs.GetAddrPortStr(ctx)
if err != nil {
return nil, fmt.Errorf("bootstrap failed, %w", err)
}
return dialer.DialContext(ctx, "tcp", dialAddr)
}, nil
} else {
// Bootstrap disabled.
dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
return func(ctx context.Context) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", dialAddr)
}, nil
}
}
}
closeIfFuncErr := func(c io.Closer) {
if err != nil {
c.Close()
}
}
switch addrURL.Scheme {
case "", "udp":
const defaultPort = 53
const maxConcurrentQueryPreConn = 4096 // Protocol limit is 65535.
host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
if err != nil {
return nil, err
}
if _, err := netip.ParseAddr(host); err != nil {
return nil, fmt.Errorf("addr must be an ip address, %w", err)
}
dialAddr := joinPort(host, port)
dialUdpPipeline := func(ctx context.Context) (transport.DnsConn, error) {
c, err := dialer.DialContext(ctx, "udp", dialAddr)
if err != nil {
return nil, err
}
to := transport.TraditionalDnsConnOpts{
WithLengthHeader: false,
IdleTimeout: time.Minute * 5,
MaxConcurrentQuery: maxConcurrentQueryPreConn,
}
return transport.NewDnsConn(to, wrapConn(c, opt.EventObserver)), nil
}
dialTcpNetConn := func(ctx context.Context) (transport.NetConn, error) {
c, err := dialer.DialContext(ctx, "tcp", dialAddr)
if err != nil {
return nil, err
}
return wrapConn(c, opt.EventObserver), nil
}
return &udpWithFallback{
u: transport.NewPipelineTransport(transport.PipelineOpts{
DialContext: dialUdpPipeline,
MaxConcurrentQueryWhileDialing: maxConcurrentQueryPreConn,
Logger: opt.Logger,
}),
t: transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialTcpNetConn}),
}, nil
case "tcp":
const defaultPort = 53
tcpDialer, err := newTcpDialer(true, defaultPort)
if err != nil {
return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
}
idleTimeout := opt.IdleTimeout
if idleTimeout <= 0 {
idleTimeout = time.Second * 10
}
dialNetConn := func(ctx context.Context) (transport.NetConn, error) {
c, err := tcpDialer(ctx)
if err != nil {
return nil, err
}
return wrapConn(c, opt.EventObserver), nil
}
if opt.EnablePipeline {
to := transport.TraditionalDnsConnOpts{
WithLengthHeader: true,
IdleTimeout: idleTimeout,
MaxConcurrentQuery: pipelineConcurrentLimit,
}
dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) {
c, err := dialNetConn(ctx)
if err != nil {
return nil, err
}
return transport.NewDnsConn(to, c), nil
}
return transport.NewPipelineTransport(transport.PipelineOpts{
DialContext: dialDnsConn,
MaxConcurrentQueryWhileDialing: pipelineConcurrentLimit,
Logger: opt.Logger,
}), nil
}
return transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialNetConn, IdleTimeout: idleTimeout}), nil
case "tls":
const defaultPort = 853
tlsConfig := opt.TLSConfig.Clone()
if tlsConfig == nil {
tlsConfig = new(tls.Config)
}
if len(tlsConfig.ServerName) == 0 {
tlsConfig.ServerName = tryRemovePort(addrUrlHost)
}
tcpDialer, err := newTcpDialer(false, defaultPort)
if err != nil {
return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
}
dialNetConn := func(ctx context.Context) (transport.NetConn, error) {
conn, err := tcpDialer(ctx)
if err != nil {
return nil, err
}
conn = wrapConn(conn, opt.EventObserver)
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
tlsConn.Close()
return nil, err
}
return wrapConn(tlsConn, opt.EventObserver), nil
}
if opt.EnablePipeline {
to := transport.TraditionalDnsConnOpts{
WithLengthHeader: true,
IdleTimeout: opt.IdleTimeout,
MaxConcurrentQuery: pipelineConcurrentLimit,
}
dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) {
c, err := dialNetConn(ctx)
if err != nil {
return nil, err
}
return transport.NewDnsConn(to, c), nil
}
return transport.NewPipelineTransport(transport.PipelineOpts{
DialContext: dialDnsConn,
MaxConcurrentQueryWhileDialing: pipelineConcurrentLimit,
Logger: opt.Logger,
}), nil
}
return transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialNetConn}), nil
case "https":
const defaultPort = 443
idleConnTimeout := time.Second * 30
if opt.IdleTimeout > 0 {
idleConnTimeout = opt.IdleTimeout
}
var t http.RoundTripper
var addonCloser io.Closer
if opt.EnableHTTP3 {
udpBootstrap, err := newUdpAddrResolveFunc(defaultPort)
if err != nil {
return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err)
}
lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})}
conn, err := lc.ListenPacket(context.Background(), "udp", "")
if err != nil {
return nil, fmt.Errorf("failed to init udp socket for quic, %w", err)
}
quicTransport := &quic.Transport{
Conn: conn,
}
quicConfig := newDefaultClientQuicConfig()
quicConfig.MaxIdleTimeout = idleConnTimeout
defer closeIfFuncErr(quicTransport)
addonCloser = quicTransport
t = &http3.RoundTripper{
TLSClientConfig: opt.TLSConfig,
QUICConfig: quicConfig,
Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
ua, err := udpBootstrap(ctx)
if err != nil {
return nil, err
}
return quicTransport.DialEarly(ctx, ua, tlsCfg, cfg)
},
MaxResponseHeaderBytes: 4 * 1024,
}
} else {
tcpDialer, err := newTcpDialer(false, defaultPort)
if err != nil {
return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
}
t1 := &http.Transport{
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { // overwrite server addr
c, err := tcpDialer(ctx)
c = wrapConn(c, opt.EventObserver)
return c, err
},
TLSClientConfig: opt.TLSConfig,
TLSHandshakeTimeout: tlsHandshakeTimeout,
IdleConnTimeout: idleConnTimeout,
// Following opts are for http/1 only.
// MaxConnsPerHost: 2,
// MaxIdleConnsPerHost: 2,
}
t2, err := http2.ConfigureTransports(t1)
if err != nil {
return nil, fmt.Errorf("failed to upgrade http2 support, %w", err)
}
t2.MaxHeaderListSize = 4 * 1024
t2.MaxReadFrameSize = 16 * 1024
t2.ReadIdleTimeout = time.Second * 30
t2.PingTimeout = time.Second * 5
t = t1
}
u, err := doh.NewUpstream(addrURL.String(), t, opt.Logger)
if err != nil {
return nil, fmt.Errorf("failed to create doh upstream, %w", err)
}
return &dohWithClose{
u: u,
closer: addonCloser,
}, nil
case "quic", "doq":
const defaultPort = 853
tlsConfig := opt.TLSConfig.Clone()
if tlsConfig == nil {
tlsConfig = new(tls.Config)
}
if len(tlsConfig.ServerName) == 0 {
tlsConfig.ServerName = tryRemovePort(addrUrlHost)
}
tlsConfig.NextProtos = []string{"doq"}
quicConfig := newDefaultClientQuicConfig()
if opt.IdleTimeout > 0 {
quicConfig.MaxIdleTimeout = opt.IdleTimeout
}
// Don't accept stream.
quicConfig.MaxIncomingStreams = -1
quicConfig.MaxIncomingUniStreams = -1
udpBootstrap, err := newUdpAddrResolveFunc(defaultPort)
if err != nil {
return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err)
}
srk, _, err := utils.InitQUICSrkFromIfaceMac()
if err != nil {
opt.Logger.Warn("failed to init quic stateless reset key, it will be disabled", zap.Error(err))
}
lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})}
uc, err := lc.ListenPacket(context.Background(), "udp", "")
if err != nil {
return nil, fmt.Errorf("failed to init udp socket for quic, %w", err)
}
t := &quic.Transport{
Conn: uc,
StatelessResetKey: (*quic.StatelessResetKey)(srk),
}
dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) {
ua, err := udpBootstrap(ctx)
if err != nil {
return nil, fmt.Errorf("bootstrap failed, %w", err)
}
// This is a workaround to
// 1. recover from strange 0rtt rejected err.
// 2. avoid NextConnection might block forever.
// TODO: Remove this workaround.
var c quic.Connection
ec, err := t.DialEarly(ctx, ua, tlsConfig, quicConfig)
if err != nil {
return nil, err
}
c, err = ec.NextConnection(ctx)
if err != nil {
return nil, err
}
return transport.NewQuicDnsConn(c), nil
}
return transport.NewPipelineTransport(transport.PipelineOpts{
DialContext: dialDnsConn,
// Quic rfc recommendation is 100. Some implications use 65535.
MaxConcurrentQueryWhileDialing: 90,
Logger: opt.Logger,
}), nil
default:
return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme)
}
}
type udpWithFallback struct {
u *transport.PipelineTransport
t *transport.ReuseConnTransport
}
func (u *udpWithFallback) ExchangeContext(ctx context.Context, q []byte) (*[]byte, error) {
r, err := u.u.ExchangeContext(ctx, q)
if err != nil {
return nil, err
}
if msgTruncated(*r) {
pool.ReleaseBuf(r)
return u.t.ExchangeContext(ctx, q)
}
return r, nil
}
func (u *udpWithFallback) Close() error {
u.u.Close()
u.t.Close()
return nil
}
type dohWithClose struct {
u *doh.Upstream
closer io.Closer // maybe nil
}
func (u *dohWithClose) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) {
return u.u.ExchangeContext(ctx, m)
}
func (u *dohWithClose) Close() error {
if u.closer != nil {
return u.closer.Close()
}
return nil
}
func newDefaultClientQuicConfig() *quic.Config {
return &quic.Config{
TokenStore: quic.NewLRUTokenStore(4, 8),
// Dns does not need large amount of io, so the rx/tx windows are small.
InitialStreamReceiveWindow: 4 * 1024,
MaxStreamReceiveWindow: 4 * 1024,
InitialConnectionReceiveWindow: 8 * 1024,
MaxConnectionReceiveWindow: 64 * 1024,
MaxIdleTimeout: time.Second * 30,
KeepAlivePeriod: time.Second * 25,
HandshakeIdleTimeout: tlsHandshakeTimeout,
}
}

View File

@ -0,0 +1,233 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package upstream
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/miekg/dns"
)
func newUDPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
udpConn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
udpAddr := udpConn.LocalAddr().String()
udpServer := dns.Server{
PacketConn: udpConn,
Handler: handler,
}
go udpServer.ActivateAndServe()
return udpAddr, func() {
udpServer.Shutdown()
}
}
func newTCPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
tcpAddr := l.Addr().String()
tcpServer := dns.Server{
Listener: l,
Handler: handler,
MaxTCPQueries: -1,
}
go tcpServer.ActivateAndServe()
return tcpAddr, func() {
tcpServer.Shutdown()
}
}
func newDoTTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
serverName := "test"
cert, err := utils.GenerateCertificate(serverName)
if err != nil {
t.Fatal(err)
}
tlsConfig := new(tls.Config)
tlsConfig.Certificates = []tls.Certificate{cert}
tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
if err != nil {
t.Fatal(err)
}
doTAddr := tlsListener.Addr().String()
doTServer := dns.Server{
Net: "tcp-tls",
Listener: tlsListener,
TLSConfig: tlsConfig,
Handler: handler,
MaxTCPQueries: -1,
}
go doTServer.ActivateAndServe()
return doTAddr, func() {
doTServer.Shutdown()
}
}
type newTestServerFunc func(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func())
var m = map[string]newTestServerFunc{
"udp": newUDPTestServer,
"tcp": newTCPTestServer,
"tls": newDoTTestServer,
}
func Test_fastUpstream(t *testing.T) {
// TODO: add test for doh
// TODO: add test for socks5
// server config
for scheme, f := range m {
for _, bigMsg := range [...]bool{true, false} {
for _, latency := range [...]time.Duration{0, time.Millisecond * 10} {
// client specific
for _, idleTimeout := range [...]time.Duration{0, time.Second} {
testName := fmt.Sprintf(
"test: protocol: %s, bigMsg: %v, latency: %s, getIdleTimeout: %s",
scheme,
bigMsg,
latency,
idleTimeout,
)
t.Run(testName, func(t *testing.T) {
addr, shutdownServer := f(t, &vServer{
latency: latency,
bigMsg: bigMsg,
})
defer shutdownServer()
u, err := NewUpstream(
scheme+"://"+addr,
Opt{
IdleTimeout: time.Second,
TLSConfig: &tls.Config{InsecureSkipVerify: true},
},
)
if err != nil {
t.Fatal(err)
}
if err := testUpstream(u); err != nil {
t.Fatal(err)
}
})
}
}
}
}
}
func testUpstream(u Upstream) error {
wg := sync.WaitGroup{}
errs := make([]error, 0)
errsLock := sync.Mutex{}
logErr := func(err error) {
errsLock.Lock()
errs = append(errs, err)
errsLock.Unlock()
}
errsToString := func() string {
s := fmt.Sprintf("%d err(s) occured during the test: ", len(errs))
for i := range errs {
s = s + errs[i].Error() + "|"
}
return s
}
for i := uint16(0); i < 10; i++ {
wg.Add(1)
i := i
go func() {
defer wg.Done()
q := new(dns.Msg)
q.SetQuestion("example.com.", dns.TypeA)
q.Id = i
queryPayload, err := q.Pack()
if err != nil {
logErr(err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := u.ExchangeContext(ctx, queryPayload)
if err != nil {
logErr(err)
return
}
resp := new(dns.Msg)
err = resp.Unpack(*r)
if err != nil {
logErr(err)
return
}
if q.Id != resp.Id {
logErr(dns.ErrId)
return
}
if !resp.Response {
logErr(fmt.Errorf("resp is not a resp bit"))
return
}
}()
}
wg.Wait()
if len(errs) != 0 {
return errors.New(errsToString())
}
return nil
}
type vServer struct {
latency time.Duration
bigMsg bool // with 1kb padding
}
var padding = make([]byte, 1024)
func (s *vServer) ServeDNS(w dns.ResponseWriter, q *dns.Msg) {
r := new(dns.Msg)
r.SetReply(q)
if s.bigMsg {
r.SetEdns0(dns.MaxMsgSize, false)
opt := r.IsEdns0()
opt.Option = append(opt.Option, &dns.EDNS0_PADDING{Padding: padding})
}
time.Sleep(s.latency)
w.WriteMsg(r)
}

104
pkg/upstream/utils.go Normal file
View File

@ -0,0 +1,104 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package upstream
import (
"fmt"
"net"
"net/netip"
"strconv"
)
type socketOpts struct {
so_mark int
bind_to_device string
}
func parseDialAddr(urlHost, dialAddr string, defaultPort uint16) (string, uint16, error) {
addr := urlHost
if len(dialAddr) > 0 {
addr = dialAddr
}
host, port, err := trySplitHostPort(addr)
if err != nil {
return "", 0, err
}
if port == 0 {
port = defaultPort
}
return host, port, nil
}
func joinPort(host string, port uint16) string {
return net.JoinHostPort(host, strconv.Itoa(int(port)))
}
func tryRemovePort(s string) string {
host, _, err := net.SplitHostPort(s)
if err != nil {
return s
}
return host
}
// trySplitHostPort splits host and port.
// If s has no port, it returns s,0,nil
func trySplitHostPort(s string) (string, uint16, error) {
var port uint16
host, portS, err := net.SplitHostPort(s)
if err == nil {
n, err := strconv.ParseUint(portS, 10, 16)
if err != nil {
return "", 0, fmt.Errorf("invalid port, %w", err)
}
port = uint16(n)
return host, port, nil
}
return s, 0, nil
}
func parseBootstrapAp(s string) (netip.AddrPort, error) {
host, port, err := trySplitHostPort(s)
if err != nil {
return netip.AddrPort{}, err
}
if port == 0 {
port = 53
}
addr, err := netip.ParseAddr(host)
if err != nil {
return netip.AddrPort{}, err
}
return netip.AddrPortFrom(addr, port), nil
}
func tryTrimIpv6Brackets(s string) string {
if len(s) < 2 {
return s
}
if s[0] == '[' && s[len(s)-1] == ']' {
return s[1 : len(s)-2]
}
return s
}
func msgTruncated(b []byte) bool {
return b[2]&(1<<1) != 0
}

View File

@ -0,0 +1,28 @@
//go:build !linux
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package upstream
import "syscall"
func getSocketControlFunc(_ socketOpts) func(string, string, syscall.RawConn) error {
return nil
}

View File

@ -0,0 +1,58 @@
//go:build linux
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package upstream
import (
"os"
"syscall"
"golang.org/x/sys/unix"
)
func getSocketControlFunc(opts socketOpts) func(string, string, syscall.RawConn) error {
return func(_, _ string, c syscall.RawConn) error {
var sysCallErr error
if err := c.Control(func(fd uintptr) {
// SO_MARK
if opts.so_mark > 0 {
sysCallErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.so_mark)
if sysCallErr != nil {
sysCallErr = os.NewSyscallError("failed to set SO_MARK", sysCallErr)
return
}
}
// SO_BINDTODEVICE
if len(opts.bind_to_device) > 0 {
sysCallErr = unix.SetsockoptString(int(fd), unix.SOL_SOCKET, unix.SO_BINDTODEVICE, opts.bind_to_device)
if sysCallErr != nil {
sysCallErr = os.NewSyscallError("failed to set SO_BINDTODEVICE", sysCallErr)
return
}
}
}); err != nil {
return err
}
return sysCallErr
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package utils
import (
"github.com/mitchellh/mapstructure"
"golang.org/x/exp/constraints"
"strconv"
)
func SetDefaultNum[K constraints.Integer | constraints.Float](p *K, d K) {
if *p == 0 {
*p = d
}
}
func SetDefaultUnsignNum[K constraints.Integer | constraints.Float](p *K, d K) {
if *p <= 0 {
*p = d
}
}
func SetDefaultString(p *string, d string) {
if len(*p) == 0 {
*p = d
}
}
func CheckNumRange[K constraints.Integer | constraints.Float](v, min, max K) bool {
if v < min || v > max {
return false
}
return true
}
// WeakDecode decodes args from config to output.
func WeakDecode(in any, output any) error {
config := &mapstructure.DecoderConfig{
ErrorUnused: true,
Result: output,
WeaklyTypedInput: true,
TagName: "yaml",
}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return err
}
return decoder.Decode(in)
}
func ParseNameOrNum[T constraints.Integer](s string, m map[string]T) (T, bool) {
i, err := strconv.Atoi(s)
if err != nil {
v, ok := m[s]
return v, ok
}
return T(i), true
}

58
pkg/utils/net.go Normal file
View File

@ -0,0 +1,58 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package utils
import (
"net"
"net/netip"
)
// GetIPFromAddr returns a net.IP from the given net.Addr.
// addr can be *net.TCPAddr, *net.UDPAddr, *net.IPNet, *net.IPAddr
// Will return nil otherwise.
func GetIPFromAddr(addr net.Addr) (ip net.IP) {
switch v := addr.(type) {
case *net.TCPAddr:
return v.IP
case *net.UDPAddr:
return v.IP
case *net.IPNet:
return v.IP
case *net.IPAddr:
return v.IP
}
return nil
}
// GetAddrFromAddr returns netip.Addr from net.Addr.
// See also: GetIPFromAddr.
func GetAddrFromAddr(addr net.Addr) netip.Addr {
a, _ := netip.AddrFromSlice(GetIPFromAddr(addr))
return a
}
// SplitSchemeAndHost splits addr to protocol and host.
func SplitSchemeAndHost(addr string) (protocol, host string) {
if protocol, host, ok := SplitString2(addr, "://"); ok {
return protocol, host
} else {
return "", addr
}
}

52
pkg/utils/quic.go Normal file
View File

@ -0,0 +1,52 @@
package utils
import (
"crypto/sha256"
"errors"
"net"
"sync"
)
var quicSrkSalt = []byte{115, 189, 156, 229, 145, 216, 251, 127, 220, 89,
243, 234, 211, 79, 190, 166, 135, 253, 183, 36, 245, 174, 78, 200, 54, 213,
85, 255, 104, 240, 103, 27}
var (
quicSrkInitOnce sync.Once
quicSrk *[32]byte
quicSrkFromIface net.Interface
quicSrkInitErr error
)
func initQUICSrkFromIfaceMac() {
nonZero := func(b []byte) bool {
for _, i := range b {
if i != 0 {
return true
}
}
return false
}
ifaces, err := net.Interfaces()
if err != nil {
quicSrkInitErr = err
return
}
for _, iface := range ifaces {
if nonZero(iface.HardwareAddr) {
k := sha256.Sum256(append(iface.HardwareAddr, quicSrkSalt...))
quicSrk = &k
quicSrkFromIface = iface
return
}
}
quicSrkInitErr = errors.New("cannot find non-zero mac interface")
}
// A helper func to init quic stateless reset key.
// It use the first non-zero interface mac + sha256 hash.
func InitQUICSrkFromIfaceMac() (*[32]byte, net.Interface, error) {
quicSrkInitOnce.Do(initQUICSrkFromIfaceMac)
return quicSrk, quicSrkFromIface, quicSrkInitErr
}

20
pkg/utils/server.go Normal file
View File

@ -0,0 +1,20 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package utils

49
pkg/utils/strings.go Normal file
View File

@ -0,0 +1,49 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package utils
import (
"strings"
"unsafe"
)
// BytesToStringUnsafe converts bytes to string.
func BytesToStringUnsafe(b []byte) string {
return unsafe.String(unsafe.SliceData(b), len(b))
}
// RemoveComment removes comment after "symbol".
func RemoveComment(s, symbol string) string {
if i := strings.Index(s, symbol); i >= 0 {
return s[:i]
}
return s
}
// SplitString2 split s to two parts by given symbol
func SplitString2(s, symbol string) (s1 string, s2 string, ok bool) {
if len(symbol) == 0 {
return "", s, true
}
if i := strings.Index(s, symbol); i >= 0 {
return s[:i], s[i+len(symbol):], true
}
return "", "", false
}

107
pkg/utils/utils.go Normal file
View File

@ -0,0 +1,107 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package utils
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"time"
)
// LoadCertPool reads and loads certificates in certs.
func LoadCertPool(certs []string) (*x509.CertPool, error) {
rootCAs := x509.NewCertPool()
for _, cert := range certs {
b, err := os.ReadFile(cert)
if err != nil {
return nil, err
}
if ok := rootCAs.AppendCertsFromPEM(b); !ok {
return nil, fmt.Errorf("no certificate was successfully parsed in %s", cert)
}
}
return rootCAs, nil
}
// GenerateCertificate generates an ecdsa certificate with given dnsName.
// This should only use in test.
func GenerateCertificate(dnsName string) (cert tls.Certificate, err error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return
}
//serial number
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
err = fmt.Errorf("generate serial number: %w", err)
return
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{CommonName: dnsName},
DNSNames: []string{dnsName},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return
}
b, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
// ClosedChan returns true if c is closed.
// c must not use for sending data and must be used in close() only.
// If ClosedChan receives something from c, it panics.
func ClosedChan(c chan struct{}) bool {
select {
case _, ok := <-c:
if !ok {
return true
}
panic("received from the chan")
default:
return false
}
}

123
pkg/utils/utils_test.go Normal file
View File

@ -0,0 +1,123 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package utils
import (
"reflect"
"testing"
)
func TestSplitString2(t *testing.T) {
type args struct {
s string
symbol string
}
tests := []struct {
name string
args args
wantS1 string
wantS2 string
wantOk bool
}{
{"blank", args{"", ""}, "", "", true},
{"blank", args{"///", ""}, "", "///", true},
{"split", args{"///", "/"}, "", "//", true},
{"split", args{"--/", "/"}, "--", "", true},
{"split", args{"https://***.***.***", "://"}, "https", "***.***.***", true},
{"split", args{"://***.***.***", "://"}, "", "***.***.***", true},
{"split", args{"https://", "://"}, "https", "", true},
{"split", args{"--/", "*"}, "", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotS1, gotS2, gotOk := SplitString2(tt.args.s, tt.args.symbol)
if gotS1 != tt.wantS1 {
t.Errorf("SplitString2() gotS1 = %v, want %v", gotS1, tt.wantS1)
}
if gotS2 != tt.wantS2 {
t.Errorf("SplitString2() gotS2 = %v, want %v", gotS2, tt.wantS2)
}
if gotOk != tt.wantOk {
t.Errorf("SplitString2() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func TestRemoveComment(t *testing.T) {
type args struct {
s string
symbol string
}
tests := []struct {
name string
args args
want string
}{
{name: "empty", args: args{s: "", symbol: ""}, want: ""},
{name: "empty symbol", args: args{s: "12345", symbol: ""}, want: ""},
{name: "empty string", args: args{s: "", symbol: "#"}, want: ""},
{name: "remove 1", args: args{s: "123/456", symbol: "/"}, want: "123"},
{name: "remove 2", args: args{s: "123//456", symbol: "//"}, want: "123"},
{name: "remove 3", args: args{s: "123/*/456", symbol: "//"}, want: "123/*/456"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := RemoveComment(tt.args.s, tt.args.symbol); got != tt.want {
t.Errorf("RemoveComment() = %v, want %v", got, tt.want)
}
})
}
}
type TestArgsStruct struct {
A string `yaml:"1"`
B []int `yaml:"2"`
}
func Test_WeakDecode(t *testing.T) {
testObj := new(TestArgsStruct)
testArgs := map[string]any{
"1": "test",
"2": []int{1, 2, 3},
}
wantObj := &TestArgsStruct{
A: "test",
B: []int{1, 2, 3},
}
err := WeakDecode(testArgs, testObj)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(testObj, wantObj) {
t.Fatalf("args decode failed, want %v, got %v", wantObj, testObj)
}
}
func Test_WeakDecode2(t *testing.T) {
testObj := new([]byte)
args := []any{"1", 2, 3}
err := WeakDecode(args, testObj)
if err != nil {
t.Fatal(err)
}
}

View File

@ -0,0 +1,85 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package zone_file
import (
"io"
"os"
"strings"
"github.com/miekg/dns"
)
type Matcher struct {
m map[dns.Question][]dns.RR
}
func (m *Matcher) LoadFile(s string) error {
f, err := os.Open(s)
if err != nil {
return err
}
defer f.Close()
return m.Load(f)
}
func (m *Matcher) Load(r io.Reader) error {
if m.m == nil {
m.m = make(map[dns.Question][]dns.RR)
}
parser := dns.NewZoneParser(r, "", "")
parser.SetDefaultTTL(3600)
for {
rr, ok := parser.Next()
if !ok {
break
}
h := rr.Header()
q := dns.Question{
Name: strings.ToLower(h.Name),
Qtype: h.Rrtype,
Qclass: h.Class,
}
m.m[q] = append(m.m[q], rr)
}
return parser.Err()
}
func (m *Matcher) Search(q dns.Question) []dns.RR {
q.Name = strings.ToLower(q.Name)
return m.m[q]
}
func (m *Matcher) Reply(q *dns.Msg) *dns.Msg {
var r *dns.Msg
for _, question := range q.Question {
rr := m.Search(question)
if rr != nil {
if r == nil {
r = new(dns.Msg)
r.SetReply(q)
}
r.Answer = append(r.Answer, rr...)
}
}
return r
}

Some files were not shown because too many files have changed in this diff Show More