package resolver import ( "fmt" "strings" "time" "github.com/miekg/dns" ) // Resolver wraps miekg/dns with timeout and retry logic. type Resolver struct { Timeout time.Duration Retries int } // NewResolver creates a Resolver with sensible defaults. func NewResolver() *Resolver { return &Resolver{ Timeout: 3 * time.Second, Retries: 1, } } // Query sends a UDP DNS query. If the response is truncated it automatically // retries over TCP. func (r *Resolver) Query(name string, server string, qtype uint16) (*dns.Msg, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(name), qtype) m.RecursionDesired = true c := new(dns.Client) c.Timeout = r.Timeout c.Net = "udp" var resp *dns.Msg var err error for attempt := 0; attempt <= r.Retries; attempt++ { resp, _, err = c.Exchange(m, ensurePort(server)) if err == nil { break } } if err != nil { return nil, fmt.Errorf("udp query %s @%s: %w", dns.TypeToString[qtype], server, err) } // Fall back to TCP on truncation. if resp.Truncated { return r.QueryTCP(name, server, qtype) } return resp, nil } // QueryTCP sends a DNS query over TCP. func (r *Resolver) QueryTCP(name string, server string, qtype uint16) (*dns.Msg, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(name), qtype) m.RecursionDesired = true c := new(dns.Client) c.Timeout = r.Timeout c.Net = "tcp" var resp *dns.Msg var err error for attempt := 0; attempt <= r.Retries; attempt++ { resp, _, err = c.Exchange(m, ensurePort(server)) if err == nil { break } } if err != nil { return nil, fmt.Errorf("tcp query %s @%s: %w", dns.TypeToString[qtype], server, err) } return resp, nil } // QueryNoRecurse sends a UDP query with RD=0 (non-recursive). Falls back to // TCP on truncation. func (r *Resolver) QueryNoRecurse(name string, server string, qtype uint16) (*dns.Msg, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(name), qtype) m.RecursionDesired = false c := new(dns.Client) c.Timeout = r.Timeout c.Net = "udp" var resp *dns.Msg var err error for attempt := 0; attempt <= r.Retries; attempt++ { resp, _, err = c.Exchange(m, ensurePort(server)) if err == nil { break } } if err != nil { return nil, fmt.Errorf("udp query (no recurse) %s @%s: %w", dns.TypeToString[qtype], server, err) } if resp.Truncated { m.RecursionDesired = false c2 := new(dns.Client) c2.Timeout = r.Timeout c2.Net = "tcp" for attempt := 0; attempt <= r.Retries; attempt++ { resp, _, err = c2.Exchange(m, ensurePort(server)) if err == nil { break } } if err != nil { return nil, fmt.Errorf("tcp query (no recurse) %s @%s: %w", dns.TypeToString[qtype], server, err) } } return resp, nil } // QueryEDNS sends a UDP query with EDNS0 buffer size set. func (r *Resolver) QueryEDNS(name string, server string, qtype uint16, bufsize uint16) (*dns.Msg, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(name), qtype) m.RecursionDesired = false m.SetEdns0(bufsize, false) c := new(dns.Client) c.Timeout = r.Timeout c.Net = "udp" var resp *dns.Msg var err error for attempt := 0; attempt <= r.Retries; attempt++ { resp, _, err = c.Exchange(m, ensurePort(server)) if err == nil { break } } if err != nil { return nil, fmt.Errorf("edns query %s @%s: %w", dns.TypeToString[qtype], server, err) } return resp, nil } // QueryVersionBind asks for version.bind TXT in the CH class. func (r *Resolver) QueryVersionBind(server string) (string, error) { m := new(dns.Msg) m.SetQuestion("version.bind.", dns.TypeTXT) m.Question[0].Qclass = dns.ClassCHAOS m.RecursionDesired = false c := new(dns.Client) c.Timeout = r.Timeout c.Net = "udp" var resp *dns.Msg var err error for attempt := 0; attempt <= r.Retries; attempt++ { resp, _, err = c.Exchange(m, ensurePort(server)) if err == nil { break } } if err != nil { return "", err } for _, rr := range resp.Answer { if txt, ok := rr.(*dns.TXT); ok { if len(txt.Txt) > 0 { return txt.Txt[0], nil } } } return "", nil } // QueryAXFR attempts a zone transfer. Returns true if the server allows it. // Uses a short timeout to avoid blocking on unresponsive servers. func (r *Resolver) QueryAXFR(name string, server string) (bool, error) { type axfrResult struct { allowed bool err error } ch := make(chan axfrResult, 1) go func() { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(name), dns.TypeAXFR) tr := new(dns.Transfer) tr.DialTimeout = 3 * time.Second tr.ReadTimeout = 3 * time.Second env, err := tr.In(m, ensurePort(server)) if err != nil { ch <- axfrResult{false, nil} return } for e := range env { if e.Error != nil { ch <- axfrResult{false, nil} return } if len(e.RR) > 0 { ch <- axfrResult{true, nil} return } } ch <- axfrResult{false, nil} }() select { case res := <-ch: return res.allowed, res.err case <-time.After(5 * time.Second): return false, nil } } func ensurePort(server string) string { // Already has port (IPv4:port or [IPv6]:port). if strings.Contains(server, "]:") || (!strings.Contains(server, "[") && strings.Count(server, ":") == 1) { return server } // IPv6 without brackets/port. if strings.Contains(server, ":") && !strings.Contains(server, "[") { return "[" + server + "]:53" } return server + ":53" }