package agentidp import ( "context" "encoding/json" "net/http" "net/http/httptest" "sync" "testing" "time" ) func newTokenServer(t *testing.T, statusCode int, body interface{}) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost || r.URL.Path != "/api/v1/token" { t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) _ = json.NewEncoder(w).Encode(body) })) } var tokenResp = map[string]interface{}{ "access_token": "eyJ.abc.def", "token_type": "Bearer", "expires_in": 3600, "scope": "agents:read", } func TestTokenManager_GetToken_Issues(t *testing.T) { srv := newTokenServer(t, 200, tokenResp) defer srv.Close() tm := NewTokenManager(srv.URL, "client-id", "secret", "agents:read") tok, err := tm.GetToken(context.Background()) if err != nil { t.Fatalf("unexpected error: %v", err) } if tok != "eyJ.abc.def" { t.Errorf("expected token eyJ.abc.def, got %q", tok) } } func TestTokenManager_GetToken_Caches(t *testing.T) { callCount := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(tokenResp) })) defer srv.Close() tm := NewTokenManager(srv.URL, "client-id", "secret", "agents:read") _, _ = tm.GetToken(context.Background()) _, _ = tm.GetToken(context.Background()) if callCount != 1 { t.Errorf("expected 1 HTTP call (cached), got %d", callCount) } } func TestTokenManager_GetToken_RefreshesNearExpiry(t *testing.T) { callCount := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ resp := map[string]interface{}{ "access_token": "eyJ.abc.def", "token_type": "Bearer", "expires_in": 3600, "scope": "agents:read", } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() tm := NewTokenManager(srv.URL, "client-id", "secret", "agents:read") _, _ = tm.GetToken(context.Background()) // Force the cached token to appear nearly expired tm.mu.Lock() tm.cached = &cachedToken{ accessToken: "old-token", expiresAt: time.Now().Add(30 * time.Second), // < refreshBufferSeconds } tm.mu.Unlock() tok, err := tm.GetToken(context.Background()) if err != nil { t.Fatalf("unexpected error: %v", err) } if tok != "eyJ.abc.def" { t.Errorf("expected refreshed token, got %q", tok) } if callCount != 2 { t.Errorf("expected 2 HTTP calls (initial + refresh), got %d", callCount) } } func TestTokenManager_GetToken_AuthFailure(t *testing.T) { srv := newTokenServer(t, 401, map[string]interface{}{ "error": "invalid_client", "error_description": "Bad credentials.", }) defer srv.Close() tm := NewTokenManager(srv.URL, "client-id", "bad-secret", "agents:read") _, err := tm.GetToken(context.Background()) if err == nil { t.Fatal("expected error, got nil") } apiErr, ok := err.(*AgentIdPError) if !ok { t.Fatalf("expected *AgentIdPError, got %T", err) } if apiErr.Code != "invalid_client" { t.Errorf("expected code invalid_client, got %q", apiErr.Code) } if apiErr.HTTPStatus != 401 { t.Errorf("expected HTTPStatus 401, got %d", apiErr.HTTPStatus) } } func TestTokenManager_ClearCache(t *testing.T) { callCount := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(tokenResp) })) defer srv.Close() tm := NewTokenManager(srv.URL, "client-id", "secret", "agents:read") _, _ = tm.GetToken(context.Background()) tm.ClearCache() _, _ = tm.GetToken(context.Background()) if callCount != 2 { t.Errorf("expected 2 HTTP calls (cache cleared), got %d", callCount) } } func TestTokenManager_GoroutineSafe(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(tokenResp) })) defer srv.Close() tm := NewTokenManager(srv.URL, "client-id", "secret", "agents:read") var wg sync.WaitGroup for i := 0; i < 20; i++ { wg.Add(1) go func() { defer wg.Done() tok, err := tm.GetToken(context.Background()) if err != nil { t.Errorf("goroutine error: %v", err) } if tok != "eyJ.abc.def" { t.Errorf("unexpected token: %q", tok) } }() } wg.Wait() }