package main import ( "context" "encoding/json" "errors" "fmt" "log" "net/http" "os" "os/exec" "path/filepath" "strings" "time" ) const ( configDir = "/tmp/nexavpn-vpn-dns" corefilePath = configDir + "/Corefile" overridesPath = configDir + "/service-overrides.hosts" ) type dnsResponse struct { Records []dnsRecord `json:"records"` } type dnsRecord struct { Domain string `json:"domain"` TargetIP string `json:"target_ip"` } func main() { ctx := context.Background() if err := os.MkdirAll(configDir, 0o755); err != nil { log.Fatalf("unable to create coredns config dir: %v", err) } if err := writeCorefile(); err != nil { log.Fatalf("unable to write Corefile: %v", err) } if err := refreshOverrides(ctx); err != nil { log.Printf("initial dns override sync failed: %v", err) } go func() { ticker := time.NewTicker(15 * time.Second) defer ticker.Stop() for range ticker.C { if err := refreshOverrides(ctx); err != nil { log.Printf("dns override sync failed: %v", err) } } }() cmd := exec.Command("/coredns", "-conf", corefilePath) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr log.Printf("starting coredns on %s", envOrDefault("NEXAVPN_VPN_DNS_ADDR", ":53")) if err := cmd.Run(); err != nil { log.Fatalf("coredns exited: %v", err) } } func writeCorefile() error { upstreams := parseList(envOrDefault("NEXAVPN_VPN_DNS_UPSTREAMS", "172.16.0.100,172.16.0.105")) if len(upstreams) == 0 { return errors.New("no upstream dns servers configured") } corefile := fmt.Sprintf(`%s { errors hosts ` + overridesPath + ` { ttl 30 reload 15s fallthrough } forward . %s cache 30 } `, envOrDefault("NEXAVPN_VPN_DNS_ADDR", ":53"), strings.Join(upstreams, " ")) return os.WriteFile(corefilePath, []byte(corefile), 0o644) } func refreshOverrides(ctx context.Context) error { syncURL := strings.TrimRight(envOrDefault("NEXAVPN_DNS_SYNC_URL", "http://127.0.0.1:8080/api/v1/gateway-agent/dns/services"), "/") req, err := http.NewRequestWithContext(ctx, http.MethodGet, syncURL, nil) if err != nil { return err } req.Header.Set("X-Gateway-Bootstrap-Token", os.Getenv("GATEWAY_BOOTSTRAP_TOKEN")) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("service dns sync failed with status %s", resp.Status) } var payload dnsResponse if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { return err } lines := make([]string, 0, len(payload.Records)) for _, record := range payload.Records { domain := normalizeDomain(record.Domain) targetIP := strings.TrimSpace(record.TargetIP) if domain == "" || targetIP == "" { continue } lines = append(lines, targetIP+" "+domain) } content := strings.Join(lines, "\n") if content != "" { content += "\n" } return os.WriteFile(filepath.Clean(overridesPath), []byte(content), 0o644) } func parseList(raw string) []string { seen := make(map[string]struct{}) values := make([]string, 0) for _, part := range strings.Split(raw, ",") { value := strings.TrimSpace(part) if value == "" { continue } if _, ok := seen[value]; ok { continue } seen[value] = struct{}{} values = append(values, value) } return values } func normalizeDomain(value string) string { value = strings.TrimSpace(strings.ToLower(value)) value = strings.TrimSuffix(value, ".") return value } func envOrDefault(key string, fallback string) string { if value := strings.TrimSpace(os.Getenv(key)); value != "" { return value } return fallback }