# Copyright (C) 2019-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
import time
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Type
from confluent_kafka import KafkaException, Producer
from swh.journal.serializers import KeyType, key_to_kafka, pprint_key, value_to_kafka
from .interface import ValueProtocol
logger = logging.getLogger(__name__)
[docs]
class DeliveryTag(NamedTuple):
    """Unique tag allowing us to check for a message delivery"""
    topic: str
    kafka_key: bytes 
[docs]
class DeliveryFailureInfo(NamedTuple):
    """Verbose information for failed deliveries"""
    object_type: str
    key: KeyType
    message: str
    code: str 
[docs]
def get_object_type(topic: str) -> str:
    """Get the object type from a topic string"""
    return topic.rsplit(".", 1)[-1] 
[docs]
class KafkaDeliveryError(Exception):
    """Delivery failed on some kafka messages."""
    def __init__(self, message: str, delivery_failures: Iterable[DeliveryFailureInfo]):
        self.message = message
        self.delivery_failures = list(delivery_failures)
[docs]
    def pretty_failures(self) -> str:
        return ", ".join(
            f"{f.object_type} {pprint_key(f.key)} ({f.message})"
            for f in self.delivery_failures
        ) 
    def __str__(self):
        return f"KafkaDeliveryError({self.message}, [{self.pretty_failures()}])" 
[docs]
class KafkaJournalWriter:
    """This class is used to write serialized versions of value objects to a series
    of Kafka topics. The type parameter of value objects, which must implement
    the `ValueProtocol`, is the type of values this writer will write.
    Typically, `ValueProtocol` will be `swh.model.model.BaseModel`.
    Topics used to send objects representations are built from a ``prefix`` plus the
    type of the object:
      ``{prefix}.{object_type}``
    Objects can be sent as is, or can be anonymized. The anonymization feature, when
    activated, will write anonymized versions of value objects in the main topic, and
    stock (non-anonymized) objects will be sent to a dedicated (privileged) set of
    topics:
      ``{prefix}_privileged.{object_type}``
    The anonymization of a value object is the result of calling its
    ``anonymize()`` method. An object is considered anonymizable if this
    method returns a (non-None) value.
    Args:
      brokers: list of broker addresses and ports.
      prefix: the prefix used to build the topic names for objects.
      client_id: the id of the writer sent to kafka.
      value_sanitizer: a function that takes the object type and the dict
        representation of an object as argument, and returns an other dict
        that should be actually stored in the journal (eg. removing keys
        that do no belong there)
      producer_config: extra configuration keys passed to the `Producer`.
      flush_timeout: timeout, in seconds, after which the `flush` operation
        will fail if some message deliveries are still pending.
      producer_class: override for the kafka producer class.
      anonymize: if True, activate the anonymization feature.
      auto_flush: if True (default), flush the kafka producer in
        ``write_addition()`` and ``write_additions()``. This should be set
        to False ONLY for testing purpose. DO NOT USE ON PRODUCTION ENVIRONMENT.
    """
    def __init__(
        self,
        brokers: Iterable[str],
        prefix: str,
        client_id: str,
        value_sanitizer: Callable[[str, Dict[str, Any]], Dict[str, Any]],
        producer_config: Optional[Dict] = None,
        flush_timeout: float = 120,
        producer_class: Type[Producer] = Producer,
        anonymize: bool = False,
        auto_flush: bool = True,
    ):
        self._prefix = prefix
        self._prefix_privileged = f"{self._prefix}_privileged"
        self.anonymize = anonymize
        self.auto_flush = auto_flush
        if not producer_config:
            producer_config = {}
        if "message.max.bytes" not in producer_config:
            producer_config = {
                "message.max.bytes": 100 * 1024 * 1024,
                **producer_config,
            }
        self.producer = producer_class(
            {
                "bootstrap.servers": ",".join(brokers),
                "client.id": client_id,
                "on_delivery": self._on_delivery,
                "error_cb": self._error_cb,
                "logger": logger,
                "acks": "all",
                **producer_config,
            }
        )
        # Delivery management
        self.flush_timeout = flush_timeout
        # delivery tag -> original object "key" mapping
        self.deliveries_pending: Dict[DeliveryTag, KeyType] = {}
        # List of (object_type, key, error_msg, error_name) for failed deliveries
        self.delivery_failures: List[DeliveryFailureInfo] = []
        self.value_sanitizer = value_sanitizer
    def _error_cb(self, error):
        if error.fatal():
            raise KafkaException(error)
        logger.info("Received non-fatal kafka error: %s", error)
    def _on_delivery(self, error, message):
        (topic, key) = delivery_tag = DeliveryTag(message.topic(), message.key())
        sent_key = self.deliveries_pending.pop(delivery_tag, None)
        if error is not None:
            self.delivery_failures.append(
                DeliveryFailureInfo(
                    get_object_type(topic), sent_key, error.str(), error.name()
                )
            )
[docs]
    def reliable_produce(self, topic: str, key: KeyType, kafka_value: Optional[bytes]):
        kafka_key = key_to_kafka(key)
        max_attempts = 5
        last_exception: Optional[Exception] = None
        for attempt in range(max_attempts):
            try:
                self.producer.produce(
                    topic=topic,
                    key=kafka_key,
                    value=kafka_value,
                )
            except BufferError as e:
                last_exception = e
                wait = 1 + 3 * attempt
                if logger.isEnabledFor(logging.DEBUG):  # pprint_key is expensive
                    logger.debug(
                        "BufferError producing %s %s; waiting for %ss",
                        get_object_type(topic),
                        pprint_key(kafka_key),
                        wait,
                    )
                self.producer.poll(wait)
            else:
                self.deliveries_pending[DeliveryTag(topic, kafka_key)] = key
                return
        # We reach this point if all delivery attempts have failed
        self.delivery_failures.append(
            DeliveryFailureInfo(
                get_object_type(topic), key, str(last_exception), "SWH_BUFFER_ERROR"
            )
        ) 
[docs]
    def send(self, topic: str, key: KeyType, value):
        kafka_value = value_to_kafka(value)
        return self.reliable_produce(topic, key, kafka_value) 
[docs]
    def delivery_error(self, message) -> KafkaDeliveryError:
        """Get all failed deliveries, and clear them"""
        ret = self.delivery_failures
        self.delivery_failures = []
        while self.deliveries_pending:
            delivery_tag, orig_key = self.deliveries_pending.popitem()
            (topic, kafka_key) = delivery_tag
            ret.append(
                DeliveryFailureInfo(
                    get_object_type(topic),
                    orig_key,
                    "No delivery before flush() timeout",
                    "SWH_FLUSH_TIMEOUT",
                )
            )
        return KafkaDeliveryError(message, ret) 
[docs]
    def flush(self) -> None:
        start = time.monotonic()
        self.producer.flush(self.flush_timeout)
        while self.deliveries_pending:
            if time.monotonic() - start > self.flush_timeout:
                break
            self.producer.poll(0.1)
        if self.deliveries_pending:
            # Delivery timeout
            raise self.delivery_error(
                "flush() exceeded timeout (%ss)" % self.flush_timeout,
            )
        elif self.delivery_failures:
            raise self.delivery_error("Failed deliveries after flush()") 
    def _write_addition(self, object_type: str, object_: ValueProtocol) -> None:
        """Write a single object to the journal"""
        key = object_.unique_key()
        if self.anonymize:
            anon_object_ = object_.anonymize()
            if anon_object_:  # can be either None, or an anonymized object
                # if the object is anonymizable, send the non-anonymized version in the
                # privileged channel
                topic = f"{self._prefix_privileged}.{object_type}"
                dict_ = self.value_sanitizer(object_type, object_.to_dict())
                logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_)
                self.send(topic, key=key, value=dict_)
                object_ = anon_object_
        topic = f"{self._prefix}.{object_type}"
        dict_ = self.value_sanitizer(object_type, object_.to_dict())
        logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_)
        self.send(topic, key=key, value=dict_)
[docs]
    def write_addition(self, object_type: str, object_: ValueProtocol) -> None:
        """Write a single object to the journal"""
        self._write_addition(object_type, object_)
        if self.auto_flush:
            self.flush() 
[docs]
    def write_additions(
        self, object_type: str, objects: Iterable[ValueProtocol]
    ) -> None:
        """Write a set of objects to the journal"""
        for object_ in objects:
            self._write_addition(object_type, object_)
        if self.auto_flush:
            self.flush() 
[docs]
    def delete(self, object_type: str, object_keys: Iterable[KeyType]) -> None:
        """Write a tombstone for the given keys.
        For older data to be removed, the topic must be configured with
        ``cleanup.policy=compact``. Please also consider setting:
        - ``max.compaction.lag.ms``: delay between the appearance of a tombstone
          and the actual deletion of older values.
        - ``delete.retention.ms``: how long must tombstones themselves be kept.
          This is important as they enable journal clients to learn that a given
          kep has been deleted and act accordingly.
        Note that deletion won’t happen for keys located in the currently active
        log segment. It will only be possible once enough newer entries have be
        added, pushing older keys to “dirty” log segments that can be compacted.
        """
        topic = f"{self._prefix}.{object_type}"
        for key in object_keys:
            self.reliable_produce(topic, key, None)
        # Handle non-anonymized objects
        # XXX: is this list already available elsewhere?
        if object_type in ("revision", "release"):
            topic = f"{self._prefix_privileged}.{object_type}"
            for key in object_keys:
                self.reliable_produce(topic, key, None)