package handler import ( "net/http" "net/http/httptest" "testing" "codeberg.org/VARASYS/ZDDC/zddc/internal/config" ) func TestCORSMiddleware(t *testing.T) { allowed := config.Config{CORSOrigins: []string{"https://zddc.varasys.io", "https://other.example"}} disabled := config.Config{CORSOrigins: nil} pass := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) cases := []struct { name string cfg config.Config method string origin string acrHeaders string wantStatus int wantAllowOrig string wantAllowCreds string wantVary string wantAllowHdrs string wantNextCalled bool }{ { name: "allowed origin, GET", cfg: allowed, method: http.MethodGet, origin: "https://zddc.varasys.io", wantStatus: http.StatusOK, wantAllowOrig: "https://zddc.varasys.io", wantAllowCreds: "true", wantVary: "Origin", wantNextCalled: true, }, { name: "second allowed origin, GET", cfg: allowed, method: http.MethodGet, origin: "https://other.example", wantStatus: http.StatusOK, wantAllowOrig: "https://other.example", wantAllowCreds: "true", wantVary: "Origin", wantNextCalled: true, }, { name: "disallowed origin, GET", cfg: allowed, method: http.MethodGet, origin: "https://evil.example", wantStatus: http.StatusOK, wantNextCalled: true, }, { name: "no Origin header, GET", cfg: allowed, method: http.MethodGet, wantStatus: http.StatusOK, wantNextCalled: true, }, { name: "OPTIONS preflight, allowed origin", cfg: allowed, method: http.MethodOptions, origin: "https://zddc.varasys.io", acrHeaders: "X-Auth-Request-Email, Content-Type", wantStatus: http.StatusNoContent, wantAllowOrig: "https://zddc.varasys.io", wantAllowCreds: "true", wantVary: "Origin", wantAllowHdrs: "X-Auth-Request-Email, Content-Type", wantNextCalled: false, }, { name: "OPTIONS preflight, disallowed origin falls through", cfg: allowed, method: http.MethodOptions, origin: "https://evil.example", wantStatus: http.StatusOK, wantNextCalled: true, }, { name: "CORS disabled passes through", cfg: disabled, method: http.MethodGet, origin: "https://zddc.varasys.io", wantStatus: http.StatusOK, wantNextCalled: true, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { called := false next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true pass.ServeHTTP(w, r) }) h := CORSMiddleware(tc.cfg, next) req := httptest.NewRequest(tc.method, "/", nil) if tc.origin != "" { req.Header.Set("Origin", tc.origin) } if tc.acrHeaders != "" { req.Header.Set("Access-Control-Request-Headers", tc.acrHeaders) } rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != tc.wantStatus { t.Errorf("status = %d, want %d", rec.Code, tc.wantStatus) } if got := rec.Header().Get("Access-Control-Allow-Origin"); got != tc.wantAllowOrig { t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, tc.wantAllowOrig) } if got := rec.Header().Get("Access-Control-Allow-Credentials"); got != tc.wantAllowCreds { t.Errorf("Access-Control-Allow-Credentials = %q, want %q", got, tc.wantAllowCreds) } if got := rec.Header().Get("Vary"); got != tc.wantVary { t.Errorf("Vary = %q, want %q", got, tc.wantVary) } if got := rec.Header().Get("Access-Control-Allow-Headers"); got != tc.wantAllowHdrs { t.Errorf("Access-Control-Allow-Headers = %q, want %q", got, tc.wantAllowHdrs) } if called != tc.wantNextCalled { t.Errorf("next called = %v, want %v", called, tc.wantNextCalled) } }) } }