diff --git a/dashboard/src/App.tsx b/dashboard/src/App.tsx index c90dcd8..40aa32c 100644 --- a/dashboard/src/App.tsx +++ b/dashboard/src/App.tsx @@ -9,6 +9,7 @@ import AgentDetail from '@/pages/AgentDetail'; import Credentials from '@/pages/Credentials'; import AuditLog from '@/pages/AuditLog'; import Health from '@/pages/Health'; +import { UsagePanel } from '@/components/UsagePanel'; /** Top-level router — defines all application routes. */ export default function App(): React.JSX.Element { @@ -23,6 +24,7 @@ export default function App(): React.JSX.Element { } /> } /> } /> + } /> } /> diff --git a/dashboard/src/components/UsagePanel.tsx b/dashboard/src/components/UsagePanel.tsx new file mode 100644 index 0000000..a18448e --- /dev/null +++ b/dashboard/src/components/UsagePanel.tsx @@ -0,0 +1,192 @@ +import * as React from 'react'; +import { useAuth } from '@/lib/auth'; +import { TokenManager } from '@sentryagent/idp-sdk'; + +/** Shape of the GET /api/v1/billing/usage response. */ +interface UsageResponse { + tenantId: string; + date: string; + apiCalls: number; + agentCount: number; + subscriptionStatus: string; + currentPeriodEnd: string | null; + stripeSubscriptionId: string | null; +} + +type LoadState = 'idle' | 'loading' | 'success' | 'error'; + +interface UsageState { + loadState: LoadState; + data: UsageResponse | null; + errorMessage: string | null; +} + +const initialState: UsageState = { + loadState: 'idle', + data: null, + errorMessage: null, +}; + +/** + * Fetches the current usage summary from the API using the stored credentials. + * + * @param baseUrl - The API base URL. + * @param clientId - The agent client ID. + * @param clientSecret - The agent client secret. + * @returns The usage response from the server. + */ +async function fetchUsage( + baseUrl: string, + clientId: string, + clientSecret: string, +): Promise { + const tokenManager = new TokenManager( + baseUrl, + clientId, + clientSecret, + 'agents:read', + ); + const token = await tokenManager.getToken(); + + const response = await fetch(`${baseUrl}/api/v1/billing/usage`, { + headers: { Authorization: `Bearer ${token}` }, + }); + + if (!response.ok) { + throw new Error(`Failed to fetch usage data (HTTP ${response.status})`); + } + + return response.json() as Promise; +} + +/** Badge shown for the tenant's subscription tier. */ +function SubscriptionBadge({ status }: { status: string }): React.JSX.Element { + const isPro = status !== 'free'; + + return ( + + {isPro ? 'Pro' : 'Free Tier'} + + ); +} + +/** A single metric card with label and value. */ +function MetricCard({ label, value }: { label: string; value: string | number }): React.JSX.Element { + return ( +
+

{label}

+

{value}

+
+ ); +} + +/** + * Displays the current tenant's usage summary: + * - API calls today + * - Active agent count + * - Subscription status (Free Tier / Pro) + * + * Fetches GET /api/v1/billing/usage with the current Bearer token. + * Handles loading state and error state gracefully. + */ +export function UsagePanel(): React.JSX.Element { + const { credentials } = useAuth(); + const [state, setState] = React.useState(initialState); + + const loadUsage = React.useCallback(async (): Promise => { + if (!credentials) return; + + setState((prev) => ({ ...prev, loadState: 'loading', errorMessage: null })); + + try { + const data = await fetchUsage( + credentials.baseUrl, + credentials.clientId, + credentials.clientSecret, + ); + setState({ loadState: 'success', data, errorMessage: null }); + } catch (err) { + const message = err instanceof Error ? err.message : 'Unknown error occurred.'; + setState({ loadState: 'error', data: null, errorMessage: message }); + } + }, [credentials]); + + React.useEffect(() => { + void loadUsage(); + }, [loadUsage]); + + const isLoading = state.loadState === 'loading' || state.loadState === 'idle'; + + return ( +
+
+

Usage & Billing

+ +
+ + {/* Error state */} + {state.loadState === 'error' && ( +
+ {state.errorMessage ?? 'Failed to load usage data.'} +
+ )} + + {/* Loading skeleton */} + {isLoading && ( +
+ {[1, 2, 3].map((i) => ( +
+ ))} +
+ )} + + {/* Data */} + {state.loadState === 'success' && state.data !== null && ( + <> +
+

+ Showing usage for {state.data.date} +

+ +
+ +
+ + + +
+ + {state.data.subscriptionStatus === 'free' && ( +
+

+ You are on the Free Tier — limited to 10 agents and 1,000 API calls/day. +

+

+ Upgrade to Pro for unlimited agents and API calls. +

+
+ )} + + {state.data.currentPeriodEnd !== null && ( +

+ Current period ends:{' '} + {new Date(state.data.currentPeriodEnd).toLocaleDateString()} +

+ )} + + )} +
+ ); +} diff --git a/dashboard/src/components/layout/AppShell.tsx b/dashboard/src/components/layout/AppShell.tsx index b0db26c..30270da 100644 --- a/dashboard/src/components/layout/AppShell.tsx +++ b/dashboard/src/components/layout/AppShell.tsx @@ -12,6 +12,7 @@ const NAV_ITEMS: NavItem[] = [ { to: '/dashboard/agents', label: 'Agents' }, { to: '/dashboard/audit', label: 'Audit Log' }, { to: '/dashboard/health', label: 'Health' }, + { to: '/dashboard/usage', label: 'Usage' }, ]; /** diff --git a/openspec/changes/phase-4-developer-growth/tasks.md b/openspec/changes/phase-4-developer-growth/tasks.md index d0ae63e..bf18909 100644 --- a/openspec/changes/phase-4-developer-growth/tasks.md +++ b/openspec/changes/phase-4-developer-growth/tasks.md @@ -93,19 +93,19 @@ ## 10. WS6: Billing & Usage Metering -- [ ] 10.1 Create migration `007_add_billing.sql` — `tenant_subscriptions` table (tenant_id, status, stripe_customer_id, stripe_subscription_id, current_period_end) and `usage_events` table (tenant_id, date, metric_type, count) -- [ ] 10.2 Install `stripe` npm package — add to package.json -- [ ] 10.3 Create `UsageMeteringMiddleware` — increments in-memory per-tenant counters on every authenticated request; flushes to `usage_events` every 60s -- [ ] 10.4 Create `UsageService` with `getDailyUsage(tenantId, date)` and `getActivAgentCount(tenantId)` methods -- [ ] 10.5 Create `FreeTierEnforcementMiddleware` — checks usage cache (Redis, 60s TTL) before agent creation and API calls; rejects with HTTP 429 when limit exceeded; skips when `BILLING_ENABLED=false` -- [ ] 10.6 Add `agentidp_billing_limit_rejections_total` Prometheus counter (labels: `tenant_id`, `limit_type`) -- [ ] 10.7 Create `BillingService` with `createCheckoutSession(tenantId)`, `handleWebhookEvent(event)`, `getSubscriptionStatus(tenantId)` methods -- [ ] 10.8 Create `POST /billing/checkout` endpoint — creates Stripe Checkout session, returns `checkoutUrl` -- [ ] 10.9 Create `POST /billing/webhook` endpoint — verifies Stripe signature, processes subscription events, updates `tenant_subscriptions` -- [ ] 10.10 Create `GET /billing/usage` endpoint (authenticated) — returns current period usage summary for tenant -- [ ] 10.11 Add `BILLING_ENABLED` env var — disable enforcement and Stripe processing when false; document in `.env.example` -- [ ] 10.12 Write unit tests for UsageService, BillingService, FreeTierEnforcementMiddleware — free tier block, paid tier pass-through, webhook processing -- [ ] 10.13 Update web dashboard — add "Usage" tab to navigation with billing status panel and usage metrics from `GET /billing/usage` +- [x] 10.1 Create migration `007_add_billing.sql` — `tenant_subscriptions` table (tenant_id, status, stripe_customer_id, stripe_subscription_id, current_period_end) and `usage_events` table (tenant_id, date, metric_type, count) +- [x] 10.2 Install `stripe` npm package — add to package.json +- [x] 10.3 Create `UsageMeteringMiddleware` — increments in-memory per-tenant counters on every authenticated request; flushes to `usage_events` every 60s +- [x] 10.4 Create `UsageService` with `getDailyUsage(tenantId, date)` and `getActivAgentCount(tenantId)` methods +- [x] 10.5 Create `FreeTierEnforcementMiddleware` — checks usage cache (Redis, 60s TTL) before agent creation and API calls; rejects with HTTP 429 when limit exceeded; skips when `BILLING_ENABLED=false` +- [x] 10.6 Add `agentidp_billing_limit_rejections_total` Prometheus counter (labels: `tenant_id`, `limit_type`) +- [x] 10.7 Create `BillingService` with `createCheckoutSession(tenantId)`, `handleWebhookEvent(event)`, `getSubscriptionStatus(tenantId)` methods +- [x] 10.8 Create `POST /billing/checkout` endpoint — creates Stripe Checkout session, returns `checkoutUrl` +- [x] 10.9 Create `POST /billing/webhook` endpoint — verifies Stripe signature, processes subscription events, updates `tenant_subscriptions` +- [x] 10.10 Create `GET /billing/usage` endpoint (authenticated) — returns current period usage summary for tenant +- [x] 10.11 Add `BILLING_ENABLED` env var — disable enforcement and Stripe processing when false; document in `.env.example` +- [x] 10.12 Write unit tests for UsageService, BillingService, FreeTierEnforcementMiddleware — free tier block, paid tier pass-through, webhook processing +- [x] 10.13 Update web dashboard — add "Usage" tab to navigation with billing status panel and usage metrics from `GET /billing/usage` ## 11. QA & Release diff --git a/package-lock.json b/package-lock.json index 9233fae..370eecc 100644 --- a/package-lock.json +++ b/package-lock.json @@ -30,6 +30,7 @@ "prom-client": "^15.1.3", "rate-limiter-flexible": "^5.0.5", "redis": "^4.6.13", + "stripe": "^21.0.1", "ulid": "^3.0.2", "uuid": "^9.0.1", "web-did-resolver": "^2.0.32" @@ -1833,7 +1834,7 @@ "version": "20.19.37", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.19.37.tgz", "integrity": "sha512-8kzdPJ3FsNsVIurqBs7oodNnCEVbni9yUEkaHbgptDACOPW04jimGagZ51E6+lXUwJjgnBw+hyko/lkFWCldqw==", - "dev": true, + "devOptional": true, "license": "MIT", "dependencies": { "undici-types": "~6.21.0" @@ -7619,6 +7620,23 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/stripe": { + "version": "21.0.1", + "resolved": "https://registry.npmjs.org/stripe/-/stripe-21.0.1.tgz", + "integrity": "sha512-ocv0j7dWttswDWV2XL/kb6+yiLpDXNXL3RQAOB5OB2kr49z0cEatdQc12+zP/j5nrXk6rAsT4N3y/NUvBbK7Pw==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@types/node": ">=18" + }, + "peerDependenciesMeta": { + "@types/node": { + "optional": true + } + } + }, "node_modules/superagent": { "version": "8.1.2", "resolved": "https://registry.npmjs.org/superagent/-/superagent-8.1.2.tgz", @@ -8044,7 +8062,7 @@ "version": "6.21.0", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", - "dev": true, + "devOptional": true, "license": "MIT" }, "node_modules/unpipe": { diff --git a/package.json b/package.json index 283073b..9c7906b 100644 --- a/package.json +++ b/package.json @@ -38,6 +38,7 @@ "prom-client": "^15.1.3", "rate-limiter-flexible": "^5.0.5", "redis": "^4.6.13", + "stripe": "^21.0.1", "ulid": "^3.0.2", "uuid": "^9.0.1", "web-did-resolver": "^2.0.32" diff --git a/src/app.ts b/src/app.ts index 35a7d35..a279826 100644 --- a/src/app.ts +++ b/src/app.ts @@ -8,6 +8,7 @@ import express, { Application } from 'express'; import helmet from 'helmet'; import cors from 'cors'; import morgan from 'morgan'; +import Stripe from 'stripe'; import { getPool } from './db/pool.js'; import { getRedisClient } from './cache/redis.js'; @@ -21,6 +22,8 @@ import { OrgRepository } from './repositories/OrgRepository.js'; import { AuditService } from './services/AuditService.js'; import { AgentService } from './services/AgentService.js'; import { MarketplaceService } from './services/MarketplaceService.js'; +import { BillingService } from './services/BillingService.js'; +import { UsageService } from './services/UsageService.js'; import { CredentialService } from './services/CredentialService.js'; import { OAuth2Service } from './services/OAuth2Service.js'; import { OrgService } from './services/OrgService.js'; @@ -35,6 +38,7 @@ import { createKafkaProducer } from './adapters/KafkaAdapter.js'; import { AgentController } from './controllers/AgentController.js'; import { MarketplaceController } from './controllers/MarketplaceController.js'; +import { BillingController } from './controllers/BillingController.js'; import { OIDCTrustPolicyController } from './controllers/OIDCTrustPolicyController.js'; import { OIDCTokenExchangeController } from './controllers/OIDCTokenExchangeController.js'; import { TokenController } from './controllers/TokenController.js'; @@ -49,6 +53,7 @@ import { ComplianceController } from './controllers/ComplianceController.js'; import { createAgentsRouter } from './routes/agents.js'; import { createMarketplaceRouter } from './routes/marketplace.js'; +import { createBillingRouter } from './routes/billing.js'; import { createOIDCTrustPoliciesRouter } from './routes/oidcTrustPolicies.js'; import { createOIDCTokenExchangeRouter } from './routes/oidcTokenExchange.js'; import { OIDCTrustPolicyService } from './services/OIDCTrustPolicyService.js'; @@ -69,6 +74,8 @@ import { createOpaMiddleware } from './middleware/opa.js'; import { metricsMiddleware } from './middleware/metrics.js'; import { createOrgContextMiddleware } from './middleware/orgContext.js'; import { authMiddleware } from './middleware/auth.js'; +import { createUsageMeteringMiddleware, startUsageMeteringFlush } from './middleware/usageMeteringMiddleware.js'; +import { createFreeTierEnforcementMiddleware } from './middleware/freeTierEnforcementMiddleware.js'; import { tlsEnforcementMiddleware } from './middleware/TLSEnforcementMiddleware.js'; import { createVaultClientFromEnv } from './vault/VaultClient.js'; import { getEncryptionService } from './services/EncryptionService.js'; @@ -232,6 +239,17 @@ export async function createApp(): Promise { const webhookController = new WebhookController(webhookService); const marketplaceController = new MarketplaceController(marketplaceService); + // ──────────────────────────────────────────────────────────────── + // Billing & Usage Metering (WS6) + // ──────────────────────────────────────────────────────────────── + const stripe = new Stripe(process.env['STRIPE_SECRET_KEY'] ?? '', { apiVersion: '2026-03-25.dahlia' }); + const billingService = new BillingService(pool, stripe); + const usageService = new UsageService(pool); + const billingController = new BillingController(billingService, usageService); + + // Start periodic flush of in-memory usage counters to DB (every 60s) + startUsageMeteringFlush(pool); + // OIDC trust policy management + GitHub Actions token exchange const oidcTrustPolicyService = new OIDCTrustPolicyService(pool); const oidcTrustPolicyController = new OIDCTrustPolicyController(oidcTrustPolicyService); @@ -254,6 +272,18 @@ export async function createApp(): Promise { // ──────────────────────────────────────────────────────────────── app.use(createOrgContextMiddleware(pool)); + // ──────────────────────────────────────────────────────────────── + // Usage metering — records per-tenant API call counts in-memory + // Applied after auth middleware so req.user is populated. + // ──────────────────────────────────────────────────────────────── + app.use(createUsageMeteringMiddleware(pool)); + + // ──────────────────────────────────────────────────────────────── + // Free tier enforcement — rejects requests exceeding free plan limits + // Applied after usage metering and before routes. + // ──────────────────────────────────────────────────────────────── + app.use(createFreeTierEnforcementMiddleware(pool, redis as RedisClientType)); + // ──────────────────────────────────────────────────────────────── // Routes // ──────────────────────────────────────────────────────────────── @@ -287,6 +317,9 @@ export async function createApp(): Promise { app.use(`${API_BASE}`, createComplianceRouter(complianceController)); app.use(`${API_BASE}/marketplace`, createMarketplaceRouter(marketplaceController)); + // Billing & Usage Metering — checkout, webhook, usage summary + app.use(`${API_BASE}/billing`, createBillingRouter(billingController, authMiddleware)); + // OIDC trust-policy management (authenticated) and token exchange (unauthenticated) // Both routers mount under ${API_BASE}/oidc — trust-policy routes use /trust-policies prefix, // token exchange uses /token, so there are no path conflicts. diff --git a/src/controllers/BillingController.ts b/src/controllers/BillingController.ts new file mode 100644 index 0000000..94f7df6 --- /dev/null +++ b/src/controllers/BillingController.ts @@ -0,0 +1,141 @@ +/** + * Billing controller for SentryAgent.ai AgentIdP. + * Handles Stripe checkout session creation, webhook processing, and usage queries. + */ + +import { Request, Response, NextFunction } from 'express'; +import { BillingService } from '../services/BillingService.js'; +import { UsageService } from '../services/UsageService.js'; +import { ValidationError, AuthorizationError } from '../utils/errors.js'; + +/** + * Controller for billing and usage endpoints. + * Delegates all business logic to BillingService and UsageService. + */ +export class BillingController { + /** + * @param billingService - The billing service for Stripe operations. + * @param usageService - The usage metering service. + */ + constructor( + private readonly billingService: BillingService, + private readonly usageService: UsageService, + ) {} + + /** + * Handles POST /billing/checkout — creates a Stripe Checkout Session. + * Reads the tenant ID from the authenticated user's organizationId. + * Returns { checkoutUrl } with HTTP 201. + * + * @param req - Express request. Must have req.user populated. + * @param res - Express response. + * @param next - Express next function. + */ + createCheckoutSession = async (req: Request, res: Response, next: NextFunction): Promise => { + try { + if (!req.user) { + throw new AuthorizationError(); + } + + const tenantId = req.user.organization_id; + if (!tenantId) { + throw new ValidationError('organization_id is required in token.'); + } + + const body = req.body as { successUrl?: unknown; cancelUrl?: unknown }; + + const successUrl = + typeof body.successUrl === 'string' && body.successUrl.length > 0 + ? body.successUrl + : `${req.protocol}://${req.hostname}/dashboard?billing=success`; + + const cancelUrl = + typeof body.cancelUrl === 'string' && body.cancelUrl.length > 0 + ? body.cancelUrl + : `${req.protocol}://${req.hostname}/dashboard?billing=cancel`; + + const checkoutUrl = await this.billingService.createCheckoutSession( + tenantId, + successUrl, + cancelUrl, + ); + + res.status(201).json({ checkoutUrl }); + } catch (err) { + next(err); + } + }; + + /** + * Handles POST /billing/webhook — processes Stripe webhook events. + * Reads the raw body buffer and Stripe-Signature header. + * Returns HTTP 200 { received: true } on success. + * Returns HTTP 400 if the signature header is missing. + * + * @param req - Express request. Body must be a raw Buffer. + * @param res - Express response. + * @param next - Express next function. + */ + handleWebhook = async (req: Request, res: Response, next: NextFunction): Promise => { + try { + const sig = req.headers['stripe-signature']; + if (!sig || typeof sig !== 'string') { + throw new ValidationError('Missing Stripe-Signature header.'); + } + + const webhookSecret = process.env['STRIPE_WEBHOOK_SECRET']; + if (!webhookSecret) { + throw new Error('STRIPE_WEBHOOK_SECRET environment variable is required.'); + } + + // req.body is a raw Buffer when express.raw() middleware is applied + const rawBody = req.body as Buffer; + + await this.billingService.handleWebhookEvent(rawBody, sig, webhookSecret); + + res.status(200).json({ received: true }); + } catch (err) { + next(err); + } + }; + + /** + * Handles GET /billing/usage — returns today's usage summary for the tenant. + * Returns combined { tenantId, date, apiCalls, agentCount, subscriptionStatus }. + * + * @param req - Express request. Must have req.user populated. + * @param res - Express response. + * @param next - Express next function. + */ + getUsage = async (req: Request, res: Response, next: NextFunction): Promise => { + try { + if (!req.user) { + throw new AuthorizationError(); + } + + const tenantId = req.user.organization_id; + if (!tenantId) { + throw new ValidationError('organization_id is required in token.'); + } + + const today = new Date().toISOString().slice(0, 10); // YYYY-MM-DD + + const [usage, subscription] = await Promise.all([ + this.usageService.getDailyUsage(tenantId, today), + this.billingService.getSubscriptionStatus(tenantId), + ]); + + res.status(200).json({ + tenantId: usage.tenantId, + date: usage.date, + apiCalls: usage.apiCalls, + agentCount: usage.agentCount, + subscriptionStatus: subscription.status, + currentPeriodEnd: subscription.currentPeriodEnd, + stripeSubscriptionId: subscription.stripeSubscriptionId, + }); + } catch (err) { + next(err); + } + }; +} diff --git a/src/db/migrations/023_add_billing.sql b/src/db/migrations/023_add_billing.sql new file mode 100644 index 0000000..aa0e9b1 --- /dev/null +++ b/src/db/migrations/023_add_billing.sql @@ -0,0 +1,43 @@ +-- Migration 023: Add billing and usage metering tables +-- Phase 4, WS6 — Billing & Usage Metering + +-- ── Tenant subscriptions ──────────────────────────────────────────────────────── +-- Tracks the Stripe subscription status for each tenant/organization. +-- One row per tenant. When no row exists, the tenant is on the free tier. +CREATE TABLE IF NOT EXISTS tenant_subscriptions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + tenant_id UUID NOT NULL REFERENCES organizations(org_id) ON DELETE CASCADE, + status VARCHAR(32) NOT NULL DEFAULT 'free', + stripe_customer_id VARCHAR(255), + stripe_subscription_id VARCHAR(255), + current_period_end TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- One subscription row per tenant (UPSERT target) +ALTER TABLE tenant_subscriptions + ADD CONSTRAINT uq_tenant_subscriptions_tenant_id UNIQUE (tenant_id); + +CREATE INDEX IF NOT EXISTS idx_tenant_subscriptions_tenant_id + ON tenant_subscriptions(tenant_id); + +-- ── Usage events ───────────────────────────────────────────────────────────────── +-- Daily usage counters per tenant and metric type. +-- Uses UPSERT with ON CONFLICT to accumulate counts without duplicates. +CREATE TABLE IF NOT EXISTS usage_events ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + tenant_id UUID NOT NULL, + date DATE NOT NULL DEFAULT CURRENT_DATE, + metric_type VARCHAR(64) NOT NULL, + count INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Unique constraint enables UPSERT: ON CONFLICT (tenant_id, date, metric_type) +ALTER TABLE usage_events + ADD CONSTRAINT uq_usage_events_tenant_date_metric + UNIQUE (tenant_id, date, metric_type); + +CREATE INDEX IF NOT EXISTS idx_usage_events_tenant_date + ON usage_events(tenant_id, date); diff --git a/src/metrics/registry.ts b/src/metrics/registry.ts index f63fcca..7ee3cfe 100644 --- a/src/metrics/registry.ts +++ b/src/metrics/registry.ts @@ -159,3 +159,16 @@ export const tenantApiCallsTotal = new Counter({ labelNames: ['tenant_id'] as const, registers: [metricsRegistry], }); + +/** + * Total number of requests rejected due to free tier billing limits. + * Labels: tenant_id, limit_type ('agent_limit' | 'api_limit') + * + * WS6 — Billing & Usage Metering. + */ +export const billingLimitRejectionsTotal = new Counter({ + name: 'agentidp_billing_limit_rejections_total', + help: 'Total number of requests rejected due to free tier billing limits.', + labelNames: ['tenant_id', 'limit_type'] as const, + registers: [metricsRegistry], +}); diff --git a/src/middleware/freeTierEnforcementMiddleware.ts b/src/middleware/freeTierEnforcementMiddleware.ts new file mode 100644 index 0000000..34f9800 --- /dev/null +++ b/src/middleware/freeTierEnforcementMiddleware.ts @@ -0,0 +1,189 @@ +/** + * Free tier enforcement middleware for SentryAgent.ai AgentIdP. + * Rejects requests that exceed free-tier usage limits (agents or API calls). + * Skipped entirely when BILLING_ENABLED=false. + */ + +import { Request, Response, NextFunction, RequestHandler } from 'express'; +import { Pool } from 'pg'; +import type { RedisClientType } from 'redis'; +import { SentryAgentError } from '../utils/errors.js'; +import { billingLimitRejectionsTotal } from '../metrics/registry.js'; +import { UsageService } from '../services/UsageService.js'; + +/** Free tier: maximum number of active (non-decommissioned) agents per tenant. */ +const FREE_TIER_MAX_AGENTS = 10; + +/** Free tier: maximum number of API calls per day per tenant. */ +const FREE_TIER_MAX_API_CALLS = 1_000; + +/** Redis cache TTL for billing usage checks, in seconds. */ +const CACHE_TTL_SECONDS = 60; + +/** + * Thrown when a free-tier tenant exceeds the agent creation limit. + * HTTP 429 with code FREE_TIER_AGENT_LIMIT. + */ +class FreeTierAgentLimitError extends SentryAgentError { + constructor(current: number) { + super( + `Free tier limit of ${FREE_TIER_MAX_AGENTS} active agents reached. Upgrade to create more agents.`, + 'FREE_TIER_AGENT_LIMIT', + 429, + { limit: FREE_TIER_MAX_AGENTS, current }, + ); + } +} + +/** + * Thrown when a free-tier tenant exceeds the daily API call limit. + * HTTP 429 with code FREE_TIER_API_LIMIT. + */ +class FreeTierApiLimitError extends SentryAgentError { + constructor(current: number) { + super( + `Free tier daily API call limit of ${FREE_TIER_MAX_API_CALLS} reached. Upgrade for unlimited calls.`, + 'FREE_TIER_API_LIMIT', + 429, + { limit: FREE_TIER_MAX_API_CALLS, current }, + ); + } +} + +/** + * Returns true when the tenant is on the free plan. + * A tenant is considered free when no row exists in tenant_subscriptions, + * or the existing row has status = 'free'. + * + * @param pool - PostgreSQL connection pool. + * @param tenantId - The tenant UUID. + * @returns True if the tenant is on the free tier. + */ +async function isFreeTenant(pool: Pool, tenantId: string): Promise { + const result = await pool.query<{ status: string }>( + `SELECT status FROM tenant_subscriptions WHERE tenant_id = $1 LIMIT 1`, + [tenantId], + ); + + if (result.rows.length === 0) return true; + return result.rows[0].status === 'free'; +} + +/** + * Returns the cached API call count for a tenant on a given date. + * Returns null when the cache is cold. + * + * @param redis - Redis client. + * @param tenantId - The tenant UUID. + * @param date - Date string in YYYY-MM-DD format. + * @returns Cached count or null. + */ +async function getCachedApiCalls( + redis: RedisClientType, + tenantId: string, + date: string, +): Promise { + const cacheKey = `billing:usage:${tenantId}:${date}`; + const cached = await redis.get(cacheKey); + if (cached === null) return null; + return parseInt(cached, 10); +} + +/** + * Writes the API call count to the Redis cache with a TTL. + * + * @param redis - Redis client. + * @param tenantId - The tenant UUID. + * @param date - Date string in YYYY-MM-DD format. + * @param count - The API call count to cache. + */ +async function setCachedApiCalls( + redis: RedisClientType, + tenantId: string, + date: string, + count: number, +): Promise { + const cacheKey = `billing:usage:${tenantId}:${date}`; + await redis.set(cacheKey, String(count), { EX: CACHE_TTL_SECONDS }); +} + +/** + * Creates the free tier enforcement middleware. + * + * Behaviour: + * - When BILLING_ENABLED env var is 'false' (string), calls next() immediately. + * - On agent creation (POST on an agents path): checks active agent count. + * If >= FREE_TIER_MAX_AGENTS and tenant is free → rejects with HTTP 429 FREE_TIER_AGENT_LIMIT. + * - On every authenticated request: checks daily API call count. + * If >= FREE_TIER_MAX_API_CALLS and tenant is free → rejects with HTTP 429 FREE_TIER_API_LIMIT. + * - Uses Redis to cache usage lookups (TTL: 60s) to minimise DB queries. + * + * @param pool - PostgreSQL connection pool. + * @param redis - Redis client. + * @returns Express RequestHandler. + */ +export function createFreeTierEnforcementMiddleware( + pool: Pool, + redis: RedisClientType, +): RequestHandler { + const usageService = new UsageService(pool); + + return (req: Request, _res: Response, next: NextFunction): void => { + // Skip if billing is disabled + if (process.env['BILLING_ENABLED'] === 'false') { + next(); + return; + } + + // Only enforce for authenticated requests + if (!req.user?.organization_id) { + next(); + return; + } + + const tenantId = req.user.organization_id; + const today = new Date().toISOString().slice(0, 10); // YYYY-MM-DD + + void (async (): Promise => { + try { + const free = await isFreeTenant(pool, tenantId); + + if (!free) { + next(); + return; + } + + // ── API call limit check ──────────────────────────────────────────── + let apiCalls = await getCachedApiCalls(redis, tenantId, today); + + if (apiCalls === null) { + const summary = await usageService.getDailyUsage(tenantId, today); + apiCalls = summary.apiCalls; + await setCachedApiCalls(redis, tenantId, today, apiCalls); + } + + if (apiCalls >= FREE_TIER_MAX_API_CALLS) { + billingLimitRejectionsTotal.inc({ tenant_id: tenantId, limit_type: 'api_limit' }); + throw new FreeTierApiLimitError(apiCalls); + } + + // ── Agent creation limit check ────────────────────────────────────── + const isAgentCreation = + req.method === 'POST' && /\/agents(\/?)$/.test(req.path); + + if (isAgentCreation) { + const agentCount = await usageService.getActiveAgentCount(tenantId); + + if (agentCount >= FREE_TIER_MAX_AGENTS) { + billingLimitRejectionsTotal.inc({ tenant_id: tenantId, limit_type: 'agent_limit' }); + throw new FreeTierAgentLimitError(agentCount); + } + } + + next(); + } catch (err) { + next(err); + } + })(); + }; +} diff --git a/src/middleware/usageMeteringMiddleware.ts b/src/middleware/usageMeteringMiddleware.ts new file mode 100644 index 0000000..9dd83c0 --- /dev/null +++ b/src/middleware/usageMeteringMiddleware.ts @@ -0,0 +1,86 @@ +/** + * Usage metering middleware for SentryAgent.ai AgentIdP. + * Tracks per-tenant API call counts in-memory and periodically flushes to the DB. + */ + +import { Request, Response, NextFunction, RequestHandler } from 'express'; +import { Pool } from 'pg'; + +/** In-memory counter: tenantId → accumulated API call count since last flush. */ +const tenantCounters = new Map(); + +/** Flush interval in milliseconds. */ +const FLUSH_INTERVAL_MS = 60_000; + +/** + * Increments the in-memory API call counter for a tenant. + * + * @param tenantId - The tenant UUID. + */ +function incrementCounter(tenantId: string): void { + const current = tenantCounters.get(tenantId) ?? 0; + tenantCounters.set(tenantId, current + 1); +} + +/** + * Flushes all in-memory counters to the `usage_events` table via UPSERT. + * Clears the in-memory counters after a successful flush. + * Uses ON CONFLICT to accumulate counts rather than overwrite. + * + * @param pool - PostgreSQL connection pool. + */ +async function flushCounters(pool: Pool): Promise { + if (tenantCounters.size === 0) { + return; + } + + // Snapshot and clear before async DB work to avoid double-counting on slow flushes + const snapshot = new Map(tenantCounters); + tenantCounters.clear(); + + const today = new Date().toISOString().slice(0, 10); // YYYY-MM-DD + + for (const [tenantId, count] of snapshot) { + if (count <= 0) continue; + + await pool.query( + `INSERT INTO usage_events (tenant_id, date, metric_type, count) + VALUES ($1, $2, 'api_calls', $3) + ON CONFLICT (tenant_id, date, metric_type) + DO UPDATE SET count = usage_events.count + EXCLUDED.count`, + [tenantId, today, count], + ); + } +} + +/** + * Creates the usage metering middleware. + * On every authenticated request (where `req.user` is populated), + * increments the in-memory counter for the tenant identified by `organizationId`. + * + * @param _pool - PostgreSQL connection pool (used internally by flushCounters). + * @returns Express RequestHandler. + */ +export function createUsageMeteringMiddleware(_pool: Pool): RequestHandler { + return (req: Request, _res: Response, next: NextFunction): void => { + if (req.user?.organization_id) { + incrementCounter(req.user.organization_id); + } + next(); + }; +} + +/** + * Starts the periodic flush interval that writes in-memory counters to the DB. + * Call once at application startup (after the pool is ready). + * + * @param pool - PostgreSQL connection pool. + * @returns The NodeJS.Timeout handle (can be used to stop the interval in tests). + */ +export function startUsageMeteringFlush(pool: Pool): NodeJS.Timeout { + return setInterval(() => { + flushCounters(pool).catch((err: unknown) => { + console.error('[UsageMetering] Failed to flush usage counters:', err); + }); + }, FLUSH_INTERVAL_MS); +} diff --git a/src/routes/billing.ts b/src/routes/billing.ts new file mode 100644 index 0000000..146ca97 --- /dev/null +++ b/src/routes/billing.ts @@ -0,0 +1,52 @@ +/** + * Billing routes for SentryAgent.ai AgentIdP. + * Provides Stripe checkout, webhook processing, and usage reporting endpoints. + */ + +import { Router, RequestHandler } from 'express'; +import express from 'express'; +import { BillingController } from '../controllers/BillingController.js'; +import { asyncHandler } from '../utils/asyncHandler.js'; + +/** + * Creates and returns the Express router for billing endpoints. + * + * Routes: + * POST /billing/checkout — authenticated; creates a Stripe Checkout Session + * POST /billing/webhook — unauthenticated; receives Stripe webhook events (raw body) + * GET /billing/usage — authenticated; returns today's usage summary + * + * @param controller - The billing controller instance. + * @param authMiddleware - The JWT authentication middleware for protected endpoints. + * @returns Configured Express router. + */ +export function createBillingRouter( + controller: BillingController, + authMiddleware: RequestHandler, +): Router { + const router = Router(); + + // POST /billing/checkout — authenticated; creates a Stripe Checkout Session + router.post( + '/checkout', + authMiddleware, + asyncHandler(controller.createCheckoutSession.bind(controller)), + ); + + // POST /billing/webhook — unauthenticated; Stripe sends raw JSON body + // express.raw() must be applied HERE so the body is a Buffer (required for signature verification) + router.post( + '/webhook', + express.raw({ type: 'application/json' }), + asyncHandler(controller.handleWebhook.bind(controller)), + ); + + // GET /billing/usage — authenticated; returns usage summary for today + router.get( + '/usage', + authMiddleware, + asyncHandler(controller.getUsage.bind(controller)), + ); + + return router; +} diff --git a/src/services/BillingService.ts b/src/services/BillingService.ts new file mode 100644 index 0000000..9f36dd5 --- /dev/null +++ b/src/services/BillingService.ts @@ -0,0 +1,187 @@ +/** + * Billing service for SentryAgent.ai AgentIdP. + * Manages Stripe checkout sessions, webhook processing, and subscription status. + */ + +import { Pool } from 'pg'; +import Stripe from 'stripe'; + +/** + * Current subscription status for a tenant. + */ +export interface ISubscriptionStatus { + /** The tenant (organization) UUID. */ + tenantId: string; + /** Subscription status: 'free', 'active', 'past_due', 'canceled', etc. */ + status: string; + /** End of the current billing period, or null if on free tier. */ + currentPeriodEnd: Date | null; + /** Stripe subscription ID, or null if on free tier. */ + stripeSubscriptionId: string | null; +} + +/** DB row shape for tenant_subscriptions queries. */ +interface ISubscriptionRow { + status: string; + current_period_end: Date | null; + stripe_subscription_id: string | null; +} + +/** + * Service for managing Stripe billing integration. + * Handles checkout session creation, webhook event processing, + * and subscription status retrieval. + */ +export class BillingService { + /** + * @param pool - PostgreSQL connection pool. + * @param stripe - Configured Stripe client instance. + */ + constructor( + private readonly pool: Pool, + private readonly stripe: Stripe, + ) {} + + /** + * Creates a Stripe Checkout Session for a tenant to subscribe. + * Returns the URL the tenant should be redirected to complete payment. + * + * @param tenantId - The tenant UUID (used as client_reference_id). + * @param successUrl - URL to redirect to on successful checkout. + * @param cancelUrl - URL to redirect to if checkout is cancelled. + * @returns The Stripe Checkout Session URL. + * @throws Error if the session URL is not returned by Stripe. + */ + async createCheckoutSession( + tenantId: string, + successUrl: string, + cancelUrl: string, + ): Promise { + const priceId = process.env['STRIPE_PRICE_ID']; + + const session = await this.stripe.checkout.sessions.create({ + mode: 'subscription', + client_reference_id: tenantId, + line_items: priceId + ? [{ price: priceId, quantity: 1 }] + : undefined, + success_url: successUrl, + cancel_url: cancelUrl, + }); + + if (!session.url) { + throw new Error('Stripe did not return a checkout session URL.'); + } + + return session.url; + } + + /** + * Verifies and processes an incoming Stripe webhook event. + * Handles subscription created, updated, and deleted events + * by upserting the tenant_subscriptions table. + * + * @param rawBody - The raw request body buffer (required for signature verification). + * @param sig - The value of the Stripe-Signature request header. + * @param webhookSecret - The Stripe webhook endpoint secret (whsec_...). + * @throws Error if the webhook signature is invalid. + */ + async handleWebhookEvent( + rawBody: Buffer, + sig: string, + webhookSecret: string, + ): Promise { + const event = this.stripe.webhooks.constructEvent(rawBody, sig, webhookSecret); + + if ( + event.type === 'customer.subscription.created' || + event.type === 'customer.subscription.updated' || + event.type === 'customer.subscription.deleted' + ) { + const subscription = event.data.object as Stripe.Subscription; + await this.upsertSubscription(subscription); + } + } + + /** + * Returns the current subscription status for a tenant. + * When no subscription row exists, returns a free-tier status. + * + * @param tenantId - The tenant UUID. + * @returns The current ISubscriptionStatus. + */ + async getSubscriptionStatus(tenantId: string): Promise { + const result = await this.pool.query( + `SELECT status, current_period_end, stripe_subscription_id + FROM tenant_subscriptions + WHERE tenant_id = $1 + LIMIT 1`, + [tenantId], + ); + + if (result.rows.length === 0) { + return { + tenantId, + status: 'free', + currentPeriodEnd: null, + stripeSubscriptionId: null, + }; + } + + const row = result.rows[0]; + return { + tenantId, + status: row.status, + currentPeriodEnd: row.current_period_end ?? null, + stripeSubscriptionId: row.stripe_subscription_id ?? null, + }; + } + + /** + * Upserts a Stripe subscription into tenant_subscriptions. + * Resolves the tenant from the subscription's customer. + * Requires the Stripe customer metadata to contain a `tenant_id` field, + * OR a checkout session client_reference_id to have been stored. + * Falls back to fetching the customer record to find metadata.tenant_id. + * + * @param subscription - The Stripe Subscription object. + */ + private async upsertSubscription(subscription: Stripe.Subscription): Promise { + // Fetch the customer to retrieve tenant_id from metadata + const customerId = + typeof subscription.customer === 'string' + ? subscription.customer + : subscription.customer.id; + + const customer = await this.stripe.customers.retrieve(customerId); + if (customer.deleted) { + return; + } + + const tenantId = (customer as Stripe.Customer).metadata?.['tenant_id']; + if (!tenantId) { + // Cannot map to a tenant — skip gracefully + return; + } + + // billing_cycle_anchor gives the next billing date (used as period end proxy). + // ended_at is populated when the subscription is canceled. + const periodTimestamp = subscription.ended_at ?? subscription.billing_cycle_anchor; + const currentPeriodEnd = new Date(periodTimestamp * 1000); + const status = subscription.status; + + await this.pool.query( + `INSERT INTO tenant_subscriptions + (tenant_id, status, stripe_customer_id, stripe_subscription_id, current_period_end, updated_at) + VALUES ($1, $2, $3, $4, $5, NOW()) + ON CONFLICT (tenant_id) + DO UPDATE SET + status = EXCLUDED.status, + stripe_customer_id = EXCLUDED.stripe_customer_id, + stripe_subscription_id = EXCLUDED.stripe_subscription_id, + current_period_end = EXCLUDED.current_period_end, + updated_at = NOW()`, + [tenantId, status, customerId, subscription.id, currentPeriodEnd], + ); + } +} diff --git a/src/services/UsageService.ts b/src/services/UsageService.ts new file mode 100644 index 0000000..4d3364a --- /dev/null +++ b/src/services/UsageService.ts @@ -0,0 +1,73 @@ +/** + * Usage metering service for SentryAgent.ai AgentIdP. + * Provides daily usage summaries and active agent counts per tenant. + */ + +import { Pool } from 'pg'; + +/** + * Daily usage summary for a tenant. + */ +export interface IUsageSummary { + /** The tenant (organization) UUID. */ + tenantId: string; + /** Date string in YYYY-MM-DD format. */ + date: string; + /** Number of API calls made on the given date. */ + apiCalls: number; + /** Number of active (non-decommissioned) agents for the tenant. */ + agentCount: number; +} + +/** + * Service for retrieving per-tenant usage data. + * Reads from the `usage_events` and `agents` tables. + */ +export class UsageService { + /** + * @param pool - PostgreSQL connection pool. + */ + constructor(private readonly pool: Pool) {} + + /** + * Returns the daily usage summary for a tenant on a given date. + * If no usage row exists for the date, apiCalls defaults to 0. + * + * @param tenantId - The tenant UUID. + * @param date - Date string in 'YYYY-MM-DD' format. + * @returns A resolved IUsageSummary with api_calls and agent count. + */ + async getDailyUsage(tenantId: string, date: string): Promise { + const usageResult = await this.pool.query<{ count: string }>( + `SELECT COALESCE(SUM(count), 0) AS count + FROM usage_events + WHERE tenant_id = $1 + AND date = $2 + AND metric_type = 'api_calls'`, + [tenantId, date], + ); + + const agentCount = await this.getActiveAgentCount(tenantId); + const apiCalls = parseInt(usageResult.rows[0]?.count ?? '0', 10); + + return { tenantId, date, apiCalls, agentCount }; + } + + /** + * Returns the number of non-decommissioned agents for a tenant. + * + * @param tenantId - The tenant UUID. + * @returns The count of active agents (status != 'decommissioned'). + */ + async getActiveAgentCount(tenantId: string): Promise { + const result = await this.pool.query<{ count: string }>( + `SELECT COUNT(*) AS count + FROM agents + WHERE organization_id = $1 + AND status != 'decommissioned'`, + [tenantId], + ); + + return parseInt(result.rows[0]?.count ?? '0', 10); + } +} diff --git a/tests/unit/metrics/registry.test.ts b/tests/unit/metrics/registry.test.ts index 5ce1f4b..273e924 100644 --- a/tests/unit/metrics/registry.test.ts +++ b/tests/unit/metrics/registry.test.ts @@ -19,6 +19,8 @@ import { rateLimitHitsTotal, dbPoolActiveConnections, dbPoolWaitingRequests, + tenantApiCallsTotal, + billingLimitRejectionsTotal, } from '../../../src/metrics/registry'; describe('metricsRegistry', () => { @@ -33,9 +35,9 @@ describe('metricsRegistry', () => { expect(metricsRegistry).not.toBe(register); }); - it('contains exactly 12 metric entries', async () => { + it('contains exactly 14 metric entries', async () => { const entries = await metricsRegistry.getMetricsAsJSON(); - expect(entries).toHaveLength(12); + expect(entries).toHaveLength(14); }); // ────────────────────────────────────────────────────────────────── @@ -54,6 +56,8 @@ describe('metricsRegistry', () => { 'agentidp_rate_limit_hits_total', 'agentidp_db_pool_active_connections', 'agentidp_db_pool_waiting_requests', + 'agentidp_tenant_api_calls_total', + 'agentidp_billing_limit_rejections_total', ])('registers metric "%s"', async (metricName) => { const entries = await metricsRegistry.getMetricsAsJSON(); const names = entries.map((e) => e.name); @@ -200,4 +204,33 @@ describe('metricsRegistry', () => { expect(() => dbPoolWaitingRequests.set(2)).not.toThrow(); }); }); + + describe('tenantApiCallsTotal', () => { + it('has name agentidp_tenant_api_calls_total', () => { + const metric = tenantApiCallsTotal as unknown as { name: string }; + expect(metric.name).toBe('agentidp_tenant_api_calls_total'); + }); + + it('increments with tenant_id label without throwing', () => { + expect(() => + tenantApiCallsTotal.inc({ tenant_id: 'org-test-001' }), + ).not.toThrow(); + }); + }); + + describe('billingLimitRejectionsTotal', () => { + it('has name agentidp_billing_limit_rejections_total', () => { + const metric = billingLimitRejectionsTotal as unknown as { name: string }; + expect(metric.name).toBe('agentidp_billing_limit_rejections_total'); + }); + + it('increments with tenant_id and limit_type labels without throwing', () => { + expect(() => + billingLimitRejectionsTotal.inc({ tenant_id: 'org-test-001', limit_type: 'agent_limit' }), + ).not.toThrow(); + expect(() => + billingLimitRejectionsTotal.inc({ tenant_id: 'org-test-002', limit_type: 'api_limit' }), + ).not.toThrow(); + }); + }); }); diff --git a/tests/unit/middleware/billing.test.ts b/tests/unit/middleware/billing.test.ts new file mode 100644 index 0000000..09bcbae --- /dev/null +++ b/tests/unit/middleware/billing.test.ts @@ -0,0 +1,304 @@ +/** + * Unit tests for billing middleware: + * - FreeTierEnforcementMiddleware + * - BillingController.handleWebhook + */ + +import { Request, Response, NextFunction } from 'express'; +import { Pool, QueryResult } from 'pg'; +import { createFreeTierEnforcementMiddleware } from '../../../src/middleware/freeTierEnforcementMiddleware'; +import { BillingController } from '../../../src/controllers/BillingController'; +import { BillingService } from '../../../src/services/BillingService'; +import { UsageService } from '../../../src/services/UsageService'; +import { ITokenPayload } from '../../../src/types/index'; + +// ── Helpers ────────────────────────────────────────────────────────────────── + +function makePool(queryFn: jest.Mock): Pool { + return { query: queryFn } as unknown as Pool; +} + +type RedisClientMock = { + get: jest.Mock; + set: jest.Mock; +}; + +function makeRedis(overrides: Partial = {}): RedisClientMock { + return { + get: overrides.get ?? jest.fn().mockResolvedValue(null), + set: overrides.set ?? jest.fn().mockResolvedValue('OK'), + }; +} + +function makeRequest(overrides: Partial<{ + method: string; + path: string; + organizationId: string | undefined; +}>): Request { + const organizationId = overrides.organizationId ?? 'org-uuid-123'; + const user: ITokenPayload | undefined = organizationId !== undefined + ? { + sub: 'agent-1', + client_id: 'agent-1', + scope: 'agents:read', + jti: 'jti-1', + iat: Math.floor(Date.now() / 1000), + exp: Math.floor(Date.now() / 1000) + 3600, + organization_id: organizationId, + } + : undefined; + + return { + method: overrides.method ?? 'GET', + path: overrides.path ?? '/api/v1/agents', + user, + headers: {}, + } as unknown as Request; +} + +function makeResponse(): Response { + return { + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), + } as unknown as Response; +} + +// ════════════════════════════════════════════════════════════════════════════ +// FreeTierEnforcementMiddleware +// ════════════════════════════════════════════════════════════════════════════ + +describe('createFreeTierEnforcementMiddleware', () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it('should call next() immediately when BILLING_ENABLED=false', () => { + process.env['BILLING_ENABLED'] = 'false'; + + const pool = makePool(jest.fn()); + const redis = makeRedis(); + const next = jest.fn() as NextFunction; + + const middleware = createFreeTierEnforcementMiddleware(pool, redis as never); + middleware(makeRequest({}), makeResponse(), next); + + // next() called synchronously because billing is disabled + expect(next).toHaveBeenCalledTimes(1); + expect((next as jest.Mock).mock.calls[0]).toHaveLength(0); + }); + + it('should call next() without error for unauthenticated request', () => { + process.env['BILLING_ENABLED'] = 'true'; + + const pool = makePool(jest.fn()); + const redis = makeRedis(); + const next = jest.fn() as NextFunction; + + const req: Request = { + method: 'GET', + path: '/api/v1/agents', + user: undefined, + headers: {}, + } as unknown as Request; + + const middleware = createFreeTierEnforcementMiddleware(pool, redis as never); + middleware(req, makeResponse(), next); + + expect(next).toHaveBeenCalledTimes(1); + expect((next as jest.Mock).mock.calls[0]).toHaveLength(0); + }); + + it('should call next(error) with FREE_TIER_API_LIMIT when daily API calls >= 1000', async () => { + process.env['BILLING_ENABLED'] = 'true'; + + // Query sequence: + // 1: isFreeTenant → no subscription row (free) + // 2: getDailyUsage → usage_events count = 1000 + // 3: getDailyUsage → agent count (getActiveAgentCount called inside getDailyUsage) + const mockQuery = jest.fn() + .mockResolvedValueOnce({ rows: [] } as unknown as QueryResult) // isFreeTenant + .mockResolvedValueOnce({ rows: [{ count: '1000' }] } as unknown as QueryResult) // usage_events + .mockResolvedValueOnce({ rows: [{ count: '5' }] } as unknown as QueryResult); // agents count + + // Cache miss → hits DB + const redis = makeRedis({ get: jest.fn().mockResolvedValue(null) }); + + const nextCalled = new Promise((resolve) => { + const next = ((err?: unknown) => { resolve(err); }) as NextFunction; + const middleware = createFreeTierEnforcementMiddleware(makePool(mockQuery), redis as never); + middleware(makeRequest({ method: 'GET', path: '/api/v1/agents' }), makeResponse(), next); + }); + + const callArg = await nextCalled; + expect(callArg).toBeDefined(); + expect((callArg as { code: string }).code).toBe('FREE_TIER_API_LIMIT'); + }); + + it('should call next(error) with FREE_TIER_AGENT_LIMIT when agent count >= 10 on POST /agents', async () => { + process.env['BILLING_ENABLED'] = 'true'; + + // Query sequence: + // 1: isFreeTenant → free + // 2: usage_events → 500 calls (below limit, so continue) + // 3: agents count for getDailyUsage + // 4: agents count for isAgentCreation check + const mockQuery = jest.fn() + .mockResolvedValueOnce({ rows: [] } as unknown as QueryResult) // isFreeTenant + .mockResolvedValueOnce({ rows: [{ count: '500' }] } as unknown as QueryResult) // usage_events + .mockResolvedValueOnce({ rows: [{ count: '3' }] } as unknown as QueryResult) // getDailyUsage agent count + .mockResolvedValueOnce({ rows: [{ count: '10' }] } as unknown as QueryResult); // isAgentCreation check + + const redis = makeRedis({ get: jest.fn().mockResolvedValue(null) }); + + const nextCalled = new Promise((resolve) => { + const next = ((err?: unknown) => { resolve(err); }) as NextFunction; + const middleware = createFreeTierEnforcementMiddleware(makePool(mockQuery), redis as never); + middleware(makeRequest({ method: 'POST', path: '/agents' }), makeResponse(), next); + }); + + const callArg = await nextCalled; + expect(callArg).toBeDefined(); + expect((callArg as { code: string }).code).toBe('FREE_TIER_AGENT_LIMIT'); + }); + + it('should call next() without error for paid tenant regardless of limits', async () => { + process.env['BILLING_ENABLED'] = 'true'; + + // Active subscription → paid + const mockQuery = jest.fn().mockResolvedValue({ + rows: [{ status: 'active' }], + } as unknown as QueryResult); + + const redis = makeRedis(); + + const nextCalled = new Promise((resolve) => { + const next = ((err?: unknown) => { resolve(err); }) as NextFunction; + const middleware = createFreeTierEnforcementMiddleware(makePool(mockQuery), redis as never); + middleware(makeRequest({ method: 'POST', path: '/agents' }), makeResponse(), next); + }); + + const callArg = await nextCalled; + // next() called with no error + expect(callArg).toBeUndefined(); + }); + + it('should use Redis cache and skip DB usage query on cache hit', async () => { + process.env['BILLING_ENABLED'] = 'true'; + + // Only 1 DB query: isFreeTenant (no subscription) + // The second query for usage is replaced by a Redis cache hit + const mockQuery = jest.fn() + .mockResolvedValueOnce({ rows: [] } as unknown as QueryResult); // isFreeTenant + + // Cache returns api_calls = 100 (below 1000 limit) + const redis = makeRedis({ get: jest.fn().mockResolvedValue('100') }); + + const nextCalled = new Promise((resolve) => { + const next = ((err?: unknown) => { resolve(err); }) as NextFunction; + const middleware = createFreeTierEnforcementMiddleware(makePool(mockQuery), redis as never); + middleware(makeRequest({ method: 'GET', path: '/api/v1/agents' }), makeResponse(), next); + }); + + const callArg = await nextCalled; + // Only 1 query (isFreeTenant); no DB call for usage + expect(mockQuery).toHaveBeenCalledTimes(1); + expect(callArg).toBeUndefined(); + }); +}); + +// ════════════════════════════════════════════════════════════════════════════ +// BillingController.handleWebhook +// ════════════════════════════════════════════════════════════════════════════ + +describe('BillingController.handleWebhook', () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv, STRIPE_WEBHOOK_SECRET: 'whsec_test' }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + function makeBillingController( + handleWebhookEventFn: jest.Mock = jest.fn().mockResolvedValue(undefined), + ): BillingController { + const billingService = { + handleWebhookEvent: handleWebhookEventFn, + createCheckoutSession: jest.fn(), + getSubscriptionStatus: jest.fn().mockResolvedValue({ + tenantId: 'test', + status: 'free', + currentPeriodEnd: null, + stripeSubscriptionId: null, + }), + } as unknown as BillingService; + + const usageService = { + getDailyUsage: jest.fn().mockResolvedValue({ + tenantId: 'test', + date: '2026-04-02', + apiCalls: 0, + agentCount: 0, + }), + getActiveAgentCount: jest.fn().mockResolvedValue(0), + } as unknown as UsageService; + + return new BillingController(billingService, usageService); + } + + it('should return 200 { received: true } for valid Stripe-Signature', async () => { + const controller = makeBillingController(); + const req = { + headers: { 'stripe-signature': 'valid-sig' }, + body: Buffer.from('{}'), + } as unknown as Request; + const res = makeResponse(); + const next = jest.fn() as NextFunction; + + await controller.handleWebhook(req, res, next); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.json).toHaveBeenCalledWith({ received: true }); + expect(next).not.toHaveBeenCalled(); + }); + + it('should call next() with ValidationError when Stripe-Signature header is missing', async () => { + const controller = makeBillingController(); + const req = { + headers: {}, + body: Buffer.from('{}'), + } as unknown as Request; + const res = makeResponse(); + const next = jest.fn() as NextFunction; + + await controller.handleWebhook(req, res, next); + + expect(next).toHaveBeenCalledTimes(1); + const callArg = (next as jest.Mock).mock.calls[0]?.[0]; + expect((callArg as { code: string }).code).toBe('VALIDATION_ERROR'); + }); + + it('should call next() with the error when BillingService throws', async () => { + const stripeError = new Error('Invalid signature'); + const controller = makeBillingController(jest.fn().mockRejectedValue(stripeError)); + + const req = { + headers: { 'stripe-signature': 'bad-sig' }, + body: Buffer.from('{}'), + } as unknown as Request; + const res = makeResponse(); + const next = jest.fn() as NextFunction; + + await controller.handleWebhook(req, res, next); + + expect(next).toHaveBeenCalledWith(stripeError); + }); +}); diff --git a/tests/unit/services/BillingService.test.ts b/tests/unit/services/BillingService.test.ts new file mode 100644 index 0000000..b3c49b2 --- /dev/null +++ b/tests/unit/services/BillingService.test.ts @@ -0,0 +1,262 @@ +/** + * Unit tests for src/services/BillingService.ts and src/services/UsageService.ts + */ + +import { Pool, QueryResult } from 'pg'; +import Stripe from 'stripe'; +import { BillingService } from '../../../src/services/BillingService'; +import { UsageService } from '../../../src/services/UsageService'; + +// ── Mock pg Pool ───────────────────────────────────────────────────────────── + +function makePool(queryFn: jest.Mock): Pool { + return { query: queryFn } as unknown as Pool; +} + +// ── Mock Stripe ─────────────────────────────────────────────────────────────── + +function makeStripe(overrides: Partial<{ + checkoutUrl: string; + constructEvent: () => Stripe.Event; + retrieveCustomer: () => Stripe.Customer; +}> = {}): Stripe { + const checkoutUrl = overrides.checkoutUrl ?? 'https://checkout.stripe.com/session_test'; + + const fakeSubscription: Stripe.Subscription = { + id: 'sub_1', + object: 'subscription', + status: 'active', + customer: 'cus_1', + billing_cycle_anchor: Math.floor(Date.now() / 1000) + 86400, + ended_at: null, + } as unknown as Stripe.Subscription; + + const fakeEvent: Stripe.Event = { + id: 'evt_1', + type: 'customer.subscription.created', + object: 'event', + data: { object: fakeSubscription }, + api_version: '2026-03-25.dahlia', + created: Math.floor(Date.now() / 1000), + livemode: false, + pending_webhooks: 0, + request: null, + } as unknown as Stripe.Event; + + const fakeCustomer: Stripe.Customer = { + id: 'cus_1', + object: 'customer', + created: Math.floor(Date.now() / 1000), + livemode: false, + metadata: { tenant_id: 'tenant-uuid-123' }, + deleted: undefined, + } as unknown as Stripe.Customer; + + return { + checkout: { + sessions: { + create: jest.fn().mockResolvedValue({ url: checkoutUrl, id: 'cs_1' }), + }, + }, + webhooks: { + constructEvent: overrides.constructEvent + ? jest.fn(overrides.constructEvent) + : jest.fn().mockReturnValue(fakeEvent), + }, + customers: { + retrieve: overrides.retrieveCustomer + ? jest.fn(overrides.retrieveCustomer) + : jest.fn().mockResolvedValue(fakeCustomer), + }, + } as unknown as Stripe; +} + +// ════════════════════════════════════════════════════════════════════════════ +// UsageService +// ════════════════════════════════════════════════════════════════════════════ + +describe('UsageService', () => { + describe('getDailyUsage()', () => { + it('should return correct IUsageSummary with api_calls from DB', async () => { + const mockQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ count: '42' }] } as unknown as QueryResult) + .mockResolvedValueOnce({ rows: [{ count: '5' }] } as unknown as QueryResult); + + const service = new UsageService(makePool(mockQuery)); + const result = await service.getDailyUsage('tenant-1', '2026-04-02'); + + expect(result).toEqual({ + tenantId: 'tenant-1', + date: '2026-04-02', + apiCalls: 42, + agentCount: 5, + }); + }); + + it('should default apiCalls to 0 when no usage_events row exists', async () => { + const mockQuery = jest.fn() + .mockResolvedValueOnce({ rows: [{ count: '0' }] } as unknown as QueryResult) + .mockResolvedValueOnce({ rows: [{ count: '2' }] } as unknown as QueryResult); + + const service = new UsageService(makePool(mockQuery)); + const result = await service.getDailyUsage('tenant-2', '2026-04-02'); + + expect(result.apiCalls).toBe(0); + expect(result.agentCount).toBe(2); + }); + }); + + describe('getActiveAgentCount()', () => { + it('should return the count of non-decommissioned agents', async () => { + const mockQuery = jest.fn().mockResolvedValue({ + rows: [{ count: '7' }], + } as unknown as QueryResult); + + const service = new UsageService(makePool(mockQuery)); + const count = await service.getActiveAgentCount('tenant-1'); + + expect(count).toBe(7); + expect(mockQuery).toHaveBeenCalledWith( + expect.stringContaining("status != 'decommissioned'"), + ['tenant-1'], + ); + }); + + it('should return 0 when no agents exist', async () => { + const mockQuery = jest.fn().mockResolvedValue({ + rows: [{ count: '0' }], + } as unknown as QueryResult); + + const service = new UsageService(makePool(mockQuery)); + const count = await service.getActiveAgentCount('tenant-empty'); + + expect(count).toBe(0); + }); + }); +}); + +// ════════════════════════════════════════════════════════════════════════════ +// BillingService +// ════════════════════════════════════════════════════════════════════════════ + +describe('BillingService', () => { + describe('createCheckoutSession()', () => { + it('should return the Stripe checkout URL', async () => { + const mockQuery = jest.fn(); + const stripe = makeStripe({ checkoutUrl: 'https://checkout.stripe.com/test' }); + const service = new BillingService(makePool(mockQuery), stripe); + + const url = await service.createCheckoutSession( + 'tenant-1', + 'https://app.com/success', + 'https://app.com/cancel', + ); + + expect(url).toBe('https://checkout.stripe.com/test'); + }); + }); + + describe('getSubscriptionStatus()', () => { + it('should return free status when no subscription row exists', async () => { + const mockQuery = jest.fn().mockResolvedValue({ rows: [] } as unknown as QueryResult); + const stripe = makeStripe(); + const service = new BillingService(makePool(mockQuery), stripe); + + const status = await service.getSubscriptionStatus('tenant-free'); + + expect(status).toEqual({ + tenantId: 'tenant-free', + status: 'free', + currentPeriodEnd: null, + stripeSubscriptionId: null, + }); + }); + + it('should return subscription data when a row exists', async () => { + const periodEnd = new Date('2026-05-01T00:00:00Z'); + const mockQuery = jest.fn().mockResolvedValue({ + rows: [{ + status: 'active', + current_period_end: periodEnd, + stripe_subscription_id: 'sub_abc123', + }], + } as unknown as QueryResult); + + const stripe = makeStripe(); + const service = new BillingService(makePool(mockQuery), stripe); + + const status = await service.getSubscriptionStatus('tenant-paid'); + + expect(status.status).toBe('active'); + expect(status.stripeSubscriptionId).toBe('sub_abc123'); + expect(status.currentPeriodEnd).toEqual(periodEnd); + }); + }); + + describe('handleWebhookEvent()', () => { + it('should process customer.subscription.created and upsert DB', async () => { + const mockQuery = jest.fn().mockResolvedValue({ rows: [] } as unknown as QueryResult); + const stripe = makeStripe(); + const service = new BillingService(makePool(mockQuery), stripe); + + await service.handleWebhookEvent( + Buffer.from('{}'), + 'stripe-sig-header', + 'whsec_test', + ); + + // constructEvent should have been called with raw body, sig, and secret + expect(stripe.webhooks.constructEvent).toHaveBeenCalledWith( + expect.any(Buffer), + 'stripe-sig-header', + 'whsec_test', + ); + + // DB upsert should have been called + expect(mockQuery).toHaveBeenCalledWith( + expect.stringContaining('INSERT INTO tenant_subscriptions'), + expect.any(Array), + ); + }); + + it('should throw when Stripe signature verification fails', async () => { + const mockQuery = jest.fn(); + const sigError = new Error('Stripe signature verification failed'); + const stripe = makeStripe({ + constructEvent: () => { throw sigError; }, + }); + + const service = new BillingService(makePool(mockQuery), stripe); + + await expect( + service.handleWebhookEvent(Buffer.from('{}'), 'bad-sig', 'whsec_test'), + ).rejects.toThrow('Stripe signature verification failed'); + }); + + it('should skip processing for unrecognised event types', async () => { + const mockQuery = jest.fn(); + + const unknownEvent: unknown = { + id: 'evt_unknown', + type: 'payment_intent.created', + object: 'event', + data: { object: {} }, + api_version: '2026-03-25.dahlia', + created: Math.floor(Date.now() / 1000), + livemode: false, + pending_webhooks: 0, + request: null, + }; + + const stripe = makeStripe({ + constructEvent: () => unknownEvent as Stripe.Event, + }); + + const service = new BillingService(makePool(mockQuery), stripe); + await service.handleWebhookEvent(Buffer.from('{}'), 'sig', 'whsec_test'); + + // No DB query for unrecognised events + expect(mockQuery).not.toHaveBeenCalled(); + }); + }); +});