package cache
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"codeberg.org/VARASYS/ZDDC/zddc/internal/config"
)
// newTestCache spins up an httptest server as the upstream and
// returns the cache + the upstream's URL. The upstream's behavior is
// the caller's to define.
func newTestCache(t *testing.T, mode string, upstreamHandler http.HandlerFunc) (*Cache, *httptest.Server) {
t.Helper()
upstream := httptest.NewServer(upstreamHandler)
t.Cleanup(upstream.Close)
root := t.TempDir()
c, err := New(config.Config{
Root: root,
Upstream: upstream.URL,
Mode: mode,
})
if err != nil {
t.Fatalf("New: %v", err)
}
return c, upstream
}
func TestNew_RequiresUpstream(t *testing.T) {
if _, err := New(config.Config{Root: t.TempDir()}); err == nil {
t.Error("expected error for empty upstream")
}
}
func TestNew_StripsTrailingSlash(t *testing.T) {
c, err := New(config.Config{
Root: t.TempDir(),
Upstream: "http://example.com/",
})
if err != nil {
t.Fatalf("New: %v", err)
}
if got := c.Upstream(); got != "http://example.com" {
t.Errorf("Upstream() = %q, want trailing slash stripped", got)
}
}
func TestNew_BearerFile(t *testing.T) {
dir := t.TempDir()
tokenPath := filepath.Join(dir, "token")
if err := os.WriteFile(tokenPath, []byte(" abc123\n"), 0o600); err != nil {
t.Fatalf("write token: %v", err)
}
c, err := New(config.Config{
Root: t.TempDir(),
Upstream: "http://example.com",
BearerFile: tokenPath,
})
if err != nil {
t.Fatalf("New: %v", err)
}
if c.bearer != "abc123" {
t.Errorf("bearer = %q, want abc123 (whitespace trimmed)", c.bearer)
}
}
func TestNew_BearerFileEmptyRejected(t *testing.T) {
dir := t.TempDir()
empty := filepath.Join(dir, "empty")
_ = os.WriteFile(empty, []byte("\n\n"), 0o600)
if _, err := New(config.Config{
Root: t.TempDir(),
Upstream: "http://example.com",
BearerFile: empty,
}); err == nil {
t.Error("expected error for empty bearer file")
}
}
func TestServeHTTP_RejectsWriteMethods(t *testing.T) {
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
t.Errorf("upstream should not be called for write methods")
})
for _, method := range []string{http.MethodPut, http.MethodPost, http.MethodDelete} {
rec := httptest.NewRecorder()
r := httptest.NewRequest(method, "/foo", nil)
c.ServeHTTP(rec, r)
if rec.Code != http.StatusMethodNotAllowed {
t.Errorf("%s = %d, want 405", method, rec.Code)
}
if got := rec.Header().Get("Allow"); got != "GET, HEAD" {
t.Errorf("%s Allow = %q", method, got)
}
}
}
func TestServeHTTP_MissThenHit(t *testing.T) {
var hits int32
c, upstream := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
if r.URL.Path != "/foo.txt" {
t.Errorf("upstream got %q, want /foo.txt", r.URL.Path)
}
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Last-Modified", "Mon, 02 Jan 2006 15:04:05 GMT")
_, _ = w.Write([]byte("hello"))
})
_ = upstream
// First request: miss.
rec := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/foo.txt", nil)
c.ServeHTTP(rec, r)
if rec.Code != http.StatusOK {
t.Fatalf("first GET = %d", rec.Code)
}
if got := rec.Header().Get(HeaderName); got != "miss" {
t.Errorf("first cache header = %q, want miss", got)
}
if got := rec.Body.String(); got != "hello" {
t.Errorf("body = %q", got)
}
// Cache file should exist.
cached := filepath.Join(c.root, "foo.txt")
if _, err := os.Stat(cached); err != nil {
t.Fatalf("expected cached file: %v", err)
}
// Second request: hit. Wait briefly to let the marker write race finish.
rec2 := httptest.NewRecorder()
r2 := httptest.NewRequest(http.MethodGet, "/foo.txt", nil)
c.ServeHTTP(rec2, r2)
if rec2.Code != http.StatusOK {
t.Fatalf("second GET = %d", rec2.Code)
}
if got := rec2.Header().Get(HeaderName); got != "hit" {
t.Errorf("second cache header = %q, want hit", got)
}
if got := rec2.Body.String(); got != "hello" {
t.Errorf("second body = %q", got)
}
// Marker file should be present.
marker := filepath.Join(c.root, MarkerFile)
mb, err := os.ReadFile(marker)
if err != nil {
t.Fatalf("marker missing: %v", err)
}
if !strings.Contains(string(mb), "upstream:") {
t.Errorf("marker contents unexpected: %s", string(mb))
}
}
func TestServeHTTP_ProxyModeDoesNotPersist(t *testing.T) {
c, _ := newTestCache(t, "proxy", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("payload"))
})
rec := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/foo.txt", nil)
c.ServeHTTP(rec, r)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d", rec.Code)
}
if got := rec.Header().Get(HeaderName); got != "proxy" {
t.Errorf("cache header = %q, want proxy", got)
}
cached := filepath.Join(c.root, "foo.txt")
if _, err := os.Stat(cached); !os.IsNotExist(err) {
t.Errorf("proxy mode wrote to cache: %v", err)
}
// Marker also shouldn't exist (no caching happened).
if _, err := os.Stat(filepath.Join(c.root, MarkerFile)); !os.IsNotExist(err) {
t.Errorf("marker file written in proxy mode")
}
}
func TestServeHTTP_DirectoryListingsCachedAsSidecar(t *testing.T) {
var hits int32
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
w.Header().Set("Content-Type", "text/html")
_, _ = w.Write([]byte("listing"))
})
// First request: miss, body served + sidecar written.
rec := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/Project/", nil)
r.Header.Set("Accept", "text/html")
c.ServeHTTP(rec, r)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d", rec.Code)
}
if got := rec.Header().Get(HeaderName); got != "miss" {
t.Errorf("first cache header = %q, want miss", got)
}
sidecar := filepath.Join(c.root, "Project", listingCachePrefix+"html")
if _, err := os.Stat(sidecar); err != nil {
t.Fatalf("expected listing sidecar: %v", err)
}
// Second request: hit.
rec2 := httptest.NewRecorder()
r2 := httptest.NewRequest(http.MethodGet, "/Project/", nil)
r2.Header.Set("Accept", "text/html")
c.ServeHTTP(rec2, r2)
if got := rec2.Header().Get(HeaderName); got != "hit" {
t.Errorf("second cache header = %q, want hit", got)
}
if rec2.Body.String() != "listing" {
t.Errorf("body = %q", rec2.Body.String())
}
}
func TestServeHTTP_ListingFormatVariesByAccept(t *testing.T) {
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.Header.Get("Accept"), "application/json") {
_, _ = w.Write([]byte(`[{"name":"foo"}]`))
} else {
_, _ = w.Write([]byte("html"))
}
})
// JSON request → JSON sidecar.
rec := httptest.NewRecorder()
rj := httptest.NewRequest(http.MethodGet, "/Project/", nil)
rj.Header.Set("Accept", "application/json")
c.ServeHTTP(rec, rj)
if !strings.Contains(rec.Body.String(), "foo") {
t.Errorf("json body = %q", rec.Body.String())
}
// HTML request → HTML sidecar (separately).
rec2 := httptest.NewRecorder()
rh := httptest.NewRequest(http.MethodGet, "/Project/", nil)
rh.Header.Set("Accept", "text/html")
c.ServeHTTP(rec2, rh)
if !strings.Contains(rec2.Body.String(), "html") {
t.Errorf("html body = %q", rec2.Body.String())
}
// Both sidecars exist.
if _, err := os.Stat(filepath.Join(c.root, "Project", listingCachePrefix+"json")); err != nil {
t.Errorf("json sidecar missing: %v", err)
}
if _, err := os.Stat(filepath.Join(c.root, "Project", listingCachePrefix+"html")); err != nil {
t.Errorf("html sidecar missing: %v", err)
}
}
func TestServeHTTP_ListingOfflineServesStale(t *testing.T) {
root := t.TempDir()
if err := os.MkdirAll(filepath.Join(root, "Project"), 0o755); err != nil {
t.Fatalf("mkdir: %v", err)
}
if err := os.WriteFile(filepath.Join(root, "Project", listingCachePrefix+"html"), []byte(""), 0o644); err != nil {
t.Fatalf("seed: %v", err)
}
c, err := New(config.Config{Root: root, Upstream: "http://127.0.0.1:1", Mode: "cache"})
if err != nil {
t.Fatalf("New: %v", err)
}
c.client.Timeout = 200 * time.Millisecond
rec := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/Project/", nil)
r.Header.Set("Accept", "text/html")
c.ServeHTTP(rec, r)
if rec.Code != http.StatusOK {
t.Fatalf("offline listing = %d, want 200", rec.Code)
}
if !strings.Contains(rec.Body.String(), "") {
t.Errorf("body = %q", rec.Body.String())
}
}
func TestServeHTTP_HEAD_HitDoesNotReturnBody(t *testing.T) {
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("hello"))
})
// Seed the cache via GET.
rec := httptest.NewRecorder()
c.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/foo.txt", nil))
if rec.Code != http.StatusOK {
t.Fatalf("seed: %d", rec.Code)
}
// HEAD: should be a hit, no body.
rec2 := httptest.NewRecorder()
c.ServeHTTP(rec2, httptest.NewRequest(http.MethodHead, "/foo.txt", nil))
if rec2.Code != http.StatusOK {
t.Fatalf("HEAD: %d", rec2.Code)
}
if rec2.Body.Len() != 0 {
t.Errorf("HEAD body length = %d, want 0", rec2.Body.Len())
}
}
func TestServeHTTP_OfflineServesStale(t *testing.T) {
root := t.TempDir()
// Pre-seed a cached file.
if err := os.WriteFile(filepath.Join(root, "stale.txt"), []byte("stale-content"), 0o644); err != nil {
t.Fatalf("seed: %v", err)
}
c, err := New(config.Config{
Root: root,
Upstream: "http://127.0.0.1:1", // unreachable port
Mode: "cache",
})
if err != nil {
t.Fatalf("New: %v", err)
}
// Speed up the timeout so the test doesn't hang.
c.client.Timeout = 200 * time.Millisecond
rec := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/stale.txt", nil)
c.ServeHTTP(rec, r)
if rec.Code != http.StatusOK {
t.Fatalf("offline-with-cache = %d, want 200", rec.Code)
}
if got := rec.Header().Get(HeaderName); got != "hit" {
// On hit we don't even hit the network. That's expected.
t.Logf("first attempt was %q (likely cache hit before any network)", got)
}
if got := rec.Body.String(); got != "stale-content" {
t.Errorf("body = %q", got)
}
}
func TestServeHTTP_OfflineMissReturns503(t *testing.T) {
root := t.TempDir()
c, err := New(config.Config{
Root: root,
Upstream: "http://127.0.0.1:1",
Mode: "cache",
})
if err != nil {
t.Fatalf("New: %v", err)
}
c.client.Timeout = 200 * time.Millisecond
rec := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/never-cached.txt", nil)
c.ServeHTTP(rec, r)
if rec.Code != http.StatusServiceUnavailable {
t.Errorf("offline-no-cache = %d, want 503", rec.Code)
}
if got := rec.Header().Get(HeaderName); got != "offline" {
t.Errorf("cache header = %q, want offline", got)
}
}
func TestServeHTTP_BearerForwarded(t *testing.T) {
dir := t.TempDir()
tokenPath := filepath.Join(dir, "token")
_ = os.WriteFile(tokenPath, []byte("secrettoken"), 0o600)
var seenAuth string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seenAuth = r.Header.Get("Authorization")
_, _ = w.Write([]byte("ok"))
}))
defer upstream.Close()
c, err := New(config.Config{
Root: t.TempDir(),
Upstream: upstream.URL,
Mode: "cache",
BearerFile: tokenPath,
})
if err != nil {
t.Fatalf("New: %v", err)
}
rec := httptest.NewRecorder()
c.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/foo.txt", nil))
if seenAuth != "Bearer secrettoken" {
t.Errorf("Authorization = %q, want Bearer secrettoken", seenAuth)
}
}
func TestServeHTTP_PreservesQuery(t *testing.T) {
var seenURL string
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
seenURL = r.URL.RequestURI()
w.Header().Set("Cache-Control", "no-store") // no-cache the JSON response
_, _ = w.Write([]byte(`{}`))
})
rec := httptest.NewRecorder()
c.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/foo.txt?q=bar", nil))
if seenURL != "/foo.txt?q=bar" {
t.Errorf("upstream saw %q, want /foo.txt?q=bar", seenURL)
}
}
func TestServeHTTP_HonorsNoStore(t *testing.T) {
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-store")
_, _ = w.Write([]byte("ephemeral"))
})
rec := httptest.NewRecorder()
c.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/dynamic.json", nil))
if rec.Code != http.StatusOK {
t.Fatalf("status: %d", rec.Code)
}
if got := rec.Header().Get(HeaderName); got != "proxy" {
t.Errorf("cache header = %q, want proxy (no-store should bypass cache)", got)
}
cached := filepath.Join(c.root, "dynamic.json")
if _, err := os.Stat(cached); !os.IsNotExist(err) {
t.Errorf("no-store response was cached")
}
}
func TestServeHTTP_PathTraversalRejected(t *testing.T) {
called := false
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
called = true
_, _ = w.Write([]byte("data"))
})
rec := httptest.NewRecorder()
c.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/../etc/passwd", nil))
// The upstream may still be called (the proxy doesn't gatekeep), but
// we MUST NOT cache to a path that escapes the root.
_ = called
root := c.root
parent := filepath.Dir(root)
if _, err := os.Stat(filepath.Join(parent, "etc", "passwd")); !os.IsNotExist(err) {
t.Error("path traversal wrote outside cache root")
}
}
func TestServeHTTP_ForwardsErrorStatus(t *testing.T) {
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
})
rec := httptest.NewRecorder()
c.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/secret.txt", nil))
if rec.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", rec.Code)
}
cached := filepath.Join(c.root, "secret.txt")
if _, err := os.Stat(cached); !os.IsNotExist(err) {
t.Error("403 response was cached")
}
}
func TestRevalidate_PurgesOn403(t *testing.T) {
root := t.TempDir()
if err := os.WriteFile(filepath.Join(root, "victim.txt"), []byte("cached"), 0o644); err != nil {
t.Fatalf("seed: %v", err)
}
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer upstream.Close()
c, err := New(config.Config{Root: root, Upstream: upstream.URL, Mode: "cache"})
if err != nil {
t.Fatalf("New: %v", err)
}
c.revalidate("/victim.txt", time.Now())
if _, err := os.Stat(filepath.Join(root, "victim.txt")); !os.IsNotExist(err) {
t.Error("revalidate did not purge after 403")
}
}
func TestRevalidate_PurgesOn404(t *testing.T) {
root := t.TempDir()
if err := os.WriteFile(filepath.Join(root, "gone.txt"), []byte("cached"), 0o644); err != nil {
t.Fatalf("seed: %v", err)
}
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer upstream.Close()
c, err := New(config.Config{Root: root, Upstream: upstream.URL, Mode: "cache"})
if err != nil {
t.Fatalf("New: %v", err)
}
c.revalidate("/gone.txt", time.Now())
if _, err := os.Stat(filepath.Join(root, "gone.txt")); !os.IsNotExist(err) {
t.Error("revalidate did not purge after 404")
}
}
func TestRevalidate_NoPurgeOn200ButRefreshes(t *testing.T) {
root := t.TempDir()
old := []byte("old-content")
if err := os.WriteFile(filepath.Join(root, "fresh.txt"), old, 0o644); err != nil {
t.Fatalf("seed: %v", err)
}
// Set the file's mtime to an hour ago.
hourAgo := time.Now().Add(-time.Hour)
_ = os.Chtimes(filepath.Join(root, "fresh.txt"), hourAgo, hourAgo)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("new-content"))
}))
defer upstream.Close()
c, err := New(config.Config{Root: root, Upstream: upstream.URL, Mode: "cache"})
if err != nil {
t.Fatalf("New: %v", err)
}
c.revalidate("/fresh.txt", hourAgo)
got, _ := os.ReadFile(filepath.Join(root, "fresh.txt"))
if string(got) != "new-content" {
t.Errorf("revalidate did not refresh: got %q", string(got))
}
}
func TestRevalidate_NoOpOn304(t *testing.T) {
root := t.TempDir()
original := []byte("original")
if err := os.WriteFile(filepath.Join(root, "still.txt"), original, 0o644); err != nil {
t.Fatalf("seed: %v", err)
}
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Always return 304; assume client sent If-Modified-Since.
if r.Header.Get("If-Modified-Since") == "" {
t.Errorf("revalidate did not send If-Modified-Since")
}
w.WriteHeader(http.StatusNotModified)
}))
defer upstream.Close()
c, err := New(config.Config{Root: root, Upstream: upstream.URL, Mode: "cache"})
if err != nil {
t.Fatalf("New: %v", err)
}
c.revalidate("/still.txt", time.Now())
got, _ := os.ReadFile(filepath.Join(root, "still.txt"))
if string(got) != "original" {
t.Errorf("304 caused content change: got %q", string(got))
}
}
func TestRangeRequest_Hit(t *testing.T) {
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte("0123456789"))
})
// Seed cache.
rec := httptest.NewRecorder()
c.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/data.txt", nil))
if rec.Code != http.StatusOK {
t.Fatalf("seed: %d", rec.Code)
}
// Range request.
rec2 := httptest.NewRecorder()
r2 := httptest.NewRequest(http.MethodGet, "/data.txt", nil)
r2.Header.Set("Range", "bytes=2-5")
c.ServeHTTP(rec2, r2)
if rec2.Code != http.StatusPartialContent {
t.Fatalf("range = %d, want 206", rec2.Code)
}
if rec2.Body.String() != "2345" {
t.Errorf("range body = %q", rec2.Body.String())
}
if got := rec2.Header().Get("Content-Range"); !strings.HasPrefix(got, "bytes 2-5/") {
t.Errorf("Content-Range = %q", got)
}
}
func TestServeHTTP_ConcurrentRequestsForSameURL(t *testing.T) {
// Stress the marker-once and tmpfile path with parallel misses.
var hits int32
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
_, _ = io.WriteString(w, "concurrent")
})
var wg sync.WaitGroup
for i := 0; i < 8; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
c.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/c.txt", nil))
if rec.Code != http.StatusOK {
t.Errorf("status = %d", rec.Code)
}
if rec.Body.String() != "concurrent" {
t.Errorf("body = %q", rec.Body.String())
}
}()
}
wg.Wait()
// File should exist with the right content.
got, err := os.ReadFile(filepath.Join(c.root, "c.txt"))
if err != nil {
t.Fatalf("read: %v", err)
}
if string(got) != "concurrent" {
t.Errorf("cached body = %q", string(got))
}
}
func TestCachePathFor_Boundaries(t *testing.T) {
c, _ := newTestCache(t, "cache", func(w http.ResponseWriter, r *http.Request) {})
cases := []struct {
urlPath string
ok bool
}{
{"", false},
{"/", false},
{"/../etc/passwd", false},
{"/foo/../bar", false},
{"/foo/bar.txt", true},
{"/" + MarkerFile, false},
{"/Project/foo.txt", true},
}
for _, tc := range cases {
_, ok := c.cachePathFor(tc.urlPath)
if ok != tc.ok {
t.Errorf("cachePathFor(%q) ok=%v, want %v", tc.urlPath, ok, tc.ok)
}
}
}