303 lines
7.2 KiB
Go
303 lines
7.2 KiB
Go
package coremain
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"regexp"
|
||
"strings"
|
||
|
||
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// ConfigValidator 配置验证器
|
||
type ConfigValidator struct {
|
||
config *Config
|
||
errors []error
|
||
warnings []string
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewConfigValidator 创建配置验证器
|
||
func NewConfigValidator(config *Config, logger *zap.Logger) *ConfigValidator {
|
||
return &ConfigValidator{
|
||
config: config,
|
||
errors: make([]error, 0),
|
||
warnings: make([]string, 0),
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// Validate 验证配置
|
||
func (v *ConfigValidator) Validate() error {
|
||
v.errors = []error{}
|
||
v.warnings = []string{}
|
||
|
||
v.logger.Info("开始配置验证")
|
||
|
||
// 1. 检查基本结构
|
||
v.validateBasicStructure()
|
||
|
||
// 2. 检查插件引用完整性
|
||
v.validatePluginReferences()
|
||
|
||
// 3. 检查必需插件
|
||
v.validateRequiredPlugins()
|
||
|
||
// 4. 检查文件路径
|
||
v.validateFilePaths()
|
||
|
||
// 5. 检查配置冲突
|
||
v.validateConflicts()
|
||
|
||
// 6. 检查循环依赖
|
||
v.validateCircularDependencies()
|
||
|
||
// 输出验证结果
|
||
if len(v.errors) > 0 {
|
||
v.logger.Error("配置验证失败",
|
||
zap.Int("error_count", len(v.errors)),
|
||
zap.Int("warning_count", len(v.warnings)))
|
||
|
||
var errorMsgs []string
|
||
for _, err := range v.errors {
|
||
errorMsgs = append(errorMsgs, err.Error())
|
||
}
|
||
|
||
return fmt.Errorf("配置验证失败:\n%s", strings.Join(errorMsgs, "\n"))
|
||
}
|
||
|
||
if len(v.warnings) > 0 {
|
||
v.logger.Warn("配置验证警告",
|
||
zap.Int("warning_count", len(v.warnings)))
|
||
|
||
for _, warning := range v.warnings {
|
||
v.logger.Warn(warning)
|
||
}
|
||
}
|
||
|
||
v.logger.Info("配置验证通过")
|
||
return nil
|
||
}
|
||
|
||
// validateBasicStructure 验证基本结构
|
||
func (v *ConfigValidator) validateBasicStructure() {
|
||
if v.config == nil {
|
||
v.errors = append(v.errors, fmt.Errorf("配置不能为空"))
|
||
return
|
||
}
|
||
|
||
// 检查日志配置
|
||
if v.config.Log.Level == "" {
|
||
v.config.Log.Level = "info" // 设置默认值
|
||
v.warnings = append(v.warnings, "日志级别未设置,使用默认值: info")
|
||
}
|
||
|
||
// 检查API配置
|
||
if v.config.API.HTTP == "" {
|
||
v.warnings = append(v.warnings, "API地址未设置,API服务将被禁用")
|
||
}
|
||
|
||
// 检查Web配置
|
||
if v.config.Web.HTTP == "" {
|
||
v.warnings = append(v.warnings, "Web界面地址未设置,Web界面将被禁用")
|
||
}
|
||
}
|
||
|
||
// validatePluginReferences 验证插件引用完整性
|
||
func (v *ConfigValidator) validatePluginReferences() {
|
||
existingPlugins := make(map[string]bool)
|
||
|
||
// 收集所有插件标签
|
||
for _, p := range v.config.Plugins {
|
||
existingPlugins[p.Tag] = true
|
||
}
|
||
|
||
// 检查每个插件的引用
|
||
for _, p := range v.config.Plugins {
|
||
deps := v.extractPluginDependencies(p)
|
||
for _, dep := range deps {
|
||
if !existingPlugins[dep] {
|
||
v.errors = append(v.errors,
|
||
fmt.Errorf("插件 '%s' 引用了不存在的插件 '%s'", p.Tag, dep))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// validateRequiredPlugins 验证必需插件
|
||
func (v *ConfigValidator) validateRequiredPlugins() {
|
||
requiredTags := []string{"main"}
|
||
|
||
for _, tag := range requiredTags {
|
||
found := false
|
||
for _, p := range v.config.Plugins {
|
||
if p.Tag == tag {
|
||
found = true
|
||
break
|
||
}
|
||
}
|
||
if !found {
|
||
v.errors = append(v.errors,
|
||
fmt.Errorf("缺少必需插件: %s", tag))
|
||
}
|
||
}
|
||
}
|
||
|
||
// validateFilePaths 验证文件路径
|
||
func (v *ConfigValidator) validateFilePaths() {
|
||
for _, p := range v.config.Plugins {
|
||
switch p.Type {
|
||
case "domain_set":
|
||
v.validateDomainSetFiles(p)
|
||
case "ip_set":
|
||
v.validateIPSetFiles(p)
|
||
}
|
||
}
|
||
}
|
||
|
||
// validateDomainSetFiles 验证域名文件
|
||
func (v *ConfigValidator) validateDomainSetFiles(plugin PluginConfig) {
|
||
if args, ok := plugin.Args.(map[string]interface{}); ok {
|
||
if files, ok := args["files"].([]interface{}); ok {
|
||
for _, f := range files {
|
||
path := f.(string)
|
||
if err := v.validateFilePath(path); err != nil {
|
||
v.errors = append(v.errors,
|
||
fmt.Errorf("域名文件路径无效 (插件: %s): %w", plugin.Tag, err))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// validateIPSetFiles 验证IP文件
|
||
func (v *ConfigValidator) validateIPSetFiles(plugin PluginConfig) {
|
||
if args, ok := plugin.Args.(map[string]interface{}); ok {
|
||
if files, ok := args["files"].([]interface{}); ok {
|
||
for _, f := range files {
|
||
path := f.(string)
|
||
if err := v.validateFilePath(path); err != nil {
|
||
v.errors = append(v.errors,
|
||
fmt.Errorf("IP文件路径无效 (插件: %s): %w", plugin.Tag, err))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// validateFilePath 验证文件路径
|
||
func (v *ConfigValidator) validateFilePath(path string) error {
|
||
// 检查是否为绝对路径
|
||
if !filepath.IsAbs(path) {
|
||
// 转换为绝对路径
|
||
absPath, err := filepath.Abs(path)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解析路径: %s", path)
|
||
}
|
||
path = absPath
|
||
}
|
||
|
||
// 检查文件是否存在
|
||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||
return fmt.Errorf("文件不存在: %s", path)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// validateConflicts 验证配置冲突
|
||
func (v *ConfigValidator) validateConflicts() {
|
||
tagCount := make(map[string]int)
|
||
|
||
// 检查重复的插件标签
|
||
for _, p := range v.config.Plugins {
|
||
tagCount[p.Tag]++
|
||
if tagCount[p.Tag] > 1 {
|
||
v.errors = append(v.errors,
|
||
fmt.Errorf("发现重复的插件标签: %s", p.Tag))
|
||
}
|
||
}
|
||
|
||
// 检查端口冲突
|
||
apiPort := v.extractPort(v.config.API.HTTP)
|
||
webPort := v.extractPort(v.config.Web.HTTP)
|
||
|
||
if apiPort != "" && webPort != "" && apiPort == webPort {
|
||
v.errors = append(v.errors,
|
||
fmt.Errorf("API端口和Web端口冲突: %s", apiPort))
|
||
}
|
||
}
|
||
|
||
// validateCircularDependencies 验证循环依赖
|
||
func (v *ConfigValidator) validateCircularDependencies() {
|
||
// 转换为utils.PluginConfig
|
||
utilsPlugins := make([]utils.PluginConfig, len(v.config.Plugins))
|
||
for i, p := range v.config.Plugins {
|
||
utilsPlugins[i] = utils.PluginConfig{
|
||
Tag: p.Tag,
|
||
Type: p.Type,
|
||
Args: p.Args,
|
||
}
|
||
}
|
||
|
||
// 使用拓扑排序检测循环依赖
|
||
_, err := utils.TopologicalSort(utilsPlugins)
|
||
if err != nil {
|
||
v.errors = append(v.errors,
|
||
fmt.Errorf("检测到循环依赖: %w", err))
|
||
}
|
||
}
|
||
|
||
// extractPluginDependencies 从插件配置中提取依赖关系
|
||
func (v *ConfigValidator) extractPluginDependencies(plugin PluginConfig) []string {
|
||
var deps []string
|
||
|
||
// 将配置转换为字符串进行正则匹配
|
||
configStr := fmt.Sprintf("%+v", plugin.Args)
|
||
|
||
// 正则表达式匹配 $plugin_name 格式的引用
|
||
re := regexp.MustCompile(`\$([a-zA-Z0-9_-]+)`)
|
||
matches := re.FindAllStringSubmatch(configStr, -1)
|
||
|
||
for _, match := range matches {
|
||
if len(match) > 1 {
|
||
dep := match[1]
|
||
// 排除一些常见的关键字,避免误识别
|
||
if dep != "primary" && dep != "secondary" && dep != "timeout" &&
|
||
dep != "china_ip" && dep != "always_standby" && dep != "verbose" {
|
||
deps = append(deps, dep)
|
||
}
|
||
}
|
||
}
|
||
|
||
return deps
|
||
}
|
||
|
||
// extractPort 从地址中提取端口号
|
||
func (v *ConfigValidator) extractPort(addr string) string {
|
||
if addr == "" {
|
||
return ""
|
||
}
|
||
|
||
// 支持 IPv4:port 和 [IPv6]:port 格式
|
||
re := regexp.MustCompile(`:(\d+)$`)
|
||
matches := re.FindStringSubmatch(addr)
|
||
if len(matches) > 1 {
|
||
return matches[1]
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// GetValidationResult 获取验证结果
|
||
func (v *ConfigValidator) GetValidationResult() (errors []error, warnings []string) {
|
||
return v.errors, v.warnings
|
||
}
|
||
|
||
// IsValid 检查配置是否有效
|
||
func (v *ConfigValidator) IsValid() bool {
|
||
return len(v.errors) == 0
|
||
}
|