From f702eccc57e38789687175cc118870a8611b999e Mon Sep 17 00:00:00 2001 From: s3lph Date: Wed, 27 Jun 2018 21:17:18 +0200 Subject: [PATCH 1/9] First implementation of multipart/form-data parsing --- matemat/webserver/httpd.py | 90 ++++++------- matemat/webserver/pagelets/__init__.py | 1 + matemat/webserver/pagelets/login.py | 19 ++- matemat/webserver/pagelets/logout.py | 9 +- matemat/webserver/pagelets/main.py | 8 +- matemat/webserver/pagelets/touchkey.py | 18 ++- matemat/webserver/pagelets/upload_test.py | 28 ++++ matemat/webserver/test/abstract_httpd_test.py | 28 ++-- matemat/webserver/test/test_post.py | 123 ++++++++++++++---- matemat/webserver/test/test_serve.py | 19 +-- matemat/webserver/test/test_session.py | 4 +- matemat/webserver/util.py | 118 +++++++++++++++++ 12 files changed, 357 insertions(+), 108 deletions(-) create mode 100644 matemat/webserver/pagelets/upload_test.py create mode 100644 matemat/webserver/util.py diff --git a/matemat/webserver/httpd.py b/matemat/webserver/httpd.py index 220849c..a4e9cca 100644 --- a/matemat/webserver/httpd.py +++ b/matemat/webserver/httpd.py @@ -1,12 +1,11 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import traceback import os import socket import mimetypes -import urllib.parse from socketserver import TCPServer from http.server import HTTPServer, BaseHTTPRequestHandler from http.cookies import SimpleCookie @@ -14,6 +13,7 @@ from uuid import uuid4 from datetime import datetime, timedelta from matemat import __version__ as matemat_version +from matemat.webserver.util import parse_args # @@ -28,12 +28,17 @@ BaseHTTPRequestHandler.log_error = lambda self, fstring='', *args: None # Dictionary to hold registered pagelet paths and their handler functions -_PAGELET_PATHS: Dict[str, Callable[[str, str, Dict[str, str], Dict[str, Any], Dict[str, str], bytes], - Tuple[int, Union[bytes, str]]]] = dict() +_PAGELET_PATHS: Dict[str, Callable[[str, # HTTP method (GET, POST, ...) + str, # Request path + Dict[str, Tuple[str, Union[bytes, str, List[str]]]], # args: (name, (type, value)) + Dict[str, Any], # Session vars + Dict[str, str]], # Response headers + Tuple[int, Union[bytes, str]]]] = dict() # Returns: (status code, response body) # Inactivity timeout for client sessions _SESSION_TIMEOUT: int = 3600 +_MAX_POST: int = 1_000_000 def pagelet(path: str): @@ -43,12 +48,17 @@ def pagelet(path: str): The function must have the following signature: - (method: str, path: str, args: Dict[str, Union[str, List[str]], session_vars: Dict[str, Any], - headers: Dict[str, str]) -> (int, Optional[Union[str, bytes]]) + (method: str, + path: str, + args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + session_vars: Dict[str, Any], + headers: Dict[str, str]) + -> (int, Optional[Union[str, bytes]]) method: The HTTP method (GET, POST) that was used. path: The path that was requested. - args: The arguments that were passed with the request (as GET or POST arguments). + args: The arguments that were passed with the request (as GET or POST arguments), each of which may be + either a str or bytes object, or a list of str. session_vars: The session storage. May be read from and written to. headers: The dictionary of HTTP response headers. Add headers you wish to send with the response. returns: A tuple consisting of the HTTP status code (as an int) and the response body (as str or bytes, @@ -56,7 +66,12 @@ def pagelet(path: str): :param path: The path to register the function for. """ - def http_handler(fun: Callable[[str, str, Dict[str, str], Dict[str, Any], Dict[str, str], bytes], + + def http_handler(fun: Callable[[str, + str, + Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + Dict[str, Any], + Dict[str, str]], Tuple[int, Union[bytes, str]]]): # Add the function to the dict of pagelets _PAGELET_PATHS[path] = fun @@ -166,7 +181,7 @@ class HttpHandler(BaseHTTPRequestHandler): if session_id in self.server.session_vars: del self.server.session_vars[session_id] - def _handle(self, method: str, path: str, args: Dict[str, Union[str, List[str]]]) -> None: + def _handle(self, method: str, path: str, args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]]) -> None: """ Handle a HTTP request by either dispatching it to the appropriate pagelet or by serving a static resource. @@ -238,7 +253,7 @@ class HttpHandler(BaseHTTPRequestHandler): mimetype = 'application/octet-stream' # Send content type and length header self.send_header('Content-Type', mimetype) - self.send_header('Content-Length', len(data)) + self.send_header('Content-Length', str(len(data))) self.end_headers() # Send the requested resource as response body self.wfile.write(data) @@ -247,36 +262,6 @@ class HttpHandler(BaseHTTPRequestHandler): self.send_response(404) self.end_headers() - @staticmethod - def _parse_args(request: str, postbody: Optional[str] = None) -> Tuple[str, Dict[str, Union[str, List[str]]]]: - """ - Given a HTTP request path, and optionally a HTTP POST body in application/x-www-form-urlencoded form, parse the - arguments and return them as a dictionary. - - If a key is used both in GET and in POST, the POST value takes precedence, and the GET value is discarded. - - :param request: The request string to parse. - :param postbody: The POST body to parse, defaults to None. - :return: A tuple consisting of the base path and a dictionary with the parsed key/value pairs. - """ - # Parse the request "URL" (i.e. only the path) - tokens = urllib.parse.urlparse(request) - # Parse the GET arguments - args = urllib.parse.parse_qs(tokens.query) - - if postbody is not None: - # Parse the POST body - postargs = urllib.parse.parse_qs(postbody) - # Write all POST values into the dict, overriding potential duplicates from GET - for k, v in postargs.items(): - args[k] = v - # urllib.parse.parse_qs turns ALL arguments into arrays. This turns arrays of length 1 into scalar values - for k, v in args.items(): - if len(v) == 1: - args[k] = v[0] - # Return the path and the parsed arguments - return tokens.path, args - # noinspection PyPep8Naming def do_GET(self) -> None: """ @@ -284,7 +269,7 @@ class HttpHandler(BaseHTTPRequestHandler): """ try: # Parse the request and hand it to the handle function - path, args = self._parse_args(self.path) + path, args = parse_args(self.path) self._handle('GET', path, args) # Special handling for some errors except PermissionError: @@ -305,25 +290,24 @@ class HttpHandler(BaseHTTPRequestHandler): """ try: # Read the POST body, if it exists, and its MIME type is application/x-www-form-urlencoded - clen: str = self.headers.get('Content-Length', failobj='0') + clen: int = int(str(self.headers.get('Content-Length', failobj='0'))) + if clen > _MAX_POST: + raise ValueError('Request too big') ctype: str = self.headers.get('Content-Type', failobj='application/octet-stream') - post: str = '' - if ctype == 'application/x-www-form-urlencoded': - post = self.rfile.read(int(clen)).decode('utf-8') + post: bytes = self.rfile.read(clen) + path, args = parse_args(self.path, postbody=post, enctype=ctype) # Parse the request and hand it to the handle function - path, args = self._parse_args(self.path, postbody=post) self._handle('POST', path, args) - # Special handling for some errors - except PermissionError as e: + # Special handling for some errors + except PermissionError: self.send_response(403, 'Forbidden') self.end_headers() - print(e) - traceback.print_tb(e.__traceback__) - except ValueError as e: + except ValueError: + self.send_response(400, 'Bad Request') + self.end_headers() + except TypeError: self.send_response(400, 'Bad Request') self.end_headers() - print(e) - traceback.print_tb(e.__traceback__) except BaseException as e: # Generic error handling self.send_response(500, 'Internal Server Error') diff --git a/matemat/webserver/pagelets/__init__.py b/matemat/webserver/pagelets/__init__.py index 9b926d6..71ded5e 100644 --- a/matemat/webserver/pagelets/__init__.py +++ b/matemat/webserver/pagelets/__init__.py @@ -8,3 +8,4 @@ from .main import main_page from .login import login_page from .logout import logout from .touchkey import touchkey_page +from .upload_test import upload_test diff --git a/matemat/webserver/pagelets/login.py b/matemat/webserver/pagelets/login.py index 876fd71..8fbe831 100644 --- a/matemat/webserver/pagelets/login.py +++ b/matemat/webserver/pagelets/login.py @@ -1,5 +1,5 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple, Union from matemat.exceptions import AuthenticationError from matemat.webserver import pagelet @@ -8,7 +8,12 @@ from matemat.db import MatematDatabase @pagelet('/login') -def login_page(method: str, path: str, args: Dict[str, str], session_vars: Dict[str, Any], headers: Dict[str, str]): +def login_page(method: str, + path: str, + args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + session_vars: Dict[str, Any], + headers: Dict[str, str])\ + -> Tuple[int, Optional[Union[str, bytes]]]: if 'user' in session_vars: headers['Location'] = '/' return 301, None @@ -38,13 +43,19 @@ def login_page(method: str, path: str, args: Dict[str, str], session_vars: Dict[ ''' return 200, data.format(msg=args['msg'] if 'msg' in args else '') elif method == 'POST': - print(args) + if 'username' not in args or not isinstance(args['username'], str): + return 400, None + if 'password' not in args or not isinstance(args['password'], str): + return 400, None + username: str = args['username'] + password: str = args['password'] with MatematDatabase('test.db') as db: try: - user: User = db.login(args['username'], args['password']) + user: User = db.login(username, password) except AuthenticationError: headers['Location'] = '/login?msg=Username%20or%20password%20wrong.%20Please%20try%20again.' return 301, bytes() session_vars['user'] = user headers['Location'] = '/' return 301, bytes() + return 405, None diff --git a/matemat/webserver/pagelets/logout.py b/matemat/webserver/pagelets/logout.py index 86095b0..53a292a 100644 --- a/matemat/webserver/pagelets/logout.py +++ b/matemat/webserver/pagelets/logout.py @@ -1,11 +1,16 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple, Union from matemat.webserver import pagelet @pagelet('/logout') -def logout(method: str, path: str, args: Dict[str, str], session_vars: Dict[str, Any], headers: Dict[str, str]): +def logout(method: str, + path: str, + args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + session_vars: Dict[str, Any], + headers: Dict[str, str])\ + -> Tuple[int, Optional[Union[str, bytes]]]: if 'user' in session_vars: del session_vars['user'] headers['Location'] = '/' diff --git a/matemat/webserver/pagelets/main.py b/matemat/webserver/pagelets/main.py index 2ead15d..d2dd208 100644 --- a/matemat/webserver/pagelets/main.py +++ b/matemat/webserver/pagelets/main.py @@ -1,5 +1,5 @@ -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from matemat.webserver import MatematWebserver, pagelet from matemat.primitives import User @@ -7,7 +7,11 @@ from matemat.db import MatematDatabase @pagelet('/') -def main_page(method: str, path: str, args: Dict[str, str], session_vars: Dict[str, Any], headers: Dict[str, str])\ +def main_page(method: str, + path: str, + args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + session_vars: Dict[str, Any], + headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: data = ''' diff --git a/matemat/webserver/pagelets/touchkey.py b/matemat/webserver/pagelets/touchkey.py index fd99fea..2a8202d 100644 --- a/matemat/webserver/pagelets/touchkey.py +++ b/matemat/webserver/pagelets/touchkey.py @@ -1,5 +1,5 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple, Union from matemat.exceptions import AuthenticationError from matemat.webserver import pagelet @@ -8,7 +8,12 @@ from matemat.db import MatematDatabase @pagelet('/touchkey') -def touchkey_page(method: str, path: str, args: Dict[str, str], session_vars: Dict[str, Any], headers: Dict[str, str]): +def touchkey_page(method: str, + path: str, + args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + session_vars: Dict[str, Any], + headers: Dict[str, str])\ + -> Tuple[int, Optional[Union[str, bytes]]]: if 'user' in session_vars: headers['Location'] = '/' return 301, bytes() @@ -37,12 +42,19 @@ def touchkey_page(method: str, path: str, args: Dict[str, str], session_vars: Di ''' return 200, data.format(username=args['username'] if 'username' in args else '') elif method == 'POST': + if 'username' not in args or not isinstance(args['username'], str): + return 400, None + if 'touchkey' not in args or not isinstance(args['touchkey'], str): + return 400, None + username: str = args['username'] + touchkey: str = args['touchkey'] with MatematDatabase('test.db') as db: try: - user: User = db.login(args['username'], touchkey=args['touchkey']) + user: User = db.login(username, touchkey=touchkey) except AuthenticationError: headers['Location'] = f'/touchkey?username={args["username"]}&msg=Please%20try%20again.' return 301, bytes() session_vars['user'] = user headers['Location'] = '/' return 301, None + return 405, None diff --git a/matemat/webserver/pagelets/upload_test.py b/matemat/webserver/pagelets/upload_test.py new file mode 100644 index 0000000..a6f1e85 --- /dev/null +++ b/matemat/webserver/pagelets/upload_test.py @@ -0,0 +1,28 @@ + +from typing import Any, Dict, Union + +from matemat.webserver import pagelet + + +@pagelet('/upload') +def upload_test(method: str, + path: str, + args: Dict[str, Union[str, bytes]], + session_vars: Dict[str, Any], + headers: Dict[str, str]): + if method == 'GET': + return 200, ''' + + + +
+ + + +
+ + + ''' + else: + headers['Content-Type'] = 'text/plain' + return 200, args.items().__str__() diff --git a/matemat/webserver/test/abstract_httpd_test.py b/matemat/webserver/test/abstract_httpd_test.py index de0daf6..b96767e 100644 --- a/matemat/webserver/test/abstract_httpd_test.py +++ b/matemat/webserver/test/abstract_httpd_test.py @@ -1,5 +1,5 @@ -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import unittest.mock from io import BytesIO @@ -31,8 +31,8 @@ class HttpResponse: 'Content-Length': 0 } self.pagelet: str = None - # The response body. Only UTF-8 strings are supported - self.body: str = '' + # The response body + self.body: bytes = bytes() # Parsing phase, one of 'begin', 'hdr', 'body' or 'done' self.parse_phase = 'begin' # Buffer for uncompleted lines @@ -55,7 +55,7 @@ class HttpResponse: return # If in the body phase, simply decode and append to the body, while the body is not complete yet elif self.parse_phase == 'body': - self.body += fragment.decode('utf-8') + self.body += fragment if len(self.body) >= int(self.headers['Content-Length']): self.__finalize() return @@ -66,24 +66,24 @@ class HttpResponse: if not fragment.endswith(b'\r\n'): # Special treatment for no trailing CR-LF: Add remainder to buffer head, tail = fragment.rsplit(b'\r\n', 1) - data: str = (self.buffer + head).decode('utf-8') + data: bytes = (self.buffer + head) self.buffer = tail else: - data: str = (self.buffer + fragment).decode('utf-8') + data: bytes = (self.buffer + fragment) self.buffer = bytes() # Iterate the lines that are ready to be parsed - for line in data.split('\r\n'): + for line in data.split(b'\r\n'): # The 'begin' phase indicates that the parser is waiting for the HTTP status line if self.parse_phase == 'begin': - if line.startswith('HTTP/'): + if line.startswith(b'HTTP/'): # Parse the statuscode and advance to header parsing - _, statuscode, _ = line.split(' ', 2) + _, statuscode, _ = line.decode('utf-8').split(' ', 2) self.statuscode = int(statuscode) self.parse_phase = 'hdr' elif self.parse_phase == 'hdr': # Parse a header line and add it to the header dict if len(line) > 0: - k, v = line.split(':', 1) + k, v = line.decode('utf-8').split(':', 1) self.headers[k.strip()] = v.strip() else: # Empty line separates header from body @@ -156,12 +156,16 @@ class MockSocket(bytes): def test_pagelet(path: str): - def with_testing_headers(fun: Callable[[str, str, Dict[str, str], Dict[str, Any], Dict[str, str]], + def with_testing_headers(fun: Callable[[str, + str, + Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + Dict[str, Any], + Dict[str, str]], Tuple[int, Union[bytes, str]]]): @pagelet(path) def testing_wrapper(method: str, path: str, - args: Dict[str, str], + args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], session_vars: Dict[str, Any], headers: Dict[str, str]): status, body = fun(method, path, args, session_vars, headers) diff --git a/matemat/webserver/test/test_post.py b/matemat/webserver/test/test_post.py index ad99247..511c6e3 100644 --- a/matemat/webserver/test/test_post.py +++ b/matemat/webserver/test/test_post.py @@ -1,14 +1,16 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple,Union from matemat.webserver.httpd import HttpHandler from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet +import codecs + @test_pagelet('/just/testing/post') def post_test_pagelet(method: str, path: str, - args: Dict[str, str], + args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], session_vars: Dict[str, Any], headers: Dict[str, str]): """ @@ -16,8 +18,13 @@ def post_test_pagelet(method: str, """ headers['Content-Type'] = 'text/plain' dump: str = '' - for k, v in args.items(): - dump += f'{k}: {v if isinstance(v, str) else ",".join(v)}\n' + for k, (t, v) in args.items(): + if t.startswith('text/'): + if isinstance(v, bytes): + v = v.decode('utf-8') + dump += f'{k}: {",".join(v) if isinstance(v, list) else v}\n' + else: + dump += f'{k}: {codecs.encode(v, "hex").decode("utf-8")}\n' return 200, dump @@ -26,7 +33,7 @@ class TestPost(AbstractHttpdTest): Test cases for the content serving of the web server. """ - def test_post_get_only_args(self): + def test_post_urlenc_get_only_args(self): """ Test a POST request that only contains GET arguments. """ @@ -38,17 +45,17 @@ class TestPost(AbstractHttpdTest): packet = self.client_sock.get_response() # Parse response body - lines: List[str] = packet.body.split('\n')[:-1] + lines: List[bytes] = packet.body.split(b'\n')[:-1] kv: Dict[str, str] = dict() for l in lines: - k, v = l.split(':', 1) + k, v = l.decode('utf-8').split(':', 1) kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') # Make sure the arguments were properly parsed self.assertEqual('bar', kv['foo']) self.assertEqual('1', kv['test']) - def test_post_post_only_args(self): + def test_post_urlenc_post_only_args(self): """ Test a POST request that only contains POST arguments (urlencoded). """ @@ -61,17 +68,17 @@ class TestPost(AbstractHttpdTest): packet = self.client_sock.get_response() # Parse response body - lines: List[str] = packet.body.split('\n')[:-1] + lines: List[bytes] = packet.body.split(b'\n')[:-1] kv: Dict[str, str] = dict() for l in lines: - k, v = l.split(':', 1) + k, v = l.decode('utf-8').split(':', 1) kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') # Make sure the arguments were properly parsed self.assertEqual('bar', kv['foo']) self.assertEqual('1', kv['test']) - def test_post_mixed_args(self): + def test_post_urlenc_mixed_args(self): """ Test that mixed POST and GET args are properly parsed, and that POST takes precedence over GET. """ @@ -84,10 +91,10 @@ class TestPost(AbstractHttpdTest): packet = self.client_sock.get_response() # Parse response body - lines: List[str] = packet.body.split('\n')[:-1] + lines: List[bytes] = packet.body.split(b'\n')[:-1] kv: Dict[str, str] = dict() for l in lines: - k, v = l.split(':', 1) + k, v = l.decode('utf-8').split(':', 1) kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') # Make sure the arguments were properly parsed @@ -95,7 +102,7 @@ class TestPost(AbstractHttpdTest): self.assertEqual('1', kv['gettest']) self.assertEqual('2', kv['posttest']) - def test_post_get_array(self): + def test_post_urlenc_get_array(self): """ Test a POST request that contains GET array arguments. """ @@ -107,17 +114,17 @@ class TestPost(AbstractHttpdTest): packet = self.client_sock.get_response() # Parse response body - lines: List[str] = packet.body.split('\n')[:-1] + lines: List[bytes] = packet.body.split(b'\n')[:-1] kv: Dict[str, str] = dict() for l in lines: - k, v = l.split(':', 1) + k, v = l.decode('utf-8').split(':', 1) kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') # Make sure the arguments were properly parsed self.assertListEqual(['bar', 'baz'], kv['foo']) self.assertEqual('1', kv['test']) - def test_post_post_array(self): + def test_post_urlenc_post_array(self): """ Test a POST request that contains POST array arguments. """ @@ -130,17 +137,17 @@ class TestPost(AbstractHttpdTest): packet = self.client_sock.get_response() # Parse response body - lines: List[str] = packet.body.split('\n')[:-1] + lines: List[bytes] = packet.body.split(b'\n')[:-1] kv: Dict[str, str] = dict() for l in lines: - k, v = l.split(':', 1) + k, v = l.decode('utf-8').split(':', 1) kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') # Make sure the arguments were properly parsed self.assertListEqual(['bar', 'baz'], kv['foo']) self.assertEqual('1', kv['test']) - def test_post_mixed_array(self): + def test_post_urlenc_mixed_array(self): """ Test a POST request that contains both GET and POST array arguments. """ @@ -153,13 +160,85 @@ class TestPost(AbstractHttpdTest): packet = self.client_sock.get_response() # Parse response body - lines: List[str] = packet.body.split('\n')[:-1] + lines: List[bytes] = packet.body.split(b'\n')[:-1] kv: Dict[str, str] = dict() for l in lines: - k, v = l.split(':', 1) + k, v = l.decode('utf-8').split(':', 1) kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') # Make sure the arguments were properly parsed self.assertListEqual(['postbar', 'postbaz'], kv['foo']) self.assertListEqual(['1', '42'], kv['gettest']) self.assertListEqual(['1', '2'], kv['posttest']) + + def test_post_no_body(self): + """ + Test a POST request that contains no headers or body. + """ + # Send POST request + self.client_sock.set_request(b'POST /just/testing/post?foo=bar HTTP/1.1\r\n\r\n') + HttpHandler(self.client_sock, ('::1', 45678), self.server) + packet = self.client_sock.get_response() + # Make sure a 400 Bad Request is returned + self.assertEqual(400, packet.statuscode) + + def test_post_multipart_post_only(self): + """ + Test a POST request with a miltipart/form-data body. + """ + # Send POST request + formdata = (b'------testboundary\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'------testboundary\r\n' + b'Content-Disposition: form-data; name="bar"; filename="foo.bar"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x80\x0b\x0c\x73\x0e\x0f\r\n' + b'------testboundary--\r\n') + + self.client_sock.set_request(f'POST /just/testing/post HTTP/1.1\r\n' + f'Content-Type: multipart/form-data; boundary=----testboundary\r\n' + f'Content-Length: {len(formdata)}\r\n\r\n'.encode('utf-8') + formdata) + HttpHandler(self.client_sock, ('::1', 45678), self.server) + packet = self.client_sock.get_response() + lines: List[bytes] = packet.body.split(b'\n')[:-1] + kv: Dict[str, Any] = dict() + for l in lines: + k, v = l.split(b':', 1) + kv[k.decode('utf-8').strip()] = v.strip() + self.assertIn('foo', kv) + self.assertIn('bar', kv) + self.assertEqual(kv['foo'], b'Hello, World!') + self.assertEqual(kv['bar'], b'00010203040506070809800b0c730e0f') + + def test_post_multipart_mixed(self): + """ + Test a POST request with a miltipart/form-data body. + """ + # Send POST request + formdata = (b'------testboundary\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'------testboundary\r\n' + b'Content-Disposition: form-data; name="bar"; filename="foo.bar"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x80\x0b\x0c\x73\x0e\x0f\r\n' + b'------testboundary--\r\n') + + self.client_sock.set_request(f'POST /just/testing/post?getfoo=bar&foo=thisshouldbegone HTTP/1.1\r\n' + f'Content-Type: multipart/form-data; boundary=----testboundary\r\n' + f'Content-Length: {len(formdata)}\r\n\r\n'.encode('utf-8') + formdata) + HttpHandler(self.client_sock, ('::1', 45678), self.server) + packet = self.client_sock.get_response() + lines: List[bytes] = packet.body.split(b'\n')[:-1] + kv: Dict[str, Any] = dict() + for l in lines: + k, v = l.split(b':', 1) + kv[k.decode('utf-8').strip()] = v.strip() + self.assertIn('foo', kv) + self.assertIn('bar', kv) + self.assertEqual(kv['getfoo'], b'bar') + self.assertEqual(kv['foo'], b'Hello, World!') + self.assertEqual(kv['bar'], b'00010203040506070809800b0c730e0f') diff --git a/matemat/webserver/test/test_serve.py b/matemat/webserver/test/test_serve.py index f3dc6be..0556764 100644 --- a/matemat/webserver/test/test_serve.py +++ b/matemat/webserver/test/test_serve.py @@ -1,5 +1,5 @@ -from typing import Any, Dict +from typing import Any, Dict, Union import os import os.path @@ -10,7 +10,7 @@ from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_p @test_pagelet('/just/testing/serve_pagelet_ok') def serve_test_pagelet_ok(method: str, path: str, - args: Dict[str, str], + args: Dict[str, Union[bytes, str]], session_vars: Dict[str, Any], headers: Dict[str, str]): headers['Content-Type'] = 'text/plain' @@ -20,7 +20,7 @@ def serve_test_pagelet_ok(method: str, @test_pagelet('/just/testing/serve_pagelet_fail') def serve_test_pagelet_fail(method: str, path: str, - args: Dict[str, str], + args: Dict[str, Union[bytes, str]], session_vars: Dict[str, Any], headers: Dict[str, str]): session_vars['test'] = 'hello, world!' @@ -54,7 +54,7 @@ class TestServe(AbstractHttpdTest): self.assertEqual('serve_test_pagelet_ok', packet.pagelet) # Make sure the expected content is served self.assertEqual(200, packet.statuscode) - self.assertEqual('serve test pagelet ok', packet.body) + self.assertEqual(b'serve test pagelet ok', packet.body) def test_serve_pagelet_fail(self): # Call the test pagelet that produces a 500 Internal Server Error result @@ -66,7 +66,7 @@ class TestServe(AbstractHttpdTest): self.assertEqual('serve_test_pagelet_fail', packet.pagelet) # Make sure the expected content is served self.assertEqual(500, packet.statuscode) - self.assertEqual('serve test pagelet fail', packet.body) + self.assertEqual(b'serve test pagelet fail', packet.body) def test_serve_static_ok(self): # Request a static resource @@ -78,7 +78,7 @@ class TestServe(AbstractHttpdTest): self.assertIsNone(packet.pagelet) # Make sure the expected content is served self.assertEqual(200, packet.statuscode) - self.assertEqual('static resource test', packet.body) + self.assertEqual(b'static resource test', packet.body) def test_serve_static_forbidden(self): # Request a static resource with lacking permissions @@ -90,7 +90,7 @@ class TestServe(AbstractHttpdTest): self.assertIsNone(packet.pagelet) # Make sure a 403 header is served self.assertEqual(403, packet.statuscode) - self.assertNotEqual('This should not be readable', packet.body) + self.assertNotEqual(b'This should not be readable', packet.body) def test_serve_not_found(self): # Request a nonexistent resource @@ -116,7 +116,10 @@ class TestServe(AbstractHttpdTest): def test_static_post_not_allowed(self): # Request a resource outside the webroot - self.client_sock.set_request(b'POST /iwanttouploadthis HTTP/1.1\r\n\r\nq=this%20should%20not%20be%20uploaded') + self.client_sock.set_request(b'POST /iwanttopostthis HTTP/1.1\r\n' + b'Content-Type: application/x-www-form-urlencoded\r\n' + b'Content-length: 37\r\n\r\n' + b'q=this%20should%20not%20be%20uploaded') HttpHandler(self.client_sock, ('::1', 45678), self.server) packet = self.client_sock.get_response() diff --git a/matemat/webserver/test/test_session.py b/matemat/webserver/test/test_session.py index b8e21cf..50ade85 100644 --- a/matemat/webserver/test/test_session.py +++ b/matemat/webserver/test/test_session.py @@ -1,5 +1,5 @@ -from typing import Any, Dict +from typing import Any, Dict, Union from datetime import datetime, timedelta from time import sleep @@ -11,7 +11,7 @@ from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_p @test_pagelet('/just/testing/sessions') def session_test_pagelet(method: str, path: str, - args: Dict[str, str], + args: Dict[str, Union[bytes, str]], session_vars: Dict[str, Any], headers: Dict[str, str]): session_vars['test'] = 'hello, world!' diff --git a/matemat/webserver/util.py b/matemat/webserver/util.py new file mode 100644 index 0000000..931f759 --- /dev/null +++ b/matemat/webserver/util.py @@ -0,0 +1,118 @@ + +from typing import Dict, List, Tuple, Optional, Union + +import urllib.parse + + +def _parse_multipart(body: bytes, boundary: str) -> Dict[str, List[Tuple[str, Union[bytes, str]]]]: + """ + Given a HTTP body with form-data in multipart form, and the multipart-boundary, parse the multipart items and + return them as a dictionary. + + :param body: The HTTP multipart/form-data body. + :param boundary: The multipart boundary. + :return: A dictionary of field names as key, and content types and field values as value. + """ + # Generate item header boundary and terminating boundary from general boundary string + _boundary = f'\r\n--{boundary}\r\n'.encode('utf-8') + _end_boundary = f'\r\n--{boundary}--\r\n'.encode('utf-8') + # Split at the end boundary and make sure there comes nothing after it + allparts = body.split(_end_boundary, 1) + if len(allparts) != 2 or allparts[1] != b'': + raise ValueError('Last boundary missing or corrupted') + # Split remaining body into its parts (appending a CRLF for the first boundary to match), and verify at least 1 part + # is there + parts: List[bytes] = (b'\r\n' + allparts[0]).split(_boundary) + if len(parts) < 1 or parts[0] != b'': + raise ValueError('First boundary missing or corrupted') + # Remove the first, empty part + parts = parts[1:] + + # Results go into this dict + args: Dict[str, List[Tuple[str, Union[bytes, str]]]] = dict() + + # Parse each multipart part + for part in parts: + # Parse multipart headers + hdr: Dict[str, str] = dict() + while True: + head, part = part.split(b'\r\n', 1) + # Break on header/body delimiter + if head == b'': + break + # Add header to hdr dict + hk, hv = head.decode('utf-8').split(':') + hdr[hk.strip()] = hv.strip() + # At least Content-Type and Content-Disposition must be present + if 'Content-Type' not in hdr or 'Content-Disposition' not in hdr: + raise ValueError('Missing Content-Type or Content-Disposition header') + # Extract Content-Disposition header value and its arguments + cd, *cdargs = hdr['Content-Disposition'].split(';') + # Content-Disposition MUST be form-data; everything else is rejected + if cd.strip() != 'form-data': + raise ValueError(f'Unknown Content-Disposition: cd') + # Extract the "name" header argument + for cdarg in cdargs: + k, v = cdarg.split('=', 1) + if k.strip() == 'name': + name: str = v.strip() + # Remove quotation marks around the name value + if name.startswith('"') and name.endswith('"'): + name = v[1:-1] + # Add the Content-Type and the content to the header, with the provided name + if name not in args: + args[name] = list() + args[name].append((hdr['Content-Type'].strip(), part)) + + return args + + +def parse_args(request: str, postbody: Optional[bytes] = None, enctype: str = 'text/plain') \ + -> Tuple[str, Dict[str, Tuple[str, Union[bytes, str, List[str]]]]]: + """ + Given a HTTP request path, and optionally a HTTP POST body in application/x-www-form-urlencoded or + multipart/form-data form, parse the arguments and return them as a dictionary. + + If a key is used both in GET and in POST, the POST value takes precedence, and the GET value is discarded. + + :param request: The request string to parse. + :param postbody: The POST body to parse, defaults to None. + :param enctype: Encoding of the POST body; supported values are application/x-www-form-urlencoded and + multipart/form-data. + :return: A tuple consisting of the base path and a dictionary with the parsed key/value pairs, and the value's + content type. + """ + # Parse the request "URL" (i.e. only the path) + tokens = urllib.parse.urlparse(request) + # Parse the GET arguments + getargs = urllib.parse.parse_qs(tokens.query) + + # TODO: { 'foo': [ ('text/plain', 'bar'), ('application/octet-stream', '\x80') ] } + # TODO: Use a @dataclass once Python 3.7 is out + args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]] = dict() + for k, v in getargs.items(): + args[k] = 'text/plain', v + + if postbody is not None: + if enctype == 'application/x-www-form-urlencoded': + # Parse the POST body + postargs = urllib.parse.parse_qs(postbody.decode('utf-8')) + # Write all POST values into the dict, overriding potential duplicates from GET + for k, v in postargs.items(): + args[k] = 'text/plain', v + elif enctype.startswith('multipart/form-data'): + # Parse the multipart boundary from the Content-Type header + boundary: str = enctype.split('boundary=')[1] + # Parse the multipart body + mpargs = _parse_multipart(postbody, boundary) + for k, v in mpargs.items(): + # TODO: Process all values, not just the first + args[k] = v[0] + else: + raise ValueError(f'Unsupported Content-Type: {enctype}') + # urllib.parse.parse_qs turns ALL arguments into arrays. This turns arrays of length 1 into scalar values + for (k, (ct, v)) in args.items(): + if len(v) == 1: + args[k] = ct, v[0] + # Return the path and the parsed arguments + return tokens.path, args From 5bb1dfad2176d5ae0c8325024b1673eb45e9c944 Mon Sep 17 00:00:00 2001 From: s3lph Date: Wed, 27 Jun 2018 21:20:36 +0200 Subject: [PATCH 2/9] Fixed a style error --- matemat/webserver/test/test_post.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matemat/webserver/test/test_post.py b/matemat/webserver/test/test_post.py index 511c6e3..0c6e3d2 100644 --- a/matemat/webserver/test/test_post.py +++ b/matemat/webserver/test/test_post.py @@ -1,5 +1,5 @@ -from typing import Any, Dict, List, Tuple,Union +from typing import Any, Dict, List, Tuple, Union from matemat.webserver.httpd import HttpHandler from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet From 118de8bf95f344ea0c567040334a229d089544f6 Mon Sep 17 00:00:00 2001 From: s3lph Date: Thu, 28 Jun 2018 23:58:01 +0200 Subject: [PATCH 3/9] New request parsing (WIP: Documentation) --- matemat/webserver/__init__.py | 1 + matemat/webserver/httpd.py | 14 +- matemat/webserver/pagelets/__init__.py | 1 - matemat/webserver/pagelets/login.py | 16 +- matemat/webserver/pagelets/logout.py | 4 +- matemat/webserver/pagelets/main.py | 4 +- matemat/webserver/pagelets/touchkey.py | 14 +- matemat/webserver/pagelets/upload_test.py | 28 -- matemat/webserver/requestargs.py | 121 ++++++ matemat/webserver/test/abstract_httpd_test.py | 6 +- matemat/webserver/test/test_parse_request.py | 347 ++++++++++++++++++ matemat/webserver/test/test_post.py | 84 ++--- matemat/webserver/test/test_requestargs.py | 204 ++++++++++ matemat/webserver/test/test_serve.py | 6 +- matemat/webserver/test/test_session.py | 4 +- matemat/webserver/util.py | 64 ++-- 16 files changed, 774 insertions(+), 144 deletions(-) delete mode 100644 matemat/webserver/pagelets/upload_test.py create mode 100644 matemat/webserver/requestargs.py create mode 100644 matemat/webserver/test/test_parse_request.py create mode 100644 matemat/webserver/test/test_requestargs.py diff --git a/matemat/webserver/__init__.py b/matemat/webserver/__init__.py index f4d86f3..1b4ab06 100644 --- a/matemat/webserver/__init__.py +++ b/matemat/webserver/__init__.py @@ -6,4 +6,5 @@ API that can be used by 'pagelets' - single pages of a web service. If a reques server will attempt to serve the request with a static resource in a previously configured webroot directory. """ +from .requestargs import RequestArgument from .httpd import MatematWebserver, HttpHandler, pagelet diff --git a/matemat/webserver/httpd.py b/matemat/webserver/httpd.py index a4e9cca..79efb98 100644 --- a/matemat/webserver/httpd.py +++ b/matemat/webserver/httpd.py @@ -1,5 +1,5 @@ -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, Tuple, Union import traceback @@ -13,6 +13,7 @@ from uuid import uuid4 from datetime import datetime, timedelta from matemat import __version__ as matemat_version +from matemat.webserver import RequestArgument from matemat.webserver.util import parse_args @@ -30,7 +31,7 @@ BaseHTTPRequestHandler.log_error = lambda self, fstring='', *args: None # Dictionary to hold registered pagelet paths and their handler functions _PAGELET_PATHS: Dict[str, Callable[[str, # HTTP method (GET, POST, ...) str, # Request path - Dict[str, Tuple[str, Union[bytes, str, List[str]]]], # args: (name, (type, value)) + Dict[str, RequestArgument], # args: (name, argument) Dict[str, Any], # Session vars Dict[str, str]], # Response headers Tuple[int, Union[bytes, str]]]] = dict() # Returns: (status code, response body) @@ -50,15 +51,14 @@ def pagelet(path: str): (method: str, path: str, - args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str]) -> (int, Optional[Union[str, bytes]]) method: The HTTP method (GET, POST) that was used. path: The path that was requested. - args: The arguments that were passed with the request (as GET or POST arguments), each of which may be - either a str or bytes object, or a list of str. + args: The arguments that were passed with the request (as GET or POST arguments). session_vars: The session storage. May be read from and written to. headers: The dictionary of HTTP response headers. Add headers you wish to send with the response. returns: A tuple consisting of the HTTP status code (as an int) and the response body (as str or bytes, @@ -69,7 +69,7 @@ def pagelet(path: str): def http_handler(fun: Callable[[str, str, - Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + Dict[str, RequestArgument], Dict[str, Any], Dict[str, str]], Tuple[int, Union[bytes, str]]]): @@ -181,7 +181,7 @@ class HttpHandler(BaseHTTPRequestHandler): if session_id in self.server.session_vars: del self.server.session_vars[session_id] - def _handle(self, method: str, path: str, args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]]) -> None: + def _handle(self, method: str, path: str, args: Dict[str, RequestArgument]) -> None: """ Handle a HTTP request by either dispatching it to the appropriate pagelet or by serving a static resource. diff --git a/matemat/webserver/pagelets/__init__.py b/matemat/webserver/pagelets/__init__.py index 71ded5e..9b926d6 100644 --- a/matemat/webserver/pagelets/__init__.py +++ b/matemat/webserver/pagelets/__init__.py @@ -8,4 +8,3 @@ from .main import main_page from .login import login_page from .logout import logout from .touchkey import touchkey_page -from .upload_test import upload_test diff --git a/matemat/webserver/pagelets/login.py b/matemat/webserver/pagelets/login.py index 8fbe831..f7813b4 100644 --- a/matemat/webserver/pagelets/login.py +++ b/matemat/webserver/pagelets/login.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from matemat.exceptions import AuthenticationError -from matemat.webserver import pagelet +from matemat.webserver import pagelet, RequestArgument from matemat.primitives import User from matemat.db import MatematDatabase @@ -10,7 +10,7 @@ from matemat.db import MatematDatabase @pagelet('/login') def login_page(method: str, path: str, - args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: @@ -43,15 +43,11 @@ def login_page(method: str, ''' return 200, data.format(msg=args['msg'] if 'msg' in args else '') elif method == 'POST': - if 'username' not in args or not isinstance(args['username'], str): - return 400, None - if 'password' not in args or not isinstance(args['password'], str): - return 400, None - username: str = args['username'] - password: str = args['password'] + username: RequestArgument = args['username'] + password: RequestArgument = args['password'] with MatematDatabase('test.db') as db: try: - user: User = db.login(username, password) + user: User = db.login(username.get_str(), password.get_str()) except AuthenticationError: headers['Location'] = '/login?msg=Username%20or%20password%20wrong.%20Please%20try%20again.' return 301, bytes() diff --git a/matemat/webserver/pagelets/logout.py b/matemat/webserver/pagelets/logout.py index 53a292a..b70d7c1 100644 --- a/matemat/webserver/pagelets/logout.py +++ b/matemat/webserver/pagelets/logout.py @@ -1,13 +1,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union -from matemat.webserver import pagelet +from matemat.webserver import pagelet, RequestArgument @pagelet('/logout') def logout(method: str, path: str, - args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: diff --git a/matemat/webserver/pagelets/main.py b/matemat/webserver/pagelets/main.py index d2dd208..2b9ce79 100644 --- a/matemat/webserver/pagelets/main.py +++ b/matemat/webserver/pagelets/main.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union -from matemat.webserver import MatematWebserver, pagelet +from matemat.webserver import MatematWebserver, pagelet, RequestArgument from matemat.primitives import User from matemat.db import MatematDatabase @@ -9,7 +9,7 @@ from matemat.db import MatematDatabase @pagelet('/') def main_page(method: str, path: str, - args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: diff --git a/matemat/webserver/pagelets/touchkey.py b/matemat/webserver/pagelets/touchkey.py index 2a8202d..22e3df4 100644 --- a/matemat/webserver/pagelets/touchkey.py +++ b/matemat/webserver/pagelets/touchkey.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from matemat.exceptions import AuthenticationError -from matemat.webserver import pagelet +from matemat.webserver import pagelet, RequestArgument from matemat.primitives import User from matemat.db import MatematDatabase @@ -10,7 +10,7 @@ from matemat.db import MatematDatabase @pagelet('/touchkey') def touchkey_page(method: str, path: str, - args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: @@ -42,15 +42,11 @@ def touchkey_page(method: str, ''' return 200, data.format(username=args['username'] if 'username' in args else '') elif method == 'POST': - if 'username' not in args or not isinstance(args['username'], str): - return 400, None - if 'touchkey' not in args or not isinstance(args['touchkey'], str): - return 400, None - username: str = args['username'] - touchkey: str = args['touchkey'] + username: RequestArgument = args['username'] + touchkey: RequestArgument = args['touchkey'] with MatematDatabase('test.db') as db: try: - user: User = db.login(username, touchkey=touchkey) + user: User = db.login(username.get_str(), touchkey=touchkey.get_str()) except AuthenticationError: headers['Location'] = f'/touchkey?username={args["username"]}&msg=Please%20try%20again.' return 301, bytes() diff --git a/matemat/webserver/pagelets/upload_test.py b/matemat/webserver/pagelets/upload_test.py deleted file mode 100644 index a6f1e85..0000000 --- a/matemat/webserver/pagelets/upload_test.py +++ /dev/null @@ -1,28 +0,0 @@ - -from typing import Any, Dict, Union - -from matemat.webserver import pagelet - - -@pagelet('/upload') -def upload_test(method: str, - path: str, - args: Dict[str, Union[str, bytes]], - session_vars: Dict[str, Any], - headers: Dict[str, str]): - if method == 'GET': - return 200, ''' - - - -
- - - -
- - - ''' - else: - headers['Content-Type'] = 'text/plain' - return 200, args.items().__str__() diff --git a/matemat/webserver/requestargs.py b/matemat/webserver/requestargs.py new file mode 100644 index 0000000..a35f759 --- /dev/null +++ b/matemat/webserver/requestargs.py @@ -0,0 +1,121 @@ + +from typing import List, Optional, Tuple, Union + + +class RequestArgument(object): + + def __init__(self, + name: str, + value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]] = None) -> None: + self.__name: str = name + self.__value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]] = None + if value is None: + self.__value = [] + else: + if isinstance(value, list): + if len(value) == 1: + self.__value = value[0] + else: + self.__value = value + else: + self.__value = value + + @property + def is_array(self) -> bool: + return isinstance(self.__value, list) + + @property + def is_scalar(self) -> bool: + return not isinstance(self.__value, list) + + @property + def is_view(self) -> bool: + return False + + @property + def name(self) -> str: + return self.__name + + def get_str(self, index: int = None) -> Optional[str]: + if self.is_array: + if index is None: + raise ValueError('index must not be None') + v: Tuple[str, Union[bytes, str]] = self.__value[index] + if isinstance(v[1], str): + return v[1] + elif isinstance(v[1], bytes): + return v[1].decode('utf-8') + else: + if index is not None: + raise ValueError('index must be None') + if isinstance(self.__value[1], str): + return self.__value[1] + elif isinstance(self.__value[1], bytes): + return self.__value[1].decode('utf-8') + + def get_bytes(self, index: int = None) -> Optional[bytes]: + if self.is_array: + if index is None: + raise ValueError('index must not be None') + v: Tuple[str, Union[bytes, str]] = self.__value[index] + if isinstance(v[1], bytes): + return v[1] + elif isinstance(v[1], str): + return v[1].encode('utf-8') + else: + if index is not None: + raise ValueError('index must be None') + if isinstance(self.__value[1], bytes): + return self.__value[1] + elif isinstance(self.__value[1], str): + return self.__value[1].encode('utf-8') + + def get_content_type(self, index: int = None) -> Optional[str]: + if self.is_array: + if index is None: + raise ValueError('index must not be None') + v: Tuple[str, Union[bytes, str]] = self.__value[index] + return v[0] + else: + if index is not None: + raise ValueError('index must be None') + return self.__value[0] + + def append(self, ctype: str, value: Union[str, bytes]): + if self.is_view: + raise TypeError('A RequestArgument view is immutable!') + if len(self) == 0: + self.__value = ctype, value + else: + if self.is_scalar: + self.__value = [self.__value] + self.__value.append((ctype, value)) + + def __len__(self): + return len(self.__value) if self.is_array else 1 + + def __iter__(self): + if self.is_scalar: + yield _View(self.__name, self.__value) + else: + # Typing helper + _value: List[Tuple[str, Union[bytes, str]]] = self.__value + for v in _value: + yield _View(self.__name, v) + + def __getitem__(self, index: Union[int, slice]): + if self.is_scalar: + if index == 0: + return _View(self.__name, self.__value) + raise ValueError('Scalar RequestArgument only indexable with 0') + return _View(self.__name, self.__value[index]) + + +class _View(RequestArgument): + + def __init__(self, name: str, value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]]): + super().__init__(name, value) + + @property + def is_view(self) -> bool: + return True diff --git a/matemat/webserver/test/abstract_httpd_test.py b/matemat/webserver/test/abstract_httpd_test.py index b96767e..daa1126 100644 --- a/matemat/webserver/test/abstract_httpd_test.py +++ b/matemat/webserver/test/abstract_httpd_test.py @@ -9,7 +9,7 @@ from abc import ABC from datetime import datetime from http.server import HTTPServer -from matemat.webserver.httpd import pagelet +from matemat.webserver import pagelet, RequestArgument class HttpResponse: @@ -158,14 +158,14 @@ def test_pagelet(path: str): def with_testing_headers(fun: Callable[[str, str, - Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + Dict[str, RequestArgument], Dict[str, Any], Dict[str, str]], Tuple[int, Union[bytes, str]]]): @pagelet(path) def testing_wrapper(method: str, path: str, - args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str]): status, body = fun(method, path, args, session_vars, headers) diff --git a/matemat/webserver/test/test_parse_request.py b/matemat/webserver/test/test_parse_request.py new file mode 100644 index 0000000..a533936 --- /dev/null +++ b/matemat/webserver/test/test_parse_request.py @@ -0,0 +1,347 @@ + +import unittest + +from matemat.webserver.util import parse_args + + +class TestParseRequest(unittest.TestCase): + + def test_parse_get_root(self): + path, args = parse_args('/') + self.assertEqual('/', path) + self.assertEqual(0, len(args)) + + def test_parse_get_no_args(self): + path, args = parse_args('/index.html') + self.assertEqual('/index.html', path) + self.assertEqual(0, len(args)) + + def test_parse_get_root_getargs(self): + path, args = parse_args('/?foo=42&bar=1337&baz=Hello,%20World!') + self.assertEqual('/', path) + self.assertEqual(3, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('bar', args.keys()) + self.assertIn('baz', args.keys()) + self.assertTrue(args['foo'].is_scalar) + self.assertTrue(args['bar'].is_scalar) + self.assertTrue(args['baz'].is_scalar) + self.assertEqual('text/plain', args['foo'].get_content_type()) + self.assertEqual('text/plain', args['bar'].get_content_type()) + self.assertEqual('text/plain', args['baz'].get_content_type()) + self.assertEqual('42', args['foo'].get_str()) + self.assertEqual('1337', args['bar'].get_str()) + self.assertEqual('Hello, World!', args['baz'].get_str()) + + def test_parse_get_getargs(self): + path, args = parse_args('/abc/def?foo=42&bar=1337&baz=Hello,%20World!') + self.assertEqual('/abc/def', path) + self.assertEqual(3, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('bar', args.keys()) + self.assertIn('baz', args.keys()) + self.assertTrue(args['foo'].is_scalar) + self.assertTrue(args['bar'].is_scalar) + self.assertTrue(args['baz'].is_scalar) + self.assertEqual('text/plain', args['foo'].get_content_type()) + self.assertEqual('text/plain', args['bar'].get_content_type()) + self.assertEqual('text/plain', args['baz'].get_content_type()) + self.assertEqual('42', args['foo'].get_str()) + self.assertEqual('1337', args['bar'].get_str()) + self.assertEqual('Hello, World!', args['baz'].get_str()) + + def test_parse_get_getarray(self): + path, args = parse_args('/abc/def?foo=42&foo=1337&baz=Hello,%20World!') + self.assertEqual('/abc/def', path) + self.assertEqual(2, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('baz', args.keys()) + self.assertTrue(args['foo'].is_array) + self.assertTrue(args['baz'].is_scalar) + self.assertEqual(2, len(args['foo'])) + self.assertEqual('42', args['foo'].get_str(0)) + self.assertEqual('1337', args['foo'].get_str(1)) + + def test_parse_get_zero_arg(self): + path, args = parse_args('/abc/def?foo=&bar=42') + self.assertEqual(2, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('bar', args.keys()) + self.assertTrue(args['foo'].is_scalar) + self.assertTrue(args['bar'].is_scalar) + self.assertEqual(1, len(args['foo'])) + self.assertEqual('', args['foo'].get_str()) + self.assertEqual('42', args['bar'].get_str()) + + def test_parse_get_urlencoded_encoding_fail(self): + with self.assertRaises(ValueError): + parse_args('/?foo=42&bar=%80&baz=Hello,%20World!') + + def test_parse_post_urlencoded(self): + path, args = parse_args('/', + postbody=b'foo=42&bar=1337&baz=Hello,%20World!', + enctype='application/x-www-form-urlencoded') + self.assertEqual('/', path) + self.assertEqual(3, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('bar', args.keys()) + self.assertIn('baz', args.keys()) + self.assertTrue(args['foo'].is_scalar) + self.assertTrue(args['bar'].is_scalar) + self.assertTrue(args['baz'].is_scalar) + self.assertEqual('text/plain', args['foo'].get_content_type()) + self.assertEqual('text/plain', args['bar'].get_content_type()) + self.assertEqual('text/plain', args['baz'].get_content_type()) + self.assertEqual('42', args['foo'].get_str()) + self.assertEqual('1337', args['bar'].get_str()) + self.assertEqual('Hello, World!', args['baz'].get_str()) + + def test_parse_post_urlencoded_array(self): + path, args = parse_args('/', + postbody=b'foo=42&foo=1337&baz=Hello,%20World!', + enctype='application/x-www-form-urlencoded') + self.assertEqual('/', path) + self.assertEqual(2, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('baz', args.keys()) + self.assertTrue(args['foo'].is_array) + self.assertTrue(args['baz'].is_scalar) + self.assertEqual(2, len(args['foo'])) + self.assertEqual('42', args['foo'].get_str(0)) + self.assertEqual('1337', args['foo'].get_str(1)) + + def test_parse_post_urlencoded_zero_arg(self): + path, args = parse_args('/abc/def', postbody=b'foo=&bar=42', enctype='application/x-www-form-urlencoded') + self.assertEqual(2, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('bar', args.keys()) + self.assertTrue(args['foo'].is_scalar) + self.assertTrue(args['bar'].is_scalar) + self.assertEqual(1, len(args['foo'])) + self.assertEqual('', args['foo'].get_str()) + self.assertEqual('42', args['bar'].get_str()) + + def test_parse_post_urlencoded_encoding_fail(self): + with self.assertRaises(ValueError): + parse_args('/', + postbody=b'foo=42&bar=%80&baz=Hello,%20World!', + enctype='application/x-www-form-urlencoded') + + def test_parse_post_multipart_no_args(self): + path, args = parse_args('/', + postbody=b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + self.assertEqual('/', path) + self.assertEqual(0, len(args)) + + def test_parse_post_multipart(self): + path, args = parse_args('/', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="bar"; filename="bar.bin"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + self.assertEqual('/', path) + self.assertEqual(3, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('bar', args.keys()) + self.assertIn('baz', args.keys()) + self.assertTrue(args['foo'].is_scalar) + self.assertTrue(args['bar'].is_scalar) + self.assertTrue(args['baz'].is_scalar) + self.assertEqual('text/plain', args['foo'].get_content_type()) + self.assertEqual('application/octet-stream', args['bar'].get_content_type()) + self.assertEqual('text/plain', args['baz'].get_content_type()) + self.assertEqual('42', args['foo'].get_str()) + self.assertEqual(b'1337', args['bar'].get_bytes()) + self.assertEqual('Hello, World!', args['baz'].get_str()) + + def test_parse_post_multipart_zero_arg(self): + path, args = parse_args('/abc/def', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="bar"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + self.assertEqual(2, len(args)) + self.assertIn('foo', args.keys()) + self.assertIn('bar', args.keys()) + self.assertTrue(args['foo'].is_scalar) + self.assertTrue(args['bar'].is_scalar) + self.assertEqual(1, len(args['foo'])) + self.assertEqual('', args['foo'].get_str()) + self.assertEqual('42', args['bar'].get_str()) + + def test_parse_post_multipart_broken_boundaries(self): + with self.assertRaises(ValueError): + # Boundary not defined in Content-Type + parse_args('/', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="bar"; filename="bar.bin"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data') + with self.assertRaises(ValueError): + # Corrupted "--" head at first boundary + parse_args('/', + postbody=b'-+testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="bar"; filename="bar.bin"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + with self.assertRaises(ValueError): + # Missing "--" tail at end boundary + parse_args('/', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="bar"; filename="bar.bin"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + with self.assertRaises(ValueError): + # Missing Content-Type header in one part + parse_args('/', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="bar"; filename="bar.bin"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + with self.assertRaises(ValueError): + # Missing Content-Disposition header in one part + parse_args('/', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + with self.assertRaises(ValueError): + # Missing form-data name argument + parse_args('/', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; filename="bar.bin"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + with self.assertRaises(ValueError): + # Unknown Content-Disposition + parse_args('/', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: attachment; name="bar"; filename="bar.bin"\r\n' + b'Content-Type: application/octet-stream\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + + def test_get_post_precedence_urlencoded(self): + path, args = parse_args('/foo?foo=thisshouldnotbethere&bar=isurvived', + postbody=b'foo=42&foo=1337&baz=Hello,%20World!', + enctype='application/x-www-form-urlencoded') + self.assertIn('foo', args) + self.assertIn('bar', args) + self.assertIn('baz', args) + self.assertEqual(2, len(args['foo'])) + self.assertEqual(1, len(args['bar'])) + self.assertEqual(1, len(args['baz'])) + self.assertEqual('42', args['foo'].get_str(0)) + self.assertEqual('1337', args['foo'].get_str(1)) + self.assertEqual('isurvived', args['bar'].get_str()) + self.assertEqual('Hello, World!', args['baz'].get_str()) + + def test_get_post_precedence_multipart(self): + path, args = parse_args('/foo?foo=thisshouldnotbethere&bar=isurvived', + postbody=b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'42\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="foo"; filename="bar.bin"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'1337\r\n' + b'--testBoundary1337\r\n' + b'Content-Disposition: form-data; name="baz"\r\n' + b'Content-Type: text/plain\r\n\r\n' + b'Hello, World!\r\n' + b'--testBoundary1337--\r\n', + enctype='multipart/form-data; boundary=testBoundary1337') + self.assertIn('foo', args) + self.assertIn('bar', args) + self.assertIn('baz', args) + self.assertEqual(2, len(args['foo'])) + self.assertEqual(1, len(args['bar'])) + self.assertEqual(1, len(args['baz'])) + self.assertEqual('42', args['foo'].get_str(0)) + self.assertEqual('1337', args['foo'].get_str(1)) + self.assertEqual('isurvived', args['bar'].get_str()) + self.assertEqual('Hello, World!', args['baz'].get_str()) diff --git a/matemat/webserver/test/test_post.py b/matemat/webserver/test/test_post.py index 0c6e3d2..9b2fe22 100644 --- a/matemat/webserver/test/test_post.py +++ b/matemat/webserver/test/test_post.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Tuple, Union -from matemat.webserver.httpd import HttpHandler +from matemat.webserver import HttpHandler, RequestArgument from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet import codecs @@ -10,7 +10,7 @@ import codecs @test_pagelet('/just/testing/post') def post_test_pagelet(method: str, path: str, - args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str]): """ @@ -18,13 +18,12 @@ def post_test_pagelet(method: str, """ headers['Content-Type'] = 'text/plain' dump: str = '' - for k, (t, v) in args.items(): - if t.startswith('text/'): - if isinstance(v, bytes): - v = v.decode('utf-8') - dump += f'{k}: {",".join(v) if isinstance(v, list) else v}\n' - else: - dump += f'{k}: {codecs.encode(v, "hex").decode("utf-8")}\n' + for k, ra in args.items(): + for a in ra: + if a.get_content_type().startswith('text/'): + dump += f'{k}: {a.get_str()}\n' + else: + dump += f'{k}: {codecs.encode(a.get_bytes(), "hex").decode("utf-8")}\n' return 200, dump @@ -118,10 +117,14 @@ class TestPost(AbstractHttpdTest): kv: Dict[str, str] = dict() for l in lines: k, v = l.decode('utf-8').split(':', 1) - kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') - + k = k.strip() + v = v.strip() + if k in kv: + kv[k] += f',{v}' + else: + kv[k] = v # Make sure the arguments were properly parsed - self.assertListEqual(['bar', 'baz'], kv['foo']) + self.assertEqual('bar,baz', kv['foo']) self.assertEqual('1', kv['test']) def test_post_urlenc_post_array(self): @@ -141,10 +144,14 @@ class TestPost(AbstractHttpdTest): kv: Dict[str, str] = dict() for l in lines: k, v = l.decode('utf-8').split(':', 1) - kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') - + k = k.strip() + v = v.strip() + if k in kv: + kv[k] += f',{v}' + else: + kv[k] = v # Make sure the arguments were properly parsed - self.assertListEqual(['bar', 'baz'], kv['foo']) + self.assertEqual('bar,baz', kv['foo']) self.assertEqual('1', kv['test']) def test_post_urlenc_mixed_array(self): @@ -164,12 +171,16 @@ class TestPost(AbstractHttpdTest): kv: Dict[str, str] = dict() for l in lines: k, v = l.decode('utf-8').split(':', 1) - kv[k.strip()] = v.strip() if ',' not in v else v.strip().split(',') - + k = k.strip() + v = v.strip() + if k in kv: + kv[k] += f',{v}' + else: + kv[k] = v # Make sure the arguments were properly parsed - self.assertListEqual(['postbar', 'postbaz'], kv['foo']) - self.assertListEqual(['1', '42'], kv['gettest']) - self.assertListEqual(['1', '2'], kv['posttest']) + self.assertEqual('postbar,postbaz', kv['foo']) + self.assertEqual('1,42', kv['gettest']) + self.assertEqual('1,2', kv['posttest']) def test_post_no_body(self): """ @@ -184,7 +195,7 @@ class TestPost(AbstractHttpdTest): def test_post_multipart_post_only(self): """ - Test a POST request with a miltipart/form-data body. + Test a POST request with a miutipart/form-data body. """ # Send POST request formdata = (b'------testboundary\r\n' @@ -211,34 +222,3 @@ class TestPost(AbstractHttpdTest): self.assertIn('bar', kv) self.assertEqual(kv['foo'], b'Hello, World!') self.assertEqual(kv['bar'], b'00010203040506070809800b0c730e0f') - - def test_post_multipart_mixed(self): - """ - Test a POST request with a miltipart/form-data body. - """ - # Send POST request - formdata = (b'------testboundary\r\n' - b'Content-Disposition: form-data; name="foo"\r\n' - b'Content-Type: text/plain\r\n\r\n' - b'Hello, World!\r\n' - b'------testboundary\r\n' - b'Content-Disposition: form-data; name="bar"; filename="foo.bar"\r\n' - b'Content-Type: application/octet-stream\r\n\r\n' - b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x80\x0b\x0c\x73\x0e\x0f\r\n' - b'------testboundary--\r\n') - - self.client_sock.set_request(f'POST /just/testing/post?getfoo=bar&foo=thisshouldbegone HTTP/1.1\r\n' - f'Content-Type: multipart/form-data; boundary=----testboundary\r\n' - f'Content-Length: {len(formdata)}\r\n\r\n'.encode('utf-8') + formdata) - HttpHandler(self.client_sock, ('::1', 45678), self.server) - packet = self.client_sock.get_response() - lines: List[bytes] = packet.body.split(b'\n')[:-1] - kv: Dict[str, Any] = dict() - for l in lines: - k, v = l.split(b':', 1) - kv[k.decode('utf-8').strip()] = v.strip() - self.assertIn('foo', kv) - self.assertIn('bar', kv) - self.assertEqual(kv['getfoo'], b'bar') - self.assertEqual(kv['foo'], b'Hello, World!') - self.assertEqual(kv['bar'], b'00010203040506070809800b0c730e0f') diff --git a/matemat/webserver/test/test_requestargs.py b/matemat/webserver/test/test_requestargs.py new file mode 100644 index 0000000..dcdde14 --- /dev/null +++ b/matemat/webserver/test/test_requestargs.py @@ -0,0 +1,204 @@ + +from typing import List + +import unittest + +from matemat.webserver import RequestArgument +# noinspection PyProtectedMember +from matemat.webserver.requestargs import _View + + +class TestRequestArguments(unittest.TestCase): + + def test_create_default(self): + ra = RequestArgument('foo') + self.assertEqual('foo', ra.name) + self.assertEqual(0, len(ra)) + self.assertFalse(ra.is_scalar) + self.assertTrue(ra.is_array) + self.assertFalse(ra.is_view) + + def test_create_str_scalar(self): + ra = RequestArgument('foo', ('text/plain', 'bar')) + self.assertEqual('foo', ra.name) + self.assertEqual(1, len(ra)) + self.assertTrue(ra.is_scalar) + self.assertFalse(ra.is_array) + self.assertEqual('bar', ra.get_str()) + self.assertEqual(b'bar', ra.get_bytes()) + self.assertEqual('text/plain', ra.get_content_type()) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_str(0)) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_bytes(0)) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_content_type(0)) + self.assertFalse(ra.is_view) + + def test_create_str_scalar_array(self): + ra = RequestArgument('foo', [('text/plain', 'bar')]) + self.assertEqual('foo', ra.name) + self.assertEqual(1, len(ra)) + self.assertTrue(ra.is_scalar) + self.assertFalse(ra.is_array) + self.assertEqual('bar', ra.get_str()) + self.assertEqual(b'bar', ra.get_bytes()) + self.assertEqual('text/plain', ra.get_content_type()) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_str(0)) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_bytes(0)) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_content_type(0)) + self.assertFalse(ra.is_view) + + def test_create_bytes_scalar(self): + ra = RequestArgument('foo', ('application/octet-stream', b'\x00\x80\xff\xfe')) + self.assertEqual('foo', ra.name) + self.assertEqual(1, len(ra)) + self.assertTrue(ra.is_scalar) + self.assertFalse(ra.is_array) + with self.assertRaises(UnicodeDecodeError): + ra.get_str() + self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes()) + self.assertEqual('application/octet-stream', ra.get_content_type()) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_str(0)) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_bytes(0)) + with self.assertRaises(ValueError): + self.assertEqual('bar', ra.get_content_type(0)) + self.assertFalse(ra.is_view) + + def test_create_array(self): + ra = RequestArgument('foo', [ + ('text/plain', 'bar'), + ('application/octet-stream', b'\x00\x80\xff\xfe') + ]) + self.assertEqual('foo', ra.name) + self.assertEqual(2, len(ra)) + self.assertFalse(ra.is_scalar) + self.assertTrue(ra.is_array) + with self.assertRaises(ValueError): + ra.get_str() + with self.assertRaises(ValueError): + ra.get_bytes() + with self.assertRaises(ValueError): + ra.get_content_type() + self.assertEqual('bar', ra.get_str(0)) + self.assertEqual(b'bar', ra.get_bytes(0)) + self.assertEqual('text/plain', ra.get_content_type(0)) + with self.assertRaises(UnicodeDecodeError): + ra.get_str(1) + self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes(1)) + self.assertEqual('application/octet-stream', ra.get_content_type(1)) + self.assertFalse(ra.is_view) + + def test_append_empty_str(self): + ra = RequestArgument('foo') + self.assertEqual(0, len(ra)) + self.assertFalse(ra.is_scalar) + + ra.append('text/plain', 'bar') + self.assertEqual(1, len(ra)) + self.assertTrue(ra.is_scalar) + self.assertEqual('bar', ra.get_str()) + self.assertEqual(b'bar', ra.get_bytes()) + self.assertEqual('text/plain', ra.get_content_type()) + self.assertFalse(ra.is_view) + + def test_append_empty_bytes(self): + ra = RequestArgument('foo') + self.assertEqual(0, len(ra)) + self.assertFalse(ra.is_scalar) + + ra.append('application/octet-stream', b'\x00\x80\xff\xfe') + self.assertEqual(1, len(ra)) + self.assertTrue(ra.is_scalar) + with self.assertRaises(UnicodeDecodeError): + ra.get_str() + self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes()) + self.assertEqual('application/octet-stream', ra.get_content_type()) + self.assertFalse(ra.is_view) + + def test_append_multiple(self): + ra = RequestArgument('foo') + self.assertEqual(0, len(ra)) + self.assertFalse(ra.is_scalar) + + ra.append('text/plain', 'bar') + self.assertEqual(1, len(ra)) + self.assertTrue(ra.is_scalar) + + ra.append('application/octet-stream', b'\x00\x80\xff\xfe') + self.assertEqual(2, len(ra)) + self.assertFalse(ra.is_scalar) + + ra.append('text/plain', 'Hello, World!') + self.assertEqual(3, len(ra)) + self.assertFalse(ra.is_scalar) + + def test_iterate_empty(self): + ra = RequestArgument('foo') + self.assertEqual(0, len(ra)) + for _ in ra: + self.fail() + + def test_iterate_scalar(self): + ra = RequestArgument('foo', ('text/plain', 'bar')) + self.assertTrue(ra.is_scalar) + count: int = 0 + for it in ra: + self.assertIsInstance(it, _View) + self.assertEqual('foo', it.name) + self.assertTrue(it.is_view) + self.assertTrue(it.is_scalar) + count += 1 + self.assertEqual(1, count) + + def test_iterate_array(self): + ra = RequestArgument('foo', [('text/plain', 'bar'), ('abc', b'def'), ('xyz', '1337')]) + self.assertFalse(ra.is_scalar) + items: List[str] = list() + for it in ra: + self.assertIsInstance(it, _View) + self.assertTrue(it.is_view) + self.assertTrue(it.is_scalar) + items.append(it.get_content_type()) + self.assertEqual(['text/plain', 'abc', 'xyz'], items) + + def test_iterate_sliced(self): + ra = RequestArgument('foo', [('a', 'b'), ('c', 'd'), ('e', 'f'), ('g', 'h'), ('i', 'j'), ('k', 'l')]) + self.assertFalse(ra.is_scalar) + items: List[str] = list() + for it in ra[1:4:2]: + self.assertIsInstance(it, _View) + self.assertTrue(it.is_view) + self.assertTrue(it.is_scalar) + items.append(it.get_content_type()) + self.assertEqual(['c', 'g'], items) + + def test_index_scalar(self): + ra = RequestArgument('foo', ('bar', 'baz')) + it = ra[0] + self.assertIsInstance(it, _View) + self.assertEqual('foo', it.name) + self.assertEqual('bar', it.get_content_type()) + self.assertEqual('baz', it.get_str()) + with self.assertRaises(ValueError): + _ = ra[1] + + def test_index_array(self): + ra = RequestArgument('foo', [('a', 'b'), ('c', 'd')]) + it = ra[1] + self.assertIsInstance(it, _View) + self.assertEqual('foo', it.name) + self.assertEqual('c', it.get_content_type()) + self.assertEqual('d', it.get_str()) + + def test_view_immutable(self): + ra = RequestArgument('foo', ('bar', 'baz')) + it = ra[0] + self.assertIsInstance(it, _View) + with self.assertRaises(TypeError): + it.append('foo', 'bar') diff --git a/matemat/webserver/test/test_serve.py b/matemat/webserver/test/test_serve.py index 0556764..7e159e3 100644 --- a/matemat/webserver/test/test_serve.py +++ b/matemat/webserver/test/test_serve.py @@ -3,14 +3,14 @@ from typing import Any, Dict, Union import os import os.path -from matemat.webserver.httpd import HttpHandler +from matemat.webserver import HttpHandler, RequestArgument from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet @test_pagelet('/just/testing/serve_pagelet_ok') def serve_test_pagelet_ok(method: str, path: str, - args: Dict[str, Union[bytes, str]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str]): headers['Content-Type'] = 'text/plain' @@ -20,7 +20,7 @@ def serve_test_pagelet_ok(method: str, @test_pagelet('/just/testing/serve_pagelet_fail') def serve_test_pagelet_fail(method: str, path: str, - args: Dict[str, Union[bytes, str]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str]): session_vars['test'] = 'hello, world!' diff --git a/matemat/webserver/test/test_session.py b/matemat/webserver/test/test_session.py index 50ade85..fe30529 100644 --- a/matemat/webserver/test/test_session.py +++ b/matemat/webserver/test/test_session.py @@ -4,14 +4,14 @@ from typing import Any, Dict, Union from datetime import datetime, timedelta from time import sleep -from matemat.webserver.httpd import HttpHandler +from matemat.webserver import HttpHandler, RequestArgument from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet @test_pagelet('/just/testing/sessions') def session_test_pagelet(method: str, path: str, - args: Dict[str, Union[bytes, str]], + args: Dict[str, RequestArgument], session_vars: Dict[str, Any], headers: Dict[str, str]): session_vars['test'] = 'hello, world!' diff --git a/matemat/webserver/util.py b/matemat/webserver/util.py index 931f759..85ef721 100644 --- a/matemat/webserver/util.py +++ b/matemat/webserver/util.py @@ -3,8 +3,10 @@ from typing import Dict, List, Tuple, Optional, Union import urllib.parse +from matemat.webserver import RequestArgument -def _parse_multipart(body: bytes, boundary: str) -> Dict[str, List[Tuple[str, Union[bytes, str]]]]: + +def _parse_multipart(body: bytes, boundary: str) -> List[RequestArgument]: """ Given a HTTP body with form-data in multipart form, and the multipart-boundary, parse the multipart items and return them as a dictionary. @@ -13,6 +15,8 @@ def _parse_multipart(body: bytes, boundary: str) -> Dict[str, List[Tuple[str, Un :param boundary: The multipart boundary. :return: A dictionary of field names as key, and content types and field values as value. """ + # Prepend a CRLF for the first boundary to match + body = b'\r\n' + body # Generate item header boundary and terminating boundary from general boundary string _boundary = f'\r\n--{boundary}\r\n'.encode('utf-8') _end_boundary = f'\r\n--{boundary}--\r\n'.encode('utf-8') @@ -20,16 +24,15 @@ def _parse_multipart(body: bytes, boundary: str) -> Dict[str, List[Tuple[str, Un allparts = body.split(_end_boundary, 1) if len(allparts) != 2 or allparts[1] != b'': raise ValueError('Last boundary missing or corrupted') - # Split remaining body into its parts (appending a CRLF for the first boundary to match), and verify at least 1 part - # is there - parts: List[bytes] = (b'\r\n' + allparts[0]).split(_boundary) + # Split remaining body into its parts, and verify at least 1 part is there + parts: List[bytes] = (allparts[0]).split(_boundary) if len(parts) < 1 or parts[0] != b'': raise ValueError('First boundary missing or corrupted') # Remove the first, empty part parts = parts[1:] # Results go into this dict - args: Dict[str, List[Tuple[str, Union[bytes, str]]]] = dict() + args: Dict[str, RequestArgument] = dict() # Parse each multipart part for part in parts: @@ -50,25 +53,29 @@ def _parse_multipart(body: bytes, boundary: str) -> Dict[str, List[Tuple[str, Un cd, *cdargs = hdr['Content-Disposition'].split(';') # Content-Disposition MUST be form-data; everything else is rejected if cd.strip() != 'form-data': - raise ValueError(f'Unknown Content-Disposition: cd') + raise ValueError(f'Unknown Content-Disposition: {cd}') # Extract the "name" header argument + has_name = False for cdarg in cdargs: k, v = cdarg.split('=', 1) if k.strip() == 'name': + has_name = True name: str = v.strip() # Remove quotation marks around the name value if name.startswith('"') and name.endswith('"'): name = v[1:-1] # Add the Content-Type and the content to the header, with the provided name if name not in args: - args[name] = list() - args[name].append((hdr['Content-Type'].strip(), part)) + args[name] = RequestArgument(name) + args[name].append(hdr['Content-Type'].strip(), part) + if not has_name: + raise ValueError('mutlipart/form-data part without name attribute') - return args + return list(args.values()) def parse_args(request: str, postbody: Optional[bytes] = None, enctype: str = 'text/plain') \ - -> Tuple[str, Dict[str, Tuple[str, Union[bytes, str, List[str]]]]]: + -> Tuple[str, Dict[str, RequestArgument]]: """ Given a HTTP request path, and optionally a HTTP POST body in application/x-www-form-urlencoded or multipart/form-data form, parse the arguments and return them as a dictionary. @@ -85,34 +92,41 @@ def parse_args(request: str, postbody: Optional[bytes] = None, enctype: str = 't # Parse the request "URL" (i.e. only the path) tokens = urllib.parse.urlparse(request) # Parse the GET arguments - getargs = urllib.parse.parse_qs(tokens.query) + if len(tokens.query) == 0: + getargs = dict() + else: + getargs = urllib.parse.parse_qs(tokens.query, strict_parsing=True, keep_blank_values=True, errors='strict') - # TODO: { 'foo': [ ('text/plain', 'bar'), ('application/octet-stream', '\x80') ] } - # TODO: Use a @dataclass once Python 3.7 is out - args: Dict[str, Tuple[str, Union[bytes, str, List[str]]]] = dict() + args: Dict[str, RequestArgument] = dict() for k, v in getargs.items(): - args[k] = 'text/plain', v + args[k] = RequestArgument(k) + for _v in v: + args[k].append('text/plain', _v) if postbody is not None: if enctype == 'application/x-www-form-urlencoded': # Parse the POST body - postargs = urllib.parse.parse_qs(postbody.decode('utf-8')) + pb: str = postbody.decode('utf-8') + if len(pb) == 0: + postargs = dict() + else: + postargs = urllib.parse.parse_qs(pb, strict_parsing=True, keep_blank_values=True, errors='strict') # Write all POST values into the dict, overriding potential duplicates from GET for k, v in postargs.items(): - args[k] = 'text/plain', v + args[k] = RequestArgument(k) + for _v in v: + args[k].append('text/plain', _v) elif enctype.startswith('multipart/form-data'): # Parse the multipart boundary from the Content-Type header - boundary: str = enctype.split('boundary=')[1] + try: + boundary: str = enctype.split('boundary=')[1].strip() + except IndexError: + raise ValueError('Multipart boundary in header not set or corrupted') # Parse the multipart body mpargs = _parse_multipart(postbody, boundary) - for k, v in mpargs.items(): - # TODO: Process all values, not just the first - args[k] = v[0] + for ra in mpargs: + args[ra.name] = ra else: raise ValueError(f'Unsupported Content-Type: {enctype}') - # urllib.parse.parse_qs turns ALL arguments into arrays. This turns arrays of length 1 into scalar values - for (k, (ct, v)) in args.items(): - if len(v) == 1: - args[k] = ct, v[0] # Return the path and the parsed arguments return tokens.path, args From 8898abc77b9b9b4c90635598f358e57e1dfc2232 Mon Sep 17 00:00:00 2001 From: s3lph Date: Fri, 29 Jun 2018 01:12:25 +0200 Subject: [PATCH 4/9] Documentation of the RequestArgument class. --- matemat/webserver/requestargs.py | 164 +++++++++++++++++++++++++++++-- matemat/webserver/util.py | 1 + 2 files changed, 159 insertions(+), 6 deletions(-) diff --git a/matemat/webserver/requestargs.py b/matemat/webserver/requestargs.py index a35f759..02df0b2 100644 --- a/matemat/webserver/requestargs.py +++ b/matemat/webserver/requestargs.py @@ -1,121 +1,273 @@ -from typing import List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple, Union class RequestArgument(object): + """ + Container class for HTTP request arguments that simplifies dealing with + - scalar and array arguments: + Automatically converts between single values and arrays where necessary: Arrays with one element can be + accessed as scalars, and scalars can be iterated, yielding themselves as a single item. + - UTF-8 strings and binary data (e.g. file uploads): + All data can be retrieved both as a str (if utf-8 decoding is possible) and a bytes object. + + The objects returned from iteration or indexing are immutable views of (parts of) this object. + + Usage example: + + qsargs = urllib.parse.parse_qs(qs, strict_parsing=True, keep_blank_values=True, errors='strict') + args: Dict[str, RequestArgument] = dict() + for k, vs in qsargs: + args[k] = RequestArgument(k) + for v in vs: + # text/plain usually is a sensible choice for values decoded from urlencoded strings + # IF ALREADY IN STRING FORM (which parse_qs does)! + args[k].append('text/plain', v) + + if 'username' in args and args['username'].is_scalar: + username: str = args['username'].get_str() + else: + raise ValueError() + + """ def __init__(self, name: str, value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]] = None) -> None: + """ + Create a new RequestArgument with a name and optionally an initial value. + + :param name: The name for this argument, as provided via GET or POST. + :param value: The initial value, if any. Optional, initializes with empty array if omitted. + """ + # Assign name self.__name: str = name + # Initialize value self.__value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]] = None + # Default to empty array if value is None: self.__value = [] else: if isinstance(value, list): if len(value) == 1: + # An array of length 1 will be reduced to a scalar self.__value = value[0] else: + # Store the array self.__value = value else: + # Scalar value, simply store self.__value = value @property def is_array(self) -> bool: + """ + :return: True, if the value is a (possibly empty) array, False otherwise. + """ return isinstance(self.__value, list) @property def is_scalar(self) -> bool: + """ + :return: True, if the value is a single scalar value, False otherwise. + """ return not isinstance(self.__value, list) @property def is_view(self) -> bool: + """ + :return: True, if this instance is an immutable view, False otherwise. + """ return False @property def name(self) -> str: + """ + :return: The name of this argument. + """ return self.__name - def get_str(self, index: int = None) -> Optional[str]: + def get_str(self, index: int = None) -> str: + """ + Attempts to return a value as a string. If this instance is an scalar, no index must be provided. If this + instance is an array, an index must be provided. + + :param index: For array values: The index of the value to retrieve. For scalar values: None (default). + :return: An UTF-8 string representation of the requested value. + :raises UnicodeDecodeError: If the value cannot be decoded into an UTF-8 string. + :raises IndexError: If the index is out of bounds. + :raises ValueError: If this is an array value, and no index is provided, or if this is a scalar value and an + index is provided. + """ if self.is_array: + # instance is an array value if index is None: + # Needs an index for array values raise ValueError('index must not be None') + # Type hint; access array element v: Tuple[str, Union[bytes, str]] = self.__value[index] if isinstance(v[1], str): + # The value already is a string, return return v[1] elif isinstance(v[1], bytes): + # The value is a bytes object, attempt to decode return v[1].decode('utf-8') else: + # instance is a scalar value if index is not None: + # Must not have an index for array values raise ValueError('index must be None') if isinstance(self.__value[1], str): + # The value already is a string, return return self.__value[1] elif isinstance(self.__value[1], bytes): + # The value is a bytes object, attempt to decode return self.__value[1].decode('utf-8') - def get_bytes(self, index: int = None) -> Optional[bytes]: + def get_bytes(self, index: int = None) -> bytes: + """ + Attempts to return a value as a bytes object. If this instance is an scalar, no index must be provided. If + this instance is an array, an index must be provided. + + :param index: For array values: The index of the value to retrieve. For scalar values: None (default). + :return: A bytes object representation of the requested value. Strings will be encoded as UTF-8. + :raises IndexError: If the index is out of bounds. + :raises ValueError: If this is an array value, and no index is provided, or if this is a scalar value and an + index is provided. + """ if self.is_array: + # instance is an array value if index is None: + # Needs an index for array values raise ValueError('index must not be None') + # Type hint; access array element v: Tuple[str, Union[bytes, str]] = self.__value[index] if isinstance(v[1], bytes): + # The value already is a bytes object, return return v[1] elif isinstance(v[1], str): + # The value is a string, encode first return v[1].encode('utf-8') else: + # instance is a scalar value if index is not None: + # Must not have an index for array values raise ValueError('index must be None') if isinstance(self.__value[1], bytes): + # The value already is a bytes object, return return self.__value[1] elif isinstance(self.__value[1], str): + # The value is a string, encode first return self.__value[1].encode('utf-8') def get_content_type(self, index: int = None) -> Optional[str]: + """ + Attempts to retrieve a value's Content-Type. If this instance is an scalar, no index must be provided. If this + instance is an array, an index must be provided. + + :param index: For array values: The index of the value to retrieve. For scalar values: None (default). + :return: The Content-Type of the requested value, as sent by the client. Not necessarily trustworthy. + :raises IndexError: If the index is out of bounds. + :raises ValueError: If this is an array value, and no index is provided, or if this is a scalar value and an + index is provided. + """ if self.is_array: + # instance is an array value if index is None: + # Needs an index for array values raise ValueError('index must not be None') + # Type hint; access array element v: Tuple[str, Union[bytes, str]] = self.__value[index] + # Return the content type of the requested value return v[0] else: + # instance is a scalar value if index is not None: + # Must not have an index for array values raise ValueError('index must be None') + # Return the content type of the scalar value return self.__value[0] def append(self, ctype: str, value: Union[str, bytes]): + """ + Append a value to this instance. Turns an empty argument into a scalar and a scalar into an array. + + :param ctype: The Content-Type, as provided in the request. + :param value: The scalar value to append, either a string or bytes object. + :raises TypeError: If called on an immutable view. + """ if self.is_view: + # This is an immutable view, raise exception raise TypeError('A RequestArgument view is immutable!') if len(self) == 0: + # Turn an empty argument into a scalar self.__value = ctype, value else: + # First turn the scalar into a one-element array ... if self.is_scalar: self.__value = [self.__value] + # ... then append the new value self.__value.append((ctype, value)) - def __len__(self): + def __len__(self) -> int: + """ + :return: Number of values for this argument. + """ return len(self.__value) if self.is_array else 1 - def __iter__(self): + def __iter__(self) -> Iterator['RequestArgument']: + """ + Iterate the values of this argument. Each value is accessible as if it were a scalar RequestArgument in turn, + although they are immutable. + + :return: An iterator that yields immutable views of the values. + """ if self.is_scalar: + # If this is a scalar, yield an immutable view of the single value yield _View(self.__name, self.__value) else: # Typing helper _value: List[Tuple[str, Union[bytes, str]]] = self.__value for v in _value: + # If this is an array, yield an immutable scalar view for each (ctype, value) element in the array yield _View(self.__name, v) def __getitem__(self, index: Union[int, slice]): + """ + Index the argument with either an int or a slice. The returned values are represented as immutable + RequestArgument views. Scalar arguments may be indexed with int(0). + :param index: The index or slice. + :return: An immutable view of the indexed elements of this argument. + """ if self.is_scalar: + # Scalars may only be indexed with 0 if index == 0: + # Return an immutable view of the single scalar value return _View(self.__name, self.__value) raise ValueError('Scalar RequestArgument only indexable with 0') + # Pass the index or slice through to the array, packing the result in an immutable view return _View(self.__name, self.__value[index]) class _View(RequestArgument): + """ + This class represents an immutable view of a (subset of a) RequestArgument object. Should not be instantiated + directly. + """ - def __init__(self, name: str, value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]]): + def __init__(self, name: str, value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]])\ + -> None: + """ + Create a new immutable view of a (subset of a) RequestArgument. + + :param name: The name for this argument, same as in the original RequestArgument. + :param value: The values to represent in this view, obtained by e.g. indexing or slicing. + """ super().__init__(name, value) @property def is_view(self) -> bool: + """ + :return: True, if this instance is an immutable view, False otherwise. + """ return True diff --git a/matemat/webserver/util.py b/matemat/webserver/util.py index 85ef721..61cf430 100644 --- a/matemat/webserver/util.py +++ b/matemat/webserver/util.py @@ -69,6 +69,7 @@ def _parse_multipart(body: bytes, boundary: str) -> List[RequestArgument]: args[name] = RequestArgument(name) args[name].append(hdr['Content-Type'].strip(), part) if not has_name: + # Content-Disposition header without name attribute raise ValueError('mutlipart/form-data part without name attribute') return list(args.values()) From ab9e470c353d1b36fb55bd72dc0cba62277a22ba Mon Sep 17 00:00:00 2001 From: s3lph Date: Fri, 29 Jun 2018 01:22:12 +0200 Subject: [PATCH 5/9] Some more type hinting/safety. --- matemat/webserver/requestargs.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/matemat/webserver/requestargs.py b/matemat/webserver/requestargs.py index 02df0b2..f69a53d 100644 --- a/matemat/webserver/requestargs.py +++ b/matemat/webserver/requestargs.py @@ -43,7 +43,7 @@ class RequestArgument(object): # Assign name self.__name: str = name # Initialize value - self.__value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]] = None + self.__value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]] = [] # Default to empty array if value is None: self.__value = [] @@ -98,6 +98,7 @@ class RequestArgument(object): :raises IndexError: If the index is out of bounds. :raises ValueError: If this is an array value, and no index is provided, or if this is a scalar value and an index is provided. + :raises TypeError: If the requested value is neither a str nor a bytes object. """ if self.is_array: # instance is an array value @@ -123,6 +124,7 @@ class RequestArgument(object): elif isinstance(self.__value[1], bytes): # The value is a bytes object, attempt to decode return self.__value[1].decode('utf-8') + raise TypeError('Value is neither a str nor bytes') def get_bytes(self, index: int = None) -> bytes: """ @@ -134,6 +136,7 @@ class RequestArgument(object): :raises IndexError: If the index is out of bounds. :raises ValueError: If this is an array value, and no index is provided, or if this is a scalar value and an index is provided. + :raises TypeError: If the requested value is neither a str nor a bytes object. """ if self.is_array: # instance is an array value @@ -159,6 +162,7 @@ class RequestArgument(object): elif isinstance(self.__value[1], str): # The value is a string, encode first return self.__value[1].encode('utf-8') + raise TypeError('Value is neither a str nor bytes') def get_content_type(self, index: int = None) -> Optional[str]: """ @@ -177,16 +181,18 @@ class RequestArgument(object): # Needs an index for array values raise ValueError('index must not be None') # Type hint; access array element - v: Tuple[str, Union[bytes, str]] = self.__value[index] + va: Tuple[str, Union[bytes, str]] = self.__value[index] # Return the content type of the requested value - return v[0] + return va[0] else: # instance is a scalar value if index is not None: # Must not have an index for array values raise ValueError('index must be None') + # Type hint + vs: Tuple[str, Union[bytes, str]] = self.__value # Return the content type of the scalar value - return self.__value[0] + return vs[0] def append(self, ctype: str, value: Union[str, bytes]): """ @@ -227,8 +233,8 @@ class RequestArgument(object): yield _View(self.__name, self.__value) else: # Typing helper - _value: List[Tuple[str, Union[bytes, str]]] = self.__value - for v in _value: + vs: List[Tuple[str, Union[bytes, str]]] = self.__value + for v in vs: # If this is an array, yield an immutable scalar view for each (ctype, value) element in the array yield _View(self.__name, v) From 73c7dbe89f28dda245322153ffdc5404736eebe0 Mon Sep 17 00:00:00 2001 From: s3lph Date: Fri, 29 Jun 2018 18:11:26 +0200 Subject: [PATCH 6/9] Documentation of RequestArgument test cases. --- matemat/webserver/requestargs.py | 2 +- matemat/webserver/test/test_requestargs.py | 148 ++++++++++++++++++++- 2 files changed, 145 insertions(+), 5 deletions(-) diff --git a/matemat/webserver/requestargs.py b/matemat/webserver/requestargs.py index f69a53d..1a56aeb 100644 --- a/matemat/webserver/requestargs.py +++ b/matemat/webserver/requestargs.py @@ -250,7 +250,7 @@ class RequestArgument(object): if index == 0: # Return an immutable view of the single scalar value return _View(self.__name, self.__value) - raise ValueError('Scalar RequestArgument only indexable with 0') + raise IndexError('Scalar RequestArgument only indexable with 0') # Pass the index or slice through to the array, packing the result in an immutable view return _View(self.__name, self.__value[index]) diff --git a/matemat/webserver/test/test_requestargs.py b/matemat/webserver/test/test_requestargs.py index dcdde14..133ea4a 100644 --- a/matemat/webserver/test/test_requestargs.py +++ b/matemat/webserver/test/test_requestargs.py @@ -9,196 +9,336 @@ from matemat.webserver.requestargs import _View class TestRequestArguments(unittest.TestCase): + """ + Test cases for the RequestArgument class. + """ def test_create_default(self): + """ + Test creation of an empty RequestArgument + """ ra = RequestArgument('foo') + # Name must be set to 1st argument self.assertEqual('foo', ra.name) + # Must be a 0-length array self.assertEqual(0, len(ra)) self.assertFalse(ra.is_scalar) self.assertTrue(ra.is_array) + # Must not be a view self.assertFalse(ra.is_view) def test_create_str_scalar(self): + """ + Test creation of a scalar RequestArgument with string value. + """ ra = RequestArgument('foo', ('text/plain', 'bar')) + # Name must be set to 1st argument self.assertEqual('foo', ra.name) + # Must be a scalar, length 1 self.assertEqual(1, len(ra)) self.assertTrue(ra.is_scalar) self.assertFalse(ra.is_array) + # Scalar value must be representable both as str and bytes self.assertEqual('bar', ra.get_str()) self.assertEqual(b'bar', ra.get_bytes()) + # Content-Type must be set correctly self.assertEqual('text/plain', ra.get_content_type()) + # Using indices must result in an error with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_str(0)) with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_bytes(0)) with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_content_type(0)) + # Must not be a view self.assertFalse(ra.is_view) def test_create_str_scalar_array(self): + """ + Test creation of a scalar RequestArgument with string value, passing an array instead of a single tuple. + """ ra = RequestArgument('foo', [('text/plain', 'bar')]) + # Name must be set to 1st argument self.assertEqual('foo', ra.name) + # Must be a scalar, length 1 self.assertEqual(1, len(ra)) self.assertTrue(ra.is_scalar) self.assertFalse(ra.is_array) + # Scalar value must be representable both as str and bytes self.assertEqual('bar', ra.get_str()) self.assertEqual(b'bar', ra.get_bytes()) + # Content-Type must be set correctly self.assertEqual('text/plain', ra.get_content_type()) + # Using indices must result in an error with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_str(0)) with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_bytes(0)) with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_content_type(0)) + # Must not be a view self.assertFalse(ra.is_view) def test_create_bytes_scalar(self): + """ + Test creation of a scalar RequestArgument with bytes value. + """ ra = RequestArgument('foo', ('application/octet-stream', b'\x00\x80\xff\xfe')) + # Name must be set to 1st argument self.assertEqual('foo', ra.name) + # Must be a scalar, length 1 self.assertEqual(1, len(ra)) self.assertTrue(ra.is_scalar) self.assertFalse(ra.is_array) + # Conversion to UTF-8 string must fail; bytes representation must work with self.assertRaises(UnicodeDecodeError): ra.get_str() self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes()) + # Content-Type must be set correctly self.assertEqual('application/octet-stream', ra.get_content_type()) + # Using indices must result in an error with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_str(0)) with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_bytes(0)) with self.assertRaises(ValueError): self.assertEqual('bar', ra.get_content_type(0)) + # Must not be a view self.assertFalse(ra.is_view) def test_create_array(self): + """ + Test creation of an array RequestArgument with mixed str and bytes initial value. + """ ra = RequestArgument('foo', [ ('text/plain', 'bar'), ('application/octet-stream', b'\x00\x80\xff\xfe') ]) + # Name must be set to 1st argument self.assertEqual('foo', ra.name) + # Must be an array, length 2 self.assertEqual(2, len(ra)) self.assertFalse(ra.is_scalar) self.assertTrue(ra.is_array) + # Retrieving values without an index must fail with self.assertRaises(ValueError): ra.get_str() with self.assertRaises(ValueError): ra.get_bytes() with self.assertRaises(ValueError): ra.get_content_type() + # The first value must be representable both as str and bytes, and have ctype text/plain self.assertEqual('bar', ra.get_str(0)) self.assertEqual(b'bar', ra.get_bytes(0)) self.assertEqual('text/plain', ra.get_content_type(0)) + # Conversion of the second value to UTF-8 string must fail; bytes representation must work with self.assertRaises(UnicodeDecodeError): ra.get_str(1) self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes(1)) + # The second value's ctype must be correct self.assertEqual('application/octet-stream', ra.get_content_type(1)) + # Must not be a view self.assertFalse(ra.is_view) def test_append_empty_str(self): + """ + Test appending a str value to an empty RequestArgument. + """ + # Initialize the empty RequestArgument ra = RequestArgument('foo') self.assertEqual(0, len(ra)) self.assertFalse(ra.is_scalar) + # Append a string value ra.append('text/plain', 'bar') + # New length must be 1, empty array must be converted to scalar self.assertEqual(1, len(ra)) self.assertTrue(ra.is_scalar) + # Retrieval of the new value must work both in str and bytes representation self.assertEqual('bar', ra.get_str()) self.assertEqual(b'bar', ra.get_bytes()) + # Content type of the new value must be correct self.assertEqual('text/plain', ra.get_content_type()) + # Must not be a view self.assertFalse(ra.is_view) def test_append_empty_bytes(self): + """ + Test appending a bytes value to an empty RequestArgument. + """ + # Initialize the empty RequestArgument ra = RequestArgument('foo') self.assertEqual(0, len(ra)) self.assertFalse(ra.is_scalar) + # Append a bytes value ra.append('application/octet-stream', b'\x00\x80\xff\xfe') + # New length must be 1, empty array must be converted to scalar self.assertEqual(1, len(ra)) self.assertTrue(ra.is_scalar) + # Conversion of the new value to UTF-8 string must fail; bytes representation must work with self.assertRaises(UnicodeDecodeError): ra.get_str() self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes()) + # Content type of the new value must be correct self.assertEqual('application/octet-stream', ra.get_content_type()) + # Must not be a view self.assertFalse(ra.is_view) def test_append_multiple(self): + """ + Test appending multiple values to an empty RequestArgument. + """ + # Initialize the empty RequestArgument ra = RequestArgument('foo') self.assertEqual(0, len(ra)) self.assertFalse(ra.is_scalar) + # Append a first value ra.append('text/plain', 'bar') + # New length must be 1, empty array must be converted to scalar self.assertEqual(1, len(ra)) self.assertTrue(ra.is_scalar) + self.assertEqual(b'bar', ra.get_bytes()) + # Append a second value ra.append('application/octet-stream', b'\x00\x80\xff\xfe') + # New length must be 2, scalar must be converted to array self.assertEqual(2, len(ra)) self.assertFalse(ra.is_scalar) + self.assertEqual(b'bar', ra.get_bytes(0)) + self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes(1)) + # Append a third value ra.append('text/plain', 'Hello, World!') + # New length must be 3, array must remain array self.assertEqual(3, len(ra)) self.assertFalse(ra.is_scalar) + self.assertEqual(b'bar', ra.get_bytes(0)) + self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes(1)) + self.assertEqual(b'Hello, World!', ra.get_bytes(2)) def test_iterate_empty(self): + """ + Test iterating an empty RequestArgument. + """ ra = RequestArgument('foo') self.assertEqual(0, len(ra)) + # No value must be yielded from iterating an empty instance for _ in ra: self.fail() def test_iterate_scalar(self): + """ + Test iterating a scalar RequestArgument. + """ ra = RequestArgument('foo', ('text/plain', 'bar')) self.assertTrue(ra.is_scalar) + # Counter for the number of iterations count: int = 0 for it in ra: + # Make sure the yielded value is a scalar view and has the same name as the original instance self.assertIsInstance(it, _View) - self.assertEqual('foo', it.name) self.assertTrue(it.is_view) + self.assertEqual('foo', it.name) self.assertTrue(it.is_scalar) count += 1 + # Only one value must be yielded from iterating a scalar instance self.assertEqual(1, count) def test_iterate_array(self): + """ + Test iterating an array RequestArgument. + """ ra = RequestArgument('foo', [('text/plain', 'bar'), ('abc', b'def'), ('xyz', '1337')]) self.assertFalse(ra.is_scalar) + # Container to put the iterated ctypes into items: List[str] = list() for it in ra: + # Make sure the yielded values are scalar views and have the same name as the original instance self.assertIsInstance(it, _View) self.assertTrue(it.is_view) self.assertTrue(it.is_scalar) + # Collect the value's ctype items.append(it.get_content_type()) + # Compare collected ctypes with expected result self.assertEqual(['text/plain', 'abc', 'xyz'], items) - def test_iterate_sliced(self): + def test_slice(self): + """ + Test slicing an array RequestArgument. + """ ra = RequestArgument('foo', [('a', 'b'), ('c', 'd'), ('e', 'f'), ('g', 'h'), ('i', 'j'), ('k', 'l')]) - self.assertFalse(ra.is_scalar) + # Create the sliced view + sliced = ra[1:4:2] + # Make sure the sliced value is a view + self.assertIsInstance(sliced, _View) + self.assertTrue(sliced.is_view) + # Make sure the slice has the same name + self.assertEqual('foo', sliced.name) + # Make sure the slice has the expected shape (array of the 2nd and 4th scalar in the original) + self.assertTrue(sliced.is_array) + self.assertEqual(2, len(sliced)) + self.assertEqual('d', sliced.get_str(0)) + self.assertEqual('h', sliced.get_str(1)) + + def test_iterate_sliced(self): + """ + Test iterating a sliced array RequestArgument. + """ + ra = RequestArgument('foo', [('a', 'b'), ('c', 'd'), ('e', 'f'), ('g', 'h'), ('i', 'j'), ('k', 'l')]) + # Container to put the iterated ctypes into items: List[str] = list() + # Iterate the sliced view for it in ra[1:4:2]: + # Make sure the yielded values are scalar views and have the same name as the original instance self.assertIsInstance(it, _View) self.assertTrue(it.is_view) + self.assertEqual('foo', it.name) self.assertTrue(it.is_scalar) items.append(it.get_content_type()) + # Make sure the expected values are collected (array of the 2nd and 4th scalar in the original) self.assertEqual(['c', 'g'], items) def test_index_scalar(self): + """ + Test indexing of a scalar RequestArgument. + """ ra = RequestArgument('foo', ('bar', 'baz')) + # Index the scalar RequestArgument instance, obtaining an immutable view it = ra[0] + # Make sure the value is a scalar view with the same properties as the original instance self.assertIsInstance(it, _View) + self.assertTrue(it.is_scalar) self.assertEqual('foo', it.name) self.assertEqual('bar', it.get_content_type()) self.assertEqual('baz', it.get_str()) - with self.assertRaises(ValueError): + # Make sure other indices don't work + with self.assertRaises(IndexError): _ = ra[1] def test_index_array(self): + """ + Test indexing of an array RequestArgument. + """ ra = RequestArgument('foo', [('a', 'b'), ('c', 'd')]) + # Index the array RequestArgument instance, obtaining an immutable view it = ra[1] + # Make sure the value is a scalar view with the same properties as the value in the original instance self.assertIsInstance(it, _View) self.assertEqual('foo', it.name) self.assertEqual('c', it.get_content_type()) self.assertEqual('d', it.get_str()) def test_view_immutable(self): + """ + Test immutability of views. + """ ra = RequestArgument('foo', ('bar', 'baz')) + # Index the scalar RequestArgument instance, obtaining an immutable view it = ra[0] + # Make sure the returned value is a view self.assertIsInstance(it, _View) + # Make sure the returned value is immutable with self.assertRaises(TypeError): it.append('foo', 'bar') From 21a927046d43b37c883973c10efc54b43d70d646 Mon Sep 17 00:00:00 2001 From: s3lph Date: Fri, 29 Jun 2018 18:29:51 +0200 Subject: [PATCH 7/9] Reworked RequestArgument API to somewhat more lax concerning 0-indices, potentially leading to safer code. --- matemat/webserver/requestargs.py | 177 +++++++-------------- matemat/webserver/test/test_requestargs.py | 66 ++++---- 2 files changed, 94 insertions(+), 149 deletions(-) diff --git a/matemat/webserver/requestargs.py b/matemat/webserver/requestargs.py index 1a56aeb..dcad518 100644 --- a/matemat/webserver/requestargs.py +++ b/matemat/webserver/requestargs.py @@ -43,35 +43,31 @@ class RequestArgument(object): # Assign name self.__name: str = name # Initialize value - self.__value: Union[Tuple[str, Union[bytes, str]], List[Tuple[str, Union[bytes, str]]]] = [] + self.__value: List[Tuple[str, Union[bytes, str]]] = [] # Default to empty array if value is None: self.__value = [] else: if isinstance(value, list): - if len(value) == 1: - # An array of length 1 will be reduced to a scalar - self.__value = value[0] - else: - # Store the array - self.__value = value - else: - # Scalar value, simply store + # Store the array self.__value = value + else: + # Turn scalar into an array before storing + self.__value = [value] @property def is_array(self) -> bool: """ :return: True, if the value is a (possibly empty) array, False otherwise. """ - return isinstance(self.__value, list) + return len(self.__value) != 1 @property def is_scalar(self) -> bool: """ :return: True, if the value is a single scalar value, False otherwise. """ - return not isinstance(self.__value, list) + return len(self.__value) == 1 @property def is_view(self) -> bool: @@ -87,112 +83,70 @@ class RequestArgument(object): """ return self.__name - def get_str(self, index: int = None) -> str: + def get_str(self, index: int = 0) -> str: """ - Attempts to return a value as a string. If this instance is an scalar, no index must be provided. If this - instance is an array, an index must be provided. + Attempts to return a value as a string. The index defaults to 0. - :param index: For array values: The index of the value to retrieve. For scalar values: None (default). + :param index: The index of the value to retrieve. Default: 0. :return: An UTF-8 string representation of the requested value. :raises UnicodeDecodeError: If the value cannot be decoded into an UTF-8 string. :raises IndexError: If the index is out of bounds. - :raises ValueError: If this is an array value, and no index is provided, or if this is a scalar value and an - index is provided. + :raises ValueError: If the index is not an int. :raises TypeError: If the requested value is neither a str nor a bytes object. """ - if self.is_array: - # instance is an array value - if index is None: - # Needs an index for array values - raise ValueError('index must not be None') - # Type hint; access array element - v: Tuple[str, Union[bytes, str]] = self.__value[index] - if isinstance(v[1], str): - # The value already is a string, return - return v[1] - elif isinstance(v[1], bytes): - # The value is a bytes object, attempt to decode - return v[1].decode('utf-8') - else: - # instance is a scalar value - if index is not None: - # Must not have an index for array values - raise ValueError('index must be None') - if isinstance(self.__value[1], str): - # The value already is a string, return - return self.__value[1] - elif isinstance(self.__value[1], bytes): - # The value is a bytes object, attempt to decode - return self.__value[1].decode('utf-8') + if not isinstance(index, int): + # Index must be an int + raise ValueError('index must not be None') + # Type hint; access array element + v: Tuple[str, Union[bytes, str]] = self.__value[index] + if isinstance(v[1], str): + # The value already is a string, return + return v[1] + elif isinstance(v[1], bytes): + # The value is a bytes object, attempt to decode + return v[1].decode('utf-8') raise TypeError('Value is neither a str nor bytes') - def get_bytes(self, index: int = None) -> bytes: + def get_bytes(self, index: int = 0) -> bytes: """ - Attempts to return a value as a bytes object. If this instance is an scalar, no index must be provided. If - this instance is an array, an index must be provided. + Attempts to return a value as a bytes object. The index defaults to 0. - :param index: For array values: The index of the value to retrieve. For scalar values: None (default). + :param index: The index of the value to retrieve. Default: 0. :return: A bytes object representation of the requested value. Strings will be encoded as UTF-8. :raises IndexError: If the index is out of bounds. - :raises ValueError: If this is an array value, and no index is provided, or if this is a scalar value and an - index is provided. + :raises ValueError: If the index is not an int. :raises TypeError: If the requested value is neither a str nor a bytes object. """ - if self.is_array: - # instance is an array value - if index is None: - # Needs an index for array values - raise ValueError('index must not be None') - # Type hint; access array element - v: Tuple[str, Union[bytes, str]] = self.__value[index] - if isinstance(v[1], bytes): - # The value already is a bytes object, return - return v[1] - elif isinstance(v[1], str): - # The value is a string, encode first - return v[1].encode('utf-8') - else: - # instance is a scalar value - if index is not None: - # Must not have an index for array values - raise ValueError('index must be None') - if isinstance(self.__value[1], bytes): - # The value already is a bytes object, return - return self.__value[1] - elif isinstance(self.__value[1], str): - # The value is a string, encode first - return self.__value[1].encode('utf-8') + if not isinstance(index, int): + # Index must be a int + raise ValueError('index must not be None') + # Type hint; access array element + v: Tuple[str, Union[bytes, str]] = self.__value[index] + if isinstance(v[1], bytes): + # The value already is a bytes object, return + return v[1] + elif isinstance(v[1], str): + # The value is a string, encode first + return v[1].encode('utf-8') raise TypeError('Value is neither a str nor bytes') - def get_content_type(self, index: int = None) -> Optional[str]: + def get_content_type(self, index: int = 0) -> str: """ - Attempts to retrieve a value's Content-Type. If this instance is an scalar, no index must be provided. If this - instance is an array, an index must be provided. + Attempts to retrieve a value's Content-Type. The index defaults to 0. - :param index: For array values: The index of the value to retrieve. For scalar values: None (default). + :param index: The index of the value to retrieve. Default: 0. :return: The Content-Type of the requested value, as sent by the client. Not necessarily trustworthy. :raises IndexError: If the index is out of bounds. - :raises ValueError: If this is an array value, and no index is provided, or if this is a scalar value and an - index is provided. + :raises ValueError: If the index is not an int. """ - if self.is_array: - # instance is an array value - if index is None: - # Needs an index for array values - raise ValueError('index must not be None') - # Type hint; access array element - va: Tuple[str, Union[bytes, str]] = self.__value[index] - # Return the content type of the requested value - return va[0] - else: - # instance is a scalar value - if index is not None: - # Must not have an index for array values - raise ValueError('index must be None') - # Type hint - vs: Tuple[str, Union[bytes, str]] = self.__value - # Return the content type of the scalar value - return vs[0] + # instance is an array value + if not isinstance(index, int): + # Needs an index for array values + raise ValueError('index must not be None') + # Type hint; access array element + va: Tuple[str, Union[bytes, str]] = self.__value[index] + # Return the content type of the requested value + return va[0] def append(self, ctype: str, value: Union[str, bytes]): """ @@ -205,21 +159,13 @@ class RequestArgument(object): if self.is_view: # This is an immutable view, raise exception raise TypeError('A RequestArgument view is immutable!') - if len(self) == 0: - # Turn an empty argument into a scalar - self.__value = ctype, value - else: - # First turn the scalar into a one-element array ... - if self.is_scalar: - self.__value = [self.__value] - # ... then append the new value - self.__value.append((ctype, value)) + self.__value.append((ctype, value)) def __len__(self) -> int: """ :return: Number of values for this argument. """ - return len(self.__value) if self.is_array else 1 + return len(self.__value) def __iter__(self) -> Iterator['RequestArgument']: """ @@ -228,29 +174,18 @@ class RequestArgument(object): :return: An iterator that yields immutable views of the values. """ - if self.is_scalar: - # If this is a scalar, yield an immutable view of the single value - yield _View(self.__name, self.__value) - else: - # Typing helper - vs: List[Tuple[str, Union[bytes, str]]] = self.__value - for v in vs: - # If this is an array, yield an immutable scalar view for each (ctype, value) element in the array - yield _View(self.__name, v) + for v in self.__value: + # Yield an immutable scalar view for each (ctype, value) element in the array + yield _View(self.__name, v) def __getitem__(self, index: Union[int, slice]): """ Index the argument with either an int or a slice. The returned values are represented as immutable - RequestArgument views. Scalar arguments may be indexed with int(0). + RequestArgument views. + :param index: The index or slice. :return: An immutable view of the indexed elements of this argument. """ - if self.is_scalar: - # Scalars may only be indexed with 0 - if index == 0: - # Return an immutable view of the single scalar value - return _View(self.__name, self.__value) - raise IndexError('Scalar RequestArgument only indexable with 0') # Pass the index or slice through to the array, packing the result in an immutable view return _View(self.__name, self.__value[index]) diff --git a/matemat/webserver/test/test_requestargs.py b/matemat/webserver/test/test_requestargs.py index 133ea4a..3383863 100644 --- a/matemat/webserver/test/test_requestargs.py +++ b/matemat/webserver/test/test_requestargs.py @@ -43,13 +43,17 @@ class TestRequestArguments(unittest.TestCase): self.assertEqual(b'bar', ra.get_bytes()) # Content-Type must be set correctly self.assertEqual('text/plain', ra.get_content_type()) - # Using indices must result in an error - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_str(0)) - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_bytes(0)) - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_content_type(0)) + # Using 0 indices must yield the same results + self.assertEqual('bar', ra.get_str(0)) + self.assertEqual(b'bar', ra.get_bytes(0)) + self.assertEqual('text/plain', ra.get_content_type(0)) + # Using other indices must result in an error + with self.assertRaises(IndexError): + ra.get_str(1) + with self.assertRaises(IndexError): + ra.get_bytes(1) + with self.assertRaises(IndexError): + ra.get_content_type(1) # Must not be a view self.assertFalse(ra.is_view) @@ -69,13 +73,17 @@ class TestRequestArguments(unittest.TestCase): self.assertEqual(b'bar', ra.get_bytes()) # Content-Type must be set correctly self.assertEqual('text/plain', ra.get_content_type()) - # Using indices must result in an error - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_str(0)) - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_bytes(0)) - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_content_type(0)) + # Using 0 indices must yield the same results + self.assertEqual('bar', ra.get_str(0)) + self.assertEqual(b'bar', ra.get_bytes(0)) + self.assertEqual('text/plain', ra.get_content_type(0)) + # Using other indices must result in an error + with self.assertRaises(IndexError): + ra.get_str(1) + with self.assertRaises(IndexError): + ra.get_bytes(1) + with self.assertRaises(IndexError): + ra.get_content_type(1) # Must not be a view self.assertFalse(ra.is_view) @@ -96,13 +104,18 @@ class TestRequestArguments(unittest.TestCase): self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes()) # Content-Type must be set correctly self.assertEqual('application/octet-stream', ra.get_content_type()) - # Using indices must result in an error - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_str(0)) - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_bytes(0)) - with self.assertRaises(ValueError): - self.assertEqual('bar', ra.get_content_type(0)) + # Using 0 indices must yield the same results + with self.assertRaises(UnicodeDecodeError): + ra.get_str(0) + self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes(0)) + self.assertEqual('application/octet-stream', ra.get_content_type(0)) + # Using other indices must result in an error + with self.assertRaises(IndexError): + ra.get_str(1) + with self.assertRaises(IndexError): + ra.get_bytes(1) + with self.assertRaises(IndexError): + ra.get_content_type(1) # Must not be a view self.assertFalse(ra.is_view) @@ -120,13 +133,10 @@ class TestRequestArguments(unittest.TestCase): self.assertEqual(2, len(ra)) self.assertFalse(ra.is_scalar) self.assertTrue(ra.is_array) - # Retrieving values without an index must fail - with self.assertRaises(ValueError): - ra.get_str() - with self.assertRaises(ValueError): - ra.get_bytes() - with self.assertRaises(ValueError): - ra.get_content_type() + # Retrieving values without an index must yield the first element + self.assertEqual('bar', ra.get_str()) + self.assertEqual(b'bar', ra.get_bytes()) + self.assertEqual('text/plain', ra.get_content_type()) # The first value must be representable both as str and bytes, and have ctype text/plain self.assertEqual('bar', ra.get_str(0)) self.assertEqual(b'bar', ra.get_bytes(0)) From 2f927cec41d282a9b6e1485b80b8f5a015fde42b Mon Sep 17 00:00:00 2001 From: s3lph Date: Fri, 29 Jun 2018 21:56:22 +0200 Subject: [PATCH 8/9] Implemented a container for RequestArgument instances, with some more unit tests. --- matemat/webserver/__init__.py | 2 +- matemat/webserver/requestargs.py | 140 ++++++++++++++-- matemat/webserver/test/test_requestargs.py | 179 ++++++++++++++++++++- 3 files changed, 303 insertions(+), 18 deletions(-) diff --git a/matemat/webserver/__init__.py b/matemat/webserver/__init__.py index 1b4ab06..c52368e 100644 --- a/matemat/webserver/__init__.py +++ b/matemat/webserver/__init__.py @@ -6,5 +6,5 @@ API that can be used by 'pagelets' - single pages of a web service. If a reques server will attempt to serve the request with a static resource in a previously configured webroot directory. """ -from .requestargs import RequestArgument +from .requestargs import RequestArgument, RequestArguments from .httpd import MatematWebserver, HttpHandler, pagelet diff --git a/matemat/webserver/requestargs.py b/matemat/webserver/requestargs.py index dcad518..2150b31 100644 --- a/matemat/webserver/requestargs.py +++ b/matemat/webserver/requestargs.py @@ -1,5 +1,79 @@ -from typing import Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Tuple, Union + + +class RequestArguments(object): + """ + Container for HTTP Request arguments. + + Usage: + + # Create empty instance + ra = RequestArguments() + # Add an entry for the key 'foo' with the value 'bar' and Content-Type 'text/plain' + ra['foo'].append('text/plain', 'bar') + # Retrieve the value for the key 'foo', as a string... + foo = str(ra.foo) + # ... or as raw bytes + foo = bytes(ra.foo) + """ + + def __init__(self) -> None: + """ + Create an empty container instance. + """ + self.__container: Dict[str, RequestArgument] = dict() + + def __getitem__(self, key: str) -> 'RequestArgument': + """ + Retrieve the argument for the given name, creating it on the fly, if it doesn't exist. + + :param key: Name of the argument to retrieve. + :return: A RequestArgument instance. + :raises TypeError: If key is not a string. + """ + if not isinstance(key, str): + raise TypeError('key must be a str') + # Create empty argument, if it doesn't exist + if key not in self.__container: + self.__container[key] = RequestArgument(key) + # Return the argument for the name + return self.__container[key] + + def __getattr__(self, key: str) -> 'RequestArgument': + """ + Syntactic sugar for accessing values with a name that can be used in Python attributes. The value will be + returned as an immutable view. + + :param key: Name of the argument to retrieve. + :return: An immutable view of the RequestArgument instance. + """ + return _View.of(self.__container[key]) + + def __iter__(self) -> Iterator['RequestArguments']: + """ + Returns an iterator over the values in this instance. Values are represented as immutable views. + + :return: An iterator that yields immutable views of the values. + """ + for ra in self.__container.values(): + # Yield an immutable scalar view for each value + yield _View.of(ra) + + def __contains__(self, key: str) -> bool: + """ + Checks whether an argument with a given name exists in the RequestArguments instance. + + :param key: The name to check whether it exists. + :return: True, if present, False otherwise. + """ + return key in self.__container + + def __len__(self) -> int: + """ + :return: The number of arguments in this instance. + """ + return len(self.__container) class RequestArgument(object): @@ -16,18 +90,16 @@ class RequestArgument(object): Usage example: qsargs = urllib.parse.parse_qs(qs, strict_parsing=True, keep_blank_values=True, errors='strict') - args: Dict[str, RequestArgument] = dict() + args: RequestArguments for k, vs in qsargs: - args[k] = RequestArgument(k) + args[k].clear() for v in vs: # text/plain usually is a sensible choice for values decoded from urlencoded strings # IF ALREADY IN STRING FORM (which parse_qs does)! args[k].append('text/plain', v) - if 'username' in args and args['username'].is_scalar: - username: str = args['username'].get_str() - else: - raise ValueError() + if 'username' in args and args.username.is_scalar: + username = str(args.username) """ @@ -91,12 +163,12 @@ class RequestArgument(object): :return: An UTF-8 string representation of the requested value. :raises UnicodeDecodeError: If the value cannot be decoded into an UTF-8 string. :raises IndexError: If the index is out of bounds. - :raises ValueError: If the index is not an int. + :raises TypeError: If the index is not an int. :raises TypeError: If the requested value is neither a str nor a bytes object. """ if not isinstance(index, int): # Index must be an int - raise ValueError('index must not be None') + raise TypeError('index must be an int') # Type hint; access array element v: Tuple[str, Union[bytes, str]] = self.__value[index] if isinstance(v[1], str): @@ -107,6 +179,14 @@ class RequestArgument(object): return v[1].decode('utf-8') raise TypeError('Value is neither a str nor bytes') + def __str__(self) -> str: + """ + Attempts to return the first value as a string. + :return: An UTF-8 string representation of the first value. + :raises UnicodeDecodeError: If the value cannot be decoded into an UTF-8 string. + """ + return self.get_str() + def get_bytes(self, index: int = 0) -> bytes: """ Attempts to return a value as a bytes object. The index defaults to 0. @@ -114,12 +194,12 @@ class RequestArgument(object): :param index: The index of the value to retrieve. Default: 0. :return: A bytes object representation of the requested value. Strings will be encoded as UTF-8. :raises IndexError: If the index is out of bounds. - :raises ValueError: If the index is not an int. + :raises TypeError: If the index is not an int. :raises TypeError: If the requested value is neither a str nor a bytes object. """ if not isinstance(index, int): # Index must be a int - raise ValueError('index must not be None') + raise TypeError('index must be an int') # Type hint; access array element v: Tuple[str, Union[bytes, str]] = self.__value[index] if isinstance(v[1], bytes): @@ -130,6 +210,13 @@ class RequestArgument(object): return v[1].encode('utf-8') raise TypeError('Value is neither a str nor bytes') + def __bytes__(self) -> bytes: + """ + Attempts to return the first value as a bytes object. + :return: A bytes string representation of the first value. + """ + return self.get_bytes() + def get_content_type(self, index: int = 0) -> str: """ Attempts to retrieve a value's Content-Type. The index defaults to 0. @@ -137,18 +224,20 @@ class RequestArgument(object): :param index: The index of the value to retrieve. Default: 0. :return: The Content-Type of the requested value, as sent by the client. Not necessarily trustworthy. :raises IndexError: If the index is out of bounds. - :raises ValueError: If the index is not an int. + :raises TypeError: If the index is not an int. """ # instance is an array value if not isinstance(index, int): # Needs an index for array values - raise ValueError('index must not be None') + raise TypeError('index must be an int') # Type hint; access array element va: Tuple[str, Union[bytes, str]] = self.__value[index] # Return the content type of the requested value + if not isinstance(va[0], str): + raise TypeError('Content-Type is not a str') return va[0] - def append(self, ctype: str, value: Union[str, bytes]): + def append(self, ctype: str, value: Union[str, bytes]) -> None: """ Append a value to this instance. Turns an empty argument into a scalar and a scalar into an array. @@ -161,6 +250,17 @@ class RequestArgument(object): raise TypeError('A RequestArgument view is immutable!') self.__value.append((ctype, value)) + def clear(self) -> None: + """ + Remove all values from this instance. + + :raises TypeError: If called on an immutable view. + """ + if self.is_view: + # This is an immutable view, raise exception + raise TypeError('A RequestArgument view is immutable!') + self.__value.clear() + def __len__(self) -> int: """ :return: Number of values for this argument. @@ -178,7 +278,7 @@ class RequestArgument(object): # Yield an immutable scalar view for each (ctype, value) element in the array yield _View(self.__name, v) - def __getitem__(self, index: Union[int, slice]): + def __getitem__(self, index: Union[int, slice]) -> 'RequestArgument': """ Index the argument with either an int or a slice. The returned values are represented as immutable RequestArgument views. @@ -206,6 +306,16 @@ class _View(RequestArgument): """ super().__init__(name, value) + @staticmethod + def of(argument: 'RequestArgument') ->'RequestArgument': + """ + Create an immutable, unsliced view of an RequestArgument instance. + + :param argument: The RequestArgument instance to create a view of. + :return: An immutable view of the provided RequestArgument instance. + """ + return argument[:] + @property def is_view(self) -> bool: """ diff --git a/matemat/webserver/test/test_requestargs.py b/matemat/webserver/test/test_requestargs.py index 3383863..3e093a2 100644 --- a/matemat/webserver/test/test_requestargs.py +++ b/matemat/webserver/test/test_requestargs.py @@ -1,9 +1,10 @@ -from typing import List +from typing import Dict, List, Set, Tuple import unittest +import urllib.parse -from matemat.webserver import RequestArgument +from matemat.webserver import RequestArgument, RequestArguments # noinspection PyProtectedMember from matemat.webserver.requestargs import _View @@ -228,6 +229,56 @@ class TestRequestArguments(unittest.TestCase): self.assertEqual(b'\x00\x80\xff\xfe', ra.get_bytes(1)) self.assertEqual(b'Hello, World!', ra.get_bytes(2)) + def test_clear_empty(self): + """ + Test clearing an empty RequestArgument. + """ + # Initialize the empty RequestArgument + ra = RequestArgument('foo') + self.assertEqual(0, len(ra)) + self.assertFalse(ra.is_scalar) + ra.clear() + # Clearing an empty RequestArgument shouldn't have any effect + self.assertEqual('foo', ra.name) + self.assertEqual(0, len(ra)) + self.assertFalse(ra.is_scalar) + + def test_clear_scalar(self): + """ + Test clearing a scalar RequestArgument. + """ + # Initialize the scalar RequestArgument + ra = RequestArgument('foo', ('text/plain', 'bar')) + self.assertEqual(1, len(ra)) + self.assertTrue(ra.is_scalar) + ra.clear() + # Clearing a scalar RequestArgument should reduce its size to 0 + self.assertEqual('foo', ra.name) + self.assertEqual(0, len(ra)) + self.assertFalse(ra.is_scalar) + with self.assertRaises(IndexError): + ra.get_str() + + def test_clear_array(self): + """ + Test clearing an array RequestArgument. + """ + # Initialize the array RequestArgument + ra = RequestArgument('foo', [ + ('text/plain', 'bar'), + ('application/octet-stream', b'\x00\x80\xff\xfe'), + ('text/plain', 'baz'), + ]) + self.assertEqual(3, len(ra)) + self.assertFalse(ra.is_scalar) + ra.clear() + # Clearing an array RequestArgument should reduce its size to 0 + self.assertEqual('foo', ra.name) + self.assertEqual(0, len(ra)) + self.assertFalse(ra.is_scalar) + with self.assertRaises(IndexError): + ra.get_str() + def test_iterate_empty(self): """ Test iterating an empty RequestArgument. @@ -352,3 +403,127 @@ class TestRequestArguments(unittest.TestCase): # Make sure the returned value is immutable with self.assertRaises(TypeError): it.append('foo', 'bar') + with self.assertRaises(TypeError): + it.clear() + + def test_str_shorthand(self): + """ + Test the shorthand for get_str(0). + """ + ra = RequestArgument('foo', ('bar', 'baz')) + self.assertEqual('baz', str(ra)) + + def test_bytes_shorthand(self): + """ + Test the shorthand for get_bytes(0). + """ + ra = RequestArgument('foo', ('bar', b'\x00\x80\xff\xfe')) + self.assertEqual(b'\x00\x80\xff\xfe', bytes(ra)) + + # noinspection PyTypeChecker + def test_insert_garbage(self): + """ + Test proper handling with non-int indices and non-str/non-bytes data + :return: + """ + ra = RequestArgument('foo', 42) + with self.assertRaises(TypeError): + str(ra) + ra = RequestArgument('foo', (None, 42)) + with self.assertRaises(TypeError): + str(ra) + with self.assertRaises(TypeError): + bytes(ra) + with self.assertRaises(TypeError): + ra.get_content_type() + with self.assertRaises(TypeError): + ra.get_str('foo') + with self.assertRaises(TypeError): + ra.get_bytes('foo') + with self.assertRaises(TypeError): + ra.get_content_type('foo') + + def test_requestarguments_index(self): + """ + Make sure indexing a RequestArguments instance creates a new entry on the fly. + """ + ra = RequestArguments() + self.assertEqual(0, len(ra)) + self.assertFalse('foo' in ra) + # Create new entry + _ = ra['foo'] + self.assertEqual(1, len(ra)) + self.assertTrue('foo' in ra) + # Already exists, no new entry created + _ = ra['foo'] + self.assertEqual(1, len(ra)) + # Entry must be empty and mutable, and have the correct name + self.assertFalse(ra['foo'].is_view) + self.assertEqual(0, len(ra['foo'])) + self.assertEqual('foo', ra['foo'].name) + # Key must be a string + with self.assertRaises(TypeError): + # noinspection PyTypeChecker + _ = ra[42] + + def test_requestarguments_attr(self): + """ + Test attribute access syntactic sugar. + """ + ra = RequestArguments() + # Attribute should not exist yet + with self.assertRaises(KeyError): + _ = ra.foo + # Create entry + _ = ra['foo'] + # Creating entry should have created the attribute + self.assertEqual('foo', ra.foo.name) + # Attribute access should yield an immutable view + self.assertTrue(ra.foo.is_view) + + def test_requestarguments_iterate(self): + """ + Test iterating a RequestArguments instance. + """ + # Create an instance with some values + ra = RequestArguments() + ra['foo'].append('a', 'b') + ra['bar'].append('c', 'd') + ra['foo'].append('e', 'f') + # Container for test values (name, value) + items: Set[Tuple[str, str]] = set() + # Iterate RequestArguments instance, adding the name and value of each to the set + for a in ra: + items.add((a.name, str(a))) + # Compare result with expected value + self.assertEqual(2, len(items)) + self.assertIn(('foo', 'b'), items) + self.assertIn(('bar', 'd'), items) + + def test_requestarguments_full_use_case(self): + """ + Simulate a minimal RequestArguments use case. + """ + # Create empty RequestArguments instance + ra = RequestArguments() + # Parse GET request + getargs: Dict[str, List[str]] = urllib.parse.parse_qs('foo=42&bar=1337&foo=43&baz=Hello,%20World!') + # Insert GET arguments into RequestArguments + for k, vs in getargs.items(): + for v in vs: + ra[k].append('text/plain', v) + # Parse POST request + postargs: Dict[str, List[str]] = urllib.parse.parse_qs('foo=postfoo&postbar=42&foo=postfoo') + # Insert POST arguments into RequestArguments + for k, vs in postargs.items(): + # In this implementation, POST args replace GET args + ra[k].clear() + for v in vs: + ra[k].append('text/plain', v) + + # Someplace else: Use the RequestArguments instance. + self.assertEqual('1337', ra.bar.get_str()) + self.assertEqual('Hello, World!', ra.baz.get_str()) + self.assertEqual('42', ra.postbar.get_str()) + for a in ra.foo: + self.assertEqual('postfoo', a.get_str()) From 0fb60d1828695ec52d5008b6415441138058c604 Mon Sep 17 00:00:00 2001 From: s3lph Date: Fri, 29 Jun 2018 22:13:39 +0200 Subject: [PATCH 9/9] Using the new RequestArguments API throughout the project. --- matemat/webserver/httpd.py | 10 ++--- matemat/webserver/pagelets/login.py | 10 ++--- matemat/webserver/pagelets/logout.py | 4 +- matemat/webserver/pagelets/main.py | 6 +-- matemat/webserver/pagelets/touchkey.py | 12 +++-- matemat/webserver/test/abstract_httpd_test.py | 8 ++-- matemat/webserver/test/test_parse_request.py | 44 +++++++++---------- matemat/webserver/test/test_post.py | 12 ++--- matemat/webserver/test/test_serve.py | 8 ++-- matemat/webserver/test/test_session.py | 6 +-- matemat/webserver/util.py | 28 ++++++------ 11 files changed, 73 insertions(+), 75 deletions(-) diff --git a/matemat/webserver/httpd.py b/matemat/webserver/httpd.py index 79efb98..c59e3fc 100644 --- a/matemat/webserver/httpd.py +++ b/matemat/webserver/httpd.py @@ -13,7 +13,7 @@ from uuid import uuid4 from datetime import datetime, timedelta from matemat import __version__ as matemat_version -from matemat.webserver import RequestArgument +from matemat.webserver import RequestArguments from matemat.webserver.util import parse_args @@ -31,7 +31,7 @@ BaseHTTPRequestHandler.log_error = lambda self, fstring='', *args: None # Dictionary to hold registered pagelet paths and their handler functions _PAGELET_PATHS: Dict[str, Callable[[str, # HTTP method (GET, POST, ...) str, # Request path - Dict[str, RequestArgument], # args: (name, argument) + RequestArguments, # HTTP Request arguments Dict[str, Any], # Session vars Dict[str, str]], # Response headers Tuple[int, Union[bytes, str]]]] = dict() # Returns: (status code, response body) @@ -51,7 +51,7 @@ def pagelet(path: str): (method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str]) -> (int, Optional[Union[str, bytes]]) @@ -69,7 +69,7 @@ def pagelet(path: str): def http_handler(fun: Callable[[str, str, - Dict[str, RequestArgument], + RequestArguments, Dict[str, Any], Dict[str, str]], Tuple[int, Union[bytes, str]]]): @@ -181,7 +181,7 @@ class HttpHandler(BaseHTTPRequestHandler): if session_id in self.server.session_vars: del self.server.session_vars[session_id] - def _handle(self, method: str, path: str, args: Dict[str, RequestArgument]) -> None: + def _handle(self, method: str, path: str, args: RequestArguments) -> None: """ Handle a HTTP request by either dispatching it to the appropriate pagelet or by serving a static resource. diff --git a/matemat/webserver/pagelets/login.py b/matemat/webserver/pagelets/login.py index f7813b4..7d0cc2d 100644 --- a/matemat/webserver/pagelets/login.py +++ b/matemat/webserver/pagelets/login.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional, Tuple, Union from matemat.exceptions import AuthenticationError -from matemat.webserver import pagelet, RequestArgument +from matemat.webserver import pagelet, RequestArguments from matemat.primitives import User from matemat.db import MatematDatabase @@ -10,7 +10,7 @@ from matemat.db import MatematDatabase @pagelet('/login') def login_page(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: @@ -41,13 +41,11 @@ def login_page(method: str, ''' - return 200, data.format(msg=args['msg'] if 'msg' in args else '') + return 200, data.format(msg=str(args.msg) if 'msg' in args else '') elif method == 'POST': - username: RequestArgument = args['username'] - password: RequestArgument = args['password'] with MatematDatabase('test.db') as db: try: - user: User = db.login(username.get_str(), password.get_str()) + user: User = db.login(str(args.username), str(args.password)) except AuthenticationError: headers['Location'] = '/login?msg=Username%20or%20password%20wrong.%20Please%20try%20again.' return 301, bytes() diff --git a/matemat/webserver/pagelets/logout.py b/matemat/webserver/pagelets/logout.py index b70d7c1..beb86a3 100644 --- a/matemat/webserver/pagelets/logout.py +++ b/matemat/webserver/pagelets/logout.py @@ -1,13 +1,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union -from matemat.webserver import pagelet, RequestArgument +from matemat.webserver import pagelet, RequestArguments @pagelet('/logout') def logout(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: diff --git a/matemat/webserver/pagelets/main.py b/matemat/webserver/pagelets/main.py index 2b9ce79..e22c872 100644 --- a/matemat/webserver/pagelets/main.py +++ b/matemat/webserver/pagelets/main.py @@ -1,7 +1,7 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union -from matemat.webserver import MatematWebserver, pagelet, RequestArgument +from matemat.webserver import pagelet, RequestArguments from matemat.primitives import User from matemat.db import MatematDatabase @@ -9,7 +9,7 @@ from matemat.db import MatematDatabase @pagelet('/') def main_page(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: diff --git a/matemat/webserver/pagelets/touchkey.py b/matemat/webserver/pagelets/touchkey.py index 22e3df4..4de8009 100644 --- a/matemat/webserver/pagelets/touchkey.py +++ b/matemat/webserver/pagelets/touchkey.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from matemat.exceptions import AuthenticationError -from matemat.webserver import pagelet, RequestArgument +from matemat.webserver import pagelet, RequestArguments from matemat.primitives import User from matemat.db import MatematDatabase @@ -10,7 +10,7 @@ from matemat.db import MatematDatabase @pagelet('/touchkey') def touchkey_page(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str])\ -> Tuple[int, Optional[Union[str, bytes]]]: @@ -40,13 +40,11 @@ def touchkey_page(method: str, ''' - return 200, data.format(username=args['username'] if 'username' in args else '') + return 200, data.format(username=str(args.username) if 'username' in args else '') elif method == 'POST': - username: RequestArgument = args['username'] - touchkey: RequestArgument = args['touchkey'] with MatematDatabase('test.db') as db: try: - user: User = db.login(username.get_str(), touchkey=touchkey.get_str()) + user: User = db.login(str(args.username), touchkey=str(args.touchkey)) except AuthenticationError: headers['Location'] = f'/touchkey?username={args["username"]}&msg=Please%20try%20again.' return 301, bytes() diff --git a/matemat/webserver/test/abstract_httpd_test.py b/matemat/webserver/test/abstract_httpd_test.py index daa1126..103979b 100644 --- a/matemat/webserver/test/abstract_httpd_test.py +++ b/matemat/webserver/test/abstract_httpd_test.py @@ -1,5 +1,5 @@ -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, Tuple, Union import unittest.mock from io import BytesIO @@ -9,7 +9,7 @@ from abc import ABC from datetime import datetime from http.server import HTTPServer -from matemat.webserver import pagelet, RequestArgument +from matemat.webserver import pagelet, RequestArguments class HttpResponse: @@ -158,14 +158,14 @@ def test_pagelet(path: str): def with_testing_headers(fun: Callable[[str, str, - Dict[str, RequestArgument], + RequestArguments, Dict[str, Any], Dict[str, str]], Tuple[int, Union[bytes, str]]]): @pagelet(path) def testing_wrapper(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str]): status, body = fun(method, path, args, session_vars, headers) diff --git a/matemat/webserver/test/test_parse_request.py b/matemat/webserver/test/test_parse_request.py index a533936..0a94065 100644 --- a/matemat/webserver/test/test_parse_request.py +++ b/matemat/webserver/test/test_parse_request.py @@ -20,9 +20,9 @@ class TestParseRequest(unittest.TestCase): path, args = parse_args('/?foo=42&bar=1337&baz=Hello,%20World!') self.assertEqual('/', path) self.assertEqual(3, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('bar', args.keys()) - self.assertIn('baz', args.keys()) + self.assertIn('foo', args) + self.assertIn('bar', args) + self.assertIn('baz', args) self.assertTrue(args['foo'].is_scalar) self.assertTrue(args['bar'].is_scalar) self.assertTrue(args['baz'].is_scalar) @@ -37,9 +37,9 @@ class TestParseRequest(unittest.TestCase): path, args = parse_args('/abc/def?foo=42&bar=1337&baz=Hello,%20World!') self.assertEqual('/abc/def', path) self.assertEqual(3, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('bar', args.keys()) - self.assertIn('baz', args.keys()) + self.assertIn('foo', args) + self.assertIn('bar', args) + self.assertIn('baz', args) self.assertTrue(args['foo'].is_scalar) self.assertTrue(args['bar'].is_scalar) self.assertTrue(args['baz'].is_scalar) @@ -54,8 +54,8 @@ class TestParseRequest(unittest.TestCase): path, args = parse_args('/abc/def?foo=42&foo=1337&baz=Hello,%20World!') self.assertEqual('/abc/def', path) self.assertEqual(2, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('baz', args.keys()) + self.assertIn('foo', args) + self.assertIn('baz', args) self.assertTrue(args['foo'].is_array) self.assertTrue(args['baz'].is_scalar) self.assertEqual(2, len(args['foo'])) @@ -65,8 +65,8 @@ class TestParseRequest(unittest.TestCase): def test_parse_get_zero_arg(self): path, args = parse_args('/abc/def?foo=&bar=42') self.assertEqual(2, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('bar', args.keys()) + self.assertIn('foo', args) + self.assertIn('bar', args) self.assertTrue(args['foo'].is_scalar) self.assertTrue(args['bar'].is_scalar) self.assertEqual(1, len(args['foo'])) @@ -83,9 +83,9 @@ class TestParseRequest(unittest.TestCase): enctype='application/x-www-form-urlencoded') self.assertEqual('/', path) self.assertEqual(3, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('bar', args.keys()) - self.assertIn('baz', args.keys()) + self.assertIn('foo', args) + self.assertIn('bar', args) + self.assertIn('baz', args) self.assertTrue(args['foo'].is_scalar) self.assertTrue(args['bar'].is_scalar) self.assertTrue(args['baz'].is_scalar) @@ -102,8 +102,8 @@ class TestParseRequest(unittest.TestCase): enctype='application/x-www-form-urlencoded') self.assertEqual('/', path) self.assertEqual(2, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('baz', args.keys()) + self.assertIn('foo', args) + self.assertIn('baz', args) self.assertTrue(args['foo'].is_array) self.assertTrue(args['baz'].is_scalar) self.assertEqual(2, len(args['foo'])) @@ -113,8 +113,8 @@ class TestParseRequest(unittest.TestCase): def test_parse_post_urlencoded_zero_arg(self): path, args = parse_args('/abc/def', postbody=b'foo=&bar=42', enctype='application/x-www-form-urlencoded') self.assertEqual(2, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('bar', args.keys()) + self.assertIn('foo', args) + self.assertIn('bar', args) self.assertTrue(args['foo'].is_scalar) self.assertTrue(args['bar'].is_scalar) self.assertEqual(1, len(args['foo'])) @@ -152,9 +152,9 @@ class TestParseRequest(unittest.TestCase): enctype='multipart/form-data; boundary=testBoundary1337') self.assertEqual('/', path) self.assertEqual(3, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('bar', args.keys()) - self.assertIn('baz', args.keys()) + self.assertIn('foo', args) + self.assertIn('bar', args) + self.assertIn('baz', args) self.assertTrue(args['foo'].is_scalar) self.assertTrue(args['bar'].is_scalar) self.assertTrue(args['baz'].is_scalar) @@ -177,8 +177,8 @@ class TestParseRequest(unittest.TestCase): b'--testBoundary1337--\r\n', enctype='multipart/form-data; boundary=testBoundary1337') self.assertEqual(2, len(args)) - self.assertIn('foo', args.keys()) - self.assertIn('bar', args.keys()) + self.assertIn('foo', args) + self.assertIn('bar', args) self.assertTrue(args['foo'].is_scalar) self.assertTrue(args['bar'].is_scalar) self.assertEqual(1, len(args['foo'])) diff --git a/matemat/webserver/test/test_post.py b/matemat/webserver/test/test_post.py index 9b2fe22..0bc5d16 100644 --- a/matemat/webserver/test/test_post.py +++ b/matemat/webserver/test/test_post.py @@ -1,7 +1,7 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List -from matemat.webserver import HttpHandler, RequestArgument +from matemat.webserver import HttpHandler, RequestArguments from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet import codecs @@ -10,7 +10,7 @@ import codecs @test_pagelet('/just/testing/post') def post_test_pagelet(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str]): """ @@ -18,12 +18,12 @@ def post_test_pagelet(method: str, """ headers['Content-Type'] = 'text/plain' dump: str = '' - for k, ra in args.items(): + for ra in args: for a in ra: if a.get_content_type().startswith('text/'): - dump += f'{k}: {a.get_str()}\n' + dump += f'{a.name}: {a.get_str()}\n' else: - dump += f'{k}: {codecs.encode(a.get_bytes(), "hex").decode("utf-8")}\n' + dump += f'{a.name}: {codecs.encode(a.get_bytes(), "hex").decode("utf-8")}\n' return 200, dump diff --git a/matemat/webserver/test/test_serve.py b/matemat/webserver/test/test_serve.py index 7e159e3..722870b 100644 --- a/matemat/webserver/test/test_serve.py +++ b/matemat/webserver/test/test_serve.py @@ -1,16 +1,16 @@ -from typing import Any, Dict, Union +from typing import Any, Dict import os import os.path -from matemat.webserver import HttpHandler, RequestArgument +from matemat.webserver import HttpHandler, RequestArguments from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet @test_pagelet('/just/testing/serve_pagelet_ok') def serve_test_pagelet_ok(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str]): headers['Content-Type'] = 'text/plain' @@ -20,7 +20,7 @@ def serve_test_pagelet_ok(method: str, @test_pagelet('/just/testing/serve_pagelet_fail') def serve_test_pagelet_fail(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str]): session_vars['test'] = 'hello, world!' diff --git a/matemat/webserver/test/test_session.py b/matemat/webserver/test/test_session.py index fe30529..5cf408e 100644 --- a/matemat/webserver/test/test_session.py +++ b/matemat/webserver/test/test_session.py @@ -1,17 +1,17 @@ -from typing import Any, Dict, Union +from typing import Any, Dict from datetime import datetime, timedelta from time import sleep -from matemat.webserver import HttpHandler, RequestArgument +from matemat.webserver import HttpHandler, RequestArguments from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet @test_pagelet('/just/testing/sessions') def session_test_pagelet(method: str, path: str, - args: Dict[str, RequestArgument], + args: RequestArguments, session_vars: Dict[str, Any], headers: Dict[str, str]): session_vars['test'] = 'hello, world!' diff --git a/matemat/webserver/util.py b/matemat/webserver/util.py index 61cf430..2bc2244 100644 --- a/matemat/webserver/util.py +++ b/matemat/webserver/util.py @@ -1,9 +1,9 @@ -from typing import Dict, List, Tuple, Optional, Union +from typing import Dict, List, Tuple, Optional import urllib.parse -from matemat.webserver import RequestArgument +from matemat.webserver import RequestArguments, RequestArgument def _parse_multipart(body: bytes, boundary: str) -> List[RequestArgument]: @@ -76,7 +76,7 @@ def _parse_multipart(body: bytes, boundary: str) -> List[RequestArgument]: def parse_args(request: str, postbody: Optional[bytes] = None, enctype: str = 'text/plain') \ - -> Tuple[str, Dict[str, RequestArgument]]: + -> Tuple[str, RequestArguments]: """ Given a HTTP request path, and optionally a HTTP POST body in application/x-www-form-urlencoded or multipart/form-data form, parse the arguments and return them as a dictionary. @@ -98,11 +98,11 @@ def parse_args(request: str, postbody: Optional[bytes] = None, enctype: str = 't else: getargs = urllib.parse.parse_qs(tokens.query, strict_parsing=True, keep_blank_values=True, errors='strict') - args: Dict[str, RequestArgument] = dict() - for k, v in getargs.items(): - args[k] = RequestArgument(k) - for _v in v: - args[k].append('text/plain', _v) + args = RequestArguments() + for k, vs in getargs.items(): + args[k].clear() + for v in vs: + args[k].append('text/plain', v) if postbody is not None: if enctype == 'application/x-www-form-urlencoded': @@ -113,10 +113,10 @@ def parse_args(request: str, postbody: Optional[bytes] = None, enctype: str = 't else: postargs = urllib.parse.parse_qs(pb, strict_parsing=True, keep_blank_values=True, errors='strict') # Write all POST values into the dict, overriding potential duplicates from GET - for k, v in postargs.items(): - args[k] = RequestArgument(k) - for _v in v: - args[k].append('text/plain', _v) + for k, vs in postargs.items(): + args[k].clear() + for v in vs: + args[k].append('text/plain', v) elif enctype.startswith('multipart/form-data'): # Parse the multipart boundary from the Content-Type header try: @@ -126,7 +126,9 @@ def parse_args(request: str, postbody: Optional[bytes] = None, enctype: str = 't # Parse the multipart body mpargs = _parse_multipart(postbody, boundary) for ra in mpargs: - args[ra.name] = ra + args[ra.name].clear() + for a in ra: + args[ra.name].append(a.get_content_type(), bytes(a)) else: raise ValueError(f'Unsupported Content-Type: {enctype}') # Return the path and the parsed arguments