Source code for duck.contrib.websockets.extensions
"""
WebSocket Extensions module.
Provides base and concrete implementations for WebSocket frame extensions,
including permessage-deflate compression as defined in RFC 7692.
"""
import zlib
from duck.contrib.websockets.opcodes import OpCode, CONTROL_OPCODES
from duck.contrib.websockets.exceptions import ExtensionError
[docs]
class Extension:
"""
Base class for WebSocket extensions.
Extensions allow for modification of WebSocket frames during encoding
(sending) or decoding (receiving), such as compression or encryption.
"""
def __init__(self, name: str):
"""
Initialize the extension with a valid name.
Args:
name (str): The name of the extension as it should appear in
the Sec-WebSocket-Extensions header.
Raises:
ValueError: If the name is not a non-empty string.
"""
if not isinstance(name, str) or not name.strip():
raise ValueError("Extension name must be a non-empty string.")
self.name = name.strip()
[docs]
def check_frame(self, frame):
"""
Validates that the given object is a Frame instance.
Args:
frame: The frame object to validate.
Raises:
ExtensionError: If the object is not an instance of Frame.
"""
from duck.contrib.websockets.frame import Frame
if not isinstance(frame, Frame):
raise ExtensionError(
f"The frame should be an instance of Frame, not {type(frame)}."
)
[docs]
def encode(self, frame) -> "Frame":
"""
Applies the extension to encode (transform) an outgoing frame.
Args:
frame: The frame to encode.
Raises:
NotImplementedError: If not implemented by a subclass.
Returns:
Frame: The encoded frame, typically the same frame but encoded.
"""
raise NotImplementedError("Implement this method to be able to encode frames.")
[docs]
def decode(self, frame) -> "Frame":
"""
Applies the extension to decode (transform) an incoming frame.
Args:
frame: The frame to decode.
Raises:
NotImplementedError: If not implemented by a subclass.
Returns:
Frame: The decoded frame, typically the same frame but decoded.
"""
raise NotImplementedError("Implement this method to be able to decode frames.")
[docs]
class PerMessageDeflate(Extension):
"""
Per-message Deflate Extension (RFC 7692).
Provides compression for non-control WebSocket frames using DEFLATE.
Supports options for context takeover and window size.
Args:
client_no_context_takeover (bool): Whether to disable decompression context reuse.
server_no_context_takeover (bool): Whether to disable compression context reuse.
client_max_window_bits (int): Maximum window bits for client decompression (8-15).
Defaults to 15, which means 32K LZ77 history buffer.
"""
def __init__(
self,
name: str,
client_no_context_takeover: bool = False,
server_no_context_takeover: bool = False,
client_max_window_bits: int = 15
):
super().__init__(name)
if not (8 <= client_max_window_bits <= 15):
raise ValueError("client_max_window_bits must be between 8 and 15.")
self.client_no_context_takeover = client_no_context_takeover
self.server_no_context_takeover = server_no_context_takeover
self.client_max_window_bits = client_max_window_bits
# Create initial compression/decompression contexts with raw DEFLATE.
self._compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS)
self._decompressor = zlib.decompressobj(wbits=-client_max_window_bits)
[docs]
def encode(self, frame) -> "Frame":
"""
Compresses the payload of the given frame using DEFLATE.
Skips control frames. Appends Z_SYNC_FLUSH marker and strips final 4-byte tail
as required by RFC 7692. Sets RSV1 on first frame of a message.
Args:
frame: A Frame instance to encode.
Raises:
ExtensionError: If the input is not a valid Frame.
Returns:
Frame: The encoded frame, typically the same frame but encoded.
"""
self.check_frame(frame)
if frame.opcode not in CONTROL_OPCODES:
# Compress the payload using raw DEFLATE with Z_SYNC_FLUSH.
compressed = self._compressor.compress(frame.payload)
compressed += self._compressor.flush(zlib.Z_SYNC_FLUSH)
# Remove the last 4 bytes (empty DEFLATE block, 0x00 0x00 0xff 0xff)
frame.payload = compressed[:-4]
# Set RSV1 on the first frame in a fragmented message
if frame.opcode != OpCode.CONTINUATION:
frame.rsv1 = True
if self.server_no_context_takeover:
# Reset compressor state for each message if takeover is disabled.
self._compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS)
# Return the encoded frame (if applicable).
return frame
[docs]
def decode(self, frame) -> "Frame":
"""
Decompresses the payload of the given frame using DEFLATE.
Skips control frames. Appends the 4-byte tail removed during encoding.
Unsets RSV1 after decoding.
Args:
frame: A Frame instance to decode.
Raises:
ExtensionError: If the input is not a valid Frame.
Returns:
Frame: The decoded frame, typically the same frame but decoded.
"""
self.check_frame(frame)
if frame.opcode not in CONTROL_OPCODES:
# Append DEFLATE tail to allow proper decompression.
frame.payload += b"\x00\x00\xff\xff"
# Decompress payload
frame.payload = self._decompressor.decompress(frame.payload)
# Clear RSV1 to avoid protocol errors during frame validation
if frame.opcode != OpCode.CONTINUATION:
frame.rsv1 = False
if self.client_no_context_takeover:
# Reset decompressor state after each message if takeover is disabled.
self._decompressor = zlib.decompressobj(wbits=-self.client_max_window_bits)
# Return the decoded frame (if applicable)
return frame
[docs]
def __repr__(self):
"""
Returns a debug-friendly string representation of the extension.
"""
return (
f"<PerMessageDeflate "
f"client_no_context_takeover={self.client_no_context_takeover}, "
f"server_no_context_takeover={self.server_no_context_takeover}, "
f"client_max_window_bits={self.client_max_window_bits}>"
)