Source code for duck.http.core.proxyhandler

"""
Module containing proxy handler class.
"""
import ssl
import asyncio
import socket

from http.cookies import SimpleCookie
from collections import defaultdict
from typing import (
    Optional,
    Tuple,
    Generator,
    AsyncGenerator,
    Union
)
from duck.settings import SETTINGS
from duck.http.content import Content
from duck.http.headers import Headers
from duck.http.request import HttpRequest
from duck.http.response import StreamingHttpResponse
from duck.http.response_payload import (
    BaseResponsePayload,
    SimpleHttpResponsePayload,
)
from duck.meta import Meta
from duck.utils.net import is_domain
from duck.utils.importer import x_import
from duck.utils.headers import parse_headers_from_bytes
from duck.utils.importer import import_module_once
from duck.utils.xsocket import xsocket, create_xsocket
from duck.utils.xsocket.io import SocketIO


# Buffer size to use when receiving payload from the remote server.
# Ensure that the buffer is large enough to hold most typical payloads
# but not too large to cause inefficient memory usage.
PAYLOAD_BUFFER_SIZE = SETTINGS["SERVER_BUFFER"]

# Timeout to establish a connection with the remote proxy server (e.g., Django).
# A value of 1 second is typically used for fast responses. 
# Consider increasing if connection time to the server is high.
CONNECT_TIMEOUT = SETTINGS["PROXY_CONNECT_TIMEOUT"]  # Timeout in seconds for establishing the connection

# Timeout to wait for data from the remote proxy server.
# This value should balance between waiting for data and not blocking indefinitely.
# Increase if network latency or server load is high.
READ_TIMEOUT = SETTINGS["PROXY_READ_TIMEOUT"]  # Timeout in seconds for reading data

# The amount of data to stream at once from the remote proxy server.
# 4096 bytes (4 KB) is a reasonable default chunk size for streaming,
# but can be adjusted based on specific requirements.
STREAM_CHUNK_SIZE = SETTINGS["PROXY_STREAM_CHUNK_SIZE"]  # Streaming chunk size in bytes

# Import middleware for fixing request before reaching Django server.
DJANGO_REQUEST_FIXER_MIDDLEWARE = x_import("duck.http.middlewares.contrib.DjangoRequestFixerMiddleware")


[docs] class ReverseProxyError(Exception): """ Raised on exceptions encountered on reverse-proxy ops. """
[docs] class BadGatewayError(ReverseProxyError): """ Custom exception for bad gateway errors (HTTP 502). This exception is used to signal that the request could not be processed because of a server-side issue, often related to a gateway or proxy. """
[docs] class HttpProxyResponse(StreamingHttpResponse): """ A synchronous subclass of StreamingHttpResponse that handles HTTP responses in a proxy scenario, including cases where the response content is incomplete or partially received. This class is specifically designed to handle situations where: - Only the HTTP response headers are available, and content is streamed progressively. - The response content is incomplete, and the response is processed as it is received. It is ideal for proxy servers or intermediary systems that need to pass along HTTP responses from a remote server to the client while processing and streaming the data in chunks, minimizing memory usage by not requiring the full response to be loaded at once. Key Features: - **Streaming**: Supports efficient, memory-friendly streaming of content from the proxy server to the client. - **Partial Responses**: Handles scenarios where the response is incomplete or streaming, such as large files or slow connections. - **Real-Time Processing**: Allows the server to process parts of the response as they arrive without blocking or waiting for the full content. This class can be extended for custom logic related to response manipulation, such as modifying headers, handling chunk sizes, or managing specific proxy behaviors. Example Usage: ```py response = HttpProxyResponse( target_socket, payload_obj, content_obj, chunk_size, ) for content in response.iter_content(): # Content logic here pass ``` Attributes: target_socket: The socket already connected used for communication with the proxy server. payload_obj: The response payload associated with the HTTP response. content_obj: Content object with the initial or incomplete content. chunk_size: The streaming chunk size. """ def __init__( self, target_socket: xsocket, payload_obj: BaseResponsePayload, content_obj: Optional[Content] = None, chunk_size: int = STREAM_CHUNK_SIZE, ): """ Initialize the HttpProxyResponse object. Args: target_socket (xsocket): The live xsocket connection used to receive content. payload_obj (BaseResponsePayload): The response payload object containing headers and metadata. content_obj (Optional[Content]): The initial partial content object, if available. chunk_size (int): The size (in bytes) of each chunk to stream. Defaults to STREAM_CHUNK_SIZE. """ self.target_socket = target_socket self.payload_obj = payload_obj self.content_obj = content_obj self.chunk_size = chunk_size self.more_data = bytearray(b'') # List Buffer for additional data received during streaming # No need to call super() or set stream because we are directly modifying iter methods like iter_content/async_iter_content # rather than using the default ones wihich requires a stream
[docs] def iter_content(self) -> Generator[bytes, None, None]: """ A generator to stream the current content followed by additional data as it is received. Yields: bytes: The next chunk of data to stream. """ if self.content_obj and self.content_obj.data: yield self.content_obj.data # Yield current data. while True: data = self.recv_more(self.chunk_size) if data: yield data else: break
[docs] def recv_more(self, buffer: int = 4096) -> bytes: """ Receive additional data from the target socket. Args: buffer (int): The buffer size for receiving data. Defaults to 4096 bytes. Returns: bytes: The received data, or an empty byte string if the content is complete. """ # Set blocking to True just in case. self.target_socket.setblocking(True) try: # Calculate the current content length and the expected total length current_content_length = len(self.more_data) + (len(self.content_obj.data) if self.content_obj else 0) expected_content_length = self.payload_obj.get_header("content-length") # If the content length is known and fully received, stop receiving if expected_content_length and expected_content_length.isdigit(): if current_content_length >= int(expected_content_length): return b"" # Receive data from the socket data = SocketIO.receive(self.target_socket, bufsize=buffer, timeout=READ_TIMEOUT) self.more_data.extend(data) return data except (TimeoutError, ConnectionResetError, EOFError): # Handle errors silently pass except (OSError, socket.error) as e: raise BadGatewayError(f"Receiving additional data failed: {e}") return b""
# ASYNCHRONOUS IMPLEMENTATIONS
[docs] async def async_iter_content(self) -> AsyncGenerator[bytes, None]: """ An asynchronous generator to stream the current content followed by additional data as it is received. Yields: bytes: The next chunk of data to stream. """ # Set blocking to False just in case. self.target_socket.setblocking(False) if self.content_obj and self.content_obj.data: yield self.content_obj.data # Yield current data. while True: data = await self.async_recv_more(self.chunk_size) if data: yield data else: break
[docs] async def async_recv_more(self, buffer: int = 4096) -> bytes: """ Asynchronously receive additional data from the target socket. Args: buffer (int): The buffer size for receiving data. Defaults to 4096 bytes. Returns: bytes: The received data, or an empty byte string if the content is complete. """ try: # Calculate the current content length and the expected total length current_content_length = len(self.more_data) + (len(self.content_obj.data) if self.content_obj else 0) expected_content_length = self.payload_obj.get_header("content-length") # If the content length is known and fully received, stop receiving if expected_content_length and expected_content_length.isdigit(): if current_content_length >= int(expected_content_length): return b"" # Receive data from the socket data = await SocketIO.async_receive(self.target_socket, bufsize=buffer, timeout=READ_TIMEOUT) self.more_data.extend(data) return data except (TimeoutError, ConnectionResetError, EOFError): # Handle errors silently pass except (OSError, socket.error) as e: raise BadGatewayError(f"Receiving additional data failed: {e}") return b""
[docs] def __repr__(self): return f"<{self.__class__.__name__} (" f"'{self.status_code}'" f")>"
[docs] class HttpProxyHandler: """ HttpProxyHandler class to handle forwarding TCP requests to target hosts. This class supports both IPv4 and IPv6 and allows modification of headers before forwarding the data to the client. """ __slots__ = ("uses_ipv6", "target_host", "target_port", "uses_ssl") def __init__( self, target_host: str, target_port: int, uses_ipv6: bool = False, uses_ssl: bool = False, ): """ Initializes the ProxyHandler. Args: uses_ipv6 (bool): Boolean indicating whether to use IPv6. Defaults to False. """ assert isinstance(target_host, str), ( "Target host should be a string representing hostname or host ip address." ) assert isinstance(target_port, int), ( "Target port should be an integer for host port." ) self.uses_ipv6 = uses_ipv6 self.target_host = socket.gethostbyname(target_host) if is_domain(target_host) else target_host self.target_host = "127.0.0.1" if self.target_host.startswith("0.") else self.target_host # Windows compatibility self.target_port = target_port self.uses_ssl = uses_ssl if self.uses_ssl: raise ReverseProxyError("Reverse proxying to SSL servers not supported yet.")
[docs] def get_response( self, request: HttpRequest, client_socket: socket.socket, ) -> HttpProxyResponse: """ Handles the client connection and forwards the request to the target server and returns partial response. To receive more response, use method HttpPartialResponse.recvmore, this method may raise Args: request (HttpRequest): Client Http request object. client_socket (socket.socket): The socket object connected to the client. Returns: HttpProxyResponse: The partial unfinished response ready for processing but lacking full content. """ # Connect to the target server target_socket = self.connect_to_target() try: # Forward the client's request to the target server self.forward_request_to_target(request, client_socket, target_socket) # Receive partial response payload, partial_content = self.fetch_response_payload(target_socket) content_length_header = payload.get_header("content-length") chunk_size = None try: content_length_header = int(content_length_header) chunk_size = min(content_length_header, STREAM_CHUNK_SIZE) except (ValueError, TypeError): pass streaming_response = HttpProxyResponse( target_socket, payload_obj=payload, content_obj=Content(partial_content), chunk_size=chunk_size or STREAM_CHUNK_SIZE, ) return streaming_response except (OSError, ConnectionRefusedError, socket.error) as e: raise BadGatewayError(f"Connection was successful, but subsequent actions failed: {e}.")
[docs] def connect_to_target(self) -> xsocket: """ Connects to the target server. Returns: xsocket: The xsocket object connected to the target server. """ target_socket = create_xsocket( socket.AF_INET6 if self.uses_ipv6 else socket.AF_INET, socket.SOCK_STREAM, ) try: SocketIO.connect(target_socket, (self.target_host, self.target_port), timeout=CONNECT_TIMEOUT) return target_socket except (OSError, ConnectionRefusedError, socket.error) as e: raise BadGatewayError(f"Connection to remote server failed: {e}")
[docs] def forward_request_to_target( self, request: HttpRequest, client_socket: xsocket, target_socket: xsocket, ): """ Forwards the client's request to the target server. Args: request (HttpRequest): Client Http request object to forward to target server. client_socket (xsocket): The socket object connected to the client. target_socket (xsocket): The socket object connected to the target server. """ # Modify headers if needed self.modify_client_request(request) SocketIO.send(target_socket, request.raw)
[docs] def fetch_response_payload( self, target_socket: xsocket, ) -> Tuple[SimpleHttpResponsePayload, bytes]: """ Returns received Response Payload and leftover data from target server. Returns: SimpleHttpResponsePayload, bytes: This is SimpleHttpResponsePayload object and leftover_data (partial data sent from target server after headers). """ data = b"" leftover_data = b"" while True: part = SocketIO.receive(target_socket, bufsize=PAYLOAD_BUFFER_SIZE, timeout=READ_TIMEOUT) data = data.join([b"", part]) if b"\r\n\r\n" in data or not data: # content separator break try: topheader, data = data.split(b"\r\n", 1) data = data.split(b"\r\n\r\n", 1) except ValueError: raise BadGatewayError("Response payload is malformed.") if len(data) > 1: data, _leftover = data leftover_data = leftover_data.join([b"", _leftover]) else: data = data[0] # Parse headers from bytes - assuming it returns a dictionary with lists of values for each header headers: Dict[str, List[str]] = parse_headers_from_bytes(data) # Prepare containers for cleaned headers and set-cookie headers cleaned_headers = {} # Will hold only the latest value for each header (no duplicates) set_cookies = [] # List to hold all Set-Cookie headers # Iterate through each header to process them for header, values in headers.items(): # Special handling for 'Set-Cookie' headers as they can appear multiple times if header.lower() == "set-cookie": set_cookies.extend(values) # Add all 'Set-Cookie' values to the set_cookies list else: # For all other headers, keep only the last value (no duplicates) cleaned_headers[header] = values[-1] # At this point, `cleaned_headers` will contain headers with a single value each, # and `set_cookies` will contain all Set-Cookie values. # set topheader, headers in response payload payload = SimpleHttpResponsePayload(topheader.decode("utf-8"), cleaned_headers) for set_cookie in set_cookies: try: payload.cookies.load(set_cookie) except (CookieError, TypeError): pass return payload, leftover_data
[docs] def modify_client_request(self, request: HttpRequest): """ Modifies the Http request if needed. Args: request (HttpRequest): Client Http request object. """ # Implement your header modification logic here if SETTINGS["USE_DJANGO"]: DJANGO_REQUEST_FIXER_MIDDLEWARE.process_request(request)
# ASYNCHRONOUS IMPLEMENTATIONS
[docs] class AsyncHttpProxyHandler(HttpProxyHandler): """ An asynchronous HttpProxyHandler class to handle forwarding TCP requests to target hosts. This class supports both IPv4 and IPv6 and allows modification of headers before forwarding the data to the client. """
[docs] async def get_response( self, request: HttpRequest, client_socket: socket.socket, ) -> HttpProxyResponse: """ Asynchronously handles the client connection and forwards the request to the target server and returns partial response. To receive more response, use method HttpPartialResponse.recvmore, this method may raise Args: request (HttpRequest): Client Http request object. client_socket (socket.socket): The socket object connected to the client. Returns: HttpProxyResponse: The partial unfinished response ready for processing but lacking full content. """ # Connect to the target server target_socket = await self.connect_to_target() try: # Forward the client's request to the target server await self.forward_request_to_target(request, client_socket, target_socket) # Receive partial response payload, partial_content = await self.fetch_response_payload(target_socket) content_length_header = payload.get_header("content-length") chunk_size = None try: content_length_header = int(content_length_header) chunk_size = min(content_length_header, STREAM_CHUNK_SIZE) except (ValueError, TypeError): pass streaming_response = HttpProxyResponse( target_socket, payload_obj=payload, content_obj=Content(partial_content), chunk_size=chunk_size or STREAM_CHUNK_SIZE, ) return streaming_response except (OSError, ConnectionRefusedError, socket.error) as e: raise BadGatewayError(f"Connection was successful, but subsequent actions failed: {e}.")
[docs] async def connect_to_target(self) -> xsocket: """ Asynchronously connects to the target server. Returns: xsocket: The xsocket object connected to the target server. """ target_socket = create_xsocket( socket.AF_INET6 if self.uses_ipv6 else socket.AF_INET, socket.SOCK_STREAM, ) # Set socket blocking to False target_socket.setblocking(False) try: await SocketIO.async_connect(target_socket, (self.target_host, self.target_port), timeout=CONNECT_TIMEOUT) return target_socket except (OSError, ConnectionRefusedError, socket.error) as e: raise BadGatewayError(f"Connection to remote server failed: {e}")
[docs] async def forward_request_to_target( self, request: HttpRequest, client_socket: xsocket, target_socket: xsocket, ): """ Asynchronously forwards the client's request to the target server. Args: request (HttpRequest): Client Http request object to forward to target server. client_socket (xsocket): The socket object connected to the client. target_socket (xsocket): The socket object connected to the target server. """ # Modify headers if needed self.modify_client_request(request) raw_request = request.raw.join([b"", b"\r\n\r\n"]) await SocketIO.async_send(target_socket, raw_request)
[docs] async def fetch_response_payload( self, target_socket: xsocket, ) -> Tuple[SimpleHttpResponsePayload, bytes]: """ Returns received Response Payload and leftover data from target server. Returns: SimpleHttpResponsePayload, bytes: This is SimpleHttpResponsePayload object and leftover_data (partial data sent from target server after headers). """ data = b"" leftover_data = b"" while True: part = await SocketIO.async_receive(target_socket, bufsize=PAYLOAD_BUFFER_SIZE, timeout=READ_TIMEOUT) data = data.join([b"", part]) if b"\r\n\r\n" in data or not data: # content separator break try: topheader, data = data.split(b"\r\n", 1) data = data.split(b"\r\n\r\n", 1) except ValueError: raise BadGatewayError("Response payload is malformed.") if len(data) > 1: data, _leftover = data leftover_data = leftover_data.join([b"", _leftover]) else: data = data[0] # Parse headers from bytes - assuming it returns a dictionary with lists of values for each header headers: Dict[str, List[str]] = parse_headers_from_bytes(data) # Prepare containers for cleaned headers and set-cookie headers cleaned_headers = {} # Will hold only the latest value for each header (no duplicates) set_cookies = [] # List to hold all Set-Cookie headers # Iterate through each header to process them for header, values in headers.items(): # Special handling for 'Set-Cookie' headers as they can appear multiple times if header.lower() == "set-cookie": set_cookies.extend(values) # Add all 'Set-Cookie' values to the set_cookies list else: # For all other headers, keep only the last value (no duplicates) cleaned_headers[header] = values[-1] # At this point, `cleaned_headers` will contain headers with a single value each, # and `set_cookies` will contain all Set-Cookie values. # set topheader, headers in response payload payload = SimpleHttpResponsePayload(topheader.decode("utf-8"), cleaned_headers) for set_cookie in set_cookies: try: payload.cookies.load(set_cookie) except (CookieError, TypeError): pass return payload, leftover_data