Source code for duck.contrib.jwt
"""
JWT utilities for Duck's auth system.
Provides token encoding, decoding, issuance, and rotation.
Requires PyJWT — install with: pip install PyJWT
"""
import datetime
from typing import Any, Dict
from duck.settings import SETTINGS
[docs]
class JWTError(Exception):
"""
Base class for JWT-related errors.
"""
[docs]
class JWTExpired(JWTError):
"""
Raised when a JWT has expired.
"""
[docs]
class JWTInvalid(JWTError):
"""
Raised when a JWT cannot be decoded or is structurally invalid.
"""
[docs]
def get_jwt_lib():
"""
Import and return the PyJWT module.
Returns:
The imported ``jwt`` module.
Raises:
ImportError: If the ``PyJWT`` package is not installed.
"""
try:
import jwt
return jwt
except (ImportError, ModuleNotFoundError) as e:
raise ImportError(
"Failed to import the 'PyJWT' package. "
"Install it using: pip install PyJWT"
) from e
[docs]
def get_secret_key() -> str:
"""
Retrieve the secret key from Duck settings.
Returns:
The ``SECRET_KEY`` string from settings.
"""
secret = SETTINGS.get("JWT_SECRET_KEY") or SETTINGS.get('SECRET_KEY')
if not secret:
raise JWTError(
"JWT_SECRET_KEY is not set in Duck settings. "
"This is required for JWT encoding and decoding."
)
# Finally, return secret
return secret
[docs]
def get_algorithm() -> str:
"""
Retrieve the JWT signing algorithm from Duck settings.
Defaults to ``HS256`` if ``JWT_ALGORITHM`` is not set.
Returns:
The algorithm string (e.g. ``"HS256"``).
"""
return SETTINGS.get("JWT_ALGORITHM", "HS256")
[docs]
def get_access_lifetime() -> datetime.timedelta:
"""
Retrieve the access token lifetime from Duck settings.
Defaults to 60 minutes if ``JWT_ACCESS_LIFETIME`` is not set.
Returns:
Seconds representing the access token lifetime.
"""
return SETTINGS.get("JWT_ACCESS_LIFETIME", 3600)
[docs]
def get_refresh_lifetime() -> float:
"""
Retrieve the refresh token lifetime from Duck settings.
Defaults to 7 days in seconds if ``JWT_REFRESH_LIFETIME`` is not set.
Returns:
Seconds representing the refresh token lifetime.
"""
return SETTINGS.get("JWT_REFRESH_LIFETIME", 7 * 24 * 3600)
[docs]
def encode_token(
payload: dict[str, Any],
token_type: str = "access",
) -> str:
"""
Encode a signed JWT token from the given payload.
Args:
payload: Data to embed in the token (e.g. ``{"user_id": 1}``).
token_type: Either ``"access"`` or ``"refresh"``. Controls the
default expiry when ``expires_in`` is not given.
expires_in: Custom expiry duration. Overrides settings-based defaults.
Returns:
A signed JWT string.
"""
jwt = get_jwt_lib()
if "exp" not in payload:
raise JWTError(f"Expiry not set, please include key 'exp' in payload.")
elif payload["exp"] is None:
raise JWTError(f"Expiry is None, please include key 'exp' in payload.")
claims = {
**payload,
"type": token_type,
"iat": int(datetime.datetime.utcnow().timestamp()),
}
return jwt.encode(claims, get_secret_key(), algorithm=get_algorithm())
[docs]
def decode_token(token: str, verify_expiry: bool = True) -> Dict[str, Any]:
"""
Decode and verify a signed JWT token.
Args:
token: The raw JWT string to decode.
verify_expiry: Whether to verify. If True, `JWTExpired` may be raised on expiry.
Returns:
The decoded payload as a dict.
Raises:
JWTExpired: If the token's ``exp`` claim has passed.
JWTInvalid: If the token is malformed, tampered with, or the
signature does not match.
"""
jwt = get_jwt_lib()
try:
return jwt.decode(
token,
get_secret_key(),
algorithms=[get_algorithm()],
options={"verify_exp": verify_expiry},
)
except jwt.ExpiredSignatureError as e:
raise JWTExpired("Token has expired.") from e
except jwt.InvalidTokenError as e:
raise JWTInvalid(f"Token is invalid: {e}") from e
[docs]
def issue_token_pair(
payload: dict[str, Any] | None = None,
) -> dict[str, str]:
"""
Issue both an access and a refresh token for a user.
Args:
payload: Additional claims embedded in both tokens.
Returns:
A dict with ``"access"`` and ``"refresh"`` JWT strings.
"""
return {
"access": encode_token(payload, token_type="access"),
"refresh": encode_token(payload, token_type="refresh"),
}