"""
Duck WebSocket Implementation.
In your urlpatterns, you can use websocket protocol as follows:
```
# urls.py
from duck.urls import path
from duck.contrib.websockets import WebSocketView, OpCode
class MyWebSocket(WebSocketView):
async def on_open(self):
print("WebSocket connection established")
async def on_receive(data, opcode: int):
if opcode == OpCode.TEXT:
message = "Client sent " + data.decode("utf-8")
await self.send_text(message)
else:
# Handle binary here
pass
# Now create a urlpattern entry
urlpatterns = [
path("/ws/myws", MyWebSocket, name="mywebsocket"),
# other patterns here.
]
```
"""
import ssl
import time
import json
import struct
import copy
import base64
import hashlib
import asyncio
import enum
from typing import (
Sequence,
List,
Optional,
Union,
)
from duck.views import View
from duck.settings import SETTINGS
from duck.exceptions.all import ExpectingNoResponse
from duck.http.request import HttpRequest
from duck.http.response import (
HttpSwitchProtocolResponse,
HttpBadRequestResponse
)
from duck.contrib.websockets.frame import Frame
from duck.contrib.websockets.logging import log_message, logger
from duck.contrib.websockets.opcodes import (
OpCode,
CloseCode,
DATA_OPCODES,
)
from duck.contrib.websockets.extensions import (
Extension,
PerMessageDeflate,
)
from duck.contrib.websockets.exceptions import (
ProtocolError,
PayloadTooBig,
)
from duck.utils.xsocket import xsocket
from duck.utils.xsocket.io import SocketIO
from duck.utils.asyncio import create_task
[docs]
class State(enum.IntEnum):
"""
Int enum of connection state.
"""
OPEN = 1
CLOSED = 0
INITIATING = -1
INITIATED = 2
[docs]
class WebSocketView(View):
"""
RFC 7692-compliant WebSocket view with permessage-deflate compression,
context takeover negotiation, heartbeat, fragmentation handling,
partial streaming of large frames, and robust error handling.
Features:
- Supports WebSocket upgrade handshake with version checks.
- Negotiates permessage-deflate compression extension with context takeover flags.
- Sends and receives WebSocket frames with optional compression.
- Implements ping/pong heartbeat with exponential backoff and failure detection.
- Handles fragmented message reassembly.
- Ensures all task exceptions in heartbeat and receive loops are properly re-raised.
- Cleanly closes connections and releases resources.
"""
MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
"""
str: Magic string for generating Sec-WebSocket-Accept-Key.
"""
PING_INTERVAL = 20
"""
int: Seconds between sending ping frames.
"""
PONG_TIMEOUT = 10
"""
int: Seconds to wait for a pong response.
"""
MAX_BACKOFF = 45
"""
int: Maximum exponential backoff delay in seconds.
"""
RECEIVE_TIMEOUT = 120
"""
int: Timeout in seconds for receiving WebSocket frames.
"""
MAX_FRAME_SIZE = 1 * 1024 * 1024
"""
int: Maximum allowed size of incoming message frame in bytes (1MB).
"""
def __init__(self, upgrade_request: HttpRequest, **kwargs):
"""
Initialize the WebSocketView with the initial HTTP upgrade request.
Args:
upgrade_request (HttpRequest): The HTTP request that initiated the WebSocket upgrade.
**kwargs: Additional keyword arguments passed to subclasses.
Initializes internal state for compression, heartbeat, fragmentation,
and communication queues.
"""
from duck.http.core.httpd.httpd import response_handler
self.request = upgrade_request
self.kwargs = kwargs
# Private attributes
self._closing = False
self._last_pong_time = time.time()
self._heartbeat_task = None
self._receiver_task = None
self._response_handler = response_handler
self._data_handling_tasks: set[asyncio.Task] = set()
# WebSocket state
self.client_websocket_version = None
self.min_websocket_version = 8
self.initiated_upgrade = False
self.fragmented_frame: Frame = None
self.state = State.CLOSED
# Extensions negotiated upon upgrade
self.extensions: Sequence[Extension] = []
@property
def server(self):
return self.request.application.server
@property
def sock(self):
"""
Returns the connected socket.
"""
return self.request.client_socket
[docs]
def strictly_async(self):
return True # Set the view to be strictly asynchronous
[docs]
def get_sec_accept_key(self, sec_websocket_key: str) -> str:
"""
Generates the Sec-WebSocket-Accept key for the handshake response.
Args:
sec_websocket_key (str): Sec-WebSocket-Key header value.
Returns:
str: Base64-encoded SHA-1 hash of the client's Sec-WebSocket-Key and the magic string.
"""
combined = sec_websocket_key + self.MAGIC_STRING
sha1 = hashlib.sha1(combined.encode("utf-8")).digest()
return base64.b64encode(sha1).decode("utf-8")
[docs]
async def run(self) -> None:
"""
Entry point for executing the WebSocket view.
This method runs the view's main event loop (`run_forever`). It is expected
that `run_forever` never returns under normal operation. If it does return
without raising an exception, an `ExpectingNoResponse` error is raised,
indicating an unexpected termination.
Raises:
ExpectingNoResponse: If `run_forever` completes without raising an exception.
Exception: Any exception raised during the execution of `run_forever`.
"""
exc = None
# Set client socket blocking to False if set to true.
self.request.client_socket.setblocking(False)
try:
await self.run_forever()
except Exception as e:
exc = e
finally:
if exc is None:
raise ExpectingNoResponse("WebSocket view must not return a response.")
if self.state in [State.INITIATED, State.CLOSED]:
# Don't raise exception, avoid doing so to avoid server responding with Internal Server Error
# because initial response has already been sent to client.
logger.log_raw(
f'\nError invoking websocket view for URL "{self.request.path}" ',
level=logger.ERROR,
custom_color=logger.Fore.YELLOW,
)
logger.log_exception(exc)
else:
raise exc # Reraise original exception
[docs]
async def initiate_upgrade_to_ws(self) -> bool:
"""
Perform the WebSocket handshake and send upgrade response to client.
Negotiates permessage-deflate compression extension and context takeover
parameters per RFC 7692 if supported by the client.
Returns:
bool: True if handshake and upgrade succeeded, False otherwise.
"""
from duck.shortcuts import simple_response, template_response
from duck.http.middlewares.security.csrf import (
CSRFMiddleware,
OriginError,
)
error_msg = None
sec_key = self.request.get_header('sec-websocket-key', '').strip()
version = self.request.get_header("sec-websocket-version", "").strip()
if self.request.get_header("upgrade", "").lower() != "websocket":
error_msg = "Missing Upgrade: websocket"
elif getattr(self.sock, "h2_handling", False):
error_msg = "WebSocket not allowed on HTTP/2"
elif not sec_key:
error_msg = "Missing Sec-WebSocket-Key"
elif not version:
error_msg = "Missing Sec-WebSocket-Version"
else:
try:
self.client_websocket_version = int(version)
if self.client_websocket_version < self.min_websocket_version:
error_msg = f"Minimum version {self.min_websocket_version} required"
except ValueError:
error_msg = "Invalid Sec-WebSocket-Version"
headers = {"Sec-WebSocket-Accept": self.get_sec_accept_key(sec_key)}
exts = self.request.get_header("sec-websocket-extensions", "").lower()
negotiated_exts = []
# Check origin validity
try:
CSRFMiddleware._check_origin_ok(self.request)
except OriginError as e:
err = str(e)
body = None if not SETTINGS['DEBUG'] else f"Error: {err}"
response = (template_response if SETTINGS['DEBUG'] else simple_response)(HttpBadRequestResponse, body=body)
self.request.META["DEBUG_MESSAGE"] = err
await self._response_handler.async_send_response(response, self.sock, request=self.request)
return
if "permessage-deflate" in exts:
client_no_context_takeover = "client_no_context_takeover" in exts
server_no_context_takeover = "server_no_context_takeover" in exts
# Add permessage-deflate extension
permessage_deflate = PerMessageDeflate(
name="permessage-deflate",
client_no_context_takeover=client_no_context_takeover,
server_no_context_takeover=server_no_context_takeover,
)
# Add extension in list
self.extensions.append(permessage_deflate)
# Add negotiated extensions
negotiated_exts.append(permessage_deflate.name)
negotiated_exts.append("client_no_context_takeover") if client_no_context_takeover else None
negotiated_exts.append("server_no_context_takeover") if server_no_context_takeover else None
# Set negotiated extensions in headers
if negotiated_exts:
headers["Sec-WebSocket-Extensions"] = "; ".join(negotiated_exts)
response = (
HttpSwitchProtocolResponse(upgrade_to="websocket", headers=headers)
if not error_msg
else (template_response if SETTINGS.get("DEBUG") else simple_response)(
HttpBadRequestResponse, body=error_msg
)
)
# Update state and send response
self.state = State.INITIATING
await self._response_handler.async_send_response(response, self.sock, request=self.request)
# Set the state of the upgrade
self.initiated_upgrade = error_msg is None
# Set and return the state of the upgrade.
if self.initiated_upgrade:
self.state = State.INITIATED
return self.initiated_upgrade
[docs]
async def run_forever(self):
"""
Run the WebSocket connection until closed or error occurs.
Starts asynchronous tasks for heartbeat ping/pong and receiving frames.
Waits until one of the tasks completes or raises an exception.
Exceptions raised inside heartbeat or receive loops are re-raised here.
Upon exit, calls `safe_close` to clean up resources.
"""
pending = []
if not await self.initiate_upgrade_to_ws():
SocketIO.close(self.sock)
self.state = State.CLOSED
return
# Log some message to the terminal
log_message(self.request, f"{self.request.path} [WebSocket] [OPEN]")
loop_exc = None # exception that may have been raised by a receive_loop or heartbeat_loop
try:
self._heartbeat_task = create_task(self._heartbeat_loop(), ignore_errors=[BaseException]) # ignore all exceptions.
self._receiver_task = create_task(self._receive_loop(), ignore_errors=[BaseException])
done, pending = await asyncio.wait(
[self._receiver_task, self._heartbeat_task], return_when=asyncio.FIRST_COMPLETED
)
ignore_errors = (
ConnectionError,
ConnectionResetError,
ssl.SSLError,
ssl.SSLWantReadError,
ssl.SSLWantWriteError,
OSError,
struct.error,
BrokenPipeError,
asyncio.CancelledError,
)
for task in done:
try:
exc = task.exception() # put this in a try block in case it raises an exception
except Exception as e:
exc = e
if exc:
loop_exc = exc
if isinstance(exc, TimeoutError):
# Did not receive data within specific timeout
await self.send_close(CloseCode.GOING_AWAY, reason="WebSocket Timeout Error")
return
elif not any([isinstance(exc, i) for i in ignore_errors]):
if "Connection closed while reading exact bytes" in str(exc):
# Avoid logging this exception, it happens almost everytime
# when connection closed unexpectably
return
if SETTINGS['DEBUG']:
# Only log these errors in debug mode.
logger.log_raw(f'Error running websocket loop for URL "{self.request.path}" ', level=logger.WARNING)
logger.log_exception(task.exception())
# Break loop on first exception
break
finally:
# WebSocket Shutdown
for task in {*self._data_handling_tasks, *(pending or [])}:
if not task.done():
try:
task.cancel()
except Exception:
pass
# Clear the data handling tasks
self._data_handling_tasks.clear()
if loop_exc and isinstance(loop_exc, TimeoutError):
await self.safe_close(disable_logging=True)
log_message(self.request, debug_message=[f"{self.request.path} [WebSocket] [CLOSE]", "WebSocket Timeout"])
else:
await self.safe_close()
self.state = State.CLOSED
[docs]
async def on_new_frame(self, frame: Frame):
"""
Handles the new frame by parsing it to `on_receive`.
"""
try:
await self.on_receive(frame.payload, frame.opcode)
except Exception as e:
logger.log_exception(e)
await self.send_close(CloseCode.INTERNAL_ERROR, reason="Internal Server Error" + f": {e}" if SETTINGS['DEBUG'] else "")
[docs]
async def on_open(self):
"""
Called on WebSocket open.
"""
pass
[docs]
async def on_receive(self, message: bytes, opcode, **kwargs):
"""
Called when a full WebSocket message is received.
Should be overridden by subclasses to implement message handling.
Args:
message (bytes): Message payload.
opcode (int): Message opcode.
"""
raise NotImplementedError("Implement this method to handle received WebSocket messages.")
[docs]
async def on_close(self, frame: Frame = None):
"""
Called when the WebSocket connection is closed.
Override to implement cleanup logic but make sure to call `safe_close(call_on_close_handler=False)` to actually close the connection.
Args:
frame (int, optional): Close frame if the client sent a close frame.
"""
if frame:
try:
await self.send_close(CloseCode.NORMAL_CLOSURE, reason="Normal closure")
except Exception:
pass
await self.safe_close(call_on_close_handler=False)
[docs]
async def read_frame(self) -> Frame:
"""
Read a single WebSocket frame from the client.
Handles masking and permessage-deflate decompression.
Returns:
Frame: The parsed frame.
Raises:
ProtocolError: If the frame format is invalid.
PayloadTooBig: If the payload exceeds max_size.
"""
async def read_exact(n: int):
"""
Async function for reading exact number of bytes from the socket.
"""
timeout = self.RECEIVE_TIMEOUT
buffer = bytearray()
while len(buffer) < n:
await asyncio.sleep(0)
chunk = await SocketIO.async_receive(sock=self.request.client_socket, timeout=timeout, bufsize=(n - len(buffer)))
if not chunk:
# connection closed before reading n bytes
raise ProtocolError("Connection closed while reading exact bytes")
buffer.extend(chunk)
return bytes(buffer)
# Receive frame from the socket
frame = await Frame.parse(
read_exact=read_exact,
mask_required=True,
max_size=self.MAX_FRAME_SIZE,
extensions=self.extensions
)
# Return the new frame.
return frame
[docs]
async def _heartbeat_loop(self):
"""
Periodically send ping frames and verify pong responses.
Uses exponential backoff on missed pong frames. Raises TimeoutError after
three consecutive failures.
Raises:
TimeoutError: If pong not received within timeout multiple times.
Exception: Propagates other exceptions from send_ping or sleep.
"""
failures = 0
base = self.PING_INTERVAL
while not self._closing:
try:
await self.send_ping()
except Exception:
pass
await asyncio.sleep(self.PONG_TIMEOUT)
if time.time() - self._last_pong_time > self.PONG_TIMEOUT:
failures += 1
if failures >= 3:
raise TimeoutError("WebSocket pong timeout")
delay = min(base * (2 ** failures), self.MAX_BACKOFF)
await asyncio.sleep(delay)
else:
failures = 0
await asyncio.sleep(base - self.PONG_TIMEOUT)
[docs]
async def _receive_loop(self):
"""
Continuously reads frames from the client, handles control frames,
reassembles fragmented messages, and dispatches complete messages
to the handler.
Raises:
Exception: Propagates exceptions from socket read or processing.
"""
# Execute on_open event.
await self.on_open()
while not self._closing:
try:
frame = await self.read_frame()
except PayloadTooBig:
await self.send_close(
CloseCode.GOING_AWAY,
reason=f"Payload too big. Max payload is {self.MAX_FRAME_SIZE} bytes"
)
break
if frame.opcode == OpCode.CLOSE:
await self.on_close(frame)
elif frame.opcode == OpCode.PING:
await self.send_pong(frame.payload)
elif frame.opcode == OpCode.PONG:
self._last_pong_time = time.time()
elif frame.opcode in (OpCode.TEXT, OpCode.BINARY):
if frame.fin:
# Use duck.utils.asyncio.create_task as it automatically raises errors (if any) compared to the
# default asyncio.create_task
task = create_task(
coro=self.on_new_frame(frame),
on_complete=lambda task: (
self._data_handling_tasks.remove(task)
if task in self._data_handling_tasks
else None
),
)
# Add data handling task to task list
self._data_handling_tasks.add(task)
else:
# Start buffering fragmented message
self.fragmented_frame = frame
elif frame.opcode == OpCode.CONTINUATION:
if not self.fragmented_frame:
raise ProtocolError("Continuation frame without initial fragment")
# Add more payload to fragmented frame.
self.fragmented_frame.payload += frame.payload
if frame.fin:
self.fragmented_frame.fin = True
# Make a copy of the frame in case it gets resetted.
frame = copy.copy(self.fragmented_frame)
try:
# Use duck.utils.asyncio.create_task as it automatically raises errors (if any) compared to the
# default asyncio.create_task
task = create_task(
coro=self.on_new_frame(frame),
on_complete=lambda task: (
self._data_handling_tasks.remove(task)
if task in self._data_handling_tasks
else None
),
)
# Add data handling task to task list
self._data_handling_tasks.add(task)
finally:
# Reset the fragmented frame.
self.fragmented_frame = None
else:
raise ProtocolError(f"Unexpected opcode: {frame.opcode}")
# Yield control to eventloop.
await asyncio.sleep(0)
[docs]
async def send_frame(self, frame: Frame):
"""
Sends a frame to the client, first it applies all negotiated extensions received upon upgrare and
then it sends the frame to the connected client socket.
"""
data = frame.serialize(mask=False, extensions=self.extensions)
await SocketIO.async_send(sock=self.sock, data=data, ignore_error_list=[ssl.SSLError, BrokenPipeError, ConnectionResetError, ConnectionError])
[docs]
async def send(self, data: Union[str, bytes], opcode: int = OpCode.TEXT):
"""
Alias to send a WebSocket message frame.
Args:
data (Union[str, bytes]): Payload data.
opcode (int): WebSocket frame opcode. Defaults to 0x1 (TEXT).
"""
data = data.encode() if not isinstance(data, bytes) else data
max_size = self.MAX_FRAME_SIZE
data_len = len(data)
if opcode in DATA_OPCODES and data_len > max_size:
remaining = data_len
sent_len = 0
# Send fragmented frames.
while remaining > 0:
chunk_size = min(max_size, remaining)
data_to_send = data[sent_len:sent_len + chunk_size]
if remaining == data_len:
# First frame
frame = Frame(
opcode=opcode,
fin=False,
payload=data_to_send,
rsv1=False,
rsv2=False,
rsv3=False,
)
elif remaining > chunk_size:
# Continuation frame
frame = Frame(
opcode=OpCode.CONTINUATION,
fin=False,
payload=data_to_send,
rsv1=False,
rsv2=False,
rsv3=False,
)
elif remaining == chunk_size:
# Last frame
frame = Frame(
opcode=OpCode.CONTINUATION,
fin=True,
payload=data_to_send,
rsv1=False,
rsv2=False,
rsv3=False,
)
# Send frame
await self.send_frame(frame)
# Set some values.
remaining -= chunk_size
sent_len += chunk_size
else:
# This is the final frame
frame = Frame(
opcode=opcode,
fin=True,
payload=data,
rsv1=False,
rsv2=False,
rsv3=False,
)
# Send frame
await self.send_frame(frame)
[docs]
async def send_text(self, data: str):
"""
Send a text WebSocket message.
Args:
data (str): Text message to send.
"""
await self.send(data, opcode=OpCode.TEXT)
[docs]
async def send_json(self, data: Union[dict, list]):
"""
Serialize a Python object to JSON and send as a text message.
Args:
data (Union[dict, list]): Python data to serialize.
"""
json_str = json.dumps(data)
await self.send_text(json_str)
[docs]
async def send_binary(self, data: bytes):
"""
Send binary data as a WebSocket message.
Args:
data (bytes): Raw bytes to send.
"""
await self.send(data, opcode=OpCode.BINARY)
[docs]
async def send_ping(self, data: bytes = b""):
"""
Send a ping control frame.
Args:
data (bytes): Optional ping payload.
"""
await self.send(data, opcode=OpCode.PING)
[docs]
async def send_pong(self, data: bytes = b""):
"""
Send a pong control frame.
Args:
data (bytes): Optional pong payload.
"""
await self.send(data, opcode=OpCode.PONG)
[docs]
async def send_close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = ""):
"""
Send a close control frame and initiate connection close.
This is failsafe, meaning it just fails silently.
Args:
code (int): WebSocket close status code. Defaults to 1000 (Normal Closure).
reason (str): Optional close reason string.
"""
try:
payload = struct.pack(">H", code) + reason.encode()
await self.send(payload, opcode=OpCode.CLOSE)
except Exception as e:
logger.log_exception(e)
[docs]
async def safe_close(self, disable_logging: bool = False, call_on_close_handler: bool = True):
"""
Safely close the WebSocket connection and invoke `on_close` callback.
Ensures close logic is only run once.
Args:
disable_logging (bool): Disables logging even on first attempt.
call_on_close_handler (bool): Whether to call `on_close` method before closing.
"""
if not self._closing:
if self.initiated_upgrade and not disable_logging:
log_message(self.request, f"{self.request.path} [WebSocket] [CLOSE]")
self._closing = True
if call_on_close_handler:
try:
await self.on_close(frame=None)
except Exception as e:
logger.log_exception(e)
SocketIO.close(self.sock) # fail-safe method to close socket.'
self.state = State.CLOSED