# Copyright (C) 2019-2024  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import contextlib
import logging
import multiprocessing
import socket
import subprocess
import threading
import time
import grpc
import pytest
from swh.provenance import get_provenance
from swh.provenance.grpc.swhprovenance_pb2_grpc import ProvenanceServiceStub
from swh.provenance.grpc_server import (
    ExecutableNotFound,
    default_rust_executable_dir,
    spawn_rust_grpc_server,
)
logger = logging.getLogger(__name__)
[docs]
@pytest.fixture
def swh_provenance(swh_provenance_config):
    yield get_provenance(**swh_provenance_config) 
[docs]
class ProvenanceServerProcess(multiprocessing.Process):
    def __init__(self, config, *args, **kwargs):
        self.config = config
        self.q = multiprocessing.Queue()
        super().__init__(*args, **kwargs)
[docs]
    def run(self):
        try:
            assert self.config["cls"] == "local_rust"
            (server, port) = spawn_rust_grpc_server(**self.config["grpc_server"])
            self.q.put(
                {
                    "grpc_url": f"localhost:{port}",
                    "port": port,
                    "pid": server.pid,
                }
            )
        except Exception as e:
            if isinstance(e, ExecutableNotFound):
                # hack to add a bit more context and help to the user,
                # especially when this is used from another swh package...
                # XXX on py>=3.11 we could use e.add_note() instead
                e.args = (
                    *e.args,
                    "This probably means you need to build the rust grpc server "
                    "for swh-provenance.",
                )
            logger.exception(e)
            self.q.put(e) 
[docs]
    def start(self, *args, **kwargs):
        super().start()
        self.result = self.q.get() 
 
[docs]
class StatsdServer:
    def __init__(self):
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self._sock.bind(("127.0.0.1", 0))
        self._sock.settimeout(0.1)
        (self.host, self.port) = self._sock.getsockname()
        self._closing = False
        self._thread = threading.Thread(target=self._listen)
        self._thread.start()
        self.datagrams = []
        self.new_datagram = threading.Event()
        """Woken up every time a datagram is added to self.datagrams."""
    def _listen(self):
        while not self._closing:
            try:
                (datagram, addr) = self._sock.recvfrom(4096)
            except TimeoutError:
                continue
            self.datagrams.append(datagram)
            self.new_datagram.set()
        self._sock.close()
[docs]
    def close(self):
        self._closing = True 
 
[docs]
@pytest.fixture(scope="session")
def provenance_statsd_server():
    with contextlib.closing(StatsdServer()) as statsd_server:
        yield statsd_server 
[docs]
@pytest.fixture(scope="session", params=["rust"])
def provenance_grpc_backend_implementation(request):
    return request.param 
[docs]
@pytest.fixture(scope="session")
def provenance_database_and_graph(tmpdir_factory):
    database_path = tmpdir_factory.mktemp("provenance_database")
    subprocess.run(
        [
            f"{default_rust_executable_dir({})}/swh-provenance-gen-test-database",
            "main",
            database_path,
        ],
        check=True,
    )
    subprocess.run(
        [
            f"{default_rust_executable_dir({})}/swh-provenance-index",
            "--database",
            f"file://{database_path}",
        ],
        check=True,
    )
    return database_path 
[docs]
@pytest.fixture(scope="session")
def provenance_grpc_server_config(
    provenance_grpc_backend_implementation,
    provenance_statsd_server,
    provenance_database_and_graph,
):
    return {
        "provenance": {
            "cls": f"local_{provenance_grpc_backend_implementation}",
            "grpc_server": {
                "db": f"file://{provenance_database_and_graph}",
                "graph": provenance_database_and_graph / "graph.json",
                "graph_format": "json",
                "debug": True,
                "statsd_host": provenance_statsd_server.host,
                "statsd_port": provenance_statsd_server.port,
            },
        }
    } 
[docs]
@pytest.fixture(scope="session")
def provenance_grpc_server_process(
    provenance_grpc_server_config, provenance_statsd_server
):
    server = ProvenanceServerProcess(provenance_grpc_server_config["provenance"])
    yield server
    try:
        server.kill()
    except AttributeError:
        # server was never started
        pass 
[docs]
@pytest.fixture(scope="session")
def provenance_grpc_server_started(provenance_grpc_server_process):
    server = provenance_grpc_server_process
    server.start()
    if isinstance(server.result, Exception):
        raise server.result
    # wait for the server to be up
    for _ in range(100):
        try:
            socket.create_connection(("localhost", server.result["port"]), timeout=1.0)
        except ConnectionRefusedError:
            time.sleep(0.01)
    yield server
    server.kill() 
[docs]
@pytest.fixture(scope="module")
def provenance_grpc_stub(provenance_grpc_server):
    with grpc.insecure_channel(provenance_grpc_server) as channel:
        stub = ProvenanceServiceStub(channel)
        yield stub 
[docs]
@pytest.fixture(scope="module")
def provenance_grpc_server(provenance_grpc_server_started):
    yield provenance_grpc_server_started.result["grpc_url"]