Source code for duck.http.jwt.engine

"""
Module representing Duck's JWT store — payload management only.

Transport extraction and client delivery live in JWTMiddleware.
This store is concerned only with decoding, encoding, and mutating
the JWT payload — mirroring what SessionStore does for session data.
"""

import datetime

from typing import Optional, Union, Dict
from functools import wraps

from duck.settings import SETTINGS
from duck.logging import logger
from duck.contrib.jwt import (
    encode_token,
    decode_token,
    get_access_lifetime,
    get_refresh_lifetime,
    JWTInvalid,
)


[docs] class JWTStore: """ Store for decoding, encoding, and managing a JWT payload. Mirrors the lazy-load pattern of ``SessionStore`` — the payload is decoded on first access and mutations are tracked via ``modified``. Example usage: ```py store = JWTStore(raw_token="eyJ...") store.load() user_id = store.get("user_id") store.set("role", "distributor") if store.needs_update(): # Reset expiry - None will resolve expiry in settings. store.set_expiry(store.expiry_secs) # Retrieve access and refresh tokens access, refresh = store.encode_all() # Send token to client here. # Mark the token as synced with the client. store.mark_updated() ``` Settings keys consumed: - ``JWT_SECRET_KEY`` — signing secret (required). - ``JWT_ALGORITHM`` — defaults to ``"HS256"``. - ``JWT_ACCESS_LIFETIME`` — default expiry in seconds, defaults to ``3600``. """ # Slots declared as a tuple, not a set literal __slots__ = ( "raw_token", "payload", "loaded", "modified", ) def __init__(self, raw_token: Optional[str] = None): """ Initializes the JWT store. Args: raw_token (Optional[str]): Encoded JWT string to decode on load, or ``None`` to start with an empty payload. """ self.raw_token: Optional[str] = raw_token self.payload: dict = {} self.loaded: bool = False self.modified: bool = False @property def expiry_secs(self) -> Optional[float]: """ Returns the remaining seconds until the token expires. Computed live from the ``exp`` claim rather than the originally set duration, so it reflects actual time-to-expiry at call time. Returns ``None`` if no ``exp`` claim is present, and ``0.0`` if already expired. Returns: Optional[float]: Seconds remaining, ``0.0`` if expired, or ``None`` if ``exp`` is not set. """ exp = self.payload.get("exp") if exp is None: return get_access_lifetime() # exp may be a datetime or a numeric unix timestamp if isinstance(exp, datetime.datetime): remaining = (exp - datetime.datetime.utcnow()).total_seconds() else: remaining = exp - datetime.datetime.utcnow().timestamp() return float(max(0.0, remaining)) # Decorator
[docs] @staticmethod def ensure_loaded(method): """ Decorator that auto-loads the store before the method executes. """ @wraps(method) def wrapper(self, *args, **kwargs): if not self.loaded: self.load() return method(self, *args, **kwargs) return wrapper
# Core load / encode
[docs] def load(self) -> dict: """ Decodes the raw token and populates the internal payload. If no raw token was provided, the store loads with an empty payload and is treated as an unauthenticated context — no error is raised. Returns: dict: The decoded JWT payload. Raises: JWTExpired: If the token's ``exp`` claim is in the past. JWTInvalid: If the token is malformed or signature verification fails. """ if not self.raw_token: # No token — empty unauthenticated store self.payload = {} self.loaded = True return self.payload # Decode token but ignore expiry checks, the responsible middleware will check `is_expired` method try: self.payload = decode_token(self.raw_token, verify_expiry=False) except JWTInvalid as e: raise JWTInvalid("Token invalid, try using `reset()` to reset the JWT.") from e # Mark clean after fresh decode self.loaded = True self.modified = False return self.payload
[docs] def reset(self): """ This will reset the set raw token. """ self.raw_token = None
[docs] def encode(self) -> str: """ Encodes the current payload into a signed JWT string. Also updates ``raw_token`` with the newly encoded value. Returns: str: The signed JWT token string. Raises: JWTError: If ``exp`` has not been set on the payload. """ if "exp" not in self.payload: raise JWTError("Expiry not set, please set it using `set_expiry` or `reset_expiry` method.") token: str = encode_token( payload=self.payload, token_type="access", ) # PyJWT >=2.0 returns str; older versions return bytes if isinstance(token, bytes): token = token.decode("utf-8") # Keep raw_token in sync self.raw_token = token return token
[docs] def encode_all(self) -> Dict[str, str]: """ Returns a dictionary containing both access and refresh token. """ return {"access": self.encode(), "refresh": self.encode_refresh_token()}
[docs] def encode_refresh_token(self) -> str: """ This retrieves a refresh token for this store. """ token = encode_token({"exp": get_refresh_lifetime()}, token_type="refresh") # PyJWT >=2.0 returns str; older versions return bytes if isinstance(token, bytes): token = token.decode("utf-8") return token
# Payload access
[docs] @ensure_loaded def get(self, key: str, default=None): """ Returns a claim value from the payload. Args: key (str): Claim key to look up. default: Returned if the key is absent. """ return self.payload.get(key, default)
[docs] @ensure_loaded def set(self, key: str, value): """ Sets a claim and marks the store as modified. Args: key (str): Claim key. value: Value to store. """ self.payload[key] = value self.modified = True
[docs] @ensure_loaded def delete(self, key: str): """ Removes a claim from the payload. Args: key (str): Claim key to remove. """ if key in self.payload: del self.payload[key] self.modified = True
[docs] @ensure_loaded def update(self, data: dict): """ Merges a dict of claims into the payload. Args: data (dict): Claims to merge in. """ if data: self.payload.update(data) self.modified = True
[docs] @ensure_loaded def clear(self): """ Clears all claims from the payload. """ if self.payload: self.payload.clear() self.modified = True
[docs] @ensure_loaded def keys(self): return self.payload.keys()
[docs] @ensure_loaded def values(self): return self.payload.values()
[docs] @ensure_loaded def items(self): return self.payload.items()
# Expiry helpers
[docs] def set_expiry( self, expiry: Optional[Union[int, float, datetime.datetime, datetime.timedelta]] = None, ): """ Sets the ``exp`` claim on the payload. Args: expiry (Optional[Union[int, float, datetime.datetime, datetime.timedelta]]): - ``int`` / ``float`` — seconds from now. - ``datetime.timedelta`` — duration added to now. - ``datetime.datetime`` — absolute expiry datetime. - ``None`` — falls back to ``JWT_ACCESS_LIFETIME_AGE`` from settings. Raises: JWTError: If the expiry type is not supported. """ now = datetime.datetime.utcnow() if expiry is None: age = get_access_lifetime() self.payload["exp"] = now + datetime.timedelta(seconds=age) elif isinstance(expiry, (int, float)): self.payload["exp"] = now + datetime.timedelta(seconds=expiry) elif isinstance(expiry, datetime.timedelta): # total_seconds() is the correct method (not .to_seconds()) self.payload["exp"] = now + expiry elif isinstance(expiry, datetime.datetime): # Subtraction yields a timedelta, extract float seconds self.payload["exp"] = expiry else: raise JWTError( f"Unsupported expiry type '{type(expiry)}'. " "Expected int, float, timedelta, datetime, or None." ) self.modified = True
[docs] def reset_expiry(self): """ Reset the expiry to the exact seconds last used for the JWT or the seconds set in settings or default seconds. """ self.set_expiry(self.expiry_secs)
[docs] def is_expired(self) -> bool: """ Returns whether the ``exp`` claim is in the past. Returns: bool: ``True`` if expired or if no ``exp`` claim is present. """ exp = self.payload.get("exp") if exp is None: return False # None expirying if isinstance(exp, datetime.datetime): return datetime.datetime.utcnow() >= exp # Numeric unix timestamp return datetime.datetime.utcnow().timestamp() >= exp
[docs] def mark_updated(self): """ Resets the modified flag after the token has been sent to the client. """ self.modified = False
[docs] def needs_update(self) -> bool: """ Returns whether the payload has unsaved modifications. Mirrors ``SessionStore.needs_update`` — only meaningful once loaded. Returns: bool: ``True`` if the store is loaded and has been modified. """ if not self.loaded: return False return self.modified
# Dunder helpers
[docs] @ensure_loaded def __setitem__(self, key, value): self.payload[key] = value self.modified = True
[docs] @ensure_loaded def __getitem__(self, key): return self.payload[key]
[docs] @ensure_loaded def __delitem__(self, key): del self.payload[key] self.modified = True
[docs] @ensure_loaded def __contains__(self, key): return key in self.payload
[docs] @ensure_loaded def __repr__(self): return f"<{self.__class__.__name__} {self.payload}>"
[docs] @ensure_loaded def __iter__(self): return iter(self.payload)