Source code for swh.auth.starlette.backends
# Copyright (C) 2023-2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime
import hashlib
from typing import Any, Dict, Optional, Tuple
from aiocache.base import BaseCache
from jwcrypto.common import JWException
from starlette.authentication import (
    AuthCredentials,
    AuthenticationBackend,
    AuthenticationError,
    SimpleUser,
)
from starlette.requests import HTTPConnection
from swh.auth.keycloak import (
    ExpiredSignatureError,
    KeycloakError,
    KeycloakOpenIDConnect,
    keycloak_error_message,
)
[docs]
class BearerTokenAuthBackend(AuthenticationBackend):
    """
    Starlette authentication backend using Keycloak OpenID Connect authorization
    An Keycloak server, realm and a cache to store access tokens must be provided
    """
    def __init__(
        self, server_url: str, realm_name: str, client_id: str, cache: BaseCache
    ):
        """
        Args:
            server_url: Keycloak URL
            realm_name: Keycloak realm name
            client_id: Keycloak client ID
            cache: An aiocache cache instance
        """
        self.client_id = client_id
        self.oidc_client = KeycloakOpenIDConnect(
            server_url=server_url,
            realm_name=realm_name,
            client_id=client_id,
        )
        self.cache = cache
    def _get_token_from_header(self, auth_header: str) -> str:
        try:
            auth_type, bearer_token = auth_header.split(" ", 1)
        except ValueError:
            raise AuthenticationError("Invalid auth header")
        if auth_type != "Bearer":
            raise AuthenticationError("Invalid or unsupported authorization type")
        return bearer_token
    def _get_token_cache_key(self, refresh_token) -> str:
        hasher = hashlib.sha1()
        hasher.update(refresh_token.encode("ascii"))
        return f"api_token_{hasher.hexdigest()}"
    def _get_new_access_token(self, refresh_token: str) -> Dict[str, Any]:
        try:
            access_token = self.oidc_client.refresh_token(refresh_token)
        except KeycloakError as e:
            raise AuthenticationError(
                "Invalid or expired user token", keycloak_error_message(e)
            )
        return access_token
    def _decode_token(self, access_token: str) -> Optional[Dict[str, Any]]:
        if not access_token:
            return None
        try:
            decoded_token = self.oidc_client.decode_token(access_token)
        except (KeycloakError, UnicodeEncodeError, ExpiredSignatureError, ValueError):
            # token is eitehr too old or an invalid one
            decoded_token = None
        return decoded_token
[docs]
    async def authenticate(
        self, conn: HTTPConnection
    ) -> Optional[Tuple[AuthCredentials, SimpleUser]]:
        auth_header = conn.headers.get("Authorization")
        if auth_header is None:
            # anonymous user
            return None
        token = self._get_token_from_header(auth_header)
        try:
            # check if access token was provided in authorization header
            decoded_token = self._decode_token(token)
            if not decoded_token:
                raise AuthenticationError("Access token failed to be decoded")
        except JWException:
            # token is a refresh one so backend handles access token renewal
            # get the cache key
            cache_key = self._get_token_cache_key(token)
            # read access token from the cache
            access_token = await self.cache.get(cache_key)
            decoded_token = self._decode_token(access_token)
            if not access_token or not decoded_token:
                access_token = self._get_new_access_token(token)["access_token"]
                decoded_token = self._decode_token(access_token)
                if not decoded_token:
                    raise AuthenticationError("Access token failed to be decoded")
                exp = datetime.fromtimestamp(decoded_token["exp"])
                ttl = int(exp.timestamp() - datetime.now().timestamp())
                await self.cache.set(cache_key, access_token, ttl=ttl)
        # set user scopes
        realm_access = decoded_token.get("realm_access", {})
        user_scopes = realm_access.get("roles", [])
        resource_access = decoded_token.get("resource_access", {})
        client_resource_access = resource_access.get(self.client_id, {})
        user_scopes += client_resource_access.get("roles", [])
        return AuthCredentials(scopes=user_scopes), SimpleUser(
            decoded_token["preferred_username"]
        )