import stripe from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.config import settings from app.core.database import get_db from app.core.security import get_current_user from app.models.user import User router = APIRouter(prefix="/billing", tags=["billing"]) stripe.api_key = settings.STRIPE_SECRET_KEY @router.post("/create-checkout") async def create_checkout( user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if user.is_pro: raise HTTPException(status_code=400, detail="Already subscribed to Pro") # Create Stripe customer if needed if not user.stripe_customer_id: customer = stripe.Customer.create( email=user.email, name=user.name, metadata={"vynl_user_id": str(user.id)}, ) user.stripe_customer_id = customer.id await db.flush() session = stripe.checkout.Session.create( customer=user.stripe_customer_id, mode="subscription", line_items=[{"price": settings.STRIPE_PRICE_ID, "quantity": 1}], success_url=f"{settings.FRONTEND_URL}/billing?success=true", cancel_url=f"{settings.FRONTEND_URL}/billing?canceled=true", metadata={"vynl_user_id": str(user.id)}, ) return {"url": session.url} @router.post("/webhook") async def stripe_webhook( request: Request, db: AsyncSession = Depends(get_db), ): payload = await request.body() sig_header = request.headers.get("stripe-signature", "") try: event = stripe.Webhook.construct_event( payload, sig_header, settings.STRIPE_WEBHOOK_SECRET ) except ValueError: raise HTTPException(status_code=400, detail="Invalid payload") except stripe.SignatureVerificationError: raise HTTPException(status_code=400, detail="Invalid signature") event_type = event["type"] data = event["data"]["object"] if event_type == "checkout.session.completed": customer_id = data.get("customer") subscription_id = data.get("subscription") if customer_id: result = await db.execute( select(User).where(User.stripe_customer_id == customer_id) ) user = result.scalar_one_or_none() if user: user.is_pro = True user.stripe_subscription_id = subscription_id await db.flush() elif event_type == "customer.subscription.deleted": customer_id = data.get("customer") if customer_id: result = await db.execute( select(User).where(User.stripe_customer_id == customer_id) ) user = result.scalar_one_or_none() if user: user.is_pro = False user.stripe_subscription_id = None await db.flush() elif event_type == "customer.subscription.updated": customer_id = data.get("customer") sub_status = data.get("status") if customer_id: result = await db.execute( select(User).where(User.stripe_customer_id == customer_id) ) user = result.scalar_one_or_none() if user: user.is_pro = sub_status in ("active", "trialing") user.stripe_subscription_id = data.get("id") await db.flush() elif event_type == "invoice.payment_failed": customer_id = data.get("customer") if customer_id: result = await db.execute( select(User).where(User.stripe_customer_id == customer_id) ) user = result.scalar_one_or_none() if user: user.is_pro = False await db.flush() return {"status": "ok"} @router.post("/portal") async def create_portal( user: User = Depends(get_current_user), ): if not user.stripe_customer_id: raise HTTPException(status_code=400, detail="No billing account found") session = stripe.billing_portal.Session.create( customer=user.stripe_customer_id, return_url=f"{settings.FRONTEND_URL}/billing", ) return {"url": session.url} @router.get("/status") async def billing_status( user: User = Depends(get_current_user), ): subscription_status = None current_period_end = None if user.stripe_subscription_id: try: sub = stripe.Subscription.retrieve(user.stripe_subscription_id) subscription_status = sub.status current_period_end = sub.current_period_end except stripe.StripeError: pass return { "is_pro": user.is_pro, "subscription_status": subscription_status, "current_period_end": current_period_end, }