package checker import ( "fmt" "net" "strings" "time" "github.com/intodns/backend/internal/resolver" "github.com/miekg/dns" ) // checkMX runs the 11 MX checks. func checkMX(domain string, r *resolver.Resolver) Category { cat := Category{Name: "mx", Title: "Mail (MX)"} domain = dns.Fqdn(domain) // Get MX records. resp, err := r.Query(domain, "8.8.8.8", dns.TypeMX) if err != nil || resp == nil { cat.Checks = append(cat.Checks, CheckResult{ ID: "mx-present", Title: "MX Records Present", Status: StatusInfo, Message: fmt.Sprintf("Failed to query MX records: %v", err), }) return cat } var mxRecords []*dns.MX for _, rr := range resp.Answer { if mx, ok := rr.(*dns.MX); ok { mxRecords = append(mxRecords, mx) } } // 1. mx-present cat.Checks = append(cat.Checks, checkMXPresent(mxRecords)) if len(mxRecords) == 0 { return cat } // 2. mx-reachable cat.Checks = append(cat.Checks, checkMXReachable(mxRecords, r)) // 3. mx-no-cname cat.Checks = append(cat.Checks, checkMXNoCNAME(mxRecords, r)) // 4. mx-no-ip cat.Checks = append(cat.Checks, checkMXNoIP(mxRecords)) // 5. mx-priority cat.Checks = append(cat.Checks, checkMXPriority(mxRecords)) // 6. mx-reverse-dns cat.Checks = append(cat.Checks, checkMXReverseDNS(mxRecords, r)) // 7. mx-public-ip cat.Checks = append(cat.Checks, checkMXPublicIP(mxRecords, r)) // 8. mx-consistent cat.Checks = append(cat.Checks, checkMXConsistent(domain, r)) // 9. mx-a-records cat.Checks = append(cat.Checks, checkMXARecords(mxRecords, r)) // 10. mx-aaaa-records cat.Checks = append(cat.Checks, checkMXAAAARecords(mxRecords, r)) // 11. mx-localhost cat.Checks = append(cat.Checks, checkMXLocalhost(mxRecords, r)) return cat } func checkMXPresent(mxRecords []*dns.MX) CheckResult { start := time.Now() res := CheckResult{ID: "mx-present", Title: "MX Records Present"} defer func() { res.DurationMs = measureDuration(start) }() if len(mxRecords) == 0 { res.Status = StatusInfo res.Message = "No MX records found (domain may not handle email)" return res } res.Status = StatusPass res.Message = fmt.Sprintf("%d MX records found", len(mxRecords)) for _, mx := range mxRecords { res.Details = append(res.Details, fmt.Sprintf("Priority %d: %s", mx.Preference, mx.Mx)) } return res } func checkMXReachable(mxRecords []*dns.MX, r *resolver.Resolver) CheckResult { start := time.Now() res := CheckResult{ID: "mx-reachable", Title: "MX Reachability"} defer func() { res.DurationMs = measureDuration(start) }() allResolvable := true for _, mx := range mxRecords { ips := resolveNS(mx.Mx, r) // reuse the name resolution helper if len(ips) == 0 { allResolvable = false res.Details = append(res.Details, fmt.Sprintf("%s: no IP addresses found", mx.Mx)) } else { res.Details = append(res.Details, fmt.Sprintf("%s: %s", mx.Mx, strings.Join(ips, ", "))) } } if allResolvable { res.Status = StatusPass res.Message = "All MX hosts resolve to IP addresses" } else { res.Status = StatusFail res.Message = "Some MX hosts do not resolve" } return res } func checkMXNoCNAME(mxRecords []*dns.MX, r *resolver.Resolver) CheckResult { start := time.Now() res := CheckResult{ID: "mx-no-cname", Title: "MX No CNAME"} defer func() { res.DurationMs = measureDuration(start) }() hasCNAME := false for _, mx := range mxRecords { resp, err := r.Query(mx.Mx, "8.8.8.8", dns.TypeCNAME) if err != nil { continue } for _, rr := range resp.Answer { if cname, ok := rr.(*dns.CNAME); ok { hasCNAME = true res.Details = append(res.Details, fmt.Sprintf("%s is a CNAME to %s (RFC 2181 violation)", mx.Mx, cname.Target)) } } } if hasCNAME { res.Status = StatusFail res.Message = "Some MX records point to CNAMEs (violates RFC 2181)" } else { res.Status = StatusPass res.Message = "No MX records point to CNAMEs" } return res } func checkMXNoIP(mxRecords []*dns.MX) CheckResult { start := time.Now() res := CheckResult{ID: "mx-no-ip", Title: "MX No IP Literal"} defer func() { res.DurationMs = measureDuration(start) }() hasIPLiteral := false for _, mx := range mxRecords { name := strings.TrimSuffix(mx.Mx, ".") if net.ParseIP(name) != nil { hasIPLiteral = true res.Details = append(res.Details, fmt.Sprintf("%s is an IP literal", name)) } } if hasIPLiteral { res.Status = StatusFail res.Message = "MX records contain IP literals (must be hostnames)" } else { res.Status = StatusPass res.Message = "No MX records contain IP literals" } return res } func checkMXPriority(mxRecords []*dns.MX) CheckResult { start := time.Now() res := CheckResult{ID: "mx-priority", Title: "MX Priority Diversity"} defer func() { res.DurationMs = measureDuration(start) }() priorities := make(map[uint16]bool) for _, mx := range mxRecords { priorities[mx.Preference] = true res.Details = append(res.Details, fmt.Sprintf("Priority %d: %s", mx.Preference, mx.Mx)) } if len(mxRecords) == 1 { res.Status = StatusInfo res.Message = "Only one MX record; no priority diversity needed" } else if len(priorities) >= 2 { res.Status = StatusPass res.Message = fmt.Sprintf("MX records use %d different priority levels for redundancy", len(priorities)) } else { res.Status = StatusInfo res.Message = "All MX records share the same priority (round-robin)" } return res } func checkMXReverseDNS(mxRecords []*dns.MX, r *resolver.Resolver) CheckResult { start := time.Now() res := CheckResult{ID: "mx-reverse-dns", Title: "MX Reverse DNS (PTR)"} defer func() { res.DurationMs = measureDuration(start) }() allHavePTR := true checked := 0 for _, mx := range mxRecords { ips := resolveNS(mx.Mx, r) for _, ip := range ips { checked++ ptrName := reverseDNS(ip) if ptrName == "" { continue } resp, err := r.Query(ptrName, "8.8.8.8", dns.TypePTR) if err != nil { allHavePTR = false res.Details = append(res.Details, fmt.Sprintf("%s (%s): PTR lookup failed", mx.Mx, ip)) continue } found := false for _, rr := range resp.Answer { if ptr, ok := rr.(*dns.PTR); ok { found = true res.Details = append(res.Details, fmt.Sprintf("%s (%s): PTR -> %s", mx.Mx, ip, ptr.Ptr)) } } if !found { allHavePTR = false res.Details = append(res.Details, fmt.Sprintf("%s (%s): no PTR record", mx.Mx, ip)) } } } if checked == 0 { res.Status = StatusWarn res.Message = "No MX IPs to check for reverse DNS" } else if allHavePTR { res.Status = StatusPass res.Message = "All MX IPs have reverse DNS (PTR) records" } else { res.Status = StatusWarn res.Message = "Some MX IPs lack reverse DNS (PTR) records" } return res } func checkMXPublicIP(mxRecords []*dns.MX, r *resolver.Resolver) CheckResult { start := time.Now() res := CheckResult{ID: "mx-public-ip", Title: "MX Public IPs"} defer func() { res.DurationMs = measureDuration(start) }() allPublic := true checked := 0 for _, mx := range mxRecords { ips := resolveNS(mx.Mx, r) for _, ipStr := range ips { checked++ ip := net.ParseIP(ipStr) if isPublicIP(ip) { res.Details = append(res.Details, fmt.Sprintf("%s (%s): public", mx.Mx, ipStr)) } else { allPublic = false res.Details = append(res.Details, fmt.Sprintf("%s (%s): PRIVATE/RESERVED", mx.Mx, ipStr)) } } } if checked == 0 { res.Status = StatusWarn res.Message = "No MX IPs to check" } else if allPublic { res.Status = StatusPass res.Message = "All MX IPs are publicly routable" } else { res.Status = StatusFail res.Message = "Some MX IPs are not publicly routable" } return res } func checkMXConsistent(domain string, r *resolver.Resolver) CheckResult { start := time.Now() res := CheckResult{ID: "mx-consistent", Title: "MX Consistency"} defer func() { res.DurationMs = measureDuration(start) }() // Get NS list. nsResp, err := r.Query(domain, "8.8.8.8", dns.TypeNS) if err != nil || nsResp == nil { res.Status = StatusWarn res.Message = "Could not retrieve NS for consistency check" return res } var nsNames []string for _, rr := range nsResp.Answer { if ns, ok := rr.(*dns.NS); ok { nsNames = appendUniqLower(nsNames, ns.Ns) } } if len(nsNames) < 2 { res.Status = StatusInfo res.Message = "Fewer than 2 NS; consistency check skipped" return res } var mxSets []string allSame := true var referenceSet string for _, ns := range nsNames { ips := resolveNS(ns, r) for _, ip := range ips { resp, err := r.QueryNoRecurse(domain, ip, dns.TypeMX) if err != nil { continue } var mxNames []string for _, rr := range resp.Answer { if mx, ok := rr.(*dns.MX); ok { mxNames = append(mxNames, fmt.Sprintf("%d:%s", mx.Preference, strings.ToLower(mx.Mx))) } } sorted := sortedStrings(mxNames) setStr := strings.Join(sorted, ",") mxSets = append(mxSets, setStr) res.Details = append(res.Details, fmt.Sprintf("%s: %s", ns, strings.Join(sorted, " "))) if referenceSet == "" { referenceSet = setStr } else if setStr != referenceSet { allSame = false } break } } if allSame { res.Status = StatusPass res.Message = "All nameservers return the same MX set" } else { res.Status = StatusFail res.Message = "Nameservers return different MX sets" } return res } func checkMXARecords(mxRecords []*dns.MX, r *resolver.Resolver) CheckResult { start := time.Now() res := CheckResult{ID: "mx-a-records", Title: "MX A Records"} defer func() { res.DurationMs = measureDuration(start) }() allHaveA := true for _, mx := range mxRecords { resp, err := r.Query(mx.Mx, "8.8.8.8", dns.TypeA) if err != nil { allHaveA = false res.Details = append(res.Details, fmt.Sprintf("%s: error resolving A record", mx.Mx)) continue } found := false for _, rr := range resp.Answer { if a, ok := rr.(*dns.A); ok { found = true res.Details = append(res.Details, fmt.Sprintf("%s: %s", mx.Mx, a.A.String())) } } if !found { allHaveA = false res.Details = append(res.Details, fmt.Sprintf("%s: no A record", mx.Mx)) } } if allHaveA { res.Status = StatusPass res.Message = "All MX hosts have A records" } else { res.Status = StatusWarn res.Message = "Some MX hosts lack A records" } return res } func checkMXAAAARecords(mxRecords []*dns.MX, r *resolver.Resolver) CheckResult { start := time.Now() res := CheckResult{ID: "mx-aaaa-records", Title: "MX AAAA Records"} defer func() { res.DurationMs = measureDuration(start) }() anyAAAA := false for _, mx := range mxRecords { resp, err := r.Query(mx.Mx, "8.8.8.8", dns.TypeAAAA) if err != nil { res.Details = append(res.Details, fmt.Sprintf("%s: error resolving", mx.Mx)) continue } found := false for _, rr := range resp.Answer { if aaaa, ok := rr.(*dns.AAAA); ok { found = true anyAAAA = true res.Details = append(res.Details, fmt.Sprintf("%s: %s", mx.Mx, aaaa.AAAA.String())) } } if !found { res.Details = append(res.Details, fmt.Sprintf("%s: no AAAA record", mx.Mx)) } } if anyAAAA { res.Status = StatusInfo res.Message = "Some MX hosts have AAAA records (IPv6 capable)" } else { res.Status = StatusInfo res.Message = "No MX hosts have AAAA records" } return res } func checkMXLocalhost(mxRecords []*dns.MX, r *resolver.Resolver) CheckResult { start := time.Now() res := CheckResult{ID: "mx-localhost", Title: "MX Not Localhost"} defer func() { res.DurationMs = measureDuration(start) }() hasLocalhost := false for _, mx := range mxRecords { name := strings.ToLower(strings.TrimSuffix(mx.Mx, ".")) if name == "localhost" { hasLocalhost = true res.Details = append(res.Details, fmt.Sprintf("MX %s points to localhost", mx.Mx)) continue } // Also check if any resolved IP is loopback. ips := resolveNS(mx.Mx, r) for _, ipStr := range ips { ip := net.ParseIP(ipStr) if ip != nil && ip.IsLoopback() { hasLocalhost = true res.Details = append(res.Details, fmt.Sprintf("MX %s resolves to loopback %s", mx.Mx, ipStr)) } } } if hasLocalhost { res.Status = StatusFail res.Message = "MX record points to localhost" } else { res.Status = StatusPass res.Message = "No MX records point to localhost" } return res }