commit cd761e81458eecc3b6e7bfd786887c5cf802bc66 Author: dengxiongjian Date: Thu Jul 31 11:28:55 2025 +0800 新增Mikrotik API 插入解析ip diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000..87a9fcc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,63 @@ +name: "Bug report" +description: "[反馈 Bug 必须用此模板] 程序不能运行?或者没有按照期望的那样工作?" +title: "[Bug] " +body: + - type: markdown + attributes: + value: "感谢热心的反馈 bug。描述越详细越有助于定位和解决 Bug。不提供有效信息的反馈可能会被直接关闭。" + + - type: checkboxes + id: pre-check + attributes: + label: "在提交之前,请确认" + options: + - label: "我已经尝试搜索过 Issue ,但没有找到相关问题。" + required: true + - label: "我正在使用最新的 mosdns 版本(或者最新的 commit),问题依旧存在。" + required: true + - label: "我仔细看过 wiki 后仍然无法自行解决该问题。" + required: true + - label: "我非常确定这是 mosdns 核心的问题。(如果是通过第三方衍生软件使用 mosdns + 核心,不确定问题源头时,请先向衍生软件开发者提交问题。)" + required: true + + - type: input + id: version + attributes: + label: mosdns 版本 + description: "不清楚可用 `mosdns version` 查看。" + placeholder: v9.9.9 + validations: + required: true + + - type: input + id: system + attributes: + label: 操作系统 + placeholder: ubuntu + validations: + required: true + + - type: textarea + id: what-happened + attributes: + label: Bug 描述和复现步骤 + description: "描述越详细越有助于定位和解决 Bug。" + placeholder: "示例: Bug: mosdns 的 qname 匹配器无法匹配 xxxx 域名。复现方式: 使用如下配置,请求 xxxx.xxxx ,观察 log 发现匹配器没有匹配到。" + validations: + required: true + + - type: textarea + id: config + attributes: + label: 使用的配置文件 + render: yaml + description: "必须是完整的配置文件。不要只复制某个插件的配置片段。请尽可能的提供最小配置文件(能复现 bug,但没有其他功能)方便开发者定位问题。" + validations: + required: true + + - type: textarea + id: log + attributes: + label: mosdns 的 log 记录 + render: txt diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..fc3c1aa --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,16 @@ +name: "Feature request" +description: "希望添加新功能" +title: "[Feature request] " +body: + - type: markdown + attributes: + value: "感谢提出建议。提交的 issue 不一定会马上得到开发者的回复。开发者会根据社区的反应以及功能的重要性选择是否实现这个新功能。" + + - type: textarea + id: new-feature + attributes: + label: 希望添加的功能 + description: "请详细描述一下该功能作用,使用场景,是否有类似实现,文档等。" + placeholder: "希望支持 ... 可以更快的 ..." + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/other-questions.md b/.github/ISSUE_TEMPLATE/other-questions.md new file mode 100644 index 0000000..0cfd9f6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/other-questions.md @@ -0,0 +1,10 @@ +--- +name: Other questions +about: 不要在 Issue 里提问。有问题请进入 Discussions 讨论。 +title: '' +labels: '' +assignees: '' + +--- + +不要在 issue 里提问。有问题请进入 Discussions 讨论。 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..8bb4e25 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,38 @@ +name: Release mosdns + +on: + push: + tags: + - 'v*' + +jobs: + + build-release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '1.23' + check-latest: true + cache: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Build + run: python ./release.py + env: + CGO_ENABLED: '0' + + - name: Publish + uses: softprops/action-gh-release@v2 + with: + files: './release/mosdns*.zip' + prerelease: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..5480900 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,24 @@ +name: Test mosdns + +on: + push: + pull_request: + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.23' + check-latest: true + cache: true + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -race -v ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a73e9d6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +vendor/ + +# release dir +release/ + +# ide +.vscode/ +.idea/ + +# test utils +testutils/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..a9da9df --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +FROM golang:latest as builder +ARG CGO_ENABLED=0 + +COPY ./ /root/src/ +WORKDIR /root/src/ +RUN go build -ldflags "-s -w -X main.version=$(git describe --tags --long --always)" -trimpath -o mosdns + +FROM alpine:latest + +COPY --from=builder /root/src/mosdns /usr/bin/ + +RUN apk add --no-cache ca-certificates \ No newline at end of file diff --git a/Dockerfile_buildx b/Dockerfile_buildx new file mode 100644 index 0000000..2fe4a7c --- /dev/null +++ b/Dockerfile_buildx @@ -0,0 +1,14 @@ +FROM --platform=${TARGETPLATFORM} golang:latest as builder + +ARG CGO_ENABLED=0 + +COPY ./ /root/src/ +WORKDIR /root/src/ +RUN go build -ldflags "-s -w -X main.version=$(git describe --tags --long --always)" -trimpath -o mosdns + + +FROM --platform=${TARGETPLATFORM} alpine:latest + +COPY --from=builder /root/src/mosdns /usr/bin/ + +RUN apk add --no-cache ca-certificates \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e62ec04 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ +GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md new file mode 100644 index 0000000..8186fc1 --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +# mosdns + +功能概述、配置方式、教程等,详见: [wiki](https://irine-sistiana.gitbook.io/mosdns-wiki/) + +下载预编译文件、更新日志,详见: [release](https://github.com/IrineSistiana/mosdns/releases) + +docker 镜像: [docker hub](https://hub.docker.com/r/irinesistiana/mosdns) diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..5bd2db6 --- /dev/null +++ b/config.yaml @@ -0,0 +1,192 @@ +# ============================================ +# MosDNS v5 完整配置:中英文注释版 +# ============================================ + +log: + level: info # 可选: debug/info/warn/error + +# 管理 API(可用于调试、监控) +api: + http: "0.0.0.0:5535" + +# 引入上游 DNS 配置文件(在 dns.yaml 中) +include: ['/opt/mosdns/dns.yaml'] + +plugins: + # ====================== + # 域名/IP 匹配规则 + # ====================== + + # GFW 域名列表(如 google.com) + - tag: GFW_domains + type: domain_set + args: + files: + - "/opt/mosdns/config/geosite_tiktok.txt" + - "/opt/mosdns/config/gfwlist.out.txt" + + # Amazon 域名列表(如 amazon.com) + - tag: amazon_domains + type: domain_set + args: + files: + - "/opt/mosdns/config/geosite_amazon.txt" + - "/opt/mosdns/config/geosite_amazon-ads.txt" + - "/opt/mosdns/config/geosite_amazontrust.txt" + - "/opt/mosdns/config/amazon.txt" + + # 中国大陆常用域名(如 .cn / baidu.com) + - tag: CN_domains + type: domain_set + args: + files: + - "/opt/mosdns/config/domains.txt" + + # 中国大陆 IP 列表 + - tag: geoip_cn + type: ip_set + args: + files: + - "/opt/mosdns/config/cn.txt" + + # ====================== + # 缓存模块 + # ====================== + - tag: cache + type: cache + args: + size: 32768 # 最大缓存条目数 + lazy_cache_ttl: 43200 # 默认缓存 TTL(秒) + + # ====================== + # 上游 DNS 定义 + # ====================== + + # 国内 DNS fallback 模式 + - tag: forward_local + type: fallback + args: + primary: cn-dns + secondary: cn-dns + threshold: 500 + always_standby: true + + # 国外 DNS fallback 模式 + - tag: forward_remote + type: fallback + args: + primary: jp-dns + secondary: jp-dns + threshold: 500 + always_standby: true + + # 封装调用国内 DNS + - tag: forward_local_upstream + type: sequence + args: + - exec: prefer_ipv4 + - exec: query_summary forward_local + - exec: $forward_local + + # 封装调用国外 DNS + - tag: forward_remote_upstream + type: sequence + args: + - exec: prefer_ipv4 + - exec: query_summary forward_remote + - exec: $forward_remote + + # 如果已有响应,直接返回 + - tag: has_resp_sequence + type: sequence + args: + - matches: has_resp + exec: accept + + # ====================== + # 查询逻辑 + # ====================== + + # 拒绝无效查询(如 HTTPS 记录) + - tag: query_is_reject_domain + type: sequence + args: + - matches: qtype 65 + exec: reject 3 + + # GFW 域名:强制走国外 DNS + - tag: query_is_foreign_domain + type: sequence + args: + - matches: qname $GFW_domains + exec: $forward_remote_upstream + - exec: query_summary gfw_domain + + # 国内域名:强制走国内 DNS + - tag: query_is_cn_domain + type: sequence + args: + - matches: qname $CN_domains + exec: $forward_local_upstream + - exec: query_summary cn_domain + + # Amazon 域名:走国外 DNS 并添加到 MikroTik + - tag: query_is_amazon_domain + type: sequence + args: + - matches: qname $amazon_domains + exec: $forward_remote_upstream + - exec: $mikrotik_amazon + - exec: query_summary amazon_domain + + # 未知域名处理逻辑: + # 先查国内 DNS → 如果返回非 CN IP ⇒ fallback 到国外 + - tag: query_unknown_fallback + type: sequence + args: + - exec: prefer_ipv4 + - exec: $forward_local + - matches: resp_ip $geoip_cn + exec: accept + - exec: $forward_remote_upstream + - exec: query_summary fallback_to_remote + + # ====================== + # 主查询处理流程 + # ====================== + - tag: main_sequence + type: sequence + args: + - exec: $cache # 首先查缓存 + - exec: $query_is_reject_domain + - exec: jump has_resp_sequence + + - exec: $query_is_foreign_domain # gfwlist + - exec: jump has_resp_sequence + + - exec: $query_is_cn_domain # 国内域名 + - exec: jump has_resp_sequence + + - exec: $query_is_amazon_domain # Amazon 域名(走国外 DNS + 添加到 MikroTik) + - exec: jump has_resp_sequence + + - exec: $query_unknown_fallback # 其他未知域名 fallback 流程 + - exec: jump has_resp_sequence + + # ====================== + # 服务监听 + # ====================== + + # UDP 监听 + - tag: udp_server + type: udp_server + args: + entry: main_sequence + listen: ":5300" + + # TCP 监听 + - tag: tcp_server + type: tcp_server + args: + entry: main_sequence + listen: ":5300" diff --git a/coremain/config.go b/coremain/config.go new file mode 100644 index 0000000..2f42168 --- /dev/null +++ b/coremain/config.go @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package coremain + +import ( + "github.com/IrineSistiana/mosdns/v5/mlog" +) + +type Config struct { + Log mlog.LogConfig `yaml:"log"` + Include []string `yaml:"include"` + Plugins []PluginConfig `yaml:"plugins"` + API APIConfig `yaml:"api"` +} + +// PluginConfig represents a plugin config +type PluginConfig struct { + // Tag for this plugin. Optional. If omitted, this plugin will + // be registered with a random tag. + Tag string `yaml:"tag"` + + // Type, required. + Type string `yaml:"type"` + + // Args, might be required by some plugins. + // The type of Args is depended on RegNewPluginFunc. + // If it's a map[string]any, it will be converted by mapstruct. + Args any `yaml:"args"` +} + +type APIConfig struct { + HTTP string `yaml:"http"` +} diff --git a/coremain/mosdns.go b/coremain/mosdns.go new file mode 100644 index 0000000..d9d3b21 --- /dev/null +++ b/coremain/mosdns.go @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package coremain + +import ( + "bytes" + "errors" + "fmt" + "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/IrineSistiana/mosdns/v5/pkg/safe_close" + "github.com/go-chi/chi/v5" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.uber.org/zap" + "io" + "net/http" + "net/http/pprof" +) + +type Mosdns struct { + logger *zap.Logger // non-nil logger. + + // Plugins + plugins map[string]any + + httpMux *chi.Mux + metricsReg *prometheus.Registry + sc *safe_close.SafeClose +} + +// NewMosdns initializes a mosdns instance and its plugins. +func NewMosdns(cfg *Config) (*Mosdns, error) { + // Init logger. + lg, err := mlog.NewLogger(cfg.Log) + if err != nil { + return nil, fmt.Errorf("failed to init logger: %w", err) + } + + m := &Mosdns{ + logger: lg, + plugins: make(map[string]any), + httpMux: chi.NewRouter(), + metricsReg: newMetricsReg(), + sc: safe_close.NewSafeClose(), + } + // This must be called after m.httpMux and m.metricsReg been set. + m.initHttpMux() + + // Start http api server + if httpAddr := cfg.API.HTTP; len(httpAddr) > 0 { + httpServer := &http.Server{ + Addr: httpAddr, + Handler: m.httpMux, + } + m.sc.Attach(func(done func(), closeSignal <-chan struct{}) { + defer done() + errChan := make(chan error, 1) + go func() { + m.logger.Info("starting api http server", zap.String("addr", httpAddr)) + errChan <- httpServer.ListenAndServe() + }() + select { + case err := <-errChan: + m.sc.SendCloseSignal(err) + case <-closeSignal: + _ = httpServer.Close() + } + }) + } + + // Load plugins. + + // Close all plugins on signal. + // From here, call m.sc.SendCloseSignal() if any plugin failed to load. + m.sc.Attach(func(done func(), closeSignal <-chan struct{}) { + go func() { + defer done() + <-closeSignal + m.logger.Info("starting shutdown sequences") + for tag, p := range m.plugins { + if closer, _ := p.(io.Closer); closer != nil { + m.logger.Info("closing plugin", zap.String("tag", tag)) + _ = closer.Close() + } + } + m.logger.Info("all plugins were closed") + }() + }) + + // Preset plugins + if err := m.loadPresetPlugins(); err != nil { + m.sc.SendCloseSignal(err) + _ = m.sc.WaitClosed() + return nil, err + } + // Plugins from config. + if err := m.loadPluginsFromCfg(cfg, 0); err != nil { + m.sc.SendCloseSignal(err) + _ = m.sc.WaitClosed() + return nil, err + } + m.logger.Info("all plugins are loaded") + + return m, nil +} + +// NewTestMosdnsWithPlugins returns a mosdns instance for testing. +func NewTestMosdnsWithPlugins(p map[string]any) *Mosdns { + return &Mosdns{ + logger: mlog.Nop(), + httpMux: chi.NewRouter(), + plugins: p, + metricsReg: newMetricsReg(), + sc: safe_close.NewSafeClose(), + } +} + +func (m *Mosdns) GetSafeClose() *safe_close.SafeClose { + return m.sc +} + +// CloseWithErr is a shortcut for m.sc.SendCloseSignal +func (m *Mosdns) CloseWithErr(err error) { + m.sc.SendCloseSignal(err) +} + +// Logger returns a non-nil logger. +func (m *Mosdns) Logger() *zap.Logger { + return m.logger +} + +// GetPlugin returns a plugin. +func (m *Mosdns) GetPlugin(tag string) any { + return m.plugins[tag] +} + +// GetMetricsReg returns a prometheus.Registerer with a prefix of "mosdns_" +func (m *Mosdns) GetMetricsReg() prometheus.Registerer { + return prometheus.WrapRegistererWithPrefix("mosdns_", m.metricsReg) +} + +func (m *Mosdns) GetAPIRouter() *chi.Mux { + return m.httpMux +} + +func (m *Mosdns) RegPluginAPI(tag string, mux *chi.Mux) { + m.httpMux.Mount("/plugins/"+tag, mux) +} + +func newMetricsReg() *prometheus.Registry { + reg := prometheus.NewRegistry() + reg.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + reg.MustRegister(collectors.NewGoCollector()) + return reg +} + +// initHttpMux initializes api entries. It MUST be called after m.metricsReg being initialized. +func (m *Mosdns) initHttpMux() { + // Register metrics. + m.httpMux.Method(http.MethodGet, "/metrics", promhttp.HandlerFor(m.metricsReg, promhttp.HandlerOpts{})) + + // Register pprof. + m.httpMux.Route("/debug/pprof", func(r chi.Router) { + r.Get("/*", pprof.Index) + r.Get("/cmdline", pprof.Cmdline) + r.Get("/profile", pprof.Profile) + r.Get("/symbol", pprof.Symbol) + r.Get("/trace", pprof.Trace) + }) + + // A helper page for invalid request. + invalidApiReqHelper := func(w http.ResponseWriter, req *http.Request) { + b := new(bytes.Buffer) + _, _ = fmt.Fprintf(b, "Invalid request %s %s\n\n", req.Method, req.RequestURI) + b.WriteString("Available api urls:\n") + _ = chi.Walk(m.httpMux, func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { + b.WriteString(method) + b.WriteByte(' ') + b.WriteString(route) + b.WriteByte('\n') + return nil + }) + _, _ = w.Write(b.Bytes()) + } + m.httpMux.NotFound(invalidApiReqHelper) + m.httpMux.MethodNotAllowed(invalidApiReqHelper) +} + +func (m *Mosdns) loadPresetPlugins() error { + for tag, f := range LoadNewPersetPluginFuncs() { + p, err := f(NewBP(tag, m)) + if err != nil { + return fmt.Errorf("failed to init preset plugin %s, %w", tag, err) + } + m.plugins[tag] = p + } + return nil +} + +// loadPluginsFromCfg loads plugins from this config. It follows include first. +func (m *Mosdns) loadPluginsFromCfg(cfg *Config, includeDepth int) error { + const maxIncludeDepth = 8 + if includeDepth > maxIncludeDepth { + return errors.New("maximum include depth reached") + } + includeDepth++ + + // Follow include first. + for _, s := range cfg.Include { + subCfg, path, err := loadConfig(s) + if err != nil { + return fmt.Errorf("failed to read config from %s, %w", s, err) + } + m.logger.Info("load config", zap.String("file", path)) + if err := m.loadPluginsFromCfg(subCfg, includeDepth); err != nil { + return fmt.Errorf("failed to load config from %s, %w", s, err) + } + } + + for i, pc := range cfg.Plugins { + if err := m.newPlugin(pc); err != nil { + return fmt.Errorf("failed to init plugin #%d %s, %w", i, pc.Tag, err) + } + } + return nil +} diff --git a/coremain/plugin.go b/coremain/plugin.go new file mode 100644 index 0000000..d222cd4 --- /dev/null +++ b/coremain/plugin.go @@ -0,0 +1,201 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package coremain + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/go-chi/chi/v5" + "go.uber.org/zap" + "reflect" + "sync" +) + +// NewPluginArgsFunc represents a func that creates a new args object. +type NewPluginArgsFunc func() any + +// NewPluginFunc represents a func that can init a Plugin. +// args is the object created by NewPluginArgsFunc. +type NewPluginFunc func(bp *BP, args any) (p any, err error) + +type PluginTypeInfo struct { + NewPlugin NewPluginFunc + NewArgs NewPluginArgsFunc +} + +var ( + // pluginTypeRegister stores init funcs for certain plugin types + pluginTypeRegister struct { + sync.RWMutex + m map[string]PluginTypeInfo + } +) + +// RegNewPluginFunc registers the type. +// If the type has been registered. RegNewPluginFunc will panic. +func RegNewPluginFunc(typ string, initFunc NewPluginFunc, argsType NewPluginArgsFunc) { + pluginTypeRegister.Lock() + defer pluginTypeRegister.Unlock() + + if pluginTypeRegister.m == nil { + pluginTypeRegister.m = make(map[string]PluginTypeInfo) + } + + _, ok := pluginTypeRegister.m[typ] + if ok { + panic(fmt.Sprintf("duplicate plugin type [%s]", typ)) + } + + pluginTypeRegister.m[typ] = PluginTypeInfo{ + NewPlugin: initFunc, + NewArgs: argsType, + } +} + +// DelPluginType deletes the init func for this plugin type. +// It is a noop if pluginType is not registered. +func DelPluginType(typ string) { + pluginTypeRegister.Lock() + defer pluginTypeRegister.Unlock() + delete(pluginTypeRegister.m, typ) +} + +// GetPluginType gets the registered type init func. +func GetPluginType(typ string) (PluginTypeInfo, bool) { + pluginTypeRegister.RLock() + defer pluginTypeRegister.RUnlock() + + info, ok := pluginTypeRegister.m[typ] + return info, ok +} + +// newPlugin initializes a Plugin from c and adds it to mosdns. +func (m *Mosdns) newPlugin(c PluginConfig) error { + if len(c.Tag) == 0 { + c.Tag = fmt.Sprintf("anonymouse_%s_%d", c.Type, len(m.plugins)) + } + + if _, dup := m.plugins[c.Tag]; dup { + return fmt.Errorf("duplicated plugin tag %s", c.Tag) + } + + typeInfo, ok := GetPluginType(c.Type) + if !ok { + return fmt.Errorf("plugin type %s not defined", c.Type) + } + + args := typeInfo.NewArgs() + if reflect.TypeOf(c.Args) == reflect.TypeOf(args) { // Same type, no need to parse. + args = c.Args + } else { + if err := utils.WeakDecode(c.Args, args); err != nil { + return fmt.Errorf("unable to decode plugin args: %w", err) + } + } + + m.logger.Info("loading plugin", zap.String("tag", c.Tag), zap.String("type", c.Type)) + p, err := typeInfo.NewPlugin(NewBP(c.Tag, m), args) + if err != nil { + return fmt.Errorf("failed to init plugin: %w", err) + } + m.plugins[c.Tag] = p + return nil +} + +// GetAllPluginTypes returns all plugin types which are configurable. +func GetAllPluginTypes() []string { + pluginTypeRegister.RLock() + defer pluginTypeRegister.RUnlock() + + var t []string + for typ := range pluginTypeRegister.m { + t = append(t, typ) + } + return t +} + +type NewPersetPluginFunc func(bp *BP) (any, error) + +var presetPluginFuncReg struct { + sync.Mutex + m map[string]NewPersetPluginFunc +} + +func RegNewPersetPluginFunc(tag string, f NewPersetPluginFunc) { + presetPluginFuncReg.Lock() + defer presetPluginFuncReg.Unlock() + + if presetPluginFuncReg.m == nil { + presetPluginFuncReg.m = make(map[string]NewPersetPluginFunc) + } + if _, ok := presetPluginFuncReg.m[tag]; ok { + panic(fmt.Sprintf("preset plugin %s has already been registered", tag)) + } + + presetPluginFuncReg.m[tag] = f +} + +func LoadNewPersetPluginFuncs() map[string]NewPersetPluginFunc { + presetPluginFuncReg.Lock() + defer presetPluginFuncReg.Unlock() + m := make(map[string]NewPersetPluginFunc) + for tag, pluginFunc := range presetPluginFuncReg.m { + m[tag] = pluginFunc + } + return m +} + +type BP struct { + tag string + m *Mosdns + l *zap.Logger +} + +// NewBP creates a new BP. m MUST NOT nil. +func NewBP(tag string, m *Mosdns) *BP { + return &BP{ + tag: tag, + l: m.Logger().Named(tag), + m: m, + } +} + +// L returns a non-nil logger. +func (p *BP) L() *zap.Logger { + return p.l +} + +// M returns a non-nil Mosdns. +func (p *BP) M() *Mosdns { + return p.m +} + +// Tag returns the plugin tag. +// This tag should be unique globally unless it's in +// a test environment. +func (p *BP) Tag() string { + return p.tag +} + +// RegAPI mounts mux to mosdns api. Note: Plugins MUST NOT call RegAPI twice. +// Since mounting same path to root chi.Mux causes runtime panic. +func (p *BP) RegAPI(mux *chi.Mux) { + p.m.RegPluginAPI(p.tag, mux) +} diff --git a/coremain/run.go b/coremain/run.go new file mode 100644 index 0000000..0c8357b --- /dev/null +++ b/coremain/run.go @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package coremain + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/kardianos/service" + "github.com/mitchellh/mapstructure" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "go.uber.org/zap" + "os" + "os/signal" + "runtime" + "syscall" +) + +type serverFlags struct { + c string + dir string + cpu int + asService bool +} + +var rootCmd = &cobra.Command{ + Use: "mosdns", +} + +func init() { + sf := new(serverFlags) + startCmd := &cobra.Command{ + Use: "start [-c config_file] [-d working_dir]", + Short: "Start mosdns main program.", + RunE: func(cmd *cobra.Command, args []string) error { + if sf.asService { + svc, err := service.New(&serverService{f: sf}, svcCfg) + if err != nil { + return fmt.Errorf("failed to init service, %w", err) + } + return svc.Run() + } + + m, err := NewServer(sf) + if err != nil { + return err + } + + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + sig := <-c + m.logger.Warn("signal received", zap.Stringer("signal", sig)) + m.sc.SendCloseSignal(nil) + }() + return m.GetSafeClose().WaitClosed() + }, + DisableFlagsInUseLine: true, + SilenceUsage: true, + } + rootCmd.AddCommand(startCmd) + fs := startCmd.Flags() + fs.StringVarP(&sf.c, "config", "c", "", "config file") + fs.StringVarP(&sf.dir, "dir", "d", "", "working dir") + fs.IntVar(&sf.cpu, "cpu", 0, "set runtime.GOMAXPROCS") + fs.BoolVar(&sf.asService, "as-service", false, "start as a service") + _ = fs.MarkHidden("as-service") + + serviceCmd := &cobra.Command{ + Use: "service", + Short: "Manage mosdns as a system service.", + } + serviceCmd.PersistentPreRunE = initService + serviceCmd.AddCommand( + newSvcInstallCmd(), + newSvcUninstallCmd(), + newSvcStartCmd(), + newSvcStopCmd(), + newSvcRestartCmd(), + newSvcStatusCmd(), + ) + rootCmd.AddCommand(serviceCmd) +} + +func AddSubCmd(c *cobra.Command) { + rootCmd.AddCommand(c) +} + +func Run() error { + return rootCmd.Execute() +} + +func NewServer(sf *serverFlags) (*Mosdns, error) { + if sf.cpu > 0 { + runtime.GOMAXPROCS(sf.cpu) + } + + if len(sf.dir) > 0 { + err := os.Chdir(sf.dir) + if err != nil { + return nil, fmt.Errorf("failed to change the current working directory, %w", err) + } + mlog.L().Info("working directory changed", zap.String("path", sf.dir)) + } + + cfg, fileUsed, err := loadConfig(sf.c) + if err != nil { + return nil, fmt.Errorf("fail to load config, %w", err) + } + mlog.L().Info("main config loaded", zap.String("file", fileUsed)) + + return NewMosdns(cfg) +} + +// loadConfig load a config from a file. If filePath is empty, it will +// automatically search and load a file which name start with "config". +func loadConfig(filePath string) (*Config, string, error) { + v := viper.New() + + if len(filePath) > 0 { + v.SetConfigFile(filePath) + } else { + v.SetConfigName("config") + v.AddConfigPath(".") + } + + if err := v.ReadInConfig(); err != nil { + return nil, "", fmt.Errorf("failed to read config: %w", err) + } + + decoderOpt := func(cfg *mapstructure.DecoderConfig) { + cfg.ErrorUnused = true + cfg.TagName = "yaml" + cfg.WeaklyTypedInput = true + } + + cfg := new(Config) + if err := v.Unmarshal(cfg, decoderOpt); err != nil { + return nil, "", fmt.Errorf("failed to unmarshal config: %w", err) + } + return cfg, v.ConfigFileUsed(), nil +} diff --git a/coremain/service.go b/coremain/service.go new file mode 100644 index 0000000..31e7250 --- /dev/null +++ b/coremain/service.go @@ -0,0 +1,220 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package coremain + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/kardianos/service" + "github.com/spf13/cobra" + "go.uber.org/zap" + "os" + "path/filepath" + "time" +) + +var ( + // initialized by "service" sub command + svc service.Service + svcCfg = &service.Config{ + Name: "mosdns", + DisplayName: "mosdns", + Description: "A DNS forwarder", + } +) + +type serverService struct { + f *serverFlags + m *Mosdns +} + +func (ss *serverService) Start(s service.Service) error { + mlog.L().Info("starting service", zap.String("platform", s.Platform())) + m, err := NewServer(ss.f) + if err != nil { + return err + } + ss.m = m + go func() { + err := m.GetSafeClose().WaitClosed() + if err != nil { + m.Logger().Fatal("server exited", zap.Error(err)) + } else { + m.Logger().Info("server exited") + } + }() + return nil +} + +func (ss *serverService) Stop(_ service.Service) error { + ss.m.Logger().Info("service is shutting down") + ss.m.GetSafeClose().SendCloseSignal(nil) + return ss.m.GetSafeClose().WaitClosed() +} + +// initService will init svc for sub command "service" +func initService(_ *cobra.Command, _ []string) error { + s, err := service.New(&serverService{}, svcCfg) + if err != nil { + return fmt.Errorf("cannot init service, %w", err) + } + svc = s + return nil +} + +func newSvcInstallCmd() *cobra.Command { + sf := new(serverFlags) + c := &cobra.Command{ + Use: "install [-d working_dir] [-c config_file]", + Short: "Install mosdns as a system service.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + if len(sf.dir) > 0 { + absWd, err := filepath.Abs(sf.dir) + if err != nil { + return fmt.Errorf("cannot solve absolute working dir path, %w", err) + } else { + sf.dir = absWd + } + } else { + ep, err := os.Executable() + if err != nil { + return fmt.Errorf("cannot solve current executable path, %w", err) + } + sf.dir = filepath.Dir(ep) + } + mlog.S().Infof("set service working dir as %s", sf.dir) + svcCfg.Arguments = []string{"start", "--as-service", "-d", sf.dir} + if len(sf.c) > 0 { + svcCfg.Arguments = append(svcCfg.Arguments, "-c", sf.c) + } + return svc.Install() + }, + DisableFlagsInUseLine: true, + SilenceUsage: true, + } + c.Flags().StringVarP(&sf.dir, "dir", "d", "", "working dir") + c.Flags().StringVarP(&sf.c, "config", "c", "", "config path") + return c +} + +func newSvcUninstallCmd() *cobra.Command { + c := &cobra.Command{ + Use: "uninstall", + Short: "Uninstall mosdns from system service.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return svc.Uninstall() + }, + DisableFlagsInUseLine: true, + SilenceUsage: true, + } + return c +} + +func newSvcStartCmd() *cobra.Command { + c := &cobra.Command{ + Use: "start", + Short: "Start mosdns system service.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + if err := svc.Start(); err != nil { + return err + } + + mlog.S().Info("service is starting") + time.Sleep(time.Second) + s, err := svc.Status() + if err != nil { + mlog.S().Warn("cannot get service status, %w", err) + } else { + switch s { + case service.StatusRunning: + mlog.S().Info("service is running") + case service.StatusStopped: + mlog.S().Error("service is stopped, check mosdns and system service log for more info") + default: + mlog.S().Warn("cannot get service status, system may not support this operation") + } + } + + return nil + }, + DisableFlagsInUseLine: true, + SilenceUsage: true, + } + return c +} + +func newSvcStopCmd() *cobra.Command { + c := &cobra.Command{ + Use: "stop", + Short: "Stop mosdns system service.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return svc.Stop() + }, + DisableFlagsInUseLine: true, + SilenceUsage: true, + } + return c +} + +func newSvcRestartCmd() *cobra.Command { + c := &cobra.Command{ + Use: "restart", + Short: "Restart mosdns system service.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return svc.Restart() + }, + DisableFlagsInUseLine: true, + SilenceUsage: true, + } + return c +} + +func newSvcStatusCmd() *cobra.Command { + c := &cobra.Command{ + Use: "status", + Short: "Status of mosdns system service.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + s, err := svc.Status() + if err != nil { + return fmt.Errorf("cannot get service status, %w", err) + } + var out string + switch s { + case service.StatusRunning: + out = "running" + case service.StatusStopped: + out = "stopped" + case service.StatusUnknown: + out = "unknown" + } + println(out) + return nil + }, + DisableFlagsInUseLine: true, + SilenceUsage: true, + } + return c +} diff --git a/deploy-mikrotik-amazon-updated.md b/deploy-mikrotik-amazon-updated.md new file mode 100644 index 0000000..d5153cb --- /dev/null +++ b/deploy-mikrotik-amazon-updated.md @@ -0,0 +1,192 @@ +# MosDNS + MikroTik Amazon 域名处理部署指南(更新版) + +## 功能说明 + +这个配置会在解析 Amazon 相关域名时,自动将解析到的 IP 地址添加到 MikroTik 路由器的 address list 中,用于防火墙规则控制。 + +## 部署步骤 + +### 1. 上传文件到 Debian 12 服务器 + +```bash +# 上传编译好的 mosdns 可执行文件 +scp mosdns-linux-amd64 user@your-server:/usr/local/bin/mosdns + +# 上传配置文件 +scp config.yaml user@your-server:/opt/mosdns/ +scp dns.yaml user@your-server:/opt/mosdns/ + +# 设置执行权限 +ssh user@your-server "chmod +x /usr/local/bin/mosdns" +``` + +### 2. 创建必要的目录和文件 + +```bash +# 创建配置目录 +sudo mkdir -p /opt/mosdns/config + +# 下载 Amazon 域名列表 +sudo wget -O /opt/mosdns/config/geosite_amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon +sudo wget -O /opt/mosdns/config/geosite_amazon-ads.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon-ads +sudo wget -O /opt/mosdns/config/geosite_amazontrust.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazontrust +sudo wget -O /opt/mosdns/config/amazon.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/amazon + +# 下载其他必要的域名和 IP 文件 +sudo wget -O /opt/mosdns/config/geosite_tiktok.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/tiktok +sudo wget -O /opt/mosdns/config/gfwlist.out.txt https://raw.githubusercontent.com/gfwlist/gfwlist/master/gfwlist.txt +sudo wget -O /opt/mosdns/config/domains.txt https://raw.githubusercontent.com/v2fly/domain-list-community/master/data/category-games +sudo wget -O /opt/mosdns/config/cn.txt https://raw.githubusercontent.com/Loyalsoldier/v2ray-rules-dat/release/geoip.dat +``` + +### 3. 在 MikroTik 中创建 address list + +```bash +# 通过 SSH 连接到 MikroTik 路由器 +ssh admin@10.248.0.1 + +# 创建 IPv4 和 IPv6 address list +/ip firewall address-list add list=AmazonIP +/ip firewall address-list add list=AmazonIP6 + +# 创建防火墙规则(可选) +/ip firewall filter add chain=forward src-address-list=AmazonIP action=drop comment="Block Amazon IPs" +/ip firewall filter add chain=forward src-address-list=AmazonIP6 action=drop comment="Block Amazon IPv6 IPs" +``` + +### 4. 修改配置文件中的 MikroTik 连接信息 + +编辑 `/opt/mosdns/dns.yaml` 文件,确认 mikrotik_amazon 插件的配置: + +```yaml +# 当前配置(根据你的实际情况修改) +args: "10.248.0.1:9728:admin:szn0s!nw@pwd():false:10:AmazonIP:AmazonIP6:24:32:AmazonIP:86400" +``` + +参数说明: +- `10.248.0.1`: MikroTik 路由器 IP +- `9728`: API 端口 +- `admin`: 用户名 +- `szn0s!nw@pwd()`: 密码 +- `false`: 不使用 TLS +- `10`: 连接超时时间 +- `AmazonIP`: IPv4 address list 名称 +- `AmazonIP6`: IPv6 address list 名称 +- `24`: IPv4 掩码 +- `32`: IPv6 掩码 +- `AmazonIP`: 注释 +- `86400`: 地址超时时间(24小时) + +### 5. 创建 systemd 服务 + +```bash +sudo tee /etc/systemd/system/mosdns.service > /dev/null < /dev/null < github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/onsi/ginkgo/v2 v2.20.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.55.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + github.com/quic-go/qpack v0.4.0 // indirect + github.com/sagikazarmark/locafero v0.6.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.7.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/vishvananda/netns v0.0.4 // indirect + go.uber.org/mock v0.4.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/mod v0.20.0 // indirect + golang.org/x/text v0.17.0 // indirect + golang.org/x/tools v0.24.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..337824d --- /dev/null +++ b/go.sum @@ -0,0 +1,154 @@ +github.com/IrineSistiana/go-bytes-pool v0.0.0-20230918115058-c72bd9761c57 h1:nfurUSSmVY9sY/mYyoReOA1w2cR2fp2eicL9ojicZhQ= +github.com/IrineSistiana/go-bytes-pool v0.0.0-20230918115058-c72bd9761c57/go.mod h1:pQ/FSsWSNYmNdgIKmulKlmVC/R2PEpq2vIEi3J9IijI= +github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a h1:GQdh/h0q0ni3L//CXusyk+7QdhBL289vdNaes1WKkHI= +github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-routeros/routeros/v3 v3.0.1 h1:FdNKlF6Hst8nkHr0dIvD54pQ+dZ8sHOJfQSVRKz0BFg= +github.com/go-routeros/routeros/v3 v3.0.1/go.mod h1:j4mq65czXfKtHsdLkgVv8w7sNzyhLZy1TKi2zQDMpiQ= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= +github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60= +github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= +github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/onsi/ginkgo/v2 v2.20.0 h1:PE84V2mHqoT1sglvHc8ZdQtPcwmvvt29WLEEO3xmdZw= +github.com/onsi/ginkgo/v2 v2.20.0/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= +github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= +github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= +github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= +github.com/quic-go/quic-go v0.46.0 h1:uuwLClEEyk1DNvchH8uCByQVjo3yKL9opKulExNDs7Y= +github.com/quic-go/quic-go v0.46.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3N51bwOk= +github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= +github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= +github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/vishvananda/netlink v1.2.1-beta.2.0.20221107222636-d3c0a2caa559 h1:NwQroOyW+fpfiUroBzAMqFc6NRwBmvJevoVtEK6gsFE= +github.com/vishvananda/netlink v1.2.1-beta.2.0.20221107222636-d3c0a2caa559/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= +go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220804214406-8e32c043e418/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..70b83ee --- /dev/null +++ b/main.go @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package main + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/mlog" + _ "github.com/IrineSistiana/mosdns/v5/plugin" + _ "github.com/IrineSistiana/mosdns/v5/tools" + "github.com/spf13/cobra" + _ "net/http/pprof" +) + +var ( + version = "dev/unknown" +) + +func init() { + coremain.AddSubCmd(&cobra.Command{ + Use: "version", + Short: "Print out version info and exit.", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println(version) + }, + }) +} + +func main() { + if err := coremain.Run(); err != nil { + mlog.S().Fatal(err) + } +} diff --git a/mlog/logger.go b/mlog/logger.go new file mode 100644 index 0000000..861f091 --- /dev/null +++ b/mlog/logger.go @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package mlog + +import ( + "fmt" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "os" +) + +type LogConfig struct { + // Level, See also zapcore.ParseLevel. + Level string `yaml:"level"` + + // File that logger will be writen into. + // Default is stderr. + File string `yaml:"file"` + + // Production enables json output. + Production bool `yaml:"production"` +} + +var ( + stderr = zapcore.Lock(os.Stderr) + lvl = zap.NewAtomicLevelAt(zap.InfoLevel) + l = zap.New(zapcore.NewCore(zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), stderr, lvl)) + s = l.Sugar() + + nop = zap.NewNop() +) + +func NewLogger(lc LogConfig) (*zap.Logger, error) { + lvl, err := zapcore.ParseLevel(lc.Level) + if err != nil { + return nil, fmt.Errorf("invalid log level: %w", err) + } + + var out zapcore.WriteSyncer + if lf := lc.File; len(lf) > 0 { + f, _, err := zap.Open(lf) + if err != nil { + return nil, fmt.Errorf("open log file: %w", err) + } + out = zapcore.Lock(f) + } else { + out = stderr + } + + if lc.Production { + return zap.New(zapcore.NewCore(zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()), out, lvl)), nil + } + return zap.New(zapcore.NewCore(zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), out, lvl)), nil +} + +// L is a global logger. +func L() *zap.Logger { + return l +} + +// SetLevel sets the log level for the global logger. +func SetLevel(l zapcore.Level) { + lvl.SetLevel(l) +} + +// S is a global logger. +func S() *zap.SugaredLogger { + return s +} + +// Nop is a logger that never writes out logs. +func Nop() *zap.Logger { + return nop +} diff --git a/mosdns-linux-amd64 b/mosdns-linux-amd64 new file mode 100644 index 0000000..b5c7057 Binary files /dev/null and b/mosdns-linux-amd64 differ diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 0000000..835ac8a --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package cache + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/concurrent_lru" + "github.com/IrineSistiana/mosdns/v5/pkg/concurrent_map" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "sync/atomic" + "time" +) + +const ( + defaultCleanerInterval = time.Second * 10 +) + +type Key interface { + concurrent_lru.Hashable +} + +type Value interface { + any +} + +// Cache is a simple map cache that stores values in memory. +// It is safe for concurrent use. +type Cache[K Key, V Value] struct { + opts Opts + + closed atomic.Bool + closeNotify chan struct{} + m *concurrent_map.Map[K, *elem[V]] +} + +type Opts struct { + Size int + CleanerInterval time.Duration +} + +func (opts *Opts) init() { + utils.SetDefaultNum(&opts.Size, 1024) + utils.SetDefaultNum(&opts.CleanerInterval, defaultCleanerInterval) +} + +type elem[V Value] struct { + v V + expirationTime time.Time +} + +// New initializes a Cache. +// The minimum size is 1024. +// cleanerInterval specifies the interval that Cache scans +// and discards expired values. If cleanerInterval <= 0, a default +// interval will be used. +func New[K Key, V Value](opts Opts) *Cache[K, V] { + opts.init() + c := &Cache[K, V]{ + closeNotify: make(chan struct{}), + m: concurrent_map.NewMapCache[K, *elem[V]](opts.Size), + } + go c.gcLoop(opts.CleanerInterval) + return c +} + +// Close closes the inner cleaner of this cache. +func (c *Cache[K, V]) Close() error { + if ok := c.closed.CompareAndSwap(false, true); ok { + close(c.closeNotify) + } + return nil +} + +func (c *Cache[K, V]) Get(key K) (v V, expirationTime time.Time, ok bool) { + if e, hasEntry := c.m.Get(key); hasEntry { + if e.expirationTime.Before(time.Now()) { + c.m.Del(key) + return + } + return e.v, e.expirationTime, true + } + return +} + +// Range calls f through all entries. If f returns an error, the same error will be returned +// by Range. +func (c *Cache[K, V]) Range(f func(key K, v V, expirationTime time.Time) error) error { + cf := func(key K, v *elem[V]) (newV *elem[V], setV bool, delV bool, err error) { + return nil, false, false, f(key, v.v, v.expirationTime) + } + return c.m.RangeDo(cf) +} + +// Store stores this kv in cache. If expirationTime is before time.Now(), +// Store is an noop. +func (c *Cache[K, V]) Store(key K, v V, expirationTime time.Time) { + now := time.Now() + if now.After(expirationTime) { + return + } + + e := &elem[V]{ + v: v, + expirationTime: expirationTime, + } + c.m.Set(key, e) + return +} + +func (c *Cache[K, V]) gcLoop(interval time.Duration) { + if interval <= 0 { + interval = defaultCleanerInterval + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-c.closeNotify: + return + case now := <-ticker.C: + c.gc(now) + } + } +} + +func (c *Cache[K, V]) gc(now time.Time) { + f := func(key K, v *elem[V]) (newV *elem[V], setV, delV bool, err error) { + return nil, false, now.After(v.expirationTime), nil + } + _ = c.m.RangeDo(f) +} + +// Len returns the current size of this cache. +func (c *Cache[K, V]) Len() int { + return c.m.Len() +} + +// Flush removes all stored entries from this cache. +func (c *Cache[K, V]) Flush() { + c.m.Flush() +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 0000000..2d9ec69 --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package cache + +import ( + "sync" + "testing" + "time" +) + +type testKey int + +func (t testKey) Sum() uint64 { + return uint64(t) +} + +func Test_Cache(t *testing.T) { + c := New[testKey, int](Opts{ + Size: 1024, + }) + for i := 0; i < 128; i++ { + key := testKey(i) + c.Store(key, i, time.Now().Add(time.Millisecond*200)) + v, _, ok := c.Get(key) + + if v != i { + t.Fatal("cache kv mismatched") + } + if !ok { + t.Fatal() + } + } + + for i := 0; i < 1024*4; i++ { + key := testKey(i) + c.Store(key, i, time.Now().Add(time.Millisecond*200)) + } + + if l := c.Len(); l > 1024 { + t.Fatal("cache overflow") + } +} + +func Test_memCache_cleaner(t *testing.T) { + c := New[testKey, int](Opts{ + Size: 1024, + CleanerInterval: time.Millisecond * 10, + }) + defer c.Close() + for i := 0; i < 64; i++ { + key := testKey(i) + c.Store(key, i, time.Now().Add(time.Millisecond*10)) + } + + time.Sleep(time.Millisecond * 100) + if c.Len() != 0 { + t.Fatal() + } +} + +func Test_memCache_race(t *testing.T) { + c := New[testKey, int](Opts{ + Size: 1024, + }) + defer c.Close() + + wg := sync.WaitGroup{} + for i := 0; i < 32; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 256; i++ { + key := testKey(i) + c.Store(key, i, time.Now().Add(time.Minute)) + _, _, _ = c.Get(key) + c.gc(time.Now()) + } + }() + } + wg.Wait() +} diff --git a/pkg/concurrent_lru/concurrent_lru.go b/pkg/concurrent_lru/concurrent_lru.go new file mode 100644 index 0000000..800a950 --- /dev/null +++ b/pkg/concurrent_lru/concurrent_lru.go @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package concurrent_lru + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/lru" + "sync" +) + +type Hashable interface { + comparable + Sum() uint64 +} + +type ShardedLRU[K Hashable, V any] struct { + l []*ConcurrentLRU[K, V] +} + +func NewShardedLRU[K Hashable, V any]( + shardNum, maxSizePerShard int, + onEvict func(key K, v V), +) *ShardedLRU[K, V] { + cl := &ShardedLRU[K, V]{ + l: make([]*ConcurrentLRU[K, V], 0, shardNum), + } + + for i := 0; i < shardNum; i++ { + cl.l = append(cl.l, NewConecurrentLRU[K, V](maxSizePerShard, onEvict)) + } + + return cl +} + +func (c *ShardedLRU[K, V]) Add(key K, v V) { + sl := c.getShard(key) + sl.Add(key, v) +} + +func (c *ShardedLRU[K, V]) Del(key K) { + sl := c.getShard(key) + sl.Del(key) +} + +func (c *ShardedLRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) { + for _, l := range c.l { + removed += l.Clean(f) + } + return removed +} + +func (c *ShardedLRU[K, V]) Flush() { + for _, l := range c.l { + l.Flush() + } +} + +func (c *ShardedLRU[K, V]) Get(key K) (v V, ok bool) { + sl := c.getShard(key) + v, ok = sl.Get(key) + return +} + +func (c *ShardedLRU[K, V]) Len() int { + sum := 0 + for _, l := range c.l { + sum += l.Len() + } + return sum +} + +func (c *ShardedLRU[K, V]) shardNum() int { + return len(c.l) +} + +func (c *ShardedLRU[K, V]) getShard(key K) *ConcurrentLRU[K, V] { + return c.l[key.Sum()%uint64(c.shardNum())] +} + +// ConcurrentLRU is a lru.LRU with a lock. +// It is concurrent safe. +type ConcurrentLRU[K comparable, V any] struct { + sync.Mutex + lru *lru.LRU[K, V] +} + +func NewConecurrentLRU[K comparable, V any](maxSize int, onEvict func(key K, v V)) *ConcurrentLRU[K, V] { + return &ConcurrentLRU[K, V]{ + lru: lru.NewLRU[K, V](maxSize, onEvict), + } +} + +func (c *ConcurrentLRU[K, V]) Add(key K, v V) { + c.Lock() + defer c.Unlock() + + c.lru.Add(key, v) +} + +func (c *ConcurrentLRU[K, V]) Del(key K) { + c.Lock() + defer c.Unlock() + + c.lru.Del(key) +} + +func (c *ConcurrentLRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) { + c.Lock() + defer c.Unlock() + + return c.lru.Clean(f) +} + +func (c *ConcurrentLRU[K, V]) Flush() { + c.Lock() + defer c.Unlock() + c.lru.Flush() +} + +func (c *ConcurrentLRU[K, V]) Get(key K) (v V, ok bool) { + c.Lock() + defer c.Unlock() + + v, ok = c.lru.Get(key) + return +} + +func (c *ConcurrentLRU[K, V]) Len() int { + c.Lock() + defer c.Unlock() + + return c.lru.Len() +} diff --git a/pkg/concurrent_lru/concurrent_lru_test.go b/pkg/concurrent_lru/concurrent_lru_test.go new file mode 100644 index 0000000..3f98fba --- /dev/null +++ b/pkg/concurrent_lru/concurrent_lru_test.go @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package concurrent_lru + +import ( + "reflect" + "testing" +) + +type testKey int + +func (k testKey) Sum() uint64 { + return uint64(k) +} + +func TestConcurrentLRU(t *testing.T) { + onEvict := func(k testKey, v int) {} + + var cache *ShardedLRU[testKey, int] + reset := func(shardNum, maxShardSize int) { + cache = NewShardedLRU[testKey, int](shardNum, maxShardSize, onEvict) + } + + add := func(keys ...int) { + for _, k := range keys { + cache.Add(testKey(k), k) + } + } + + mustGet := func(keys ...int) { + for _, k := range keys { + gotV, ok := cache.Get(testKey(k)) + if !ok || !reflect.DeepEqual(gotV, k) { + t.Fatalf("want %v, got %v", k, gotV) + } + } + } + + emptyGet := func(keys ...int) { + for _, k := range keys { + gotV, ok := cache.Get(testKey(k)) + if ok { + t.Fatalf("want empty, got %v", gotV) + } + } + } + + checkLen := func(want int) { + if want != cache.Len() { + t.Fatalf("want %v, got %v", want, cache.Len()) + } + } + + // test add + reset(4, 16) + add(1, 1, 1, 1, 2, 2, 3, 3, 4) + checkLen(4) + mustGet(1, 2, 3, 4) + emptyGet(5, 6, 7, 9999) + + // test add overflow + reset(4, 16) // max size is 64 + for i := 0; i < 1024; i++ { + add(i) + } + if cache.Len() > 64 { + t.Fatalf("lru overflowed: want len = %d, got = %d", 64, cache.Len()) + } + + // test del + reset(4, 16) + add(1, 2, 3, 4) + cache.Del(2) + cache.Del(4) + cache.Del(9999) + mustGet(1, 3) + emptyGet(2, 4) + + // test clean + reset(4, 16) + add(1, 2, 3, 4) + cleanFunc := func(k testKey, v int) (remove bool) { + switch k { + case 1, 3: + return true + } + return false + } + if cleaned := cache.Clean(cleanFunc); cleaned != 2 { + t.Fatalf("q.Clean want cleaned = 2, got %v", cleaned) + } + mustGet(2, 4) + emptyGet(1, 3) +} diff --git a/pkg/concurrent_map/map.go b/pkg/concurrent_map/map.go new file mode 100644 index 0000000..a351c21 --- /dev/null +++ b/pkg/concurrent_map/map.go @@ -0,0 +1,186 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package concurrent_map + +import ( + "sync" +) + +const ( + MapShardSize = 64 +) + +type Hashable interface { + comparable + Sum() uint64 +} + +type TestAndSetFunc[K comparable, V any] func(key K, v V, ok bool) (newV V, setV, deleteV bool) + +type Map[K Hashable, V any] struct { + shards [MapShardSize]shard[K, V] +} + +func NewMap[K Hashable, V any]() *Map[K, V] { + m := new(Map[K, V]) + for i := range m.shards { + m.shards[i] = newShard[K, V](0) + } + return m +} + +// NewMapCache returns a cache with a maximum size. +// Note that, because this it has multiple (MapShardSize) shards, +// the actual maximum size is MapShardSize*(size / MapShardSize). +// If size <=0, it's equal to NewMap(). +func NewMapCache[K Hashable, V any](size int) *Map[K, V] { + sizePreShard := size / MapShardSize + m := new(Map[K, V]) + for i := range m.shards { + m.shards[i] = newShard[K, V](sizePreShard) + } + return m +} + +func (m *Map[K, V]) getShard(key K) *shard[K, V] { + return &m.shards[key.Sum()%MapShardSize] +} + +func (m *Map[K, V]) Get(key K) (V, bool) { + return m.getShard(key).get(key) +} + +func (m *Map[K, V]) Set(key K, v V) { + m.getShard(key).set(key, v) +} + +func (m *Map[K, V]) Del(key K) { + m.getShard(key).del(key) +} + +func (m *Map[K, V]) TestAndSet(key K, f func(v V, ok bool) (newV V, setV, delV bool)) { + m.getShard(key).testAndSet(key, f) +} + +func (m *Map[K, V]) RangeDo(f func(k K, v V) (newV V, setV, delV bool, err error)) error { + for i := range m.shards { + if err := m.shards[i].rangeDo(f); err != nil { + return err + } + } + return nil +} + +func (m *Map[K, V]) Len() int { + l := 0 + for i := range m.shards { + l += m.shards[i].len() + } + return l +} + +func (m *Map[K, V]) Flush() { + for i := range m.shards { + m.shards[i].flush() + } +} + +type shard[K comparable, V any] struct { + l sync.RWMutex + max int // Negative or zero max means no limit. + m map[K]V +} + +func newShard[K comparable, V any](max int) shard[K, V] { + return shard[K, V]{ + max: max, + m: make(map[K]V), + } +} + +func (m *shard[K, V]) get(key K) (V, bool) { + m.l.RLock() + defer m.l.RUnlock() + v, ok := m.m[key] + return v, ok +} + +func (m *shard[K, V]) set(key K, v V) { + m.l.Lock() + defer m.l.Unlock() + if m.max > 0 && len(m.m)+1 > m.max { + for k := range m.m { + delete(m.m, k) + if len(m.m)+1 <= m.max { + break + } + } + } + m.m[key] = v +} + +func (m *shard[K, V]) del(key K) { + m.l.Lock() + defer m.l.Unlock() + delete(m.m, key) +} + +func (m *shard[K, V]) testAndSet(key K, f func(v V, ok bool) (newV V, setV, delV bool)) { + m.l.Lock() + defer m.l.Unlock() + v, ok := m.m[key] + newV, setV, deleteV := f(v, ok) + switch { + case setV: + m.m[key] = newV + case deleteV && ok: + delete(m.m, key) + } +} + +func (m *shard[K, V]) len() int { + m.l.RLock() + defer m.l.RUnlock() + return len(m.m) +} + +func (m *shard[K, V]) flush() { + m.l.RLock() + defer m.l.RUnlock() + m.m = make(map[K]V) +} + +func (m *shard[K, V]) rangeDo(f func(k K, v V) (newV V, setV, delV bool, err error)) error { + m.l.Lock() + defer m.l.Unlock() + for k, v := range m.m { + newV, setV, deleteV, err := f(k, v) + if err != nil { + return err + } + switch { + case setV: + m.m[k] = newV + case deleteV: + delete(m.m, k) + } + } + return nil +} diff --git a/pkg/concurrent_map/map_test.go b/pkg/concurrent_map/map_test.go new file mode 100644 index 0000000..ce99fa8 --- /dev/null +++ b/pkg/concurrent_map/map_test.go @@ -0,0 +1,190 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package concurrent_map + +import ( + "errors" + "sync" + "testing" +) + +type testMapHashable uint64 + +func (h testMapHashable) Sum() uint64 { + return uint64(h) +} + +func Test_Map(t *testing.T) { + m := NewMap[testMapHashable, int]() + wg := sync.WaitGroup{} + + // test add + for i := 0; i < 512; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + m.Set(testMapHashable(i), i) + }() + } + wg.Wait() + + // test range + wantErr := errors.New("") + f := func(key testMapHashable, v int) (newV int, setV bool, deleteV bool, err error) { + return 0, false, false, wantErr + } + if wantErr != m.RangeDo(f) { + t.Fatal("range should return a error") + } + + cc := make([]bool, 512) + f = func(key testMapHashable, v int) (newV int, setV bool, deleteV bool, err error) { + cc[key] = true + return 0, false, false, nil + } + _ = m.RangeDo(f) + for _, ok := range cc { + if !ok { + t.Fatal("test or range failed") + } + } + + // test get + for i := 0; i < 512; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + v, ok := m.Get(testMapHashable(i)) + if !ok { + t.Error() + return + } + if v != i { + t.Error() + return + } + }() + } + wg.Wait() + + // test len + if m.Len() != 512 { + t.Fatal() + } + + // test del + for i := 0; i < 512; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + m.Del(testMapHashable(i)) + }() + } + wg.Wait() + if m.Len() != 0 { + t.Fatal() + } +} + +func TestConcurrentMap_TestAndSet(t *testing.T) { + cm := NewMap[testMapHashable, int]() + wg := sync.WaitGroup{} + + f := func(v int, ok bool) (newV int, setV bool, deleteV bool) { + return 1, true, false + } + + for i := 0; i < 512; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cm.TestAndSet(1, f) + }() + } + wg.Wait() + + v, _ := cm.Get(1) + if v != 1 { + t.Fatal() + } + + // test delete + f = func(v int, ok bool) (newV int, setV bool, deleteV bool) { + return 1, false, true + } + cm.TestAndSet(1, f) + _, ok := cm.Get(1) + if ok { + t.Fatal() + } +} + +func BenchmarkConcurrentMap_Get_And_Set(b *testing.B) { + keys := make([]testMapHashable, 2048) + m := NewMap[testMapHashable, int]() + for i := 0; i < 2048; i++ { + key := testMapHashable(i) + keys[i] = key + m.Set(key, i) + } + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + i++ + key := keys[i%2048] + m.Set(key, i) + m.Get(key) + } + }) +} + +func Benchmark_RWMutexMap_Get_And_Set(b *testing.B) { + keys := make([]int, 2048) + rwm := new(sync.RWMutex) + m := make(map[int]int, 2048) + for i := 0; i < 2048; i++ { + keys[i] = i + m[i] = i + } + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + i++ + key := keys[i%2048] + + rwm.Lock() + m[key] = i + rwm.Unlock() + + rwm.RLock() + _ = m[key] + rwm.RUnlock() + } + }) +} diff --git a/pkg/dnsutils/msg.go b/pkg/dnsutils/msg.go new file mode 100644 index 0000000..b170cea --- /dev/null +++ b/pkg/dnsutils/msg.go @@ -0,0 +1,160 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package dnsutils + +import ( + "github.com/miekg/dns" + "strconv" +) + +// GetMinimalTTL returns the minimal ttl of this msg. +// If msg m has no record, it returns 0. +func GetMinimalTTL(m *dns.Msg) uint32 { + minTTL := ^uint32(0) + hasRecord := false + for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { + for _, rr := range section { + hdr := rr.Header() + if hdr.Rrtype == dns.TypeOPT { + continue // opt record ttl is not ttl. + } + hasRecord = true + ttl := hdr.Ttl + if ttl < minTTL { + minTTL = ttl + } + } + } + + if !hasRecord { // no ttl applied + return 0 + } + return minTTL +} + +// SetTTL updates all records' ttl to ttl, except opt record. +func SetTTL(m *dns.Msg, ttl uint32) { + for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { + for _, rr := range section { + hdr := rr.Header() + if hdr.Rrtype == dns.TypeOPT { + continue // opt record ttl is not ttl. + } + hdr.Ttl = ttl + } + } +} + +func ApplyMaximumTTL(m *dns.Msg, ttl uint32) { + applyTTL(m, ttl, true) +} + +func ApplyMinimalTTL(m *dns.Msg, ttl uint32) { + applyTTL(m, ttl, false) +} + +// SubtractTTL subtract delta from every m's RR. +// If RR's TTL is smaller than delta, SubtractTTL +// will return overflowed = true. +func SubtractTTL(m *dns.Msg, delta uint32) (overflowed bool) { + for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { + for _, rr := range section { + hdr := rr.Header() + if hdr.Rrtype == dns.TypeOPT { + continue // opt record ttl is not ttl. + } + if ttl := hdr.Ttl; ttl > delta { + hdr.Ttl = ttl - delta + } else { + hdr.Ttl = 1 + overflowed = true + } + } + } + return +} + +func applyTTL(m *dns.Msg, ttl uint32, maximum bool) { + for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { + for _, rr := range section { + hdr := rr.Header() + if hdr.Rrtype == dns.TypeOPT { + continue // opt record ttl is not ttl. + } + if maximum { + if hdr.Ttl > ttl { + hdr.Ttl = ttl + } + } else { + if hdr.Ttl < ttl { + hdr.Ttl = ttl + } + } + } + } +} + +func uint16Conv(u uint16, m map[uint16]string) string { + if s, ok := m[u]; ok { + return s + } + return strconv.Itoa(int(u)) +} + +func QclassToString(u uint16) string { + return uint16Conv(u, dns.ClassToString) +} + +func QtypeToString(u uint16) string { + return uint16Conv(u, dns.TypeToString) +} + +func GenEmptyReply(q *dns.Msg, rcode int) *dns.Msg { + r := new(dns.Msg) + r.SetRcode(q, rcode) + + var name string + if len(q.Question) > 1 { + name = q.Question[0].Name + } else { + name = "." + } + + r.Ns = []dns.RR{FakeSOA(name)} + return r +} + +func FakeSOA(name string) *dns.SOA { + return &dns.SOA{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + Ttl: 300, + }, + Ns: "fake-ns.mosdns.fake.root.", + Mbox: "fake-mbox.mosdns.fake.root.", + Serial: 2021110400, + Refresh: 1800, + Retry: 900, + Expire: 604800, + Minttl: 86400, + } +} diff --git a/pkg/dnsutils/net_io.go b/pkg/dnsutils/net_io.go new file mode 100644 index 0000000..b46cc09 --- /dev/null +++ b/pkg/dnsutils/net_io.go @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package dnsutils + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/miekg/dns" +) + +const ( + DnsHeaderLen = 12 // minimum dns msg size +) + +var ( + ErrPayloadTooSmall = errors.New("payload is to small for a valid dns msg") +) + +// ReadRawMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed +// with a two byte length field). +// n represents how many bytes are read from c. +// The returned the *[]byte should be released by pool.ReleaseBuf. +func ReadRawMsgFromTCP(c io.Reader) (*[]byte, error) { + h := pool.GetBuf(2) + defer pool.ReleaseBuf(h) + _, err := io.ReadFull(c, *h) + + if err != nil { + return nil, err + } + + // dns length + length := binary.BigEndian.Uint16(*h) + if length <= DnsHeaderLen { + return nil, ErrPayloadTooSmall + } + + b := pool.GetBuf(int(length)) + _, err = io.ReadFull(c, *b) + if err != nil { + pool.ReleaseBuf(b) + return nil, err + } + return b, nil +} + +// ReadMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed +// with a two byte length field). +// n represents how many bytes are read from c. +func ReadMsgFromTCP(c io.Reader) (*dns.Msg, int, error) { + b, err := ReadRawMsgFromTCP(c) + if err != nil { + return nil, 0, err + } + defer pool.ReleaseBuf(b) + + m, err := unpackMsgWithDetailedErr(*b) + return m, len(*b) + 2, err +} + +// WriteMsgToTCP packs and writes m to c in RFC 1035 format. +// n represents how many bytes are written to c. +func WriteMsgToTCP(c io.Writer, m *dns.Msg) (n int, err error) { + buf, err := pool.PackTCPBuffer(m) + if err != nil { + return 0, err + } + defer pool.ReleaseBuf(buf) + return c.Write(*buf) +} + +// WriteRawMsgToTCP See WriteMsgToTCP +func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) { + if len(b) > dns.MaxMsgSize { + return 0, fmt.Errorf("payload length %d is greater than dns max msg size", len(b)) + } + + buf := pool.GetBuf(len(b) + 2) + defer pool.ReleaseBuf(buf) + + binary.BigEndian.PutUint16((*buf)[:2], uint16(len(b))) + copy((*buf)[2:], b) + return c.Write((*buf)) +} + +func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) { + b, err := pool.PackBuffer(m) + if err != nil { + return 0, err + } + defer pool.ReleaseBuf(b) + return c.Write(*b) +} + +func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) { + if bufSize < dns.MinMsgSize { + bufSize = dns.MinMsgSize + } + + b := pool.GetBuf(bufSize) + defer pool.ReleaseBuf(b) + n, err := c.Read(*b) + if err != nil { + return nil, n, err + } + + m, err := unpackMsgWithDetailedErr((*b)[:n]) + return m, n, err +} + +func unpackMsgWithDetailedErr(b []byte) (*dns.Msg, error) { + m := new(dns.Msg) + if err := m.Unpack(b); err != nil { + return nil, fmt.Errorf("failed to unpack msg [%x], %w", b, err) + } + return m, nil +} diff --git a/pkg/dnsutils/ptr_parser.go b/pkg/dnsutils/ptr_parser.go new file mode 100644 index 0000000..8d6d5f7 --- /dev/null +++ b/pkg/dnsutils/ptr_parser.go @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package dnsutils + +import ( + "errors" + "fmt" + "net/netip" + "strconv" + "strings" +) + +var errNotPTRDomain = errors.New("domain does not has a ptr suffix") + +const ( + IP4arpa = ".in-addr.arpa." + IP6arpa = ".ip6.arpa." +) + +// ParsePTRQName returns the ip that a PTR query name contains. +func ParsePTRQName(fqdn string) (netip.Addr, error) { + switch { + case strings.HasSuffix(fqdn, IP4arpa): + return reverse4(fqdn[:len(fqdn)-len(IP4arpa)]) + case strings.HasSuffix(fqdn, IP6arpa): + return reverse6(fqdn[:len(fqdn)-len(IP6arpa)]) + default: + return netip.Addr{}, errNotPTRDomain + } +} + +func reverse4(s string) (netip.Addr, error) { + var buf [4]byte + l := 0 + for offset := len(s); offset > 0 && l < len(buf); l++ { + var label string + label, offset = prevLabel(s, offset) + n, err := strconv.ParseUint(label, 10, 8) + if err != nil { + return netip.Addr{}, fmt.Errorf("invaild bit, %w", err) + } + buf[l] = byte(n) + } + if l < len(buf) { + return netip.Addr{}, fmt.Errorf("expact at least 3 labels, got %d", l) + } + return netip.AddrFrom4(buf), nil +} + +func reverse6(s string) (netip.Addr, error) { + var buf [16]byte + var val byte + var tail bool + var l int + for offset := len(s); offset > 0 && l < len(buf); { + var label string + label, offset = prevLabel(s, offset) + if len(label) != 1 { + return netip.Addr{}, fmt.Errorf("invalid label %s", label) + } + b := label[0] + n, ok := hex2byte(b) + if !ok { + return netip.Addr{}, fmt.Errorf("invaild bit %d", b) + } + if tail { + buf[l] = val<<4 + n + l++ + tail = false + } else { + val = n + tail = true + } + } + if l < len(buf) { + return netip.Addr{}, fmt.Errorf("expact at least 16 bytes, got %d", l) + } + return netip.AddrFrom16(buf), nil +} + +func hex2byte(c byte) (byte, bool) { + lower := func(c byte) byte { + return c | ('x' - 'X') + } + var b byte + switch { + case '0' <= c && c <= '9': + b = c - '0' + case 'a' <= lower(c) && lower(c) <= 'z': + b = lower(c) - 'a' + 10 + default: + return 0, false + } + return b, true +} + +func prevLabel(s string, offset int) (string, int) { + for { + s = s[:offset] + n := strings.LastIndexByte(s, '.') + label := s[n+1 : offset] + if n != -1 && len(label) == 0 { + offset = n + continue + } + return label, n + } +} diff --git a/pkg/dnsutils/ptr_parser_test.go b/pkg/dnsutils/ptr_parser_test.go new file mode 100644 index 0000000..fdb8f96 --- /dev/null +++ b/pkg/dnsutils/ptr_parser_test.go @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package dnsutils + +import ( + "net/netip" + "reflect" + "testing" +) + +func Test_reverse4(t *testing.T) { + tests := []struct { + name string + s string + want netip.Addr + wantErr bool + }{ + {"v4", "4.4.8.8", netip.MustParseAddr("8.8.4.4"), false}, + {"v4_with_prefix", "prefix.4.4.8.8", netip.MustParseAddr("8.8.4.4"), false}, + {"invalid_format", "123114123", netip.Addr{}, true}, + {"invalid_format", "12..311..4123..", netip.Addr{}, true}, + {"invalid_format", "...", netip.Addr{}, true}, + {"short_length", "4.8.8", netip.Addr{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := reverse4(tt.s) + if (err != nil) != tt.wantErr { + t.Errorf("reverse4() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("reverse4() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_reverse6(t *testing.T) { + tests := []struct { + name string + s string + want netip.Addr + wantErr bool + }{ + {"v6", "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.MustParseAddr("2001:db8::567:89ab"), false}, + {"v6_with_prefix", "prefix.b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.MustParseAddr("2001:db8::567:89ab"), false}, + {"invalid_format", "123114123", netip.Addr{}, true}, + {"invalid_format", "..123...", netip.Addr{}, true}, + {"short_length", "0.0.0.0.0.0.8.b.d.0.1.0.0.2", netip.Addr{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := reverse6(tt.s) + if (err != nil) != tt.wantErr { + t.Errorf("reverse6() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("reverse4() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/hosts/hosts.go b/pkg/hosts/hosts.go new file mode 100644 index 0000000..7f05598 --- /dev/null +++ b/pkg/hosts/hosts.go @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package hosts + +import ( + "errors" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/miekg/dns" + "net/netip" + "strings" +) + +type Hosts struct { + matcher domain.Matcher[*IPs] +} + +// NewHosts creates a hosts using m. +func NewHosts(m domain.Matcher[*IPs]) *Hosts { + return &Hosts{ + matcher: m, + } +} + +func (h *Hosts) Lookup(fqdn string) (ipv4, ipv6 []netip.Addr) { + ips, ok := h.matcher.Match(fqdn) + if !ok { + return nil, nil // no such host + } + return ips.IPv4, ips.IPv6 +} + +func (h *Hosts) LookupMsg(m *dns.Msg) *dns.Msg { + if len(m.Question) != 1 { + return nil + } + q := m.Question[0] + typ := q.Qtype + fqdn := q.Name + if q.Qclass != dns.ClassINET || (typ != dns.TypeA && typ != dns.TypeAAAA) { + return nil + } + + ipv4, ipv6 := h.Lookup(fqdn) + if len(ipv4)+len(ipv6) == 0 { + return nil // no such host + } + + r := new(dns.Msg) + r.SetReply(m) + switch { + case typ == dns.TypeA && len(ipv4) > 0: + for _, ip := range ipv4 { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: fqdn, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 10, + }, + A: ip.AsSlice(), + } + r.Answer = append(r.Answer, rr) + } + case typ == dns.TypeAAAA && len(ipv6) > 0: + for _, ip := range ipv6 { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: fqdn, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 10, + }, + AAAA: ip.AsSlice(), + } + r.Answer = append(r.Answer, rr) + } + } + + // Append fake SOA record for empty reply. + if len(r.Answer) == 0 { + r.Ns = []dns.RR{dnsutils.FakeSOA(fqdn)} + } + return r +} + +type IPs struct { + IPv4 []netip.Addr + IPv6 []netip.Addr +} + +var _ domain.ParseStringFunc[*IPs] = ParseIPs + +func ParseIPs(s string) (string, *IPs, error) { + f := strings.Fields(s) + if len(f) == 0 { + return "", nil, errors.New("empty string") + } + + pattern := f[0] + v := new(IPs) + for _, ipStr := range f[1:] { + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return "", nil, fmt.Errorf("invalid ip addr %s, %w", ipStr, err) + } + + if ip.Is4() { // is ipv4 + v.IPv4 = append(v.IPv4, ip) + } else { // is ipv6 + v.IPv6 = append(v.IPv6, ip) + } + } + + return pattern, v, nil +} diff --git a/pkg/hosts/hosts_test.go b/pkg/hosts/hosts_test.go new file mode 100644 index 0000000..f7cf5ff --- /dev/null +++ b/pkg/hosts/hosts_test.go @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package hosts + +import ( + "bytes" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/miekg/dns" + "net" + "testing" +) + +var test_hosts = ` +# comment + # empty line +dns.google 8.8.8.8 8.8.4.4 2001:4860:4860::8844 2001:4860:4860::8888 +regexp:^123456789 192.168.1.1 +test.com 1.2.3.4 # will be replaced +test.com 2.3.4.5 +# nxdomain.com 1.2.3.4 +` + +func Test_hostsContainer_Match(t *testing.T) { + m := domain.NewMixMatcher[*IPs]() + m.SetDefaultMatcher(domain.MatcherDomain) + err := domain.LoadFromTextReader[*IPs](m, bytes.NewBuffer([]byte(test_hosts)), ParseIPs) + if err != nil { + t.Fatal(err) + } + h := NewHosts(m) + + type args struct { + name string + typ uint16 + } + tests := []struct { + name string + args args + wantMatched bool + wantAddr []string + }{ + {"matched A", args{name: "dns.google.", typ: dns.TypeA}, true, []string{"8.8.8.8", "8.8.4.4"}}, + {"matched AAAA", args{name: "dns.google.", typ: dns.TypeAAAA}, true, []string{"2001:4860:4860::8844", "2001:4860:4860::8888"}}, + {"not matched A", args{name: "nxdomain.com.", typ: dns.TypeA}, false, nil}, + {"not matched A", args{name: "sub.dns.google.", typ: dns.TypeA}, false, nil}, + {"matched regexp A", args{name: "123456789.test.", typ: dns.TypeA}, true, []string{"192.168.1.1"}}, + {"not matched regexp A", args{name: "0123456789.test.", typ: dns.TypeA}, false, nil}, + {"test replacement", args{name: "test.com.", typ: dns.TypeA}, true, []string{"2.3.4.5"}}, + {"test matched domain with mismatched type", args{name: "test.com.", typ: dns.TypeAAAA}, true, nil}, + } + for _, tt := range tests { + q := new(dns.Msg) + q.SetQuestion(tt.args.name, tt.args.typ) + + t.Run(tt.name, func(t *testing.T) { + r := h.LookupMsg(q) + if tt.wantMatched && r == nil { + t.Fatal("Lookup() should not return a nil result") + } + + for _, s := range tt.wantAddr { + wantIP := net.ParseIP(s) + if wantIP == nil { + t.Fatal("invalid test case addr") + } + found := false + for _, rr := range r.Answer { + var ip net.IP + switch rr := rr.(type) { + case *dns.A: + ip = rr.A + case *dns.AAAA: + ip = rr.AAAA + default: + continue + } + if ip.Equal(wantIP) { + found = true + break + } + } + if !found { + t.Fatal("wanted ip is not found in response") + } + } + }) + } +} diff --git a/pkg/list/elem.go b/pkg/list/elem.go new file mode 100644 index 0000000..4b9d25e --- /dev/null +++ b/pkg/list/elem.go @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package list + +type Elem[V any] struct { + prev, next *Elem[V] + list *List[V] + + Value V +} + +func NewElem[V any](v V) *Elem[V] { + return &Elem[V]{ + Value: v, + } +} + +func (e *Elem[V]) Prev() *Elem[V] { + return e.prev +} + +func (e *Elem[V]) Next() *Elem[V] { + return e.next +} diff --git a/pkg/list/list.go b/pkg/list/list.go new file mode 100644 index 0000000..a354bfe --- /dev/null +++ b/pkg/list/list.go @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package list + +type List[V any] struct { + front, back *Elem[V] + length int +} + +func New[V any]() *List[V] { + return &List[V]{} +} + +func mustBeFreeElem[V any](e *Elem[V]) { + if e.prev != nil || e.next != nil || e.list != nil { + panic("element is in use") + } +} + +func (l *List[V]) Front() *Elem[V] { + return l.front +} + +func (l *List[V]) Back() *Elem[V] { + return l.back +} + +func (l *List[V]) Len() int { + return l.length +} + +func (l *List[V]) PushFront(e *Elem[V]) *Elem[V] { + mustBeFreeElem(e) + l.length++ + e.list = l + if l.front == nil { + l.front = e + l.back = e + } else { + e.next = l.front + l.front.prev = e + l.front = e + } + return e +} + +func (l *List[V]) PushBack(e *Elem[V]) *Elem[V] { + mustBeFreeElem(e) + l.length++ + e.list = l + if l.back == nil { + l.front = e + l.back = e + } else { + e.prev = l.back + l.back.next = e + l.back = e + } + return e +} + +func (l *List[V]) PopElem(e *Elem[V]) *Elem[V] { + if e.list != l { + panic("elem is not belong to this list") + } + + l.length-- + if p := e.prev; p != nil { + p.next = e.next + } + if n := e.next; n != nil { + n.prev = e.prev + } + if e == l.front { + l.front = e.next + } + if e == l.back { + l.back = e.prev + } + e.prev, e.next, e.list = nil, nil, nil + return e +} diff --git a/pkg/list/list_test.go b/pkg/list/list_test.go new file mode 100644 index 0000000..6819a95 --- /dev/null +++ b/pkg/list/list_test.go @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package list + +import ( + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +func checkLinkPointers[V any](t *testing.T, l *List[V]) { + t.Helper() + + e := l.front + for e != nil { + if (e.next != nil && e.next.prev != e) || (e.prev != nil && e.prev.next != e) { + t.Fatal("broken list") + } + e = e.next + } +} + +func makeElems(n []int) []*Elem[int] { + s := make([]*Elem[int], 0, len(n)) + for _, i := range n { + s = append(s, NewElem(i)) + } + return s +} + +func allValue[V any](l *List[V]) []V { + s := make([]V, 0) + node := l.front + for node != nil { + s = append(s, node.Value) + node = node.next + } + return s +} + +func TestList_Push(t *testing.T) { + l := new(List[int]) + l.PushBack(NewElem(1)) + l.PushBack(NewElem(2)) + assert.Equal(t, []int{1, 2}, allValue(l)) + checkLinkPointers(t, l) + + l = new(List[int]) + l.PushFront(NewElem(1)) + l.PushFront(NewElem(2)) + assert.Equal(t, []int{2, 1}, allValue(l)) + checkLinkPointers(t, l) +} + +func TestList_PopElem(t *testing.T) { + tests := []struct { + name string + in []int + pop int + want int + wantList []int + }{ + {"pop front", []int{0, 1, 2}, 0, 0, []int{1, 2}}, + {"pop mid", []int{0, 1, 2}, 1, 1, []int{0, 2}}, + {"pop back", []int{0, 1, 2}, 2, 2, []int{0, 1}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := new(List[int]) + le := makeElems(tt.in) + for _, e := range le { + l.PushBack(e) + } + checkLinkPointers(t, l) + + got := l.PopElem(le[tt.pop]) + checkLinkPointers(t, l) + + if got.Value != tt.want { + t.Errorf("PopElem() = %v, want %v", got.Value, tt.want) + } + if !reflect.DeepEqual(allValue(l), tt.wantList) { + t.Errorf("allValue() = %v, want %v", allValue(l), tt.wantList) + } + }) + } +} diff --git a/pkg/lru/lru.go b/pkg/lru/lru.go new file mode 100644 index 0000000..885fb0c --- /dev/null +++ b/pkg/lru/lru.go @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package lru + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/list" +) + +type LRU[K comparable, V any] struct { + maxSize int + onEvict func(key K, v V) + + l *list.List[KV[K, V]] + m map[K]*list.Elem[KV[K, V]] +} + +type KV[K comparable, V any] struct { + key K + v V +} + +func NewLRU[K comparable, V any](maxSize int, onEvict func(key K, v V)) *LRU[K, V] { + if maxSize <= 0 { + panic(fmt.Sprintf("LRU: invalid max size: %d", maxSize)) + } + + return &LRU[K, V]{ + maxSize: maxSize, + onEvict: onEvict, + l: list.New[KV[K, V]](), + m: make(map[K]*list.Elem[KV[K, V]]), + } +} + +func (q *LRU[K, V]) Add(key K, v V) { + if e, ok := q.m[key]; ok { // update existed key + e.Value.v = v + q.l.PushBack(q.l.PopElem(e)) + return + } + + o := q.Len() - q.maxSize + 1 + for o > 0 { + key, v, _ := q.PopOldest() + if q.onEvict != nil { + q.onEvict(key, v) + } + o-- + } + + e := list.NewElem(KV[K, V]{ + key: key, + v: v, + }) + q.m[key] = e + q.l.PushBack(e) +} + +func (q *LRU[K, V]) Del(key K) { + e := q.m[key] + if e != nil { + q.delElem(e) + } +} + +func (q *LRU[K, V]) delElem(e *list.Elem[KV[K, V]]) { + key, v := e.Value.key, e.Value.v + q.l.PopElem(e) + delete(q.m, key) + if q.onEvict != nil { + q.onEvict(key, v) + } +} + +func (q *LRU[K, V]) PopOldest() (key K, v V, ok bool) { + e := q.l.Front() + if e != nil { + q.l.PopElem(e) + key, v = e.Value.key, e.Value.v + delete(q.m, key) + ok = true + return + } + return +} + +func (q *LRU[K, V]) Clean(f func(key K, v V) (remove bool)) (removed int) { + e := q.l.Front() + for e != nil { + next := e.Next() // Delete e will clean its pointers. Save it first. + key, v := e.Value.key, e.Value.v + if remove := f(key, v); remove { + q.delElem(e) + removed++ + } + e = next + } + return removed +} + +func (q *LRU[K, V]) Flush() { + q.l = list.New[KV[K, V]]() + q.m = make(map[K]*list.Elem[KV[K, V]]) +} + +func (q *LRU[K, V]) Get(key K) (v V, ok bool) { + e, ok := q.m[key] + if !ok { + return + } + q.l.PushBack(q.l.PopElem(e)) + return e.Value.v, true +} + +func (q *LRU[K, V]) Len() int { + return q.l.Len() +} diff --git a/pkg/lru/lru_test.go b/pkg/lru/lru_test.go new file mode 100644 index 0000000..7af72f3 --- /dev/null +++ b/pkg/lru/lru_test.go @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package lru + +import ( + "testing" +) + +func Test_lru(t *testing.T) { + var q *LRU[int, int] + reset := func(maxSize int) { + t.Helper() + q = NewLRU[int, int](maxSize, nil) + } + + add := func(keys ...int) { + t.Helper() + for _, key := range keys { + q.Add(key, key) + } + } + + mustGet := func(keys ...int) { + t.Helper() + for _, key := range keys { + gotV, ok := q.Get(key) + if !ok || gotV != key { + t.Fatalf("want %v, got %v", key, gotV) + } + } + } + + emptyGet := func(keys ...int) { + t.Helper() + for _, key := range keys { + gotV, ok := q.Get(key) + if ok { + t.Fatalf("want empty, got %v", gotV) + } + } + } + + mustPopOldest := func(keys ...int) { + t.Helper() + for _, key := range keys { + gotKey, gotV, ok := q.PopOldest() + if !ok { + t.Fatal() + } + if gotKey != key || gotV != gotKey { + t.Fatalf("want key: %v, v: %v, got key: %v, v:%v", key, key, gotKey, gotV) + } + } + } + + emptyPop := func() { + t.Helper() + gotKey, gotV, ok := q.PopOldest() + if ok { + t.Fatalf("want empty result, got key: %v, v:%v", gotKey, gotV) + } + } + + checkLen := func(want int) { + t.Helper() + if q.l.Len() != len(q.m) { + t.Fatalf("possible mem leak: q.l.Len() %v != len(q.m){ %v", q.l.Len(), len(q.m)) + } + if want != q.Len() { + t.Fatalf("want %v, got %v", want, q.Len()) + } + } + + // test add + reset(4) + add(1, 1, 1, 1, 1, 1, 2, 3) + checkLen(3) + mustGet(1, 2, 3) + + // test add overflow + reset(2) + add(1, 2, 3, 4, 5) + checkLen(2) + mustGet(4, 5) + emptyGet(1, 2, 3) + + // test pop + reset(3) + add(1, 2, 3) + mustPopOldest(1, 2, 3) + checkLen(0) + emptyPop() + + // test del + reset(3) + add(1, 2, 3) + q.Del(2) + q.Del(9999) + mustPopOldest(1, 3) + + // test clean + reset(4) + add(1, 2, 3, 4) + cleanFunc := func(key int, v int) (remove bool) { + switch key { + case 1, 3: + return true + } + return false + } + if cleaned := q.Clean(cleanFunc); cleaned != 2 { + t.Fatalf("q.Clean want cleaned = 2, got %v", cleaned) + } + mustPopOldest(2, 4) + + // test lru + reset(4) + add(1, 2, 3, 4) // 1 2 3 4 + mustGet(2, 3) // 1 4 2 3 + mustPopOldest(1, 4, 2, 3) +} diff --git a/pkg/matcher/domain/interface.go b/pkg/matcher/domain/interface.go new file mode 100644 index 0000000..485161f --- /dev/null +++ b/pkg/matcher/domain/interface.go @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package domain + +// "fqdn-insensitive" means the domain in Add() and Match() call +// is fqdn-insensitive. "google.com" and "google.com." will get +// the same outcome. +// The logic for case-insensitive is the same as above. + +type Matcher[T any] interface { + // Match matches the domain s. + // s could be a fqdn or not, and should be case-insensitive. + Match(s string) (v T, ok bool) +} + +type WriteableMatcher[T any] interface { + Matcher[T] + Add(pattern string, v T) error +} diff --git a/pkg/matcher/domain/load_helper.go b/pkg/matcher/domain/load_helper.go new file mode 100644 index 0000000..c5e44f2 --- /dev/null +++ b/pkg/matcher/domain/load_helper.go @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package domain + +import ( + "bufio" + "errors" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "io" + "strings" + "unicode" +) + +// ParseStringFunc parse data string to matcher pattern and additional attributions. +type ParseStringFunc[T any] func(s string) (pattern string, v T, err error) + +// check if s only contains a domain pattern (no other section, no space). +func patternOnly[T any](s string) (pattern string, v T, err error) { + if strings.IndexFunc(s, unicode.IsSpace) != -1 { + return "", v, errors.New("rule string has more than one section") + } + return s, v, nil +} + +// Load loads data from a string, parsing it with parseString function. +func Load[T any](m WriteableMatcher[T], s string, parseString ParseStringFunc[T]) error { + if parseString == nil { + parseString = patternOnly[T] + } + pattern, v, err := parseString(s) + if err != nil { + return err + } + return m.Add(pattern, v) +} + +// LoadFromTextReader loads multiple lines from reader r. r +func LoadFromTextReader[T any](m WriteableMatcher[T], r io.Reader, parseString ParseStringFunc[T]) error { + lineCounter := 0 + scanner := bufio.NewScanner(r) + for scanner.Scan() { + lineCounter++ + s := scanner.Text() + s = utils.RemoveComment(s, "#") + s = strings.TrimSpace(s) + if len(s) == 0 { + continue + } + + err := Load(m, s, parseString) + if err != nil { + return fmt.Errorf("line %d: %v", lineCounter, err) + } + } + return scanner.Err() +} + +func NewDomainMixMatcher() *MixMatcher[struct{}] { + mixMatcher := NewMixMatcher[struct{}]() + mixMatcher.SetDefaultMatcher(MatcherDomain) + return mixMatcher +} diff --git a/pkg/matcher/domain/matcher.go b/pkg/matcher/domain/matcher.go new file mode 100644 index 0000000..6ba41d3 --- /dev/null +++ b/pkg/matcher/domain/matcher.go @@ -0,0 +1,275 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package domain + +import ( + "errors" + "fmt" + "regexp" + "strings" + + "github.com/IrineSistiana/mosdns/v5/pkg/utils" +) + +var _ WriteableMatcher[any] = (*MixMatcher[any])(nil) +var _ WriteableMatcher[any] = (*SubDomainMatcher[any])(nil) +var _ WriteableMatcher[any] = (*FullMatcher[any])(nil) +var _ WriteableMatcher[any] = (*KeywordMatcher[any])(nil) +var _ WriteableMatcher[any] = (*RegexMatcher[any])(nil) + +type SubDomainMatcher[T any] struct { + root *labelNode[T] +} + +func NewSubDomainMatcher[T any]() *SubDomainMatcher[T] { + return &SubDomainMatcher[T]{root: new(labelNode[T])} +} + +func (m *SubDomainMatcher[T]) Match(s string) (T, bool) { + s = NormalizeDomain(s) + ds := NewReverseDomainScanner(s) + currentNode := m.root + v, ok := currentNode.getValue() + for ds.Scan() { + label := ds.NextLabel() + if nextNode := currentNode.getChild(label); nextNode != nil { + if nextNode.hasValue() { + v, ok = nextNode.getValue() + } + currentNode = nextNode + } else { + break + } + } + return v, ok +} + +func (m *SubDomainMatcher[T]) Len() int { + return m.root.len() +} + +func (m *SubDomainMatcher[T]) Add(s string, v T) error { + s = NormalizeDomain(s) + ds := NewReverseDomainScanner(s) + currentNode := m.root + for ds.Scan() { + label := ds.NextLabel() + if child := currentNode.getChild(label); child != nil { + currentNode = child + } else { + currentNode = currentNode.newChild(label) + } + } + currentNode.storeValue(v) + return nil +} + +type FullMatcher[T any] struct { + m map[string]T // string in is map must be a normalized domain (See NormalizeDomain). +} + +func NewFullMatcher[T any]() *FullMatcher[T] { + return &FullMatcher[T]{ + m: make(map[string]T), + } +} + +// Add adds domain s to this matcher, s can be a fqdn or not. +func (m *FullMatcher[T]) Add(s string, v T) error { + s = NormalizeDomain(s) + m.m[s] = v + return nil +} + +func (m *FullMatcher[T]) Match(s string) (v T, ok bool) { + s = NormalizeDomain(s) + v, ok = m.m[s] + return +} + +func (m *FullMatcher[T]) Len() int { + return len(m.m) +} + +type KeywordMatcher[T any] struct { + kws map[string]T +} + +func NewKeywordMatcher[T any]() *KeywordMatcher[T] { + return &KeywordMatcher[T]{ + kws: make(map[string]T), + } +} + +func (m *KeywordMatcher[T]) Add(keyword string, v T) error { + keyword = NormalizeDomain(keyword) // fqdn-insensitive and case-insensitive + m.kws[keyword] = v + return nil +} + +func (m *KeywordMatcher[T]) Match(s string) (v T, ok bool) { + s = NormalizeDomain(s) + for k, v := range m.kws { + if strings.Contains(s, k) { + return v, true + } + } + return v, false +} + +func (m *KeywordMatcher[T]) Len() int { + return len(m.kws) +} + +// RegexMatcher contains regexp rules. +// Note: the regexp rule is expect to match a lower-case non fqdn. +type RegexMatcher[T any] struct { + regs map[string]*regElem[T] +} + +type regElem[T any] struct { + reg *regexp.Regexp + v T +} + +func NewRegexMatcher[T any]() *RegexMatcher[T] { + return &RegexMatcher[T]{regs: make(map[string]*regElem[T])} +} + +func (m *RegexMatcher[T]) Add(expr string, v T) error { + e := m.regs[expr] + if e == nil { + reg, err := regexp.Compile(expr) + if err != nil { + return err + } + m.regs[expr] = ®Elem[T]{ + reg: reg, + v: v, + } + } else { + e.v = v + } + return nil +} + +func (m *RegexMatcher[T]) Match(s string) (v T, ok bool) { + s = NormalizeDomain(s) + for _, e := range m.regs { + if e.reg.MatchString(s) { + return e.v, true + } + } + var zeroT T + return zeroT, false +} + +func (m *RegexMatcher[T]) Len() int { + return len(m.regs) +} + +const ( + MatcherFull = "full" + MatcherDomain = "domain" + MatcherRegexp = "regexp" + MatcherKeyword = "keyword" +) + +type MixMatcher[T any] struct { + defaultMatcher string + + full *FullMatcher[T] + domain *SubDomainMatcher[T] + regex *RegexMatcher[T] + keyword *KeywordMatcher[T] +} + +func NewMixMatcher[T any]() *MixMatcher[T] { + return &MixMatcher[T]{ + full: NewFullMatcher[T](), + domain: NewSubDomainMatcher[T](), + regex: NewRegexMatcher[T](), + keyword: NewKeywordMatcher[T](), + } +} + +func (m *MixMatcher[T]) SetDefaultMatcher(s string) { + m.defaultMatcher = s +} + +func (m *MixMatcher[T]) GetSubMatcher(typ string) WriteableMatcher[T] { + switch typ { + case MatcherFull: + return m.full + case MatcherDomain: + return m.domain + case MatcherRegexp: + return m.regex + case MatcherKeyword: + return m.keyword + } + return nil +} + +var ErrNodefaultMatcher = errors.New("default matcher is not set") + +func (m *MixMatcher[T]) Add(s string, v T) error { + typ, pattern := m.splitTypeAndPattern(s) + if len(typ) == 0 { + if len(m.defaultMatcher) != 0 { + typ = m.defaultMatcher + } else { + return ErrNodefaultMatcher + } + } + sm := m.GetSubMatcher(typ) + if sm == nil { + return fmt.Errorf("unsupported match type [%s]", typ) + } + return sm.Add(pattern, v) +} + +func (m *MixMatcher[T]) Match(s string) (v T, ok bool) { + for _, matcher := range [...]Matcher[T]{m.full, m.domain, m.regex, m.keyword} { + if v, ok = matcher.Match(s); ok { + return v, true + } + } + return +} + +func (m *MixMatcher[T]) Len() int { + sum := 0 + for _, matcher := range [...]interface{ Len() int }{m.full, m.domain, m.regex, m.keyword} { + if matcher == nil { + continue + } + sum += matcher.Len() + } + return sum +} + +func (m *MixMatcher[T]) splitTypeAndPattern(s string) (string, string) { + typ, pattern, ok := utils.SplitString2(s, ":") + if !ok { + pattern = s + } + return typ, pattern +} diff --git a/pkg/matcher/domain/matcher_test.go b/pkg/matcher/domain/matcher_test.go new file mode 100644 index 0000000..a623c09 --- /dev/null +++ b/pkg/matcher/domain/matcher_test.go @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package domain + +import ( + "reflect" + "testing" +) + +func assertFunc[T any](t *testing.T, m Matcher[T]) func(domain string, wantBool bool, wantV any) { + return func(domain string, wantBool bool, wantV any) { + t.Helper() + v, ok := m.Match(domain) + if ok != wantBool { + t.Fatalf("%s, wantBool = %v, got = %v", domain, wantBool, ok) + } + + if !reflect.DeepEqual(v, wantV) { + t.Fatalf("%s, wantV = %v, got = %v", domain, wantV, v) + } + } +} + +type aStr struct { + s string +} + +func s(str string) *aStr { + return &aStr{s: str} +} + +func (a *aStr) Append(v any) { + a.s = a.s + v.(*aStr).s +} + +func TestDomainMatcher(t *testing.T) { + m := NewSubDomainMatcher[any]() + add := func(domain string, v any) { + m.Add(domain, v) + } + assert := assertFunc[any](t, m) + + add("cn", nil) + assertInt(t, 1, m.Len()) + assert("cn", true, nil) + assert("a.cn.", true, nil) + assert("a.com", false, nil) + add("a.b.com", nil) + assertInt(t, 2, m.Len()) + assert("a.b.com.", true, nil) + assert("q.w.e.r.a.b.com.", true, nil) + assert("b.com.", false, nil) + + // test replace + add("append", 0) + assertInt(t, 3, m.Len()) + assert("append.", true, 0) + add("append.", 1) + assert("append.", true, 1) + add("append", nil) + assert("append.", true, nil) + + // test sub domain + add("sub", 1) + assertInt(t, 4, m.Len()) + add("a.sub", 2) + assertInt(t, 5, m.Len()) + assert("sub", true, 1) + assert("b.sub", true, 1) + assert("a.sub", true, 2) + assert("a.a.sub", true, 2) + + // test case-insensitive + add("UPpER", 1) + assert("LowER.Upper", true, 1) + + // root match + add(".", 9) + assert("any.domain", true, 9) +} + +func assertInt(t testing.TB, want, got int) { + t.Helper() + if want != got { + t.Errorf("assertion failed: want %d, got %d", want, got) + } +} + +func Test_FullMatcher(t *testing.T) { + m := NewFullMatcher[any]() + assert := assertFunc[any](t, m) + add := func(domain string, v any) { + m.Add(domain, v) + } + + add("cn", nil) + assert("cn", true, nil) + assert("a.cn", false, nil) + add("test.test", nil) + assert("test.test", true, nil) + assert("test.a.test", false, nil) + + // test replace + add("append", 0) + assert("append", true, 0) + add("append", 1) + assert("append", true, 1) + add("append", nil) + assert("append", true, nil) + + assertInt(t, m.Len(), 3) + + // test case-insensitive + add("UPpER", 1) + assert("Upper", true, 1) +} + +func Test_KeywordMatcher(t *testing.T) { + m := NewKeywordMatcher[any]() + add := func(domain string, v any) { + m.Add(domain, v) + } + + assert := assertFunc[any](t, m) + + add("123", s("a")) + assert("123456.cn", true, s("a")) + assert("111123.com", true, s("a")) + assert("111111.cn", false, nil) + add("example.com", nil) + assert("sub.example.com", true, nil) + assert("example_sub.com", false, nil) + + // test replace + add("append", 0) + assert("append", true, 0) + add("append", 1) + assert("append", true, 1) + add("append", nil) + assert("append", true, nil) + + assertInt(t, m.Len(), 3) + + // test case-insensitive + add("UPpER", 1) + assert("L.Upper.U", true, 1) +} + +func Test_RegexMatcher(t *testing.T) { + m := NewRegexMatcher[any]() + add := func(expr string, v any, wantErr bool) { + err := m.Add(expr, v) + if (err != nil) != wantErr { + t.Fatalf("%s: want err %v, got %v", expr, wantErr, err != nil) + } + } + + assert := assertFunc[any](t, m) + + expr := "^github-production-release-asset-[0-9a-za-z]{6}\\.s3\\.amazonaws\\.com$" + add(expr, nil, false) + assert("github-production-release-asset-000000.s3.amazonaws.com", true, nil) + assert("github-production-release-asset-aaaaaa.s3.amazonaws.com", true, nil) + assert("github-production-release-asset-aa.s3.amazonaws.com", false, nil) + assert("prefix_github-production-release-asset-000000.s3.amazonaws.com", false, nil) + assert("github-production-release-asset-000000.s3.amazonaws.com.suffix", false, nil) + + expr = "^example" + add(expr, nil, false) + assert("example.com", true, nil) + assert("sub.example.com", false, nil) + + // test replace + add("append", 0, false) + assert("append", true, 0) + add("append", 1, false) + assert("append", true, 1) + add("append", nil, false) + assert("append", true, nil) + + expr = "*" + add(expr, nil, true) +} diff --git a/pkg/matcher/domain/utils.go b/pkg/matcher/domain/utils.go new file mode 100644 index 0000000..d4844b1 --- /dev/null +++ b/pkg/matcher/domain/utils.go @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package domain + +import ( + "strings" +) + +type ReverseDomainScanner struct { + s string // not fqdn + p int + t int +} + +func NewReverseDomainScanner(s string) *ReverseDomainScanner { + s = TrimDot(s) + return &ReverseDomainScanner{ + s: s, + p: len(s), + t: len(s), + } +} + +func (s *ReverseDomainScanner) Scan() bool { + if s.p <= 0 { + return false + } + s.t = s.p + s.p = strings.LastIndexByte(s.s[:s.p], '.') + return true +} + +func (s *ReverseDomainScanner) NextLabelOffset() int { + return s.p + 1 +} + +func (s *ReverseDomainScanner) NextLabel() (label string) { + return s.s[s.p+1 : s.t] +} + +// NormalizeDomain normalize domain string s. +// It removes the suffix "." and make sure the domain is in lower case. +// e.g. a fqdn "GOOGLE.com." will become "google.com" +func NormalizeDomain(s string) string { + return strings.ToLower(TrimDot(s)) +} + +// TrimDot trims suffix '.' +func TrimDot(s string) string { + if len(s) >= 1 && s[len(s)-1] == '.' { + s = s[:len(s)-1] + } + return s +} + +// labelNode can store dns labels. +type labelNode[T any] struct { + children map[string]*labelNode[T] // lazy init + + v T + hasV bool +} + +func (n *labelNode[T]) storeValue(v T) { + n.v = v + n.hasV = true +} + +func (n *labelNode[T]) getValue() (T, bool) { + return n.v, n.hasV +} + +func (n *labelNode[T]) hasValue() bool { + return n.hasV +} + +func (n *labelNode[T]) newChild(key string) *labelNode[T] { + if n.children == nil { + n.children = make(map[string]*labelNode[T]) + } + node := new(labelNode[T]) + n.children[key] = node + return node +} + +func (n *labelNode[T]) getChild(key string) *labelNode[T] { + return n.children[key] +} + +func (n *labelNode[T]) len() int { + l := 0 + for _, node := range n.children { + l += node.len() + if node.hasValue() { + l++ + } + } + return l +} diff --git a/pkg/matcher/domain/utils_test.go b/pkg/matcher/domain/utils_test.go new file mode 100644 index 0000000..e093352 --- /dev/null +++ b/pkg/matcher/domain/utils_test.go @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package domain + +import ( + "reflect" + "testing" +) + +func TestDomainScanner(t *testing.T) { + tests := []struct { + name string + fqdn string + wantOffsets []int + wantLabels []string + }{ + {"empty", "", []int{}, []string{}}, + {"root", ".", []int{}, []string{}}, + {"non fqdn", "a.2", []int{2, 0}, []string{"2", "a"}}, + {"domain", "1.2.3.", []int{4, 2, 0}, []string{"3", "2", "1"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewReverseDomainScanner(tt.fqdn) + gotOffsets := make([]int, 0) + for s.Scan() { + gotOffsets = append(gotOffsets, s.NextLabelOffset()) + } + if !reflect.DeepEqual(gotOffsets, tt.wantOffsets) { + t.Errorf("PrevLabelOffset() = %v, want %v", gotOffsets, tt.wantOffsets) + } + + s = NewReverseDomainScanner(tt.fqdn) + gotLabels := make([]string, 0) + for s.Scan() { + pl := s.NextLabel() + gotLabels = append(gotLabels, pl) + } + if !reflect.DeepEqual(gotLabels, tt.wantLabels) { + t.Errorf("PrevLabel() = %v, want %v", gotLabels, tt.wantLabels) + } + }) + } +} diff --git a/pkg/matcher/netlist/interface.go b/pkg/matcher/netlist/interface.go new file mode 100644 index 0000000..9aea6da --- /dev/null +++ b/pkg/matcher/netlist/interface.go @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package netlist + +import ( + "net/netip" +) + +type Matcher interface { + Match(addr netip.Addr) bool +} diff --git a/pkg/matcher/netlist/list.go b/pkg/matcher/netlist/list.go new file mode 100644 index 0000000..0151342 --- /dev/null +++ b/pkg/matcher/netlist/list.go @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package netlist + +import ( + "fmt" + "net/netip" + "sort" +) + +// List is a list of netip.Prefix. It stores all netip.Prefix in one single slice +// and use binary search. +// It is suitable for large static cidr search. +type List struct { + // stores valid and masked netip.Prefix(s) + e []netip.Prefix + sorted bool +} + +// NewList returns a *List. +func NewList() *List { + return &List{ + e: make([]netip.Prefix, 0), + } +} + +func mustValid(l []netip.Prefix) { + for i, prefix := range l { + if !prefix.IsValid() { + panic(fmt.Sprintf("invalid prefix at #%d", i)) + } + } +} + +// Append appends new netip.Prefix(s) to the list. +// This modified the list. Caller must call List.Sort() before calling List.Contains() +func (list *List) Append(newNet ...netip.Prefix) { + for i, n := range newNet { + addr := to6(n.Addr()) + bits := n.Bits() + if n.Addr().Is4() { + bits += 96 + } + newNet[i] = netip.PrefixFrom(addr, bits).Masked() + } + mustValid(newNet) + list.e = append(list.e, newNet...) + list.sorted = false +} + +// Sort sorts the list, this must be called after +// list being modified and before calling List.Contains(). +func (list *List) Sort() { + if list.sorted { + return + } + + sort.Sort(list) + out := make([]netip.Prefix, 0) + for i, n := range list.e { + if i == 0 { + out = append(out, n) + } else { + lv := &out[len(out)-1] + switch { + case n.Addr() == lv.Addr(): + if n.Bits() < lv.Bits() { + *lv = n + } + case !lv.Contains(n.Addr()): + out = append(out, n) + } + } + + } + + list.e = out + list.sorted = true +} + +// Len implements sort Interface. +func (list *List) Len() int { + return len(list.e) +} + +// Less implements sort Interface. +func (list *List) Less(i, j int) bool { + return list.e[i].Addr().Less(list.e[j].Addr()) +} + +// Swap implements sort Interface. +func (list *List) Swap(i, j int) { + list.e[i], list.e[j] = list.e[j], list.e[i] +} + +func (list *List) Match(addr netip.Addr) bool { + return list.Contains(addr) +} + +// Contains reports whether the list includes the given netip.Addr. +func (list *List) Contains(addr netip.Addr) bool { + if !list.sorted { + panic("list is not sorted") + } + if !addr.IsValid() { + return false + } + + addr = to6(addr) + + i, j := 0, len(list.e) + for i < j { + h := int(uint(i+j) >> 1) // avoid overflow when computing h + if list.e[h].Addr().Compare(addr) <= 0 { + i = h + 1 + } else { + j = h + } + } + + if i == 0 { + return false + } + + return list.e[i-1].Contains(addr) +} + +func to6(addr netip.Addr) netip.Addr { + if addr.Is6() { + return addr + } + return netip.AddrFrom16(addr.As16()) +} diff --git a/pkg/matcher/netlist/load_helper.go b/pkg/matcher/netlist/load_helper.go new file mode 100644 index 0000000..3be17d3 --- /dev/null +++ b/pkg/matcher/netlist/load_helper.go @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package netlist + +import ( + "bufio" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "io" + "net/netip" + "strings" +) + +// LoadFromReader loads IP list from a reader. +// It might modify the List and causes List unsorted. +func LoadFromReader(l *List, reader io.Reader) error { + scanner := bufio.NewScanner(reader) + + // count how many lines we have read. + lineCounter := 0 + for scanner.Scan() { + lineCounter++ + s := scanner.Text() + s = strings.TrimSpace(s) + s = utils.RemoveComment(s, "#") + s = utils.RemoveComment(s, " ") + if len(s) == 0 { + continue + } + err := LoadFromText(l, s) + if err != nil { + return fmt.Errorf("invalid data at line #%d: %w", lineCounter, err) + } + } + return scanner.Err() +} + +// LoadFromText loads an IP from s. +// It might modify the List and causes List unsorted. +func LoadFromText(l *List, s string) error { + if strings.ContainsRune(s, '/') { + ipNet, err := netip.ParsePrefix(s) + if err != nil { + return err + } + l.Append(ipNet) + return nil + } + + addr, err := netip.ParseAddr(s) + if err != nil { + return err + } + bits := 32 + if addr.Is6() { + bits = 128 + } + l.Append(netip.PrefixFrom(addr, bits)) + return nil +} diff --git a/pkg/matcher/netlist/netlist_test.go b/pkg/matcher/netlist/netlist_test.go new file mode 100644 index 0000000..4e1dddd --- /dev/null +++ b/pkg/matcher/netlist/netlist_test.go @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package netlist + +import ( + "bytes" + "net/netip" + "testing" +) + +func TestIPNetList_Sort_And_Merge(t *testing.T) { + raw := ` +192.168.0.0/32 # merged +192.168.0.0/24 # merged +192.168.0.0/16 +192.168.1.1/24 # merged +192.168.9.24/24 # merged +192.168.3.0/24 # merged +192.169.0.0/16 +104.16.0.0/12 +` + ipNetList := NewList() + err := LoadFromReader(ipNetList, bytes.NewBufferString(raw)) + if err != nil { + t.Fatal(err) + } + ipNetList.Sort() + + if ipNetList.Len() != 3 { + t.Fatalf("unexpected length %d", ipNetList.Len()) + } + + tests := []struct { + name string + testIP netip.Addr + want bool + }{ + {"0", netip.MustParseAddr("192.167.255.255"), false}, + {"1", netip.MustParseAddr("192.168.0.0"), true}, + {"2", netip.MustParseAddr("192.168.1.1"), true}, + {"3", netip.MustParseAddr("192.168.9.255"), true}, + {"4", netip.MustParseAddr("192.168.255.255"), true}, + {"5", netip.MustParseAddr("192.169.1.1"), true}, + {"6", netip.MustParseAddr("192.170.1.1"), false}, + {"7", netip.MustParseAddr("1.1.1.1"), false}, + {"8", netip.MustParseAddr("104.16.67.38"), true}, + {"9", netip.MustParseAddr("104.32.67.38"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ipNetList.Match(tt.testIP); got != tt.want { + t.Errorf("IPNetList.Match() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIPNetList_New_And_Contains(t *testing.T) { + raw := ` +# comment line +1.0.0.0/24 additional strings should be ignored +2.0.0.0/23 # comment +3.0.0.0 + +2000:0000::/32 +2000:2000::1 +` + + ipNetList := NewList() + err := LoadFromReader(ipNetList, bytes.NewBufferString(raw)) + if err != nil { + t.Fatal(err) + } + ipNetList.Sort() + + tests := []struct { + name string + testIP netip.Addr + want bool + }{ + {"", netip.MustParseAddr("1.0.0.0"), true}, + {"", netip.MustParseAddr("1.0.0.1"), true}, + {"", netip.MustParseAddr("1.0.1.0"), false}, + {"", netip.MustParseAddr("2.0.0.0"), true}, + {"", netip.MustParseAddr("2.0.1.255"), true}, + {"", netip.MustParseAddr("2.0.2.0"), false}, + {"", netip.MustParseAddr("3.0.0.0"), true}, + {"", netip.MustParseAddr("2000:0000::"), true}, + {"", netip.MustParseAddr("2000:0000::1"), true}, + {"", netip.MustParseAddr("2000:0000:1::"), true}, + {"", netip.MustParseAddr("2000:0001::"), false}, + {"", netip.MustParseAddr("2000:2000::1"), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ipNetList.Match(tt.testIP); got != tt.want { + t.Errorf("IPNetList.Match() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/nftset_utils/handler.go b/pkg/nftset_utils/handler.go new file mode 100644 index 0000000..3b79afb --- /dev/null +++ b/pkg/nftset_utils/handler.go @@ -0,0 +1,154 @@ +//go:build linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package nftset_utils + +import ( + "errors" + "fmt" + "net/netip" + "sync" + "time" + + "github.com/google/nftables" + "go4.org/netipx" +) + +var ( + ErrClosed = errors.New("closed handler") +) + +// NftSetHandler can add netip.Prefix to the corresponding set. +// The table that contains this set must be an inet family table. +// If the set has a 'interval' flag, the prefix from netip.Prefix will be +// applied. +type NftSetHandler struct { + opts HandlerOpts + + m sync.Mutex + closed bool + lastUpdate time.Time + set *nftables.Set + lastingConn *nftables.Conn // Note: lasting conn is not concurrent safe so m is required. + + disableSetCache bool // for test only +} + +type HandlerOpts struct { + TableFamily nftables.TableFamily + TableName string + SetName string +} + +// NewNtSetHandler inits NftSetHandler. +func NewNtSetHandler(opts HandlerOpts) *NftSetHandler { + return &NftSetHandler{ + opts: opts, + } +} + +// getSetLocked get set info from kernel. It has an internal cache and won't +// invoke a syscall every time. +func (h *NftSetHandler) getSetLocked() (*nftables.Set, error) { + const refreshInterval = time.Second + + now := time.Now() + if !h.disableSetCache && h.set != nil && now.Sub(h.lastUpdate) < refreshInterval { + return h.set, nil + } + + // Note: GetSetByName is not concurrent safe. + set, err := h.lastingConn.GetSetByName(&nftables.Table{Name: h.opts.TableName, Family: h.opts.TableFamily}, h.opts.SetName) + if err != nil { + return nil, err + } + h.set = set + h.lastUpdate = now + return set, nil +} + +// AddElems adds netip.Prefix(s) to set in a single batch. +func (h *NftSetHandler) AddElems(es ...netip.Prefix) error { + h.m.Lock() + defer h.m.Unlock() + + if h.closed { + return ErrClosed + } + + if h.lastingConn == nil { + c, err := nftables.New(nftables.AsLasting()) + if err != nil { + return fmt.Errorf("failed to open netlink, %w", err) + } + h.lastingConn = c + } + + set, err := h.getSetLocked() + if err != nil { + return fmt.Errorf("failed to get set, %w", err) + } + + var elems []nftables.SetElement + if set.Interval { + elems = make([]nftables.SetElement, 0, 2*len(es)) + } else { + elems = make([]nftables.SetElement, 0, len(es)) + } + + for i, e := range es { + if !e.IsValid() { + return fmt.Errorf("invalid prefix at index %d", i) + } + if set.Interval { + start := e.Masked().Addr() + elems = append(elems, nftables.SetElement{Key: start.AsSlice(), IntervalEnd: false}) + + end := netipx.PrefixLastIP(e).Next() // may be invalid if end is overflowed + if end.IsValid() { + elems = append(elems, nftables.SetElement{Key: end.AsSlice(), IntervalEnd: true}) + } + } else { + elems = append(elems, nftables.SetElement{Key: e.Addr().AsSlice()}) + } + } + + err = h.lastingConn.SetAddElements(set, elems) + if err != nil { + return err + } + return h.lastingConn.Flush() +} + +func (h *NftSetHandler) Close() error { + h.m.Lock() + defer h.m.Unlock() + + if h.closed { + return nil + } + + h.closed = true + if h.lastingConn != nil { + return h.lastingConn.CloseLasting() + } + return nil +} diff --git a/pkg/nftset_utils/handler_test.go b/pkg/nftset_utils/handler_test.go new file mode 100644 index 0000000..e6b8464 --- /dev/null +++ b/pkg/nftset_utils/handler_test.go @@ -0,0 +1,96 @@ +//go:build linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package nftset_utils + +import ( + "github.com/google/nftables" + "net/netip" + "os" + "sync" + "testing" +) + +func skipCI(t *testing.T) { + if os.Getenv("TEST_NFTSET") == "" { + t.SkipNow() + } +} + +func prepareSet(t testing.TB, tableName, setName string, interval bool) { + t.Helper() + nc, err := nftables.New() + if err != nil { + t.Fatal(err) + } + + table := &nftables.Table{Name: tableName, Family: nftables.TableFamilyINet} + nc.AddTable(table) + if err := nc.AddSet(&nftables.Set{Name: setName, Table: table, KeyType: nftables.TypeIPAddr, Interval: interval}, nil); err != nil { + t.Fatal(err) + } + if err := nc.Flush(); err != nil { + t.Fatal(err) + } +} + +func Test_AddElems(t *testing.T) { + skipCI(t) + n := "test" + prepareSet(t, n, n, false) + + h := NewNtSetHandler(HandlerOpts{ + TableFamily: nftables.TableFamilyINet, + TableName: n, + SetName: n, + }) + h.disableSetCache = true + + if err := h.AddElems(netip.MustParsePrefix("127.0.0.1/24")); err != nil { + t.Fatal(err) + } + + nc, err := nftables.New() + if err != nil { + t.Fatal(err) + } + elems, err := nc.GetSetElements(h.set) + if err != nil { + t.Fatal(err) + } + if len(elems) == 0 { + t.Fatal("set is empty") + } + + // test concurrent safe. + wg := new(sync.WaitGroup) + for i := 0; i < 512; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := h.AddElems(netip.MustParsePrefix("127.0.0.1/24")); err != nil { + t.Error(err) + return + } + }() + } + wg.Wait() +} diff --git a/pkg/pool/allocator.go b/pkg/pool/allocator.go new file mode 100644 index 0000000..f774692 --- /dev/null +++ b/pkg/pool/allocator.go @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package pool + +import ( + bytesPool "github.com/IrineSistiana/go-bytes-pool" +) + +var ( + _pool = bytesPool.NewPool(20) // 1Mb pool, should be enough. + GetBuf = _pool.Get + ReleaseBuf = _pool.Release +) diff --git a/pkg/pool/bytes_buf.go b/pkg/pool/bytes_buf.go new file mode 100644 index 0000000..40c8418 --- /dev/null +++ b/pkg/pool/bytes_buf.go @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package pool + +import ( + "bytes" + "fmt" + "sync" +) + +type BytesBufPool struct { + p sync.Pool +} + +func NewBytesBufPool(initSize int) *BytesBufPool { + if initSize < 0 { + panic(fmt.Sprintf("utils.NewBytesBufPool: negative init size %d", initSize)) + } + + return &BytesBufPool{ + p: sync.Pool{New: func() any { + b := new(bytes.Buffer) + b.Grow(initSize) + return b + }}, + } +} + +func (p *BytesBufPool) Get() *bytes.Buffer { + return p.p.Get().(*bytes.Buffer) +} + +func (p *BytesBufPool) Release(b *bytes.Buffer) { + b.Reset() + p.p.Put(b) +} diff --git a/pkg/pool/msg_buf.go b/pkg/pool/msg_buf.go new file mode 100644 index 0000000..845759b --- /dev/null +++ b/pkg/pool/msg_buf.go @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package pool + +import ( + "encoding/binary" + "fmt" + + "github.com/miekg/dns" +) + +// There is no such way to give dns.Msg.PackBuffer() a buffer +// with a proper size. +// Just give it a big buf and hope the buf will be reused in most scenes. +const packBufferSize = 8191 + +// PackBuffer packs the dns msg m to wire format. +// Callers should release the buf by calling ReleaseBuf after they have done +// with the wire []byte. +func PackBuffer(m *dns.Msg) (*[]byte, error) { + packBuf := GetBuf(packBufferSize) + defer ReleaseBuf(packBuf) + wire, err := m.PackBuffer(*packBuf) + if err != nil { + return nil, err + } + + msgBuf := GetBuf(len(wire)) + copy(*msgBuf, wire) + return msgBuf, nil +} + +// PackBuffer packs the dns msg m to wire format, with to bytes length header. +// Callers should release the buf by calling ReleaseBuf. +func PackTCPBuffer(m *dns.Msg) (*[]byte, error) { + packBuf := GetBuf(packBufferSize) + defer ReleaseBuf(packBuf) + wire, err := m.PackBuffer((*packBuf)[2:]) + if err != nil { + return nil, err + } + + l := len(wire) + if l > dns.MaxMsgSize { + return nil, fmt.Errorf("dns payload size %d is too large", l) + } + + msgBuf := GetBuf(2 + len(wire)) + binary.BigEndian.PutUint16(*msgBuf, uint16(l)) + copy((*msgBuf)[2:], wire) + return msgBuf, nil +} diff --git a/pkg/pool/timer.go b/pkg/pool/timer.go new file mode 100644 index 0000000..56616d4 --- /dev/null +++ b/pkg/pool/timer.go @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package pool + +import ( + "sync" + "time" +) + +var ( + timerPool = sync.Pool{} +) + +func GetTimer(t time.Duration) *time.Timer { + timer, ok := timerPool.Get().(*time.Timer) + if !ok { + return time.NewTimer(t) + } + if timer.Reset(t) { + panic("dispatcher.go getTimer: active timer trapped in timerPool") + } + return timer +} + +func ReleaseTimer(timer *time.Timer) { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerPool.Put(timer) +} + +func ResetAndDrainTimer(timer *time.Timer, d time.Duration) { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(d) +} diff --git a/pkg/query_context/context.go b/pkg/query_context/context.go new file mode 100644 index 0000000..dee0ae1 --- /dev/null +++ b/pkg/query_context/context.go @@ -0,0 +1,312 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package query_context + +import ( + "sync/atomic" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/miekg/dns" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +const ( + edns0Size = 1200 +) + +// Context is a query context that pass through plugins. +// All Context funcs are not safe for concurrent use. +type Context struct { + id uint32 + startTime time.Time + + // ServerMeta contains some meta info from the server. + // It is read-only. + ServerMeta ServerMeta + query *dns.Msg // always has one question. + clientOpt *dns.OPT // may be nil + + resp *dns.Msg + respOpt *dns.OPT // nil if clientOpt == nil + upstreamOpt *dns.OPT // may be nil + + // lazy init. + kv map[uint32]any + marks map[uint32]struct{} +} + +var contextUid atomic.Uint32 + +type ServerMeta = server.QueryMeta + +// NewContext creates a new query Context. +// q must have one question. +// NewContext takes the ownership of q. +func NewContext(q *dns.Msg) *Context { + ctx := &Context{ + id: contextUid.Add(1), + startTime: time.Now(), + query: q, + clientOpt: addNewAndSwapOldOpt(q), + } + if ctx.clientOpt != nil { + ctx.respOpt = newOpt() + + // RFC 3225 3 + // The DO bit of the query MUST be copied in the response. + if ctx.clientOpt.Do() { + setDo(ctx.respOpt, true) + } + } + return ctx +} + +// Id returns the Context id. +// Note: This id is not the dns msg id. +// It's a unique uint32 growing with the number of query. +func (ctx *Context) Id() uint32 { + return ctx.id +} + +// StartTime returns the time when the Context was created. +func (ctx *Context) StartTime() time.Time { + return ctx.startTime +} + +// Q returns the query msg that will be forward to upstream. +// It always returns a non-nil msg with one question and EDNS0 OPT. +// If Caller want to modify the msg, be sure not to break those conditions. +func (ctx *Context) Q() *dns.Msg { + return ctx.query +} + +// QQuestion returns the query question. +func (ctx *Context) QQuestion() dns.Question { + return ctx.query.Question[0] +} + +// QOpt returns the query opt. It always returns a non-nil opt. +// It's a helper func for searching opt in Q() manually. +func (ctx *Context) QOpt() *dns.OPT { + opt := findOpt(ctx.query) + ctx.query.IsEdns0() + if opt == nil { + panic("query opt is missing") + } + return opt +} + +// ClientOpt returns the OPT rr from client. Maybe nil, if client does not send it. +// Plugins that responsible for handling EDNS0 option should +// check ClientOpt and pick/add options into Q() on demand. +// The OPT is read-only. +func (ctx *Context) ClientOpt() *dns.OPT { + return ctx.clientOpt +} + +// SetResponse sets m as response. It takes the ownership of m. +// If m is nil. It removes existing response. +func (ctx *Context) SetResponse(m *dns.Msg) { + ctx.resp = m + if m == nil { + ctx.upstreamOpt = nil + } else { + ctx.upstreamOpt = popOpt(m) + } +} + +// R returns the response that will be sent to client. It might be nil. +// Note: R does not have EDNS0. Caller MUST NOT add a dns.OPT into R. +// Use RespOpt() instead. +func (ctx *Context) R() *dns.Msg { + return ctx.resp +} + +// RespOpt returns the OPT that will be sent to client. +// If client support EDNS0, then RespOpt always returns a non-nil OPT. +// No matter what R() returns. +// Otherwise, RespOpt returns nil. +func (ctx *Context) RespOpt() *dns.OPT { + return ctx.respOpt +} + +// UpstreamOpt returns the OPT from upstream. May be nil. +// Plugins that responsible for handling EDNS0 option should +// check UpstreamOpt and pick/add options into RespOpt on demand. +// The OPT is read-only. +func (ctx *Context) UpstreamOpt() *dns.OPT { + return ctx.upstreamOpt +} + +// InfoField returns a zap.Field contains a brief summary of this Context. +// Useful in log. +func (ctx *Context) InfoField() zap.Field { + return zap.Object("query", ctx) +} + +// Copy deep copies this Context. +// See CopyTo. +func (ctx *Context) Copy() *Context { + newCtx := new(Context) + ctx.CopyTo(newCtx) + return newCtx +} + +// CopyTo deep copies this Context to d. +// Note that values that stored by StoreValue is not deep-copied. +func (ctx *Context) CopyTo(d *Context) *Context { + d.id = ctx.id + d.startTime = ctx.startTime + + d.ServerMeta = ctx.ServerMeta + d.query = ctx.query.Copy() + d.clientOpt = ctx.clientOpt + + if ctx.resp != nil { + d.resp = ctx.resp.Copy() + } + if ctx.respOpt != nil { + d.respOpt = dns.Copy(ctx.respOpt).(*dns.OPT) + } + d.upstreamOpt = ctx.upstreamOpt + + d.kv = copyMap(ctx.kv) + d.marks = copyMap(ctx.marks) + return d +} + +// StoreValue stores any v in to this Context +// k MUST from RegKey. +func (ctx *Context) StoreValue(k uint32, v any) { + if ctx.kv == nil { + ctx.kv = make(map[uint32]any) + } + ctx.kv[k] = v +} + +// GetValue returns the value stored by StoreValue. +func (ctx *Context) GetValue(k uint32) (any, bool) { + v, ok := ctx.kv[k] + return v, ok +} + +// DeleteValue deletes value k from Context +func (ctx *Context) DeleteValue(k uint32) { + delete(ctx.kv, k) +} + +// SetMark marks this Context with given mark. +func (ctx *Context) SetMark(m uint32) { + if ctx.marks == nil { + ctx.marks = make(map[uint32]struct{}) + } + ctx.marks[m] = struct{}{} +} + +// HasMark reports whether this mark m was marked by SetMark. +func (ctx *Context) HasMark(m uint32) bool { + _, ok := ctx.marks[m] + return ok +} + +// DeleteMark deletes mark m from this Context. +func (ctx *Context) DeleteMark(m uint32) { + delete(ctx.marks, m) +} + +// MarshalLogObject implements zapcore.ObjectMarshaler. +func (ctx *Context) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + encoder.AddUint32("uqid", ctx.id) + + if clientAddr := ctx.ServerMeta.ClientAddr; clientAddr.IsValid() { + zap.Stringer("client", clientAddr).AddTo(encoder) + } + + question := ctx.query.Question[0] + encoder.AddString("qname", question.Name) + encoder.AddUint16("qtype", question.Qtype) + encoder.AddUint16("qclass", question.Qclass) + + if r := ctx.resp; r != nil { + encoder.AddInt("rcode", r.Rcode) + } + encoder.AddDuration("elapsed", time.Since(ctx.startTime)) + return nil +} + +func copyMap[K comparable, V any](m map[K]V) map[K]V { + if m == nil { + return nil + } + cm := make(map[K]V, len(m)) + for k, v := range m { + cm[k] = v + } + return cm +} + +func addNewAndSwapOldOpt(m *dns.Msg) *dns.OPT { + for i := len(m.Extra) - 1; i >= 0; i-- { + // If m has oldOpt + if oldOpt, ok := m.Extra[i].(*dns.OPT); ok { + // replace it directly + m.Extra[i] = newOpt() + return oldOpt + } + } + m.Extra = append(m.Extra, newOpt()) + return nil +} + +func popOpt(m *dns.Msg) *dns.OPT { + for i := len(m.Extra) - 1; i >= 0; i-- { + if opt, ok := m.Extra[i].(*dns.OPT); ok { + m.Extra = append(m.Extra[:i], m.Extra[i+1:]...) + return opt + } + } + return nil +} + +func findOpt(m *dns.Msg) *dns.OPT { + for i := len(m.Extra) - 1; i >= 0; i-- { + if opt, ok := m.Extra[i].(*dns.OPT); ok { + return opt + } + } + return nil +} + +func newOpt() *dns.OPT { + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + opt.SetUDPSize(edns0Size) + return opt +} + +func setDo(opt *dns.OPT, do bool) { + const doBit = 1 << 15 // DNSSEC OK + if do { + opt.Hdr.Ttl |= doBit + } +} diff --git a/pkg/query_context/kv.go b/pkg/query_context/kv.go new file mode 100644 index 0000000..0d5824d --- /dev/null +++ b/pkg/query_context/kv.go @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package query_context + +import "sync/atomic" + +var kId atomic.Uint32 + +// RegKey returns a unique uint32 for the key used in +// Context.StoreValue, Context.GetValue. +// It should only be called during initialization. +func RegKey() uint32 { + i := kId.Add(1) + if i == 0 { + panic("key id overflowed") + } + return i +} diff --git a/pkg/rate_limiter/rate_limiter.go b/pkg/rate_limiter/rate_limiter.go new file mode 100644 index 0000000..41bad99 --- /dev/null +++ b/pkg/rate_limiter/rate_limiter.go @@ -0,0 +1,153 @@ +package rate_limiter + +import ( + "net/netip" + "sync" + "time" + + "golang.org/x/time/rate" +) + +const ( + tableShards = 32 + gcInterval = time.Minute +) + +type Limiter struct { + // Limit and Burst are read-only. + Limit rate.Limit + Burst int + + closeOnce sync.Once + closeNotify chan struct{} + tables [tableShards]*tableShard +} + +type tableShard struct { + m sync.Mutex + table map[netip.Addr]*limiterEntry +} + +type limiterEntry struct { + l *rate.Limiter + lastSeen time.Time + sync.Once +} + +// NewRateLimiter creates a new client rate limiter. +// limit and burst should be greater than zero. See rate.Limiter for more +// details. +// Limiter has a internal gc which will run and remove old client entries every 1m. +// If the token refill time (burst/limit) is greater than 1m, +// the actual average qps limit may be higher than expected because the client status +// may be deleted and re-initialized. +func NewRateLimiter(limit rate.Limit, burst int) *Limiter { + l := &Limiter{ + Limit: limit, + Burst: burst, + closeNotify: make(chan struct{}), + } + + for i := range l.tables { + l.tables[i] = &tableShard{table: make(map[netip.Addr]*limiterEntry)} + } + + go l.gcLoop(gcInterval) + return l +} + +// maskedUnmappedP must be a masked prefix and contain a unmapped addr. +func (l *Limiter) Allow(unmappedAddr netip.Addr) bool { + now := time.Now() + shard := l.getTableShard(unmappedAddr) + shard.m.Lock() + e, ok := shard.table[unmappedAddr] + if !ok { + e = &limiterEntry{ + l: rate.NewLimiter(l.Limit, l.Burst), + lastSeen: now, + } + shard.table[unmappedAddr] = e + } + e.lastSeen = now + shard.m.Unlock() + clientLimiter := e.l + return clientLimiter.AllowN(now, 1) +} + +func (l *Limiter) Close() error { + l.closeOnce.Do(func() { + close(l.closeNotify) + }) + return nil +} + +func (l *Limiter) gcLoop(gcInterval time.Duration) { + ticker := time.NewTicker(gcInterval) + defer ticker.Stop() + + for { + select { + case <-l.closeNotify: + return + case now := <-ticker.C: + l.doGc(now, gcInterval) + } + } +} + +func (l *Limiter) doGc(now time.Time, gcInterval time.Duration) { + for _, shard := range l.tables { + shard.m.Lock() + for a, e := range shard.table { + if now.Sub(e.lastSeen) > gcInterval { + delete(shard.table, a) + } + } + shard.m.Unlock() + } +} + +func (l *Limiter) getTableShard(unmappedAddr netip.Addr) *tableShard { + return l.tables[getTableShardIdx(unmappedAddr)] +} + +func (l *Limiter) ForEach(doFunc func(unmappedAddr netip.Addr, r *rate.Limiter) (doBreak bool)) (doBreak bool) { + for _, shard := range l.tables { + shard.m.Lock() + for a, e := range shard.table { + doBreak = doFunc(a, e.l) + if doBreak { + shard.m.Unlock() + return + } + } + shard.m.Unlock() + } + return false +} + +// Len returns current number of entries in the Limiter. +func (l *Limiter) Len() int { + n := 0 + for _, shard := range l.tables { + shard.m.Lock() + n += len(shard.table) + shard.m.Unlock() + } + return n +} + +func getTableShardIdx(unmappedAddr netip.Addr) int { + var i byte + if unmappedAddr.Is4() { + for _, b := range unmappedAddr.As4() { + i ^= b + } + } else { + for _, b := range unmappedAddr.As16() { + i ^= b + } + } + return int(i % tableShards) +} diff --git a/pkg/rate_limiter/rate_limiter_test.go b/pkg/rate_limiter/rate_limiter_test.go new file mode 100644 index 0000000..e29317e --- /dev/null +++ b/pkg/rate_limiter/rate_limiter_test.go @@ -0,0 +1,20 @@ +package rate_limiter + +import ( + "testing" + "time" + + "golang.org/x/time/rate" +) + +func BenchmarkXxx(b *testing.B) { + now := time.Now() + var l *limiterEntry + for i := 0; i < b.N; i++ { + l = &limiterEntry{ + l: rate.NewLimiter(0, 0), + lastSeen: now, + } + } + _ = l +} diff --git a/pkg/safe_close/safe_close.go b/pkg/safe_close/safe_close.go new file mode 100644 index 0000000..51406ad --- /dev/null +++ b/pkg/safe_close/safe_close.go @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package safe_close + +import "sync" + +// SafeClose can achieve safe close where WaitClosed returns only after +// all sub goroutines exited. +// +// 1. Main service goroutine starts and wait on ReceiveCloseSignal. +// 2. Any service's sub goroutine should be started by Attach and wait on ReceiveCloseSignal. +// 3. If any fatal err occurs, any service goroutine can call SendCloseSignal to close the service. +// 4. Any third party caller can call SendCloseSignal to close the service. +type SafeClose struct { + m sync.Mutex + wg sync.WaitGroup + closeSignal chan struct{} + closeErr error +} + +func NewSafeClose() *SafeClose { + return &SafeClose{ + closeSignal: make(chan struct{}), + } +} + +// WaitClosed waits until all SendCloseSignal is called and all +// attached funcs in SafeClose are done. +func (s *SafeClose) WaitClosed() error { + <-s.closeSignal + s.wg.Wait() + return s.closeErr +} + +// SendCloseSignal sends a close signal. Unblock WaitClosed. +// The given error will be read by WaitClosed. +// Once SendCloseSignal is called, following calls are noop. +func (s *SafeClose) SendCloseSignal(err error) { + s.m.Lock() + select { + case <-s.closeSignal: + default: + s.closeErr = err + close(s.closeSignal) + } + s.m.Unlock() +} + +func (s *SafeClose) ReceiveCloseSignal() <-chan struct{} { + return s.closeSignal +} + +// Attach add this goroutine to s.wg WaitClosed. +// f must receive closeSignal and call done when it is done. +// If s was closed, f will not run. +func (s *SafeClose) Attach(f func(done func(), closeSignal <-chan struct{})) { + s.m.Lock() + select { + case <-s.closeSignal: + default: + s.wg.Add(1) + go func() { + f(s.wg.Done, s.closeSignal) + }() + } + s.m.Unlock() +} diff --git a/pkg/server/doq.go b/pkg/server/doq.go new file mode 100644 index 0000000..e6240a1 --- /dev/null +++ b/pkg/server/doq.go @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/quic-go/quic-go" + "go.uber.org/zap" +) + +const ( + defaultQuicIdleTimeout = time.Second * 30 + streamReadTimeout = time.Second * 2 + quicFirstReadTimeout = time.Second * 2 +) + +type DoQServerOpts struct { + Logger *zap.Logger + IdleTimeout time.Duration +} + +// ServeDoQ starts a server at l. It returns if l had an Accept() error. +// It always returns a non-nil error. +func ServeDoQ(l *quic.Listener, h Handler, opts DoQServerOpts) error { + logger := opts.Logger + if logger == nil { + logger = nopLogger + } + idleTimeout := opts.IdleTimeout + if idleTimeout <= 0 { + idleTimeout = defaultQuicIdleTimeout + } + + listenerCtx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errListenerCtxCanceled) + for { + c, err := l.Accept(listenerCtx) + if err != nil { + return fmt.Errorf("unexpected listener err: %w", err) + } + + // handle connection + connCtx, cancelConn := context.WithCancelCause(listenerCtx) + go func() { + defer c.CloseWithError(0, "") + defer cancelConn(errConnectionCtxCanceled) + + var clientAddr netip.Addr + ta, ok := c.RemoteAddr().(*net.UDPAddr) + if ok { + clientAddr = ta.AddrPort().Addr() + } + + firstRead := true + for { + var streamAcceptTimeout time.Duration + if firstRead { + firstRead = false + streamAcceptTimeout = quicFirstReadTimeout + } else { + streamAcceptTimeout = idleTimeout + } + streamAcceptCtx, cancelStreamAccept := context.WithTimeout(connCtx, streamAcceptTimeout) + stream, err := c.AcceptStream(streamAcceptCtx) + cancelStreamAccept() + if err != nil { + return + } + + // Handle stream. + // For doq, one stream, one query. + go func() { + defer func() { + stream.Close() + stream.CancelRead(0) // TODO: Needs a proper error code. + }() + // Avoid fragmentation attack. + stream.SetReadDeadline(time.Now().Add(streamReadTimeout)) + req, _, err := dnsutils.ReadMsgFromTCP(stream) + if err != nil { + return + } + queryMeta := QueryMeta{ + ClientAddr: clientAddr, + ServerName: c.ConnectionState().TLS.ServerName, + } + + resp := h.Handle(connCtx, req, queryMeta, pool.PackTCPBuffer) + if resp == nil { + return + } + if _, err := stream.Write(*resp); err != nil { + logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err)) + } + }() + } + }() + } +} diff --git a/pkg/server/http_handler.go b/pkg/server/http_handler.go new file mode 100644 index 0000000..5a41314 --- /dev/null +++ b/pkg/server/http_handler.go @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import ( + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "net/netip" + "strings" + + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +type HttpHandlerOpts struct { + // GetSrcIPFromHeader specifies the header that contain client source address. + // e.g. "X-Forwarded-For". + GetSrcIPFromHeader string + + // Logger specifies the logger which Handler writes its log to. + // Default is a nop logger. + Logger *zap.Logger +} + +type HttpHandler struct { + dnsHandler Handler + logger *zap.Logger + srcIPHeader string +} + +var _ http.Handler = (*HttpHandler)(nil) + +func NewHttpHandler(h Handler, opts HttpHandlerOpts) *HttpHandler { + hh := new(HttpHandler) + hh.dnsHandler = h + hh.srcIPHeader = opts.GetSrcIPFromHeader + hh.logger = opts.Logger + if hh.logger == nil { + hh.logger = nopLogger + } + return hh +} + +func (h *HttpHandler) warnErr(req *http.Request, msg string, err error) { + h.logger.Warn(msg, zap.String("from", req.RemoteAddr), zap.String("method", req.Method), zap.String("url", req.RequestURI), zap.Error(err)) +} + +func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + addrPort, err := netip.ParseAddrPort(req.RemoteAddr) + if err != nil { + h.logger.Error("failed to parse request remote addr", zap.String("addr", req.RemoteAddr), zap.Error(err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + clientAddr := addrPort.Addr() + + // read remote addr from header + if header := h.srcIPHeader; len(header) != 0 { + if xff := req.Header.Get(header); len(xff) != 0 { + addr, err := readClientAddrFromXFF(xff) + if err != nil { + h.warnErr(req, "failed to get client ip from header", fmt.Errorf("failed to prase header %s: %s, %s", header, xff, err)) + w.WriteHeader(http.StatusBadRequest) + return + } + clientAddr = addr + } + } + + // read msg + q, err := ReadMsgFromReq(req) + if err != nil { + h.warnErr(req, "invalid request", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + queryMeta := QueryMeta{ + ClientAddr: clientAddr, + } + if u := req.URL; u != nil { + queryMeta.UrlPath = u.Path + } + if tlsStat := req.TLS; tlsStat != nil { + queryMeta.ServerName = tlsStat.ServerName + } + resp := h.dnsHandler.Handle(req.Context(), q, queryMeta, pool.PackBuffer) + if resp == nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + defer pool.ReleaseBuf(resp) + w.Header().Set("Content-Type", "application/dns-message") + if _, err := w.Write(*resp); err != nil { + h.warnErr(req, "failed to write response", err) + return + } +} + +func readClientAddrFromXFF(s string) (netip.Addr, error) { + if i := strings.IndexRune(s, ','); i > 0 { + return netip.ParseAddr(s[:i]) + } + return netip.ParseAddr(s) +} + +var errInvalidMediaType = errors.New("missing or invalid media type header") + +var bufPool = pool.NewBytesBufPool(512) + +func ReadMsgFromReq(req *http.Request) (*dns.Msg, error) { + var b []byte + + switch req.Method { + case http.MethodGet: + // Check accept header + if req.Header.Get("Accept") != "application/dns-message" { + return nil, errInvalidMediaType + } + + s := req.URL.Query().Get("dns") + if len(s) == 0 { + return nil, errors.New("no dns parameter") + } + msgSize := base64.RawURLEncoding.DecodedLen(len(s)) + if msgSize > dns.MaxMsgSize { + return nil, fmt.Errorf("msg length %d is too big", msgSize) + } + + var err error + b, err = base64.RawURLEncoding.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 query: %w", err) + } + + case http.MethodPost: + // Check Content-Type header + if req.Header.Get("Content-Type") != "application/dns-message" { + return nil, errInvalidMediaType + } + + buf := bufPool.Get() + defer bufPool.Release(buf) + _, err := buf.ReadFrom(io.LimitReader(req.Body, dns.MaxMsgSize)) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + b = buf.Bytes() + default: + return nil, fmt.Errorf("unsupported method: %s", req.Method) + } + + m := new(dns.Msg) + if err := m.Unpack(b); err != nil { + return nil, fmt.Errorf("failed to unpack msg [%x], %w", b, err) + } + return m, nil +} diff --git a/pkg/server/iface.go b/pkg/server/iface.go new file mode 100644 index 0000000..c45b502 --- /dev/null +++ b/pkg/server/iface.go @@ -0,0 +1,29 @@ +package server + +import ( + "context" + "net/netip" + + "github.com/miekg/dns" +) + +// Handler handles incoming request q and MUST ALWAYS return a response. +// Handler MUST handle dns errors by itself and return a proper error responses. +// e.g. Return a SERVFAIL if something goes wrong. +// If Handle() returns a nil resp, caller will +// udp: do nothing. +// tcp/dot: close the connection immediately. +// doh: send a 500 response. +// doq: close the stream immediately. +type Handler interface { + Handle(ctx context.Context, q *dns.Msg, meta QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) (respPayload *[]byte) +} + +type QueryMeta struct { + FromUDP bool + + // Optional + ClientAddr netip.Addr + ServerName string + UrlPath string +} diff --git a/pkg/server/tcp.go b/pkg/server/tcp.go new file mode 100644 index 0000000..8af18bd --- /dev/null +++ b/pkg/server/tcp.go @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/netip" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "go.uber.org/zap" +) + +const ( + defaultTCPIdleTimeout = time.Second * 10 + tcpFirstReadTimeout = time.Second * 2 +) + +type TCPServerOpts struct { + // Nil logger == nop + Logger *zap.Logger + + // Default is defaultTCPIdleTimeout. + IdleTimeout time.Duration +} + +// ServeTCP starts a server at l. It returns if l had an Accept() error. +// It always returns a non-nil error. +func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error { + logger := opts.Logger + if logger == nil { + logger = nopLogger + } + idleTimeout := opts.IdleTimeout + if idleTimeout <= 0 { + idleTimeout = defaultTCPIdleTimeout + } + firstReadTimeout := tcpFirstReadTimeout + if idleTimeout < firstReadTimeout { + firstReadTimeout = idleTimeout + } + + listenerCtx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errListenerCtxCanceled) + for { + c, err := l.Accept() + if err != nil { + return fmt.Errorf("unexpected listener err: %w", err) + } + + // handle connection + tcpConnCtx, cancelConn := context.WithCancelCause(listenerCtx) + go func() { + defer c.Close() + defer cancelConn(errConnectionCtxCanceled) + + firstRead := true + for { + if firstRead { + firstRead = false + c.SetReadDeadline(time.Now().Add(firstReadTimeout)) + } else { + c.SetReadDeadline(time.Now().Add(idleTimeout)) + } + req, _, err := dnsutils.ReadMsgFromTCP(c) + if err != nil { + return // read err, close the connection + } + + // Try to get server name from tls conn. + var serverName string + if tlsConn, ok := c.(*tls.Conn); ok { + serverName = tlsConn.ConnectionState().ServerName + } + + // handle query + go func() { + var clientAddr netip.Addr + ta, ok := c.RemoteAddr().(*net.TCPAddr) + if ok { + clientAddr = ta.AddrPort().Addr() + } + r := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr, ServerName: serverName}, pool.PackTCPBuffer) + if r == nil { + c.Close() // abort the connection + return + } + defer pool.ReleaseBuf(r) + + if _, err := c.Write(*r); err != nil { + logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err)) + return + } + }() + } + }() + } +} diff --git a/pkg/server/tls.go b/pkg/server/tls.go new file mode 100644 index 0000000..64d7045 --- /dev/null +++ b/pkg/server/tls.go @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import ( + "crypto/tls" +) + +func LoadCert(tlsCfg *tls.Config, cert, key string) error { + c, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return err + } + tlsCfg.Certificates = []tls.Certificate{c} + return nil +} diff --git a/pkg/server/udp.go b/pkg/server/udp.go new file mode 100644 index 0000000..aee5675 --- /dev/null +++ b/pkg/server/udp.go @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import ( + "context" + "fmt" + "net" + + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +type UDPServerOpts struct { + Logger *zap.Logger +} + +// ServeUDP starts a server at c. It returns if c had a read error. +// It always returns a non-nil error. +// h is required. logger is optional. +func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error { + logger := opts.Logger + if logger == nil { + logger = nopLogger + } + + listenerCtx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errListenerCtxCanceled) + + rb := pool.GetBuf(dns.MaxMsgSize) + defer pool.ReleaseBuf(rb) + + oobReader, oobWriter, err := initOobHandler(c) + if err != nil { + return fmt.Errorf("failed to init oob handler, %w", err) + } + var ob []byte + if oobReader != nil { + obp := pool.GetBuf(1024) + defer pool.ReleaseBuf(obp) + ob = *obp + } + + for { + n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob) + if err != nil { + if n == 0 { + // Err with zero read. Most likely because c was closed. + return fmt.Errorf("unexpected read err: %w", err) + } + // Temporary err. + logger.Warn("read err", zap.Error(err)) + continue + } + + q := new(dns.Msg) + if err := q.Unpack((*rb)[:n]); err != nil { + logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr)) + continue + } + + var dstIpFromCm net.IP + if oobReader != nil { + var err error + dstIpFromCm, err = oobReader(ob[:oobn]) + if err != nil { + logger.Error("failed to get dst address from oob", zap.Error(err)) + } + } + + // handle query + go func() { + payload := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}, pool.PackBuffer) + if payload == nil { + return + } + defer pool.ReleaseBuf(payload) + + var oob []byte + if oobWriter != nil && dstIpFromCm != nil { + oob = oobWriter(dstIpFromCm) + } + if _, _, err := c.WriteMsgUDPAddrPort(*payload, oob, remoteAddr); err != nil { + logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) + } + }() + } +} + +type getSrcAddrFromOOB func(oob []byte) (net.IP, error) +type writeSrcAddrToOOB func(a net.IP) []byte diff --git a/pkg/server/udp_linux.go b/pkg/server/udp_linux.go new file mode 100644 index 0000000..9728a39 --- /dev/null +++ b/pkg/server/udp_linux.go @@ -0,0 +1,125 @@ +//go:build linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import ( + "errors" + "fmt" + "net" + "os" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" +) + +var ( + errCmNoDstAddr = errors.New("control msg does not have dst address") +) + +func getOOBFromCM4(oob []byte) (net.IP, error) { + var cm ipv4.ControlMessage + if err := cm.Parse(oob); err != nil { + return nil, err + } + if cm.Dst == nil { + return nil, errCmNoDstAddr + } + return cm.Dst, nil +} + +func getOOBFromCM6(oob []byte) (net.IP, error) { + var cm ipv6.ControlMessage + if err := cm.Parse(oob); err != nil { + return nil, err + } + if cm.Dst == nil { + return nil, errCmNoDstAddr + } + return cm.Dst, nil +} + +func srcIP2Cm(ip net.IP) []byte { + if ip4 := ip.To4(); ip4 != nil { + return (&ipv4.ControlMessage{ + Src: ip, + }).Marshal() + } + + if ip6 := ip.To16(); ip6 != nil { + return (&ipv6.ControlMessage{ + Src: ip, + }).Marshal() + } + + return nil +} + +func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) { + if !c.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() { + return nil, nil, nil + } + + sc, err := c.SyscallConn() + if err != nil { + return nil, nil, err + } + + var getter getSrcAddrFromOOB + var setter writeSrcAddrToOOB + var controlErr error + if err := sc.Control(func(fd uintptr) { + v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN) + if err != nil { + controlErr = os.NewSyscallError("failed to get SO_PROTOCOL", err) + return + } + switch v { + case unix.AF_INET: + c4 := ipv4.NewPacketConn(c) + if err := c4.SetControlMessage(ipv4.FlagDst, true); err != nil { + controlErr = fmt.Errorf("failed to set ipv4 cmsg flags, %w", err) + } + + getter = getOOBFromCM4 + setter = srcIP2Cm + return + case unix.AF_INET6: + c6 := ipv6.NewPacketConn(c) + if err := c6.SetControlMessage(ipv6.FlagDst, true); err != nil { + controlErr = fmt.Errorf("failed to set ipv6 cmsg flags, %w", err) + } + getter = getOOBFromCM6 + setter = srcIP2Cm + return + default: + controlErr = fmt.Errorf("socket protocol %d is not supported", v) + } + }); err != nil { + return nil, nil, fmt.Errorf("control fd err, %w", controlErr) + } + + if controlErr != nil { + return nil, nil, fmt.Errorf("failed to set up socket, %w", controlErr) + } + return getter, setter, nil +} diff --git a/pkg/server/udp_others.go b/pkg/server/udp_others.go new file mode 100644 index 0000000..1e42651 --- /dev/null +++ b/pkg/server/udp_others.go @@ -0,0 +1,28 @@ +//go:build !linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import "net" + +func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) { + return nil, nil, nil +} diff --git a/pkg/server/utils.go b/pkg/server/utils.go new file mode 100644 index 0000000..597c171 --- /dev/null +++ b/pkg/server/utils.go @@ -0,0 +1,16 @@ +package server + +import ( + "errors" + + "go.uber.org/zap" +) + +var ( + errListenerCtxCanceled = errors.New("listener ctx canceled") + errConnectionCtxCanceled = errors.New("connection ctx canceled") +) + +var ( + nopLogger = zap.NewNop() +) diff --git a/pkg/server_handler/entry_handler.go b/pkg/server_handler/entry_handler.go new file mode 100644 index 0000000..c30fcd7 --- /dev/null +++ b/pkg/server_handler/entry_handler.go @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server_handler + +import ( + "context" + "time" + + "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +const ( + defaultQueryTimeout = time.Second * 5 +) + +var ( + nopLogger = mlog.Nop() + + // options that can forward to upstream + queryForwardEDNS0Option = map[uint16]struct{}{ + dns.EDNS0SUBNET: {}, + } + + // options that useless for downstream + respRemoveEDNS0Option = map[uint16]struct{}{ + dns.EDNS0PADDING: {}, + } +) + +type EntryHandlerOpts struct { + // Logger is used for logging. Default is a noop logger. + Logger *zap.Logger + + // Required. + Entry sequence.Executable + + // QueryTimeout limits the timeout value of each query. + // Default is defaultQueryTimeout. + QueryTimeout time.Duration +} + +func (opts *EntryHandlerOpts) init() { + if opts.Logger == nil { + opts.Logger = nopLogger + } + utils.SetDefaultNum(&opts.QueryTimeout, defaultQueryTimeout) +} + +type EntryHandler struct { + opts EntryHandlerOpts +} + +var _ server.Handler = (*EntryHandler)(nil) + +func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler { + opts.init() + return &EntryHandler{opts: opts} +} + +// ServeDNS implements server.Handler. +// If entry returns an error, a SERVFAIL response will be returned. +// If entry returns without a response, a REFUSED response will be returned. +func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, serverMeta server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte { + // basic query check. + if q.Response || len(q.Question) != 1 || len(q.Answer)+len(q.Ns) > 0 || len(q.Extra) > 1 { + return nil + } + + ddl := time.Now().Add(h.opts.QueryTimeout) + ctx, cancel := context.WithDeadline(ctx, ddl) + defer cancel() + + qCtx := query_context.NewContext(q) + qCtx.ServerMeta = serverMeta + + // exec entry + err := h.opts.Entry.Exec(ctx, qCtx) + var resp *dns.Msg + if err != nil { + h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err)) + resp = new(dns.Msg) + resp.SetReply(q) + resp.Rcode = dns.RcodeServerFailure + } else { + resp = qCtx.R() + } + + if resp == nil { + resp = new(dns.Msg) + resp.SetReply(q) + resp.Rcode = dns.RcodeRefused + } + // We assume that our server is a forwarder. + resp.RecursionAvailable = true + + // add respOpt back to resp + if respOpt := qCtx.RespOpt(); respOpt != nil { + resp.Extra = append(resp.Extra, respOpt) + } + + if serverMeta.FromUDP { + udpSize := getValidUDPSize(qCtx.ClientOpt()) + resp.Truncate(udpSize) + } + + payload, err := packMsgPayload(resp) + if err != nil { + h.opts.Logger.Error("internal err: failed to pack resp msg", qCtx.InfoField(), zap.Error(err)) + return nil + } + return payload +} + +// opt can be nil. +func getValidUDPSize(opt *dns.OPT) int { + var s uint16 + if opt != nil { + s = opt.UDPSize() + } + if s < dns.MinMsgSize { + s = dns.MinMsgSize + } + return int(s) +} + +func newOpt() *dns.OPT { + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + return opt +} diff --git a/pkg/upstream/bootstrap/bootstrap.go b/pkg/upstream/bootstrap/bootstrap.go new file mode 100644 index 0000000..610a587 --- /dev/null +++ b/pkg/upstream/bootstrap/bootstrap.go @@ -0,0 +1,248 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package bootstrap + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +const ( + minimumUpdateInterval = time.Minute * 5 + retryInterval = time.Second * 2 + queryTimeout = time.Second * 5 +) + +var ( + errNoAddrInResp = errors.New("resp does not have ip address") +) + +func New( + host string, + port uint16, + bootstrapServer netip.AddrPort, + bootstrapVer int, // 0,4,6 + logger *zap.Logger, // not nil +) (*Bootstrap, error) { + dp := new(Bootstrap) + dp.fqdn = dns.Fqdn(host) + dp.port = port + if !bootstrapServer.IsValid() { + return nil, errors.New("invalid bootstrap server address") + } + dp.bootstrap = net.UDPAddrFromAddrPort(bootstrapServer) + qt, ok := bootstrapVer2Qt(bootstrapVer) + if !ok { + return nil, fmt.Errorf("invalid bootstrap version %d", bootstrapVer) + } + dp.qt = qt + dp.logger = logger + + dp.readyNotify = make(chan struct{}) + return dp, nil +} + +type Bootstrap struct { + fqdn string + port uint16 + bootstrap *net.UDPAddr + qt uint16 // dns.TypeA or dns.TypeAAAA + logger *zap.Logger // not nil + + updating atomic.Bool + nextUpdate time.Time + + readyNotify chan struct{} + m sync.Mutex + ready bool + addrStr string +} + +func (sp *Bootstrap) GetAddrPortStr(ctx context.Context) (string, error) { + sp.tryUpdate() + + select { + case <-ctx.Done(): + return "", context.Cause(ctx) + case <-sp.readyNotify: + } + + sp.m.Lock() + addr := sp.addrStr + sp.m.Unlock() + return addr, nil +} + +func (sp *Bootstrap) tryUpdate() { + if sp.updating.CompareAndSwap(false, true) { + if time.Now().After(sp.nextUpdate) { + go func() { + defer sp.updating.Store(false) + ctx, cancel := context.WithTimeout(context.Background(), queryTimeout) + defer cancel() + start := time.Now() + addr, ttl, err := sp.updateAddr(ctx) + if err != nil { + sp.logger.Check(zap.WarnLevel, "failed to update bootstrap addr").Write( + zap.String("fqdn", sp.fqdn), + zap.Error(err), + ) + sp.nextUpdate = time.Now().Add(retryInterval) + } else { + updateInterval := time.Second * time.Duration(ttl) + if updateInterval < minimumUpdateInterval { + updateInterval = minimumUpdateInterval + } + sp.logger.Check(zap.DebugLevel, "bootstrap addr updated").Write( + zap.String("fqdn", sp.fqdn), + zap.Stringer("addr", addr), + zap.Duration("ttl", updateInterval), + zap.Duration("elapse", time.Since(start)), + ) + sp.nextUpdate = time.Now().Add(updateInterval) + } + }() + } else { + sp.updating.Store(false) + } + } +} + +func (sp *Bootstrap) updateAddr(ctx context.Context) (netip.Addr, uint32, error) { + addr, ttl, err := sp.resolve(ctx, sp.qt) + if err != nil { + return netip.Addr{}, 0, err + } + + addrPort := netip.AddrPortFrom(addr, sp.port).String() + sp.m.Lock() + sp.addrStr = addrPort + if !sp.ready { + sp.ready = true + close(sp.readyNotify) + } + sp.m.Unlock() + return addr, ttl, nil +} + +func (sp *Bootstrap) resolve(ctx context.Context, qt uint16) (netip.Addr, uint32, error) { + const edns0UdpSize = 1200 + + q := new(dns.Msg) + q.SetQuestion(sp.fqdn, qt) + q.SetEdns0(edns0UdpSize, false) + + c, err := net.DialUDP("udp", nil, sp.bootstrap) + if err != nil { + return netip.Addr{}, 0, err + } + defer c.Close() + + writeErrC := make(chan error, 1) + type res struct { + resp *dns.Msg + err error + } + readResC := make(chan res, 1) + + cancelWrite := make(chan struct{}) + defer close(cancelWrite) + go func() { + if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil { + writeErrC <- err + return + } + + retryTicker := time.NewTicker(time.Second) + defer retryTicker.Stop() + for { + select { + case <-cancelWrite: + return + case <-retryTicker.C: + if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil { + writeErrC <- err + return + } + } + } + }() + + go func() { + m, _, err := dnsutils.ReadMsgFromUDP(c, edns0UdpSize) + readResC <- res{resp: m, err: err} + }() + + select { + case <-ctx.Done(): + return netip.Addr{}, 0, context.Cause(ctx) + case err := <-writeErrC: + return netip.Addr{}, 0, fmt.Errorf("failed to write query, %w", err) + case r := <-readResC: + resp := r.resp + err := r.err + if err != nil { + return netip.Addr{}, 0, fmt.Errorf("failed to read resp, %w", err) + } + + for _, v := range resp.Answer { + var ip net.IP + var ttl uint32 + switch rr := v.(type) { + case *dns.A: + ip = rr.A + ttl = rr.Hdr.Ttl + case *dns.AAAA: + ip = rr.AAAA + ttl = rr.Hdr.Ttl + default: + continue + } + addr, ok := netip.AddrFromSlice(ip) + if ok { + return addr, ttl, nil + } + } + + // No ip addr in resp. + return netip.Addr{}, 0, errNoAddrInResp + } +} + +func bootstrapVer2Qt(ver int) (uint16, bool) { + switch ver { + case 0, 4: + return dns.TypeA, true + case 6: + return dns.TypeAAAA, true + default: + return 0, false + } +} diff --git a/pkg/upstream/doh/upstream.go b/pkg/upstream/doh/upstream.go new file mode 100644 index 0000000..d92ce4b --- /dev/null +++ b/pkg/upstream/doh/upstream.go @@ -0,0 +1,166 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package doh + +import ( + "context" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "net/http" + urlpkg "net/url" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +const ( + defaultDoHTimeout = time.Second * 6 +) + +var nopLogger = zap.NewNop() + +// Upstream is a DNS-over-HTTPS (RFC 8484) upstream. +type Upstream struct { + rt http.RoundTripper + logger *zap.Logger // non-nil + urlTemplate *urlpkg.URL + reqTemplate *http.Request +} + +func NewUpstream(endPoint string, rt http.RoundTripper, logger *zap.Logger) (*Upstream, error) { + req, err := http.NewRequest(http.MethodGet, endPoint, nil) + if err != nil { + return nil, fmt.Errorf("failed to parse http request, %w", err) + } + + req.Header["Accept"] = []string{"application/dns-message"} + req.Header["User-Agent"] = nil // Don't let go http send a default user agent header. + + if logger == nil { + logger = nopLogger + } + return &Upstream{ + rt: rt, + logger: logger, + urlTemplate: req.URL, + reqTemplate: req, + }, nil +} + +var ( + bufPool4k = pool.NewBytesBufPool(4096) +) + +func (u *Upstream) ExchangeContext(ctx context.Context, q []byte) (*[]byte, error) { + bp := pool.GetBuf(len(q)) + defer pool.ReleaseBuf(bp) + wire := *bp + copy(wire, q) + + // In order to maximize HTTP cache friendliness, DoH clients using media + // formats that include the ID field from the DNS message header, such + // as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS + // request. + // https://tools.ietf.org/html/rfc8484#section-4.1 + wire[0] = 0 + wire[1] = 0 + + queryLen := 4 + base64.RawURLEncoding.EncodedLen(len(wire)) + queryBuf := make([]byte, queryLen) + + p := 0 + p += copy(queryBuf, "dns=") + + // Padding characters for base64url MUST NOT be included. + // See: https://tools.ietf.org/html/rfc8484#section-6. + base64.RawURLEncoding.Encode(queryBuf[p:], wire) + + type res struct { + r *[]byte + err error + } + + resChan := make(chan res, 1) + go func() { + // We overwrite the ctx with a fixed timeout context here. + // Because the http package may close the underlay connection + // if the context is done before the query is completed. This + // reduces the connection reuse efficiency. + ctx, cancel := context.WithTimeout(context.Background(), defaultDoHTimeout) + defer cancel() + r, err := u.exchange(ctx, utils.BytesToStringUnsafe(queryBuf)) + if err != nil { + u.logger.Check(zap.WarnLevel, "exchange failed").Write(zap.Error(err)) + } + resChan <- res{r: r, err: err} + }() + + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case res := <-resChan: + r := res.r + err := res.err + if r != nil { + binary.BigEndian.PutUint16(*r, binary.BigEndian.Uint16(q)) + } + return r, err + } +} + +func (u *Upstream) exchange(ctx context.Context, dnsQuery string) (*[]byte, error) { + req := u.reqTemplate.WithContext(ctx) + req.URL = new(urlpkg.URL) + *req.URL = *u.urlTemplate + req.URL.RawQuery = dnsQuery + resp, err := u.rt.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("http request failed: %w", err) + } + defer resp.Body.Close() + + // check status code + if resp.StatusCode != http.StatusOK { + body1k, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + if body1k != nil { + return nil, fmt.Errorf("bad http status codes %d with body [%s]", resp.StatusCode, body1k) + } + return nil, fmt.Errorf("bad http status codes %d", resp.StatusCode) + } + + bb := bufPool4k.Get() + defer bufPool4k.Release(bb) + _, err = bb.ReadFrom(io.LimitReader(resp.Body, dns.MaxMsgSize)) + if err != nil { + return nil, fmt.Errorf("failed to read http body: %w", err) + } + if bb.Len() < dnsutils.DnsHeaderLen { + return nil, dnsutils.ErrPayloadTooSmall + } + payload := pool.GetBuf(bb.Len()) + copy(*payload, bb.Bytes()) + return payload, nil +} diff --git a/pkg/upstream/event_stat.go b/pkg/upstream/event_stat.go new file mode 100644 index 0000000..6b7b803 --- /dev/null +++ b/pkg/upstream/event_stat.go @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package upstream + +import ( + "net" + "sync/atomic" +) + +type Event int + +const ( + EventConnOpen Event = iota + EventConnClose +) + +type EventObserver interface { + OnEvent(typ Event) +} + +type nopEO struct{} + +func (n nopEO) OnEvent(_ Event) {} + +type connWrapper struct { + net.Conn + closed atomic.Bool + ob EventObserver +} + +// wrapConn wraps c into a connWrapper so that we can observe the connection close. +// For convenient, if c is nil, wrapConn returns nil as well. If ob is nopEO, wrapConn +// returns c. +func wrapConn(c net.Conn, ob EventObserver) net.Conn { + if c == nil { + return nil + } + if _, ok := ob.(nopEO); ok { + return c + } + ob.OnEvent(EventConnOpen) + return &connWrapper{ + Conn: c, + ob: ob, + } +} + +func (c *connWrapper) Close() error { + if c.closed.CompareAndSwap(false, true) { + c.ob.OnEvent(EventConnClose) + } + return c.Conn.Close() +} diff --git a/pkg/upstream/transport/conn_lazy_dial.go b/pkg/upstream/transport/conn_lazy_dial.go new file mode 100644 index 0000000..a8c4e0b --- /dev/null +++ b/pkg/upstream/transport/conn_lazy_dial.go @@ -0,0 +1,165 @@ +package transport + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" +) + +type lazyDnsConn struct { + maxConcurrentQuery int + cancelDial context.CancelFunc + + mu sync.Mutex + earlyReserveCallWg sync.WaitGroup + closed bool + reservedQuery int + dialFinished chan struct{} + c DnsConn + dialErr error + + // 1: dial completed and all early reserve call finished. + // 2: dial failed. + fastPath atomic.Uint32 +} + +var _ DnsConn = (*lazyDnsConn)(nil) + +var ( + errLazyConnDialCanceled = errors.New("lazy dial canceled") +) + +func newLazyDnsConn( + dial func(ctx context.Context) (DnsConn, error), + dialTimeout time.Duration, + maxConcurrentQueryWhileDialing int, // must be valid, no default value + logger *zap.Logger, // must non-nil +) *lazyDnsConn { + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + dialCtx, cancelDial := context.WithTimeout(context.Background(), defaultDialTimeout) + lc := &lazyDnsConn{ + maxConcurrentQuery: maxConcurrentQueryWhileDialing, + cancelDial: cancelDial, + dialFinished: make(chan struct{}), + } + + go func() { + dc, err := dial(dialCtx) + cancelDial() + if err != nil { + logger.Check(zap.WarnLevel, "failed to dial dns conn").Write(zap.Error(err)) + } + lc.mu.Lock() + if lc.closed { // lc was closed and dial was canceled + lc.mu.Unlock() + if dc != nil { + dc.Close() + } + return + } + + lc.c = dc + lc.dialErr = err + close(lc.dialFinished) + lc.mu.Unlock() + }() + return lc +} + +func (lc *lazyDnsConn) Close() error { + lc.mu.Lock() + defer lc.mu.Unlock() + + if lc.closed { + return nil + } + lc.closed = true + + if lc.c == nil && lc.dialErr == nil { // still dialing + lc.cancelDial() + lc.dialErr = errLazyConnDialCanceled + close(lc.dialFinished) + } else { + // close connection + if lc.c != nil { + lc.c.Close() + } + } + return nil +} + +func (lc *lazyDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) { + switch lc.fastPath.Load() { + case 1: + return lc.c.ReserveNewQuery() + case 2: + return nil, true + } + + lc.mu.Lock() + defer lc.mu.Unlock() + + select { + case <-lc.dialFinished: + // Note: race condition here and lazyDnsConnEarlyReservedExchanger.ExchangeReserved(). + // Not a big problem. May cause at most all early exchange failed. + // earlyExchangeWg makes sure that early exchange calls ReserveNewQuery first. + dc, err := lc.c, lc.dialErr + if err != nil { + lc.fastPath.Store(2) + return nil, true + } + lc.earlyReserveCallWg.Wait() + lc.fastPath.Store(1) + return dc.ReserveNewQuery() + default: + if lc.reservedQuery >= lc.maxConcurrentQuery { + return nil, false + } + lc.reservedQuery++ + lc.earlyReserveCallWg.Add(1) + return (*lazyDnsConnEarlyReservedExchanger)(lc), false + } +} + +type lazyDnsConnEarlyReservedExchanger lazyDnsConn + +var _ ReservedExchanger = (*lazyDnsConnEarlyReservedExchanger)(nil) + +func (ote *lazyDnsConnEarlyReservedExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) { + defer func() { + ote.mu.Lock() + ote.reservedQuery-- + ote.mu.Unlock() + }() + + select { + case <-ctx.Done(): + ote.earlyReserveCallWg.Done() + return nil, context.Cause(ctx) + case <-ote.dialFinished: + dc, err := ote.c, ote.dialErr + if err != nil { + return nil, err + } + rec, _ := dc.ReserveNewQuery() + ote.earlyReserveCallWg.Done() + if rec == nil { + return nil, ErrLazyConnCannotReserveQueryExchanger + } + return rec.ExchangeReserved(ctx, q) + } +} + +func (ote *lazyDnsConnEarlyReservedExchanger) WithdrawReserved() { + ote.earlyReserveCallWg.Done() + ote.mu.Lock() + ote.reservedQuery-- + ote.mu.Unlock() +} diff --git a/pkg/upstream/transport/conn_quic.go b/pkg/upstream/transport/conn_quic.go new file mode 100644 index 0000000..1427c38 --- /dev/null +++ b/pkg/upstream/transport/conn_quic.go @@ -0,0 +1,123 @@ +package transport + +import ( + "context" + "encoding/binary" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/quic-go/quic-go" +) + +const ( + quicQueryTimeout = time.Second * 6 +) + +const ( + // RFC 9250 4.3. DoQ Error Codes + _DOQ_NO_ERROR = quic.StreamErrorCode(0x0) + _DOQ_INTERNAL_ERROR = quic.StreamErrorCode(0x1) + _DOQ_REQUEST_CANCELLED = quic.StreamErrorCode(0x3) +) + +var _ DnsConn = (*QuicDnsConn)(nil) + +type QuicDnsConn struct { + c quic.Connection +} + +func NewQuicDnsConn(c quic.Connection) *QuicDnsConn { + return &QuicDnsConn{c: c} +} + +func (c *QuicDnsConn) Close() error { + return c.c.CloseWithError(0, "") +} + +func (c *QuicDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) { + select { + case <-c.c.Context().Done(): + return nil, true + default: + } + s, err := c.c.OpenStream() + // We just checked the connection is alive. So we are assuming the error + // is caused by reaching the peer's stream limit. + if err != nil { + return nil, false + } + return &quicReservedExchanger{stream: s}, false +} + +type quicReservedExchanger struct { + stream quic.Stream +} + +var _ ReservedExchanger = (*quicReservedExchanger)(nil) + +func (ote *quicReservedExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) { + stream := ote.stream + + payload, err := copyMsgWithLenHdr(q) + if err != nil { + stream.CancelWrite(_DOQ_REQUEST_CANCELLED) + stream.CancelRead(_DOQ_REQUEST_CANCELLED) + return nil, err + } + + // 4.2.1. DNS Message IDs + // When sending queries over a QUIC connection, the DNS Message ID MUST + // be set to 0. The stream mapping for DoQ allows for unambiguous + // correlation of queries and responses, so the Message ID field is not + // required. + orgQid := binary.BigEndian.Uint16((*payload)[2:]) + binary.BigEndian.PutUint16((*payload)[2:], 0) + + stream.SetDeadline(time.Now().Add(quicQueryTimeout)) + _, err = stream.Write(*payload) + pool.ReleaseBuf(payload) + if err != nil { + stream.CancelRead(_DOQ_REQUEST_CANCELLED) + stream.CancelWrite(_DOQ_REQUEST_CANCELLED) + return nil, err + } + + // RFC 9250 4.2 + // The client MUST send the DNS query over the selected stream and MUST + // indicate through the STREAM FIN mechanism that no further data will + // be sent on that stream. + // + // Call Close() here will send the STREAM FIN. It won't close Read. + stream.Close() + + type res struct { + resp *[]byte + err error + } + rc := make(chan res, 1) + go func() { + r, err := dnsutils.ReadRawMsgFromTCP(stream) + rc <- res{resp: r, err: err} + }() + + select { + case <-ctx.Done(): + stream.CancelRead(_DOQ_REQUEST_CANCELLED) + return nil, context.Cause(ctx) + case r := <-rc: + resp := r.resp + err := r.err + if resp != nil { + binary.BigEndian.PutUint16((*resp), orgQid) + } + stream.CancelRead(_DOQ_NO_ERROR) + return resp, err + } +} + +func (ote *quicReservedExchanger) WithdrawReserved() { + s := ote.stream + s.CancelRead(_DOQ_REQUEST_CANCELLED) + s.CancelWrite(_DOQ_REQUEST_CANCELLED) +} \ No newline at end of file diff --git a/pkg/upstream/transport/conn_traditional.go b/pkg/upstream/transport/conn_traditional.go new file mode 100644 index 0000000..148eb61 --- /dev/null +++ b/pkg/upstream/transport/conn_traditional.go @@ -0,0 +1,283 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package transport + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" +) + +var ( + ErrTDCTooManyQueries = errors.New("too many queries") // Connection has too many ongoing queries. + ErrTDCClosed = errors.New("dns connection closed") +) + +var _ DnsConn = (*TraditionalDnsConn)(nil) + +// TraditionalDnsConn is a low-level connection for traditional dns protocol, where +// dns frames transport in a single and simple connection. (e.g. udp, tcp, tls) +type TraditionalDnsConn struct { + c NetConn + isTcp bool + idleTimeout time.Duration + maxCq int + + closeOnce sync.Once + closeNotify chan struct{} + closed atomic.Bool // atomic, for fast check + closeErr error // closeErr is ready (not nil) when closeNotify is closed. + + queueMu sync.RWMutex + reservedQuery int + nextQid uint16 + queue map[uint32]chan *[]byte // uint32 has fast path + + // waitingResp indicates connection is waiting a reply from the peer. + // It can identify c is dead or buggy in some circumstances. e.g. Network is dropped + // and the sockets were still open because no fin or rst was received. + waitingResp atomic.Bool +} + +type TraditionalDnsConnOpts struct { + // Set to true if underlayer connection require a length header. + // e.g. TCP and DoT. + WithLengthHeader bool + + // IdleTimeout controls the maximum idle time for each connection. + // Default is defaultIdleTimeout. + IdleTimeout time.Duration + + // MaxConcurrentQuery limits the number of maximum concurrent queries + // in the connection. Default is defaultTdcMaxConcurrentQuery. + MaxConcurrentQuery int +} + +func NewDnsConn(opt TraditionalDnsConnOpts, conn NetConn) *TraditionalDnsConn { + dc := &TraditionalDnsConn{ + c: conn, + isTcp: opt.WithLengthHeader, + closeNotify: make(chan struct{}), + queue: make(map[uint32]chan *[]byte), + } + setDefaultGZ(&dc.idleTimeout, opt.IdleTimeout, defaultIdleTimeout) + setDefaultGZ(&dc.maxCq, opt.MaxConcurrentQuery, defaultTdcMaxConcurrentQuery) + + go dc.readLoop() + return dc +} + +// exchange sends q out and waits for its reply. +func (dc *TraditionalDnsConn) exchange(ctx context.Context, q []byte) (*[]byte, error) { + select { + case <-dc.closeNotify: + return nil, ErrTDCClosed + default: + } + + assignedQid, respChan := dc.addQueueC() + if respChan == nil { + return nil, ErrTDCTooManyQueries + } + defer dc.deleteQueueC(assignedQid) + + // Reminder: Set write deadline here is not very useful to avoid dead connections. + // Typically, a write operation will time out only if its socket buffer is full. + // Ser read deadline is enough. + err := dc.writeQuery(q, assignedQid) + if err != nil { + // Write error usually is fatal. Abort and close this connection. + dc.CloseWithErr(fmt.Errorf("write err, %w", err)) + return nil, err + } + + // If a query was sent, server should have a reply (even not for this query) in a short time. + // This indicates the connection is healthy. Otherwise, this connection might be dead. + // The Read deadline will be refreshed in DnsConn.readLoop() after every successful read. + // Note: There has a race condition in this SetReadDeadline() call and the one in + // readLoop(). It's not a big problem. + if dc.waitingResp.CompareAndSwap(false, true) { + dc.c.SetReadDeadline(time.Now().Add(waitingReplyTimeout)) + } + + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case r := <-respChan: + orgId := binary.BigEndian.Uint16(q) + binary.BigEndian.PutUint16(*r, orgId) + return r, nil + case <-dc.closeNotify: + return nil, dc.closeErr + } +} + +func (dc *TraditionalDnsConn) writeQuery(q []byte, assignedQid uint16) error { + var payload *[]byte + if dc.isTcp { + var err error + payload, err = copyMsgWithLenHdr(q) + if err != nil { + return err + } + binary.BigEndian.PutUint16((*payload)[2:], assignedQid) + } else { + payload = copyMsg(q) + binary.BigEndian.PutUint16(*payload, assignedQid) + } + _, err := dc.c.Write(*payload) + pool.ReleaseBuf(payload) + return err +} + +func (dc *TraditionalDnsConn) readResp() (payload *[]byte, err error) { + if dc.isTcp { + return dnsutils.ReadRawMsgFromTCP(dc.c) + } + return readMsgUdp(dc.c) +} + +// readLoop reads DnsConn until there was a read error. +func (dc *TraditionalDnsConn) readLoop() { + + for { + dc.c.SetReadDeadline(time.Now().Add(dc.idleTimeout)) + r, err := dc.readResp() + if err != nil { + dc.CloseWithErr(fmt.Errorf("read err, %w", err)) // abort this connection. + return + } + dc.waitingResp.Store(false) + + rid := binary.BigEndian.Uint16(*r) + resChan := dc.getQueueC(rid) + if resChan != nil { + select { + case resChan <- r: // resChan has buffer + default: + pool.ReleaseBuf(r) + } + } else { + pool.ReleaseBuf(r) + } + } +} + +func (dc *TraditionalDnsConn) IsClosed() bool { + return dc.closed.Load() +} + +func (dc *TraditionalDnsConn) Close() error { + dc.CloseWithErr(ErrTDCClosed) + return nil +} + +// CloseWithErr closes DnsConn with an error. The error will be sent +// to the waiting Exchange calls. +// Subsequent calls are noop. +// Default err is ErrTDCClosed. +func (dc *TraditionalDnsConn) CloseWithErr(err error) { + if err == nil { + err = ErrTDCClosed + } + + dc.closeOnce.Do(func() { + dc.closed.Store(true) + dc.closeErr = err + close(dc.closeNotify) + dc.c.Close() + }) +} + +func (dc *TraditionalDnsConn) getQueueC(qid uint16) chan<- *[]byte { + dc.queueMu.RLock() + defer dc.queueMu.RUnlock() + return dc.queue[uint32(qid)] +} + +func (dc *TraditionalDnsConn) queueLen() int { + dc.queueMu.RLock() + defer dc.queueMu.RUnlock() + return len(dc.queue) + dc.reservedQuery +} + +// addQueueC assigns a qid and add it to the queue. +// It returns a nil c if queue has too many queries. +// Caller must call deleteQueueC to release the qid in queue. +func (dc *TraditionalDnsConn) addQueueC() (qid uint16, c chan *[]byte) { + c = make(chan *[]byte) + dc.queueMu.Lock() + for i := 0; i < 100; i++ { + qid = dc.nextQid + dc.nextQid++ + if _, dup := dc.queue[uint32(qid)]; dup { + continue + } + dc.queue[uint32(qid)] = c + dc.queueMu.Unlock() + return qid, c + } + dc.queueMu.Unlock() + + // Too many queries in queue. Can't assign qid. + return 0, nil +} + +func (dc *TraditionalDnsConn) deleteQueueC(qid uint16) { + dc.queueMu.Lock() + delete(dc.queue, uint32(qid)) + dc.queueMu.Unlock() +} + +func (dc *TraditionalDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) { + if dc.closed.Load() { + return nil, true + } + + dc.queueMu.Lock() + defer dc.queueMu.Unlock() + if len(dc.queue)+dc.reservedQuery >= dc.maxCq { + return nil, false + } + dc.reservedQuery++ + return (*tdcOneTimeExchanger)(dc), false +} + +type tdcOneTimeExchanger TraditionalDnsConn + +var _ ReservedExchanger = (*tdcOneTimeExchanger)(nil) + +func (ote *tdcOneTimeExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) { + defer ote.WithdrawReserved() + return (*TraditionalDnsConn)(ote).exchange(ctx, q) +} + +func (ote *tdcOneTimeExchanger) WithdrawReserved() { + ote.queueMu.Lock() + ote.reservedQuery-- + ote.queueMu.Unlock() +} diff --git a/pkg/upstream/transport/conn_traditional_test.go b/pkg/upstream/transport/conn_traditional_test.go new file mode 100644 index 0000000..49e0f2f --- /dev/null +++ b/pkg/upstream/transport/conn_traditional_test.go @@ -0,0 +1,250 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package transport + +import ( + "bytes" + "context" + "errors" + "math/rand" + "net" + "runtime" + "sync" + "testing" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +type dummyEchoNetConn struct { + net.Conn + rErrProb float64 + rLatency time.Duration + wErrProb float64 + + closeOnce sync.Once + closeNotify chan struct{} +} + +func newDummyEchoNetConn(rErrProb float64, rLatency time.Duration, wErrProb float64) NetConn { + c1, c2 := net.Pipe() + go func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + defer c1.Close() + defer c2.Close() + for { + m, readErr := dnsutils.ReadRawMsgFromTCP(c2) + if m != nil { + go func() { + defer pool.ReleaseBuf(m) + if rLatency > 0 { + t := time.NewTimer(rLatency) + defer t.Stop() + select { + case <-t.C: + case <-ctx.Done(): + return + } + } + latency := time.Millisecond * time.Duration(rand.Intn(20)) + time.Sleep(latency) + _, _ = dnsutils.WriteRawMsgToTCP(c2, *m) + }() + } + if readErr != nil { + return + } + } + }() + return &dummyEchoNetConn{ + Conn: c1, + rErrProb: rErrProb, + rLatency: rLatency, + wErrProb: wErrProb, + closeNotify: make(chan struct{}), + } +} + +func probTrue(p float64) bool { + return rand.Float64() < p +} + +func (d *dummyEchoNetConn) Read(p []byte) (n int, err error) { + if probTrue(d.rErrProb) { + return 0, errors.New("read err") + } + return d.Conn.Read(p) +} + +func (d *dummyEchoNetConn) Write(p []byte) (n int, err error) { + if probTrue(d.wErrProb) { + return 0, errors.New("write err") + } + return d.Conn.Write(p) +} + +func (d *dummyEchoNetConn) Close() error { + d.closeOnce.Do(func() { + close(d.closeNotify) + }) + return d.Conn.Close() +} + +func Test_dnsConn_exchange(t *testing.T) { + idleTimeout := time.Millisecond * 100 + + tests := []struct { + name string + rErrProb float64 + rLatency time.Duration + wErrProb float64 + connClosed bool // connection is closed before calling exchange() + wantMsg bool + wantErr bool + }{ + { + name: "normal", + rErrProb: 0, + rLatency: 0, + wErrProb: 0, + wantMsg: true, wantErr: false, + }, + { + name: "write err", + rErrProb: 0, + rLatency: 0, + wErrProb: 1, + wantMsg: false, wantErr: true, + }, + { + name: "read err", + rErrProb: 1, + rLatency: 0, + wErrProb: 0, + wantMsg: false, wantErr: true, + }, + { + name: "read timeout", + rErrProb: 0, + rLatency: idleTimeout * 3, + wErrProb: 0, + wantMsg: false, wantErr: true, + }, + { + name: "connection closed", + rErrProb: 0, + rLatency: 0, + wErrProb: 0, + connClosed: true, + wantMsg: false, wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := require.New(t) + c := newDummyEchoNetConn(tt.rErrProb, tt.rLatency, tt.wErrProb) + defer c.Close() + ioOpts := TraditionalDnsConnOpts{ + WithLengthHeader: true, // TODO: Test false as well + IdleTimeout: idleTimeout, + } + dc := NewDnsConn(ioOpts, c) + + if tt.connClosed { + dc.Close() + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + q := new(dns.Msg) + q.SetQuestion("test.", dns.TypeA) + queryPayload, err := q.Pack() + r.NoError(err) + + rec, closed := dc.ReserveNewQuery() + r.Equal(tt.connClosed, closed) + if rec == nil { + return + } + + respPayload, err := rec.ExchangeReserved(ctx, queryPayload) + if tt.wantErr { + r.Error(err) + } else { + r.NoError(err) + } + + if tt.wantMsg { + r.NotNil(respPayload) + r.True(bytes.Equal(queryPayload, *respPayload)) + + // test idle timeout + time.Sleep(idleTimeout + time.Millisecond*20) + runtime.Gosched() + r.True(dc.IsClosed(), "connection should be closed due to idle timeout") + } else { + r.Nil(respPayload) + } + }) + } +} + +// TODO: 测试 maxconcurrentquery。 + +func Test_dnsConn_exchange_race(t *testing.T) { + r := require.New(t) + wg := new(sync.WaitGroup) + for i := 0; i < 1024; i++ { + c := newDummyEchoNetConn(0.5, time.Millisecond*20, 0.5) + ioOpts := TraditionalDnsConnOpts{ + WithLengthHeader: true, // TODO: Test false as well + IdleTimeout: time.Millisecond * 50, + } + dc := NewDnsConn(ioOpts, c) + for j := 0; j < 24; j++ { + if dc.IsClosed() { + break + } + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + q := new(dns.Msg) + q.SetQuestion("test.", dns.TypeA) + queryPayload, err := q.Pack() + r.NoError(err) + + rec, closed := dc.ReserveNewQuery() + if closed { + return + } + if rec != nil { + _, _ = rec.ExchangeReserved(ctx, queryPayload) + } + }() + } + } + wg.Wait() +} diff --git a/pkg/upstream/transport/pipeline.go b/pkg/upstream/transport/pipeline.go new file mode 100644 index 0000000..c779ba9 --- /dev/null +++ b/pkg/upstream/transport/pipeline.go @@ -0,0 +1,151 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package transport + +import ( + "context" + "sync" + "time" + + "go.uber.org/zap" +) + +// PipelineTransport will pipeline queries as RFC 7766 6.2.1.1 suggested. +// It also can reuse udp socket. Since dns over udp is some kind of "pipeline". +type PipelineTransport struct { + m sync.Mutex // protect following fields + closed bool + conns map[*lazyDnsConn]struct{} + + dialFunc func(ctx context.Context) (DnsConn, error) + dialTimeout time.Duration + maxLazyConnQueue int + logger *zap.Logger // not nil +} + +type PipelineOpts struct { + // DialContext specifies the method to dial a connection to the server. + // DialContext MUST NOT be nil. + DialContext func(ctx context.Context) (DnsConn, error) + + // DialTimeout specifies the timeout for DialFunc. + // Default is defaultDialTimeout. + DialTimeout time.Duration + + // When connection is dialing, how many queries can be queued up in that + // connection. Default is defaultLazyConnMaxConcurrentQuery. + // Note: If the connection turns out having a smaller limit, part of queued up + // queries will fail. + MaxConcurrentQueryWhileDialing int + + Logger *zap.Logger +} + +func NewPipelineTransport(opt PipelineOpts) *PipelineTransport { + t := &PipelineTransport{ + conns: make(map[*lazyDnsConn]struct{}), + } + t.dialFunc = opt.DialContext + setDefaultGZ(&t.dialTimeout, opt.DialTimeout, defaultDialTimeout) + setDefaultGZ(&t.maxLazyConnQueue, opt.MaxConcurrentQueryWhileDialing, defaultMaxLazyConnQueue) + setNonNilLogger(&t.logger, opt.Logger) + + return t +} + +func (t *PipelineTransport) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) { + const maxRetry = 2 + retry := 0 + for { + dc, isNewConn, err := t.getReservedExchanger() + if err != nil { + return nil, err + } + r, err := dc.ExchangeReserved(ctx, m) + if err != nil { + // Reused connection may not stable. + // Try to re-send this query if it failed on a reused connection. + if !isNewConn && retry < maxRetry && ctx.Err() == nil { + retry++ + continue + } + return nil, err + } + return r, nil + } +} + +// Close closes PipelineTransport and all its connections. +// It always returns a nil error. +func (t *PipelineTransport) Close() error { + t.m.Lock() + defer t.m.Unlock() + if t.closed { + return nil + } + t.closed = true + for conn := range t.conns { + conn.Close() + } + return nil +} + +func (t *PipelineTransport) getReservedExchanger() (_ ReservedExchanger, isNewConn bool, err error) { + t.m.Lock() + if t.closed { + err = ErrClosedTransport + t.m.Unlock() + return + } + + var rxc ReservedExchanger + const maxReserveAttempt = 16 + reserveAttempt := 0 + for c := range t.conns { + var closed bool + rxc, closed = c.ReserveNewQuery() + if closed { + delete(t.conns, c) + } + if rxc != nil { + break + } else { + reserveAttempt++ + if reserveAttempt > maxReserveAttempt { + break + } + } + } + + // Dial a new connection + if rxc == nil { + c := newLazyDnsConn(t.dialFunc, t.dialTimeout, t.maxLazyConnQueue, t.logger) + rxc, _ = c.ReserveNewQuery() // ignore the closed error for new lazy connection + isNewConn = true + t.conns[c] = struct{}{} + } + t.m.Unlock() + + if rxc == nil { + isNewConn = false + err = ErrNewConnCannotReserveQueryExchanger + } + return rxc, isNewConn, err +} diff --git a/pkg/upstream/transport/pipeline_test.go b/pkg/upstream/transport/pipeline_test.go new file mode 100644 index 0000000..5c2c077 --- /dev/null +++ b/pkg/upstream/transport/pipeline_test.go @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package transport + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +type dummyEchoDnsConnOpt struct { + exchangeErr error + mcq int + + closed atomic.Bool + + wantConcurrentExchangeCall int + waitingExchangeCall atomic.Int32 + unblockOnce sync.Once + unblockExchange chan struct{} +} + +type dummyEchoDnsConn struct { + opt *dummyEchoDnsConnOpt + + m sync.Mutex + reserved int +} + +func (dc *dummyEchoDnsConn) ExchangeReserved(ctx context.Context, q []byte) (*[]byte, error) { + defer dc.WithdrawReserved() + + if dc.opt.waitingExchangeCall.Add(1) == int32(dc.opt.wantConcurrentExchangeCall) { + dc.opt.unblockOnce.Do(func() { close(dc.opt.unblockExchange) }) + } + defer dc.opt.waitingExchangeCall.Add(-1) + + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case <-dc.opt.unblockExchange: + if dc.opt.exchangeErr != nil { + return nil, dc.opt.exchangeErr + } + return copyMsg(q), nil + } +} + +func (dc *dummyEchoDnsConn) WithdrawReserved() { + dc.m.Lock() + defer dc.m.Unlock() + dc.reserved-- + if dc.reserved < 0 { + panic("negative reserved counter") + } +} + +func (dc *dummyEchoDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) { + if dc.opt.closed.Load() { + return nil, true + } + dc.m.Lock() + defer dc.m.Unlock() + if dc.reserved >= dc.opt.mcq { + return nil, false + } + dc.reserved++ + return dc, false +} + +func (dc *dummyEchoDnsConn) Close() error { + return nil +} + +func Test_PipelineTransport(t *testing.T) { + const ( + mcq = 100 + wantConn = 10 + wantMaxConcurrentExchangeCall = mcq * wantConn + ) + + r := require.New(t) + dcControl := &dummyEchoDnsConnOpt{ + mcq: mcq, + unblockExchange: make(chan struct{}), + wantConcurrentExchangeCall: wantMaxConcurrentExchangeCall, + } + po := PipelineOpts{ + DialContext: func(ctx context.Context) (DnsConn, error) { return &dummyEchoDnsConn{opt: dcControl}, nil }, + MaxConcurrentQueryWhileDialing: mcq, + } + pt := NewPipelineTransport(po) + defer pt.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + q := new(dns.Msg) + q.SetQuestion("test.", dns.TypeA) + queryPayload, err := q.Pack() + r.NoError(err) + wg := new(sync.WaitGroup) + for i := 0; i < wantMaxConcurrentExchangeCall; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := pt.ExchangeContext(ctx, queryPayload) + if err != nil { + t.Error(err) + } + }() + if t.Failed() { + break + } + } + wg.Wait() + + pt.m.Lock() + pl := len(pt.conns) + pt.m.Unlock() + + r.Equal(wantConn, pl) + + dcControl.closed.Store(true) + _, _ = pt.ExchangeContext(ctx, queryPayload) // remove all closed conn + + pt.m.Lock() + pl = len(pt.conns) + pt.m.Unlock() + r.Equal(1, pl, "all connection should be remove then one will be opened") +} diff --git a/pkg/upstream/transport/reuse.go b/pkg/upstream/transport/reuse.go new file mode 100644 index 0000000..b7ea890 --- /dev/null +++ b/pkg/upstream/transport/reuse.go @@ -0,0 +1,341 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package transport + +import ( + "context" + "errors" + "net" + "sync" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "go.uber.org/zap" +) + +const ( + // Most servers will send SERVFAIL after 3~5s. If no resp, connection may be dead. + reuseConnQueryTimeout = time.Second * 6 +) + +// ReuseConnTransport is for old tcp protocol. (no pipelining) +type ReuseConnTransport struct { + dialFunc func(ctx context.Context) (NetConn, error) + dialTimeout time.Duration + idleTimeout time.Duration + logger *zap.Logger // non-nil + ctx context.Context + ctxCancel context.CancelCauseFunc + + m sync.Mutex // protect following fields + closed bool + idleConns map[*reusableConn]struct{} + conns map[*reusableConn]struct{} + + // for testing + testWaitRespTimeout time.Duration +} + +type ReuseConnOpts struct { + // DialContext specifies the method to dial a connection to the server. + // DialContext MUST NOT be nil. + DialContext func(ctx context.Context) (NetConn, error) + + // DialTimeout specifies the timeout for DialFunc. + // Default is defaultDialTimeout. + DialTimeout time.Duration + + // Default is defaultIdleTimeout + IdleTimeout time.Duration + + Logger *zap.Logger +} + +func NewReuseConnTransport(opt ReuseConnOpts) *ReuseConnTransport { + ctx, cancel := context.WithCancelCause(context.Background()) + t := &ReuseConnTransport{ + ctx: ctx, + ctxCancel: cancel, + idleConns: make(map[*reusableConn]struct{}), + conns: make(map[*reusableConn]struct{}), + } + t.dialFunc = opt.DialContext + setDefaultGZ(&t.dialTimeout, opt.DialTimeout, defaultDialTimeout) + setDefaultGZ(&t.idleTimeout, opt.IdleTimeout, defaultIdleTimeout) + setNonNilLogger(&t.logger, opt.Logger) + + return t +} + +func (t *ReuseConnTransport) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) { + const maxRetry = 2 + + retry := 0 + for { + var isNewConn bool + c, err := t.getIdleConn() + if err != nil { + return nil, err + } + if c == nil { + isNewConn = true + c, err = t.getNewConn(ctx) + if err != nil { + return nil, err + } + } + + queryPayload, err := copyMsgWithLenHdr(m) + if err != nil { + return nil, err + } + + resp, err := c.exchange(ctx, queryPayload) + if err != nil { + if !isNewConn && retry <= maxRetry { + retry++ + continue // retry if c is a reused connection. + } + return nil, err + } + return resp, nil + } +} + +// getNewConn dial a *reusableConn. +// The caller must call releaseReusableConn to release the reusableConn. +func (t *ReuseConnTransport) getNewConn(ctx context.Context) (*reusableConn, error) { + callCtx, cancel := context.WithCancel(ctx) + defer cancel() + + type dialRes struct { + c *reusableConn + err error + } + dialChan := make(chan dialRes) + go func() { + dialCtx, cancelDial := context.WithTimeout(t.ctx, t.dialTimeout) + defer cancelDial() + + var rc *reusableConn + c, err := t.dialFunc(dialCtx) + if err != nil { + t.logger.Check(zap.WarnLevel, "fail to dial reusable conn").Write(zap.Error(err)) + } + if c != nil { + rc = t.newReusableConn(c) + if rc == nil { // transport closed + c.Close() + rc = nil + err = ErrClosedTransport + } + } + + select { + case dialChan <- dialRes{c: rc, err: err}: + case <-callCtx.Done(): // caller canceled getNewConn() call + if rc != nil { // put this conn to pool + t.setIdle(rc) + } + } + }() + + select { + case <-callCtx.Done(): + return nil, context.Cause(ctx) + case <-t.ctx.Done(): + return nil, context.Cause(t.ctx) + case res := <-dialChan: + return res.c, res.err + } +} + +func (t *ReuseConnTransport) setIdle(c *reusableConn) { + t.m.Lock() + defer t.m.Unlock() + if t.closed { + return + } + if _, ok := t.conns[c]; ok { + t.idleConns[c] = struct{}{} + } +} + +// getIdleConn returns a *reusableConn from conn pool, or nil if no conn +// is idle. +// The caller must call releaseReusableConn to release the reusableConn. +func (t *ReuseConnTransport) getIdleConn() (*reusableConn, error) { + t.m.Lock() + defer t.m.Unlock() + if t.closed { + return nil, ErrClosedTransport + } + + for c := range t.idleConns { + delete(t.idleConns, c) + return c, nil + } + return nil, nil +} + +// Close closes ReuseConnTransport and all its connections. +// It always returns a nil error. +func (t *ReuseConnTransport) Close() error { + t.m.Lock() + defer t.m.Unlock() + if t.closed { + return nil + } + t.closed = true + for c := range t.conns { + delete(t.conns, c) + delete(t.idleConns, c) + c.closeWithErrByTransport(ErrClosedTransport) + } + t.ctxCancel(ErrClosedTransport) + return nil +} + +type reusableConn struct { + c NetConn + t *ReuseConnTransport + + m sync.Mutex + waitingResp chan *[]byte + + closeOnce sync.Once + closeNotify chan struct{} + closeErr error +} + +// return nil if transport was closed +func (t *ReuseConnTransport) newReusableConn(c NetConn) *reusableConn { + rc := &reusableConn{ + c: c, + t: t, + closeNotify: make(chan struct{}), + } + + t.m.Lock() + if t.closed { // t was closed. + t.m.Unlock() + return nil + } + t.conns[rc] = struct{}{} + t.m.Unlock() + go rc.readLoop() + return rc +} + +var ( + errUnexpectedResp = errors.New("server misbehaving: unexpected response") +) + +func (c *reusableConn) readLoop() { + for { + resp, err := dnsutils.ReadRawMsgFromTCP(c.c) + if err != nil { + c.closeWithErr(err) + return + } + + c.m.Lock() + respChan := c.waitingResp + c.waitingResp = nil + c.m.Unlock() + + if respChan == nil { + pool.ReleaseBuf(resp) + c.closeWithErr(errUnexpectedResp) + return + } + + // This connection is idled again. + c.c.SetReadDeadline(time.Now().Add(c.t.idleTimeout)) + // Note: calling setIdle before sending resp back to make sure this connection is idle + // before Exchange call returning. Otherwise, Test_ReuseConnTransport may fail. + c.t.setIdle(c) + + select { + case respChan <- resp: + default: + panic("bug: respChan has buffer, we shouldn't reach here") + } + } +} + +func (c *reusableConn) closeWithErr(err error) { + if err == nil { + err = net.ErrClosed + } + c.closeOnce.Do(func() { + c.t.m.Lock() + delete(c.t.conns, c) + delete(c.t.idleConns, c) + c.t.m.Unlock() + + c.closeErr = err + c.c.Close() + close(c.closeNotify) + }) +} + +func (c *reusableConn) closeWithErrByTransport(err error) { + if err == nil { + err = net.ErrClosed + } + c.closeOnce.Do(func() { + c.closeErr = err + c.c.Close() + close(c.closeNotify) + }) +} + +func (c *reusableConn) exchange(ctx context.Context, q *[]byte) (*[]byte, error) { + respChan := make(chan *[]byte, 1) + c.m.Lock() + if c.waitingResp != nil { + c.m.Unlock() + panic("bug: reusableConn: concurrent exchange calls") + } + c.waitingResp = respChan + c.m.Unlock() + + waitRespTimeout := reuseConnQueryTimeout + if c.t.testWaitRespTimeout > 0 { + waitRespTimeout = c.t.testWaitRespTimeout + } + c.c.SetDeadline(time.Now().Add(waitRespTimeout)) + _, err := c.c.Write(*q) + if err != nil { + c.closeWithErr(err) + return nil, err + } + + select { + case resp := <-respChan: + return resp, nil + case <-c.closeNotify: + return nil, c.closeErr + case <-ctx.Done(): + return nil, context.Cause(ctx) + } +} diff --git a/pkg/upstream/transport/reuse_test.go b/pkg/upstream/transport/reuse_test.go new file mode 100644 index 0000000..34648bf --- /dev/null +++ b/pkg/upstream/transport/reuse_test.go @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package transport + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +func Test_ReuseConnTransport(t *testing.T) { + const idleTimeout = time.Second * 5 + r := require.New(t) + + po := ReuseConnOpts{ + DialContext: func(ctx context.Context) (NetConn, error) { + return newDummyEchoNetConn(0, 0, 0), nil + }, + IdleTimeout: idleTimeout, + } + rt := NewReuseConnTransport(po) + defer rt.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + q := new(dns.Msg) + q.SetQuestion("test.", dns.TypeA) + queryPayload, err := q.Pack() + r.NoError(err) + concurrentQueryNum := 10 + for l := 0; l < 4; l++ { + wg := new(sync.WaitGroup) + for i := 0; i < concurrentQueryNum; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := rt.ExchangeContext(ctx, queryPayload) + if err != nil { + t.Error(err) + } + }() + } + wg.Wait() + if t.Failed() { + return + } + } + + rt.m.Lock() + connNum := len(rt.conns) + idledConnNum := len(rt.idleConns) + rt.m.Unlock() + + r.Equal(0, connNum-idledConnNum, "there should be no active conn") + r.Equal(concurrentQueryNum, connNum) + r.Equal(concurrentQueryNum, idledConnNum, "all conn should be in idle status") +} + +func Test_ReuseConnTransport_Read_err_and_close(t *testing.T) { + const idleTimeout = time.Second * 5 + r := require.New(t) + + po := ReuseConnOpts{ + DialContext: func(ctx context.Context) (NetConn, error) { + return newDummyEchoNetConn(1, 0, 0), nil // 100% read err + }, + IdleTimeout: idleTimeout, + } + rt := NewReuseConnTransport(po) + defer rt.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + q := new(dns.Msg) + q.SetQuestion("test.", dns.TypeA) + queryPayload, err := q.Pack() + r.NoError(err) + + wg := new(sync.WaitGroup) + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := rt.ExchangeContext(ctx, queryPayload) + r.Error(err) + }() + if t.Failed() { + return + } + } + wg.Wait() + + rt.m.Lock() + connNum := len(rt.conns) + idledConnNum := len(rt.idleConns) + rt.m.Unlock() + + r.Equal(0, connNum) + r.Equal(0, idledConnNum) +} + +func Test_ReuseConnTransport_conn_lose_and_close(t *testing.T) { + r := require.New(t) + po := ReuseConnOpts{ + DialContext: func(ctx context.Context) (NetConn, error) { + return newDummyEchoNetConn(0, time.Second, 0), nil // 100% read timeout + }, + } + rt := NewReuseConnTransport(po) + defer rt.Close() + rt.testWaitRespTimeout = time.Millisecond * 1 + + q := new(dns.Msg) + q.SetQuestion("test.", dns.TypeA) + queryPayload, err := q.Pack() + r.NoError(err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err = rt.ExchangeContext(ctx, queryPayload) // canceled ctx + r.Error(err) + + time.Sleep(time.Millisecond * 100) + + rt.m.Lock() + connNum := len(rt.conns) + idledConnNum := len(rt.idleConns) + rt.m.Unlock() + + // connection should be closed and removed + r.Equal(0, connNum) + r.Equal(0, idledConnNum) +} diff --git a/pkg/upstream/transport/transport.go b/pkg/upstream/transport/transport.go new file mode 100644 index 0000000..c9d6060 --- /dev/null +++ b/pkg/upstream/transport/transport.go @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package transport + +import ( + "context" + "errors" + "io" + "time" +) + +var ( + ErrClosedTransport = errors.New("transport has been closed") + ErrPayloadOverFlow = errors.New("payload is too large") + ErrNewConnCannotReserveQueryExchanger = errors.New("new connection failed to reserve query exchanger") + ErrLazyConnCannotReserveQueryExchanger = errors.New("lazy connection failed to reserve query exchanger") +) + +const ( + defaultIdleTimeout = time.Second * 10 + defaultDialTimeout = time.Second * 5 + + // If a pipeline connection sent a query but did not see any reply (include replies that + // for other queries) from the server after waitingReplyTimeout. It assumes that + // something goes wrong with the connection or the server. The connection will be closed. + waitingReplyTimeout = time.Second * 10 + + defaultTdcMaxConcurrentQuery = 32 + defaultMaxLazyConnQueue = 16 +) + +// One method MUST be called in ReservedExchanger. +type ReservedExchanger interface { + // ExchangeReserved sends q to the server and returns it's response. + // ExchangeReserved MUST not modify nor keep the q. + // q MUST be a valid dns message. + // resp (if no err) should be released by ReleaseResp(). + ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) + + // WithdrawReserved aborts the query. + WithdrawReserved() +} + +type DnsConn interface { + // ReserveNewQuery reserves a query. It MUST be fast and non-block. If DnsConn + // cannot serve more query due to its capacity, ReserveNewQuery returns nil. + // If DnsConn is closed and can no longer serve more query, returns closed = true. + ReserveNewQuery() (_ ReservedExchanger, closed bool) + io.Closer +} + +type NetConn interface { + io.ReadWriteCloser + SetDeadline(t time.Time) error + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error +} diff --git a/pkg/upstream/transport/utils.go b/pkg/upstream/transport/utils.go new file mode 100644 index 0000000..7b69a20 --- /dev/null +++ b/pkg/upstream/transport/utils.go @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package transport + +import ( + "encoding/binary" + "io" + + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/miekg/dns" + "go.uber.org/zap" + "golang.org/x/exp/constraints" +) + +const ( + dnsHeaderLen = 12 // minimum dns msg size +) + +func copyMsgWithLenHdr(m []byte) (*[]byte, error) { + l := len(m) + if l > dns.MaxMsgSize { + return nil, ErrPayloadOverFlow + } + bp := pool.GetBuf(l + 2) + binary.BigEndian.PutUint16(*bp, uint16(l)) + copy((*bp)[2:], m) + return bp, nil +} + +func copyMsg(m []byte) *[]byte { + bp := pool.GetBuf(len(m)) + copy((*bp), m) + return bp +} + +// readMsgUdp reads dns frame from r. r typically should be a udp connection. +// It uses a 4kb rx buffer and ignores any payload that is too small for a dns msg. +// If no error, the length of payload always >= 12 bytes. +func readMsgUdp(r io.Reader) (*[]byte, error) { + // TODO: Make this configurable? + // 4kb should be enough. + payload := pool.GetBuf(4095) + +readAgain: + n, err := r.Read(*payload) + if err != nil { + pool.ReleaseBuf(payload) + return nil, err + } + if n < dnsHeaderLen { + goto readAgain + } + *payload = (*payload)[:n] + return payload, err +} + +func setDefaultGZ[T constraints.Float | constraints.Integer](i *T, s, d T) { + if s > 0 { + *i = s + } else { + *i = d + } +} + +var nopLogger = zap.NewNop() + +func setNonNilLogger(i **zap.Logger, s *zap.Logger) { + if s != nil { + *i = s + } else { + *i = nopLogger + } +} diff --git a/pkg/upstream/upstream.go b/pkg/upstream/upstream.go new file mode 100644 index 0000000..b91b1b1 --- /dev/null +++ b/pkg/upstream/upstream.go @@ -0,0 +1,609 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package upstream + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "strconv" + "strings" + "time" + + "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/upstream/bootstrap" + "github.com/IrineSistiana/mosdns/v5/pkg/upstream/doh" + "github.com/IrineSistiana/mosdns/v5/pkg/upstream/transport" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" + "go.uber.org/zap" + "golang.org/x/net/http2" + "golang.org/x/net/proxy" +) + +const ( + tlsHandshakeTimeout = time.Second * 3 + + // Maximum number of concurrent queries in one pipeline connection. + // See RFC 7766 7. Response Reordering. + // TODO: Make this configurable? + pipelineConcurrentLimit = 64 +) + +// Upstream represents a DNS upstream. +type Upstream interface { + // ExchangeContext exchanges query message m to the upstream, and returns + // response. It MUST NOT keep or modify m. + // m MUST be a valid dns msg frame. It MUST be at least 12 bytes + // (contain a valid dns header). + ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) + + io.Closer +} + +type Opt struct { + // DialAddr specifies the address the upstream will + // actually dial to in the network layer by overwriting + // the address inferred from upstream url. + // It won't affect high level layers. (e.g. SNI, HTTP HOST header won't be changed). + // Can be an IP or a domain. Port is optional. + // Tips: If the upstream url host is a domain, specific an IP address + // here can skip resolving ip of this domain. + DialAddr string + + // Socks5 specifies the socks5 proxy server that the upstream + // will connect though. + // Not implemented for udp based protocols (aka. dns over udp, http3, quic). + Socks5 string + + // SoMark sets the socket SO_MARK option in unix system. + SoMark int + + // BindToDevice sets the socket SO_BINDTODEVICE option in unix system. + BindToDevice string + + // IdleTimeout specifies the idle timeout for long-connections. + // Default: TCP, DoT: 10s , DoH, DoH3, Quic: 30s. + IdleTimeout time.Duration + + // EnablePipeline enables query pipelining support as RFC 7766 6.2.1.1 suggested. + // Available for TCP, DoT upstream. + // Note: There is no fallback. Make sure the server supports it. + EnablePipeline bool + + // EnableHTTP3 will use HTTP/3 protocol to connect a DoH upstream. (aka DoH3). + // Note: There is no fallback. Make sure the server supports it. + EnableHTTP3 bool + + // Bootstrap specifies a plain dns server to solve the + // upstream server domain address. + // It must be an IP address. Port is optional. + Bootstrap string + + // Bootstrap version. One of 0 (default equals 4), 4, 6. + // TODO: Support dual-stack. + BootstrapVer int + + // TLSConfig specifies the tls.Config that the TLS client will use. + // Available for DoT, DoH, DoQ upstream. + TLSConfig *tls.Config + + // Logger specifies the logger that the upstream will use. + Logger *zap.Logger + + // EventObserver can observe connection events. + // Not implemented for quic based protocol (DoH3, DoQ). + EventObserver EventObserver +} + +// NewUpstream creates a upstream. +// addr has the format of: [protocol://]host[:port][/path]. +// Supported protocol: udp/tcp/tls/https/quic. Default protocol is udp. +// +// Helper protocol: +// - tcp+pipeline/tls+pipeline: Automatically set opt.EnablePipeline to true. +// - h3: Automatically set opt.EnableHTTP3 to true. +func NewUpstream(addr string, opt Opt) (_ Upstream, err error) { + if opt.Logger == nil { + opt.Logger = mlog.Nop() + } + if opt.EventObserver == nil { + opt.EventObserver = nopEO{} + } + + // parse protocol and server addr + if !strings.Contains(addr, "://") { + addr = "udp://" + addr + } + addrURL, err := url.Parse(addr) + if err != nil { + return nil, fmt.Errorf("invalid server address, %w", err) + } + + // Apply helper protocol + switch addrURL.Scheme { + case "tcp+pipeline", "tls+pipeline": + addrURL.Scheme = addrURL.Scheme[:3] + opt.EnablePipeline = true + case "h3": + addrURL.Scheme = "https" + opt.EnableHTTP3 = true + } + + // If host is a ipv6 without port, it will be in []. This will cause err when + // split and join address and port. Try to remove brackets now. + addrUrlHost := tryTrimIpv6Brackets(addrURL.Host) + + dialer := &net.Dialer{ + Control: getSocketControlFunc(socketOpts{ + so_mark: opt.SoMark, + bind_to_device: opt.BindToDevice, + }), + } + + var bootstrapAp netip.AddrPort + if s := opt.Bootstrap; len(s) > 0 { + bootstrapAp, err = parseBootstrapAp(s) + if err != nil { + return nil, fmt.Errorf("invalid bootstrap, %w", err) + } + } + + newUdpAddrResolveFunc := func(defaultPort uint16) (func(ctx context.Context) (*net.UDPAddr, error), error) { + host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) + if err != nil { + return nil, err + } + + if addr, err := netip.ParseAddr(host); err == nil { // host is an ip. + ua := net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port)) + return func(ctx context.Context) (*net.UDPAddr, error) { + return ua, nil + }, nil + } else { // Not an ip, assuming it's a domain name. + if bootstrapAp.IsValid() { + // Bootstrap enabled. + bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger) + if err != nil { + return nil, err + } + + return func(ctx context.Context) (*net.UDPAddr, error) { + s, err := bs.GetAddrPortStr(ctx) + if err != nil { + return nil, fmt.Errorf("bootstrap failed, %w", err) + } + return net.ResolveUDPAddr("udp", s) + }, nil + } else { + // Bootstrap disabled. + dialAddr := joinPort(host, port) + return func(ctx context.Context) (*net.UDPAddr, error) { + return net.ResolveUDPAddr("udp", dialAddr) + }, nil + } + } + } + + newTcpDialer := func(dialAddrMustBeIp bool, defaultPort uint16) (func(ctx context.Context) (net.Conn, error), error) { + host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) + if err != nil { + return nil, err + } + + // Socks5 enabled. + if s5Addr := opt.Socks5; len(s5Addr) > 0 { + socks5Dialer, err := proxy.SOCKS5("tcp", s5Addr, nil, dialer) + if err != nil { + return nil, fmt.Errorf("failed to init socks5 dialer: %w", err) + } + + contextDialer := socks5Dialer.(proxy.ContextDialer) + dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) + return func(ctx context.Context) (net.Conn, error) { + return contextDialer.DialContext(ctx, "tcp", dialAddr) + }, nil + } + + if _, err := netip.ParseAddr(host); err == nil { + // Host is an ip addr. No need to resolve it. + dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) + return func(ctx context.Context) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", dialAddr) + }, nil + } else { + if dialAddrMustBeIp { + return nil, errors.New("addr must be an ip address") + } + // Host is not an ip addr, assuming it is a domain. + if bootstrapAp.IsValid() { + // Bootstrap enabled. + bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger) + if err != nil { + return nil, err + } + + return func(ctx context.Context) (net.Conn, error) { + dialAddr, err := bs.GetAddrPortStr(ctx) + if err != nil { + return nil, fmt.Errorf("bootstrap failed, %w", err) + } + return dialer.DialContext(ctx, "tcp", dialAddr) + }, nil + } else { + // Bootstrap disabled. + dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) + return func(ctx context.Context) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", dialAddr) + }, nil + } + } + } + + closeIfFuncErr := func(c io.Closer) { + if err != nil { + c.Close() + } + } + + switch addrURL.Scheme { + case "", "udp": + const defaultPort = 53 + const maxConcurrentQueryPreConn = 4096 // Protocol limit is 65535. + host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) + if err != nil { + return nil, err + } + if _, err := netip.ParseAddr(host); err != nil { + return nil, fmt.Errorf("addr must be an ip address, %w", err) + } + dialAddr := joinPort(host, port) + + dialUdpPipeline := func(ctx context.Context) (transport.DnsConn, error) { + c, err := dialer.DialContext(ctx, "udp", dialAddr) + if err != nil { + return nil, err + } + to := transport.TraditionalDnsConnOpts{ + WithLengthHeader: false, + IdleTimeout: time.Minute * 5, + MaxConcurrentQuery: maxConcurrentQueryPreConn, + } + return transport.NewDnsConn(to, wrapConn(c, opt.EventObserver)), nil + } + dialTcpNetConn := func(ctx context.Context) (transport.NetConn, error) { + c, err := dialer.DialContext(ctx, "tcp", dialAddr) + if err != nil { + return nil, err + } + return wrapConn(c, opt.EventObserver), nil + } + + return &udpWithFallback{ + u: transport.NewPipelineTransport(transport.PipelineOpts{ + DialContext: dialUdpPipeline, + MaxConcurrentQueryWhileDialing: maxConcurrentQueryPreConn, + Logger: opt.Logger, + }), + t: transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialTcpNetConn}), + }, nil + case "tcp": + const defaultPort = 53 + tcpDialer, err := newTcpDialer(true, defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init tcp dialer, %w", err) + } + idleTimeout := opt.IdleTimeout + if idleTimeout <= 0 { + idleTimeout = time.Second * 10 + } + + dialNetConn := func(ctx context.Context) (transport.NetConn, error) { + c, err := tcpDialer(ctx) + if err != nil { + return nil, err + } + return wrapConn(c, opt.EventObserver), nil + } + if opt.EnablePipeline { + to := transport.TraditionalDnsConnOpts{ + WithLengthHeader: true, + IdleTimeout: idleTimeout, + MaxConcurrentQuery: pipelineConcurrentLimit, + } + dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) { + c, err := dialNetConn(ctx) + if err != nil { + return nil, err + } + return transport.NewDnsConn(to, c), nil + } + return transport.NewPipelineTransport(transport.PipelineOpts{ + DialContext: dialDnsConn, + MaxConcurrentQueryWhileDialing: pipelineConcurrentLimit, + Logger: opt.Logger, + }), nil + } + return transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialNetConn, IdleTimeout: idleTimeout}), nil + case "tls": + const defaultPort = 853 + tlsConfig := opt.TLSConfig.Clone() + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } + if len(tlsConfig.ServerName) == 0 { + tlsConfig.ServerName = tryRemovePort(addrUrlHost) + } + + tcpDialer, err := newTcpDialer(false, defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init tcp dialer, %w", err) + } + + dialNetConn := func(ctx context.Context) (transport.NetConn, error) { + conn, err := tcpDialer(ctx) + if err != nil { + return nil, err + } + conn = wrapConn(conn, opt.EventObserver) + tlsConn := tls.Client(conn, tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + tlsConn.Close() + return nil, err + } + return wrapConn(tlsConn, opt.EventObserver), nil + } + + if opt.EnablePipeline { + to := transport.TraditionalDnsConnOpts{ + WithLengthHeader: true, + IdleTimeout: opt.IdleTimeout, + MaxConcurrentQuery: pipelineConcurrentLimit, + } + dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) { + c, err := dialNetConn(ctx) + if err != nil { + return nil, err + } + return transport.NewDnsConn(to, c), nil + } + return transport.NewPipelineTransport(transport.PipelineOpts{ + DialContext: dialDnsConn, + MaxConcurrentQueryWhileDialing: pipelineConcurrentLimit, + Logger: opt.Logger, + }), nil + } + return transport.NewReuseConnTransport(transport.ReuseConnOpts{DialContext: dialNetConn}), nil + case "https": + const defaultPort = 443 + + idleConnTimeout := time.Second * 30 + if opt.IdleTimeout > 0 { + idleConnTimeout = opt.IdleTimeout + } + + var t http.RoundTripper + var addonCloser io.Closer + if opt.EnableHTTP3 { + udpBootstrap, err := newUdpAddrResolveFunc(defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err) + } + + lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})} + conn, err := lc.ListenPacket(context.Background(), "udp", "") + if err != nil { + return nil, fmt.Errorf("failed to init udp socket for quic, %w", err) + } + quicTransport := &quic.Transport{ + Conn: conn, + } + quicConfig := newDefaultClientQuicConfig() + quicConfig.MaxIdleTimeout = idleConnTimeout + + defer closeIfFuncErr(quicTransport) + addonCloser = quicTransport + t = &http3.RoundTripper{ + TLSClientConfig: opt.TLSConfig, + QUICConfig: quicConfig, + Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + ua, err := udpBootstrap(ctx) + if err != nil { + return nil, err + } + return quicTransport.DialEarly(ctx, ua, tlsCfg, cfg) + }, + MaxResponseHeaderBytes: 4 * 1024, + } + } else { + tcpDialer, err := newTcpDialer(false, defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init tcp dialer, %w", err) + } + t1 := &http.Transport{ + DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { // overwrite server addr + c, err := tcpDialer(ctx) + c = wrapConn(c, opt.EventObserver) + return c, err + }, + TLSClientConfig: opt.TLSConfig, + TLSHandshakeTimeout: tlsHandshakeTimeout, + IdleConnTimeout: idleConnTimeout, + + // Following opts are for http/1 only. + // MaxConnsPerHost: 2, + // MaxIdleConnsPerHost: 2, + } + + t2, err := http2.ConfigureTransports(t1) + if err != nil { + return nil, fmt.Errorf("failed to upgrade http2 support, %w", err) + } + t2.MaxHeaderListSize = 4 * 1024 + t2.MaxReadFrameSize = 16 * 1024 + t2.ReadIdleTimeout = time.Second * 30 + t2.PingTimeout = time.Second * 5 + t = t1 + } + + u, err := doh.NewUpstream(addrURL.String(), t, opt.Logger) + if err != nil { + return nil, fmt.Errorf("failed to create doh upstream, %w", err) + } + + return &dohWithClose{ + u: u, + closer: addonCloser, + }, nil + case "quic", "doq": + const defaultPort = 853 + tlsConfig := opt.TLSConfig.Clone() + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } + if len(tlsConfig.ServerName) == 0 { + tlsConfig.ServerName = tryRemovePort(addrUrlHost) + } + tlsConfig.NextProtos = []string{"doq"} + + quicConfig := newDefaultClientQuicConfig() + if opt.IdleTimeout > 0 { + quicConfig.MaxIdleTimeout = opt.IdleTimeout + } + // Don't accept stream. + quicConfig.MaxIncomingStreams = -1 + quicConfig.MaxIncomingUniStreams = -1 + + udpBootstrap, err := newUdpAddrResolveFunc(defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err) + } + + srk, _, err := utils.InitQUICSrkFromIfaceMac() + if err != nil { + opt.Logger.Warn("failed to init quic stateless reset key, it will be disabled", zap.Error(err)) + } + + lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})} + uc, err := lc.ListenPacket(context.Background(), "udp", "") + if err != nil { + return nil, fmt.Errorf("failed to init udp socket for quic, %w", err) + } + + t := &quic.Transport{ + Conn: uc, + StatelessResetKey: (*quic.StatelessResetKey)(srk), + } + + dialDnsConn := func(ctx context.Context) (transport.DnsConn, error) { + ua, err := udpBootstrap(ctx) + if err != nil { + return nil, fmt.Errorf("bootstrap failed, %w", err) + } + + // This is a workaround to + // 1. recover from strange 0rtt rejected err. + // 2. avoid NextConnection might block forever. + // TODO: Remove this workaround. + var c quic.Connection + ec, err := t.DialEarly(ctx, ua, tlsConfig, quicConfig) + if err != nil { + return nil, err + } + c, err = ec.NextConnection(ctx) + if err != nil { + return nil, err + } + return transport.NewQuicDnsConn(c), nil + } + + return transport.NewPipelineTransport(transport.PipelineOpts{ + DialContext: dialDnsConn, + // Quic rfc recommendation is 100. Some implications use 65535. + MaxConcurrentQueryWhileDialing: 90, + Logger: opt.Logger, + }), nil + default: + return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme) + } +} + +type udpWithFallback struct { + u *transport.PipelineTransport + t *transport.ReuseConnTransport +} + +func (u *udpWithFallback) ExchangeContext(ctx context.Context, q []byte) (*[]byte, error) { + r, err := u.u.ExchangeContext(ctx, q) + if err != nil { + return nil, err + } + if msgTruncated(*r) { + pool.ReleaseBuf(r) + return u.t.ExchangeContext(ctx, q) + } + return r, nil +} + +func (u *udpWithFallback) Close() error { + u.u.Close() + u.t.Close() + return nil +} + +type dohWithClose struct { + u *doh.Upstream + closer io.Closer // maybe nil +} + +func (u *dohWithClose) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) { + return u.u.ExchangeContext(ctx, m) +} + +func (u *dohWithClose) Close() error { + if u.closer != nil { + return u.closer.Close() + } + return nil +} + +func newDefaultClientQuicConfig() *quic.Config { + return &quic.Config{ + TokenStore: quic.NewLRUTokenStore(4, 8), + + // Dns does not need large amount of io, so the rx/tx windows are small. + InitialStreamReceiveWindow: 4 * 1024, + MaxStreamReceiveWindow: 4 * 1024, + InitialConnectionReceiveWindow: 8 * 1024, + MaxConnectionReceiveWindow: 64 * 1024, + + MaxIdleTimeout: time.Second * 30, + KeepAlivePeriod: time.Second * 25, + HandshakeIdleTimeout: tlsHandshakeTimeout, + } +} diff --git a/pkg/upstream/upstream_test.go b/pkg/upstream/upstream_test.go new file mode 100644 index 0000000..f80ce0f --- /dev/null +++ b/pkg/upstream/upstream_test.go @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package upstream + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/miekg/dns" +) + +func newUDPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) { + udpConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + udpAddr := udpConn.LocalAddr().String() + udpServer := dns.Server{ + PacketConn: udpConn, + Handler: handler, + } + go udpServer.ActivateAndServe() + return udpAddr, func() { + udpServer.Shutdown() + } +} + +func newTCPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + tcpAddr := l.Addr().String() + tcpServer := dns.Server{ + Listener: l, + Handler: handler, + MaxTCPQueries: -1, + } + go tcpServer.ActivateAndServe() + return tcpAddr, func() { + tcpServer.Shutdown() + } +} + +func newDoTTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) { + serverName := "test" + cert, err := utils.GenerateCertificate(serverName) + if err != nil { + t.Fatal(err) + } + tlsConfig := new(tls.Config) + tlsConfig.Certificates = []tls.Certificate{cert} + tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig) + if err != nil { + t.Fatal(err) + } + doTAddr := tlsListener.Addr().String() + doTServer := dns.Server{ + Net: "tcp-tls", + Listener: tlsListener, + TLSConfig: tlsConfig, + Handler: handler, + MaxTCPQueries: -1, + } + go doTServer.ActivateAndServe() + return doTAddr, func() { + doTServer.Shutdown() + } +} + +type newTestServerFunc func(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) + +var m = map[string]newTestServerFunc{ + "udp": newUDPTestServer, + "tcp": newTCPTestServer, + "tls": newDoTTestServer, +} + +func Test_fastUpstream(t *testing.T) { + + // TODO: add test for doh + // TODO: add test for socks5 + + // server config + for scheme, f := range m { + for _, bigMsg := range [...]bool{true, false} { + for _, latency := range [...]time.Duration{0, time.Millisecond * 10} { + + // client specific + for _, idleTimeout := range [...]time.Duration{0, time.Second} { + + testName := fmt.Sprintf( + "test: protocol: %s, bigMsg: %v, latency: %s, getIdleTimeout: %s", + scheme, + bigMsg, + latency, + idleTimeout, + ) + + t.Run(testName, func(t *testing.T) { + addr, shutdownServer := f(t, &vServer{ + latency: latency, + bigMsg: bigMsg, + }) + defer shutdownServer() + u, err := NewUpstream( + scheme+"://"+addr, + Opt{ + IdleTimeout: time.Second, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + }, + ) + if err != nil { + t.Fatal(err) + } + + if err := testUpstream(u); err != nil { + t.Fatal(err) + } + }) + } + } + } + + } +} + +func testUpstream(u Upstream) error { + wg := sync.WaitGroup{} + errs := make([]error, 0) + errsLock := sync.Mutex{} + logErr := func(err error) { + errsLock.Lock() + errs = append(errs, err) + errsLock.Unlock() + } + errsToString := func() string { + s := fmt.Sprintf("%d err(s) occured during the test: ", len(errs)) + for i := range errs { + s = s + errs[i].Error() + "|" + } + return s + } + + for i := uint16(0); i < 10; i++ { + wg.Add(1) + i := i + go func() { + defer wg.Done() + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + q.Id = i + queryPayload, err := q.Pack() + if err != nil { + logErr(err) + return + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + r, err := u.ExchangeContext(ctx, queryPayload) + if err != nil { + logErr(err) + return + } + + resp := new(dns.Msg) + err = resp.Unpack(*r) + if err != nil { + logErr(err) + return + } + if q.Id != resp.Id { + logErr(dns.ErrId) + return + } + if !resp.Response { + logErr(fmt.Errorf("resp is not a resp bit")) + return + } + }() + } + + wg.Wait() + if len(errs) != 0 { + return errors.New(errsToString()) + } + return nil +} + +type vServer struct { + latency time.Duration + bigMsg bool // with 1kb padding +} + +var padding = make([]byte, 1024) + +func (s *vServer) ServeDNS(w dns.ResponseWriter, q *dns.Msg) { + r := new(dns.Msg) + r.SetReply(q) + if s.bigMsg { + r.SetEdns0(dns.MaxMsgSize, false) + opt := r.IsEdns0() + opt.Option = append(opt.Option, &dns.EDNS0_PADDING{Padding: padding}) + } + + time.Sleep(s.latency) + w.WriteMsg(r) +} diff --git a/pkg/upstream/utils.go b/pkg/upstream/utils.go new file mode 100644 index 0000000..2052733 --- /dev/null +++ b/pkg/upstream/utils.go @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package upstream + +import ( + "fmt" + "net" + "net/netip" + "strconv" +) + +type socketOpts struct { + so_mark int + bind_to_device string +} + +func parseDialAddr(urlHost, dialAddr string, defaultPort uint16) (string, uint16, error) { + addr := urlHost + if len(dialAddr) > 0 { + addr = dialAddr + } + host, port, err := trySplitHostPort(addr) + if err != nil { + return "", 0, err + } + if port == 0 { + port = defaultPort + } + return host, port, nil +} + +func joinPort(host string, port uint16) string { + return net.JoinHostPort(host, strconv.Itoa(int(port))) +} + +func tryRemovePort(s string) string { + host, _, err := net.SplitHostPort(s) + if err != nil { + return s + } + return host +} + +// trySplitHostPort splits host and port. +// If s has no port, it returns s,0,nil +func trySplitHostPort(s string) (string, uint16, error) { + var port uint16 + host, portS, err := net.SplitHostPort(s) + if err == nil { + n, err := strconv.ParseUint(portS, 10, 16) + if err != nil { + return "", 0, fmt.Errorf("invalid port, %w", err) + } + port = uint16(n) + return host, port, nil + } + return s, 0, nil +} + +func parseBootstrapAp(s string) (netip.AddrPort, error) { + host, port, err := trySplitHostPort(s) + if err != nil { + return netip.AddrPort{}, err + } + if port == 0 { + port = 53 + } + addr, err := netip.ParseAddr(host) + if err != nil { + return netip.AddrPort{}, err + } + return netip.AddrPortFrom(addr, port), nil +} + +func tryTrimIpv6Brackets(s string) string { + if len(s) < 2 { + return s + } + if s[0] == '[' && s[len(s)-1] == ']' { + return s[1 : len(s)-2] + } + return s +} + +func msgTruncated(b []byte) bool { + return b[2]&(1<<1) != 0 +} diff --git a/pkg/upstream/utils_others.go b/pkg/upstream/utils_others.go new file mode 100644 index 0000000..0cb9abd --- /dev/null +++ b/pkg/upstream/utils_others.go @@ -0,0 +1,28 @@ +//go:build !linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package upstream + +import "syscall" + +func getSocketControlFunc(_ socketOpts) func(string, string, syscall.RawConn) error { + return nil +} diff --git a/pkg/upstream/utils_unix.go b/pkg/upstream/utils_unix.go new file mode 100644 index 0000000..685a82b --- /dev/null +++ b/pkg/upstream/utils_unix.go @@ -0,0 +1,58 @@ +//go:build linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package upstream + +import ( + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +func getSocketControlFunc(opts socketOpts) func(string, string, syscall.RawConn) error { + return func(_, _ string, c syscall.RawConn) error { + var sysCallErr error + if err := c.Control(func(fd uintptr) { + // SO_MARK + if opts.so_mark > 0 { + sysCallErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.so_mark) + if sysCallErr != nil { + sysCallErr = os.NewSyscallError("failed to set SO_MARK", sysCallErr) + return + } + } + + // SO_BINDTODEVICE + if len(opts.bind_to_device) > 0 { + sysCallErr = unix.SetsockoptString(int(fd), unix.SOL_SOCKET, unix.SO_BINDTODEVICE, opts.bind_to_device) + if sysCallErr != nil { + sysCallErr = os.NewSyscallError("failed to set SO_BINDTODEVICE", sysCallErr) + return + } + } + + }); err != nil { + return err + } + return sysCallErr + } +} diff --git a/pkg/utils/config_helper.go b/pkg/utils/config_helper.go new file mode 100644 index 0000000..ae26819 --- /dev/null +++ b/pkg/utils/config_helper.go @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package utils + +import ( + "github.com/mitchellh/mapstructure" + "golang.org/x/exp/constraints" + "strconv" +) + +func SetDefaultNum[K constraints.Integer | constraints.Float](p *K, d K) { + if *p == 0 { + *p = d + } +} + +func SetDefaultUnsignNum[K constraints.Integer | constraints.Float](p *K, d K) { + if *p <= 0 { + *p = d + } +} + +func SetDefaultString(p *string, d string) { + if len(*p) == 0 { + *p = d + } +} + +func CheckNumRange[K constraints.Integer | constraints.Float](v, min, max K) bool { + if v < min || v > max { + return false + } + return true +} + +// WeakDecode decodes args from config to output. +func WeakDecode(in any, output any) error { + config := &mapstructure.DecoderConfig{ + ErrorUnused: true, + Result: output, + WeaklyTypedInput: true, + TagName: "yaml", + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(in) +} + +func ParseNameOrNum[T constraints.Integer](s string, m map[string]T) (T, bool) { + i, err := strconv.Atoi(s) + if err != nil { + v, ok := m[s] + return v, ok + } + return T(i), true +} diff --git a/pkg/utils/net.go b/pkg/utils/net.go new file mode 100644 index 0000000..bc82722 --- /dev/null +++ b/pkg/utils/net.go @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package utils + +import ( + "net" + "net/netip" +) + +// GetIPFromAddr returns a net.IP from the given net.Addr. +// addr can be *net.TCPAddr, *net.UDPAddr, *net.IPNet, *net.IPAddr +// Will return nil otherwise. +func GetIPFromAddr(addr net.Addr) (ip net.IP) { + switch v := addr.(type) { + case *net.TCPAddr: + return v.IP + case *net.UDPAddr: + return v.IP + case *net.IPNet: + return v.IP + case *net.IPAddr: + return v.IP + } + return nil +} + +// GetAddrFromAddr returns netip.Addr from net.Addr. +// See also: GetIPFromAddr. +func GetAddrFromAddr(addr net.Addr) netip.Addr { + a, _ := netip.AddrFromSlice(GetIPFromAddr(addr)) + return a +} + +// SplitSchemeAndHost splits addr to protocol and host. +func SplitSchemeAndHost(addr string) (protocol, host string) { + if protocol, host, ok := SplitString2(addr, "://"); ok { + return protocol, host + } else { + return "", addr + } +} diff --git a/pkg/utils/quic.go b/pkg/utils/quic.go new file mode 100644 index 0000000..f49ed48 --- /dev/null +++ b/pkg/utils/quic.go @@ -0,0 +1,52 @@ +package utils + +import ( + "crypto/sha256" + "errors" + "net" + "sync" +) + +var quicSrkSalt = []byte{115, 189, 156, 229, 145, 216, 251, 127, 220, 89, + 243, 234, 211, 79, 190, 166, 135, 253, 183, 36, 245, 174, 78, 200, 54, 213, + 85, 255, 104, 240, 103, 27} + +var ( + quicSrkInitOnce sync.Once + quicSrk *[32]byte + quicSrkFromIface net.Interface + quicSrkInitErr error +) + +func initQUICSrkFromIfaceMac() { + nonZero := func(b []byte) bool { + for _, i := range b { + if i != 0 { + return true + } + } + return false + } + + ifaces, err := net.Interfaces() + if err != nil { + quicSrkInitErr = err + return + } + for _, iface := range ifaces { + if nonZero(iface.HardwareAddr) { + k := sha256.Sum256(append(iface.HardwareAddr, quicSrkSalt...)) + quicSrk = &k + quicSrkFromIface = iface + return + } + } + quicSrkInitErr = errors.New("cannot find non-zero mac interface") +} + +// A helper func to init quic stateless reset key. +// It use the first non-zero interface mac + sha256 hash. +func InitQUICSrkFromIfaceMac() (*[32]byte, net.Interface, error) { + quicSrkInitOnce.Do(initQUICSrkFromIfaceMac) + return quicSrk, quicSrkFromIface, quicSrkInitErr +} diff --git a/pkg/utils/server.go b/pkg/utils/server.go new file mode 100644 index 0000000..f19bed7 --- /dev/null +++ b/pkg/utils/server.go @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package utils diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go new file mode 100644 index 0000000..23471c2 --- /dev/null +++ b/pkg/utils/strings.go @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package utils + +import ( + "strings" + "unsafe" +) + +// BytesToStringUnsafe converts bytes to string. +func BytesToStringUnsafe(b []byte) string { + return unsafe.String(unsafe.SliceData(b), len(b)) +} + +// RemoveComment removes comment after "symbol". +func RemoveComment(s, symbol string) string { + if i := strings.Index(s, symbol); i >= 0 { + return s[:i] + } + return s +} + +// SplitString2 split s to two parts by given symbol +func SplitString2(s, symbol string) (s1 string, s2 string, ok bool) { + if len(symbol) == 0 { + return "", s, true + } + if i := strings.Index(s, symbol); i >= 0 { + return s[:i], s[i+len(symbol):], true + } + return "", "", false +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 0000000..b2b68e9 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package utils + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "time" +) + +// LoadCertPool reads and loads certificates in certs. +func LoadCertPool(certs []string) (*x509.CertPool, error) { + rootCAs := x509.NewCertPool() + for _, cert := range certs { + b, err := os.ReadFile(cert) + if err != nil { + return nil, err + } + + if ok := rootCAs.AppendCertsFromPEM(b); !ok { + return nil, fmt.Errorf("no certificate was successfully parsed in %s", cert) + } + } + return rootCAs, nil +} + +// GenerateCertificate generates an ecdsa certificate with given dnsName. +// This should only use in test. +func GenerateCertificate(dnsName string) (cert tls.Certificate, err error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return + } + + //serial number + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + err = fmt.Errorf("generate serial number: %w", err) + return + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{CommonName: dnsName}, + DNSNames: []string{dnsName}, + + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + return + } + b, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + return + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + return tls.X509KeyPair(certPEM, keyPEM) +} + +// ClosedChan returns true if c is closed. +// c must not use for sending data and must be used in close() only. +// If ClosedChan receives something from c, it panics. +func ClosedChan(c chan struct{}) bool { + select { + case _, ok := <-c: + if !ok { + return true + } + panic("received from the chan") + default: + return false + } +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 0000000..0bfbe23 --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package utils + +import ( + "reflect" + "testing" +) + +func TestSplitString2(t *testing.T) { + type args struct { + s string + symbol string + } + tests := []struct { + name string + args args + wantS1 string + wantS2 string + wantOk bool + }{ + {"blank", args{"", ""}, "", "", true}, + {"blank", args{"///", ""}, "", "///", true}, + {"split", args{"///", "/"}, "", "//", true}, + {"split", args{"--/", "/"}, "--", "", true}, + {"split", args{"https://***.***.***", "://"}, "https", "***.***.***", true}, + {"split", args{"://***.***.***", "://"}, "", "***.***.***", true}, + {"split", args{"https://", "://"}, "https", "", true}, + {"split", args{"--/", "*"}, "", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotS1, gotS2, gotOk := SplitString2(tt.args.s, tt.args.symbol) + if gotS1 != tt.wantS1 { + t.Errorf("SplitString2() gotS1 = %v, want %v", gotS1, tt.wantS1) + } + if gotS2 != tt.wantS2 { + t.Errorf("SplitString2() gotS2 = %v, want %v", gotS2, tt.wantS2) + } + if gotOk != tt.wantOk { + t.Errorf("SplitString2() gotOk = %v, want %v", gotOk, tt.wantOk) + } + }) + } +} + +func TestRemoveComment(t *testing.T) { + type args struct { + s string + symbol string + } + tests := []struct { + name string + args args + want string + }{ + {name: "empty", args: args{s: "", symbol: ""}, want: ""}, + {name: "empty symbol", args: args{s: "12345", symbol: ""}, want: ""}, + {name: "empty string", args: args{s: "", symbol: "#"}, want: ""}, + {name: "remove 1", args: args{s: "123/456", symbol: "/"}, want: "123"}, + {name: "remove 2", args: args{s: "123//456", symbol: "//"}, want: "123"}, + {name: "remove 3", args: args{s: "123/*/456", symbol: "//"}, want: "123/*/456"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := RemoveComment(tt.args.s, tt.args.symbol); got != tt.want { + t.Errorf("RemoveComment() = %v, want %v", got, tt.want) + } + }) + } +} + +type TestArgsStruct struct { + A string `yaml:"1"` + B []int `yaml:"2"` +} + +func Test_WeakDecode(t *testing.T) { + testObj := new(TestArgsStruct) + testArgs := map[string]any{ + "1": "test", + "2": []int{1, 2, 3}, + } + wantObj := &TestArgsStruct{ + A: "test", + B: []int{1, 2, 3}, + } + + err := WeakDecode(testArgs, testObj) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(testObj, wantObj) { + t.Fatalf("args decode failed, want %v, got %v", wantObj, testObj) + } +} + +func Test_WeakDecode2(t *testing.T) { + testObj := new([]byte) + args := []any{"1", 2, 3} + err := WeakDecode(args, testObj) + if err != nil { + t.Fatal(err) + } +} diff --git a/pkg/zone_file/zone_file.go b/pkg/zone_file/zone_file.go new file mode 100644 index 0000000..f30fbf7 --- /dev/null +++ b/pkg/zone_file/zone_file.go @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package zone_file + +import ( + "io" + "os" + "strings" + + "github.com/miekg/dns" +) + +type Matcher struct { + m map[dns.Question][]dns.RR +} + +func (m *Matcher) LoadFile(s string) error { + f, err := os.Open(s) + if err != nil { + return err + } + defer f.Close() + + return m.Load(f) +} + +func (m *Matcher) Load(r io.Reader) error { + if m.m == nil { + m.m = make(map[dns.Question][]dns.RR) + } + + parser := dns.NewZoneParser(r, "", "") + parser.SetDefaultTTL(3600) + for { + rr, ok := parser.Next() + if !ok { + break + } + h := rr.Header() + q := dns.Question{ + Name: strings.ToLower(h.Name), + Qtype: h.Rrtype, + Qclass: h.Class, + } + m.m[q] = append(m.m[q], rr) + } + return parser.Err() +} + +func (m *Matcher) Search(q dns.Question) []dns.RR { + q.Name = strings.ToLower(q.Name) + return m.m[q] +} + +func (m *Matcher) Reply(q *dns.Msg) *dns.Msg { + var r *dns.Msg + for _, question := range q.Question { + rr := m.Search(question) + if rr != nil { + if r == nil { + r = new(dns.Msg) + r.SetReply(q) + } + r.Answer = append(r.Answer, rr...) + } + } + return r +} diff --git a/pkg/zone_file/zone_file_test.go b/pkg/zone_file/zone_file_test.go new file mode 100644 index 0000000..3178dbb --- /dev/null +++ b/pkg/zone_file/zone_file_test.go @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package zone_file + +import ( + "github.com/miekg/dns" + "strings" + "testing" +) + +const data = ` +$TTL 3600 +example.com. IN A 192.0.2.1 +1.example.com. IN AAAA 2001:db8:10::1 +` + +func TestMatcher(t *testing.T) { + m := new(Matcher) + err := m.Load(strings.NewReader(data)) + if err != nil { + t.Fatal(err) + } + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + r := m.Reply(q) + if r == nil { + t.Fatal("search failed") + } + if got := r.Answer[0].(*dns.A).A.String(); got != "192.0.2.1" { + t.Fatalf("want ip 192.0.2.1, got %s", got) + } + + q = new(dns.Msg) + q.SetQuestion("1.example.com.", dns.TypeAAAA) + r = m.Reply(q) + if r == nil { + t.Fatal("search failed") + } + if got := r.Answer[0].(*dns.AAAA).AAAA.String(); got != "2001:db8:10::1" { + t.Fatalf("want ip 2001:db8:10::1, got %s", got) + } +} diff --git a/plugin/data_provider/domain_set/domain_set.go b/plugin/data_provider/domain_set/domain_set.go new file mode 100644 index 0000000..20d7f90 --- /dev/null +++ b/plugin/data_provider/domain_set/domain_set.go @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package domain_set + +import ( + "bytes" + "fmt" + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/v5/plugin/data_provider" + "os" +) + +const PluginType = "domain_set" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +func Init(bp *coremain.BP, args any) (any, error) { + m, err := NewDomainSet(bp, args.(*Args)) + if err != nil { + return nil, err + } + return m, nil +} + +type Args struct { + Exps []string `yaml:"exps"` + Sets []string `yaml:"sets"` + Files []string `yaml:"files"` +} + +var _ data_provider.DomainMatcherProvider = (*DomainSet)(nil) + +type DomainSet struct { + mg []domain.Matcher[struct{}] +} + +func (d *DomainSet) GetDomainMatcher() domain.Matcher[struct{}] { + return MatcherGroup(d.mg) +} + +// NewDomainSet inits a DomainSet from given args. +func NewDomainSet(bp *coremain.BP, args *Args) (*DomainSet, error) { + ds := &DomainSet{} + + m := domain.NewDomainMixMatcher() + if err := LoadExpsAndFiles(args.Exps, args.Files, m); err != nil { + return nil, err + } + if m.Len() > 0 { + ds.mg = append(ds.mg, m) + } + + for _, tag := range args.Sets { + provider, _ := bp.M().GetPlugin(tag).(data_provider.DomainMatcherProvider) + if provider == nil { + return nil, fmt.Errorf("%s is not a DomainMatcherProvider", tag) + } + m := provider.GetDomainMatcher() + ds.mg = append(ds.mg, m) + } + return ds, nil +} + +func LoadExpsAndFiles(exps []string, fs []string, m *domain.MixMatcher[struct{}]) error { + if err := LoadExps(exps, m); err != nil { + return err + } + if err := LoadFiles(fs, m); err != nil { + return err + } + return nil +} + +func LoadExps(exps []string, m *domain.MixMatcher[struct{}]) error { + for i, exp := range exps { + if err := m.Add(exp, struct{}{}); err != nil { + return fmt.Errorf("failed to load expression #%d %s, %w", i, exp, err) + } + } + return nil +} + +func LoadFiles(fs []string, m *domain.MixMatcher[struct{}]) error { + for i, f := range fs { + if err := LoadFile(f, m); err != nil { + return fmt.Errorf("failed to load file #%d %s, %w", i, f, err) + } + } + return nil +} + +func LoadFile(f string, m *domain.MixMatcher[struct{}]) error { + if len(f) > 0 { + b, err := os.ReadFile(f) + if err != nil { + return err + } + + if err := domain.LoadFromTextReader[struct{}](m, bytes.NewReader(b), nil); err != nil { + return err + } + } + return nil +} diff --git a/plugin/data_provider/domain_set/group.go b/plugin/data_provider/domain_set/group.go new file mode 100644 index 0000000..43d1a88 --- /dev/null +++ b/plugin/data_provider/domain_set/group.go @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package domain_set + +import "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + +type MatcherGroup []domain.Matcher[struct{}] + +func (mg MatcherGroup) Match(s string) (struct{}, bool) { + for _, m := range mg { + if _, ok := m.Match(s); ok { + return struct{}{}, true + } + } + return struct{}{}, false +} diff --git a/plugin/data_provider/iface.go b/plugin/data_provider/iface.go new file mode 100644 index 0000000..7ec08d2 --- /dev/null +++ b/plugin/data_provider/iface.go @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package data_provider + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist" +) + +type DomainMatcherProvider interface { + GetDomainMatcher() domain.Matcher[struct{}] +} + +type IPMatcherProvider interface { + GetIPMatcher() netlist.Matcher +} diff --git a/plugin/data_provider/ip_set/ip_set.go b/plugin/data_provider/ip_set/ip_set.go new file mode 100644 index 0000000..dc06594 --- /dev/null +++ b/plugin/data_provider/ip_set/ip_set.go @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ip_set + +import ( + "bytes" + "fmt" + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/v5/plugin/data_provider" + "net/netip" + "os" + "strings" +) + +const PluginType = "ip_set" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +func Init(bp *coremain.BP, args any) (any, error) { + return NewIPSet(bp, args.(*Args)) +} + +type Args struct { + IPs []string `yaml:"ips"` + Sets []string `yaml:"sets"` + Files []string `yaml:"files"` +} + +var _ data_provider.IPMatcherProvider = (*IPSet)(nil) + +type IPSet struct { + mg []netlist.Matcher +} + +func (d *IPSet) GetIPMatcher() netlist.Matcher { + return MatcherGroup(d.mg) +} + +func NewIPSet(bp *coremain.BP, args *Args) (*IPSet, error) { + p := &IPSet{} + + l := netlist.NewList() + if err := LoadFromIPsAndFiles(args.IPs, args.Files, l); err != nil { + return nil, err + } + l.Sort() + if l.Len() > 0 { + p.mg = append(p.mg, l) + } + for _, tag := range args.Sets { + provider, _ := bp.M().GetPlugin(tag).(data_provider.IPMatcherProvider) + if provider == nil { + return nil, fmt.Errorf("%s is not an IPMatcherProvider", tag) + } + p.mg = append(p.mg, provider.GetIPMatcher()) + } + return p, nil +} + +func parseNetipPrefix(s string) (netip.Prefix, error) { + if strings.ContainsRune(s, '/') { + return netip.ParsePrefix(s) + } + addr, err := netip.ParseAddr(s) + if err != nil { + return netip.Prefix{}, err + } + return addr.Prefix(addr.BitLen()) +} + +func LoadFromIPsAndFiles(ips []string, fs []string, l *netlist.List) error { + if err := LoadFromIPs(ips, l); err != nil { + return err + } + if err := LoadFromFiles(fs, l); err != nil { + return err + } + return nil +} + +func LoadFromIPs(ips []string, l *netlist.List) error { + for i, s := range ips { + p, err := parseNetipPrefix(s) + if err != nil { + return fmt.Errorf("invalid ip #%d %s, %w", i, s, err) + } + l.Append(p) + } + return nil +} + +func LoadFromFiles(fs []string, l *netlist.List) error { + for i, f := range fs { + if err := LoadFromFile(f, l); err != nil { + return fmt.Errorf("failed to load file #%d %s, %w", i, f, err) + } + } + return nil +} + +func LoadFromFile(f string, l *netlist.List) error { + if len(f) > 0 { + b, err := os.ReadFile(f) + if err != nil { + return err + } + if err := netlist.LoadFromReader(l, bytes.NewReader(b)); err != nil { + return err + } + } + return nil +} + +type MatcherGroup []netlist.Matcher + +func (mg MatcherGroup) Match(addr netip.Addr) bool { + for _, m := range mg { + if m.Match(addr) { + return true + } + } + return false +} diff --git a/plugin/enabled_plugin_test.go b/plugin/enabled_plugin_test.go new file mode 100644 index 0000000..ce8a48d --- /dev/null +++ b/plugin/enabled_plugin_test.go @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package plugin + +import "testing" + +// This is an empty test, but it can run all init() of enabled plugins. +func Test_plugins_init(t *testing.T) { + +} diff --git a/plugin/enabled_plugins.go b/plugin/enabled_plugins.go new file mode 100644 index 0000000..c5f47ca --- /dev/null +++ b/plugin/enabled_plugins.go @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package plugin + +// data providers +import ( + // data provider + _ "github.com/IrineSistiana/mosdns/v5/plugin/data_provider/domain_set" + _ "github.com/IrineSistiana/mosdns/v5/plugin/data_provider/ip_set" + + // matcher + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/client_ip" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/cname" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/env" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/has_resp" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/has_wanted_ans" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/ptr_ip" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/qclass" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/qname" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/qtype" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/random" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/rcode" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/resp_ip" + _ "github.com/IrineSistiana/mosdns/v5/plugin/matcher/string_exp" + + // executable + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/arbitrary" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/black_hole" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/cache" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/debug_print" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/drop_resp" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/dual_selector" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/ecs_handler" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/forward" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/forward_edns0opt" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/hosts" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/ipset" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/metrics_collector" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/mikrotik_addresslist" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/nftset" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/query_summary" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/rate_limiter" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/redirect" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/reverse_lookup" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence/fallback" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/sleep" + _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/ttl" + + // executable and matcher + _ "github.com/IrineSistiana/mosdns/v5/plugin/mark" + + // server + _ "github.com/IrineSistiana/mosdns/v5/plugin/server/http_server" + _ "github.com/IrineSistiana/mosdns/v5/plugin/server/quic_server" + _ "github.com/IrineSistiana/mosdns/v5/plugin/server/tcp_server" + _ "github.com/IrineSistiana/mosdns/v5/plugin/server/udp_server" +) diff --git a/plugin/executable/arbitrary/arbitrary.go b/plugin/executable/arbitrary/arbitrary.go new file mode 100644 index 0000000..90e4010 --- /dev/null +++ b/plugin/executable/arbitrary/arbitrary.go @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package arbitrary + +import ( + "bytes" + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/zone_file" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "os" + "strings" +) + +const PluginType = "arbitrary" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +type Args struct { + Rules []string `yaml:"rules"` + Files []string `yaml:"files"` +} + +var _ sequence.Executable = (*Arbitrary)(nil) + +type Arbitrary struct { + m *zone_file.Matcher +} + +func NewArbitrary(args *Args) (*Arbitrary, error) { + m := new(zone_file.Matcher) + for i, s := range args.Rules { + if err := m.Load(strings.NewReader(s)); err != nil { + return nil, fmt.Errorf("failed to load rr #%d [%s], %w", i, s, err) + } + } + for i, file := range args.Files { + b, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file #%d [%s], %w", i, file, err) + } + if err := m.Load(bytes.NewReader(b)); err != nil { + return nil, fmt.Errorf("failed to load rr file #%d [%s], %w", i, file, err) + } + } + return &Arbitrary{ + m: m, + }, nil +} + +func (a *Arbitrary) Exec(_ context.Context, qCtx *query_context.Context) error { + if r := a.m.Reply(qCtx.Q()); r != nil { + qCtx.SetResponse(r) + } + return nil +} + +func Init(_ *coremain.BP, v any) (any, error) { + args := v.(*Args) + return NewArbitrary(args) +} diff --git a/plugin/executable/black_hole/black_hole.go b/plugin/executable/black_hole/black_hole.go new file mode 100644 index 0000000..775253d --- /dev/null +++ b/plugin/executable/black_hole/black_hole.go @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package black_hole + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "net/netip" + "strings" +) + +const PluginType = "black_hole" + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.Executable = (*BlackHole)(nil) + +type BlackHole struct { + ipv4 []netip.Addr + ipv6 []netip.Addr +} + +// QuickSetup format: [ipv4|ipv6] ... +// Support both ipv4/a and ipv6/aaaa families. +func QuickSetup(_ sequence.BQ, s string) (any, error) { + return NewBlackHole(strings.Fields(s)) +} + +// NewBlackHole creates a new BlackHole with given ips. +func NewBlackHole(ips []string) (*BlackHole, error) { + b := &BlackHole{} + for _, s := range ips { + addr, err := netip.ParseAddr(s) + if err != nil { + return nil, fmt.Errorf("invalid ipv4 addr %s, %w", s, err) + } + if addr.Is4() { + b.ipv4 = append(b.ipv4, addr) + } else { + b.ipv6 = append(b.ipv6, addr) + } + } + return b, nil +} + +// Exec implements sequence.Executable. It set a response with given ips if +// query has corresponding qtypes. +func (b *BlackHole) Exec(_ context.Context, qCtx *query_context.Context) error { + if r := b.Response(qCtx.Q()); r != nil { + qCtx.SetResponse(r) + } + return nil +} + +// Response returns a response with given ips if query has corresponding qtypes. +// Otherwise, it returns nil. +func (b *BlackHole) Response(q *dns.Msg) *dns.Msg { + if len(q.Question) != 1 { + return nil + } + + qName := q.Question[0].Name + qtype := q.Question[0].Qtype + + switch { + case qtype == dns.TypeA && len(b.ipv4) > 0: + r := new(dns.Msg) + r.SetReply(q) + for _, addr := range b.ipv4 { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: qName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: addr.AsSlice(), + } + r.Answer = append(r.Answer, rr) + } + return r + + case qtype == dns.TypeAAAA && len(b.ipv6) > 0: + r := new(dns.Msg) + r.SetReply(q) + for _, addr := range b.ipv6 { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: qName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: addr.AsSlice(), + } + r.Answer = append(r.Answer, rr) + } + return r + } + return nil +} diff --git a/plugin/executable/cache/cache.go b/plugin/executable/cache/cache.go new file mode 100644 index 0000000..bc00112 --- /dev/null +++ b/plugin/executable/cache/cache.go @@ -0,0 +1,481 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package cache + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net/http" + "os" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/cache" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/go-chi/chi/v5" + "github.com/klauspost/compress/gzip" + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" + "google.golang.org/protobuf/proto" +) + +const ( + PluginType = "cache" +) + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) + sequence.MustRegExecQuickSetup(PluginType, quickSetupCache) +} + +const ( + defaultLazyUpdateTimeout = time.Second * 5 + expiredMsgTtl = 5 + + minimumChangesToDump = 1024 + dumpHeader = "mosdns_cache_v2" + dumpBlockSize = 128 + dumpMaximumBlockLength = 1 << 20 // 1M block. 8kb pre entry. Should be enough. +) + +var _ sequence.RecursiveExecutable = (*Cache)(nil) + +type Args struct { + Size int `yaml:"size"` + LazyCacheTTL int `yaml:"lazy_cache_ttl"` + DumpFile string `yaml:"dump_file"` + DumpInterval int `yaml:"dump_interval"` +} + +func (a *Args) init() { + utils.SetDefaultUnsignNum(&a.Size, 1024) + utils.SetDefaultUnsignNum(&a.DumpInterval, 600) +} + +type Cache struct { + args *Args + + logger *zap.Logger + backend *cache.Cache[key, *item] + lazyUpdateSF singleflight.Group + closeOnce sync.Once + closeNotify chan struct{} + updatedKey atomic.Uint64 + + queryTotal prometheus.Counter + hitTotal prometheus.Counter + lazyHitTotal prometheus.Counter + size prometheus.GaugeFunc +} + +func Init(bp *coremain.BP, args any) (any, error) { + c := NewCache(args.(*Args), Opts{ + Logger: bp.L(), + MetricsTag: bp.Tag(), + }) + + if err := c.RegMetricsTo(prometheus.WrapRegistererWithPrefix(PluginType+"_", bp.M().GetMetricsReg())); err != nil { + return nil, fmt.Errorf("failed to register metrics, %w", err) + } + bp.RegAPI(c.Api()) + return c, nil +} + +// QuickSetup format: [size] +// default is 1024. If size is < 1024, 1024 will be used. +func quickSetupCache(bq sequence.BQ, s string) (any, error) { + size := 0 + if len(s) > 0 { + i, err := strconv.Atoi(s) + if err != nil { + return nil, fmt.Errorf("invalid size, %w", err) + } + size = i + } + // Don't register metrics in quick setup. + return NewCache(&Args{Size: size}, Opts{Logger: bq.L()}), nil +} + +type Opts struct { + Logger *zap.Logger + MetricsTag string +} + +func NewCache(args *Args, opts Opts) *Cache { + args.init() + + logger := opts.Logger + if logger == nil { + logger = zap.NewNop() + } + + backend := cache.New[key, *item](cache.Opts{Size: args.Size}) + lb := map[string]string{"tag": opts.MetricsTag} + p := &Cache{ + args: args, + logger: logger, + backend: backend, + closeNotify: make(chan struct{}), + + queryTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "query_total", + Help: "The total number of processed queries", + ConstLabels: lb, + }), + hitTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "hit_total", + Help: "The total number of queries that hit the cache", + ConstLabels: lb, + }), + lazyHitTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "lazy_hit_total", + Help: "The total number of queries that hit the expired cache", + ConstLabels: lb, + }), + size: prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "size_current", + Help: "Current cache size in records", + ConstLabels: lb, + }, func() float64 { + return float64(backend.Len()) + }), + } + + if err := p.loadDump(); err != nil { + p.logger.Error("failed to load cache dump", zap.Error(err)) + } + p.startDumpLoop() + + return p +} + +func (c *Cache) RegMetricsTo(r prometheus.Registerer) error { + for _, collector := range [...]prometheus.Collector{c.queryTotal, c.hitTotal, c.lazyHitTotal, c.size} { + if err := r.Register(collector); err != nil { + return err + } + } + return nil +} + +func (c *Cache) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error { + c.queryTotal.Inc() + q := qCtx.Q() + + msgKey := getMsgKey(q) + if len(msgKey) == 0 { // skip cache + return next.ExecNext(ctx, qCtx) + } + + cachedResp, lazyHit := getRespFromCache(msgKey, c.backend, c.args.LazyCacheTTL > 0, expiredMsgTtl) + if lazyHit { + c.lazyHitTotal.Inc() + c.doLazyUpdate(msgKey, qCtx, next) + } + if cachedResp != nil { // cache hit + c.hitTotal.Inc() + cachedResp.Id = q.Id // change msg id + qCtx.SetResponse(cachedResp) + } + + err := next.ExecNext(ctx, qCtx) + + if r := qCtx.R(); r != nil && cachedResp != r { // pointer compare. r is not cachedResp + saveRespToCache(msgKey, r, c.backend, c.args.LazyCacheTTL) + c.updatedKey.Add(1) + } + return err +} + +// doLazyUpdate starts a new goroutine to execute next node and update the cache in the background. +// It has an inner singleflight.Group to de-duplicate same msgKey. +func (c *Cache) doLazyUpdate(msgKey string, qCtx *query_context.Context, next sequence.ChainWalker) { + qCtxCopy := qCtx.Copy() + lazyUpdateFunc := func() (any, error) { + defer c.lazyUpdateSF.Forget(msgKey) + qCtx := qCtxCopy + + c.logger.Debug("start lazy cache update", qCtx.InfoField()) + ctx, cancel := context.WithTimeout(context.Background(), defaultLazyUpdateTimeout) + defer cancel() + + err := next.ExecNext(ctx, qCtx) + if err != nil { + c.logger.Warn("failed to update lazy cache", qCtx.InfoField(), zap.Error(err)) + } + + r := qCtx.R() + if r != nil { + saveRespToCache(msgKey, r, c.backend, c.args.LazyCacheTTL) + c.updatedKey.Add(1) + } + c.logger.Debug("lazy cache updated", qCtx.InfoField()) + return nil, nil + } + c.lazyUpdateSF.DoChan(msgKey, lazyUpdateFunc) // DoChan won't block this goroutine +} + +func (c *Cache) Close() error { + if err := c.dumpCache(); err != nil { + c.logger.Error("failed to dump cache", zap.Error(err)) + } + c.closeOnce.Do(func() { + close(c.closeNotify) + }) + return c.backend.Close() +} + +func (c *Cache) loadDump() error { + if len(c.args.DumpFile) == 0 { + return nil + } + f, err := os.Open(c.args.DumpFile) + if err != nil { + return err + } + defer f.Close() + en, err := c.readDump(f) + if err != nil { + return err + } + c.logger.Info("cache dump loaded", zap.Int("entries", en)) + return nil +} + +// startDumpLoop starts a dump loop in another goroutine. It does not block. +func (c *Cache) startDumpLoop() { + if len(c.args.DumpFile) == 0 { + return + } + go func() { + ticker := time.NewTicker(time.Duration(c.args.DumpInterval) * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + // Check if we have enough changes to dump. + keyUpdated := c.updatedKey.Swap(0) + if keyUpdated < minimumChangesToDump { // Nop. + c.updatedKey.Add(keyUpdated) + continue + } + + if err := c.dumpCache(); err != nil { + c.logger.Error("dump cache", zap.Error(err)) + } + case <-c.closeNotify: + return + } + } + }() +} + +func (c *Cache) dumpCache() error { + if len(c.args.DumpFile) == 0 { + return nil + } + + f, err := os.Create(c.args.DumpFile) + if err != nil { + return err + } + defer f.Close() + + en, err := c.writeDump(f) + if err != nil { + return fmt.Errorf("failed to write dump, %w", err) + } + c.logger.Info("cache dumped", zap.Int("entries", en)) + return nil +} + +func (c *Cache) Api() *chi.Mux { + r := chi.NewRouter() + r.Get("/flush", func(w http.ResponseWriter, req *http.Request) { + c.backend.Flush() + }) + r.Get("/dump", func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("content-type", "application/octet-stream") + _, err := c.writeDump(w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + r.Post("/load_dump", func(w http.ResponseWriter, req *http.Request) { + if _, err := c.readDump(req.Body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + }) + return r +} + +func (c *Cache) writeDump(w io.Writer) (int, error) { + en := 0 + + gw, _ := gzip.NewWriterLevel(w, gzip.BestSpeed) + gw.Name = dumpHeader + + block := new(CacheDumpBlock) + writeBlock := func() error { + b, err := proto.Marshal(block) + if err != nil { + return fmt.Errorf("failed to marshal protobuf, %w", err) + } + + l := make([]byte, 8) + binary.BigEndian.PutUint64(l, uint64(len(b))) + _, err = gw.Write(l) + if err != nil { + return fmt.Errorf("failed to write header, %w", err) + } + _, err = gw.Write(b) + if err != nil { + return fmt.Errorf("failed to write data, %w", err) + } + + en += len(block.GetEntries()) + block.Reset() + return nil + } + + now := time.Now() + rangeFunc := func(k key, v *item, cacheExpirationTime time.Time) error { + if cacheExpirationTime.Before(now) { + return nil + } + msg, err := v.resp.Pack() + if err != nil { + return fmt.Errorf("failed to pack msg, %w", err) + } + e := &CachedEntry{ + Key: []byte(k), + CacheExpirationTime: cacheExpirationTime.Unix(), + MsgExpirationTime: v.expirationTime.Unix(), + Msg: msg, + } + block.Entries = append(block.Entries, e) + + // Block is big enough for a write operation. + if len(block.Entries) >= dumpBlockSize { + return writeBlock() + } + return nil + } + if err := c.backend.Range(rangeFunc); err != nil { + return en, err + } + + if len(block.GetEntries()) > 0 { + if err := writeBlock(); err != nil { + return en, err + } + } + return en, gw.Close() +} + +// readDump reads dumped data from r. It returns the number of bytes read, +// number of entries read and any error encountered. +func (c *Cache) readDump(r io.Reader) (int, error) { + en := 0 + gr, err := gzip.NewReader(r) + if err != nil { + return en, fmt.Errorf("failed to read gzip header, %w", err) + } + if gr.Name != dumpHeader { + return en, fmt.Errorf("invalid or old cache dump, header is %s, want %s", gr.Name, dumpHeader) + } + + var errReadHeaderEOF = errors.New("") + readBlock := func() error { + h := pool.GetBuf(8) + defer pool.ReleaseBuf(h) + _, err := io.ReadFull(gr, *h) + if err != nil { + if errors.Is(err, io.EOF) { + return errReadHeaderEOF + } + return fmt.Errorf("failed to read block header, %w", err) + } + u := binary.BigEndian.Uint64(*h) + if u > dumpMaximumBlockLength { + return fmt.Errorf("invalid header, block length is big, %d", u) + } + + b := pool.GetBuf(int(u)) + defer pool.ReleaseBuf(b) + _, err = io.ReadFull(gr, *b) + if err != nil { + return fmt.Errorf("failed to read block data, %w", err) + } + + block := new(CacheDumpBlock) + if err := proto.Unmarshal(*b, block); err != nil { + return fmt.Errorf("failed to decode block data, %w", err) + } + + en += len(block.GetEntries()) + for _, entry := range block.GetEntries() { + cacheExpTime := time.Unix(entry.GetCacheExpirationTime(), 0) + msgExpTime := time.Unix(entry.GetMsgExpirationTime(), 0) + storedTime := time.Unix(entry.GetMsgStoredTime(), 0) + resp := new(dns.Msg) + if err := resp.Unpack(entry.GetMsg()); err != nil { + return fmt.Errorf("failed to decode dns msg, %w", err) + } + + i := &item{ + resp: resp, + storedTime: storedTime, + expirationTime: msgExpTime, + } + c.backend.Store(key(entry.GetKey()), i, cacheExpTime) + } + return nil + } + + for { + err = readBlock() + if err != nil { + if err == errReadHeaderEOF { + err = nil // This is expected if there is no block to read. + } + break + } + } + + if err != nil { + return en, err + } + return en, gr.Close() +} diff --git a/plugin/executable/cache/cache_test.go b/plugin/executable/cache/cache_test.go new file mode 100644 index 0000000..95e2036 --- /dev/null +++ b/plugin/executable/cache/cache_test.go @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package cache + +import ( + "bytes" + "github.com/miekg/dns" + "strconv" + "testing" + "time" +) + +func Test_cachePlugin_Dump(t *testing.T) { + c := NewCache(&Args{Size: 16 * dumpBlockSize}, Opts{}) // Big enough to create dump fragments. + + resp := new(dns.Msg) + resp.SetQuestion("test.", dns.TypeA) + + now := time.Now() + hourLater := now.Add(time.Hour) + v := &item{ + resp: resp, + storedTime: now, + expirationTime: hourLater, + } + + // Fill the cache + for i := 0; i < 32*dumpBlockSize; i++ { + c.backend.Store(key(strconv.Itoa(i)), v, hourLater) + } + + buf := new(bytes.Buffer) + enw, err := c.writeDump(buf) + if err != nil { + t.Fatal(err) + } + enr, err := c.readDump(buf) + if err != nil { + t.Fatal(err) + } + + if enw != enr { + t.Fatalf("read err, wrote %d entries, read %d", enw, enr) + } +} diff --git a/plugin/executable/cache/dump.pb.go b/plugin/executable/cache/dump.pb.go new file mode 100644 index 0000000..58a7f15 --- /dev/null +++ b/plugin/executable/cache/dump.pb.go @@ -0,0 +1,250 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.31.0 +// protoc v4.24.3 +// source: plugin/executable/cache/dump.proto + +package cache + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type CachedEntry struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Msg []byte `protobuf:"bytes,2,opt,name=msg,proto3" json:"msg,omitempty"` + CacheExpirationTime int64 `protobuf:"varint,3,opt,name=cache_expiration_time,json=cacheExpirationTime,proto3" json:"cache_expiration_time,omitempty"` + MsgExpirationTime int64 `protobuf:"varint,4,opt,name=msg_expiration_time,json=msgExpirationTime,proto3" json:"msg_expiration_time,omitempty"` + MsgStoredTime int64 `protobuf:"varint,5,opt,name=msg_stored_time,json=msgStoredTime,proto3" json:"msg_stored_time,omitempty"` +} + +func (x *CachedEntry) Reset() { + *x = CachedEntry{} + if protoimpl.UnsafeEnabled { + mi := &file_plugin_executable_cache_dump_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CachedEntry) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CachedEntry) ProtoMessage() {} + +func (x *CachedEntry) ProtoReflect() protoreflect.Message { + mi := &file_plugin_executable_cache_dump_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CachedEntry.ProtoReflect.Descriptor instead. +func (*CachedEntry) Descriptor() ([]byte, []int) { + return file_plugin_executable_cache_dump_proto_rawDescGZIP(), []int{0} +} + +func (x *CachedEntry) GetKey() []byte { + if x != nil { + return x.Key + } + return nil +} + +func (x *CachedEntry) GetMsg() []byte { + if x != nil { + return x.Msg + } + return nil +} + +func (x *CachedEntry) GetCacheExpirationTime() int64 { + if x != nil { + return x.CacheExpirationTime + } + return 0 +} + +func (x *CachedEntry) GetMsgExpirationTime() int64 { + if x != nil { + return x.MsgExpirationTime + } + return 0 +} + +func (x *CachedEntry) GetMsgStoredTime() int64 { + if x != nil { + return x.MsgStoredTime + } + return 0 +} + +type CacheDumpBlock struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Entries []*CachedEntry `protobuf:"bytes,1,rep,name=entries,proto3" json:"entries,omitempty"` +} + +func (x *CacheDumpBlock) Reset() { + *x = CacheDumpBlock{} + if protoimpl.UnsafeEnabled { + mi := &file_plugin_executable_cache_dump_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CacheDumpBlock) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CacheDumpBlock) ProtoMessage() {} + +func (x *CacheDumpBlock) ProtoReflect() protoreflect.Message { + mi := &file_plugin_executable_cache_dump_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CacheDumpBlock.ProtoReflect.Descriptor instead. +func (*CacheDumpBlock) Descriptor() ([]byte, []int) { + return file_plugin_executable_cache_dump_proto_rawDescGZIP(), []int{1} +} + +func (x *CacheDumpBlock) GetEntries() []*CachedEntry { + if x != nil { + return x.Entries + } + return nil +} + +var File_plugin_executable_cache_dump_proto protoreflect.FileDescriptor + +var file_plugin_executable_cache_dump_proto_rawDesc = []byte{ + 0x0a, 0x22, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2f, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x61, + 0x62, 0x6c, 0x65, 0x2f, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2f, 0x64, 0x75, 0x6d, 0x70, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x63, 0x61, 0x63, 0x68, 0x65, 0x22, 0xbd, 0x01, 0x0a, 0x0b, + 0x43, 0x61, 0x63, 0x68, 0x65, 0x64, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, + 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6d, 0x73, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x12, + 0x32, 0x0a, 0x15, 0x63, 0x61, 0x63, 0x68, 0x65, 0x5f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x13, + 0x63, 0x61, 0x63, 0x68, 0x65, 0x45, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x54, + 0x69, 0x6d, 0x65, 0x12, 0x2e, 0x0a, 0x13, 0x6d, 0x73, 0x67, 0x5f, 0x65, 0x78, 0x70, 0x69, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x11, 0x6d, 0x73, 0x67, 0x45, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x54, + 0x69, 0x6d, 0x65, 0x12, 0x26, 0x0a, 0x0f, 0x6d, 0x73, 0x67, 0x5f, 0x73, 0x74, 0x6f, 0x72, 0x65, + 0x64, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x6d, 0x73, + 0x67, 0x53, 0x74, 0x6f, 0x72, 0x65, 0x64, 0x54, 0x69, 0x6d, 0x65, 0x22, 0x3e, 0x0a, 0x0e, 0x43, + 0x61, 0x63, 0x68, 0x65, 0x44, 0x75, 0x6d, 0x70, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x12, 0x2c, 0x0a, + 0x07, 0x65, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, + 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x43, 0x61, 0x63, 0x68, 0x65, 0x64, 0x45, 0x6e, 0x74, + 0x72, 0x79, 0x52, 0x07, 0x65, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x42, 0x19, 0x5a, 0x17, 0x70, + 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2f, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x61, 0x62, 0x6c, 0x65, + 0x2f, 0x63, 0x61, 0x63, 0x68, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_plugin_executable_cache_dump_proto_rawDescOnce sync.Once + file_plugin_executable_cache_dump_proto_rawDescData = file_plugin_executable_cache_dump_proto_rawDesc +) + +func file_plugin_executable_cache_dump_proto_rawDescGZIP() []byte { + file_plugin_executable_cache_dump_proto_rawDescOnce.Do(func() { + file_plugin_executable_cache_dump_proto_rawDescData = protoimpl.X.CompressGZIP(file_plugin_executable_cache_dump_proto_rawDescData) + }) + return file_plugin_executable_cache_dump_proto_rawDescData +} + +var file_plugin_executable_cache_dump_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_plugin_executable_cache_dump_proto_goTypes = []interface{}{ + (*CachedEntry)(nil), // 0: cache.CachedEntry + (*CacheDumpBlock)(nil), // 1: cache.CacheDumpBlock +} +var file_plugin_executable_cache_dump_proto_depIdxs = []int32{ + 0, // 0: cache.CacheDumpBlock.entries:type_name -> cache.CachedEntry + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_plugin_executable_cache_dump_proto_init() } +func file_plugin_executable_cache_dump_proto_init() { + if File_plugin_executable_cache_dump_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_plugin_executable_cache_dump_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CachedEntry); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_plugin_executable_cache_dump_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CacheDumpBlock); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_plugin_executable_cache_dump_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_plugin_executable_cache_dump_proto_goTypes, + DependencyIndexes: file_plugin_executable_cache_dump_proto_depIdxs, + MessageInfos: file_plugin_executable_cache_dump_proto_msgTypes, + }.Build() + File_plugin_executable_cache_dump_proto = out.File + file_plugin_executable_cache_dump_proto_rawDesc = nil + file_plugin_executable_cache_dump_proto_goTypes = nil + file_plugin_executable_cache_dump_proto_depIdxs = nil +} diff --git a/plugin/executable/cache/dump.proto b/plugin/executable/cache/dump.proto new file mode 100644 index 0000000..93f1106 --- /dev/null +++ b/plugin/executable/cache/dump.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package cache; + +option go_package = "plugin/executable/cache"; + +message CachedEntry { + bytes key = 1; + bytes msg = 2; + int64 cache_expiration_time = 3; + int64 msg_expiration_time = 4; + int64 msg_stored_time = 5; +} + +message CacheDumpBlock { + repeated CachedEntry entries = 1; +} diff --git a/plugin/executable/cache/utils.go b/plugin/executable/cache/utils.go new file mode 100644 index 0000000..989b8f7 --- /dev/null +++ b/plugin/executable/cache/utils.go @@ -0,0 +1,207 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package cache + +import ( + "hash/maphash" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/cache" + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/miekg/dns" + "golang.org/x/exp/constraints" +) + +type key string + +var seed = maphash.MakeSeed() + +func (k key) Sum() uint64 { + return maphash.String(seed, string(k)) +} + +// getMsgKey returns a string key for the query msg, or an empty +// string if query should not be cached. +func getMsgKey(q *dns.Msg) string { + if q.Response || q.Opcode != dns.OpcodeQuery || len(q.Question) != 1 { + return "" + } + + const ( + adBit = 1 << iota + cdBit + doBit + ) + + question := q.Question[0] + buf := make([]byte, 1+2+1+len(question.Name)) // bits + qtype + qname length + qname + b := byte(0) + // RFC 6840 5.7: The AD bit in a query as a signal + // indicating that the requester understands and is interested in the + // value of the AD bit in the response. + if q.AuthenticatedData { + b = b | adBit + } + if q.CheckingDisabled { + b = b | cdBit + } + if opt := q.IsEdns0(); opt != nil && opt.Do() { + b = b | doBit + } + buf[0] = b + buf[1] = byte(question.Qtype << 8) + buf[2] = byte(question.Qtype) + buf[3] = byte(len(question.Name)) + copy(buf[4:], question.Name) + return utils.BytesToStringUnsafe(buf) +} + +type item struct { + resp *dns.Msg + storedTime time.Time + expirationTime time.Time +} + +func copyNoOpt(m *dns.Msg) *dns.Msg { + if m == nil { + return nil + } + + m2 := new(dns.Msg) + m2.MsgHdr = m.MsgHdr + m2.Compress = m.Compress + + if len(m.Question) > 0 { + m2.Question = make([]dns.Question, len(m.Question)) + copy(m2.Question, m.Question) + } + + lenExtra := len(m.Extra) + for _, r := range m.Extra { + if r.Header().Rrtype == dns.TypeOPT { + lenExtra-- + } + } + + s := make([]dns.RR, len(m.Answer)+len(m.Ns)+lenExtra) + m2.Answer, s = s[:0:len(m.Answer)], s[len(m.Answer):] + m2.Ns, s = s[:0:len(m.Ns)], s[len(m.Ns):] + m2.Extra = s[:0:lenExtra] + + for _, r := range m.Answer { + m2.Answer = append(m2.Answer, dns.Copy(r)) + } + for _, r := range m.Ns { + m2.Ns = append(m2.Ns, dns.Copy(r)) + } + + for _, r := range m.Extra { + if r.Header().Rrtype == dns.TypeOPT { + continue + } + m2.Extra = append(m2.Extra, dns.Copy(r)) + } + return m2 +} + +func min[T constraints.Ordered](a, b T) T { + if a < b { + return a + } + return b +} + +// getRespFromCache returns the cached response from cache. +// The ttl of returned msg will be changed properly. +// Returned bool indicates whether this response is hit by lazy cache. +// Note: Caller SHOULD change the msg id because it's not same as query's. +func getRespFromCache(msgKey string, backend *cache.Cache[key, *item], lazyCacheEnabled bool, lazyTtl int) (*dns.Msg, bool) { + // Lookup cache + v, _, _ := backend.Get(key(msgKey)) + + // Cache hit + if v != nil { + now := time.Now() + + // Not expired. + if now.Before(v.expirationTime) { + r := v.resp.Copy() + dnsutils.SubtractTTL(r, uint32(now.Sub(v.storedTime).Seconds())) + return r, false + } + + // Msg expired but cache isn't. This is a lazy cache enabled entry. + // If lazy cache is enabled, return the response. + if lazyCacheEnabled { + r := v.resp.Copy() + dnsutils.SetTTL(r, uint32(lazyTtl)) + return r, true + } + } + + // cache miss + return nil, false +} + +// saveRespToCache saves r to cache backend. It returns false if r +// should not be cached and was skipped. +func saveRespToCache(msgKey string, r *dns.Msg, backend *cache.Cache[key, *item], lazyCacheTtl int) bool { + if r.Truncated != false { + return false + } + + var msgTtl time.Duration + var cacheTtl time.Duration + switch r.Rcode { + case dns.RcodeNameError: + msgTtl = time.Second * 30 + cacheTtl = msgTtl + case dns.RcodeServerFailure: + msgTtl = time.Second * 5 + cacheTtl = msgTtl + case dns.RcodeSuccess: + minTTL := dnsutils.GetMinimalTTL(r) + if len(r.Answer) == 0 { // Empty answer. Set ttl between 0~300. + const maxEmtpyAnswerTtl = 300 + msgTtl = time.Duration(min(minTTL, maxEmtpyAnswerTtl)) * time.Second + cacheTtl = msgTtl + } else { + msgTtl = time.Duration(minTTL) * time.Second + if lazyCacheTtl > 0 { + cacheTtl = time.Duration(lazyCacheTtl) * time.Second + } else { + cacheTtl = msgTtl + } + } + } + if msgTtl <= 0 || cacheTtl <= 0 { + return false + } + + now := time.Now() + v := &item{ + resp: copyNoOpt(r), + storedTime: now, + expirationTime: now.Add(msgTtl), + } + backend.Store(key(msgKey), v, now.Add(cacheTtl)) + return true +} diff --git a/plugin/executable/debug_print/print.go b/plugin/executable/debug_print/print.go new file mode 100644 index 0000000..d58dae2 --- /dev/null +++ b/plugin/executable/debug_print/print.go @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package debug_print + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "go.uber.org/zap" +) + +const PluginType = "debug_print" + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.Executable = (*DebugPrint)(nil) + +type DebugPrint struct { + sequence.BQ + msg string +} + +// QuickSetup format: s is the log message string. Default is "debug print". +func QuickSetup(bq sequence.BQ, s string) (any, error) { + if len(s) == 0 { + s = "debug print" + } + return &DebugPrint{BQ: bq, msg: s}, nil +} + +func (b *DebugPrint) Exec(_ context.Context, qCtx *query_context.Context) error { + b.BQ.L().Info(b.msg, zap.Stringer("query", qCtx.Q())) + if r := qCtx.R(); r != nil { + b.BQ.L().Info(b.msg, zap.Stringer("response", r)) + } + return nil +} diff --git a/plugin/executable/drop_resp/drop_resp.go b/plugin/executable/drop_resp/drop_resp.go new file mode 100644 index 0000000..32ea382 --- /dev/null +++ b/plugin/executable/drop_resp/drop_resp.go @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package drop_resp + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" +) + +const PluginType = "drop_resp" + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.Executable = (*DropResp)(nil) + +type DropResp struct{} + +func QuickSetup(_ sequence.BQ, _ string) (any, error) { + return &DropResp{}, nil +} + +func (b *DropResp) Exec(_ context.Context, qCtx *query_context.Context) error { + qCtx.SetResponse(nil) + return nil +} diff --git a/plugin/executable/dual_selector/cache.go b/plugin/executable/dual_selector/cache.go new file mode 100644 index 0000000..f22cc76 --- /dev/null +++ b/plugin/executable/dual_selector/cache.go @@ -0,0 +1,11 @@ +package dual_selector + +import "hash/maphash" + +type key string + +var seed = maphash.MakeSeed() + +func (k key) Sum() uint64 { + return maphash.String(seed, string(k)) +} diff --git a/plugin/executable/dual_selector/dual_selector.go b/plugin/executable/dual_selector/dual_selector.go new file mode 100644 index 0000000..86dff33 --- /dev/null +++ b/plugin/executable/dual_selector/dual_selector.go @@ -0,0 +1,205 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package dual_selector + +import ( + "context" + "io" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/cache" + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +const ( + referenceWaitTimeout = time.Millisecond * 500 + defaultSubRoutineTimeout = time.Second * 5 + + // TODO: Make cache configurable? + cacheSize = 64 * 1024 + cacheTlt = time.Hour + cacheGcInterval = time.Minute +) + +func init() { + sequence.MustRegExecQuickSetup("prefer_ipv4", func(bq sequence.BQ, _ string) (any, error) { + return NewPreferIpv4(bq), nil + }) + sequence.MustRegExecQuickSetup("prefer_ipv6", func(bq sequence.BQ, _ string) (any, error) { + return NewPreferIpv6(bq), nil + }) +} + +var _ sequence.RecursiveExecutable = (*Selector)(nil) +var _ io.Closer = (*Selector)(nil) + +type Selector struct { + sequence.BQ + prefer uint16 // dns.TypeA or dns.TypeAAAA + + preferTypOkCache *cache.Cache[key, bool] +} + +// Exec implements handler.Executable. +func (s *Selector) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error { + q := qCtx.Q() + if len(q.Question) != 1 { // skip wired query with multiple questions. + return next.ExecNext(ctx, qCtx) + } + + qtype := q.Question[0].Qtype + // skip queries that have other unrelated types. + if qtype != dns.TypeA && qtype != dns.TypeAAAA { + return next.ExecNext(ctx, qCtx) + } + + qName := key(q.Question[0].Name) + if qtype == s.prefer { + err := next.ExecNext(ctx, qCtx) + if err != nil { + return err + } + + if r := qCtx.R(); r != nil && msgAnsHasRR(r, s.prefer) { + s.preferTypOkCache.Store(qName, true, time.Now().Add(cacheTlt)) + } + return nil + } + + // Qtype is not the preferred type. + preferredTypOk, _, _ := s.preferTypOkCache.Get(qName) + if preferredTypOk { + // We know that domain has preferred type so this qtype can be blocked + // right away. + r := dnsutils.GenEmptyReply(q, dns.RcodeSuccess) + qCtx.SetResponse(r) + return nil + } + + // async check whether domain has the preferred type + qCtxPreferred := qCtx.Copy() + qCtxPreferred.Q().Question[0].Qtype = s.prefer + + ddl, cacheOk := ctx.Deadline() + if !cacheOk { + ddl = time.Now().Add(defaultSubRoutineTimeout) + } + + shouldBlock := make(chan struct{}) + shouldPass := make(chan struct{}) + go func() { + qCtx := qCtxPreferred + ctx, cancel := context.WithDeadline(context.Background(), ddl) + defer cancel() + err := next.ExecNext(ctx, qCtx) + if err != nil { + s.L().Warn("reference query routine err", qCtx.InfoField(), zap.Error(err)) + close(shouldPass) + return + } + if r := qCtx.R(); r != nil && msgAnsHasRR(r, s.prefer) { + // Target domain has preferred type. + s.preferTypOkCache.Store(qName, true, time.Now().Add(cacheTlt)) + close(shouldBlock) + return + } + close(shouldPass) + }() + + // start original query goroutine + doneChan := make(chan error, 1) + qCtxOrg := qCtx.Copy() + go func() { + qCtx := qCtxOrg + ctx, cancel := context.WithDeadline(context.Background(), ddl) + defer cancel() + doneChan <- next.ExecNext(ctx, qCtx) + }() + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-shouldBlock: // Domain has preferred type. Block this type now. + r := dnsutils.GenEmptyReply(q, dns.RcodeSuccess) + qCtx.SetResponse(r) + return nil + case err := <-doneChan: // The original query finished. Waiting for preferred type check. + waitTimeoutTimer := pool.GetTimer(referenceWaitTimeout) + defer pool.ReleaseTimer(waitTimeoutTimer) + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-shouldBlock: + r := dnsutils.GenEmptyReply(q, dns.RcodeSuccess) + qCtx.SetResponse(r) + return nil + case <-shouldPass: + *qCtx = *qCtxOrg // replace qCtx + return err + case <-waitTimeoutTimer.C: + // We have been waiting the reference query for too long. + // Something may go wrong. We accept the original reply. + *qCtx = *qCtxOrg + return err + } + } +} + +func (s *Selector) Close() error { + s.preferTypOkCache.Close() + return nil +} + +func NewPreferIpv4(bq sequence.BQ) *Selector { + return newSelector(bq, dns.TypeA) +} + +func NewPreferIpv6(bq sequence.BQ) *Selector { + return newSelector(bq, dns.TypeAAAA) +} + +func newSelector(bq sequence.BQ, preferType uint16) *Selector { + if preferType != dns.TypeA && preferType != dns.TypeAAAA { + panic("dual_selector: invalid dns qtype") + } + return &Selector{ + BQ: bq, + prefer: preferType, + preferTypOkCache: cache.New[key, bool](cache.Opts{Size: cacheSize, CleanerInterval: cacheGcInterval}), + } +} + +func msgAnsHasRR(m *dns.Msg, t uint16) bool { + if len(m.Answer) == 0 { + return false + } + + for _, rr := range m.Answer { + if rr.Header().Rrtype == t { + return true + } + } + return false +} diff --git a/plugin/executable/dual_selector/dual_selector_test.go b/plugin/executable/dual_selector/dual_selector_test.go new file mode 100644 index 0000000..9616e73 --- /dev/null +++ b/plugin/executable/dual_selector/dual_selector_test.go @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package dual_selector + +import ( + "context" + "net" + "testing" + "time" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +type dummyNext struct { + returnA bool + latencyA time.Duration + returnAAAA bool + latencyAAAA time.Duration +} + +func (d *dummyNext) Exec(_ context.Context, qCtx *query_context.Context) error { + q := qCtx.Q() + r := new(dns.Msg) + r.SetReply(q) + question := q.Question[0] + rrh := dns.RR_Header{ + Name: question.Name, + Rrtype: question.Qtype, + Class: question.Qclass, + } + + if question.Qtype == dns.TypeA && d.returnA { + r.Answer = append(r.Answer, &dns.A{ + Hdr: rrh, + A: net.IPv4(1, 2, 3, 4), + }) + time.Sleep(d.latencyA) + } + if question.Qtype == dns.TypeAAAA && d.returnAAAA { + r.Answer = append(r.Answer, &dns.AAAA{ + Hdr: rrh, + AAAA: net.IPv4(1, 2, 3, 4), + }) + time.Sleep(d.latencyAAAA) + } + qCtx.SetResponse(r) + return nil +} + +func TestSelector_Exec(t *testing.T) { + nextNoA := &dummyNext{ + returnA: false, + returnAAAA: true, + } + nextNoAAAA := &dummyNext{ + returnA: true, + returnAAAA: false, + } + nextDual := &dummyNext{ + returnA: true, + returnAAAA: true, + } + nextLateA := &dummyNext{ + returnA: true, + latencyA: time.Millisecond * 1000, + returnAAAA: true, + } + nextLateAAAA := &dummyNext{ + returnA: true, + returnAAAA: true, + latencyAAAA: time.Millisecond * 1000, + } + + tests := []struct { + name string + prefer uint16 + qtype uint16 + next *dummyNext + wantErr bool + wantReply bool + }{ + { + name: "prefer v4: do not block domain AAAA if domain does not have an A record", + prefer: dns.TypeA, + qtype: dns.TypeAAAA, + next: nextNoA, + wantErr: false, + wantReply: true, + }, + { + name: "prefer v4: do not block domain AAAA if A reply wasn't returned on time", + prefer: dns.TypeA, + qtype: dns.TypeAAAA, + next: nextLateA, + wantErr: false, + wantReply: true, + }, + { + name: "prefer v4: block domain AAAA if domain has A records", + prefer: dns.TypeA, + qtype: dns.TypeAAAA, + next: nextDual, + wantErr: false, + wantReply: false, + }, + { + name: "prefer v6: do not block domain A if domain does not have an AAAA record", + prefer: dns.TypeAAAA, + qtype: dns.TypeA, + next: nextNoAAAA, + wantErr: false, + wantReply: true, + }, + { + name: "prefer v6: do not block domain A if AAAA reply wasn't returned on time", + prefer: dns.TypeAAAA, + qtype: dns.TypeA, + next: nextLateAAAA, + wantErr: false, + wantReply: true, + }, + { + name: "prefer v6: block domain A if domain has AAAA records", + prefer: dns.TypeAAAA, + qtype: dns.TypeA, + next: nextDual, + wantErr: false, + wantReply: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := newSelector(sequence.NewBQ(coremain.NewTestMosdnsWithPlugins(nil), zap.NewNop()), tt.prefer) + + q := new(dns.Msg) + q.SetQuestion("example.", tt.qtype) + qCtx := query_context.NewContext(q) + cw := sequence.NewChainWalker([]*sequence.ChainNode{{E: tt.next}}, nil) + if err := s.Exec(context.Background(), qCtx, cw); (err != nil) != tt.wantErr { + t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr) + } + + r := qCtx.R() + if hasReply := msgAnsHasRR(r, tt.qtype); hasReply != tt.wantReply { + t.Errorf("Exec() hasReply = %v, wantReply %v", hasReply, tt.wantReply) + } + }) + } +} diff --git a/plugin/executable/ecs_handler/handler.go b/plugin/executable/ecs_handler/handler.go new file mode 100644 index 0000000..af96c20 --- /dev/null +++ b/plugin/executable/ecs_handler/handler.go @@ -0,0 +1,217 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ecs_handler + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strings" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" +) + +const PluginType = "ecs_handler" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) + + // Compatible for old ecs plugin + // TODO: Remove this in mosdns v6, probably. + sequence.MustRegExecQuickSetup("ecs", QuickSetupOldECS) +} + +var _ sequence.RecursiveExecutable = (*ECSHandler)(nil) + +type Args struct { + Forward bool `yaml:"forward"` + Send bool `yaml:"send"` + Preset string `yaml:"preset"` + Mask4 int `yaml:"mask4"` + Mask6 int `yaml:"mask6"` +} + +type ECSHandler struct { + args Args + preset netip.Addr // unmapped +} + +func NewHandler(args Args) (*ECSHandler, error) { + var preset netip.Addr + if len(args.Preset) > 0 { + addr, err := netip.ParseAddr(args.Preset) + if err != nil { + return nil, fmt.Errorf("invalid preset address, %w", err) + } + preset = addr.Unmap() + } + + checkOrInitMask := func(p *int, min, max, defaultM int) bool { + v := *p + if v < min || v > max { + return false + } + if v == 0 { + *p = defaultM + } + return true + } + if !checkOrInitMask(&args.Mask4, 0, 32, 24) { + return nil, errors.New("invalid mask4") + } + if !checkOrInitMask(&args.Mask6, 0, 128, 48) { + return nil, errors.New("invalid mask6") + } + + return &ECSHandler{args: args, preset: preset}, nil +} + +func Init(_ *coremain.BP, args any) (any, error) { + return NewHandler(*args.(*Args)) +} + +// Exec tries to append ECS to qCtx.Q(). +func (e *ECSHandler) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error { + forwarded := e.addECS(qCtx) + err := next.ExecNext(ctx, qCtx) + if err != nil { + return err + } + + if forwarded { + // forward upstream ecs back to client + respOpt := qCtx.RespOpt() + upstreamOpt := qCtx.UpstreamOpt() + if respOpt != nil && upstreamOpt != nil { + for _, o := range upstreamOpt.Option { + if o.Option() == dns.EDNS0SUBNET { + respOpt.Option = append(respOpt.Option, o) + break + } + } + } + } + return nil +} + +// AddECS adds a *dns.EDNS0_SUBNET record to q. +func (e *ECSHandler) addECS(qCtx *query_context.Context) (forwarded bool) { + queryOpt := qCtx.QOpt() + // Check if query already has an ecs. + for _, o := range queryOpt.Option { + if o.Option() == dns.EDNS0SUBNET { + return false // skip it + } + } + if qCtx.QQuestion().Qclass != dns.ClassINET { + // RFC 7871 5: + // ECS is only defined for the Internet (IN) DNS class. + return false + } + + if e.args.Forward { + clientOpt := qCtx.ClientOpt() + if clientOpt != nil { + for _, o := range clientOpt.Option { + if o.Option() == dns.EDNS0SUBNET { + queryOpt.Option = append(queryOpt.Option, o) + return true + } + } + } + } + + if e.preset.IsValid() { + clientAddr := e.preset + var ecs *dns.EDNS0_SUBNET + if clientAddr.Is4() { + ecs = newSubnet(clientAddr.AsSlice(), uint8(e.args.Mask4), false) + } else { + ecs = newSubnet(clientAddr.AsSlice(), uint8(e.args.Mask6), true) + } + queryOpt.Option = append(queryOpt.Option, ecs) + return false + } + + if e.args.Send { + clientAddr := qCtx.ServerMeta.ClientAddr + if clientAddr.IsValid() { + clientAddr = clientAddr.Unmap() + var ecs *dns.EDNS0_SUBNET + if clientAddr.Is4() { + ecs = newSubnet(clientAddr.AsSlice(), uint8(e.args.Mask4), false) + } else { + ecs = newSubnet(clientAddr.AsSlice(), uint8(e.args.Mask6), true) + } + queryOpt.Option = append(queryOpt.Option, ecs) + return false + } + } + return false +} + +func newSubnet(ip net.IP, mask uint8, v6 bool) *dns.EDNS0_SUBNET { + edns0Subnet := new(dns.EDNS0_SUBNET) + // edns family: https://www.iana.org/assignments/address-family-numbers/address-family-numbers.xhtml + // ipv4 = 1 + // ipv6 = 2 + if !v6 { // ipv4 + edns0Subnet.Family = 1 + } else { // ipv6 + edns0Subnet.Family = 2 + } + + edns0Subnet.SourceNetmask = mask + edns0Subnet.Code = dns.EDNS0SUBNET + edns0Subnet.Address = ip + + // SCOPE PREFIX-LENGTH, an unsigned octet representing the leftmost + // number of significant bits of ADDRESS that the response covers. + // In queries, it MUST be set to 0. + // https://tools.ietf.org/html/rfc7871 + edns0Subnet.SourceScope = 0 + return edns0Subnet +} + +// QuickSetup format: +// old: [ip/mask] [ip/mask] +// new: [ip] +// Note: only the first ip will be used as preset address, the second one +// will be ignored. The mask value will be ignored. +func QuickSetupOldECS(bq sequence.BQ, s string) (any, error) { + a := Args{} + fs := strings.Fields(s) + if len(fs) > 0 { + var foundMask bool + a.Preset, _, foundMask = strings.Cut(fs[0], "/") + if foundMask { + bq.L().Warn("ip mask value is deprecated and will be ignored. The default value (24/48) will be used") + } + if len(fs) > 1 { + bq.L().Warn("Dual-stack ecs is deprecated. Only the first ip will be used as preset ecs address. Others will be simply ignored") + } + } + return NewHandler(a) +} diff --git a/plugin/executable/forward/forward.go b/plugin/executable/forward/forward.go new file mode 100644 index 0000000..8a024e6 --- /dev/null +++ b/plugin/executable/forward/forward.go @@ -0,0 +1,327 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package fastforward + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "strings" + "time" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/upstream" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" +) + +const PluginType = "forward" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) + sequence.MustRegExecQuickSetup(PluginType, quickSetup) +} + +const ( + maxConcurrentQueries = 3 + queryTimeout = time.Second * 5 +) + +type Args struct { + Upstreams []UpstreamConfig `yaml:"upstreams"` + Concurrent int `yaml:"concurrent"` + + // Global options. + Socks5 string `yaml:"socks5"` + SoMark int `yaml:"so_mark"` + BindToDevice string `yaml:"bind_to_device"` + Bootstrap string `yaml:"bootstrap"` + BootstrapVer int `yaml:"bootstrap_version"` +} + +type UpstreamConfig struct { + Tag string `yaml:"tag"` + Addr string `yaml:"addr"` // Required. + DialAddr string `yaml:"dial_addr"` + IdleTimeout int `yaml:"idle_timeout"` + + // Deprecated: This option has no affect. + // TODO: (v6) Remove this option. + MaxConns int `yaml:"max_conns"` + EnablePipeline bool `yaml:"enable_pipeline"` + EnableHTTP3 bool `yaml:"enable_http3"` + InsecureSkipVerify bool `yaml:"insecure_skip_verify"` + + Socks5 string `yaml:"socks5"` + SoMark int `yaml:"so_mark"` + BindToDevice string `yaml:"bind_to_device"` + Bootstrap string `yaml:"bootstrap"` + BootstrapVer int `yaml:"bootstrap_version"` +} + +func Init(bp *coremain.BP, args any) (any, error) { + f, err := NewForward(args.(*Args), Opts{Logger: bp.L(), MetricsTag: bp.Tag()}) + if err != nil { + return nil, err + } + if err := f.RegisterMetricsTo(prometheus.WrapRegistererWithPrefix(PluginType+"_", bp.M().GetMetricsReg())); err != nil { + _ = f.Close() + return nil, err + } + return f, nil +} + +var _ sequence.Executable = (*Forward)(nil) +var _ sequence.QuickConfigurableExec = (*Forward)(nil) + +type Forward struct { + args *Args + + logger *zap.Logger + us []*upstreamWrapper + tag2Upstream map[string]*upstreamWrapper // for fast tag lookup only. +} + +type Opts struct { + Logger *zap.Logger + MetricsTag string +} + +// NewForward inits a Forward from given args. +// args must contain at least one upstream. +func NewForward(args *Args, opt Opts) (*Forward, error) { + if len(args.Upstreams) == 0 { + return nil, errors.New("no upstream is configured") + } + if opt.Logger == nil { + opt.Logger = zap.NewNop() + } + + f := &Forward{ + args: args, + logger: opt.Logger, + tag2Upstream: make(map[string]*upstreamWrapper), + } + + applyGlobal := func(c *UpstreamConfig) { + utils.SetDefaultString(&c.Socks5, args.Socks5) + utils.SetDefaultUnsignNum(&c.SoMark, args.SoMark) + utils.SetDefaultString(&c.BindToDevice, args.BindToDevice) + utils.SetDefaultString(&c.Bootstrap, args.Bootstrap) + utils.SetDefaultUnsignNum(&c.BootstrapVer, args.BootstrapVer) + } + + for i, c := range args.Upstreams { + if len(c.Addr) == 0 { + return nil, fmt.Errorf("#%d upstream invalid args, addr is required", i) + } + applyGlobal(&c) + + uw := newWrapper(i, c, opt.MetricsTag) + uOpt := upstream.Opt{ + DialAddr: c.DialAddr, + Socks5: c.Socks5, + SoMark: c.SoMark, + BindToDevice: c.BindToDevice, + IdleTimeout: time.Duration(c.IdleTimeout) * time.Second, + EnablePipeline: c.EnablePipeline, + EnableHTTP3: c.EnableHTTP3, + Bootstrap: c.Bootstrap, + BootstrapVer: c.BootstrapVer, + TLSConfig: &tls.Config{ + InsecureSkipVerify: c.InsecureSkipVerify, + ClientSessionCache: tls.NewLRUClientSessionCache(4), + }, + Logger: opt.Logger, + EventObserver: uw, + } + + u, err := upstream.NewUpstream(c.Addr, uOpt) + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("failed to init upstream #%d: %w", i, err) + } + uw.u = u + f.us = append(f.us, uw) + + if len(c.Tag) > 0 { + if _, dup := f.tag2Upstream[c.Tag]; dup { + _ = f.Close() + return nil, fmt.Errorf("duplicated upstream tag %s", c.Tag) + } + f.tag2Upstream[c.Tag] = uw + } + } + + return f, nil +} + +func (f *Forward) RegisterMetricsTo(r prometheus.Registerer) error { + for _, wu := range f.us { + // Only register metrics for upstream that has a tag. + if len(wu.cfg.Tag) == 0 { + continue + } + if err := wu.registerMetricsTo(r); err != nil { + return err + } + } + return nil +} + +func (f *Forward) Exec(ctx context.Context, qCtx *query_context.Context) (err error) { + r, err := f.exchange(ctx, qCtx, f.us) + if err != nil { + return err + } + qCtx.SetResponse(r) + return nil +} + +// QuickConfigureExec format: [upstream_tag]... +func (f *Forward) QuickConfigureExec(args string) (any, error) { + var us []*upstreamWrapper + if len(args) == 0 { // No args, use all upstreams. + us = f.us + } else { // Pick up upstreams by tags. + for _, tag := range strings.Fields(args) { + u := f.tag2Upstream[tag] + if u == nil { + return nil, fmt.Errorf("cannot find upstream by tag %s", tag) + } + us = append(us, u) + } + } + var execFunc sequence.ExecutableFunc = func(ctx context.Context, qCtx *query_context.Context) error { + r, err := f.exchange(ctx, qCtx, us) + if err != nil { + return err + } + qCtx.SetResponse(r) + return nil + } + return execFunc, nil +} + +func (f *Forward) Close() error { + for _, u := range f.us { + _ = u.Close() + } + return nil +} + +func (f *Forward) exchange(ctx context.Context, qCtx *query_context.Context, us []*upstreamWrapper) (*dns.Msg, error) { + if len(us) == 0 { + return nil, errors.New("no upstream to exchange") + } + + queryPayload, err := pool.PackBuffer(qCtx.Q()) + if err != nil { + return nil, err + } + defer pool.ReleaseBuf(queryPayload) + + concurrent := f.args.Concurrent + if concurrent <= 0 { + concurrent = 1 + } + if concurrent > maxConcurrentQueries { + concurrent = maxConcurrentQueries + } + + type res struct { + r *dns.Msg + err error + } + + resChan := make(chan res) + done := make(chan struct{}) + defer close(done) + + for i := 0; i < concurrent; i++ { + u := randPick(us) + qc := copyPayload(queryPayload) + go func(uqid uint32, question dns.Question) { + defer pool.ReleaseBuf(qc) + // Give each upstream a fixed timeout to finish the query. + upstreamCtx, cancel := context.WithTimeout(context.Background(), queryTimeout) + defer cancel() + + var r *dns.Msg + respPayload, err := u.ExchangeContext(upstreamCtx, *qc) + if err != nil { + f.logger.Warn( + "upstream error", + zap.Uint32("uqid", uqid), + zap.String("qname", question.Name), + zap.Uint16("qclass", question.Qclass), + zap.Uint16("qtype", question.Qtype), + zap.String("upstream", u.name()), + zap.Error(err), + ) + } else { + r = new(dns.Msg) + err = r.Unpack(*respPayload) + pool.ReleaseBuf(respPayload) + if err != nil { + r = nil + } + } + select { + case resChan <- res{r: r, err: err}: + case <-done: + } + }(qCtx.Id(), qCtx.QQuestion()) + } + + for i := 0; i < concurrent; i++ { + select { + case res := <-resChan: + r, err := res.r, res.err + if err != nil { + continue + } + + // Retry until the last + if i < concurrent-1 && r.Rcode != dns.RcodeSuccess && r.Rcode != dns.RcodeNameError { + continue + } + return r, nil + case <-ctx.Done(): + return nil, context.Cause(ctx) + } + } + return nil, errors.New("all upstream servers failed") +} + +func quickSetup(bq sequence.BQ, s string) (any, error) { + args := new(Args) + args.Concurrent = maxConcurrentQueries + for _, u := range strings.Fields(s) { + args.Upstreams = append(args.Upstreams, UpstreamConfig{Addr: u}) + } + return NewForward(args, Opts{Logger: bq.L()}) +} diff --git a/plugin/executable/forward/utils.go b/plugin/executable/forward/utils.go new file mode 100644 index 0000000..64d9fcf --- /dev/null +++ b/plugin/executable/forward/utils.go @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package fastforward + +import ( + "context" + "math/rand" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/upstream" + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap/zapcore" +) + +type upstreamWrapper struct { + idx int + u upstream.Upstream + cfg UpstreamConfig + queryTotal prometheus.Counter + errTotal prometheus.Counter + thread prometheus.Gauge + responseLatency prometheus.Histogram + + connOpened prometheus.Counter + connClosed prometheus.Counter +} + +func (uw *upstreamWrapper) OnEvent(typ upstream.Event) { + switch typ { + case upstream.EventConnOpen: + uw.connOpened.Inc() + case upstream.EventConnClose: + uw.connClosed.Inc() + } +} + +// newWrapper inits all metrics. +// Note: upstreamWrapper.u still needs to be set. +func newWrapper(idx int, cfg UpstreamConfig, pluginTag string) *upstreamWrapper { + lb := map[string]string{"upstream": cfg.Tag, "tag": pluginTag} + return &upstreamWrapper{ + cfg: cfg, + queryTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "query_total", + Help: "The total number of queries processed by this upstream", + ConstLabels: lb, + }), + errTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "err_total", + Help: "The total number of queries failed", + ConstLabels: lb, + }), + thread: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "thread", + Help: "The number of threads (queries) that are currently being processed", + ConstLabels: lb, + }), + responseLatency: prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "response_latency_millisecond", + Help: "The response latency in millisecond", + Buckets: []float64{1, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000}, + ConstLabels: lb, + }), + + connOpened: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "conn_opened_total", + Help: "The total number of connections that are opened", + ConstLabels: lb, + }), + connClosed: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "conn_closed_total", + Help: "The total number of connections that are closed", + ConstLabels: lb, + }), + } +} + +func (uw *upstreamWrapper) registerMetricsTo(r prometheus.Registerer) error { + for _, collector := range [...]prometheus.Collector{ + uw.queryTotal, + uw.errTotal, + uw.thread, + uw.responseLatency, + uw.connOpened, + uw.connClosed, + } { + if err := r.Register(collector); err != nil { + return err + } + } + return nil +} + +// name returns upstream tag if it was set in the config. +// Otherwise, it returns upstream address. +func (uw *upstreamWrapper) name() string { + if t := uw.cfg.Tag; len(t) > 0 { + return uw.cfg.Tag + } + return uw.cfg.Addr +} + +func (uw *upstreamWrapper) ExchangeContext(ctx context.Context, m []byte) (*[]byte, error) { + uw.queryTotal.Inc() + + start := time.Now() + uw.thread.Inc() + r, err := uw.u.ExchangeContext(ctx, m) + uw.thread.Dec() + + if err != nil { + uw.errTotal.Inc() + } else { + uw.responseLatency.Observe(float64(time.Since(start).Milliseconds())) + } + return r, err +} + +func (uw *upstreamWrapper) Close() error { + return uw.u.Close() +} + +type queryInfo dns.Msg + +func (q *queryInfo) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + if len(q.Question) != 1 { + encoder.AddBool("odd_question", true) + } else { + question := q.Question[0] + encoder.AddString("qname", question.Name) + encoder.AddUint16("qtype", question.Qtype) + encoder.AddUint16("qclass", question.Qclass) + } + return nil +} + +func randPick[T any](s []T) T { + return s[rand.Intn(len(s))] +} + +func copyPayload(b *[]byte) *[]byte { + bc := pool.GetBuf(len(*b)) + copy(*bc, *b) + return bc +} diff --git a/plugin/executable/forward_edns0opt/forwarder.go b/plugin/executable/forward_edns0opt/forwarder.go new file mode 100644 index 0000000..26991ec --- /dev/null +++ b/plugin/executable/forward_edns0opt/forwarder.go @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package forwardedns0opt + +import ( + "context" + "strconv" + "strings" + + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" +) + +const PluginType = "forward_edns0opt" + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.RecursiveExecutable = (*forwarder)(nil) + +type forwarder struct { + forwardTypCodes map[uint32]struct{} +} + +func (f *forwarder) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error { + qOpt := qCtx.QOpt() + clientOpt := qCtx.ClientOpt() + if clientOpt != nil { + for _, o := range clientOpt.Option { + if _, ok := f.forwardTypCodes[uint32(o.Option())]; ok { + qOpt.Option = append(qOpt.Option, o) + } + } + } + + err := next.ExecNext(ctx, qCtx) + if err != nil { + return err + } + + upstreamOpt := qCtx.UpstreamOpt() + respOpt := qCtx.RespOpt() + if upstreamOpt != nil && respOpt != nil { + for _, o := range upstreamOpt.Option { + if _, ok := f.forwardTypCodes[uint32(o.Option())]; ok { + respOpt.Option = append(respOpt.Option, o) + } + } + } + return nil +} + +// Format: [DNS EDNS0 Option Code] ... +func QuickSetup(_ sequence.BQ, numbers string) (any, error) { + m := make(map[uint32]struct{}) + for _, s := range strings.Fields(numbers) { + n, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return nil, err + } + m[uint32(n)] = struct{}{} + } + return &forwarder{forwardTypCodes: m}, nil +} diff --git a/plugin/executable/hosts/hosts.go b/plugin/executable/hosts/hosts.go new file mode 100644 index 0000000..de405bb --- /dev/null +++ b/plugin/executable/hosts/hosts.go @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package hosts + +import ( + "bytes" + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/hosts" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "os" +) + +const PluginType = "hosts" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +var _ sequence.Executable = (*Hosts)(nil) + +type Args struct { + Entries []string `yaml:"entries"` + Files []string `yaml:"files"` +} + +type Hosts struct { + h *hosts.Hosts +} + +func Init(_ *coremain.BP, args any) (any, error) { + return NewHosts(args.(*Args)) +} + +func NewHosts(args *Args) (*Hosts, error) { + m := domain.NewMixMatcher[*hosts.IPs]() + m.SetDefaultMatcher(domain.MatcherFull) + for i, entry := range args.Entries { + if err := domain.Load[*hosts.IPs](m, entry, hosts.ParseIPs); err != nil { + return nil, fmt.Errorf("failed to load entry #%d %s, %w", i, entry, err) + } + } + for i, file := range args.Files { + b, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file #%d %s, %w", i, file, err) + } + if err := domain.LoadFromTextReader[*hosts.IPs](m, bytes.NewReader(b), hosts.ParseIPs); err != nil { + return nil, fmt.Errorf("failed to load file #%d %s, %w", i, file, err) + } + } + + return &Hosts{ + h: hosts.NewHosts(m), + }, nil +} + +func (h *Hosts) Response(q *dns.Msg) *dns.Msg { + return h.h.LookupMsg(q) +} + +func (h *Hosts) Exec(_ context.Context, qCtx *query_context.Context) error { + r := h.h.LookupMsg(qCtx.Q()) + if r != nil { + qCtx.SetResponse(r) + } + return nil +} diff --git a/plugin/executable/ipset/ipset.go b/plugin/executable/ipset/ipset.go new file mode 100644 index 0000000..f8d609b --- /dev/null +++ b/plugin/executable/ipset/ipset.go @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ipset + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "strconv" + "strings" +) + +const PluginType = "ipset" + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +type Args struct { + SetName4 string `yaml:"set_name4"` + SetName6 string `yaml:"set_name6"` + Mask4 int `yaml:"mask4"` // default 24 + Mask6 int `yaml:"mask6"` // default 32 +} + +var _ sequence.Executable = (*ipSetPlugin)(nil) + +// QuickSetup format: [set_name,{inet|inet6},mask] *2 +// e.g. "my_set,inet,24 my_set6,inet6,48" +func QuickSetup(_ sequence.BQ, s string) (any, error) { + fs := strings.Fields(s) + if len(fs) > 2 { + return nil, fmt.Errorf("expect no more than 2 fields, got %d", len(fs)) + } + + args := new(Args) + for _, argsStr := range fs { + ss := strings.Split(argsStr, ",") + if len(ss) != 3 { + return nil, fmt.Errorf("invalid args, expect 5 fields, got %d", len(ss)) + } + + m, err := strconv.Atoi(ss[2]) + if err != nil { + return nil, fmt.Errorf("invalid mask, %w", err) + } + switch ss[1] { + case "inet": + args.Mask4 = m + args.SetName4 = ss[0] + case "inet6": + args.Mask6 = m + args.SetName6 = ss[0] + default: + return nil, fmt.Errorf("invalid set family, %s", ss[0]) + } + } + return newIpSetPlugin(args) +} diff --git a/plugin/executable/ipset/ipset_linux.go b/plugin/executable/ipset/ipset_linux.go new file mode 100644 index 0000000..18812d9 --- /dev/null +++ b/plugin/executable/ipset/ipset_linux.go @@ -0,0 +1,103 @@ +//go:build linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ipset + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/miekg/dns" + "github.com/nadoo/ipset" + "net/netip" +) + +type ipSetPlugin struct { + args *Args + nl *ipset.NetLink +} + +func newIpSetPlugin(args *Args) (*ipSetPlugin, error) { + if args.Mask4 == 0 { + args.Mask4 = 24 + } + if args.Mask6 == 0 { + args.Mask6 = 32 + } + + nl, err := ipset.Init() + if err != nil { + return nil, err + } + + return &ipSetPlugin{ + args: args, + nl: nl, + }, nil +} + +func (p *ipSetPlugin) Exec(_ context.Context, qCtx *query_context.Context) error { + r := qCtx.R() + if r != nil { + if err := p.addIPSet(r); err != nil { + return fmt.Errorf("ipset: %w", err) + } + } + return nil +} + +func (p *ipSetPlugin) Close() error { + return p.nl.Close() +} + +func (p *ipSetPlugin) addIPSet(r *dns.Msg) error { + for i := range r.Answer { + switch rr := r.Answer[i].(type) { + case *dns.A: + if len(p.args.SetName4) == 0 { + continue + } + addr, ok := netip.AddrFromSlice(rr.A.To4()) + if !ok { + return fmt.Errorf("invalid A record with ip: %s", rr.A) + } + if err := ipset.AddPrefix(p.nl, p.args.SetName4, netip.PrefixFrom(addr, p.args.Mask4)); err != nil { + return err + } + + case *dns.AAAA: + if len(p.args.SetName6) == 0 { + continue + } + addr, ok := netip.AddrFromSlice(rr.AAAA.To16()) + if !ok { + return fmt.Errorf("invalid AAAA record with ip: %s", rr.AAAA) + } + if err := ipset.AddPrefix(p.nl, p.args.SetName6, netip.PrefixFrom(addr, p.args.Mask6)); err != nil { + return err + } + default: + continue + } + } + + return nil +} diff --git a/plugin/executable/ipset/ipset_other.go b/plugin/executable/ipset/ipset_other.go new file mode 100644 index 0000000..8e9f2ce --- /dev/null +++ b/plugin/executable/ipset/ipset_other.go @@ -0,0 +1,37 @@ +//go:build !linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ipset + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" +) + +type ipSetPlugin struct{} + +func newIpSetPlugin(_ *Args) (*ipSetPlugin, error) { + return &ipSetPlugin{}, nil +} + +func (p *ipSetPlugin) Exec(_ context.Context, _ *query_context.Context) error { + return nil +} diff --git a/plugin/executable/ipset/ipset_test.go b/plugin/executable/ipset/ipset_test.go new file mode 100644 index 0000000..b2cfd29 --- /dev/null +++ b/plugin/executable/ipset/ipset_test.go @@ -0,0 +1,120 @@ +//go:build linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ipset + +import ( + "context" + "fmt" + "math/rand" + "net" + "os" + "strconv" + "testing" + + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "github.com/vishvananda/netlink" +) + +func skipTest(t *testing.T) { + if os.Getenv("TEST_IPSET") == "" { + t.SkipNow() + } +} + +func prepareSet(t *testing.T) (func(), string, string) { + t.Helper() + n4 := "test" + strconv.Itoa(rand.Int()) + n6 := "test" + strconv.Itoa(rand.Int()) + if err := netlink.IpsetCreate(n4, "hash:net", netlink.IpsetCreateOptions{ + Family: netlink.FAMILY_V4, + }); err != nil { + t.Fatal(err) + } + if err := netlink.IpsetCreate(n6, "hash:net", netlink.IpsetCreateOptions{ + Family: netlink.FAMILY_V6, + }); err != nil { + t.Fatal(err) + } + return func() { + if err := netlink.IpsetDestroy(n4); err != nil { + t.Fatal(err) + } + if err := netlink.IpsetDestroy(n6); err != nil { + t.Fatal(err) + } + }, n4, n6 +} + +func Test_ipset(t *testing.T) { + skipTest(t) + + done, n4, n6 := prepareSet(t) + defer done() + + v, err := QuickSetup(nil, fmt.Sprintf("%s,inet,24 %s,inet6,48", n4, n6)) + if err != nil { + t.Fatal(err) + } + p := v.(sequence.Executable) + + q := new(dns.Msg) + q.SetQuestion("test.", dns.TypeA) + r := new(dns.Msg) + r.SetReply(q) + r.Answer = append(r.Answer, &dns.A{A: net.ParseIP("127.0.0.1")}) + r.Answer = append(r.Answer, &dns.A{A: net.ParseIP("127.0.0.2")}) + r.Answer = append(r.Answer, &dns.AAAA{AAAA: net.ParseIP("::1")}) + r.Answer = append(r.Answer, &dns.AAAA{AAAA: net.ParseIP("::2")}) + qCtx := query_context.NewContext(q) + qCtx.SetResponse(r) + if err := p.Exec(context.Background(), qCtx); err != nil { + t.Fatal(err) + } + + // read n4 + l, err := netlink.IpsetList(n4) + if err != nil { + t.Fatal(err) + } + if len(l.Entries) != 1 { + t.Fatal("no entry") + } + e := l.Entries[0] + if !e.IP.Equal(net.ParseIP("127.0.0.0")) || e.CIDR != 24 { + t.Fatal() + } + + // read n6 + l, err = netlink.IpsetList(n6) + if err != nil { + t.Fatal(err) + } + if len(l.Entries) != 1 { + t.Fatal("no entry") + } + e = l.Entries[0] + if !e.IP.Equal(net.ParseIP("::")) || e.CIDR != 48 { + t.Fatal() + } +} diff --git a/plugin/executable/metrics_collector/collector.go b/plugin/executable/metrics_collector/collector.go new file mode 100644 index 0000000..35e5475 --- /dev/null +++ b/plugin/executable/metrics_collector/collector.go @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package metrics_collector + +import ( + "context" + "errors" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/prometheus/client_golang/prometheus" + "time" +) + +const PluginType = "metrics_collector" + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.RecursiveExecutable = (*Collector)(nil) + +type Collector struct { + queryTotal prometheus.Counter + errTotal prometheus.Counter + thread prometheus.Gauge + responseLatency prometheus.Histogram +} + +// NewCollector inits a new Collector with given name to r. +// name must be unique in the r. +func NewCollector(r prometheus.Registerer, name string) (*Collector, error) { + if len(name) == 0 { + return nil, errors.New("collector must has a name") + } + + lb := map[string]string{"name": name} + var c = &Collector{ + queryTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "query_total", + Help: "The total number of queries pass through", + ConstLabels: lb, + }), + errTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "err_total", + Help: "The total number of queries failed", + ConstLabels: lb, + }), + thread: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "thread", + Help: "The number of threads that are currently being processed", + ConstLabels: lb, + }), + responseLatency: prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "response_latency_millisecond", + Help: "The response latency in millisecond", + Buckets: []float64{1, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000}, + ConstLabels: lb, + }), + } + for _, collector := range [...]prometheus.Collector{c.queryTotal, c.errTotal, c.thread, c.responseLatency} { + if err := r.Register(collector); err != nil { + return nil, err + } + } + return c, nil +} + +func (c *Collector) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error { + c.thread.Inc() + defer c.thread.Dec() + + c.queryTotal.Inc() + start := time.Now() + err := next.ExecNext(ctx, qCtx) + if err != nil { + c.errTotal.Inc() + } + if qCtx.R() != nil { + c.responseLatency.Observe(float64(time.Since(start).Milliseconds())) + } + return err +} + +// QuickSetup format: metrics_name +func QuickSetup(bp sequence.BQ, s string) (any, error) { + r := prometheus.WrapRegistererWithPrefix(PluginType+"_", bp.M().GetMetricsReg()) + return NewCollector(r, s) +} diff --git a/plugin/executable/mikrotik_addresslist/README.md b/plugin/executable/mikrotik_addresslist/README.md new file mode 100644 index 0000000..884b94d --- /dev/null +++ b/plugin/executable/mikrotik_addresslist/README.md @@ -0,0 +1,161 @@ +# MikroTik Address List 插件 + +这个插件用于将 DNS 解析得到的 IP 地址自动添加到 MikroTik 路由器的 address list 中。 + +## 功能特性 + +- 支持 IPv4 和 IPv6 地址 +- 自动创建网络前缀(CIDR 格式) +- 支持地址超时设置 +- 支持添加注释 +- 避免重复添加相同地址 +- 支持 TLS 连接 +- 可配置连接超时 + +## 依赖 + +需要添加以下依赖到 `go.mod`: + +```go +require ( + github.com/go-routeros/routeros/v3 v3.0.0 +) +``` + +## 配置方式 + +### 1. 快速配置格式 + +``` +host:port:username:password:use_tls:timeout:address_list4:address_list6:mask4:mask6:comment:timeout_addr +``` + +**参数说明:** +- `host`: MikroTik 路由器 IP 地址 +- `port`: API 端口(默认 8728) +- `username`: 用户名 +- `password`: 密码 +- `use_tls`: 是否使用 TLS(true/false) +- `timeout`: 连接超时时间(秒) +- `address_list4`: IPv4 address list 名称 +- `address_list6`: IPv6 address list 名称 +- `mask4`: IPv4 掩码(默认 24) +- `mask6`: IPv6 掩码(默认 32) +- `comment`: 地址注释 +- `timeout_addr`: 地址超时时间(秒,0 表示永久) + +**示例:** +``` +192.168.1.1:8728:admin:password:false:10:my_list4:my_list6:24:32:from_dns:3600 +``` + +### 2. YAML 配置格式 + +```yaml +- exec: mikrotik_addresslist + args: + host: "192.168.1.1" + port: 8728 + username: "admin" + password: "password" + use_tls: false + timeout: 10 + address_list4: "my_list4" + address_list6: "my_list6" + mask4: 24 + mask6: 32 + comment: "from_dns" + timeout_addr: 3600 +``` + +## 使用示例 + +### 1. 在 mosdns 配置中使用 + +```yaml +plugins: + - tag: sequence + type: sequence + args: + - exec: forward + args: + upstream: + - addr: "8.8.8.8:53" + - exec: mikrotik_addresslist + args: "192.168.1.1:8728:admin:password:false:10:blocked_ips:blocked_ips6:24:32:blocked:86400" + +servers: + - exec: sequence + args: + - sequence +``` + +### 2. 在 MikroTik 中创建 address list + +在 MikroTik 路由器上,需要先创建 address list: + +``` +/ip firewall address-list add list=blocked_ips +/ip firewall address-list add list=blocked_ips6 +``` + +### 3. 在防火墙规则中使用 + +``` +/ip firewall filter add chain=forward src-address-list=blocked_ips action=drop +/ip firewall filter add chain=forward src-address-list=blocked_ips6 action=drop +``` + +## 工作原理 + +1. **DNS 查询处理**:当 mosdns 收到 DNS 查询并返回响应时,插件被触发 +2. **IP 提取**:从 DNS 响应的 A 记录(IPv4)和 AAAA 记录(IPv6)中提取 IP 地址 +3. **网络前缀创建**:根据配置的掩码创建 CIDR 格式的网络前缀 +4. **重复检查**:检查地址列表中是否已存在该地址 +5. **地址添加**:通过 MikroTik API 将地址添加到指定的 address list 中 + +## 安全注意事项 + +1. **API 访问权限**:确保用于连接的 MikroTik 用户具有足够的权限来管理 address list +2. **网络安全**:建议使用 TLS 连接以提高安全性 +3. **密码安全**:不要在配置文件中使用明文密码,考虑使用环境变量或加密配置 +4. **网络隔离**:限制对 MikroTik API 端口的访问 + +## 故障排除 + +### 常见错误 + +1. **连接失败**: + - 检查 MikroTik IP 地址和端口是否正确 + - 确认网络连接正常 + - 检查防火墙设置 + +2. **认证失败**: + - 验证用户名和密码是否正确 + - 确认用户具有管理 address list 的权限 + +3. **权限不足**: + - 确保用户具有 `/ip/firewall/address-list/` 的读写权限 + +### 调试 + +启用 mosdns 的调试日志来查看插件的执行情况: + +```yaml +log: + level: debug +``` + +## 测试 + +运行测试(需要设置环境变量): + +```bash +# 运行基本测试 +export TEST_MIKROTIK=1 +go test ./plugin/executable/mikrotik_addresslist/ + +# 运行 RouterOS v3 API 测试(需要真实的 MikroTik 设备) +export TEST_ROUTEROS_V3=1 +go test ./plugin/executable/mikrotik_addresslist/ -v -run TestRouterOSv3API +``` \ No newline at end of file diff --git a/plugin/executable/mikrotik_addresslist/config_example.yaml b/plugin/executable/mikrotik_addresslist/config_example.yaml new file mode 100644 index 0000000..e97e36d --- /dev/null +++ b/plugin/executable/mikrotik_addresslist/config_example.yaml @@ -0,0 +1,104 @@ +# MikroTik Address List 插件配置示例 + +# 插件定义 +plugins: + # 转发插件 - 向上游 DNS 服务器查询 + - tag: forward_google + type: forward + args: + upstream: + - addr: "8.8.8.8:53" + - addr: "8.8.4.4:53" + + # MikroTik Address List 插件 - 将解析的 IP 添加到 MikroTik + - tag: mikrotik_blocklist + type: mikrotik_addresslist + args: "192.168.1.1:8728:admin:password:false:10:blocked_ips:blocked_ips6:24:32:blocked_domain:86400" + + # 序列插件 - 组合多个插件 + - tag: sequence_with_blocklist + type: sequence + args: + - exec: forward_google + - exec: mikrotik_blocklist + +# 服务器配置 +servers: + # UDP 服务器 + - exec: sequence_with_blocklist + args: + - sequence_with_blocklist + listeners: + - protocol: udp + addr: ":53" + + # TCP 服务器 + - exec: sequence_with_blocklist + args: + - sequence_with_blocklist + listeners: + - protocol: tcp + addr: ":53" + +# 日志配置 +log: + level: info + file: "mosdns.log" + +# 其他配置示例 + +# 1. 使用 YAML 格式的详细配置 +plugins: + - tag: mikrotik_detailed + type: mikrotik_addresslist + args: + host: "192.168.1.1" + port: 8728 + username: "admin" + password: "password" + use_tls: false + timeout: 10 + address_list4: "blocked_ips" + address_list6: "blocked_ips6" + mask4: 24 + mask6: 32 + comment: "blocked_domain" + timeout_addr: 86400 + +# 2. 多个 address list 配置 +plugins: + # 恶意域名列表 + - tag: mikrotik_malware + type: mikrotik_addresslist + args: "192.168.1.1:8728:admin:password:false:10:malware_ips:malware_ips6:24:32:malware:3600" + + # 广告域名列表 + - tag: mikrotik_ads + type: mikrotik_addresslist + args: "192.168.1.1:8728:admin:password:false:10:ads_ips:ads_ips6:24:32:ads:7200" + + # 组合序列 + - tag: sequence_all + type: sequence + args: + - exec: forward_google + - exec: mikrotik_malware + - exec: mikrotik_ads + +# 3. 使用 TLS 的安全配置 +plugins: + - tag: mikrotik_secure + type: mikrotik_addresslist + args: "192.168.1.1:8729:admin:password:true:15:secure_list:secure_list6:24:32:secure:1800" + +# 4. 不同掩码配置 +plugins: + # 精确 IP 匹配 + - tag: mikrotik_exact + type: mikrotik_addresslist + args: "192.168.1.1:8728:admin:password:false:10:exact_ips:exact_ips6:32:128:exact:3600" + + # 网段匹配 + - tag: mikrotik_network + type: mikrotik_addresslist + args: "192.168.1.1:8728:admin:password:false:10:network_ips:network_ips6:16:48:network:7200" \ No newline at end of file diff --git a/plugin/executable/mikrotik_addresslist/mikrotik_addresslist.go b/plugin/executable/mikrotik_addresslist/mikrotik_addresslist.go new file mode 100644 index 0000000..a703f41 --- /dev/null +++ b/plugin/executable/mikrotik_addresslist/mikrotik_addresslist.go @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package mikrotik_addresslist + +import ( + "fmt" + "strconv" + "strings" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" +) + +const PluginType = "mikrotik_addresslist" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +type Args struct { + Host string `yaml:"host"` // MikroTik 路由器 IP 地址 + Port int `yaml:"port"` // API 端口,默认 8728 + Username string `yaml:"username"` // 用户名 + Password string `yaml:"password"` // 密码 + UseTLS bool `yaml:"use_tls"` // 是否使用 TLS,默认 false + Timeout int `yaml:"timeout"` // 连接超时时间(秒),默认 10 + + AddressList4 string `yaml:"address_list4"` // IPv4 address list 名称 + AddressList6 string `yaml:"address_list6"` // IPv6 address list 名称 + Mask4 int `yaml:"mask4"` // IPv4 掩码,默认 24 + Mask6 int `yaml:"mask6"` // IPv6 掩码,默认 32 + Comment string `yaml:"comment"` // 添加的地址的注释 + TimeoutAddr int `yaml:"timeout_addr"` // 地址超时时间(秒),0 表示永久 +} + +var _ sequence.Executable = (*mikrotikAddressListPlugin)(nil) + +// Init initializes the plugin from coremain +func Init(bp *coremain.BP, args any) (any, error) { + return newMikrotikAddressListPlugin(args.(*Args)) +} + +// QuickSetup format: host:port:username:password:use_tls:timeout:address_list4:address_list6:mask4:mask6:comment:timeout_addr +// e.g. "192.168.1.1:8728:admin:password:false:10:my_list4:my_list6:24:32:from_dns:3600" +func QuickSetup(_ sequence.BQ, s string) (any, error) { + parts := strings.Split(s, ":") + if len(parts) < 6 { + return nil, fmt.Errorf("invalid args, expect at least 6 parts, got %d", len(parts)) + } + + args := &Args{ + Host: parts[0], + Username: parts[2], + Password: parts[3], + UseTLS: false, + Timeout: 10, + Mask4: 24, + Mask6: 32, + } + + // 解析端口 + if port, err := strconv.Atoi(parts[1]); err == nil { + args.Port = port + } else { + args.Port = 8728 + } + + // 解析 TLS 设置 + if len(parts) > 4 { + if useTLS, err := strconv.ParseBool(parts[4]); err == nil { + args.UseTLS = useTLS + } + } + + // 解析超时时间 + if len(parts) > 5 { + if timeout, err := strconv.Atoi(parts[5]); err == nil { + args.Timeout = timeout + } + } + + // 解析 address list 名称 + if len(parts) > 6 { + args.AddressList4 = parts[6] + } + if len(parts) > 7 { + args.AddressList6 = parts[7] + } + + // 解析掩码 + if len(parts) > 8 { + if mask4, err := strconv.Atoi(parts[8]); err == nil { + args.Mask4 = mask4 + } + } + if len(parts) > 9 { + if mask6, err := strconv.Atoi(parts[9]); err == nil { + args.Mask6 = mask6 + } + } + + // 解析注释 + if len(parts) > 10 { + args.Comment = parts[10] + } + + // 解析地址超时时间 + if len(parts) > 11 { + if timeoutAddr, err := strconv.Atoi(parts[11]); err == nil { + args.TimeoutAddr = timeoutAddr + } + } + + return newMikrotikAddressListPlugin(args) +} diff --git a/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go b/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go new file mode 100644 index 0000000..fbcbd1e --- /dev/null +++ b/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go @@ -0,0 +1,277 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package mikrotik_addresslist + +import ( + "context" + "fmt" + "net/netip" + "strconv" + "strings" + + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/miekg/dns" + "go.uber.org/zap" + + routeros "github.com/go-routeros/routeros/v3" +) + +type mikrotikAddressListPlugin struct { + args *Args + conn *routeros.Client + log *zap.Logger +} + +func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) { + if args.Mask4 == 0 { + args.Mask4 = 24 + } + if args.Mask6 == 0 { + args.Mask6 = 32 + } + if args.Port == 0 { + args.Port = 8728 + } + if args.Timeout == 0 { + args.Timeout = 10 + } + + // 构建连接地址 + addr := fmt.Sprintf("%s:%d", args.Host, args.Port) + + // 创建 MikroTik 连接 + conn, err := routeros.Dial(addr, args.Username, args.Password) + if err != nil { + return nil, fmt.Errorf("failed to connect to MikroTik: %w", err) + } + + // 测试连接 + if err := testMikrotikConnection(conn); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to test MikroTik connection: %w", err) + } + + plugin := &mikrotikAddressListPlugin{ + args: args, + conn: conn, + log: zap.L().Named("mikrotik_addresslist"), + } + + // 记录连接成功信息 + plugin.log.Info("successfully connected to MikroTik", + zap.String("host", args.Host), + zap.Int("port", args.Port), + zap.String("username", args.Username), + zap.String("address_list4", args.AddressList4)) + + return plugin, nil +} + +// testMikrotikConnection 测试 MikroTik 连接是否正常 +func testMikrotikConnection(conn *routeros.Client) error { + // 尝试执行一个简单的命令来测试连接 + resp, err := conn.Run("/system/resource/print") + if err != nil { + return fmt.Errorf("connection test failed: %w", err) + } + + // 检查响应是否有效 + if len(resp.Re) == 0 { + return fmt.Errorf("connection test failed: no response from MikroTik") + } + + // 记录连接测试成功 + zap.L().Named("mikrotik_addresslist").Info("MikroTik connection test successful") + + return nil +} + +func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error { + // 检查连接是否正常 + if p.conn == nil { + p.log.Error("MikroTik connection is nil") + return fmt.Errorf("mikrotik_addresslist: connection is nil") + } + + r := qCtx.R() + if r != nil { + p.log.Debug("processing DNS response", + zap.String("qname", qCtx.Q().Question[0].Name), + zap.Int("answer_count", len(r.Answer))) + + if err := p.addToAddressList(r); err != nil { + p.log.Error("failed to add addresses to MikroTik", zap.Error(err)) + return fmt.Errorf("mikrotik_addresslist: %w", err) + } + } + return nil +} + +func (p *mikrotikAddressListPlugin) Close() error { + if p.conn != nil { + return p.conn.Close() + } + return nil +} + +func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error { + addedCount := 0 + p.log.Debug("starting to process DNS response", + zap.String("configured_address_list4", p.args.AddressList4), + zap.Int("answer_count", len(r.Answer))) + + for i := range r.Answer { + switch rr := r.Answer[i].(type) { + case *dns.A: + if len(p.args.AddressList4) == 0 { + p.log.Debug("skipping A record, no IPv4 address list configured") + continue + } + addr, ok := netip.AddrFromSlice(rr.A.To4()) + if !ok { + p.log.Error("invalid A record", zap.String("ip", rr.A.String())) + return fmt.Errorf("invalid A record with ip: %s", rr.A) + } + p.log.Debug("processing A record", + zap.String("ip", addr.String()), + zap.String("address_list4", p.args.AddressList4)) + if err := p.addAddressToMikrotik(addr, p.args.AddressList4, p.args.Mask4); err != nil { + return err + } + addedCount++ + + case *dns.AAAA: + // 跳过 IPv6 记录 + p.log.Debug("skipping AAAA record (IPv6 not supported)") + continue + default: + p.log.Debug("skipping non-A record", zap.String("type", fmt.Sprintf("%T", rr))) + continue + } + } + + if addedCount > 0 { + p.log.Info("added IPv4 addresses to MikroTik", zap.Int("count", addedCount)) + } else { + p.log.Debug("no IPv4 addresses added to MikroTik") + } + + return nil +} + +func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listName string, mask int) error { + p.log.Debug("addAddressToMikrotik called", + zap.String("addr", addr.String()), + zap.String("listName", listName), + zap.Int("mask", mask)) + + // 构建 CIDR 格式的地址 + var cidrAddr string + if addr.Is4() { + cidrAddr = addr.String() + "/" + strconv.Itoa(p.args.Mask4) + } else { + cidrAddr = addr.String() + "/" + strconv.Itoa(p.args.Mask6) + } + + p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName)) + + // 检查地址是否已存在 + exists, err := p.addressExists(listName, cidrAddr) + if err != nil { + // 如果检查失败,可能是地址列表不存在,继续尝试添加 + p.log.Debug("failed to check if address exists, will try to add anyway", zap.Error(err)) + } else if exists { + // 地址已存在,跳过 + p.log.Debug("address already exists", zap.String("cidr", cidrAddr), zap.String("list", listName)) + return nil + } + + // 构造 RouterOS 参数,注意必须以 = 开头! + params := []string{ + "=list=" + listName, + "=address=" + cidrAddr, + } + + // 添加注释(如果配置了) + if p.args.Comment != "" { + params = append(params, "=comment="+p.args.Comment) + } + + // 添加超时时间(如果配置了) + if p.args.TimeoutAddr > 0 { + params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr)) + } + + p.log.Info("adding address to MikroTik", + zap.String("cidr", cidrAddr), + zap.String("list", listName), + zap.String("comment", p.args.Comment), + zap.Int("timeout", p.args.TimeoutAddr)) + + p.log.Debug("Add to list: ", zap.Strings("params", params)) + + // 发送到 RouterOS + args := append([]string{"/ip/firewall/address-list/add"}, params...) + _, err = p.conn.Run(args...) + if err != nil { + if strings.Contains(err.Error(), "already have such entry") { + p.log.Debug("Already exists: ", zap.String("cidr", cidrAddr)) + return nil + } + p.log.Error("failed to add address to MikroTik", + zap.String("cidr", cidrAddr), + zap.String("list", listName), + zap.Error(err)) + return fmt.Errorf("failed to add address %s to list %s: from RouterOS device: %v", cidrAddr, listName, err) + } + + p.log.Info("successfully added address to MikroTik", + zap.String("cidr", cidrAddr), + zap.String("list", listName)) + + return nil +} + +func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) { + // 查询地址列表中是否已存在该地址 + query := fmt.Sprintf("?list=%s&address=%s", listName, address) + p.log.Debug("checking address existence", + zap.String("list", listName), + zap.String("address", address), + zap.String("query", query)) + + resp, err := p.conn.Run("/ip/firewall/address-list/print", query) + if err != nil { + p.log.Error("failed to check address existence", + zap.String("list", listName), + zap.String("address", address), + zap.Error(err)) + return false, err + } + + // 如果返回结果不为空,说明地址已存在 + exists := len(resp.Re) > 0 + p.log.Debug("address existence check", + zap.String("list", listName), + zap.String("address", address), + zap.Bool("exists", exists)) + + return exists, nil +} diff --git a/plugin/executable/nftset/nftset.go b/plugin/executable/nftset/nftset.go new file mode 100644 index 0000000..e1aa58a --- /dev/null +++ b/plugin/executable/nftset/nftset.go @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package nftset + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "strconv" + "strings" +) + +const PluginType = "nftset" + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.Executable = (*nftSetPlugin)(nil) + +type Args struct { + IPv4 SetArgs `yaml:"ipv4"` + IPv6 SetArgs `yaml:"ipv6"` +} + +type SetArgs struct { + TableFamily string `yaml:"table_family"` + Table string `yaml:"table_name"` + Set string `yaml:"set_name"` + Mask int `yaml:"mask"` +} + +// QuickSetup format: [{ip|ip6|inet},table_name,set_name,{ipv4_addr|ipv6_addr},mask] *2 (can repeat once) +// e.g. "inet,my_table,my_set,ipv4_addr,24 inet,my_table,my_set,ipv6_addr,48" +func QuickSetup(_ sequence.BQ, s string) (any, error) { + fs := strings.Fields(s) + if len(fs) > 2 { + return nil, fmt.Errorf("expect no more than 2 fields, got %d", len(fs)) + } + + args := new(Args) + for _, argsStr := range fs { + ss := strings.Split(argsStr, ",") + if len(ss) != 5 { + return nil, fmt.Errorf("invalid args, expect 5 fields, got %d", len(ss)) + } + + m, err := strconv.Atoi(ss[4]) + if err != nil { + return nil, fmt.Errorf("invalid mask, %w", err) + } + sa := SetArgs{ + TableFamily: ss[0], + Table: ss[1], + Set: ss[2], + Mask: m, + } + switch ss[3] { + case "ipv4_addr": + args.IPv4 = sa + case "ipv6_addr": + args.IPv6 = sa + default: + return nil, fmt.Errorf("invalid ip type, %s", ss[0]) + } + } + return newNftSetPlugin(args) +} diff --git a/plugin/executable/nftset/nftset_linux.go b/plugin/executable/nftset/nftset_linux.go new file mode 100644 index 0000000..0fc1c8e --- /dev/null +++ b/plugin/executable/nftset/nftset_linux.go @@ -0,0 +1,162 @@ +//go:build linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package nftset + +import ( + "context" + "fmt" + "net/netip" + + "github.com/IrineSistiana/mosdns/v5/pkg/nftset_utils" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/google/nftables" + "github.com/miekg/dns" +) + +type nftSetPlugin struct { + args *Args + v4Handler *nftset_utils.NftSetHandler + v6Handler *nftset_utils.NftSetHandler +} + +func newNftSetPlugin(args *Args) (*nftSetPlugin, error) { + utils.SetDefaultUnsignNum(&args.IPv4.Mask, 24) + utils.SetDefaultUnsignNum(&args.IPv6.Mask, 48) + if m := args.IPv4.Mask; m > 32 { + return nil, fmt.Errorf("invalid ipv4 mask %d", m) + } + if m := args.IPv6.Mask; m > 128 { + return nil, fmt.Errorf("invalid ipv6 mask %d", m) + } + + p := &nftSetPlugin{ + args: args, + } + + newHandler := func(sa SetArgs) (*nftset_utils.NftSetHandler, error) { + if !(len(sa.Table) > 0 && len(sa.TableFamily) > 0 && len(sa.Set) > 0) { + return nil, nil + } + f, ok := parseTableFamily(sa.TableFamily) + if !ok { + return nil, fmt.Errorf("unsupported nftables family [%s]", sa.TableFamily) + } + return nftset_utils.NewNtSetHandler(nftset_utils.HandlerOpts{ + TableFamily: f, + TableName: sa.Table, + SetName: sa.Set, + }), nil + } + var err error + p.v4Handler, err = newHandler(args.IPv4) + if err != nil { + return nil, err + } + p.v6Handler, err = newHandler(args.IPv6) + if err != nil { + _ = p.v4Handler.Close() + return nil, err + } + return p, nil +} + +func (p *nftSetPlugin) Exec(_ context.Context, qCtx *query_context.Context) error { + r := qCtx.R() + if r != nil { + if err := p.addElems(r); err != nil { + return fmt.Errorf("nftable: %w", err) + } + } + return nil +} + +func (p *nftSetPlugin) addElems(r *dns.Msg) error { + var v4Elems []netip.Prefix + var v6Elems []netip.Prefix + + for i := range r.Answer { + switch rr := r.Answer[i].(type) { + case *dns.A: + if p.v4Handler == nil { + continue + } + addr, ok := netip.AddrFromSlice(rr.A) + addr = addr.Unmap() + if !ok || !addr.Is4() { + return fmt.Errorf("internel: dns.A record [%s] is not a ipv4 address", rr.A) + } + v4Elems = append(v4Elems, netip.PrefixFrom(addr, p.args.IPv4.Mask)) + + case *dns.AAAA: + if p.v6Handler == nil { + continue + } + addr, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + return fmt.Errorf("internel: dns.AAAA record [%s] is not a ipv6 address", rr.AAAA) + } + if addr.Is4() { + addr = netip.AddrFrom16(addr.As16()) + } + v6Elems = append(v6Elems, netip.PrefixFrom(addr, p.args.IPv6.Mask)) + default: + continue + } + } + + if p.v4Handler != nil && len(v4Elems) > 0 { + if err := p.v4Handler.AddElems(v4Elems...); err != nil { + return fmt.Errorf("failed to add ipv4 elems %s: %w", v4Elems, err) + } + } + + if p.v6Handler != nil && len(v6Elems) > 0 { + if err := p.v6Handler.AddElems(v6Elems...); err != nil { + return fmt.Errorf("failed to add ipv6 elems %s: %w", v6Elems, err) + } + } + return nil +} + +func (p *nftSetPlugin) Close() error { + if p.v4Handler != nil { + _ = p.v4Handler.Close() + } + if p.v6Handler != nil { + _ = p.v6Handler.Close() + } + return nil +} + +func parseTableFamily(s string) (nftables.TableFamily, bool) { + switch s { + case "ip": + return nftables.TableFamilyIPv4, true + case "ip6": + return nftables.TableFamilyIPv6, true + case "inet": + return nftables.TableFamilyINet, true + default: + return 0, false + } +} diff --git a/plugin/executable/nftset/nftset_other.go b/plugin/executable/nftset/nftset_other.go new file mode 100644 index 0000000..c7f870d --- /dev/null +++ b/plugin/executable/nftset/nftset_other.go @@ -0,0 +1,38 @@ +//go:build !linux +// +build !linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package nftset + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" +) + +type nftSetPlugin struct{} + +func newNftSetPlugin(args *Args) (*nftSetPlugin, error) { + return &nftSetPlugin{}, nil +} + +func (p *nftSetPlugin) Exec(_ context.Context, _ *query_context.Context) error { + return nil +} diff --git a/plugin/executable/query_summary/query_summary.go b/plugin/executable/query_summary/query_summary.go new file mode 100644 index 0000000..5cfc39f --- /dev/null +++ b/plugin/executable/query_summary/query_summary.go @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package query_summary + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "go.uber.org/zap" +) + +const ( + PluginType = "query_summary" +) + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.RecursiveExecutable = (*SummaryLogger)(nil) + +type SummaryLogger struct { + l *zap.Logger + msg string +} + +// QuickSetup format: [msg_title] +func QuickSetup(bq sequence.BQ, s string) (any, error) { + return NewSummaryLogger(bq.L(), s), nil +} + +// NewSummaryLogger returns a SummaryLogger that logs query info into l. +// l cannot be nil. +// If msg is empty, "query summary" will be used. +func NewSummaryLogger(l *zap.Logger, msg string) *SummaryLogger { + if len(msg) == 0 { + msg = "query summary" + } + return &SummaryLogger{ + l: l, + msg: msg, + } +} + +func (l *SummaryLogger) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error { + err := next.ExecNext(ctx, qCtx) + l.l.Info( + l.msg, + zap.Inline(qCtx), + zap.Error(err), + ) + return err +} diff --git a/plugin/executable/rate_limiter/rate_limiter.go b/plugin/executable/rate_limiter/rate_limiter.go new file mode 100644 index 0000000..72a04f4 --- /dev/null +++ b/plugin/executable/rate_limiter/rate_limiter.go @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package rate_limiter + +import ( + "context" + "fmt" + "io" + "net/netip" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/rate_limiter" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "golang.org/x/time/rate" +) + +const PluginType = "rate_limiter" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +type Args struct { + Qps float64 `yaml:"qps"` + Burst int `yaml:"burst"` + Mask4 int `yaml:"mask4"` + Mask6 int `yaml:"mask6"` +} + +func (args *Args) init() error { + utils.SetDefaultUnsignNum(&args.Qps, 20) + utils.SetDefaultUnsignNum(&args.Burst, 40) + utils.SetDefaultUnsignNum(&args.Mask4, 32) + utils.SetDefaultUnsignNum(&args.Mask6, 48) + + if !utils.CheckNumRange(args.Mask4, 0, 32) { + return fmt.Errorf("invalid mask4") + } + if !utils.CheckNumRange(args.Mask6, 0, 128) { + return fmt.Errorf("invalid mask6") + } + return nil +} + +var _ sequence.Matcher = (*RateLimiter)(nil) +var _ io.Closer = (*RateLimiter)(nil) + +type RateLimiter struct { + args Args + l *rate_limiter.Limiter +} + +func Init(_ *coremain.BP, args any) (any, error) { + return New(*(args.(*Args))) +} + +func New(args Args) (*RateLimiter, error) { + err := args.init() + if err != nil { + return nil, fmt.Errorf("invalid args, %w", err) + } + l := rate_limiter.NewRateLimiter(rate.Limit(args.Qps), args.Burst) + return &RateLimiter{l: l, args: args}, nil +} + +func (s *RateLimiter) Match(ctx context.Context, qCtx *query_context.Context) (bool, error) { + addr := s.getMaskedClientAddr(qCtx) + if addr.IsValid() { + return s.l.Allow(addr), nil + } + return true, nil +} + +func (s *RateLimiter) getMaskedClientAddr(qCtx *query_context.Context) netip.Addr { + a := qCtx.ServerMeta.ClientAddr + if !a.IsValid() { + return netip.Addr{} + } + a = a.Unmap() + var p netip.Prefix + if a.Is4() { + p, _ = a.Prefix(s.args.Mask4) + } else { + p, _ = a.Prefix(s.args.Mask6) + } + return p.Addr() +} + +func (s *RateLimiter) Close() error { + return s.l.Close() +} diff --git a/plugin/executable/redirect/redirect.go b/plugin/executable/redirect/redirect.go new file mode 100644 index 0000000..77ed518 --- /dev/null +++ b/plugin/executable/redirect/redirect.go @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package redirect + +import ( + "bytes" + "context" + "fmt" + "os" + "strings" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +const PluginType = "redirect" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +var _ sequence.RecursiveExecutable = (*Redirect)(nil) + +type Args struct { + Rules []string `yaml:"rules"` + Files []string `yaml:"files"` +} + +type Redirect struct { + m *domain.MixMatcher[string] +} + +func Init(bp *coremain.BP, args any) (any, error) { + r, err := NewRedirect(args.(*Args)) + if err != nil { + return nil, err + } + bp.L().Info("redirect rules loaded", zap.Int("length", r.Len())) + return r, nil +} + +func NewRedirect(args *Args) (*Redirect, error) { + parseFunc := func(s string) (p, v string, err error) { + f := strings.Fields(s) + if len(f) != 2 { + return "", "", fmt.Errorf("redirect rule must have 2 fields, but got %d", len(f)) + } + return f[0], dns.Fqdn(f[1]), nil + } + m := domain.NewMixMatcher[string]() + m.SetDefaultMatcher(domain.MatcherFull) + for i, rule := range args.Rules { + if err := domain.Load[string](m, rule, parseFunc); err != nil { + return nil, fmt.Errorf("failed to load rule #%d %s, %w", i, rule, err) + } + } + for i, file := range args.Files { + b, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file #%d %s, %w", i, file, err) + } + if err := domain.LoadFromTextReader[string](m, bytes.NewReader(b), parseFunc); err != nil { + return nil, fmt.Errorf("failed to load file #%d %s, %w", i, file, err) + } + } + return &Redirect{m: m}, nil +} + +func (r *Redirect) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error { + q := qCtx.Q() + if len(q.Question) != 1 || q.Question[0].Qclass != dns.ClassINET { + return next.ExecNext(ctx, qCtx) + } + + orgQName := q.Question[0].Name + redirectTarget, ok := r.m.Match(orgQName) + if !ok { + return next.ExecNext(ctx, qCtx) + } + + q.Question[0].Name = redirectTarget + defer func() { + q.Question[0].Name = orgQName + }() + err := next.ExecNext(ctx, qCtx) + if r := qCtx.R(); r != nil { + // Restore original query name. + for i := range r.Question { + if r.Question[i].Name == redirectTarget { + r.Question[i].Name = orgQName + } + } + + // Insert a CNAME record. + newAns := make([]dns.RR, 1, len(r.Answer)+1) + newAns[0] = &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: orgQName, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 1, + }, + Target: redirectTarget, + } + newAns = append(newAns, r.Answer...) + r.Answer = newAns + } + return err +} + +func (r *Redirect) Len() int { + return r.m.Len() +} diff --git a/plugin/executable/reverse_lookup/reverse_lookup.go b/plugin/executable/reverse_lookup/reverse_lookup.go new file mode 100644 index 0000000..fceed1f --- /dev/null +++ b/plugin/executable/reverse_lookup/reverse_lookup.go @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package reverselookup + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/cache" + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/go-chi/chi/v5" + "github.com/miekg/dns" + "net" + "net/http" + "net/netip" + "time" +) + +const ( + PluginType = "reverse_lookup" +) + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +var _ sequence.RecursiveExecutable = (*ReverseLookup)(nil) + +type Args struct { + Size int `yaml:"size"` // Default is 64*1024 + HandlePTR bool `yaml:"handle_ptr"` + TTL int `yaml:"ttl"` // Default is 7200 (2h) +} + +func (a *Args) init() { + utils.SetDefaultUnsignNum(&a.Size, 64*1024) + utils.SetDefaultUnsignNum(&a.TTL, 7200) +} + +type ReverseLookup struct { + args *Args + c *cache.Cache[key, string] +} + +func Init(bp *coremain.BP, args any) (any, error) { + return NewReverseLookup(bp, args.(*Args)) +} + +func NewReverseLookup(bp *coremain.BP, args *Args) (any, error) { + args.init() + c := cache.New[key, string](cache.Opts{Size: args.Size}) + p := &ReverseLookup{ + args: args, + c: c, + } + r := chi.NewRouter() + r.Get("/", p.ServeHTTP) + bp.RegAPI(r) + return p, nil +} + +func (p *ReverseLookup) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error { + q := qCtx.Q() + if r := p.ResponsePTR(q); r != nil { + qCtx.SetResponse(r) + return nil + } + + if err := next.ExecNext(ctx, qCtx); err != nil { + return err + } + p.saveIPs(q, qCtx.R()) + return nil +} + +func (p *ReverseLookup) Close() error { + return p.c.Close() +} + +func (p *ReverseLookup) ServeHTTP(w http.ResponseWriter, req *http.Request) { + ipStr := req.URL.Query().Get("ip") + if len(ipStr) == 0 { + http.Error(w, "no 'ip' query parameter found", http.StatusBadRequest) + return + } + addr, err := netip.ParseAddr(ipStr) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + d := p.lookup(netip.AddrFrom16(addr.As16())) + if len(d) > 0 { + _, _ = fmt.Fprint(w, d) + } +} + +func (p *ReverseLookup) lookup(n netip.Addr) string { + v, _, _ := p.c.Get(key(as16(n))) + return v +} + +func (p *ReverseLookup) ResponsePTR(q *dns.Msg) *dns.Msg { + if p.args.HandlePTR && len(q.Question) > 0 && q.Question[0].Qtype == dns.TypePTR { + question := q.Question[0] + addr, _ := dnsutils.ParsePTRQName(question.Name) + // If we cannot parse this ptr name. Just ignore it and pass query to next node. + // PTR standards are a mess. + if !addr.IsValid() { + return nil + } + fqdn := p.lookup(addr) + if len(fqdn) > 0 { + r := new(dns.Msg) + r.SetReply(q) + r.Answer = append(r.Answer, &dns.PTR{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: question.Qtype, + Class: question.Qclass, + Ttl: 5, + }, + Ptr: fqdn, + }) + return r + } + } + return nil +} + +func (p *ReverseLookup) saveIPs(q, r *dns.Msg) { + if r == nil { + return + } + + now := time.Now() + for _, rr := range r.Answer { + var ip net.IP + switch rr := rr.(type) { + case *dns.A: + ip = rr.A + case *dns.AAAA: + ip = rr.AAAA + default: + continue + } + + addr, ok := netip.AddrFromSlice(ip) + if !ok { + continue + } + h := rr.Header() + if int(h.Ttl) > p.args.TTL { + h.Ttl = uint32(p.args.TTL) + } + name := h.Name + if len(q.Question) == 1 { + name = q.Question[0].Name + } + p.c.Store(key(as16(addr)), name, now.Add(time.Duration(p.args.TTL)*time.Second)) + } +} + +func as16(n netip.Addr) netip.Addr { + if n.Is6() { + return n + } + return netip.AddrFrom16(n.As16()) +} diff --git a/plugin/executable/reverse_lookup/utils.go b/plugin/executable/reverse_lookup/utils.go new file mode 100644 index 0000000..1a80939 --- /dev/null +++ b/plugin/executable/reverse_lookup/utils.go @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package reverselookup + +import ( + "hash/maphash" + "net/netip" +) + +type key netip.Addr + +var seed = maphash.MakeSeed() + +func (k key) Sum() uint64 { + b := netip.Addr(k).As16() + return maphash.Bytes(seed, b[:]) +} diff --git a/plugin/executable/sequence/built_in.go b/plugin/executable/sequence/built_in.go new file mode 100644 index 0000000..c20fefa --- /dev/null +++ b/plugin/executable/sequence/built_in.go @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/miekg/dns" + "strconv" +) + +var _ RecursiveExecutable = (*ActionAccept)(nil) + +type ActionAccept struct{} + +func (a ActionAccept) Exec(_ context.Context, _ *query_context.Context, _ ChainWalker) error { + return nil +} + +func setupAccept(_ BQ, _ string) (any, error) { + return ActionAccept{}, nil +} + +var _ RecursiveExecutable = (*ActionReject)(nil) + +type ActionReject struct { + Rcode int +} + +func (a ActionReject) Exec(_ context.Context, qCtx *query_context.Context, _ ChainWalker) error { + r := new(dns.Msg) + r.SetReply(qCtx.Q()) + r.Rcode = a.Rcode + qCtx.SetResponse(r) + return nil +} + +func setupReject(_ BQ, s string) (any, error) { + rcode := dns.RcodeRefused + if len(s) > 0 { + n, err := strconv.Atoi(s) + if err != nil || n < 0 || n > 0xFFF { + return nil, fmt.Errorf("invalid rcode [%s]", s) + } + rcode = n + } + return ActionReject{Rcode: rcode}, nil +} + +var _ RecursiveExecutable = (*ActionReturn)(nil) + +type ActionReturn struct{} + +func (a ActionReturn) Exec(ctx context.Context, qCtx *query_context.Context, next ChainWalker) error { + if next.jumpBack != nil { + return next.jumpBack.ExecNext(ctx, qCtx) + } + return nil +} + +func setupReturn(_ BQ, _ string) (any, error) { + return ActionReturn{}, nil +} + +var _ RecursiveExecutable = (*ActionJump)(nil) + +type ActionJump struct { + To []*ChainNode +} + +func (a *ActionJump) Exec(ctx context.Context, qCtx *query_context.Context, next ChainWalker) error { + w := NewChainWalker(a.To, &next) + return w.ExecNext(ctx, qCtx) +} + +func setupJump(bq BQ, s string) (any, error) { + target, _ := bq.M().GetPlugin(s).(*Sequence) + if target == nil { + return nil, fmt.Errorf("can not find jump target %s", s) + } + return &ActionJump{To: target.chain}, nil +} + +var _ RecursiveExecutable = (*ActionGoto)(nil) + +type ActionGoto struct { + To []*ChainNode +} + +func (a ActionGoto) Exec(ctx context.Context, qCtx *query_context.Context, _ ChainWalker) error { + w := NewChainWalker(a.To, nil) + return w.ExecNext(ctx, qCtx) +} + +func setupGoto(bq BQ, s string) (any, error) { + gt, _ := bq.M().GetPlugin(s).(*Sequence) + if gt == nil { + return nil, fmt.Errorf("can not find goto target %s", s) + } + return &ActionGoto{To: gt.chain}, nil +} + +var _ Matcher = (*MatchAlwaysTrue)(nil) + +type MatchAlwaysTrue struct{} + +func (m MatchAlwaysTrue) Match(_ context.Context, _ *query_context.Context) (bool, error) { + return true, nil +} + +func setupTrue(_ BQ, _ string) (Matcher, error) { + return MatchAlwaysTrue{}, nil +} + +var _ Matcher = (*MatchAlwaysFalse)(nil) + +type MatchAlwaysFalse struct{} + +func (m MatchAlwaysFalse) Match(_ context.Context, _ *query_context.Context) (bool, error) { + return false, nil +} + +func setupFalse(_ BQ, _ string) (Matcher, error) { + return MatchAlwaysFalse{}, nil +} diff --git a/plugin/executable/sequence/chain.go b/plugin/executable/sequence/chain.go new file mode 100644 index 0000000..59fe977 --- /dev/null +++ b/plugin/executable/sequence/chain.go @@ -0,0 +1,237 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import ( + "context" + "errors" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "io" +) + +type ChainNode struct { + Matches []Matcher // Can be empty, indicates this node has no match specified. + + // At least one of E or RE must not nil. + // In case both are set. E is preferred. + E Executable + RE RecursiveExecutable +} + +type ChainWalker struct { + p int + chain []*ChainNode + jumpBack *ChainWalker +} + +func NewChainWalker(chain []*ChainNode, jumpBack *ChainWalker) ChainWalker { + return ChainWalker{ + chain: chain, + jumpBack: jumpBack, + } +} + +func (w *ChainWalker) ExecNext(ctx context.Context, qCtx *query_context.Context) error { + p := w.p + // Evaluate rules' matchers in loop. +checkMatchesLoop: + for p < len(w.chain) { + n := w.chain[p] + + for _, match := range n.Matches { + ok, err := match.Match(ctx, qCtx) + if err != nil { + return err + } + if !ok { + // Skip this node if condition was not matched. + p++ + continue checkMatchesLoop + } + } + + // Exec rules' executables in loop, or in stack if it is a recursive executable. + switch { + case n.E != nil: + if err := n.E.Exec(ctx, qCtx); err != nil { + return err + } + p++ + continue + case n.RE != nil: + next := ChainWalker{ + p: p + 1, + chain: w.chain, + jumpBack: w.jumpBack, + } + return n.RE.Exec(ctx, qCtx, next) + default: + panic("n cannot be executed") + } + } + + if w.jumpBack != nil { // End of chain, time to jump back. + return w.jumpBack.ExecNext(ctx, qCtx) + } + + // EoC. + return nil +} + +func (w *ChainWalker) nop() bool { + return w.p >= len(w.chain) +} + +func (s *Sequence) buildChain(bq BQ, rs []RuleConfig) error { + c := make([]*ChainNode, 0, len(rs)) + for ri, r := range rs { + n, err := s.newNode(bq, r, ri) + if err != nil { + return fmt.Errorf("failed to init rule #%d, %w", ri, err) + } + c = append(c, n) + } + s.chain = c + return nil +} + +func (s *Sequence) newNode(bq BQ, r RuleConfig, ri int) (*ChainNode, error) { + n := new(ChainNode) + + // init matches + for mi, mc := range r.Matches { + m, err := s.newMatcher(bq, mc, ri, mi) + if err != nil { + return nil, fmt.Errorf("failed to init matcher #%d, %w", mi, err) + } + n.Matches = append(n.Matches, m) + } + + // init exec + e, re, err := s.newExec(bq, r, ri) + if err != nil { + return nil, fmt.Errorf("failed to init exec, %w", err) + } + n.E = e + n.RE = re + return n, nil +} + +func (s *Sequence) newMatcher(bq BQ, mc MatchConfig, ri, mi int) (Matcher, error) { + var m Matcher + switch { + case len(mc.Tag) > 0: + m, _ = bq.M().GetPlugin(mc.Tag).(Matcher) + if m == nil { + return nil, fmt.Errorf("can not find matcher %s", mc.Tag) + } + if qc, ok := m.(QuickConfigurableMatch); ok { + v, err := qc.QuickConfigureMatch(mc.Args) + if err != nil { + return nil, fmt.Errorf("fail to configure plugin %s, %w", mc.Tag, err) + } + m = v + } + + case len(mc.Type) > 0: + f := GetMatchQuickSetup(mc.Type) + if f == nil { + return nil, fmt.Errorf("invalid matcher type %s", mc.Type) + } + p, err := f(NewBQ(bq.M(), bq.L().Named(fmt.Sprintf("r%d.m%d", ri, mi))), mc.Args) + if err != nil { + return nil, fmt.Errorf("failed to init matcher, %w", err) + } + s.anonymousPlugins = append(s.anonymousPlugins, p) + m = p + } + if m == nil { + return nil, errors.New("missing args") + } + if mc.Reverse { + m = reverseMatcher(m) + } + return m, nil +} + +func (s *Sequence) newExec(bq BQ, rc RuleConfig, ri int) (Executable, RecursiveExecutable, error) { + var exec any + switch { + case len(rc.Tag) > 0: + p := bq.M().GetPlugin(rc.Tag) + if p == nil { + return nil, nil, fmt.Errorf("can not find executable %s", rc.Tag) + } + if qc, ok := p.(QuickConfigurableExec); ok { + v, err := qc.QuickConfigureExec(rc.Args) + if err != nil { + return nil, nil, fmt.Errorf("fail to configure plugin %s, %w", rc.Tag, err) + } + exec = v + } else { + exec = p + } + + case len(rc.Type) > 0: + f := GetExecQuickSetup(rc.Type) + if f == nil { + return nil, nil, fmt.Errorf("invalid executable type %s", rc.Type) + } + v, err := f(NewBQ(bq.M(), bq.L().Named(fmt.Sprintf("r%d", ri))), rc.Args) + if err != nil { + return nil, nil, fmt.Errorf("failed to init executable, %w", err) + } + s.anonymousPlugins = append(s.anonymousPlugins, v) + exec = v + default: + return nil, nil, errors.New("missing args") + } + + e, _ := exec.(Executable) + re, _ := exec.(RecursiveExecutable) + + if re == nil && e == nil { + return nil, nil, errors.New("invalid args, initialized object is not executable") + } + return e, re, nil +} + +func closePlugin(p any) { + if c, ok := p.(io.Closer); ok { + _ = c.Close() + } +} + +func reverseMatcher(m Matcher) Matcher { + return reverseMatch{m: m} +} + +type reverseMatch struct { + m Matcher +} + +func (r reverseMatch) Match(ctx context.Context, qCtx *query_context.Context) (bool, error) { + ok, err := r.m.Match(ctx, qCtx) + if err != nil { + return false, err + } + return !ok, nil +} diff --git a/plugin/executable/sequence/config.go b/plugin/executable/sequence/config.go new file mode 100644 index 0000000..71584f8 --- /dev/null +++ b/plugin/executable/sequence/config.go @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import "strings" + +type RuleArgs struct { + Matches []string `yaml:"matches"` + Exec string `yaml:"exec"` +} + +func parseArgs(ra RuleArgs) RuleConfig { + var rc RuleConfig + for _, s := range ra.Matches { + rc.Matches = append(rc.Matches, parseMatch(s)) + } + tag, typ, args := parseExec(ra.Exec) + rc.Tag = tag + rc.Type = typ + rc.Args = args + return rc +} + +func parseMatch(s string) MatchConfig { + var mc MatchConfig + s = strings.TrimSpace(s) + s, reverse := trimPrefixField(s, "!") + mc.Reverse = reverse + p, args, _ := strings.Cut(s, " ") + args = strings.TrimSpace(args) + mc.Args = args + if tag, ok := trimPrefixField(p, "$"); ok { + mc.Tag = tag + } else { + mc.Type = p + } + return mc +} + +func parseExec(s string) (tag string, typ string, args string) { + s = strings.TrimSpace(s) + p, args, _ := strings.Cut(s, " ") + args = strings.TrimSpace(args) + p, ok := trimPrefixField(p, "$") + if ok { + tag = p + } else { + typ = p + } + return +} + +type RuleConfig struct { + Matches []MatchConfig `yaml:"matches"` + Tag string `yaml:"tag"` + Type string `yaml:"type"` + Args string `yaml:"args"` +} + +type MatchConfig struct { + Tag string `yaml:"tag"` + Type string `yaml:"type"` + Args string `yaml:"args"` + Reverse bool `yaml:"reverse"` +} + +func trimPrefixField(s, p string) (string, bool) { + if strings.HasPrefix(s, p) { + return strings.TrimSpace(strings.TrimPrefix(s, p)), true + } + return s, false +} diff --git a/plugin/executable/sequence/config_test.go b/plugin/executable/sequence/config_test.go new file mode 100644 index 0000000..a56dfcd --- /dev/null +++ b/plugin/executable/sequence/config_test.go @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import ( + "reflect" + "testing" +) + +func Test_parseExec(t *testing.T) { + + tests := []struct { + name string + args string + wantTag string + wantTyp string + wantArgs string + }{ + {"", " $t1 a 1 ", "t1", "", "a 1"}, + {"", " typ a 1 ", "", "typ", "a 1"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotTag, gotTyp, gotArgs := parseExec(tt.args) + if gotTag != tt.wantTag { + t.Errorf("parseExec() gotTag = %v, want %v", gotTag, tt.wantTag) + } + if gotTyp != tt.wantTyp { + t.Errorf("parseExec() gotTyp = %v, want %v", gotTyp, tt.wantTyp) + } + if gotArgs != tt.wantArgs { + t.Errorf("parseExec() gotArgs = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func Test_parseMatch(t *testing.T) { + tests := []struct { + name string + args string + want MatchConfig + }{ + {"", " $m1 a 1 ", MatchConfig{ + Tag: "m1", + Type: "", + Args: "a 1", + Reverse: false, + }}, + {"", " ! typ a 1 ", MatchConfig{ + Tag: "", + Type: "typ", + Args: "a 1", + Reverse: true, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseMatch(tt.args); !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseMatch() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/plugin/executable/sequence/fallback/fallback.go b/plugin/executable/sequence/fallback/fallback.go new file mode 100644 index 0000000..7145a8f --- /dev/null +++ b/plugin/executable/sequence/fallback/fallback.go @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package fallback + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "go.uber.org/zap" +) + +const PluginType = "fallback" + +const ( + defaultParallelTimeout = time.Second * 5 + defaultFallbackThreshold = time.Millisecond * 500 +) + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +type fallback struct { + logger *zap.Logger + primary sequence.Executable + secondary sequence.Executable + fastFallbackDuration time.Duration + alwaysStandby bool +} + +type Args struct { + // Primary exec sequence. + Primary string `yaml:"primary"` + // Secondary exec sequence. + Secondary string `yaml:"secondary"` + + // Threshold in milliseconds. Default is 500. + Threshold int `yaml:"threshold"` + + // AlwaysStandby: secondary should always stand by in fallback. + AlwaysStandby bool `yaml:"always_standby"` +} + +func Init(bp *coremain.BP, args any) (any, error) { + return newFallbackPlugin(bp, args.(*Args)) +} + +func newFallbackPlugin(bp *coremain.BP, args *Args) (*fallback, error) { + if len(args.Primary) == 0 || len(args.Secondary) == 0 { + return nil, errors.New("args missing primary or secondary") + } + + pe := sequence.ToExecutable(bp.M().GetPlugin(args.Primary)) + if pe == nil { + return nil, fmt.Errorf("can not find primary executable %s", args.Primary) + } + se := sequence.ToExecutable(bp.M().GetPlugin(args.Secondary)) + if se == nil { + return nil, fmt.Errorf("can not find secondary executable %s", args.Secondary) + } + threshold := time.Duration(args.Threshold) * time.Millisecond + if threshold <= 0 { + threshold = defaultFallbackThreshold + } + + s := &fallback{ + logger: bp.L(), + primary: pe, + secondary: se, + fastFallbackDuration: threshold, + alwaysStandby: args.AlwaysStandby, + } + return s, nil +} + +var ( + ErrFailed = errors.New("no valid response from both primary and secondary") +) + +var _ sequence.Executable = (*fallback)(nil) + +func (f *fallback) Exec(ctx context.Context, qCtx *query_context.Context) error { + return f.doFallback(ctx, qCtx) +} + +func (f *fallback) doFallback(ctx context.Context, qCtx *query_context.Context) error { + respChan := make(chan *dns.Msg, 2) // resp could be nil. + primFailed := make(chan struct{}) + primDone := make(chan struct{}) + + // primary goroutine. + qCtxP := qCtx.Copy() + go func() { + qCtx := qCtxP + ctx, cancel := makeDdlCtx(ctx, defaultParallelTimeout) + defer cancel() + err := f.primary.Exec(ctx, qCtx) + if err != nil { + f.logger.Warn("primary error", qCtx.InfoField(), zap.Error(err)) + } + + r := qCtx.R() + if err != nil || r == nil { + close(primFailed) + respChan <- nil + } else { + close(primDone) + respChan <- r + } + }() + + // Secondary goroutine. + qCtxS := qCtx.Copy() + go func() { + timer := pool.GetTimer(f.fastFallbackDuration) + defer pool.ReleaseTimer(timer) + if !f.alwaysStandby { // not always standby, wait here. + select { + case <-primDone: // primary is done, no need to exec this. + return + case <-primFailed: // primary failed + case <-timer.C: // timed out + } + } + + qCtx := qCtxS + ctx, cancel := makeDdlCtx(ctx, defaultParallelTimeout) + defer cancel() + err := f.secondary.Exec(ctx, qCtx) + if err != nil { + f.logger.Warn("secondary error", qCtx.InfoField(), zap.Error(err)) + respChan <- nil + return + } + + r := qCtx.R() + // always standby is enabled. Wait until secondary resp is needed. + if f.alwaysStandby && r != nil { + select { + case <-ctx.Done(): + case <-primDone: + case <-primFailed: // only send secondary result when primary is failed. + case <-timer.C: // or timed out. + } + } + respChan <- r + }() + + for i := 0; i < 2; i++ { + select { + case <-ctx.Done(): + return context.Cause(ctx) + case r := <-respChan: + if r == nil { // One of goroutines finished but failed. + continue + } + qCtx.SetResponse(r) + return nil + } + } + + // All goroutines finished but failed. + return ErrFailed +} + +func makeDdlCtx(ctx context.Context, timeout time.Duration) (context.Context, func()) { + ddl, ok := ctx.Deadline() + if !ok { + ddl = time.Now().Add(timeout) + } + return context.WithDeadline(context.Background(), ddl) +} diff --git a/plugin/executable/sequence/iface.go b/plugin/executable/sequence/iface.go new file mode 100644 index 0000000..8761d2c --- /dev/null +++ b/plugin/executable/sequence/iface.go @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" +) + +// RecursiveExecutable represents something that is executable and requires stack. +type RecursiveExecutable interface { + Exec(ctx context.Context, qCtx *query_context.Context, next ChainWalker) error +} + +// Executable represents something that is executable. +type Executable interface { + Exec(ctx context.Context, qCtx *query_context.Context) error +} + +// Matcher represents a matcher that can match a certain patten in Context. +type Matcher interface { + Match(ctx context.Context, qCtx *query_context.Context) (bool, error) +} + +type RecursiveExecutableFunc func(ctx context.Context, qCtx *query_context.Context, next ChainWalker) error + +func (f RecursiveExecutableFunc) Exec(ctx context.Context, qCtx *query_context.Context, next ChainWalker) error { + return f(ctx, qCtx, next) +} + +type ExecutableFunc func(ctx context.Context, qCtx *query_context.Context) error + +func (f ExecutableFunc) Exec(ctx context.Context, qCtx *query_context.Context) error { + return f(ctx, qCtx) +} + +type MatchFunc func(ctx context.Context, qCtx *query_context.Context) (bool, error) + +func (f MatchFunc) Match(ctx context.Context, qCtx *query_context.Context) (bool, error) { + return f(ctx, qCtx) +} diff --git a/plugin/executable/sequence/quick_config.go b/plugin/executable/sequence/quick_config.go new file mode 100644 index 0000000..77a4e71 --- /dev/null +++ b/plugin/executable/sequence/quick_config.go @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +// QuickConfigurableExec can configure an exiting plugin with additional args. +// Expecting return is an Executable or RecursiveExecutable. +type QuickConfigurableExec interface { + QuickConfigureExec(args string) (any, error) +} + +// QuickConfigurableMatch can configure an exiting plugin with additional args. +type QuickConfigurableMatch interface { + QuickConfigureMatch(args string) (Matcher, error) +} diff --git a/plugin/executable/sequence/quick_setup.go b/plugin/executable/sequence/quick_setup.go new file mode 100644 index 0000000..c6542c7 --- /dev/null +++ b/plugin/executable/sequence/quick_setup.go @@ -0,0 +1,122 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/coremain" + "go.uber.org/zap" + "sync" +) + +type BQ interface { + // M returns a non-nil *coremain.Mosdns. + M() *coremain.Mosdns + // L returns a non-nil *zap.Logger. + L() *zap.Logger +} + +type bq struct { + m *coremain.Mosdns + l *zap.Logger +} + +func (bq *bq) M() *coremain.Mosdns { + return bq.m +} + +func (bq *bq) L() *zap.Logger { + return bq.l +} + +func NewBQ(m *coremain.Mosdns, l *zap.Logger) BQ { + return &bq{m: m, l: l} +} + +// ExecQuickSetupFunc configures an Executable or +// RecursiveExecutable with a simple string args. +type ExecQuickSetupFunc func(bq BQ, args string) (any, error) + +// MatchQuickSetupFunc configures a Matcher with a simple string args. +type MatchQuickSetupFunc func(bq BQ, args string) (Matcher, error) + +var execQuickSetupReg struct { + sync.RWMutex + m map[string]ExecQuickSetupFunc +} + +var matchQuickSetupReg struct { + sync.RWMutex + m map[string]MatchQuickSetupFunc +} + +func RegExecQuickSetup(typ string, f ExecQuickSetupFunc) error { + execQuickSetupReg.Lock() + defer execQuickSetupReg.Unlock() + + _, ok := execQuickSetupReg.m[typ] + if ok { + return fmt.Errorf("type %s has already been registered", typ) + } + if execQuickSetupReg.m == nil { + execQuickSetupReg.m = make(map[string]ExecQuickSetupFunc) + } + execQuickSetupReg.m[typ] = f + return nil +} + +func MustRegExecQuickSetup(typ string, f ExecQuickSetupFunc) { + if err := RegExecQuickSetup(typ, f); err != nil { + panic(err.Error()) + } +} + +func GetExecQuickSetup(typ string) ExecQuickSetupFunc { + execQuickSetupReg.RLock() + defer execQuickSetupReg.RUnlock() + return execQuickSetupReg.m[typ] +} + +func RegMatchQuickSetup(typ string, f MatchQuickSetupFunc) error { + matchQuickSetupReg.Lock() + defer matchQuickSetupReg.Unlock() + + _, ok := matchQuickSetupReg.m[typ] + if ok { + return fmt.Errorf("type %s has already been registered", typ) + } + if matchQuickSetupReg.m == nil { + matchQuickSetupReg.m = make(map[string]MatchQuickSetupFunc) + } + matchQuickSetupReg.m[typ] = f + return nil +} + +func MustRegMatchQuickSetup(typ string, f MatchQuickSetupFunc) { + if err := RegMatchQuickSetup(typ, f); err != nil { + panic(err.Error()) + } +} + +func GetMatchQuickSetup(typ string) MatchQuickSetupFunc { + matchQuickSetupReg.RLock() + defer matchQuickSetupReg.RUnlock() + return matchQuickSetupReg.m[typ] +} diff --git a/plugin/executable/sequence/sequence.go b/plugin/executable/sequence/sequence.go new file mode 100644 index 0000000..8b8555e --- /dev/null +++ b/plugin/executable/sequence/sequence.go @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" +) + +const PluginType = "sequence" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) + + MustRegExecQuickSetup("accept", setupAccept) + MustRegExecQuickSetup("reject", setupReject) + MustRegExecQuickSetup("return", setupReturn) + MustRegExecQuickSetup("goto", setupGoto) + MustRegExecQuickSetup("jump", setupJump) + MustRegMatchQuickSetup("_true", setupTrue) // add _ prefix to avoid being mis-parsed as bool + MustRegMatchQuickSetup("_false", setupFalse) +} + +type Sequence struct { + chain []*ChainNode + anonymousPlugins []any +} + +func (s *Sequence) Close() error { + for _, plugin := range s.anonymousPlugins { + closePlugin(plugin) + } + return nil +} + +type Args = []RuleArgs + +func Init(bp *coremain.BP, args any) (any, error) { + return NewSequence(bp, *args.(*Args)) +} + +func NewSequence(bq BQ, ra []RuleArgs) (*Sequence, error) { + s := &Sequence{} + + var rc []RuleConfig + for _, ra := range ra { + rc = append(rc, parseArgs(ra)) + } + if err := s.buildChain(bq, rc); err != nil { + _ = s.Close() + return nil, err + } + return s, nil +} + +func (s *Sequence) Exec(ctx context.Context, qCtx *query_context.Context) error { + walker := NewChainWalker(s.chain, nil) + return walker.ExecNext(ctx, qCtx) +} diff --git a/plugin/executable/sequence/sequence_test.go b/plugin/executable/sequence/sequence_test.go new file mode 100644 index 0000000..8c74960 --- /dev/null +++ b/plugin/executable/sequence/sequence_test.go @@ -0,0 +1,199 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import ( + "context" + "errors" + "testing" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/miekg/dns" +) + +type dummy struct { + matched bool + wantErr error + wantR *dns.Msg + dropR bool + wantReturn bool +} + +func (d *dummy) Match(ctx context.Context, qCtx *query_context.Context) (bool, error) { + if d.wantErr != nil { + return false, d.wantErr + } + return d.matched, nil +} + +func (d *dummy) Exec(ctx context.Context, qCtx *query_context.Context, next ChainWalker) error { + if d.wantErr != nil { + return d.wantErr + } + if d.wantR != nil { + qCtx.SetResponse(d.wantR) + } + if d.dropR { + qCtx.SetResponse(nil) + } + if d.wantReturn { + return nil + } + return next.ExecNext(ctx, qCtx) +} + +func preparePlugins(p map[string]any) { + p["target"] = &dummy{wantR: new(dns.Msg)} + p["err"] = &dummy{wantErr: errors.New("err")} + p["drop"] = &dummy{dropR: true} + p["nop"] = &dummy{} + p["true"] = &dummy{matched: true} + p["false"] = &dummy{matched: false} +} + +func Test_sequence_Exec(t *testing.T) { + tests := []struct { + name string + ra []RuleArgs + ra2 []RuleArgs + wantErr bool + wantTarget bool + }{ + { + name: "exec", + ra: []RuleArgs{ + {Exec: "$nop"}, + {Exec: "$target"}, + {Exec: "return"}, + {Exec: "$err"}, // skipped + }, + wantErr: false, + wantTarget: true, + }, + { + name: "match", + ra: []RuleArgs{ + { + Matches: []string{"$true", "$false", "$err"}, // skip following matches when false + Exec: "$err", // skip exec when false + }, + { + Matches: []string{"$false", "$err"}, + Exec: "$err", + }, + { + Matches: []string{"$true", "$true"}, + Exec: "$target", + }, + }, + wantErr: false, + wantTarget: true, + }, + { + name: "goto return", + ra: []RuleArgs{ + {Exec: "goto seq2"}, + {Exec: "$err"}, // goto skips fallowing nodes. + }, + ra2: []RuleArgs{ + {Exec: "$target"}, + {Exec: "return"}, + {Exec: "$err"}, // return skips fallowing nodes. + }, + wantErr: false, + wantTarget: true, + }, + { + name: "jump return", + ra: []RuleArgs{ + {Exec: "jump seq2"}, + {Exec: "$target"}, + }, + ra2: []RuleArgs{ + {Exec: "$nop"}, + {Exec: "return"}, + {Exec: "$err"}, + }, + wantErr: false, + wantTarget: true, + }, + { + name: "jump accept", + ra: []RuleArgs{ + {Exec: "jump seq2"}, + {Exec: "$err"}, // accepted in seq2, skipped + }, + ra2: []RuleArgs{ + {Exec: "$target"}, + {Exec: "accept"}, + {Exec: "$err"}, + }, + wantErr: false, + wantTarget: true, + }, + { + name: "jump end", + ra: []RuleArgs{ + {Exec: "jump seq2"}, + {Exec: "$target"}, + }, + ra2: []RuleArgs{ + {Exec: "$nop"}, + }, + wantErr: false, + wantTarget: true, + }, + { + name: "reject", + ra: []RuleArgs{ + {Exec: "reject"}, + {Exec: "$err"}, // skipped + }, + wantErr: false, + wantTarget: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ps := make(map[string]any) + m := coremain.NewTestMosdnsWithPlugins(ps) + preparePlugins(ps) + if len(tt.ra2) > 0 { + s, err := NewSequence(coremain.NewBP("test", m), tt.ra2) + if err != nil { + t.Fatal(err) + } + ps["seq2"] = s + } + s, err := NewSequence(coremain.NewBP("test", m), tt.ra) + if err != nil { + t.Fatal(err) + } + qCtx := query_context.NewContext(new(dns.Msg)) + if err := s.Exec(context.Background(), qCtx); (err != nil) != tt.wantErr { + t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr) + } + if getTarget := qCtx.R() != nil; getTarget != tt.wantTarget { + t.Errorf("Exec() getTarget = %v, wantTarget %v", getTarget, tt.wantTarget) + } + }) + } +} diff --git a/plugin/executable/sequence/utils.go b/plugin/executable/sequence/utils.go new file mode 100644 index 0000000..343aa57 --- /dev/null +++ b/plugin/executable/sequence/utils.go @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sequence + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" +) + +// reWrapper converts RecursiveExecutable to Executable +type reWrapper struct { + re RecursiveExecutable +} + +func (r *reWrapper) Exec(ctx context.Context, qCtx *query_context.Context) error { + return r.re.Exec(ctx, qCtx, ChainWalker{}) +} + +func ToExecutable(v any) Executable { + switch v := v.(type) { + case Executable: + return v + case RecursiveExecutable: + return &reWrapper{re: v} + default: + return nil + } +} diff --git a/plugin/executable/sleep/sleep.go b/plugin/executable/sleep/sleep.go new file mode 100644 index 0000000..81dcbe4 --- /dev/null +++ b/plugin/executable/sleep/sleep.go @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package sleep + +import ( + "context" + "strconv" + "time" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" +) + +const PluginType = "sleep" + +func init() { + // Register this plugin type with its initialization funcs. So that, this plugin + // can be configured by user from configuration file. + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) + + // You can also register a plugin object directly. (If plugin do not need to configure) + // Then you can directly use "_sleep_500ms" in configuration file. + coremain.RegNewPersetPluginFunc("_sleep_500ms", func(bp *coremain.BP) (any, error) { + return &sleep{d: time.Millisecond * 500}, nil + }) + + // You can register a quick setup func for sequence. So that users can + // init your plugin in the sequence directly in one string. + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +// Args is the arguments of plugin. It will be decoded from yaml. +// So it is recommended to use `yaml` as struct field's tag. +type Args struct { + Duration uint `yaml:"duration"` // (milliseconds) duration for sleep. +} + +var _ sequence.Executable = (*sleep)(nil) + +// sleep implements handler.ExecutablePlugin. +type sleep struct { + d time.Duration +} + +// Exec implements handler.Executable. +func (s *sleep) Exec(ctx context.Context, qCtx *query_context.Context) error { + if s.d > 0 { + timer := pool.GetTimer(s.d) + defer pool.ReleaseTimer(timer) + select { + case <-timer.C: + case <-ctx.Done(): + return context.Cause(ctx) + } + } + return nil +} + +func Init(_ *coremain.BP, args any) (any, error) { + d := args.(*Args).Duration + return &sleep{ + d: time.Duration(d) * time.Millisecond, + }, nil +} + +func QuickSetup(_ sequence.BQ, s string) (any, error) { + n, err := strconv.Atoi(s) + if err != nil { + return nil, err + } + return &sleep{d: time.Duration(n) * time.Millisecond}, nil +} diff --git a/plugin/executable/ttl/ttl.go b/plugin/executable/ttl/ttl.go new file mode 100644 index 0000000..9d9c2af --- /dev/null +++ b/plugin/executable/ttl/ttl.go @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ttl + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "strconv" + "strings" +) + +const ( + PluginType = "ttl" +) + +func init() { + sequence.MustRegExecQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.Executable = (*TTL)(nil) + +type TTL struct { + fix uint32 + min uint32 + max uint32 +} + +func NewTTL(fix, min, max uint32) *TTL { + return &TTL{ + fix: fix, + min: min, + max: max, + } +} + +// QuickSetup format: {[min-max]|[fix]} +// e.g. range "300-600", fixed ttl "5". +func QuickSetup(_ sequence.BQ, s string) (any, error) { + var f, l, u uint32 + ls, us, ok := strings.Cut(s, "-") + if ok { // range + n, err := strconv.ParseUint(ls, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid lower bound, %w", err) + } + l = uint32(n) + n, err = strconv.ParseUint(us, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid upper bound, %w", err) + } + u = uint32(n) + } else { // fixed + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid ttl, %w", err) + } + f = uint32(n) + } + + return NewTTL(f, l, u), nil +} + +func (t *TTL) Exec(_ context.Context, qCtx *query_context.Context) error { + if r := qCtx.R(); r != nil { + if t.fix > 0 { + dnsutils.SetTTL(r, t.fix) + } else { + if t.min > 0 { + dnsutils.ApplyMinimalTTL(r, t.min) + } + if t.max > 0 { + dnsutils.ApplyMaximumTTL(r, t.max) + } + } + } + return nil +} diff --git a/plugin/mark/mark.go b/plugin/mark/mark.go new file mode 100644 index 0000000..ae0f004 --- /dev/null +++ b/plugin/mark/mark.go @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package mark + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "strconv" + "strings" +) + +const PluginType = "mark" + +func init() { + sequence.MustRegExecQuickSetup(PluginType, func(_ sequence.BQ, args string) (any, error) { + return newMarker(args) + }) + sequence.MustRegMatchQuickSetup(PluginType, func(_ sequence.BQ, args string) (sequence.Matcher, error) { + return newMarker(args) + }) +} + +var _ sequence.Executable = (*mark)(nil) +var _ sequence.Matcher = (*mark)(nil) + +type mark struct { + m []uint32 +} + +func (m *mark) Match(_ context.Context, qCtx *query_context.Context) (bool, error) { + for _, u := range m.m { + if qCtx.HasMark(u) { + return true, nil + } + } + return false, nil +} + +func (m *mark) Exec(_ context.Context, qCtx *query_context.Context) error { + for _, u := range m.m { + qCtx.SetMark(u) + } + return nil +} + +// newMarker format: [uint32_mark]... +// "uint32_mark" is an uint32 defined as Go syntax for integer literals. +// e.g. "111", "0b111", "0o111", "0xfff". +func newMarker(s string) (*mark, error) { + var m []uint32 + for _, ms := range strings.Fields(s) { + n, err := strconv.ParseUint(ms, 10, 32) + if err != nil { + return nil, err + } + m = append(m, uint32(n)) + } + return &mark{m: m}, nil +} diff --git a/plugin/matcher/base_domain/domain_matcher.go b/plugin/matcher/base_domain/domain_matcher.go new file mode 100644 index 0000000..069ddd8 --- /dev/null +++ b/plugin/matcher/base_domain/domain_matcher.go @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package base_domain + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/data_provider" + "github.com/IrineSistiana/mosdns/v5/plugin/data_provider/domain_set" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "strings" +) + +var _ sequence.Matcher = (*Matcher)(nil) + +type Args struct { + Exps []string `yaml:"exps"` + DomainSets []string `yaml:"domain_sets"` + Files []string `yaml:"files"` +} + +type MatchFunc func(qCtx *query_context.Context, m domain.Matcher[struct{}]) (bool, error) + +type Matcher struct { + match MatchFunc + mg []domain.Matcher[struct{}] +} + +func (m *Matcher) Match(_ context.Context, qCtx *query_context.Context) (bool, error) { + return m.match(qCtx, domain_set.MatcherGroup(m.mg)) +} + +func NewMatcher(bq sequence.BQ, args *Args, f MatchFunc) (m *Matcher, err error) { + m = &Matcher{ + match: f, + } + + // Acquire matchers from other plugins. + for _, tag := range args.DomainSets { + p := bq.M().GetPlugin(tag) + dsProvider, _ := p.(data_provider.DomainMatcherProvider) + if dsProvider == nil { + return nil, fmt.Errorf("cannot find domain set %s", tag) + } + dm := dsProvider.GetDomainMatcher() + m.mg = append(m.mg, dm) + } + + // Anonymous set from plugin's args and files. + if len(args.Exps)+len(args.Files) > 0 { + anonymousSet := domain.NewDomainMixMatcher() + if err := domain_set.LoadExpsAndFiles(args.Exps, args.Files, anonymousSet); err != nil { + return nil, err + } + if anonymousSet.Len() > 0 { + m.mg = append(m.mg, anonymousSet) + } + } + + return m, nil +} + +// ParseQuickSetupArgs parses expressions and domain set to args. +// Format: "([exp] | [$domain_set_tag] | [&domain_list_file])..." +func ParseQuickSetupArgs(s string) *Args { + cutPrefix := func(s string, p string) (string, bool) { + if strings.HasPrefix(s, p) { + return strings.TrimPrefix(s, p), true + } + return s, false + } + + args := new(Args) + for _, exp := range strings.Fields(s) { + if tag, ok := cutPrefix(exp, "$"); ok { + args.DomainSets = append(args.DomainSets, tag) + } else if path, ok := cutPrefix(exp, "&"); ok { + args.Files = append(args.Files, path) + } else { + args.Exps = append(args.Exps, exp) + } + } + return args +} diff --git a/plugin/matcher/base_int/int_matcher.go b/plugin/matcher/base_int/int_matcher.go new file mode 100644 index 0000000..2f148b4 --- /dev/null +++ b/plugin/matcher/base_int/int_matcher.go @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package base_int + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "strconv" + "strings" +) + +var _ sequence.Matcher = (*Matcher)(nil) + +type MatchFunc func(qCtx *query_context.Context, m IntMatcher) (bool, error) + +type Matcher struct { + match MatchFunc + m IntMatcher +} + +func (m *Matcher) Match(_ context.Context, qCtx *query_context.Context) (bool, error) { + return m.match(qCtx, m.m) +} + +func NewMatcher(args []int, f MatchFunc) (*Matcher, error) { + m := &Matcher{ + match: f, + m: make(map[int]struct{}), + } + for _, i := range args { + m.m[i] = struct{}{} + } + return m, nil +} + +// ParseQuickSetupArgs parses numbers to Args. +// Format: "[int]..." +func ParseQuickSetupArgs(s string) ([]int, error) { + args := make([]int, 0) + for i, s := range strings.Fields(s) { + n, err := strconv.Atoi(s) + if err != nil { + return nil, fmt.Errorf("arg #%d is not an int, %w", i, err) + } + args = append(args, n) + } + return args, nil +} + +// QuickSetup returns a sequence.ExecQuickSetupFunc. +func QuickSetup(f MatchFunc) func(_ sequence.BQ, s string) (sequence.Matcher, error) { + return func(_ sequence.BQ, s string) (sequence.Matcher, error) { + args, err := ParseQuickSetupArgs(s) + if err != nil { + return nil, fmt.Errorf("invalid args, %w", err) + } + return NewMatcher(args, f) + } +} + +type IntMatcher map[int]struct{} + +func (m IntMatcher) Has(i int) bool { + _, ok := m[i] + return ok +} diff --git a/plugin/matcher/base_ip/ip_matcher.go b/plugin/matcher/base_ip/ip_matcher.go new file mode 100644 index 0000000..b192764 --- /dev/null +++ b/plugin/matcher/base_ip/ip_matcher.go @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package base_ip + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/data_provider" + "github.com/IrineSistiana/mosdns/v5/plugin/data_provider/ip_set" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "strings" +) + +var _ sequence.Matcher = (*Matcher)(nil) + +type Args struct { + IPs []string `yaml:"ips"` + IPSets []string `yaml:"ip_sets"` + Files []string `yaml:"files"` +} + +type MatchFunc func(qCtx *query_context.Context, m netlist.Matcher) (bool, error) + +type Matcher struct { + match MatchFunc + + mg []netlist.Matcher +} + +func (m *Matcher) Match(_ context.Context, qCtx *query_context.Context) (matched bool, err error) { + return m.match(qCtx, ip_set.MatcherGroup(m.mg)) +} + +func NewMatcher(bq sequence.BQ, args *Args, f MatchFunc) (m *Matcher, err error) { + m = &Matcher{ + match: f, + } + + // Acquire lists from other plugins or files. + for _, tag := range args.IPSets { + p := bq.M().GetPlugin(tag) + provider, _ := p.(data_provider.IPMatcherProvider) + if provider == nil { + return nil, fmt.Errorf("cannot find ipset %s", tag) + } + l := provider.GetIPMatcher() + m.mg = append(m.mg, l) + } + + // Anonymous set from plugin's args and files. + if len(args.IPs)+len(args.Files) > 0 { + anonymousList := netlist.NewList() + if err := ip_set.LoadFromIPsAndFiles(args.IPs, args.Files, anonymousList); err != nil { + return nil, err + } + anonymousList.Sort() + if anonymousList.Len() > 0 { + m.mg = append(m.mg, anonymousList) + } + } + + return m, nil +} + +// ParseQuickSetupArgs parses expressions and "ip_set"s to args. +// Format: "([ip] | [$ip_set_tag] | [&ip_list_file])..." +func ParseQuickSetupArgs(s string) *Args { + cutPrefix := func(s string, p string) (string, bool) { + if strings.HasPrefix(s, p) { + return strings.TrimPrefix(s, p), true + } + return s, false + } + + args := new(Args) + for _, exp := range strings.Fields(s) { + if tag, ok := cutPrefix(exp, "$"); ok { + args.IPSets = append(args.IPSets, tag) + } else if path, ok := cutPrefix(exp, "&"); ok { + args.Files = append(args.Files, path) + } else { + args.IPs = append(args.IPs, exp) + } + } + return args +} diff --git a/plugin/matcher/client_ip/client_ip_matcher.go b/plugin/matcher/client_ip/client_ip_matcher.go new file mode 100644 index 0000000..0e41355 --- /dev/null +++ b/plugin/matcher/client_ip/client_ip_matcher.go @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package client_ip + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/IrineSistiana/mosdns/v5/plugin/matcher/base_ip" +) + +const PluginType = "client_ip" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +type Args = base_ip.Args + +func QuickSetup(bq sequence.BQ, s string) (sequence.Matcher, error) { + return base_ip.NewMatcher(bq, base_ip.ParseQuickSetupArgs(s), matchClientAddr) +} + +func matchClientAddr(qCtx *query_context.Context, m netlist.Matcher) (bool, error) { + addr := qCtx.ServerMeta.ClientAddr + if !addr.IsValid() { + return false, nil + } + return m.Match(addr), nil +} diff --git a/plugin/matcher/cname/cname_matcher.go b/plugin/matcher/cname/cname_matcher.go new file mode 100644 index 0000000..5527f65 --- /dev/null +++ b/plugin/matcher/cname/cname_matcher.go @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package qname_matcher + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/IrineSistiana/mosdns/v5/plugin/matcher/base_domain" + "github.com/miekg/dns" +) + +const PluginType = "cname" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +type Args = base_domain.Args + +func QuickSetup(bq sequence.BQ, s string) (sequence.Matcher, error) { + return base_domain.NewMatcher(bq, base_domain.ParseQuickSetupArgs(s), matchCName) +} + +func matchCName(qCtx *query_context.Context, m domain.Matcher[struct{}]) (bool, error) { + r := qCtx.R() + if r == nil { + return false, nil + } + for _, rr := range r.Answer { + if cname, ok := rr.(*dns.CNAME); ok { + if _, ok := m.Match(cname.Target); ok { + return true, nil + } + } + } + return false, nil +} diff --git a/plugin/matcher/env/env.go b/plugin/matcher/env/env.go new file mode 100644 index 0000000..27e1115 --- /dev/null +++ b/plugin/matcher/env/env.go @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package env + +import ( + "fmt" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "os" + "strings" +) + +const PluginType = "env" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +func QuickSetup(_ sequence.BQ, s string) (sequence.Matcher, error) { + ss := strings.Fields(s) + var k, v string + switch len(ss) { + case 1: + k = ss[0] + case 2: + k = ss[0] + k = ss[1] + default: + return nil, fmt.Errorf("invalid arg number %d", len(ss)) + } + return CheckEnv(k, v), nil +} + +// CheckEnv checks if k is in env. If v is given, it checks whether env["k"] == v. +func CheckEnv(k, v string) sequence.Matcher { + var res bool + e, ok := os.LookupEnv(k) + if ok { + if len(v) == 0 { + res = true + } else { + res = e == v + } + } else { + res = false + } + + if res { + return sequence.MatchAlwaysTrue{} + } + return sequence.MatchAlwaysFalse{} +} diff --git a/plugin/matcher/has_resp/has_resp.go b/plugin/matcher/has_resp/has_resp.go new file mode 100644 index 0000000..d79cb83 --- /dev/null +++ b/plugin/matcher/has_resp/has_resp.go @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package has_resp + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" +) + +const PluginType = "has_resp" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +type haveResp struct{} + +func (h haveResp) Match(_ context.Context, qCtx *query_context.Context) (bool, error) { + return qCtx.R() != nil, nil +} + +func QuickSetup(_ sequence.BQ, _ string) (sequence.Matcher, error) { + return haveResp{}, nil +} diff --git a/plugin/matcher/has_wanted_ans/has_wanted_ans.go b/plugin/matcher/has_wanted_ans/has_wanted_ans.go new file mode 100644 index 0000000..7da16c3 --- /dev/null +++ b/plugin/matcher/has_wanted_ans/has_wanted_ans.go @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package has_wanted_ans + +import ( + "context" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" +) + +const PluginType = "has_wanted_ans" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +type hasQuestionAns struct{} + +func (h hasQuestionAns) Match(_ context.Context, qCtx *query_context.Context) (bool, error) { + q := qCtx.Q() + if len(q.Question) == 0 { + return false, nil + } + r := qCtx.R() + if r == nil || len(r.Answer) == 0 { + return false, nil + + } + + question := q.Question[0] + for _, rr := range r.Answer { + h := rr.Header() + if h.Rrtype == question.Qtype && h.Class == question.Qclass { + return true, nil + } + } + return false, nil +} + +func QuickSetup(_ sequence.BQ, _ string) (sequence.Matcher, error) { + return hasQuestionAns{}, nil +} diff --git a/plugin/matcher/ptr_ip/ptr_ip.go b/plugin/matcher/ptr_ip/ptr_ip.go new file mode 100644 index 0000000..73006fa --- /dev/null +++ b/plugin/matcher/ptr_ip/ptr_ip.go @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ptr_ip + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/IrineSistiana/mosdns/v5/plugin/matcher/base_ip" + "github.com/miekg/dns" +) + +const PluginType = "ptr_ip" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +type Args = base_ip.Args + +func QuickSetup(bq sequence.BQ, s string) (sequence.Matcher, error) { + return base_ip.NewMatcher(bq, base_ip.ParseQuickSetupArgs(s), MatchQueryPtrIP) +} + +func MatchQueryPtrIP(qCtx *query_context.Context, m netlist.Matcher) (bool, error) { + q := qCtx.Q() + for _, question := range q.Question { + if question.Qtype == dns.TypePTR { + addr, _ := dnsutils.ParsePTRQName(question.Name) // Ignore parse error. + if addr.IsValid() && m.Match(addr) { + return true, nil + } + } + } + return false, nil +} diff --git a/plugin/matcher/qclass/qclass.go b/plugin/matcher/qclass/qclass.go new file mode 100644 index 0000000..e75773c --- /dev/null +++ b/plugin/matcher/qclass/qclass.go @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package qclass + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/IrineSistiana/mosdns/v5/plugin/matcher/base_int" +) + +const PluginType = "qclass" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, base_int.QuickSetup(matchQClass)) +} + +func matchQClass(qCtx *query_context.Context, m base_int.IntMatcher) (bool, error) { + for _, question := range qCtx.Q().Question { + if m.Has(int(question.Qclass)) { + return true, nil + } + } + return false, nil +} diff --git a/plugin/matcher/qname/qname.go b/plugin/matcher/qname/qname.go new file mode 100644 index 0000000..503e7c8 --- /dev/null +++ b/plugin/matcher/qname/qname.go @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package qname + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + base "github.com/IrineSistiana/mosdns/v5/plugin/matcher/base_domain" +) + +const PluginType = "qname" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +type Args = base.Args + +func QuickSetup(bq sequence.BQ, s string) (sequence.Matcher, error) { + return base.NewMatcher(bq, base.ParseQuickSetupArgs(s), matchQName) +} + +func matchQName(qCtx *query_context.Context, m domain.Matcher[struct{}]) (bool, error) { + for _, question := range qCtx.Q().Question { + if _, ok := m.Match(question.Name); ok { + return true, nil + } + } + return false, nil +} diff --git a/plugin/matcher/qtype/qtype.go b/plugin/matcher/qtype/qtype.go new file mode 100644 index 0000000..e382913 --- /dev/null +++ b/plugin/matcher/qtype/qtype.go @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package qtype + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/IrineSistiana/mosdns/v5/plugin/matcher/base_int" +) + +const PluginType = "qtype" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, base_int.QuickSetup(matchQType)) +} + +func matchQType(qCtx *query_context.Context, m base_int.IntMatcher) (bool, error) { + for _, question := range qCtx.Q().Question { + if m.Has(int(question.Qtype)) { + return true, nil + } + } + return false, nil +} diff --git a/plugin/matcher/random/random.go b/plugin/matcher/random/random.go new file mode 100644 index 0000000..c94682e --- /dev/null +++ b/plugin/matcher/random/random.go @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package env + +import ( + "context" + "errors" + "fmt" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "math/rand" + "strconv" + "sync" + "time" +) + +const PluginType = "random" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +func QuickSetup(_ sequence.BQ, s string) (sequence.Matcher, error) { + if len(s) == 0 { + return nil, errors.New("a float64 probability is required") + } + p, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, fmt.Errorf("invalid probability, %w", err) + } + return Random(p), nil +} + +type random struct { + prob float64 + + m sync.Mutex + r *rand.Rand +} + +func (r *random) Match(_ context.Context, _ *query_context.Context) (bool, error) { + return r.RandBool(), nil +} + +func (r *random) RandBool() bool { + r.m.Lock() + defer r.m.Unlock() + return r.r.Float64() < r.prob +} + +// Random returns a sequence.Matcher that returns true with a probability of prob. +func Random(prob float64) sequence.Matcher { + return &random{prob: prob, r: rand.New(rand.NewSource(time.Now().UnixNano()))} +} diff --git a/plugin/matcher/rcode/rcode.go b/plugin/matcher/rcode/rcode.go new file mode 100644 index 0000000..34f0a37 --- /dev/null +++ b/plugin/matcher/rcode/rcode.go @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package rcode + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/IrineSistiana/mosdns/v5/plugin/matcher/base_int" +) + +const PluginType = "rcode" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, base_int.QuickSetup(matchRcode)) +} + +func matchRcode(qCtx *query_context.Context, m base_int.IntMatcher) (bool, error) { + r := qCtx.R() + if r == nil { + return false, nil + } + return m.Has(r.Rcode), nil +} diff --git a/plugin/matcher/resp_ip/resp_ip.go b/plugin/matcher/resp_ip/resp_ip.go new file mode 100644 index 0000000..f74787b --- /dev/null +++ b/plugin/matcher/resp_ip/resp_ip.go @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package resp_ip + +import ( + "github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/IrineSistiana/mosdns/v5/plugin/matcher/base_ip" + "github.com/miekg/dns" + "net" + "net/netip" +) + +const PluginType = "resp_ip" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +type Args = base_ip.Args + +func QuickSetup(bq sequence.BQ, s string) (sequence.Matcher, error) { + return base_ip.NewMatcher(bq, base_ip.ParseQuickSetupArgs(s), matchRespAddr) +} + +func matchRespAddr(qCtx *query_context.Context, m netlist.Matcher) (bool, error) { + r := qCtx.R() + if r == nil { + return false, nil + } + for _, rr := range r.Answer { + var ip net.IP + switch rr := rr.(type) { + case *dns.A: + ip = rr.A + case *dns.AAAA: + ip = rr.AAAA + default: + continue + } + addr, ok := netip.AddrFromSlice(ip) + if ok && m.Match(addr) { + return true, nil + } + } + return false, nil +} diff --git a/plugin/matcher/string_exp/string_exp.go b/plugin/matcher/string_exp/string_exp.go new file mode 100644 index 0000000..2534328 --- /dev/null +++ b/plugin/matcher/string_exp/string_exp.go @@ -0,0 +1,184 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package string_exp + +import ( + "context" + "errors" + "fmt" + "os" + "regexp" + "strings" + + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" +) + +const PluginType = "string_exp" + +func init() { + sequence.MustRegMatchQuickSetup(PluginType, QuickSetup) +} + +var _ sequence.Matcher = (*Matcher)(nil) + +type Matcher struct { + getStr GetStrFunc + m StringMatcher +} + +type StringMatcher interface { + MatchStr(s string) bool +} + +type GetStrFunc func(qCtx *query_context.Context) string + +func (m *Matcher) Match(_ context.Context, qCtx *query_context.Context) (bool, error) { + return m.match(qCtx), nil +} + +func (m *Matcher) match(qCtx *query_context.Context) bool { + return m.m.MatchStr(m.getStr(qCtx)) +} + +func NewMatcher(f GetStrFunc, sm StringMatcher) *Matcher { + m := &Matcher{ + getStr: f, + m: sm, + } + return m +} + +// Format: "scr_string_name op [string]..." +// scr_string_name = {url_path|server_name|$env_key} +// op = {zl|eq|prefix|suffix|contains|regexp} +func QuickSetupFromStr(s string) (sequence.Matcher, error) { + sf := strings.Fields(s) + if len(sf) < 2 { + return nil, errors.New("not enough args") + } + srcStrName := sf[0] + op := sf[1] + args := sf[2:] + + var sm StringMatcher + switch op { + case "zl": + sm = opZl{} + case "eq": + m := make(map[string]struct{}) + for _, s := range args { + m[s] = struct{}{} + } + sm = &opEq{m: m} + case "regexp": + var exps []*regexp.Regexp + for _, s := range args { + exp, err := regexp.Compile(s) + if err != nil { + return nil, fmt.Errorf("invalid reg expression, %w", err) + } + exps = append(exps, exp) + } + sm = &opRegExp{exp: exps} + case "prefix": + sm = &opF{s: args, f: strings.HasPrefix} + case "suffix": + sm = &opF{s: args, f: strings.HasSuffix} + case "contains": + sm = &opF{s: args, f: strings.Contains} + default: + return nil, fmt.Errorf("invalid operator %s", op) + } + + var gf GetStrFunc + if strings.HasPrefix(srcStrName, "$") { + // Env + envKey := strings.TrimPrefix(srcStrName, "$") + gf = func(_ *query_context.Context) string { + return os.Getenv(envKey) + } + } else { + switch srcStrName { + case "url_path": + gf = getUrlPath + case "server_name": + gf = getServerName + default: + return nil, fmt.Errorf("invalid src string name %s", srcStrName) + } + } + return NewMatcher(gf, sm), nil +} + +// QuickSetup returns a sequence.ExecQuickSetupFunc. +func QuickSetup(_ sequence.BQ, s string) (sequence.Matcher, error) { + return QuickSetupFromStr(s) +} + +type opZl struct{} + +func (op opZl) MatchStr(s string) bool { + return len(s) == 0 +} + +type opEq struct { + m map[string]struct{} +} + +func (op *opEq) MatchStr(s string) bool { + _, ok := op.m[s] + return ok +} + +type opF struct { + s []string + f func(s, arg string) bool +} + +func (op *opF) MatchStr(s string) bool { + for _, sub := range op.s { + if op.f(s, sub) { + return true + } + } + return false +} + +type opRegExp struct { + exp []*regexp.Regexp +} + +func (op *opRegExp) MatchStr(s string) bool { + for _, exp := range op.exp { + if exp.MatchString(s) { + return true + } + } + return false +} + +func getUrlPath(qCtx *query_context.Context) string { + return qCtx.ServerMeta.UrlPath +} + +func getServerName(qCtx *query_context.Context) string { + return qCtx.ServerMeta.ServerName +} diff --git a/plugin/matcher/string_exp/string_exp_test.go b/plugin/matcher/string_exp/string_exp_test.go new file mode 100644 index 0000000..e810c13 --- /dev/null +++ b/plugin/matcher/string_exp/string_exp_test.go @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package string_exp + +import ( + "context" + "os" + "testing" + + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +func TestMatcher_Match(t *testing.T) { + r := require.New(t) + q := new(dns.Msg) + qc := query_context.NewContext(q) + qc.ServerMeta = query_context.ServerMeta{UrlPath: "/dns-query", ServerName: "a.b.c"} + os.Setenv("STRING_EXP_TEST", "abc") + + doTest := func(arg string, want bool) { + t.Helper() + urlMatcher, err := QuickSetupFromStr(arg) + r.NoError(err) + got, err := urlMatcher.Match(context.Background(), qc) + r.NoError(err) + r.Equal(want, got) + } + + doTest("url_path zl", false) + doTest("url_path eq /dns-query", true) + doTest("url_path eq /123 /dns-query /abc", true) + doTest("url_path eq /123 /abc", false) + doTest("url_path contains abc dns def", true) + doTest("url_path contains abc def", false) + doTest("url_path prefix abc /dns def", true) + doTest("url_path prefix abc def", false) + doTest("url_path suffix abc query def", true) + doTest("url_path suffix abc def", false) + doTest("url_path regexp ^/dns-query$", true) + doTest("url_path regexp ^abc", false) + + doTest("server_name eq abc a.b.c def", true) + doTest("server_name eq abc def", false) + + doTest("$STRING_EXP_TEST eq 123 abc def", true) + doTest("$STRING_EXP_TEST eq 123 def", false) + doTest("$STRING_EXP_TEST_NOT_EXIST eq 123 abc def", false) + doTest("$STRING_EXP_TEST_NOT_EXIST zl", true) +} diff --git a/plugin/server/http_server/http_server.go b/plugin/server/http_server/http_server.go new file mode 100644 index 0000000..69dddc4 --- /dev/null +++ b/plugin/server/http_server/http_server.go @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package tcp_server + +import ( + "context" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" + "go.uber.org/zap" + "golang.org/x/net/http2" +) + +const PluginType = "http_server" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +type Args struct { + Entries []struct { + Exec string `yaml:"exec"` + Path string `yaml:"path"` + } `yaml:"entries"` + Listen string `yaml:"listen"` + SrcIPHeader string `yaml:"src_ip_header"` + Cert string `yaml:"cert"` + Key string `yaml:"key"` + IdleTimeout int `yaml:"idle_timeout"` +} + +func (a *Args) init() { + utils.SetDefaultNum(&a.IdleTimeout, 30) +} + +type HttpServer struct { + args *Args + + server *http.Server +} + +func (s *HttpServer) Close() error { + return s.server.Close() +} + +func Init(bp *coremain.BP, args any) (any, error) { + return StartServer(bp, args.(*Args)) +} + +func StartServer(bp *coremain.BP, args *Args) (*HttpServer, error) { + mux := http.NewServeMux() + for _, entry := range args.Entries { + dh, err := server_utils.NewHandler(bp, entry.Exec) + if err != nil { + return nil, fmt.Errorf("failed to init dns handler, %w", err) + } + hhOpts := server.HttpHandlerOpts{ + GetSrcIPFromHeader: args.SrcIPHeader, + Logger: bp.L(), + } + hh := server.NewHttpHandler(dh, hhOpts) + mux.Handle(entry.Path, hh) + } + + socketOpt := server_utils.ListenerSocketOpts{ + SO_REUSEPORT: true, + SO_RCVBUF: 64 * 1024, + } + lc := net.ListenConfig{Control: server_utils.ListenerControl(socketOpt)} + + listenerNetwork := "tcp" + if strings.HasPrefix(args.Listen, "@") { + listenerNetwork = "unix" + } + l, err := lc.Listen(context.Background(), listenerNetwork, args.Listen) + if err != nil { + return nil, fmt.Errorf("failed to listen socket, %w", err) + } + bp.L().Info("http server started", zap.Stringer("addr", l.Addr())) + + hs := &http.Server{ + Handler: mux, + ReadTimeout: time.Second, + IdleTimeout: time.Duration(args.IdleTimeout) * time.Second, + MaxHeaderBytes: 512, + } + if err := http2.ConfigureServer(hs, &http2.Server{ + MaxReadFrameSize: 16 * 1024, + IdleTimeout: time.Duration(args.IdleTimeout) * time.Second, + MaxUploadBufferPerConnection: 65535, + MaxUploadBufferPerStream: 65535, + }); err != nil { + return nil, fmt.Errorf("failed to setup http2 server, %w", err) + } + + go func() { + var err error + if len(args.Key)+len(args.Cert) > 0 { + err = hs.ServeTLS(l, args.Cert, args.Key) + } else { + err = hs.Serve(l) + } + bp.M().GetSafeClose().SendCloseSignal(err) + }() + return &HttpServer{ + args: args, + server: hs, + }, nil +} diff --git a/plugin/server/quic_server/quic_server.go b/plugin/server/quic_server/quic_server.go new file mode 100644 index 0000000..f39c078 --- /dev/null +++ b/plugin/server/quic_server/quic_server.go @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package quic_server + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "time" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" + "github.com/quic-go/quic-go" + "go.uber.org/zap" +) + +const PluginType = "quic_server" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +type Args struct { + Entry string `yaml:"entry"` + Listen string `yaml:"listen"` + Cert string `yaml:"cert"` + Key string `yaml:"key"` + IdleTimeout int `yaml:"idle_timeout"` +} + +func (a *Args) init() { + utils.SetDefaultNum(&a.IdleTimeout, 30) +} + +type QuicServer struct { + args *Args + + l *quic.Listener +} + +func (s *QuicServer) Close() error { + return s.l.Close() +} + +func Init(bp *coremain.BP, args any) (any, error) { + return StartServer(bp, args.(*Args)) +} + +func StartServer(bp *coremain.BP, args *Args) (*QuicServer, error) { + logger := bp.L() + + dh, err := server_utils.NewHandler(bp, args.Entry) + if err != nil { + return nil, fmt.Errorf("failed to init dns handler, %w", err) + } + + // Init tls + if len(args.Key) == 0 || len(args.Cert) == 0 { + return nil, errors.New("quic server requires a tls certificate") + } + tlsConfig := new(tls.Config) + if err := server.LoadCert(tlsConfig, args.Cert, args.Key); err != nil { + return nil, fmt.Errorf("failed to read tls cert, %w", err) + } + tlsConfig.NextProtos = []string{"doq"} + + uc, err := net.ListenPacket("udp", args.Listen) + if err != nil { + return nil, fmt.Errorf("failed to listen socket, %w", err) + } + + idleTimeout := time.Duration(args.IdleTimeout) * time.Second + + quicConfig := &quic.Config{ + MaxIdleTimeout: idleTimeout, + InitialStreamReceiveWindow: 4 * 1024, + MaxStreamReceiveWindow: 4 * 1024, + InitialConnectionReceiveWindow: 8 * 1024, + MaxConnectionReceiveWindow: 16 * 1024, + Allow0RTT: false, + + // UniStream is not allowed. + MaxIncomingUniStreams: -1, + } + + srk, _, err := utils.InitQUICSrkFromIfaceMac() + if err != nil { + logger.Warn("failed to init quic stateless reset key, it will be disabled", zap.Error(err)) + } + qt := &quic.Transport{ + Conn: uc, + StatelessResetKey: (*quic.StatelessResetKey)(srk), + } + + quicListener, err := qt.Listen(tlsConfig, quicConfig) + if err != nil { + qt.Close() + return nil, fmt.Errorf("failed to listen quic, %w", err) + } + bp.L().Info("quic server started", zap.Stringer("addr", quicListener.Addr())) + + go func() { + defer quicListener.Close() + serverOpts := server.DoQServerOpts{Logger: bp.L(), IdleTimeout: idleTimeout} + err := server.ServeDoQ(quicListener, dh, serverOpts) + bp.M().GetSafeClose().SendCloseSignal(err) + }() + return &QuicServer{ + args: args, + l: quicListener, + }, nil +} diff --git a/plugin/server/server_utils/handler.go b/plugin/server/server_utils/handler.go new file mode 100644 index 0000000..bbc6eab --- /dev/null +++ b/plugin/server/server_utils/handler.go @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server_utils + +import ( + "fmt" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/server_handler" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" +) + +func NewHandler(bp *coremain.BP, entry string) (server.Handler, error) { + p := bp.M().GetPlugin(entry) + exec := sequence.ToExecutable(p) + if exec == nil { + return nil, fmt.Errorf("cannot find executable entry by tag %s", entry) + } + + handlerOpts := server_handler.EntryHandlerOpts{ + Logger: bp.L(), + Entry: exec, + } + return server_handler.NewEntryHandler(handlerOpts), nil +} diff --git a/plugin/server/server_utils/socket_utils.go b/plugin/server/server_utils/socket_utils.go new file mode 100644 index 0000000..0d496ab --- /dev/null +++ b/plugin/server/server_utils/socket_utils.go @@ -0,0 +1,15 @@ +package server_utils + +import "syscall" + +type ControlFunc func(network, address string, c syscall.RawConn) error + +func NopControlFunc(network, address string, c syscall.RawConn) error { + return nil +} + +type ListenerSocketOpts struct { + SO_REUSEPORT bool + SO_RCVBUF int + SO_SNDBUF int +} diff --git a/plugin/server/server_utils/socket_utils_linux.go b/plugin/server/server_utils/socket_utils_linux.go new file mode 100644 index 0000000..e34054e --- /dev/null +++ b/plugin/server/server_utils/socket_utils_linux.go @@ -0,0 +1,46 @@ +//go:build linux + +package server_utils + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func ListenerControl(opt ListenerSocketOpts) ControlFunc { + return func(network, address string, c syscall.RawConn) error { + var ( + errControl error + errSyscall error + ) + + errControl = c.Control(func(fd uintptr) { + if opt.SO_REUSEPORT { + errSyscall = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + if errSyscall != nil { + return + } + } + + if opt.SO_RCVBUF > 0 { + errSyscall = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, opt.SO_RCVBUF) + if errSyscall != nil { + return + } + } + + if opt.SO_SNDBUF > 0 { + errSyscall = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, opt.SO_SNDBUF) + if errSyscall != nil { + return + } + } + }) + + if errControl != nil { + return errControl + } + return errSyscall + } +} diff --git a/plugin/server/server_utils/socket_utils_others.go b/plugin/server/server_utils/socket_utils_others.go new file mode 100644 index 0000000..c17e71f --- /dev/null +++ b/plugin/server/server_utils/socket_utils_others.go @@ -0,0 +1,7 @@ +//go:build !linux + +package server_utils + +func ListenerControl(opt ListenerSocketOpts) ControlFunc { + return NopControlFunc +} diff --git a/plugin/server/tcp_server/tcp_server.go b/plugin/server/tcp_server/tcp_server.go new file mode 100644 index 0000000..3196b31 --- /dev/null +++ b/plugin/server/tcp_server/tcp_server.go @@ -0,0 +1,113 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package tcp_server + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "strings" + "time" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" + "go.uber.org/zap" +) + +const PluginType = "tcp_server" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +type Args struct { + Entry string `yaml:"entry"` + Listen string `yaml:"listen"` + Cert string `yaml:"cert"` + Key string `yaml:"key"` + IdleTimeout int `yaml:"idle_timeout"` +} + +func (a *Args) init() { + utils.SetDefaultString(&a.Listen, "127.0.0.1:53") + utils.SetDefaultNum(&a.IdleTimeout, 10) +} + +type TcpServer struct { + args *Args + + l net.Listener +} + +func (s *TcpServer) Close() error { + return s.l.Close() +} + +func Init(bp *coremain.BP, args any) (any, error) { + return StartServer(bp, args.(*Args)) +} + +func StartServer(bp *coremain.BP, args *Args) (*TcpServer, error) { + dh, err := server_utils.NewHandler(bp, args.Entry) + if err != nil { + return nil, fmt.Errorf("failed to init dns handler, %w", err) + } + + // Init tls + var tc *tls.Config + if len(args.Key)+len(args.Cert) > 0 { + tc = new(tls.Config) + if err := server.LoadCert(tc, args.Cert, args.Key); err != nil { + return nil, fmt.Errorf("failed to read tls cert, %w", err) + } + } + + socketOpt := server_utils.ListenerSocketOpts{ + SO_REUSEPORT: true, + SO_RCVBUF: 64 * 1024, + } + lc := net.ListenConfig{Control: server_utils.ListenerControl(socketOpt)} + listenerNetwork := "tcp" + if strings.HasPrefix(args.Listen, "@") { + listenerNetwork = "unix" + } + l, err := lc.Listen(context.Background(), listenerNetwork, args.Listen) + if err != nil { + return nil, fmt.Errorf("failed to listen socket, %w", err) + } + if tc != nil { + l = tls.NewListener(l, tc) + } + bp.L().Info("tcp server started", zap.Stringer("addr", l.Addr()), zap.Bool("tls", tc != nil)) + + go func() { + defer l.Close() + serverOpts := server.TCPServerOpts{Logger: bp.L(), IdleTimeout: time.Duration(args.IdleTimeout) * time.Second} + err := server.ServeTCP(l, dh, serverOpts) + bp.M().GetSafeClose().SendCloseSignal(err) + }() + return &TcpServer{ + args: args, + l: l, + }, nil +} diff --git a/plugin/server/udp_server/udp_server.go b/plugin/server/udp_server/udp_server.go new file mode 100644 index 0000000..e104df6 --- /dev/null +++ b/plugin/server/udp_server/udp_server.go @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package udp_server + +import ( + "context" + "fmt" + "net" + + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" + "go.uber.org/zap" +) + +const PluginType = "udp_server" + +func init() { + coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) +} + +type Args struct { + Entry string `yaml:"entry"` + Listen string `yaml:"listen"` +} + +func (a *Args) init() { + utils.SetDefaultString(&a.Listen, "127.0.0.1:53") +} + +type UdpServer struct { + args *Args + + c net.PacketConn +} + +func (s *UdpServer) Close() error { + return s.c.Close() +} + +func Init(bp *coremain.BP, args any) (any, error) { + return StartServer(bp, args.(*Args)) +} + +func StartServer(bp *coremain.BP, args *Args) (*UdpServer, error) { + dh, err := server_utils.NewHandler(bp, args.Entry) + if err != nil { + return nil, fmt.Errorf("failed to init dns handler, %w", err) + } + + socketOpt := server_utils.ListenerSocketOpts{ + SO_REUSEPORT: true, + SO_RCVBUF: 64 * 1024, + } + lc := net.ListenConfig{Control: server_utils.ListenerControl(socketOpt)} + c, err := lc.ListenPacket(context.Background(), "udp", args.Listen) + if err != nil { + return nil, fmt.Errorf("failed to create socket, %w", err) + } + bp.L().Info("udp server started", zap.Stringer("addr", c.LocalAddr())) + + go func() { + defer c.Close() + err := server.ServeUDP(c.(*net.UDPConn), dh, server.UDPServerOpts{Logger: bp.L()}) + bp.M().GetSafeClose().SendCloseSignal(err) + }() + return &UdpServer{ + args: args, + c: c, + }, nil +} diff --git a/release.py b/release.py new file mode 100644 index 0000000..4bd2b28 --- /dev/null +++ b/release.py @@ -0,0 +1,116 @@ +# !/usr/bin/env python3 +import argparse +import logging +import os +import subprocess +import zipfile + +parser = argparse.ArgumentParser() +parser.add_argument("-upx", action="store_true") +parser.add_argument("-i", type=int) +args = parser.parse_args() + +PROJECT_NAME = 'mosdns' +RELEASE_DIR = './release' + +logger = logging.getLogger(__name__) + +# more info: https://golang.org/doc/install/source +# [(env : value),(env : value)] +envs = [ + [['GOOS', 'darwin'], ['GOARCH', 'amd64']], + [['GOOS', 'darwin'], ['GOARCH', 'arm64']], + # [['GOOS', 'linux'], ['GOARCH', '386']], + [['GOOS', 'linux'], ['GOARCH', 'amd64']], + + [['GOOS', 'linux'], ['GOARCH', 'arm'], ['GOARM', '5']], + [['GOOS', 'linux'], ['GOARCH', 'arm'], ['GOARM', '6']], + [['GOOS', 'linux'], ['GOARCH', 'arm'], ['GOARM', '7']], + [['GOOS', 'linux'], ['GOARCH', 'arm64']], + + # [['GOOS', 'linux'], ['GOARCH', 'mips'], ['GOMIPS', 'hardfloat']], + # [['GOOS', 'linux'], ['GOARCH', 'mips'], ['GOMIPS', 'softfloat']], + # [['GOOS', 'linux'], ['GOARCH', 'mipsle'], ['GOMIPS', 'hardfloat']], + [['GOOS', 'linux'], ['GOARCH', 'mipsle'], ['GOMIPS', 'softfloat']], + + # [['GOOS', 'linux'], ['GOARCH', 'mips64'], ['GOMIPS64', 'hardfloat']], + # [['GOOS', 'linux'], ['GOARCH', 'mips64'], ['GOMIPS64', 'softfloat']], + [['GOOS', 'linux'], ['GOARCH', 'mips64le'], ['GOMIPS64', 'hardfloat']], + # [['GOOS', 'linux'], ['GOARCH', 'mips64le'], ['GOMIPS64', 'softfloat']], + + [['GOOS', 'linux'], ['GOARCH', 'ppc64le']], + + # [['GOOS', 'freebsd'], ['GOARCH', '386']], + [['GOOS', 'freebsd'], ['GOARCH', 'amd64']], + + # [['GOOS', 'windows'], ['GOARCH', '386']], + [['GOOS', 'windows'], ['GOARCH', 'amd64']], +] + + +def go_build(): + logger.info(f'building {PROJECT_NAME}') + + global envs + if args.i: + envs = [envs[args.i]] + + VERSION = 'dev/unknown' + try: + VERSION = subprocess.check_output('git describe --tags --long --always', shell=True).decode().rstrip() + except subprocess.CalledProcessError as e: + logger.error(f'get git tag failed: {e.args}') + + try: + subprocess.check_call('go run ../ config gen config.yaml', shell=True, env=os.environ) + except Exception: + logger.exception('failed to generate config template') + raise + + for env in envs: + os_env = os.environ.copy() # new env + + s = PROJECT_NAME + for pairs in env: + os_env[pairs[0]] = pairs[1] # add env + s = s + '-' + pairs[1] + zip_filename = s + '.zip' + + suffix = '.exe' if os_env['GOOS'] == 'windows' else '' + bin_filename = PROJECT_NAME + suffix + + logger.info(f'building {zip_filename}') + try: + subprocess.check_call( + f'go build -ldflags "-s -w -X main.version={VERSION}" -trimpath -o {bin_filename} ../', shell=True, + env=os_env) + + if args.upx: + try: + subprocess.check_call(f'upx -9 -q {bin_filename}', shell=True, stderr=subprocess.DEVNULL, + stdout=subprocess.DEVNULL) + except subprocess.CalledProcessError as e: + logger.error(f'upx failed: {e.args}') + + with zipfile.ZipFile(zip_filename, mode='w', compression=zipfile.ZIP_DEFLATED, + compresslevel=5) as zf: + zf.write(bin_filename) + zf.write('../README.md', 'README.md') + zf.write('./config.yaml', 'config.yaml') + zf.write('../LICENSE', 'LICENSE') + + except subprocess.CalledProcessError as e: + logger.error(f'build {zip_filename} failed: {e.args}') + except Exception: + logger.exception('unknown err') + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + + if len(RELEASE_DIR) != 0: + if not os.path.exists(RELEASE_DIR): + os.mkdir(RELEASE_DIR) + os.chdir(RELEASE_DIR) + + go_build() diff --git a/scripts/openwrt/mosdns-init-openwrt b/scripts/openwrt/mosdns-init-openwrt new file mode 100644 index 0000000..d257697 --- /dev/null +++ b/scripts/openwrt/mosdns-init-openwrt @@ -0,0 +1,47 @@ +#!/bin/sh /etc/rc.common +# +# Copyright (C) 2020-2022, IrineSistiana +# +# This file is part of mosdns. +# +# mosdns is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# mosdns is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +START=99 +USE_PROCD=1 + +##### ONLY CHANGE THIS BLOCK ###### +PROG=/usr/bin/mosdns # where is mosdns +RES_DIR=/etc/mosdns/ # resource dir / working dir / the dir where you store ip/domain lists +CONF=./config.yaml # where is the config file, it can be a relative path to $RES_DIR +##### ONLY CHANGE THIS BLOCK ###### + +start_service() { + procd_open_instance + procd_set_param command $PROG start -d $RES_DIR -c $CONF + + procd_set_param user root + procd_set_param stdout 1 + procd_set_param stderr 1 + procd_set_param respawn "${respawn_threshold:-3600}" "${respawn_timeout:-5}" "${respawn_retry:-5}" + procd_close_instance + echo "mosdns is started!" +} + +reload_service() { + stop + sleep 2s + echo "mosdns is restarted!" + start +} diff --git a/scripts/update_chn_ip_domain.py b/scripts/update_chn_ip_domain.py new file mode 100644 index 0000000..72a76d5 --- /dev/null +++ b/scripts/update_chn_ip_domain.py @@ -0,0 +1,114 @@ +import netaddr +import requests +import logging +import math + +logger = logging.getLogger(__name__) + + +def update_ip_list(): + url = 'https://ftp.apnic.net/apnic/stats/apnic/delegated-apnic-latest' + timeout = 30 + save_to_file = './chn_ip.list' + + logger.info(f'fetching chn ip data from {url}') + + ipNetwork_list = [] + + with requests.get(url, timeout=timeout) as res: + if res.status_code != 200: + raise Exception(f'status code :{res.status_code}') + + logger.info(f'parsing...') + + lines = res.text.splitlines() + for line in lines: + try: + if line.find('|CN|ipv4|') != -1: + elems = line.split('|') + ip_start = elems[3] + count = int(elems[4]) + cidr_prefix_length = int(32 - math.log(count, 2)) + ipNetwork_list.append(netaddr.IPNetwork(f'{ip_start}/{cidr_prefix_length}\n')) + + if line.find('|CN|ipv6|') != -1: + elems = line.split('|') + ip_start = elems[3] + cidr_prefix_length = elems[4] + ipNetwork_list.append(netaddr.IPNetwork(f'{ip_start}/{cidr_prefix_length}\n')) + except IndexError: + logging.warning(f'unexpected format: {line}') + + logger.info('merging') + ipNetwork_list = netaddr.cidr_merge(ipNetwork_list) + logger.info('writing to file') + + with open(save_to_file, 'wt') as f: + f.writelines([f'{x}\n' for x in ipNetwork_list]) + + logger.info('all done') + + +def update_chn_domain_list(): + def get_domains_from(url: str, timeout=30): + logger.info(f'fetching {url}') + + domains = [] + with requests.get(url, timeout=timeout) as res: + if res.status_code != 200: + res.close() + raise Exception(f'status code :{res.status_code}') + + lines = res.text.splitlines() + for line in lines: + try: + if line.find('server=/') != -1: + elems = line.split('/') + domain = elems[1] + domains.append(domain) + except IndexError: + logger.warning(f'unexpected format: {line}') + + return domains + + urls = ['https://raw.githubusercontent.com/felixonmars/dnsmasq-china-list/master/accelerated-domains.china.conf', + 'https://raw.githubusercontent.com/felixonmars/dnsmasq-china-list/master/google.china.conf', + 'https://raw.githubusercontent.com/felixonmars/dnsmasq-china-list/master/apple.china.conf'] + + save_to = './chn_domain.list' + domains = [] + + for url in urls: + domains = domains + get_domains_from(url) + + with open(save_to, 'wt') as f: + f.writelines([f'{x}\n' for x in domains]) + + logger.info('all done') + + +def download_chn_blocked_domain_list(): + url = 'https://github.com/Loyalsoldier/cn-blocked-domain/raw/release/domains.txt' + timeout = 30 + save_to_file = './non_chn_domain.list' + + logger.info(f'fetching chn blocked domain from {url}') + + with requests.get(url, timeout=timeout) as res: + if res.status_code != 200: + res.close() + raise Exception(f'status code :{res.status_code}') + + with open(save_to_file, 'wt') as f: + f.write(res.text) + + +def update_all(): + update_chn_domain_list() + download_chn_blocked_domain_list() + update_ip_list() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + update_all() diff --git a/tools/config.go b/tools/config.go new file mode 100644 index 0000000..0f9def5 --- /dev/null +++ b/tools/config.go @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package tools + +import ( + "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "strings" +) + +func newConvCmd() *cobra.Command { + var ( + in string + out string + ) + + c := &cobra.Command{ + Use: "conv -i input_cfg -o output_cfg", + Args: cobra.NoArgs, + Short: "Convert configuration file format. Supported extensions: " + strings.Join(viper.SupportedExts, ", "), + Run: func(cmd *cobra.Command, args []string) { + if err := convCfg(in, out); err != nil { + mlog.S().Fatal(err) + } + }, + DisableFlagsInUseLine: true, + } + c.Flags().StringVarP(&in, "in", "i", "", "input config") + c.Flags().StringVarP(&out, "out", "o", "", "output config") + c.MarkFlagRequired("in") + c.MarkFlagRequired("out") + c.MarkFlagFilename("in") + c.MarkFlagFilename("out") + return c +} + +func newGenCmd() *cobra.Command { + c := &cobra.Command{ + Use: "gen config_file", + Short: "Generate a template config. Supported extensions: " + strings.Join(viper.SupportedExts, ", "), + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + if err := genCfg(args[0]); err != nil { + mlog.S().Fatal(err) + } + }, + DisableFlagsInUseLine: true, + } + return c +} + +func convCfg(in, out string) error { + v := viper.New() + v.SetConfigFile(in) + if err := v.ReadInConfig(); err != nil { + return err + } + return v.SafeWriteConfigAs(out) +} + +func genCfg(out string) error { + cfg := ` +log: + level: info + +plugins: + - tag: forward_google + type: forward + args: + upstreams: + - addr: https://8.8.8.8/dns-query + + - tag: udp_server + type: udp_server + args: + entry: forward_google + listen: "127.0.0.1:53" + - tag: tcp_server + type: tcp_server + args: + entry: forward_google + listen: "127.0.0.1:53" +` + v := viper.New() + v.SetConfigType("yaml") + if err := v.ReadConfig(strings.NewReader(cfg)); err != nil { + return err + } + + return v.WriteConfigAs(out) +} diff --git a/tools/init.go b/tools/init.go new file mode 100644 index 0000000..744a926 --- /dev/null +++ b/tools/init.go @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package tools + +import ( + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/spf13/cobra" +) + +func init() { + probeCmd := &cobra.Command{ + Use: "probe", + Short: "Run some server tests.", + } + probeCmd.AddCommand( + newConnReuseCmd(), + newIdleTimeoutCmd(), + newPipelineCmd(), + ) + coremain.AddSubCmd(probeCmd) + + configCmd := &cobra.Command{ + Use: "config", + Short: "Tools that can generate/convert mosdns config file.", + } + configCmd.AddCommand(newGenCmd(), newConvCmd()) + coremain.AddSubCmd(configCmd) +} diff --git a/tools/probe.go b/tools/probe.go new file mode 100644 index 0000000..8010f7e --- /dev/null +++ b/tools/probe.go @@ -0,0 +1,238 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package tools + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "net" + "strconv" + "time" + + "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/miekg/dns" + "github.com/spf13/cobra" +) + +func newIdleTimeoutCmd() *cobra.Command { + return &cobra.Command{ + Use: "idle-timeout {tcp|tls}://server_addr[:port]", + Args: cobra.ExactArgs(1), + Short: "Probe server's idle timeout.", + Run: func(cmd *cobra.Command, args []string) { + if err := ProbServerTimeout(args[0]); err != nil { + mlog.S().Fatal(err) + } + }, + DisableFlagsInUseLine: true, + } +} + +func newConnReuseCmd() *cobra.Command { + return &cobra.Command{ + Use: "conn-reuse {tcp|tls}://server_addr[:port]", + Args: cobra.ExactArgs(1), + Short: "Check whether this server supports RFC 1035 connection reuse.", + Run: func(cmd *cobra.Command, args []string) { + if err := ProbServerConnectionReuse(args[0]); err != nil { + mlog.S().Fatal(err) + } + }, + DisableFlagsInUseLine: true, + } +} + +func newPipelineCmd() *cobra.Command { + return &cobra.Command{ + Use: "pipeline {tcp|tls}://server_addr[:port]", + Args: cobra.ExactArgs(1), + Short: "Check whether this server supports RFC 7766 query pipelining.", + Run: func(cmd *cobra.Command, args []string) { + if err := ProbServerPipeline(args[0]); err != nil { + mlog.S().Fatal(err) + } + }, + DisableFlagsInUseLine: true, + } +} + +func getConn(addr string) (net.Conn, error) { + tryAddPort := func(addr string, defaultPort int) string { + _, _, err := net.SplitHostPort(addr) + if err != nil { // no port, add it. + return net.JoinHostPort(addr, strconv.Itoa(defaultPort)) + } + return addr + } + + protocol, host := utils.SplitSchemeAndHost(addr) + if len(protocol) == 0 || len(host) == 0 { + return nil, fmt.Errorf("invalid addr %s", addr) + } + + switch protocol { + case "tcp": + host = tryAddPort(host, 53) + return net.Dial("tcp", host) + case "tls": + host = tryAddPort(host, 853) + serverName, _, _ := net.SplitHostPort(host) + tlsConfig := new(tls.Config) + tlsConfig.InsecureSkipVerify = false + tlsConfig.ServerName = serverName + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, err + } + tlsConn := tls.Client(conn, tlsConfig) + tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) + err = tlsConn.Handshake() + if err != nil { + conn.Close() + return nil, fmt.Errorf("tls handshake failed: %v", err) + } + tlsConn.SetDeadline(time.Time{}) + return tlsConn, nil + default: + return nil, fmt.Errorf("invalid protocol %s", protocol) + } +} + +func ProbServerConnectionReuse(addr string) error { + c, err := getConn(addr) + if err != nil { + return err + } + defer c.Close() + + conn := dns.Conn{Conn: c} + for i := 0; i < 3; i++ { + conn.SetDeadline(time.Now().Add(time.Second * 3)) + + q := new(dns.Msg) + q.SetQuestion("www.cloudflare.com.", dns.TypeA) + q.Id = uint16(i) + + mlog.S().Infof("sending msg #%d", i) + err = conn.WriteMsg(q) + if err != nil { + return fmt.Errorf("failed to write #%d probe msg: %v", i, err) + } + _, err = conn.ReadMsg() + if err != nil { + return fmt.Errorf("failed to read #%d probe msg response: %v", i, err) + } + mlog.S().Infof("recevied response #%d", i) + } + + mlog.S().Infof("server %s supports RFC 1035 connection reuse", addr) + return nil +} + +func ProbServerPipeline(addr string) error { + c, err := getConn(addr) + if err != nil { + return err + } + defer c.Close() + + conn := dns.Conn{Conn: c} + if err != nil { + return err + } + defer conn.Close() + + domains := make([]string, 0) + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return err + } + domains = append(domains, fmt.Sprintf("%x.com.", b)) + domains = append(domains, ".") + + for i, d := range domains { + conn.SetDeadline(time.Now().Add(time.Second * 3)) + + q := new(dns.Msg) + q.SetQuestion(d, dns.TypeNS) + q.Id = uint16(i) + + err = conn.WriteMsg(q) + if err != nil { + return fmt.Errorf("failed to write #%d probe msg: %v", i, err) + } + } + + oooPassed := false + start := time.Now() + for i := range domains { + conn.SetDeadline(time.Now().Add(time.Second * 10)) + m, err := conn.ReadMsg() + if err != nil { + return fmt.Errorf("failed to read #%d probe msg response: %v", i, err) + } + + mlog.S().Infof("#%d response received, latency: %d ms", m.Id, time.Since(start).Milliseconds()) + if m.Id != uint16(i) { + oooPassed = true + } + } + + if oooPassed { + mlog.S().Info("server supports RFC7766 query pipelining") + } else { + mlog.S().Info("no out-of-order response received in this test, server MAY NOT support RFC7766 query pipelining") + } + return nil +} + +func ProbServerTimeout(addr string) error { + c, err := getConn(addr) + if err != nil { + return err + } + defer c.Close() + + conn := dns.Conn{Conn: c} + q := new(dns.Msg) + q.SetQuestion("www.cloudflare.com.", dns.TypeA) + err = conn.WriteMsg(q) + if err != nil { + return fmt.Errorf("failed to write probe msg: %v", err) + } + + mlog.S().Info("testing server idle timeout, awaiting server closing the connection, this may take a while") + start := time.Now() + _, err = conn.ReadMsg() + if err != nil { + return fmt.Errorf("failed to read probe msg response: %v", err) + } + + for { + _, err := conn.ReadMsg() + if err != nil { + break + } + } + mlog.S().Infof("connection closed by peer, it's idle timeout is %.2f sec", time.Since(start).Seconds()) + return nil +}