1
0
Fork 0
forked from s3lph/matemat

Merge branch 'staging' into 'master'

Increased test coverage

See merge request s3lph/matemat!40
This commit is contained in:
s3lph 2018-08-28 21:23:08 +00:00
commit f959fd6ef4
13 changed files with 294 additions and 29 deletions

View file

@ -300,10 +300,6 @@ class MatematDatabase(object):
'is_admin': is_admin, 'is_admin': is_admin,
'is_member': is_member 'is_member': is_member
}) })
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(
f'change_user should affect 1 users row, but affected {affected}')
# Only update the actual user object after the changes in the database succeeded # Only update the actual user object after the changes in the database succeeded
user.name = name user.name = name
user.email = email user.email = email

View file

@ -84,6 +84,12 @@ class DatabaseTest(unittest.TestCase):
with self.db as db: with self.db as db:
with db.transaction() as c: with db.transaction() as c:
u = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') u = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
# Attempt touchkey login without a set touchkey
try:
db.login('testuser', touchkey='0123')
self.fail()
except AuthenticationError as e:
self.assertEqual('Touchkey not set', e.msg)
# Add a touchkey without using the provided function # Add a touchkey without using the provided function
c.execute('''UPDATE users SET touchkey = :tkhash WHERE user_id = :user_id''', { c.execute('''UPDATE users SET touchkey = :tkhash WHERE user_id = :user_id''', {
'tkhash': crypt.crypt('0123', crypt.mksalt()), 'tkhash': crypt.crypt('0123', crypt.mksalt()),
@ -172,6 +178,9 @@ class DatabaseTest(unittest.TestCase):
self.assertFalse(checkuser.is_admin) self.assertFalse(checkuser.is_admin)
self.assertFalse(checkuser.is_member) self.assertFalse(checkuser.is_member)
self.assertEqual(4200, checkuser.balance) self.assertEqual(4200, checkuser.balance)
# Balance change without an agent must fail
with self.assertRaises(ValueError):
db.change_user(user, None, balance=0)
user.id = -1 user.id = -1
with self.assertRaises(DatabaseConsistencyError): with self.assertRaises(DatabaseConsistencyError):
db.change_user(user, agent, is_member='True') db.change_user(user, agent, is_member='True')
@ -360,3 +369,10 @@ class DatabaseTest(unittest.TestCase):
self.assertEqual(60, c.fetchone()[0]) self.assertEqual(60, c.fetchone()[0])
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [fritzmate.id]) c.execute('''SELECT stock FROM products WHERE product_id = ?''', [fritzmate.id])
self.assertEqual(10, c.fetchone()[0]) self.assertEqual(10, c.fetchone()[0])
user1.id = -1
clubmate.id = -1
with self.assertRaises(DatabaseConsistencyError):
db.increment_consumption(user1, florapowermate)
with self.assertRaises(DatabaseConsistencyError):
db.increment_consumption(user2, clubmate)

View file

@ -103,3 +103,20 @@ class DatabaseTest(unittest.TestCase):
with db.transaction(): with db.transaction():
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
self.db.close() self.db.close()
def test_use_before_open(self):
with self.assertRaises(RuntimeError):
with self.db.transaction():
pass
with self.assertRaises(RuntimeError):
self.db.close()
with self.assertRaises(RuntimeError):
_ = self.db._user_version
with self.assertRaises(RuntimeError):
self.db._user_version = 42
def test_setup_prevent_downgrade(self):
self.db._sqlite_db = sqlite3.connect(':memory:')
self.db._user_version = 1337
with self.assertRaises(RuntimeError):
self.db._setup()

View file

View file

@ -0,0 +1,15 @@
import unittest
from matemat.exceptions import AuthenticationError
class TestAuthenticationError(unittest.TestCase):
def test_msg(self):
e = AuthenticationError('testmsg')
self.assertEqual('testmsg', e.msg)
def test_str(self):
e = AuthenticationError('testmsg')
self.assertEqual('AuthenticationError: testmsg', str(e))

View file

@ -0,0 +1,15 @@
import unittest
from matemat.exceptions import DatabaseConsistencyError
class TestDatabaseConsistencyError(unittest.TestCase):
def test_msg(self):
e = DatabaseConsistencyError('testmsg')
self.assertEqual('testmsg', e.msg)
def test_str(self):
e = DatabaseConsistencyError('testmsg')
self.assertEqual('DatabaseConsistencyError: testmsg', str(e))

View file

@ -0,0 +1,19 @@
import unittest
from matemat.exceptions import HttpException
class TestHttpException(unittest.TestCase):
def test_all_args(self):
e = HttpException(1337, 'Foo Bar', 'Lorem Ipsum Dolor Sit Amet')
self.assertEqual(1337, e.status)
self.assertEqual('Foo Bar', e.title)
self.assertEqual('Lorem Ipsum Dolor Sit Amet', e.message)
def test_default_args(self):
e = HttpException()
self.assertEqual(500, e.status)
self.assertEqual('An error occurred', e.title)
self.assertIsNone(e.message)

View file

@ -131,3 +131,9 @@ class TestCurrencyFormat(unittest.TestCase):
parse_chf('13,37') parse_chf('13,37')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parse_chf('CHF 13,37') parse_chf('CHF 13,37')
def test_parse_frac_negative(self):
with self.assertRaises(ValueError):
parse_chf('13.-7')
with self.assertRaises(ValueError):
parse_chf('CHF 13.-7')

View file

@ -455,15 +455,9 @@ class HttpHandler(BaseHTTPRequestHandler):
self.send_response(404) self.send_response(404)
self.end_headers() self.end_headers()
# noinspection PyPep8Naming def _handle_request(self, method: str, path: str, args: RequestArguments):
def do_GET(self) -> None:
"""
Called by BasicHTTPRequestHandler for GET requests.
"""
try: try:
# Parse the request and hand it to the handle function self._handle(method, path, args)
path, args = parse_args(self.path)
self._handle('GET', path, args)
# Special handling for some errors # Special handling for some errors
except HttpException as e: except HttpException as e:
self.send_error(e.status, e.title, e.message) self.send_error(e.status, e.title, e.message)
@ -483,6 +477,19 @@ class HttpHandler(BaseHTTPRequestHandler):
self.send_error(500, 'Internal Server Error') self.send_error(500, 'Internal Server Error')
self.server.logger.exception('', e.args, e) self.server.logger.exception('', e.args, e)
# noinspection PyPep8Naming
def do_GET(self) -> None:
"""
Called by BasicHTTPRequestHandler for GET requests.
"""
# Parse the request and hand it to the handle function
try:
path, args = parse_args(self.path)
self._handle_request('GET', path, args)
except ValueError as e:
self.send_error(400, 'Bad Request')
self.server.logger.debug('', exc_info=e)
# noinspection PyPep8Naming # noinspection PyPep8Naming
def do_POST(self) -> None: def do_POST(self) -> None:
""" """
@ -492,29 +499,18 @@ class HttpHandler(BaseHTTPRequestHandler):
# Read the POST body, if it exists, and its MIME type is application/x-www-form-urlencoded # Read the POST body, if it exists, and its MIME type is application/x-www-form-urlencoded
clen: int = int(str(self.headers.get('Content-Length', failobj='0'))) clen: int = int(str(self.headers.get('Content-Length', failobj='0')))
if clen > _MAX_POST: if clen > _MAX_POST:
raise ValueError('Request too big') # Return a 413 error page if the request size exceeds boundaries
self.send_error(413, 'Payload Too Large')
self.server.logger.debug('', exc_info=HttpException(413, 'Payload Too Large'))
return
ctype: str = self.headers.get('Content-Type', failobj='application/octet-stream') ctype: str = self.headers.get('Content-Type', failobj='application/octet-stream')
post: bytes = self.rfile.read(clen) post: bytes = self.rfile.read(clen)
path, args = parse_args(self.path, postbody=post, enctype=ctype)
# Parse the request and hand it to the handle function # Parse the request and hand it to the handle function
self._handle('POST', path, args) path, args = parse_args(self.path, postbody=post, enctype=ctype)
# Special handling for some errors self._handle_request('POST', path, args)
except HttpException as e:
if 500 <= e.status < 600:
self.send_error(500, 'Internal Server Error')
self.server.logger.exception('', exc_info=e)
else:
self.server.logger.debug('', exc_info=e)
except PermissionError as e:
self.send_error(403, 'Forbidden')
self.server.logger.debug('', exc_info=e)
except ValueError as e: except ValueError as e:
self.send_error(400, 'Bad Request') self.send_error(400, 'Bad Request')
self.server.logger.debug('', exc_info=e) self.server.logger.debug('', exc_info=e)
except BaseException as e:
# Generic error handling
self.send_error(500, 'Internal Server Error')
self.server.logger.exception('', e.args, e)
@property @property
def session_vars(self) -> Dict[str, Any]: def session_vars(self) -> Dict[str, Any]:

View file

@ -241,3 +241,12 @@ class TestConfig(TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# The filename is only a placeholder, file content is determined by mocking open # The filename is only a placeholder, file content is determined by mocking open
parse_config_file('test') parse_config_file('test')
def test_parse_config_not_a_filename(self):
"""
Test type checking for the config filenames
"""
with self.assertRaises(TypeError):
parse_config_file(42)
with self.assertRaises(TypeError):
parse_config_file(['config', 42])

View file

@ -0,0 +1,108 @@
from typing import Any, Dict, Union
from matemat.exceptions import HttpException
from matemat.webserver import HttpHandler, RequestArguments, PageletResponse
from matemat.webserver.test.abstract_httpd_test import AbstractHttpdTest, test_pagelet
@test_pagelet('/just/testing/http_exception')
def test_pagelet_http_exception(method: str,
path: str,
args: RequestArguments,
session_vars: Dict[str, Any],
headers: Dict[str, str],
pagelet_variables: Dict[str, str]) -> Union[bytes, str, PageletResponse]:
raise HttpException(int(str(args.exc)), 'Test Exception')
@test_pagelet('/just/testing/value_error')
def test_pagelet_value_error(method: str,
path: str,
args: RequestArguments,
session_vars: Dict[str, Any],
headers: Dict[str, str],
pagelet_variables: Dict[str, str]) -> Union[bytes, str, PageletResponse]:
raise ValueError('test')
@test_pagelet('/just/testing/permission_error')
def test_pagelet_permission_error(method: str,
path: str,
args: RequestArguments,
session_vars: Dict[str, Any],
headers: Dict[str, str],
pagelet_variables: Dict[str, str]) -> Union[bytes, str, PageletResponse]:
raise PermissionError('test')
@test_pagelet('/just/testing/other_error')
def test_pagelet_other_error(method: str,
path: str,
args: RequestArguments,
session_vars: Dict[str, Any],
headers: Dict[str, str],
pagelet_variables: Dict[str, str]) -> Union[bytes, str, PageletResponse]:
raise TypeError('test')
class TestHttpd(AbstractHttpdTest):
def test_httpd_get_illegal_path(self):
self.client_sock.set_request(b'GET /foo?bar?baz HTTP/1.1\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(400, packet.statuscode)
def test_httpd_post_illegal_path(self):
self.client_sock.set_request(b'POST /foo?bar?baz HTTP/1.1\r\n'
b'Content-Length: 0\r\n'
b'Content-Type: application/x-www-form-urlencoded\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(400, packet.statuscode)
def test_httpd_post_illegal_header(self):
self.client_sock.set_request(b'POST /foo?bar=baz HTTP/1.1\r\n'
b'Content-Length: 0\r\n'
b'Content-Type: application/octet-stream\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(400, packet.statuscode)
def test_httpd_post_request_too_big(self):
self.client_sock.set_request(b'POST /foo?bar=baz HTTP/1.1\r\n'
b'Content-Length: 1000001\r\n'
b'Content-Type: application/octet-stream\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(413, packet.statuscode)
def test_httpd_exception_http_400(self):
self.client_sock.set_request(b'GET /just/testing/http_exception?exc=400 HTTP/1.1\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(400, packet.statuscode)
def test_httpd_exception_http_500(self):
self.client_sock.set_request(b'GET /just/testing/http_exception?exc=500 HTTP/1.1\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(500, packet.statuscode)
def test_httpd_exception_value_error(self):
self.client_sock.set_request(b'GET /just/testing/value_error HTTP/1.1\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(400, packet.statuscode)
def test_httpd_exception_permission_error(self):
self.client_sock.set_request(b'GET /just/testing/permission_error HTTP/1.1\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(403, packet.statuscode)
def test_httpd_exception_other_error(self):
self.client_sock.set_request(b'GET /just/testing/other_error HTTP/1.1\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
self.assertEqual(500, packet.statuscode)

View file

@ -204,6 +204,32 @@ class TestParseRequest(unittest.TestCase):
self.assertEqual(b'1337', args['bar'].get_bytes()) self.assertEqual(b'1337', args['bar'].get_bytes())
self.assertEqual('Hello, World!', args['baz'].get_str()) self.assertEqual('Hello, World!', args['baz'].get_str())
def test_parse_post_multipart_names(self):
"""
Test that multipart names work both with and without quotation marks
"""
path, args = parse_args('/',
postbody=b'--testBoundary1337\r\n'
b'Content-Disposition: form-data; name="foo"\r\n'
b'Content-Type: text/plain\r\n\r\n'
b'42\r\n'
b'--testBoundary1337\r\n'
b'Content-Disposition: form-data; name=bar\r\n'
b'Content-Type: text/plain\r\n\r\n'
b'Hello, World!\r\n'
b'--testBoundary1337--\r\n',
enctype='multipart/form-data; boundary=testBoundary1337')
self.assertEqual('/', path)
self.assertEqual(2, len(args))
self.assertIn('foo', args)
self.assertIn('bar', args)
self.assertTrue(args['foo'].is_scalar)
self.assertTrue(args['bar'].is_scalar)
self.assertEqual('text/plain', args['foo'].get_content_type())
self.assertEqual('text/plain', args['bar'].get_content_type())
self.assertEqual('42', args['foo'].get_str())
self.assertEqual('Hello, World!', args['bar'].get_str())
def test_parse_post_multipart_zero_arg(self): def test_parse_post_multipart_zero_arg(self):
""" """
Test that a multipart POST request with an empty argument is parsed correctly. Test that a multipart POST request with an empty argument is parsed correctly.

View file

@ -63,6 +63,28 @@ def serve_test_pagelet_fail(method: str,
raise HttpException(599, 'Error expected during unit testing') raise HttpException(599, 'Error expected during unit testing')
# noinspection PyTypeChecker
@test_pagelet('/just/testing/serve_pagelet_empty')
def serve_test_pagelet_empty(method: str,
path: str,
args: RequestArguments,
session_vars: Dict[str, Any],
headers: Dict[str, str],
pagelet_variables: Dict[str, str]) -> Union[bytes, str, PageletResponse]:
return PageletResponse()
# noinspection PyTypeChecker
@test_pagelet('/just/testing/serve_pagelet_type_error')
def serve_test_pagelet_fail(method: str,
path: str,
args: RequestArguments,
session_vars: Dict[str, Any],
headers: Dict[str, str],
pagelet_variables: Dict[str, str]) -> Union[bytes, str, PageletResponse]:
return 42
class TestServe(AbstractHttpdTest): class TestServe(AbstractHttpdTest):
""" """
Test cases for the content serving of the web server. Test cases for the content serving of the web server.
@ -147,6 +169,18 @@ class TestServe(AbstractHttpdTest):
# Make sure the response body was rendered correctly by the templating engine # Make sure the response body was rendered correctly by the templating engine
self.assertEqual(b'Hello, World!', packet.body) self.assertEqual(b'Hello, World!', packet.body)
def test_serve_pagelet_empty(self):
# Call the test pagelet that redirects to another path
self.client_sock.set_request(b'GET /just/testing/serve_pagelet_empty HTTP/1.1\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
# Make sure the correct pagelet was called
self.assertEqual('serve_test_pagelet_empty', packet.pagelet)
self.assertEqual(200, packet.statuscode)
# Make sure the response body was rendered correctly by the templating engine
self.assertEqual(b'', packet.body)
def test_serve_static_ok(self): def test_serve_static_ok(self):
# Request a static resource # Request a static resource
self.client_sock.set_request(b'GET /static_resource.txt HTTP/1.1\r\n\r\n') self.client_sock.set_request(b'GET /static_resource.txt HTTP/1.1\r\n\r\n')
@ -249,6 +283,14 @@ class TestServe(AbstractHttpdTest):
packet = self.client_sock.get_response() packet = self.client_sock.get_response()
self.assertEqual('application/x-foo-bar', packet.headers['Content-Type']) self.assertEqual('application/x-foo-bar', packet.headers['Content-Type'])
def test_serve_pagelet_type_error(self):
# A 500 error should be returned if a pagelet returns an invalid type
self.client_sock.set_request(b'GET /just/testing/serve_pagelet_type_error HTTP/1.1\r\n\r\n')
HttpHandler(self.client_sock, ('::1', 45678), self.server)
packet = self.client_sock.get_response()
# Make sure a 500 header is served
self.assertEqual(500, packet.statuscode)
def test_serve_static_mime_extension(self): def test_serve_static_mime_extension(self):
# The correct Content-Type should be guessed by file extension primarily # The correct Content-Type should be guessed by file extension primarily
self.client_sock.set_request(b'GET /teststyle.css HTTP/1.1\r\n\r\n') self.client_sock.set_request(b'GET /teststyle.css HTTP/1.1\r\n\r\n')