/** * Unit tests for billing middleware: * - FreeTierEnforcementMiddleware * - BillingController.handleWebhook */ import { Request, Response, NextFunction } from 'express'; import { Pool, QueryResult } from 'pg'; import { createFreeTierEnforcementMiddleware } from '../../../src/middleware/freeTierEnforcementMiddleware'; import { BillingController } from '../../../src/controllers/BillingController'; import { BillingService } from '../../../src/services/BillingService'; import { UsageService } from '../../../src/services/UsageService'; import { ITokenPayload } from '../../../src/types/index'; // ── Helpers ────────────────────────────────────────────────────────────────── function makePool(queryFn: jest.Mock): Pool { return { query: queryFn } as unknown as Pool; } type RedisClientMock = { get: jest.Mock; set: jest.Mock; }; function makeRedis(overrides: Partial = {}): RedisClientMock { return { get: overrides.get ?? jest.fn().mockResolvedValue(null), set: overrides.set ?? jest.fn().mockResolvedValue('OK'), }; } function makeRequest(overrides: Partial<{ method: string; path: string; organizationId: string | undefined; }>): Request { const organizationId = overrides.organizationId ?? 'org-uuid-123'; const user: ITokenPayload | undefined = organizationId !== undefined ? { sub: 'agent-1', client_id: 'agent-1', scope: 'agents:read', jti: 'jti-1', iat: Math.floor(Date.now() / 1000), exp: Math.floor(Date.now() / 1000) + 3600, organization_id: organizationId, } : undefined; return { method: overrides.method ?? 'GET', path: overrides.path ?? '/api/v1/agents', user, headers: {}, } as unknown as Request; } function makeResponse(): Response { return { status: jest.fn().mockReturnThis(), json: jest.fn().mockReturnThis(), } as unknown as Response; } // ════════════════════════════════════════════════════════════════════════════ // FreeTierEnforcementMiddleware // ════════════════════════════════════════════════════════════════════════════ describe('createFreeTierEnforcementMiddleware', () => { const originalEnv = process.env; beforeEach(() => { process.env = { ...originalEnv }; }); afterEach(() => { process.env = originalEnv; }); it('should call next() immediately when BILLING_ENABLED=false', () => { process.env['BILLING_ENABLED'] = 'false'; const pool = makePool(jest.fn()); const redis = makeRedis(); const next = jest.fn() as NextFunction; const middleware = createFreeTierEnforcementMiddleware(pool, redis as never); middleware(makeRequest({}), makeResponse(), next); // next() called synchronously because billing is disabled expect(next).toHaveBeenCalledTimes(1); expect((next as jest.Mock).mock.calls[0]).toHaveLength(0); }); it('should call next() without error for unauthenticated request', () => { process.env['BILLING_ENABLED'] = 'true'; const pool = makePool(jest.fn()); const redis = makeRedis(); const next = jest.fn() as NextFunction; const req: Request = { method: 'GET', path: '/api/v1/agents', user: undefined, headers: {}, } as unknown as Request; const middleware = createFreeTierEnforcementMiddleware(pool, redis as never); middleware(req, makeResponse(), next); expect(next).toHaveBeenCalledTimes(1); expect((next as jest.Mock).mock.calls[0]).toHaveLength(0); }); it('should call next(error) with FREE_TIER_API_LIMIT when daily API calls >= 1000', async () => { process.env['BILLING_ENABLED'] = 'true'; // Query sequence: // 1: isFreeTenant → no subscription row (free) // 2: getDailyUsage → usage_events count = 1000 // 3: getDailyUsage → agent count (getActiveAgentCount called inside getDailyUsage) const mockQuery = jest.fn() .mockResolvedValueOnce({ rows: [] } as unknown as QueryResult) // isFreeTenant .mockResolvedValueOnce({ rows: [{ count: '1000' }] } as unknown as QueryResult) // usage_events .mockResolvedValueOnce({ rows: [{ count: '5' }] } as unknown as QueryResult); // agents count // Cache miss → hits DB const redis = makeRedis({ get: jest.fn().mockResolvedValue(null) }); const nextCalled = new Promise((resolve) => { const next = ((err?: unknown) => { resolve(err); }) as NextFunction; const middleware = createFreeTierEnforcementMiddleware(makePool(mockQuery), redis as never); middleware(makeRequest({ method: 'GET', path: '/api/v1/agents' }), makeResponse(), next); }); const callArg = await nextCalled; expect(callArg).toBeDefined(); expect((callArg as { code: string }).code).toBe('FREE_TIER_API_LIMIT'); }); it('should call next(error) with FREE_TIER_AGENT_LIMIT when agent count >= 10 on POST /agents', async () => { process.env['BILLING_ENABLED'] = 'true'; // Query sequence: // 1: isFreeTenant → free // 2: usage_events → 500 calls (below limit, so continue) // 3: agents count for getDailyUsage // 4: agents count for isAgentCreation check const mockQuery = jest.fn() .mockResolvedValueOnce({ rows: [] } as unknown as QueryResult) // isFreeTenant .mockResolvedValueOnce({ rows: [{ count: '500' }] } as unknown as QueryResult) // usage_events .mockResolvedValueOnce({ rows: [{ count: '3' }] } as unknown as QueryResult) // getDailyUsage agent count .mockResolvedValueOnce({ rows: [{ count: '10' }] } as unknown as QueryResult); // isAgentCreation check const redis = makeRedis({ get: jest.fn().mockResolvedValue(null) }); const nextCalled = new Promise((resolve) => { const next = ((err?: unknown) => { resolve(err); }) as NextFunction; const middleware = createFreeTierEnforcementMiddleware(makePool(mockQuery), redis as never); middleware(makeRequest({ method: 'POST', path: '/agents' }), makeResponse(), next); }); const callArg = await nextCalled; expect(callArg).toBeDefined(); expect((callArg as { code: string }).code).toBe('FREE_TIER_AGENT_LIMIT'); }); it('should call next() without error for paid tenant regardless of limits', async () => { process.env['BILLING_ENABLED'] = 'true'; // Active subscription → paid const mockQuery = jest.fn().mockResolvedValue({ rows: [{ status: 'active' }], } as unknown as QueryResult); const redis = makeRedis(); const nextCalled = new Promise((resolve) => { const next = ((err?: unknown) => { resolve(err); }) as NextFunction; const middleware = createFreeTierEnforcementMiddleware(makePool(mockQuery), redis as never); middleware(makeRequest({ method: 'POST', path: '/agents' }), makeResponse(), next); }); const callArg = await nextCalled; // next() called with no error expect(callArg).toBeUndefined(); }); it('should use Redis cache and skip DB usage query on cache hit', async () => { process.env['BILLING_ENABLED'] = 'true'; // Only 1 DB query: isFreeTenant (no subscription) // The second query for usage is replaced by a Redis cache hit const mockQuery = jest.fn() .mockResolvedValueOnce({ rows: [] } as unknown as QueryResult); // isFreeTenant // Cache returns api_calls = 100 (below 1000 limit) const redis = makeRedis({ get: jest.fn().mockResolvedValue('100') }); const nextCalled = new Promise((resolve) => { const next = ((err?: unknown) => { resolve(err); }) as NextFunction; const middleware = createFreeTierEnforcementMiddleware(makePool(mockQuery), redis as never); middleware(makeRequest({ method: 'GET', path: '/api/v1/agents' }), makeResponse(), next); }); const callArg = await nextCalled; // Only 1 query (isFreeTenant); no DB call for usage expect(mockQuery).toHaveBeenCalledTimes(1); expect(callArg).toBeUndefined(); }); }); // ════════════════════════════════════════════════════════════════════════════ // BillingController.handleWebhook // ════════════════════════════════════════════════════════════════════════════ describe('BillingController.handleWebhook', () => { const originalEnv = process.env; beforeEach(() => { process.env = { ...originalEnv, STRIPE_WEBHOOK_SECRET: 'whsec_test' }; }); afterEach(() => { process.env = originalEnv; }); function makeBillingController( handleWebhookEventFn: jest.Mock = jest.fn().mockResolvedValue(undefined), ): BillingController { const billingService = { handleWebhookEvent: handleWebhookEventFn, createCheckoutSession: jest.fn(), getSubscriptionStatus: jest.fn().mockResolvedValue({ tenantId: 'test', status: 'free', currentPeriodEnd: null, stripeSubscriptionId: null, }), } as unknown as BillingService; const usageService = { getDailyUsage: jest.fn().mockResolvedValue({ tenantId: 'test', date: '2026-04-02', apiCalls: 0, agentCount: 0, }), getActiveAgentCount: jest.fn().mockResolvedValue(0), } as unknown as UsageService; return new BillingController(billingService, usageService); } it('should return 200 { received: true } for valid Stripe-Signature', async () => { const controller = makeBillingController(); const req = { headers: { 'stripe-signature': 'valid-sig' }, body: Buffer.from('{}'), } as unknown as Request; const res = makeResponse(); const next = jest.fn() as NextFunction; await controller.handleWebhook(req, res, next); expect(res.status).toHaveBeenCalledWith(200); expect(res.json).toHaveBeenCalledWith({ received: true }); expect(next).not.toHaveBeenCalled(); }); it('should call next() with ValidationError when Stripe-Signature header is missing', async () => { const controller = makeBillingController(); const req = { headers: {}, body: Buffer.from('{}'), } as unknown as Request; const res = makeResponse(); const next = jest.fn() as NextFunction; await controller.handleWebhook(req, res, next); expect(next).toHaveBeenCalledTimes(1); const callArg = (next as jest.Mock).mock.calls[0]?.[0]; expect((callArg as { code: string }).code).toBe('VALIDATION_ERROR'); }); it('should call next() with the error when BillingService throws', async () => { const stripeError = new Error('Invalid signature'); const controller = makeBillingController(jest.fn().mockRejectedValue(stripeError)); const req = { headers: { 'stripe-signature': 'bad-sig' }, body: Buffer.from('{}'), } as unknown as Request; const res = makeResponse(); const next = jest.fn() as NextFunction; await controller.handleWebhook(req, res, next); expect(next).toHaveBeenCalledWith(stripeError); }); });