diff --git a/backend/app/api/endpoints/youtube_music.py b/backend/app/api/endpoints/youtube_music.py index 66deeb3..e433c9b 100644 --- a/backend/app/api/endpoints/youtube_music.py +++ b/backend/app/api/endpoints/youtube_music.py @@ -1,14 +1,14 @@ import asyncio -from functools import partial from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings -from app.core.database import get_db -from app.core.security import get_current_user +from app.core.security import ALGORITHM +from app.core.database import async_session from app.models.user import User from app.models.playlist import Playlist from app.models.track import Track @@ -17,6 +17,7 @@ from app.services.recommender import build_taste_profile from app.schemas.playlist import PlaylistDetailResponse router = APIRouter(prefix="/youtube-music", tags=["youtube-music"]) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") class ImportYouTubeRequest(BaseModel): @@ -35,29 +36,29 @@ class YouTubeTrackResult(BaseModel): image_url: str | None = None +def _get_user_id_from_token(token: str) -> int: + """Extract user ID from JWT without hitting the database.""" + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id = payload.get("sub") + if user_id is None: + raise HTTPException(status_code=401, detail="Invalid credentials") + return int(user_id) + except JWTError: + raise HTTPException(status_code=401, detail="Invalid credentials") + + @router.post("/import", response_model=PlaylistDetailResponse) async def import_youtube_playlist( data: ImportYouTubeRequest, - user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db), + token: str = Depends(oauth2_scheme), ): - # Free tier limit - if not user.is_pro: - result = await db.execute( - select(Playlist).where(Playlist.user_id == user.id) - ) - existing = list(result.scalars().all()) - if len(existing) >= settings.FREE_MAX_PLAYLISTS: - raise HTTPException( - status_code=403, - detail="Free tier limited to 1 playlist. Upgrade to Pro for unlimited.", - ) + user_id = _get_user_id_from_token(token) - # Fetch tracks from YouTube Music (run sync ytmusicapi in thread) - loop = asyncio.get_event_loop() + # Run sync ytmusicapi in a thread — NO DB connection open during this try: - playlist_name, playlist_image, raw_tracks = await loop.run_in_executor( - None, partial(get_playlist_tracks, data.url) + playlist_name, playlist_image, raw_tracks = await asyncio.to_thread( + get_playlist_tracks, data.url ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -67,52 +68,85 @@ async def import_youtube_playlist( if not raw_tracks: raise HTTPException(status_code=400, detail="Playlist is empty or could not be read.") - # Create playlist - playlist = Playlist( - user_id=user.id, - name=playlist_name, - platform_source="youtube_music", - external_id=data.url, - track_count=len(raw_tracks), - ) - db.add(playlist) - await db.flush() + # Now do all DB work in a fresh session + async with async_session() as db: + try: + # Verify user exists + result = await db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=401, detail="User not found") - # Create tracks (no audio features available from YouTube Music) - tracks = [] - for rt in raw_tracks: - track = Track( - playlist_id=playlist.id, - title=rt["title"], - artist=rt["artist"], - album=rt.get("album"), - image_url=rt.get("image_url"), - ) - db.add(track) - tracks.append(track) + # Free tier limit + if not user.is_pro: + result = await db.execute( + select(Playlist).where(Playlist.user_id == user.id) + ) + existing = list(result.scalars().all()) + if len(existing) >= settings.FREE_MAX_PLAYLISTS: + raise HTTPException( + status_code=403, + detail="Free tier limited to 1 playlist. Upgrade to Premium for unlimited.", + ) - await db.flush() + playlist = Playlist( + user_id=user.id, + name=playlist_name, + platform_source="youtube_music", + external_id=data.url, + track_count=len(raw_tracks), + ) + db.add(playlist) + await db.flush() - # Build taste profile (without audio features, will be limited) - playlist.taste_profile = build_taste_profile(tracks) - playlist.tracks = tracks + tracks = [] + for rt in raw_tracks: + track = Track( + playlist_id=playlist.id, + title=rt["title"], + artist=rt["artist"], + album=rt.get("album"), + image_url=rt.get("image_url"), + ) + db.add(track) + tracks.append(track) - return playlist + await db.flush() + + playlist.taste_profile = build_taste_profile(tracks) + + await db.commit() + + # Return response manually to avoid lazy-load issues + return PlaylistDetailResponse( + id=playlist.id, + name=playlist.name, + platform_source=playlist.platform_source, + track_count=playlist.track_count, + taste_profile=playlist.taste_profile, + imported_at=playlist.imported_at, + tracks=[], + ) + except HTTPException: + await db.rollback() + raise + except Exception: + await db.rollback() + raise @router.post("/search", response_model=list[YouTubeTrackResult]) async def search_youtube_music( data: SearchYouTubeRequest, - user: User = Depends(get_current_user), + token: str = Depends(oauth2_scheme), ): + _get_user_id_from_token(token) # Just verify auth + if not data.query.strip(): raise HTTPException(status_code=400, detail="Query cannot be empty") - loop = asyncio.get_event_loop() try: - results = await loop.run_in_executor( - None, partial(search_track, data.query.strip()) - ) + results = await asyncio.to_thread(search_track, data.query.strip()) except Exception: raise HTTPException(status_code=500, detail="Failed to search YouTube Music")