diff --git a/matemat/webserver/httpd.py b/matemat/webserver/httpd.py index 20af55c..2f3f3cb 100644 --- a/matemat/webserver/httpd.py +++ b/matemat/webserver/httpd.py @@ -16,8 +16,15 @@ from datetime import datetime, timedelta from matemat import __version__ as matemat_version +# +# Python internal class hacks +# + # Enable IPv6 support (IPv6/IPv4 dual-stack support should be implicitly enabled) TCPServer.address_family = socket.AF_INET6 +# Redirect internal logging to somewhere else, or, for now, silently discard (TODO: logger will come later) +BaseHTTPRequestHandler.log_request = lambda self, code='-', size='-': None +BaseHTTPRequestHandler.log_error = lambda self, fstring='', *args: None # Dictionary to hold registered pagelet paths and their handler functions @@ -159,7 +166,7 @@ class HttpHandler(BaseHTTPRequestHandler): if session_id in self.server.session_vars: del self.server.session_vars[session_id] - def _handle(self, method: str, path: str, args: Dict[str, str]) -> None: + def _handle(self, method: str, path: str, args: Dict[str, Union[str, List[str]]]) -> None: """ Handle a HTTP request by either dispatching it to the appropriate pagelet or by serving a static resource. @@ -279,18 +286,18 @@ class HttpHandler(BaseHTTPRequestHandler): self._handle('GET', path, args) # Special handling for some errors except PermissionError as e: - self.send_error(403, 'Forbidden') + self.send_response(403, 'Forbidden') self.end_headers() print(e) traceback.print_tb(e.__traceback__) except ValueError as e: - self.send_header(400, 'Bad Request') + self.send_response(400, 'Bad Request') self.end_headers() print(e) traceback.print_tb(e.__traceback__) except BaseException as e: # Generic error handling - self.send_error(500, 'Internal Server Error') + self.send_response(500, 'Internal Server Error') self.end_headers() print(e) traceback.print_tb(e.__traceback__) @@ -304,26 +311,26 @@ class HttpHandler(BaseHTTPRequestHandler): # Read the POST body, if it exists, and its MIME type is application/x-www-form-urlencoded clen: str = self.headers.get('Content-Length', failobj='0') ctype: str = self.headers.get('Content-Type', failobj='application/octet-stream') - post = '' + post: str = '' if ctype == 'application/x-www-form-urlencoded': - post: str = self.rfile.read(int(clen)).decode('utf-8') + post = self.rfile.read(int(clen)).decode('utf-8') # Parse the request and hand it to the handle function path, args = self._parse_args(self.path, postbody=post) self._handle('POST', path, args) # Special handling for some errors except PermissionError as e: - self.send_error(403, 'Forbidden') + self.send_response(403, 'Forbidden') self.end_headers() print(e) traceback.print_tb(e.__traceback__) except ValueError as e: - self.send_header(400, 'Bad Request') + self.send_response(400, 'Bad Request') self.end_headers() print(e) traceback.print_tb(e.__traceback__) except BaseException as e: # Generic error handling - self.send_error(500, 'Internal Server Error') + self.send_response(500, 'Internal Server Error') self.end_headers() print(e) traceback.print_tb(e.__traceback__) diff --git a/matemat/webserver/test/__init__.py b/matemat/webserver/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/matemat/webserver/test/abstract_httpd_test.py b/matemat/webserver/test/abstract_httpd_test.py new file mode 100644 index 0000000..806a312 --- /dev/null +++ b/matemat/webserver/test/abstract_httpd_test.py @@ -0,0 +1,163 @@ + +from typing import Any, Dict, Tuple + +import unittest.mock +from io import BytesIO + +from abc import ABC +from datetime import datetime +from http.server import HTTPServer + + +class HttpResponse: + """ + A really basic HTTP response container and parser class, just good enough for unit testing a HTTP server, if even. + + Usage: + response = HttpResponse() + while response.parse_phase != 'done' + response.parse() + print(response.statuscode) + """ + + def __init__(self) -> None: + # The HTTP status code of the response + self.statuscode: int = 0 + # HTTP headers set in the response + self.headers: Dict[str, str] = { + 'Content-Length': 0 + } + # The response body. Only UTF-8 strings are supported + self.body: str = '' + # Parsing phase, one of 'begin', 'hdr', 'body' or 'done' + self.parse_phase = 'begin' + # Buffer for uncompleted lines + self.buffer: bytes = bytes() + + def parse(self, fragment: bytes) -> None: + """ + Parse a new fragment of data. This function does nothing if the parsed HTTP response is already complete. + + :param fragment: The data fragment to parse. + """ + # response packet complete, nothing to do + if self.parse_phase == 'done': + return + # If in the body phase, simply decode and append to the body, while the body is not complete yet + elif self.parse_phase == 'body': + self.body += fragment.decode('utf-8') + if len(self.body) >= int(self.headers['Content-Length']): + self.parse_phase = 'done' + return + if b'\r\n' not in fragment: + # If the fragment does not contain a CR-LF, add it to the buffer, we only want to parse whole lines + self.buffer = self.buffer + fragment + else: + if not fragment.endswith(b'\r\n'): + # Special treatment for no trailing CR-LF: Add remainder to buffer + head, tail = fragment.rsplit(b'\r\n', 1) + data: str = (self.buffer + head).decode('utf-8') + self.buffer = tail + else: + data: str = (self.buffer + fragment).decode('utf-8') + self.buffer = bytes() + # Iterate the lines that are ready to be parsed + for line in data.split('\r\n'): + # The 'begin' phase indicates that the parser is waiting for the HTTP status line + if self.parse_phase == 'begin': + if line.startswith('HTTP/'): + # Parse the statuscode and advance to header parsing + _, statuscode, _ = line.split(' ', 2) + self.statuscode = int(statuscode) + self.parse_phase = 'hdr' + elif self.parse_phase == 'hdr': + # Parse a header line and add it to the header dict + if len(line) > 0: + k, v = line.split(':', 1) + self.headers[k.strip()] = v.strip() + else: + # Empty line separates header from body + self.parse_phase = 'body' + elif self.parse_phase == 'body': + # if there is a remainder in the data packet, it is (part of) the body, add to body string + self.body += line + if len(self.body) >= int(self.headers['Content-Length']): + self.parse_phase = 'done' + + +class MockServer: + """ + A mock implementation of http.server.HTTPServer. Only used for matemat-specific storage. + """ + + def __init__(self, webroot: str = '/var/matemat/webroot') -> None: + # Session timeout and variables for all sessions + self.session_vars: Dict[str, Tuple[datetime, Dict[str, Any]]] = dict() + # Webroot for statically served content + self.webroot: str = webroot + + +class MockSocket(bytes): + """ + A mock implementation of a socket.socket for http.server.BaseHTTPRequestHandler. + + The bytes inheritance is due to a broken type annotation in BaseHTTPRequestHandler. + """ + + def __init__(self) -> None: + super().__init__() + # The request string + self.__request = bytes() + # The parsed response + self.__packet = HttpResponse() + + def set_request(self, request: bytes) -> None: + """ + Sets the HTTP request to send to the server. + + :param request: The request + """ + self.__request: bytes = request + + def makefile(self, mode: str, size: int) -> BytesIO: + """ + Required by http.server.HTTPServer. + + :return: A dummy buffer IO object instead of a network socket file handle. + """ + return BytesIO(self.__request) + + def sendall(self, b: bytes) -> None: + """ + Required by http.server.HTTPServer. + + :param b: The data to send to the client. Will be parsed directly instead. + """ + self.__packet.parse(b) + + def get_response(self) -> HttpResponse: + """ + Fetches the parsed HTTP response generated by the server. + + :return: The response object. + """ + return self.__packet + + +class AbstractHttpdTest(ABC, unittest.TestCase): + """ + An abstract test case that can be inherited by test case classes that want to test part of the webserver's core + functionality. + + Usage (subclass test method): + + self.client_sock.set_request(b'GET /just/testing/sessions HTTP/1.1\r\n\r\n') + handler = HttpHandler(self.client_sock, ('::1', 45678), self.server) + packet = self.client_sock.get_response() + + TODO(s3lph): This could probably go here instead. + """ + + def setUp(self) -> None: + self.server: HTTPServer = MockServer() + self.client_sock: MockSocket = MockSocket() diff --git a/matemat/webserver/test/test_session.py b/matemat/webserver/test/test_session.py new file mode 100644 index 0000000..af51a75 --- /dev/null +++ b/matemat/webserver/test/test_session.py @@ -0,0 +1,53 @@ + +from typing import Any, Dict + +from datetime import datetime, timedelta + +from matemat.webserver.httpd import HttpHandler, pagelet +from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest + + +@pagelet('/just/testing/sessions') +def test_pagelet(method: str, path: str, args: Dict[str, str], session_vars: Dict[str, Any], headers: Dict[str, str]): + session_vars['test'] = 'hello, world!' + headers['Content-Type'] = 'text/plain' + return 200, 'session test' + + +class TestSession(AbstractHttpdTest): + """ + Test session handling of the Matemat webserver. + """ + + def test_create_new_session(self): + # Reference date to make sure the session expiry lies in the future + refdate = datetime.utcnow() + timedelta(seconds=3500) + # Send a mock GET request for '/just/testing/sessions' + self.client_sock.set_request(b'GET /just/testing/sessions HTTP/1.1\r\n\r\n') + # Trigger request handling + handler = HttpHandler(self.client_sock, ('::1', 45678), self.server) + # Fetch the parsed response + packet = self.client_sock.get_response() + # Make sure a full HTTP response was parsed + self.assertEqual('done', packet.parse_phase) + # Make sure the request was served by the test pagelet + self.assertEqual('session test', packet.body) + self.assertEqual(200, packet.statuscode) + + session_id: str = list(handler.server.session_vars.keys())[0] + # Make sure a cookie was set - assuming that only one was set + self.assertIn('Set-Cookie', packet.headers) + # Split into the cookie itself + cookie, expiry = packet.headers['Set-Cookie'].split(';') + cookie: str = cookie.strip() + expiry: str = expiry.strip() + # Make sure the 'matemat_session_id' cookie was set to the session ID string + self.assertEqual(f'matemat_session_id={session_id}', cookie) + # Make sure the session expires in about one hour + self.assertTrue(expiry.startswith('expires=')) + _, expdatestr = expiry.split('=', 1) + expdate = datetime.strptime(expdatestr, '%a, %d %b %Y %H:%M:%S GMT') + self.assertTrue(expdate > refdate) + # Make sure the session exists on the server + self.assertIn('test', handler.session_vars) + self.assertEqual('hello, world!', handler.session_vars['test'])