Source code for duck.http.session.engine

"""
Module representing Duck default session engine, i.e. SessionStore class.
"""

import uuid
import datetime
import warnings

from typing import Optional, Union
from functools import wraps

from duck.settings import SETTINGS
from duck.utils.importer import import_module_once
from duck.utils.asyncio import raise_if_in_async_context
from duck.contrib.sync import ensure_async
from duck.logging import logger


# TODO: Need to make the SESSION API fully async-compatible

session_connector_mod = import_module_once("duck.http.session.session_storage_connector")


[docs] class SessionError(Exception): """ Session related errors. """
[docs] class SessionExpired(SessionError): """ Raised on save operations if a session has already expired. """
[docs] class SessionStore(dict): """ Session store for storing session data. """ __slots__ = { "_session_key", "_loaded", "_modified", "disable_warnings", "session_storage_connector", } def __init__(self, session_key: str, disable_warnings: bool = False): """ Initializes the session store. """ super().__init__() self._session_key = session_key self._loaded = False self._modified = False self.disable_warnings = disable_warnings self.session_storage_connector = session_connector_mod.get_session_storage_connector() @property def session_key(self): return self._session_key @session_key.setter def session_key(self, key: str): self._session_key = key @property def modified(self) -> bool: """ Returns the state whether the session has been modified after load or creation. """ return self._modified @property def loaded(self) -> bool: """ Returns the state whether the session has been loaded. """ return self._loaded @modified.setter def modified(self, what: bool): """ Sets whether the session has been modified. """ self._modified = what
[docs] def needs_update(self) -> bool: """ Returns whether the session data is worthy to be saved, this is the lazy behavior of Duck. """ if not self.loaded: # If session hasn't been loaded, we consider it as not modified to avoid unnecessary saves return False return self.modified
[docs] @staticmethod def generate_session_id() -> str: """ Generates and returns a random session ID. """ return str(uuid.uuid4())
[docs] def session_expired(self) -> bool: """ Returns boolean on whether if the session has expired depending on expiry set on session. """ expiry_date = self.get_expiry_date() now = datetime.datetime.utcnow() return now >= expiry_date
[docs] def get_expiry_age(self): """ Returns the session max age from current settings. """ return SETTINGS["SESSION_COOKIE_AGE"]
[docs] def get_expiry_date(self): """ Returns the datetime the session is going to expire. """ expire_date = self.get("expiry_date") if not expire_date: self.set_expiry( datetime.datetime.utcnow() + datetime.timedelta(seconds=self.get_expiry_age()) ) return self.get_expiry_date() return expire_date
[docs] def set_expiry( self, expiry: Optional[Union[int, float, datetime.datetime, datetime.timedelta]] = None, ): """ Sets the session expiry. Args: expiry Optional[Union[int, float, datetime.datetime, datetime.timedelta]]: Float or int represents the seconds to expire from now and None represents the now plus the default session max_age. """ if expiry is None: self["expiry_date"] = ( datetime.datetime.utcnow() + datetime.timedelta(seconds=self.get_expiry_age()) ) elif isinstance(expiry, (datetime.datetime, datetime.timedelta)): self["expiry_date"] = expiry elif isinstance(expiry, (int, float)): self["expiry_date"] = ( datetime.datetime.utcnow() + datetime.timedelta(seconds=expiry) ) else: raise SessionError( f"Invalid expiry, expected any of [int, float, datetime.datetime, datetime.timedelta, None] but got '{type(expiry)}'" )
[docs] @staticmethod def check_session_storage_connector(method): """ Decorator to ensure a valid session storage connector is present. Validates: - Attribute exists - Attribute is not None - Attribute is correct type """ from duck.http.session.session_storage_connector import SessionStorageConnector @wraps(method) def wrapper(self, *args, **kwargs): connector = getattr(self, "session_storage_connector", None) if connector is None: raise ValueError("Session storage connector is not set") if not isinstance(connector, SessionStorageConnector): raise TypeError( "Invalid session storage connector provided. " f"Expected {SessionStorageConnector}, got {type(connector)}" ) # Execute the decorated method return method(self, *args, **kwargs) # Return wrapper return wrapper
[docs] @staticmethod def ensure_session_loaded(method): """ Decorator which ensures that the session is loaded. """ @wraps(method) def wrapper(self, *args, **kwargs): if not self.loaded: self.load() # Execute the decorated method return method(self, *args, **kwargs) # Return wrapper return wrapper
[docs] @check_session_storage_connector def load(self) -> dict: """ Loads the session from storage. """ if self.loaded and not self.disable_warnings: logger.warn( f"{self.__class__.__name__} is already loaded; reloading may be inefficient." ) # Retrieve session data from storage, if session key is invalid or session doesn't exist, get_session should return None, so we can safely fallback to empty dict session_data = self.session_storage_connector.get_session(self.session_key) or {} # Update session store with retrieved data, this will also set the modified flag if data is not empty super().update(session_data) # Avoids recursion error if not self.loaded: self._modified = False # If session hasn't been loaded for the first time, set _modified to False # Update state and return session data self._loaded = True return session_data
[docs] def save(self): """ Save the session """ raise_if_in_async_context("Please use 'async_save' method instead.") self._save()
[docs] @check_session_storage_connector @ensure_session_loaded def _save(self): """ Saves the session to storage. """ if not self.session_key: raise ValueError("Session key is not set or invalid.") # Normalize session data session_data = dict(self) if not self.session_expired(): # Session is not expired expiry_age = self.get_expiry_age() # Set the session in the real storage self.session_storage_connector.set_session( self.session_key, session_data, expiry=expiry_age, ) # Save the current state of session storage self.session_storage_connector.save() else: raise SessionExpired("Cannot save an expired session, use `set_expiry` to reset the session expiry.") # Reset session modification self._modified = False # reset session modification.
[docs] @check_session_storage_connector def exists(self, session_key: Optional[str] = None) -> bool: """ Checks if a session with the specified key exists. Args: session_key (Optional[str]): The session key or None if you want to use the current session key. """ session_key = session_key or self.session_key try: return bool(self.session_storage_connector.get_session(session_key)) except KeyError: return False
[docs] @check_session_storage_connector def assign_new_session_key(self) -> str: """ Creates a new session with a new session key. """ self.session_key = self.generate_session_id() return self.session_key
[docs] @check_session_storage_connector @ensure_session_loaded def get(self, *args, **kw): """ Return value for a key. """ return super().get(*args, **kw)
[docs] @check_session_storage_connector def delete(self, session_key: Optional[str] = None): """ Deletes and clears the session from session storage. Args: session_key (Optional[str]): The session key or None if you want to use the current session key. """ session_key = session_key or self.session_key self.session_storage_connector.delete_session(session_key) self.clear()
[docs] @ensure_session_loaded def update(self, data: dict): """ Overrides the update method to ensure items are tracked for modification. """ super().update(data) if data: self._modified = True
[docs] @ensure_session_loaded def clear(self): """ Clears all session data. """ is_empty = bool(self or None) super().clear() if not is_empty: self._modified = True
[docs] @ensure_session_loaded def pop(self, *args, **kwargs): """ Pops some session data. """ data = super().pop(*args, **kwargs) # If data is not empty, we consider the session modified, otherwise we don't to avoid unnecessary saves if data: self._modified = True # Return popped data return data
[docs] @ensure_session_loaded def popitem(self, *args, **kwargs): """ Pops some session data. """ data = super().popitem(*args, **kwargs) # If data is not empty, we consider the session modified, otherwise we don't to avoid unnecessary saves if data: self._modified = True # Return popped data return data
# SOME ASYNC API's
[docs] async def async_save(self): """ Asynchronously save session. """ await ensure_async(self._save)()
[docs] @ensure_session_loaded def __setitem__(self, key, value): super().__setitem__(key, value) self._modified = True
[docs] @ensure_session_loaded def __getitem__(self, key): value = super().__getitem__(key) return value
[docs] @ensure_session_loaded def __delitem__(self, key): super().__delitem__(key) self._modified = True
[docs] @ensure_session_loaded def __repr__(self): return f"<{self.__class__.__name__} {dict(self)}>"