140 lines
4 KiB
Go
140 lines
4 KiB
Go
package handler
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"codeberg.org/VARASYS/ZDDC/zddc/internal/config"
|
|
)
|
|
|
|
func TestCORSMiddleware(t *testing.T) {
|
|
allowed := config.Config{CORSOrigins: []string{"https://zddc.varasys.io", "https://other.example"}}
|
|
disabled := config.Config{CORSOrigins: nil}
|
|
|
|
pass := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
})
|
|
|
|
cases := []struct {
|
|
name string
|
|
cfg config.Config
|
|
method string
|
|
origin string
|
|
acrHeaders string
|
|
wantStatus int
|
|
wantAllowOrig string
|
|
wantAllowCreds string
|
|
wantVary string
|
|
wantAllowHdrs string
|
|
wantNextCalled bool
|
|
}{
|
|
{
|
|
name: "allowed origin, GET",
|
|
cfg: allowed,
|
|
method: http.MethodGet,
|
|
origin: "https://zddc.varasys.io",
|
|
wantStatus: http.StatusOK,
|
|
wantAllowOrig: "https://zddc.varasys.io",
|
|
wantAllowCreds: "true",
|
|
wantVary: "Origin",
|
|
wantNextCalled: true,
|
|
},
|
|
{
|
|
name: "second allowed origin, GET",
|
|
cfg: allowed,
|
|
method: http.MethodGet,
|
|
origin: "https://other.example",
|
|
wantStatus: http.StatusOK,
|
|
wantAllowOrig: "https://other.example",
|
|
wantAllowCreds: "true",
|
|
wantVary: "Origin",
|
|
wantNextCalled: true,
|
|
},
|
|
{
|
|
name: "disallowed origin, GET",
|
|
cfg: allowed,
|
|
method: http.MethodGet,
|
|
origin: "https://evil.example",
|
|
wantStatus: http.StatusOK,
|
|
wantNextCalled: true,
|
|
},
|
|
{
|
|
name: "no Origin header, GET",
|
|
cfg: allowed,
|
|
method: http.MethodGet,
|
|
wantStatus: http.StatusOK,
|
|
wantNextCalled: true,
|
|
},
|
|
{
|
|
name: "OPTIONS preflight, allowed origin",
|
|
cfg: allowed,
|
|
method: http.MethodOptions,
|
|
origin: "https://zddc.varasys.io",
|
|
acrHeaders: "X-Auth-Request-Email, Content-Type",
|
|
wantStatus: http.StatusNoContent,
|
|
wantAllowOrig: "https://zddc.varasys.io",
|
|
wantAllowCreds: "true",
|
|
wantVary: "Origin",
|
|
wantAllowHdrs: "X-Auth-Request-Email, Content-Type",
|
|
wantNextCalled: false,
|
|
},
|
|
{
|
|
name: "OPTIONS preflight, disallowed origin falls through",
|
|
cfg: allowed,
|
|
method: http.MethodOptions,
|
|
origin: "https://evil.example",
|
|
wantStatus: http.StatusOK,
|
|
wantNextCalled: true,
|
|
},
|
|
{
|
|
name: "CORS disabled passes through",
|
|
cfg: disabled,
|
|
method: http.MethodGet,
|
|
origin: "https://zddc.varasys.io",
|
|
wantStatus: http.StatusOK,
|
|
wantNextCalled: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
called := false
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
pass.ServeHTTP(w, r)
|
|
})
|
|
h := CORSMiddleware(tc.cfg, next)
|
|
|
|
req := httptest.NewRequest(tc.method, "/", nil)
|
|
if tc.origin != "" {
|
|
req.Header.Set("Origin", tc.origin)
|
|
}
|
|
if tc.acrHeaders != "" {
|
|
req.Header.Set("Access-Control-Request-Headers", tc.acrHeaders)
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != tc.wantStatus {
|
|
t.Errorf("status = %d, want %d", rec.Code, tc.wantStatus)
|
|
}
|
|
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != tc.wantAllowOrig {
|
|
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, tc.wantAllowOrig)
|
|
}
|
|
if got := rec.Header().Get("Access-Control-Allow-Credentials"); got != tc.wantAllowCreds {
|
|
t.Errorf("Access-Control-Allow-Credentials = %q, want %q", got, tc.wantAllowCreds)
|
|
}
|
|
if got := rec.Header().Get("Vary"); got != tc.wantVary {
|
|
t.Errorf("Vary = %q, want %q", got, tc.wantVary)
|
|
}
|
|
if got := rec.Header().Get("Access-Control-Allow-Headers"); got != tc.wantAllowHdrs {
|
|
t.Errorf("Access-Control-Allow-Headers = %q, want %q", got, tc.wantAllowHdrs)
|
|
}
|
|
if called != tc.wantNextCalled {
|
|
t.Errorf("next called = %v, want %v", called, tc.wantNextCalled)
|
|
}
|
|
})
|
|
}
|
|
}
|