Source code for lback.auth.jwt_auth

import jwt
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
import logging

from lback.core.signals import dispatcher


logger = logging.getLogger(__name__)

[docs] class JWTAuth: """ Utility class for creating, decoding, and validating JSON Web Tokens (JWT). Supports both access and refresh tokens. Integrates SignalDispatcher to emit events related to token lifecycle and validation. """
[docs] def __init__(self, secret: str, algorithm: str = "HS256", access_exp: int = 3600, refresh_exp: int = 86400): """ Initializes the JWTAuth utility. Args: secret: The secret key used for signing and verifying tokens. Keep this secret and secure. algorithm: The signing algorithm to use (default is HS256). access_exp: Expiration time for access tokens in seconds (default is 1 hour). refresh_exp: Expiration time for refresh tokens in seconds (default is 24 hours). """ if not secret: logger.error("JWTAuth initialized with an empty secret. Tokens will not be secure.") self.secret = secret self.algorithm = algorithm self.access_exp = access_exp self.refresh_exp = refresh_exp logger.info(f"JWTAuth initialized with algorithm: {self.algorithm}, access expiry: {self.access_exp}s, refresh expiry: {self.refresh_exp}s")
[docs] def create_access_token(self, payload: Dict[str, Any]) -> str: """ Creates a new access token. Emits 'jwt_access_token_created' signal on creation. Args: payload: A dictionary containing the data to encode in the token. Should NOT contain sensitive information like passwords. A 'user_id' or similar identifier is common. Returns: A signed JWT access token string. """ data = payload.copy() data["type"] = "access" data["exp"] = datetime.utcnow() + timedelta(seconds=self.access_exp) logger.debug(f"Creating access token with payload: {payload}") token = jwt.encode(data, self.secret, algorithm=self.algorithm) dispatcher.send("jwt_access_token_created", sender=self, payload=payload, token=token) logger.debug("Signal 'jwt_access_token_created' sent.") return token
[docs] def create_refresh_token(self, payload: Dict[str, Any]) -> str: """ Creates a new refresh token. Emits 'jwt_refresh_token_created' signal on creation. Args: payload: A dictionary containing the data to encode in the token. Should contain information needed to issue a new access token, like 'user_id'. Returns: A signed JWT refresh token string. """ data = payload.copy() data["type"] = "refresh" data["exp"] = datetime.utcnow() + timedelta(seconds=self.refresh_exp) logger.debug(f"Creating refresh token with payload: {payload}") token = jwt.encode(data, self.secret, algorithm=self.algorithm) dispatcher.send("jwt_refresh_token_created", sender=self, payload=payload, token=token) logger.debug("Signal 'jwt_refresh_token_created' sent.") return token
[docs] def decode_token(self, token: str) -> Optional[Dict[str, Any]]: """ Decodes a JWT token and verifies its signature and expiration. Emits 'jwt_token_decoded' on success. Emits 'jwt_decode_failed' on failure with specific error type. Args: token: The JWT token string to decode. Returns: A dictionary containing the decoded payload if valid, otherwise None. """ if not token: logger.debug("Attempted to decode an empty token.") dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="empty_token", exception=None) logger.debug("Signal 'jwt_decode_failed' (empty_token) sent.") return None try: decoded = jwt.decode(token, self.secret, algorithms=[self.algorithm]) logger.debug(f"Successfully decoded token. Payload: {decoded}") dispatcher.send("jwt_token_decoded", sender=self, token=token, payload=decoded) logger.debug("Signal 'jwt_token_decoded' sent.") return decoded except jwt.ExpiredSignatureError: logger.warning("JWT token expired.") dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="expired", exception=None) logger.debug("Signal 'jwt_decode_failed' (expired) sent.") return None except jwt.InvalidSignatureError: logger.warning("JWT token has an invalid signature.") dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="invalid_signature", exception=None) logger.debug("Signal 'jwt_decode_failed' (invalid_signature) sent.") return None except jwt.InvalidAudienceError: logger.warning("JWT token has an invalid audience claim.") dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="invalid_audience", exception=None) logger.debug("Signal 'jwt_decode_failed' (invalid_audience) sent.") return None except jwt.InvalidIssuerError: logger.warning("JWT token has an invalid issuer claim.") dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="invalid_issuer", exception=None) logger.debug("Signal 'jwt_decode_failed' (invalid_issuer) sent.") return None except jwt.InvalidIssuedAtError: logger.warning("JWT token has an invalid issued at claim.") dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="invalid_issued_at", exception=None) logger.debug("Signal 'jwt_decode_failed' (invalid_issued_at) sent.") return None except jwt.DecodeError: logger.warning("JWT token decoding failed. Malformed token?") dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="decode_error", exception=None) logger.debug("Signal 'jwt_decode_failed' (decode_error) sent.") return None except Exception as e: logger.error(f"An unexpected error occurred during JWT token decoding: {e}", exc_info=True) dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="unexpected_exception", exception=e) logger.debug("Signal 'jwt_decode_failed' (unexpected_exception) sent.") return None
[docs] def is_token_valid(self, token: str, token_type: Optional[str] = None) -> bool: """ Checks if a JWT token is valid (signature and expiry) and optionally matches a specific type. This method primarily relies on decode_token, which emits failure signals. It adds a check for token type mismatch. Args: token: The JWT token string to validate. token_type: Optional. The expected type of the token ('access' or 'refresh'). If None, only checks signature and expiry. Returns: True if the token is valid and matches the type (if specified), False otherwise. """ decoded = self.decode_token(token) if decoded is None: return False if token_type is not None: token_matches_type = decoded.get("type") == token_type if not token_matches_type: logger.debug(f"Token type mismatch. Expected '{token_type}', got '{decoded.get('type')}'.") dispatcher.send("jwt_decode_failed", sender=self, token=token, error_type="type_mismatch", expected_type=token_type, actual_type=decoded.get('type')) logger.debug("Signal 'jwt_decode_failed' (type_mismatch) sent.") return token_matches_type return True
[docs] def get_payload(self, token: str) -> Optional[Dict[str, Any]]: """ Decodes a token and returns its payload if valid. Relies on decode_token for validation and signal emission. Args: token: The JWT token string. Returns: The decoded payload dictionary if the token is valid, otherwise None. """ return self.decode_token(token)
[docs] def get_user_id(self, token: str) -> Optional[Any]: """ Decodes a token and returns the 'user_id' claim if the token is valid and the claim exists. Relies on decode_token for validation and signal emission. Args: token: The JWT token string. Returns: The value of the 'user_id' claim if found in a valid token, otherwise None. """ decoded = self.decode_token(token) if decoded: return decoded.get("user_id") return None