# Copyright (C) 2019-2024  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import Counter, defaultdict
from datetime import datetime, timezone
from itertools import chain
import re
from typing import Any, Dict, Iterable, Iterator, List, Optional
from swh.indexer import codemeta
from swh.model import model
from swh.model.hashutil import hash_to_hex
from swh.search.interface import SORT_BY_OPTIONS, OriginDict, PagedResult
from swh.search.utils import get_expansion, parse_and_format_date
_words_regexp = re.compile(r"\w+")
def _dict_words_set(d):
    """Recursively extract set of words from dict content."""
    values = set()
    def extract(obj, words):
        if isinstance(obj, dict):
            for k, v in obj.items():
                extract(v, words)
        elif isinstance(obj, list):
            for item in obj:
                extract(item, words)
        else:
            words.update(_words_regexp.findall(str(obj).lower()))
        return words
    return extract(d, values)
def _nested_get(nested_dict, nested_keys, default=""):
    """Extracts values from deeply nested dictionary nested_dict
    using the nested_keys and returns a list of all of the values
    discovered in the process.
    >>> nested_dict = [
    ... {"name": [{"@value": {"first": "f1", "last": "l1"}}], "address": "XYZ"},
    ... {"name": [{"@value": {"first": "f2", "last": "l2"}}], "address": "ABC"},
    ... ]
    >>> _nested_get(nested_dict, ["name", "@value", "last"])
    ['l1', 'l2']
    >>> _nested_get(nested_dict, ["address"])
    ['XYZ', 'ABC']
    It doesn't allow fetching intermediate values and returns "" for such cases
    >>> _nested_get(nested_dict, ["name", "@value"])
    ['', '']
    """
    def _nested_get_recursive(nested_dict, nested_keys):
        try:
            curr_obj = nested_dict
            type_curr_obj = type(curr_obj)
            for i, key in enumerate(nested_keys):
                if key in curr_obj:
                    curr_obj = curr_obj[key]
                    type_curr_obj = type(curr_obj)
                else:
                    if type_curr_obj == list:
                        curr_obj = [
                            _nested_get_recursive(obj, nested_keys[i:])
                            for obj in curr_obj
                        ]
                    # If value isn't a list or string or integer
                    elif type_curr_obj != str and type_curr_obj != int:
                        return default
            # If only one element is present in the list, take it out
            # This ensures a flat array every time
            if type_curr_obj == list and len(curr_obj) == 1:
                curr_obj = curr_obj[0]
            return curr_obj
        except Exception:
            return default
    res = _nested_get_recursive(nested_dict, nested_keys)
    if type(res) is not list:
        return [res]
    return res
def _tokenize(x):
    return x.lower().replace(",", " ").split()
def _get_sorting_key(origin, field):
    """Get value of the field from an origin for sorting origins.
    Here field should be a member of SORT_BY_OPTIONS.
    If "-" is present at the start of field then invert the value
    in a way that it reverses the sorting order.
    """
    reversed = False
    if field[0] == "-":
        field = field[1:]
        reversed = True
    DATETIME_OBJ_MAX = datetime.max.replace(tzinfo=timezone.utc)
    DATETIME_MIN = "0001-01-01T00:00:00Z"
    DATE_OBJ_MAX = datetime.max
    DATE_MIN = "0001-01-01"
    if field == "score":
        if reversed:
            return -origin.get(field, 0)
        else:
            return origin.get(field, 0)
    if field in ["date_created", "date_modified", "date_published"]:
        date = datetime.strptime(
            _nested_get(origin, get_expansion(field), DATE_MIN)[0], "%Y-%m-%d"
        )
        if reversed:
            return DATE_OBJ_MAX - date
        else:
            return date
    elif field in ["nb_visits"]:  # unlike other options, nb_visits is of type integer
        if reversed:
            return -origin.get(field, 0)
        else:
            return origin.get(field, 0)
    elif field in SORT_BY_OPTIONS:
        date = datetime.fromisoformat(
            origin.get(field, DATETIME_MIN).replace("Z", "+00:00")
        )
        if reversed:
            return DATETIME_OBJ_MAX - date
        else:
            return date
[docs]
class InMemorySearch:
    def __init__(self):
        self.initialize()
[docs]
    def check(self):
        return True 
[docs]
    def deinitialize(self) -> None:
        if hasattr(self, "_origins"):
            del self._origins
            del self._origin_ids 
[docs]
    def initialize(self) -> None:
        self._origins: Dict[str, Dict[str, Any]] = defaultdict(dict)
        self._origin_ids: List[str] = [] 
[docs]
    def flush(self) -> None:
        pass 
    _url_splitter = re.compile(r"\W")
[docs]
    def origin_update(self, documents: Iterable[OriginDict]) -> None:
        for source_document in documents:
            id_ = hash_to_hex(model.Origin(url=source_document["url"]).id)
            document: Dict[str, Any] = {
                **source_document,
                "sha1": id_,
            }
            if "url" in document:
                document["_url_tokens"] = set(
                    self._url_splitter.split(source_document["url"])
                )
            if "visit_types" in document:
                document["visit_types"] = source_document["visit_types"]
                if "visit_types" in self._origins[id_]:
                    document["visit_types"] = list(
                        set(document["visit_types"] + self._origins[id_]["visit_types"])
                    )
            if "nb_visits" in document:
                document["nb_visits"] = max(
                    document["nb_visits"], self._origins[id_].get("nb_visits", 0)
                )
            if "last_visit_date" in document:
                document["last_visit_date"] = max(
                    datetime.fromisoformat(document["last_visit_date"]),
                    datetime.fromisoformat(
                        self._origins[id_]
                        .get(
                            "last_visit_date",
                            "0001-01-01T00:00:00.000000Z",
                        )
                        .replace("Z", "+00:00")
                    ),
                ).isoformat()
            if "snapshot_id" in document and "last_eventful_visit_date" in document:
                incoming_date = datetime.fromisoformat(
                    document["last_eventful_visit_date"]
                )
                current_date = datetime.fromisoformat(
                    self._origins[id_]
                    .get(
                        "last_eventful_visit_date",
                        "0001-01-01T00:00:00Z",
                    )
                    .replace("Z", "+00:00")
                )
                incoming_snapshot_id = document["snapshot_id"]
                current_snapshot_id = self._origins[id_].get("snapshot_id", "")
                if (
                    incoming_snapshot_id == current_snapshot_id
                    or incoming_date < current_date
                ):
                    # update not required so override the incoming_values
                    document["snapshot_id"] = current_snapshot_id
                    document["last_eventful_visit_date"] = current_date.isoformat()
            if "last_revision_date" in document:
                document["last_revision_date"] = max(
                    datetime.fromisoformat(document["last_revision_date"]),
                    datetime.fromisoformat(
                        self._origins[id_]
                        .get(
                            "last_revision_date",
                            "0001-01-01T00:00:00Z",
                        )
                        .replace("Z", "+00:00")
                    ),
                ).isoformat()
            if "last_release_date" in document:
                document["last_release_date"] = max(
                    datetime.fromisoformat(document["last_release_date"]),
                    datetime.fromisoformat(
                        self._origins[id_]
                        .get(
                            "last_release_date",
                            "0001-01-01T00:00:00Z",
                        )
                        .replace("Z", "+00:00")
                    ),
                ).isoformat()
            if "jsonld" in document:
                jsonld = document["jsonld"]
                for date_field in ["dateCreated", "dateModified", "datePublished"]:
                    if date_field in jsonld:
                        date = jsonld[date_field]
                        # If date{Created,Modified,Published} value isn't parsable
                        # It gets rejected and isn't stored (unlike other fields)
                        formatted_date = parse_and_format_date(date)
                        if formatted_date is None:
                            jsonld.pop(date_field)
                        else:
                            jsonld[date_field] = formatted_date
                document["jsonld"] = codemeta.expand(jsonld)
                if len(document["jsonld"]) != 1:
                    continue
                metadata = document["jsonld"][0]
                if "http://schema.org/license" in metadata:
                    metadata["http://schema.org/license"] = [
                        {"@id": license["@id"].lower()}
                        for license in metadata["http://schema.org/license"]
                    ]
                if "http://schema.org/programmingLanguage" in metadata:
                    metadata["http://schema.org/programmingLanguage"] = [
                        {"@value": license["@value"].lower()}
                        for license in metadata["http://schema.org/programmingLanguage"]
                    ]
            self._origins[id_].update(document)
            if id_ not in self._origin_ids:
                self._origin_ids.append(id_) 
[docs]
    def origin_search(
        self,
        *,
        query: str = "",
        url_pattern: Optional[str] = None,
        metadata_pattern: Optional[str] = None,
        with_visit: bool = False,
        visit_types: Optional[List[str]] = None,
        min_nb_visits: int = 0,
        min_last_visit_date: str = "",
        min_last_eventful_visit_date: str = "",
        min_last_revision_date: str = "",
        min_last_release_date: str = "",
        min_date_created: str = "",
        min_date_modified: str = "",
        min_date_published: str = "",
        programming_languages: Optional[List[str]] = None,
        licenses: Optional[List[str]] = None,
        keywords: Optional[List[str]] = None,
        fork_weight: Optional[float] = 0.5,
        sort_by: Optional[List[str]] = None,
        page_token: Optional[str] = None,
        limit: int = 50,
    ) -> PagedResult[OriginDict]:
        if sort_by:
            sort_by.append("-score")
        else:
            sort_by = ["-score"]
        hits = self._get_hits()
        if url_pattern:
            tokens = set(self._url_splitter.split(url_pattern))
            def predicate(match):
                missing_tokens = tokens - match["_url_tokens"]
                if len(missing_tokens) == 0:
                    return True
                elif len(missing_tokens) > 1:
                    return False
                else:
                    # There is one missing token, look up by prefix.
                    (missing_token,) = missing_tokens
                    return any(
                        token.startswith(missing_token)
                        for token in match["_url_tokens"]
                    )
            hits = filter(predicate, hits)
        if metadata_pattern:
            metadata_pattern_words = set(
                _words_regexp.findall(metadata_pattern.lower())
            )
            def predicate(match):
                if "jsonld" not in match:
                    return False
                return metadata_pattern_words.issubset(_dict_words_set(match["jsonld"]))
            hits = filter(predicate, hits)
        if url_pattern is None and metadata_pattern is None:
            raise ValueError(
                "At least one of url_pattern and metadata_pattern must be provided."
            )
        next_page_token: Optional[str] = None
        if with_visit:
            hits = filter(lambda o: o.get("has_visits"), hits)
        if min_nb_visits:
            hits = filter(lambda o: o.get("nb_visits", 0) >= min_nb_visits, hits)
        if min_last_visit_date:
            hits = filter(
                lambda o: datetime.fromisoformat(
                    o.get("last_visit_date", "0001-01-01T00:00:00Z").replace(
                        "Z", "+00:00"
                    )
                )
                >= datetime.fromisoformat(min_last_visit_date),
                hits,
            )
        if min_last_eventful_visit_date:
            hits = filter(
                lambda o: datetime.fromisoformat(
                    o.get("last_eventful_visit_date", "0001-01-01T00:00:00Z").replace(
                        "Z", "+00:00"
                    )
                )
                >= datetime.fromisoformat(min_last_eventful_visit_date),
                hits,
            )
        if min_last_revision_date:
            hits = filter(
                lambda o: datetime.fromisoformat(
                    o.get("last_revision_date", "0001-01-01T00:00:00Z").replace(
                        "Z", "+00:00"
                    )
                )
                >= datetime.fromisoformat(min_last_revision_date),
                hits,
            )
        if min_last_release_date:
            hits = filter(
                lambda o: datetime.fromisoformat(
                    o.get("last_release_date", "0001-01-01T00:00:00Z").replace(
                        "Z", "+00:00"
                    )
                )
                >= datetime.fromisoformat(min_last_release_date),
                hits,
            )
        if min_date_created:
            min_date_created_obj = datetime.strptime(min_date_created, "%Y-%m-%d")
            hits = filter(
                lambda o: datetime.strptime(
                    _nested_get(o, get_expansion("date_created"))[0], "%Y-%m-%d"
                )
                >= min_date_created_obj,
                hits,
            )
        if min_date_modified:
            min_date_modified_obj = datetime.strptime(min_date_modified, "%Y-%m-%d")
            hits = filter(
                lambda o: datetime.strptime(
                    _nested_get(o, get_expansion("date_modified"))[0], "%Y-%m-%d"
                )
                >= min_date_modified_obj,
                hits,
            )
        if min_date_published:
            min_date_published_obj = datetime.strptime(min_date_published, "%Y-%m-%d")
            hits = filter(
                lambda o: datetime.strptime(
                    _nested_get(o, get_expansion("date_published"))[0], "%Y-%m-%d"
                )
                >= min_date_published_obj,
                hits,
            )
        if licenses:
            queried_licenses = [license_keyword.lower() for license_keyword in licenses]
            hits = filter(
                lambda o: any(
                    # If any of the queried licenses are found, include the origin
                    any(
                        # returns True if queried_license_keyword is found
                        # in any of the licenses of the origin
                        queried_license_keyword in origin_license
                        for origin_license in _nested_get(o, get_expansion("licenses"))
                    )
                    for queried_license_keyword in queried_licenses
                ),
                hits,
            )
        if programming_languages:
            queried_programming_languages = [
                lang_keyword.lower() for lang_keyword in programming_languages
            ]
            hits = filter(
                lambda o: any(
                    # If any of the queried languages are found, include the origin
                    any(
                        # returns True if queried_lang_keyword is found
                        # in any of the langs of the origin
                        queried_lang_keyword in origin_lang
                        for origin_lang in _nested_get(
                            o, get_expansion("programming_languages")
                        )
                    )
                    for queried_lang_keyword in queried_programming_languages
                ),
                hits,
            )
        if keywords:
            from copy import deepcopy
            hits_list = deepcopy(list(hits))
            for origin in hits_list:
                origin_keywords = [
                    _tokenize(keyword)
                    for keyword in _nested_get(origin, get_expansion("keywords"))
                ]
                origin_descriptions = [
                    _tokenize(description)
                    for description in _nested_get(
                        origin, get_expansion("descriptions")
                    )
                ]
                for q_keyword in keywords:
                    for origin_keyword_tokens in origin_keywords:
                        if q_keyword in origin_keyword_tokens:
                            origin["score"] = origin.get("score", 0) + 2
                    for origin_description_token in origin_descriptions:
                        if q_keyword in origin_description_token:
                            origin["score"] = origin.get("score", 0) + 1
            hits = (origin for origin in hits_list if origin.get("score", 0) > 0)
        if visit_types is not None:
            visit_types_set = set(visit_types)
            hits = filter(
                lambda o: visit_types_set.intersection(set(o.get("visit_types", []))),
                hits,
            )
        hits_list = list(hits)
        if fork_weight is not None:
            hits_list = [
                {
                    **hit,
                    "score": hit.get("score", 1)
                    * (
                        fork_weight
                        if any(
                            "https://forgefed.org/ns#forkedFrom" in doc
                            for doc in hit.get("jsonld", [])
                        )
                        else 1.0
                    ),
                }
                for hit in hits_list
            ]
        if sort_by:
            sort_by_list = list(sort_by)
            hits_list.sort(
                key=lambda o: tuple(
                    _get_sorting_key(o, field) for field in sort_by_list
                )
            )
        start_at_index = int(page_token) if page_token else 0
        origins = [
            {
                field: hit.get(field, default)
                for field, default in [
                    ("url", ""),
                    ("visit_types", []),
                    ("has_visits", False),
                ]
            }
            for hit in hits_list[start_at_index : start_at_index + limit]
        ]
        if len(origins) == limit:
            next_page_token = str(start_at_index + limit)
        assert len(origins) <= limit
        return PagedResult(
            results=origins,
            next_page_token=next_page_token,
        ) 
[docs]
    def origin_get(self, url: str) -> Optional[Dict[str, Any]]:
        origin_id = hash_to_hex(model.Origin(url=url).id)
        document = self._origins.get(origin_id)
        if document is None:
            return None
        else:
            return {k: v for (k, v) in document.items() if k != "_url_tokens"} 
[docs]
    def origin_delete(self, url: str) -> bool:
        origin_id = hash_to_hex(model.Origin(url=url).id)
        try:
            del self._origins[origin_id]
        except KeyError:
            return False
        try:
            self._origin_ids.remove(origin_id)
        except ValueError:
            assert False, "this should not have happened"
        return True 
[docs]
    def visit_types_count(self) -> Counter:
        hits = self._get_hits()
        return Counter(chain(*[hit.get("visit_types", []) for hit in hits])) 
    def _get_hits(self) -> Iterator[Dict[str, Any]]:
        return (
            self._origins[id_]
            for id_ in self._origin_ids
            if not self._origins[id_].get("blocklisted")
        )