package api import ( "log" "net/http" "sync" "time" "github.com/intodns/backend/internal/checker" "github.com/intodns/backend/internal/store" ) // NewRouter builds the HTTP mux with all routes and middleware. func NewRouter(ch *checker.Checker, st *store.Store) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/dnstest/api/check/stream", CheckStreamHandler(ch, st)) mux.HandleFunc("/dnstest/api/check", CheckHandler(ch, st)) mux.HandleFunc("/dnstest/api/health", HealthHandler()) mux.HandleFunc("/dnstest/api/history", HistoryHandler(st)) mux.HandleFunc("/dnstest/api/report", ReportHandler(st)) // Stack middleware: logging -> CORS -> rate limit -> mux. var handler http.Handler = mux handler = rateLimitMiddleware(handler) handler = corsMiddleware(handler) handler = loggingMiddleware(handler) return handler } // --- CORS middleware --- func corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // --- Rate limiting middleware (10 req/min per IP) --- type ipEntry struct { count int windowStart time.Time } var ( rateMu sync.Mutex rateMap = make(map[string]*ipEntry) rateOnce sync.Once ) func startRateCleanup() { rateOnce.Do(func() { go func() { for { time.Sleep(1 * time.Minute) rateMu.Lock() now := time.Now() for ip, entry := range rateMap { if now.Sub(entry.windowStart) > 2*time.Minute { delete(rateMap, ip) } } rateMu.Unlock() } }() }) } func rateLimitMiddleware(next http.Handler) http.Handler { startRateCleanup() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := extractIP(r) rateMu.Lock() entry, ok := rateMap[ip] now := time.Now() if !ok || now.Sub(entry.windowStart) > time.Minute { // New window. rateMap[ip] = &ipEntry{count: 1, windowStart: now} rateMu.Unlock() next.ServeHTTP(w, r) return } entry.count++ if entry.count > 10 { rateMu.Unlock() w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", "60") w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(`{"error":"rate limit exceeded, try again in 1 minute"}`)) return } rateMu.Unlock() next.ServeHTTP(w, r) }) } func extractIP(r *http.Request) string { // Check X-Forwarded-For first. if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP in the chain. for i := 0; i < len(xff); i++ { if xff[i] == ',' { return xff[:i] } } return xff } // Fall back to RemoteAddr (strip port). addr := r.RemoteAddr for i := len(addr) - 1; i >= 0; i-- { if addr[i] == ':' { return addr[:i] } } return addr } // --- Logging middleware --- func loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() lw := &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK} next.ServeHTTP(lw, r) log.Printf("%s %s %d %s %s", r.Method, r.URL.Path, lw.statusCode, time.Since(start).Round(time.Millisecond), extractIP(r)) }) } type loggingResponseWriter struct { http.ResponseWriter statusCode int } func (lw *loggingResponseWriter) WriteHeader(code int) { lw.statusCode = code lw.ResponseWriter.WriteHeader(code) } func (lw *loggingResponseWriter) Flush() { if f, ok := lw.ResponseWriter.(http.Flusher); ok { f.Flush() } }