- Add stripe_customer_id and stripe_subscription_id fields to User model - Add Stripe config settings (secret key, publishable key, price ID, webhook secret) - Create billing API endpoints: checkout session, webhook handler, portal, status - Add frontend Billing page with upgrade/manage subscription UI - Add billing route and Pro nav link - Add stripe dependency to requirements
153 lines
4.8 KiB
Python
153 lines
4.8 KiB
Python
import stripe
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.core.config import settings
|
|
from app.core.database import get_db
|
|
from app.core.security import get_current_user
|
|
from app.models.user import User
|
|
|
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
|
|
|
stripe.api_key = settings.STRIPE_SECRET_KEY
|
|
|
|
|
|
@router.post("/create-checkout")
|
|
async def create_checkout(
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
if user.is_pro:
|
|
raise HTTPException(status_code=400, detail="Already subscribed to Pro")
|
|
|
|
# Create Stripe customer if needed
|
|
if not user.stripe_customer_id:
|
|
customer = stripe.Customer.create(
|
|
email=user.email,
|
|
name=user.name,
|
|
metadata={"vynl_user_id": str(user.id)},
|
|
)
|
|
user.stripe_customer_id = customer.id
|
|
await db.flush()
|
|
|
|
session = stripe.checkout.Session.create(
|
|
customer=user.stripe_customer_id,
|
|
mode="subscription",
|
|
line_items=[{"price": settings.STRIPE_PRICE_ID, "quantity": 1}],
|
|
success_url=f"{settings.FRONTEND_URL}/billing?success=true",
|
|
cancel_url=f"{settings.FRONTEND_URL}/billing?canceled=true",
|
|
metadata={"vynl_user_id": str(user.id)},
|
|
)
|
|
|
|
return {"url": session.url}
|
|
|
|
|
|
@router.post("/webhook")
|
|
async def stripe_webhook(
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
payload = await request.body()
|
|
sig_header = request.headers.get("stripe-signature", "")
|
|
|
|
try:
|
|
event = stripe.Webhook.construct_event(
|
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Invalid payload")
|
|
except stripe.SignatureVerificationError:
|
|
raise HTTPException(status_code=400, detail="Invalid signature")
|
|
|
|
event_type = event["type"]
|
|
data = event["data"]["object"]
|
|
|
|
if event_type == "checkout.session.completed":
|
|
customer_id = data.get("customer")
|
|
subscription_id = data.get("subscription")
|
|
if customer_id:
|
|
result = await db.execute(
|
|
select(User).where(User.stripe_customer_id == customer_id)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
if user:
|
|
user.is_pro = True
|
|
user.stripe_subscription_id = subscription_id
|
|
await db.flush()
|
|
|
|
elif event_type == "customer.subscription.deleted":
|
|
customer_id = data.get("customer")
|
|
if customer_id:
|
|
result = await db.execute(
|
|
select(User).where(User.stripe_customer_id == customer_id)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
if user:
|
|
user.is_pro = False
|
|
user.stripe_subscription_id = None
|
|
await db.flush()
|
|
|
|
elif event_type == "customer.subscription.updated":
|
|
customer_id = data.get("customer")
|
|
sub_status = data.get("status")
|
|
if customer_id:
|
|
result = await db.execute(
|
|
select(User).where(User.stripe_customer_id == customer_id)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
if user:
|
|
user.is_pro = sub_status in ("active", "trialing")
|
|
user.stripe_subscription_id = data.get("id")
|
|
await db.flush()
|
|
|
|
elif event_type == "invoice.payment_failed":
|
|
customer_id = data.get("customer")
|
|
if customer_id:
|
|
result = await db.execute(
|
|
select(User).where(User.stripe_customer_id == customer_id)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
if user:
|
|
user.is_pro = False
|
|
await db.flush()
|
|
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.post("/portal")
|
|
async def create_portal(
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
if not user.stripe_customer_id:
|
|
raise HTTPException(status_code=400, detail="No billing account found")
|
|
|
|
session = stripe.billing_portal.Session.create(
|
|
customer=user.stripe_customer_id,
|
|
return_url=f"{settings.FRONTEND_URL}/billing",
|
|
)
|
|
|
|
return {"url": session.url}
|
|
|
|
|
|
@router.get("/status")
|
|
async def billing_status(
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
subscription_status = None
|
|
current_period_end = None
|
|
|
|
if user.stripe_subscription_id:
|
|
try:
|
|
sub = stripe.Subscription.retrieve(user.stripe_subscription_id)
|
|
subscription_status = sub.status
|
|
current_period_end = sub.current_period_end
|
|
except stripe.StripeError:
|
|
pass
|
|
|
|
return {
|
|
"is_pro": user.is_pro,
|
|
"subscription_status": subscription_status,
|
|
"current_period_end": current_period_end,
|
|
}
|