mosdns/coremain/config_validator.go
dengxiongjian 0413ee5d44
Some checks failed
Test mosdns / build (push) Has been cancelled
二次开发
2025-10-16 21:07:48 +08:00

303 lines
7.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package coremain
import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"go.uber.org/zap"
)
// ConfigValidator 配置验证器
type ConfigValidator struct {
config *Config
errors []error
warnings []string
logger *zap.Logger
}
// NewConfigValidator 创建配置验证器
func NewConfigValidator(config *Config, logger *zap.Logger) *ConfigValidator {
return &ConfigValidator{
config: config,
errors: make([]error, 0),
warnings: make([]string, 0),
logger: logger,
}
}
// Validate 验证配置
func (v *ConfigValidator) Validate() error {
v.errors = []error{}
v.warnings = []string{}
v.logger.Info("开始配置验证")
// 1. 检查基本结构
v.validateBasicStructure()
// 2. 检查插件引用完整性
v.validatePluginReferences()
// 3. 检查必需插件
v.validateRequiredPlugins()
// 4. 检查文件路径
v.validateFilePaths()
// 5. 检查配置冲突
v.validateConflicts()
// 6. 检查循环依赖
v.validateCircularDependencies()
// 输出验证结果
if len(v.errors) > 0 {
v.logger.Error("配置验证失败",
zap.Int("error_count", len(v.errors)),
zap.Int("warning_count", len(v.warnings)))
var errorMsgs []string
for _, err := range v.errors {
errorMsgs = append(errorMsgs, err.Error())
}
return fmt.Errorf("配置验证失败:\n%s", strings.Join(errorMsgs, "\n"))
}
if len(v.warnings) > 0 {
v.logger.Warn("配置验证警告",
zap.Int("warning_count", len(v.warnings)))
for _, warning := range v.warnings {
v.logger.Warn(warning)
}
}
v.logger.Info("配置验证通过")
return nil
}
// validateBasicStructure 验证基本结构
func (v *ConfigValidator) validateBasicStructure() {
if v.config == nil {
v.errors = append(v.errors, fmt.Errorf("配置不能为空"))
return
}
// 检查日志配置
if v.config.Log.Level == "" {
v.config.Log.Level = "info" // 设置默认值
v.warnings = append(v.warnings, "日志级别未设置,使用默认值: info")
}
// 检查API配置
if v.config.API.HTTP == "" {
v.warnings = append(v.warnings, "API地址未设置API服务将被禁用")
}
// 检查Web配置
if v.config.Web.HTTP == "" {
v.warnings = append(v.warnings, "Web界面地址未设置Web界面将被禁用")
}
}
// validatePluginReferences 验证插件引用完整性
func (v *ConfigValidator) validatePluginReferences() {
existingPlugins := make(map[string]bool)
// 收集所有插件标签
for _, p := range v.config.Plugins {
existingPlugins[p.Tag] = true
}
// 检查每个插件的引用
for _, p := range v.config.Plugins {
deps := v.extractPluginDependencies(p)
for _, dep := range deps {
if !existingPlugins[dep] {
v.errors = append(v.errors,
fmt.Errorf("插件 '%s' 引用了不存在的插件 '%s'", p.Tag, dep))
}
}
}
}
// validateRequiredPlugins 验证必需插件
func (v *ConfigValidator) validateRequiredPlugins() {
requiredTags := []string{"main"}
for _, tag := range requiredTags {
found := false
for _, p := range v.config.Plugins {
if p.Tag == tag {
found = true
break
}
}
if !found {
v.errors = append(v.errors,
fmt.Errorf("缺少必需插件: %s", tag))
}
}
}
// validateFilePaths 验证文件路径
func (v *ConfigValidator) validateFilePaths() {
for _, p := range v.config.Plugins {
switch p.Type {
case "domain_set":
v.validateDomainSetFiles(p)
case "ip_set":
v.validateIPSetFiles(p)
}
}
}
// validateDomainSetFiles 验证域名文件
func (v *ConfigValidator) validateDomainSetFiles(plugin PluginConfig) {
if args, ok := plugin.Args.(map[string]interface{}); ok {
if files, ok := args["files"].([]interface{}); ok {
for _, f := range files {
path := f.(string)
if err := v.validateFilePath(path); err != nil {
v.errors = append(v.errors,
fmt.Errorf("域名文件路径无效 (插件: %s): %w", plugin.Tag, err))
}
}
}
}
}
// validateIPSetFiles 验证IP文件
func (v *ConfigValidator) validateIPSetFiles(plugin PluginConfig) {
if args, ok := plugin.Args.(map[string]interface{}); ok {
if files, ok := args["files"].([]interface{}); ok {
for _, f := range files {
path := f.(string)
if err := v.validateFilePath(path); err != nil {
v.errors = append(v.errors,
fmt.Errorf("IP文件路径无效 (插件: %s): %w", plugin.Tag, err))
}
}
}
}
}
// validateFilePath 验证文件路径
func (v *ConfigValidator) validateFilePath(path string) error {
// 检查是否为绝对路径
if !filepath.IsAbs(path) {
// 转换为绝对路径
absPath, err := filepath.Abs(path)
if err != nil {
return fmt.Errorf("无法解析路径: %s", path)
}
path = absPath
}
// 检查文件是否存在
if _, err := os.Stat(path); os.IsNotExist(err) {
return fmt.Errorf("文件不存在: %s", path)
}
return nil
}
// validateConflicts 验证配置冲突
func (v *ConfigValidator) validateConflicts() {
tagCount := make(map[string]int)
// 检查重复的插件标签
for _, p := range v.config.Plugins {
tagCount[p.Tag]++
if tagCount[p.Tag] > 1 {
v.errors = append(v.errors,
fmt.Errorf("发现重复的插件标签: %s", p.Tag))
}
}
// 检查端口冲突
apiPort := v.extractPort(v.config.API.HTTP)
webPort := v.extractPort(v.config.Web.HTTP)
if apiPort != "" && webPort != "" && apiPort == webPort {
v.errors = append(v.errors,
fmt.Errorf("API端口和Web端口冲突: %s", apiPort))
}
}
// validateCircularDependencies 验证循环依赖
func (v *ConfigValidator) validateCircularDependencies() {
// 转换为utils.PluginConfig
utilsPlugins := make([]utils.PluginConfig, len(v.config.Plugins))
for i, p := range v.config.Plugins {
utilsPlugins[i] = utils.PluginConfig{
Tag: p.Tag,
Type: p.Type,
Args: p.Args,
}
}
// 使用拓扑排序检测循环依赖
_, err := utils.TopologicalSort(utilsPlugins)
if err != nil {
v.errors = append(v.errors,
fmt.Errorf("检测到循环依赖: %w", err))
}
}
// extractPluginDependencies 从插件配置中提取依赖关系
func (v *ConfigValidator) extractPluginDependencies(plugin PluginConfig) []string {
var deps []string
// 将配置转换为字符串进行正则匹配
configStr := fmt.Sprintf("%+v", plugin.Args)
// 正则表达式匹配 $plugin_name 格式的引用
re := regexp.MustCompile(`\$([a-zA-Z0-9_-]+)`)
matches := re.FindAllStringSubmatch(configStr, -1)
for _, match := range matches {
if len(match) > 1 {
dep := match[1]
// 排除一些常见的关键字,避免误识别
if dep != "primary" && dep != "secondary" && dep != "timeout" &&
dep != "china_ip" && dep != "always_standby" && dep != "verbose" {
deps = append(deps, dep)
}
}
}
return deps
}
// extractPort 从地址中提取端口号
func (v *ConfigValidator) extractPort(addr string) string {
if addr == "" {
return ""
}
// 支持 IPv4:port 和 [IPv6]:port 格式
re := regexp.MustCompile(`:(\d+)$`)
matches := re.FindStringSubmatch(addr)
if len(matches) > 1 {
return matches[1]
}
return ""
}
// GetValidationResult 获取验证结果
func (v *ConfigValidator) GetValidationResult() (errors []error, warnings []string) {
return v.errors, v.warnings
}
// IsValid 检查配置是否有效
func (v *ConfigValidator) IsValid() bool {
return len(v.errors) == 0
}