"""Async SQLAlchemy database setup for PostgreSQL.""" import logging import time from typing import AsyncGenerator from sqlalchemy import event from sqlalchemy.engine import Connection as SyncConnection # Opaque DBAPI-specific types for SQLAlchemy cursor-execute event parameters. # These are driver-dependent and not meaningfully typed by SQLAlchemy stubs. type _DBAPICursor = str | bytes | int | None type _ExecParams = tuple[str | int | float | None, ...] | dict[str, str | int | float | None] | None type _ExecContext = str | int | None from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import DeclarativeBase from musehub.config import settings as settings logger = logging.getLogger(__name__) class Base(DeclarativeBase): """Base class for SQLAlchemy models.""" pass # Global engine and session factory (initialized on startup) _engine: AsyncEngine | None = None _async_session_factory: async_sessionmaker[AsyncSession] | None = None def get_database_url() -> str: """Get the database URL from settings.""" url = settings.database_url if not url: raise RuntimeError( "DATABASE_URL must be set. " "Example: postgresql+asyncpg://musehub:musehub@localhost:5432/musehub" ) return url async def init_db() -> None: """Initialize database engine and session factory. Schema is managed by Alembic (``alembic upgrade head`` runs in the container entrypoint *before* the app starts). This function only creates the async engine and session factory. """ global _engine, _async_session_factory database_url = get_database_url() logger.info(f"Initializing database: {database_url.split('@')[-1] if '@' in database_url else database_url}") # pool_size=20/max_overflow=40: handles ~20 concurrent requests without # queueing, with headroom for bursty agent traffic. _engine = create_async_engine( database_url, echo=settings.debug, pool_size=20, max_overflow=40, pool_recycle=1800, pool_timeout=settings.db_pool_timeout, # Verify connections before checkout. If a connection was returned to # the pool in a broken transaction state (e.g. after an asyncpg-level # error during autobegin), pre_ping discards it and issues a fresh one # rather than propagating the "cannot use Connection.transaction() in a # manually started transaction" error to the next request. pool_pre_ping=True, # TCP keepalive every 30s — prevents asyncpg idle-connection reaping # during long-running push/fetch operations. connect_args={ "server_settings": { "tcp_keepalives_idle": "30", "statement_timeout": "60000", # 60s — catches runaway queries }, }, ) # Slow query log: warn on any statement exceeding the configured threshold. _slow_ms = settings.slow_query_threshold_ms if _slow_ms > 0: @event.listens_for(_engine.sync_engine, "before_cursor_execute") def _before_cursor_execute( conn: SyncConnection, cursor: _DBAPICursor, statement: str, parameters: _ExecParams, context: _ExecContext, executemany: bool ) -> None: conn.info["query_start_time"] = time.monotonic() @event.listens_for(_engine.sync_engine, "after_cursor_execute") def _after_cursor_execute( conn: SyncConnection, cursor: _DBAPICursor, statement: str, parameters: _ExecParams, context: _ExecContext, executemany: bool ) -> None: elapsed_ms = (time.monotonic() - conn.info.get("query_start_time", 0)) * 1000 if elapsed_ms >= _slow_ms: logger.warning( "SLOW QUERY (%.1f ms): %.200s", elapsed_ms, statement.replace("\n", " "), ) _async_session_factory = async_sessionmaker( bind=_engine, class_=AsyncSession, expire_on_commit=False, ) # Import models so relationships resolve even though Alembic owns DDL. from musehub.db import muse_cli_models # noqa: F401 from musehub.db.database import Base # noqa: PLC0415 from musehub.db.schema_check import assert_schema_matches_orm # noqa: PLC0415 await assert_schema_matches_orm(_engine, Base) logger.info("✅ Database initialised — ORM and schema are in sync") def get_engine() -> AsyncEngine: """Return the module-level engine. Raises if init_db() has not been called.""" if _engine is None: raise RuntimeError("Database not initialized. Call init_db() first.") return _engine async def close_db() -> None: """Close database connection.""" global _engine, _async_session_factory if _engine: await _engine.dispose() _engine = None _async_session_factory = None logger.info("Database connection closed") async def get_db() -> AsyncGenerator[AsyncSession, None]: """ Dependency for getting async database sessions. Usage: @app.get("/users") async def get_users(db: AsyncSession = Depends(get_db)): ... """ if _async_session_factory is None: raise RuntimeError("Database not initialized. Call init_db() first.") async with _async_session_factory() as session: try: yield session await session.commit() except Exception: await session.rollback() raise def AsyncSessionLocal() -> AsyncSession: """ Get a new async session directly (for non-FastAPI contexts). Usage: async with AsyncSessionLocal() as session: ... """ if _async_session_factory is None: raise RuntimeError("Database not initialized. Call init_db() first.") return _async_session_factory()