Source code for swh.provenance.backend.postgresql

# Copyright (C) 2024  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

from contextlib import contextmanager
import logging
from typing import Any, List, Optional, Union

import psycopg
import psycopg_pool

from swh.core.db import BaseDb
from swh.core.db.common import db_transaction
from swh.core.db.db_utils import swh_db_version
from swh.model.swhids import CoreSWHID, QualifiedSWHID
from swh.provenance.exc import ProvenanceDBError

logger = logging.getLogger(__name__)


[docs] class Db(BaseDb): """ PostgreSQL backend for the Software Heritage provenance index. """
[docs] class PostgresqlProvenance: current_version: int = 1 def __init__( self, db: Union[str, psycopg.Connection[Any]], min_pool_conns: int = 1, max_pool_conns: int = 10, ): self._db: Optional[Db] self._pool: Optional[psycopg_pool.ConnectionPool] try: if isinstance(db, str): self._pool = psycopg_pool.ConnectionPool( conninfo=db, min_size=min_pool_conns, max_size=max_pool_conns, open=False, ) self._db = None # Wait for the first connection to be ready, and raise the # appropriate exception if connection fails self._pool.open(wait=True, timeout=1) else: self._pool = None self._db = Db(db) except psycopg.OperationalError as e: raise ProvenanceDBError(e)
[docs] def get_db(self) -> Db: if self._db: return self._db else: assert self._pool is not None return Db.from_pool(self._pool)
[docs] def put_db(self, db: Db): if db is not self._db: db.put_conn()
[docs] @contextmanager def db(self): db = None try: db = self.get_db() yield db finally: if db: self.put_db(db)
[docs] @db_transaction() def check_config(self, *, check_write: bool, db: Db, cur=None) -> bool: dbversion = swh_db_version(db.conn) if dbversion != self.current_version: logger.warning( "database dbversion (%s) != %s current_version (%s)", dbversion, __name__, self.current_version, ) return False # Check permissions on one of the tables check = "INSERT" if check_write else "SELECT" cur.execute( "select has_table_privilege(current_user, 'content_in_revision', %s)", (check,), ) return cur.fetchone()[0]
[docs] @db_transaction() def whereis( self, swhid: CoreSWHID, *, db: Db, cur=None ) -> Optional[QualifiedSWHID]: return QualifiedSWHID( object_type=swhid.object_type, object_id=swhid.object_id, )
[docs] @db_transaction() def whereare(self, *, swhids: List[CoreSWHID]) -> List[Optional[QualifiedSWHID]]: """Given a list SWHID return a list of provenance info: See `whereis` documentation for details on the provenance info. """ return [self.whereis(swhid=si) for si in swhids]