Source code for swh.storage.proxies.blocking
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager
from typing import Dict, Iterable, Iterator, List, Optional, Union
import warnings
import psycopg_pool
from swh.model.model import Origin, OriginVisit, OriginVisitStatus
from swh.storage import get_storage
from swh.storage.exc import BlockedOriginException
from swh.storage.interface import StorageInterface
from swh.storage.metrics import DifferentialTimer
from swh.storage.proxies.blocking.db import BlockingState
from .db import BlockingQuery
BLOCKING_OVERHEAD_METRIC = "swh_storage_blocking_overhead_seconds"
[docs]
def get_datastore(cls, db=None, blocking_db=None, **kwargs):
    assert cls in ("postgresql", "blocking")
    from .db import BlockingAdmin
    if db is None:
        db = blocking_db
    return BlockingAdmin.connect(db) 
[docs]
def blocking_overhead_timer(method_name: str) -> DifferentialTimer:
    """Return a properly setup DifferentialTimer for ``method_name`` of the storage"""
    return DifferentialTimer(BLOCKING_OVERHEAD_METRIC, tags={"endpoint": method_name}) 
[docs]
class BlockingProxyStorage:
    """Blocking storage proxy
    This proxy prevents visits from a known list of origins to be performed at all.
    It uses a specific PostgreSQL database (which for now is colocated with the
    swh.storage PostgreSQL database), the access to which is implemented in the
    :mod:`.db` submodule.
    Sample configuration
    .. code-block: yaml
        storage:
          cls: blocking
          db: 'dbname=swh-blocking-proxy'
          max_pool_conns: 10
          storage:
          - cls: remote
            url: http://storage.internal.staging.swh.network:5002/
    """
    def __init__(
        self,
        storage: Union[Dict, StorageInterface],
        db: Optional[str] = None,
        blocking_db: Optional[str] = None,
        min_pool_conns: int = 1,
        max_pool_conns: int = 5,
    ):
        if db is None:
            assert blocking_db is not None
            warnings.warn(
                "'blocking_db' field in the blocking storage configuration "
                "was renamed 'db' field",
                DeprecationWarning,
            )
            db = blocking_db
        self.storage: StorageInterface = (
            get_storage(**storage) if isinstance(storage, dict) else storage
        )
        self._blocking_pool = psycopg_pool.ConnectionPool(
            db,
            min_size=min_pool_conns,
            max_size=max_pool_conns,
        )
[docs]
    def origin_visit_status_add(
        self, visit_statuses: List[OriginVisitStatus]
    ) -> Dict[str, int]:
        with self._blocking_query() as q:
            statuses = q.origins_are_blocked([v.origin for v in visit_statuses])
            if statuses and any(
                status.state != BlockingState.NON_BLOCKED
                for status in statuses.values()
            ):
                raise BlockedOriginException(statuses)
        return self.storage.origin_visit_status_add(visit_statuses) 
[docs]
    def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
        with self._blocking_query() as q:
            statuses = q.origins_are_blocked([v.origin for v in visits])
            if statuses and any(
                status.state != BlockingState.NON_BLOCKED
                for status in statuses.values()
            ):
                raise BlockedOriginException(statuses)
        return self.storage.origin_visit_add(visits) 
[docs]
    def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
        with self._blocking_query() as q:
            statuses = {}
            for origin in origins:
                status = q.origin_is_blocked(origin.url)
                if status and status.state != BlockingState.NON_BLOCKED:
                    statuses[origin.url] = status
            if statuses:
                raise BlockedOriginException(statuses)
        return self.storage.origin_add(origins) 
    @contextmanager
    def _blocking_query(self) -> Iterator[BlockingQuery]:
        ret = None
        try:
            ret = BlockingQuery.from_pool(self._blocking_pool)
            yield ret
        finally:
            if ret:
                ret.put_conn()
    def __getattr__(self, key):
        method = getattr(self.storage, key)
        if method:
            # Avoid going through __getattr__ again next time
            setattr(self, key, method)
            return method
        # Raise a NotImplementedError to make sure we don't forget to add
        # masking to any new storage functions
        raise NotImplementedError(key)