# Copyright (C) 2024-2025 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager
import functools
import inspect
import itertools
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    TypeVar,
    Union,
)
import warnings
import attr
import psycopg_pool
from swh.core.utils import grouper
from swh.model.hashutil import MultiHash
from swh.model.model import (
    ExtID,
    Origin,
    Person,
    RawExtrinsicMetadata,
    Release,
    Revision,
    Sha1Git,
)
from swh.model.swhids import ExtendedObjectType, ExtendedSWHID
from swh.storage import get_storage
from swh.storage.exc import MaskedObjectException
from swh.storage.interface import HashDict, PagedResult, Sha1, StorageInterface, TResult
from swh.storage.metrics import DifferentialTimer
from swh.storage.proxies.masking.db import MaskedStatus
from .db import MaskingQuery
BATCH_SIZE = 1024
MASKING_OVERHEAD_METRIC = "swh_storage_masking_overhead_seconds"
[docs]
def get_datastore(cls, db=None, masking_db=None, **kwargs):
    assert cls in ("postgresql", "masking")
    from .db import MaskingAdmin
    if db is None:
        db = masking_db
    return MaskingAdmin.connect(db) 
[docs]
def masking_overhead_timer(method_name: str) -> DifferentialTimer:
    """Return a properly setup DifferentialTimer for ``method_name`` of the storage"""
    return DifferentialTimer(MASKING_OVERHEAD_METRIC, tags={"endpoint": method_name}) 
[docs]
class MaskingProxyStorage:
    """Masking storage proxy
    This proxy can return modified objects or stop them from being retrieved at
    all.
    It uses a specific PostgreSQL database (which for now is colocated with the
    swh.storage PostgreSQL database), the access to which is implemented in the
    :mod:`.db` submodule.
    Sample configuration
    .. code-block: yaml
        storage:
          cls: masking
          db: 'dbname=swh-storage'
          max_pool_conns: 10
          storage:
          - cls: remote
            url: http://storage.internal.staging.swh.network:5002/
    """
    def __init__(
        self,
        storage: Union[Dict, StorageInterface],
        db: Optional[str] = None,
        masking_db: Optional[str] = None,
        min_pool_conns: int = 1,
        max_pool_conns: int = 5,
    ):
        if db is None:
            assert masking_db is not None
            warnings.warn(
                "'masking_db' field in the masking storage configuration "
                "was renamed 'db' field",
                DeprecationWarning,
            )
            db = masking_db
        self.storage: StorageInterface = (
            get_storage(**storage) if isinstance(storage, dict) else storage
        )
        self._masking_pool = psycopg_pool.ConnectionPool(
            db,
            min_size=min_pool_conns,
            max_size=max_pool_conns,
        )
        # Generate the method dictionaries once per instantiation, instead of
        # doing it on every (first) __getattr__ call.
        self._gen_method_dicts()
    @contextmanager
    def _masking_query(self) -> Iterator[MaskingQuery]:
        ret = None
        try:
            ret = MaskingQuery.from_pool(self._masking_pool)
            yield ret
        finally:
            if ret:
                ret.put_conn()
    @staticmethod
    def _get_swhids_in_result(method_name: str, result: Any) -> List[ExtendedSWHID]:
        if result is None:
            raise TypeError(f"Filtering of Nones missing in {method_name}")
        result_type = getattr(result, "object_type", None)
        if result_type == RawExtrinsicMetadata.object_type:
            # Raw Extrinsic Metadata have a swhid, but we also mask them if the target is masked
            return [result.swhid(), result.target]
        if hasattr(result, "swhid"):
            swhid = result.swhid()
            if hasattr(swhid, "to_extended"):
                swhid = swhid.to_extended()
            assert isinstance(
                swhid, ExtendedSWHID
            ), f"{method_name} returned object with unexpected swhid() method return type"
            return [swhid]
        if result_type:
            if result_type == ExtID.object_type:
                return [result.target.to_extended()]
            if result_type.value.startswith("origin_visit"):
                return [Origin(url=result.origin).swhid()]
        if (
            object_type := method_name.removesuffix("_get_random").removesuffix(
                "_get_id_partition"
            )
        ) != method_name:
            # Returns bare sha1_gits
            assert isinstance(result, bytes), f"{method_name} returned unexpected type"
            return [
                ExtendedSWHID(
                    object_type=ExtendedObjectType[object_type.upper()],
                    object_id=result,
                )
            ]
        elif method_name == "revision_log":
            # returns dicts of revisions
            assert (
                isinstance(result, dict) and "id" in result
            ), f"{method_name} returned unexpected type"
            return [
                ExtendedSWHID(
                    object_type=ExtendedObjectType.REVISION, object_id=result["id"]
                )
            ]
        elif method_name == "revision_shortlog":
            # Returns tuples (revision.id, revision.parents)
            assert isinstance(
                result[0], bytes
            ), f"{method_name} returned unexpected type"
            return [
                ExtendedSWHID(
                    object_type=ExtendedObjectType.REVISION, object_id=result[0]
                )
            ]
        elif method_name == "origin_get_by_sha1":
            # Returns origin dicts, because why not
            assert (
                isinstance(result, dict) and "url" in result
            ), f"{method_name} returned unexpected type"
            return [Origin(url=result["url"]).swhid()]
        elif method_name == "origin_visit_get_with_statuses":
            # Returns an OriginVisitWithStatuses
            return [Origin(url=result.visit.origin).swhid()]
        elif method_name == "snapshot_get":
            # Returns a snapshot dict
            return [
                ExtendedSWHID(
                    object_type=ExtendedObjectType.SNAPSHOT, object_id=result["id"]
                )
            ]
        raise ValueError(f"Cannot get swhid for result of method {method_name}")
    def _masked_result(
        self, method_name: str, result: Any
    ) -> Optional[Dict[ExtendedSWHID, List[MaskedStatus]]]:
        """Find the SWHIDs of the ``result`` object, and check if any of them is
        masked, returning the associated masking information."""
        with self._masking_query() as q:
            return q.swhids_are_masked(self._get_swhids_in_result(method_name, result))
    def _raise_if_masked_result(self, method_name: str, result: Any) -> None:
        """Raise a :exc:`MaskedObjectException` if ``result`` is masked."""
        masked = self._masked_result(method_name, result)
        if masked:
            raise MaskedObjectException(masked)
    def _raise_if_masked_swhids(self, swhids: List[ExtendedSWHID]) -> None:
        """Raise a :exc:`MaskedObjectException` if any SWHID is masked."""
        with self._masking_query() as q:
            masked = q.swhids_are_masked(swhids)
        if masked:
            raise MaskedObjectException(masked)
    def __getattr__(self, key):
        method = None
        if key in self._methods_by_name:
            method = self._methods_by_name[key](key)
        else:
            suffix = key.rsplit("_", 1)[-1]
            if suffix in self._methods_by_suffix:
                method = self._methods_by_suffix[suffix](key)
        if method:
            # Avoid going through __getattr__ again next time
            setattr(self, key, method)
            return method
        # Raise a NotImplementedError to make sure we don't forget to add
        # masking to any new storage functions
        raise NotImplementedError(key)
    def _gen_method_dicts(self):
        """Generate the :attr:`_methods_by_name` and :attr:`_methods_by_suffix`
        used by :meth:`__getattr__`"""
        _passthrough = functools.partial(getattr, self.storage)
        self._methods_by_name = {
            # Returns a single object
            "snapshot_get": self._getter_optional,
            "origin_visit_find_by_date": self._getter_optional,
            "origin_visit_get_by": self._getter_optional,
            "origin_visit_status_get_latest": self._getter_optional,
            "origin_visit_get_latest": self._getter_optional,
            # Returns a PagedResult
            "origin_list": self._getter_pagedresult,
            "origin_visit_get": self._getter_pagedresult,
            "origin_search": self._getter_pagedresult,
            "raw_extrinsic_metadata_get": self._getter_pagedresult,
            "origin_visit_get_with_statuses": self._getter_pagedresult,
            "origin_visit_status_get": self._getter_pagedresult,
            # Returns a list of (optional) objects
            "origin_get_by_sha1": self._getter_list,
            "content_find": self._getter_list,
            "skipped_content_find": self._getter_list,
            "revision_shortlog": self._getter_list,
            "extid_get_from_target": self._getter_list,
            "raw_extrinsic_metadata_get_by_ids": self._getter_list,
            # Filter arguments
            "directory_entry_get_by_path": self._getter_filtering_arguments,
            "directory_get_entries": self._getter_filtering_arguments,
            "directory_get_raw_manifest": self._getter_filtering_arguments,
            "directory_ls": self._getter_filtering_arguments,
            "raw_extrinsic_metadata_get_authorities": self._getter_filtering_arguments,
            "snapshot_branch_get_by_name": self._getter_filtering_arguments,
            "snapshot_count_branches": self._getter_filtering_arguments,
            "snapshot_get_branches": self._getter_filtering_arguments,
            "origin_snapshot_get_all": self._getter_filtering_arguments,
            # Content functions that don't match common getter or adder suffixes
            "content_add_metadata": _passthrough,
            "content_missing_per_sha1": _passthrough,
            "content_missing_per_sha1_git": _passthrough,
            "content_update": _passthrough,
            # These objects aren't maskable
            "extid_get_from_extid": _passthrough,
            "object_find_by_sha1_git": _passthrough,
            "object_find_recent_references": _passthrough,
            "metadata_authority_get": _passthrough,
            "metadata_fetcher_get": _passthrough,
            # Utility methods
            "check_config": _passthrough,
            "clear_buffers": _passthrough,
            "flush": _passthrough,
            "origin_count": _passthrough,
            "refresh_stat_counters": _passthrough,
            "stat_counters": _passthrough,
            # For tests
            "journal_writer": _passthrough,
        }
        self._methods_by_suffix = {
            # These methods will never need do any masking
            "add": _passthrough,
            "missing": _passthrough,
            # Partitions return PagedResults
            "partition": self._getter_pagedresult,
            # Getters return lists of optional objects
            "get": self._getter_list,
            "random": self._getter_random,
        }
[docs]
    def content_get_data(self, content: Union[HashDict, Sha1]) -> Optional[bytes]:
        with masking_overhead_timer("content_get_data") as t:
            with t.inner():
                ret = self.storage.content_get_data(content)
            if ret is None:
                return None
            if isinstance(content, dict) and "sha1_git" in content:
                self._raise_if_masked_swhids(
                    [
                        ExtendedSWHID(
                            object_type=ExtendedObjectType.CONTENT,
                            object_id=content["sha1_git"],
                        )
                    ]
                )
            else:
                # We did not get the SWHID of the object as argument, so we need to
                # hash the resulting content to check if its SWHID was masked.
                self._raise_if_masked_swhids(
                    [
                        ExtendedSWHID(
                            object_type=ExtendedObjectType.CONTENT,
                            object_id=MultiHash.from_data(ret, ["sha1_git"]).digest()[
                                "sha1_git"
                            ],
                        )
                    ]
                )
            return ret 
    def _get_swhids_in_args(
        self, method_name: str, parsed_args: Dict[str, Any]
    ) -> List[ExtendedSWHID]:
        """Extract SWHIDs from the parsed arguments of ``method_name``.
        Arguments:
          method_name: name of the called method
          parsed_args: arguments of the method parsed with :func:`inspect.getcallargs`
        """
        if method_name in ("directory_entry_get_by_path", "directory_ls"):
            return [
                ExtendedSWHID(
                    object_type=ExtendedObjectType.DIRECTORY,
                    object_id=parsed_args["directory"],
                )
            ]
        elif method_name == "directory_get_entries":
            return [
                ExtendedSWHID(
                    object_type=ExtendedObjectType.DIRECTORY,
                    object_id=parsed_args["directory_id"],
                )
            ]
        elif method_name.startswith("snapshot_"):
            return [
                ExtendedSWHID(
                    object_type=ExtendedObjectType.SNAPSHOT,
                    object_id=parsed_args["snapshot_id"],
                )
            ]
        elif method_name == "raw_extrinsic_metadata_get_authorities":
            return [parsed_args["target"]]
        elif method_name == "directory_get_raw_manifest":
            return [
                ExtendedSWHID(
                    object_type=ExtendedObjectType.DIRECTORY,
                    object_id=object_id,
                )
                for object_id in parsed_args["directory_ids"]
            ]
        elif method_name == "origin_snapshot_get_all":
            return [Origin(url=parsed_args["origin_url"]).swhid()]
        else:
            raise ValueError(f"Cannot get swhid for arguments of method {method_name}")
    def _getter_filtering_arguments(self, method_name: str):
        """Handles methods that should filter on their argument, instead of the
        returned value. If the underlying storage returns :const:`None`, return
        it, else, raise :exc:`MaskedObjectException` if the requested object is
        masked"""
        @functools.wraps(getattr(self.storage, method_name))
        def newf(*args, **kwargs):
            with masking_overhead_timer(method_name) as t:
                method = getattr(self.storage, method_name)
                with t.inner():
                    result = method(*args, **kwargs)
                if result is None:
                    return None
                signature = inspect.signature(getattr(StorageInterface, method_name))
                bound_args = signature.bind(self, *args, **kwargs)
                self._raise_if_masked_swhids(
                    self._get_swhids_in_args(method_name, bound_args.arguments)
                )
                return result
        return newf
    RANDOM_ATTEMPTS = 5
    def _getter_random(self, method_name: str):
        """Handles methods returning a random object. Try
        :const:`RANDOM_ATTEMPTS` times for a non-masked object, and return it,
        else return :const:`None`."""
        @functools.wraps(getattr(self.storage, method_name))
        def newf(*args, **kwargs):
            with masking_overhead_timer(method_name) as t:
                method = getattr(self.storage, method_name)
                for _ in range(self.RANDOM_ATTEMPTS):
                    with t.inner():
                        result = method(*args, **kwargs)
                    if result is None:
                        return None
                    if not self._masked_result(method_name, result):
                        return result
        return newf
    def _getter_optional(self, method_name: str):
        """Handles methods returning an optional object: if the return value is
        :const:`None`, return it, else, raise a :exc:`MaskedObjectException` if
        the return value should be masked."""
        @functools.wraps(getattr(self.storage, method_name))
        def newf(*args, **kwargs):
            with masking_overhead_timer(method_name) as t:
                method = getattr(self.storage, method_name)
                with t.inner():
                    result = method(*args, **kwargs)
                if result is None:
                    return None
                self._raise_if_masked_result(method_name, result)
                return result
        return newf
    def _raise_if_masked_result_in_list(
        self, method_name: str, results: Iterable[TResult]
    ) -> List[TResult]:
        """Raise a :exc:`MaskedObjectException` if any non-:const:`None` object
        in ``results`` is masked."""
        result_swhids = set()
        results = list(results)
        for result in results:
            if result is not None:
                result_swhids.update(self._get_swhids_in_result(method_name, result))
        if result_swhids:
            self._raise_if_masked_swhids(list(result_swhids))
        return results
    def _getter_list(
        self,
        method_name: str,
    ):
        """Handle methods returning a list (or a generator) of optional objects,
        raising :exc:`MaskedObjectException` for all the masked objects in the
        batch."""
        @functools.wraps(getattr(self.storage, method_name))
        def newf(*args, **kwargs):
            with masking_overhead_timer(method_name) as t:
                method = getattr(self.storage, method_name)
                with t.inner():
                    results = list(method(*args, **kwargs))
                self._raise_if_masked_result_in_list(method_name, results)
                return results
        return newf
    def _getter_pagedresult(self, method_name: str) -> Callable[..., PagedResult]:
        """Handle methods returning a :cls:`PagedResult`, raising
        :exc:`MaskedObjectException` if some objects in the returned page are
        masked."""
        @functools.wraps(getattr(self.storage, method_name))
        def newf(*args, **kwargs) -> PagedResult:
            with masking_overhead_timer(method_name) as t:
                method = getattr(self.storage, method_name)
                with t.inner():
                    results = method(*args, **kwargs)
                self._raise_if_masked_result_in_list(method_name, results.results)
                return results
        return newf
    # Patching proxy feature set
    TRevision = TypeVar("TRevision", Revision, Optional[Revision])
    def _apply_revision_display_names(
        self, revisions: List[TRevision]
    ) -> List[TRevision]:
        emails = set()
        for rev in revisions:
            if (
                rev is not None
                and rev.author is not None
                and rev.author.email  # ignore None or empty email addresses
            ):
                emails.add(rev.author.email)
            if (
                rev is not None
                and rev.committer is not None
                and rev.committer.email  # ignore None or empty email addresses
            ):
                emails.add(rev.committer.email)
        with self._masking_query() as q:
            display_names = q.display_name(list(emails))
        # Short path for the common case
        if not display_names:
            return revisions
        persons: Dict[Optional[bytes], Person] = {
            email: Person.from_fullname(display_name)
            for (email, display_name) in display_names.items()
        }
        return [
            (
                None
                if revision is None
                else attr.evolve(
                    revision,
                    author=(
                        revision.author
                        if revision.author is None
                        else persons.get(revision.author.email, revision.author)
                    ),
                    committer=(
                        revision.committer
                        if revision.committer is None
                        else persons.get(revision.committer.email, revision.committer)
                    ),
                )
            )
            for revision in revisions
        ]
    TRelease = TypeVar("TRelease", Release, Optional[Release])
    def _apply_release_display_names(self, releases: List[TRelease]) -> List[TRelease]:
        emails = set()
        for rel in releases:
            if (
                rel is not None
                and rel.author is not None
                and rel.author.email  # ignore None or empty email addresses
            ):
                emails.add(rel.author.email)
        with self._masking_query() as q:
            display_names = q.display_name(list(emails))
        # Short path for the common case
        if not display_names:
            return releases
        persons: Dict[Optional[bytes], Person] = {
            email: Person.from_fullname(display_name)
            for (email, display_name) in display_names.items()
        }
        return [
            (
                None
                if release is None
                else attr.evolve(
                    release,
                    author=(
                        release.author
                        if release.author is None
                        else persons.get(release.author.email, release.author)
                    ),
                )
            )
            for release in releases
        ]
[docs]
    def revision_get(
        self, revision_ids: List[Sha1Git], ignore_displayname: bool = False
    ) -> List[Optional[Revision]]:
        revisions = self.storage.revision_get(revision_ids)
        self._raise_if_masked_result_in_list("revision_get", revisions)
        return self._apply_revision_display_names(revisions) 
[docs]
    def revision_log(
        self,
        revisions: List[Sha1Git],
        ignore_displayname: bool = False,
        limit: Optional[int] = None,
    ) -> Iterable[Optional[Dict[str, Any]]]:
        revision_batches = grouper(
            self.storage.revision_log(revisions, limit=limit), BATCH_SIZE
        )
        yield from map(
            Revision.to_dict,
            itertools.chain.from_iterable(
                self._apply_revision_display_names(
                    self._raise_if_masked_result_in_list(
                        "revision_log", list(map(Revision.from_dict, revision_batch))
                    )
                )
                for revision_batch in revision_batches
            ),
        ) 
[docs]
    def revision_get_partition(
        self,
        partition_id: int,
        nb_partitions: int,
        page_token: Optional[str] = None,
        limit: int = 1000,
    ) -> PagedResult[Revision]:
        page: PagedResult[Revision] = self.storage.revision_get_partition(
            partition_id, nb_partitions, page_token, limit
        )
        return PagedResult(
            results=self._apply_revision_display_names(
                self._raise_if_masked_result_in_list(
                    "revision_get_parition", page.results
                )
            ),
            next_page_token=page.next_page_token,
        ) 
[docs]
    def release_get(
        self, releases: List[Sha1Git], ignore_displayname: bool = False
    ) -> List[Optional[Release]]:
        return self._apply_release_display_names(
            self._raise_if_masked_result_in_list(
                "release_get", self.storage.release_get(releases)
            )
        ) 
[docs]
    def release_get_partition(
        self,
        partition_id: int,
        nb_partitions: int,
        page_token: Optional[str] = None,
        limit: int = 1000,
    ) -> PagedResult[Release]:
        page = self.storage.release_get_partition(
            partition_id, nb_partitions, page_token, limit
        )
        return PagedResult(
            results=self._apply_release_display_names(
                self._raise_if_masked_result_in_list(
                    "release_get_partition", page.results
                )
            ),
            next_page_token=page.next_page_token,
        )