Added lots of missing type annotations.

This commit is contained in:
s3lph 2018-06-06 12:59:49 +02:00
parent edaf9afc8b
commit 33888fe597
10 changed files with 86 additions and 78 deletions

View file

@ -23,4 +23,4 @@ codestyle:
- pip3 install wheel pycodestyle mypy
- pip3 install -r requirements.txt
- pycodestyle --ignore=E501 matemat
- mypy matemat
- mypy --ignore-missing-imports --strict -p matemat

View file

@ -1,5 +1,5 @@
from typing import List, Optional
from typing import List, Optional, Any, Type
import bcrypt
@ -10,17 +10,17 @@ from matemat.db import DatabaseWrapper
class DatabaseFacade(object):
def __init__(self, filename: str):
def __init__(self, filename: str) -> None:
self.db: DatabaseWrapper = DatabaseWrapper(filename)
def __enter__(self):
def __enter__(self) -> 'DatabaseFacade':
self.db.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Type, exc_val: Any, exc_tb: Any) -> None:
self.db.__exit__(exc_type, exc_val, exc_tb)
def transaction(self, exclusive: bool = True):
def transaction(self, exclusive: bool = True) -> Any:
return self.db.transaction(exclusive=exclusive)
def list_users(self) -> List[User]:
@ -79,7 +79,7 @@ class DatabaseFacade(object):
raise AuthenticationError('Touchkey mismatch')
return User(user_id, username, email, admin, member)
def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True):
def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True) -> None:
with self.db.transaction() as c:
c.execute('''
SELECT password FROM users WHERE user_id = ?
@ -97,7 +97,7 @@ class DatabaseFacade(object):
'pwhash': pwhash
})
def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True):
def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True) -> None:
with self.db.transaction() as c:
c.execute('''
SELECT password FROM users WHERE user_id = ?
@ -115,7 +115,7 @@ class DatabaseFacade(object):
'tkhash': tkhash
})
def change_user(self, user: User):
def change_user(self, user: User) -> None:
with self.db.transaction() as c:
c.execute('''
UPDATE users SET
@ -135,7 +135,7 @@ class DatabaseFacade(object):
raise DatabaseConsistencyError(
f'change_user should affect 1 users row, but affected {affected}')
def delete_user(self, user: User):
def delete_user(self, user: User) -> None:
with self.db.transaction() as c:
c.execute('''
DELETE FROM users
@ -175,7 +175,7 @@ class DatabaseFacade(object):
product_id = int(c.fetchone()[0])
return Product(product_id, name, price_member, price_non_member)
def change_product(self, product: Product):
def change_product(self, product: Product) -> None:
with self.db.transaction() as c:
c.execute('''
UPDATE products
@ -195,7 +195,7 @@ class DatabaseFacade(object):
raise DatabaseConsistencyError(
f'change_product should affect 1 products row, but affected {affected}')
def delete_product(self, product: Product):
def delete_product(self, product: Product) -> None:
with self.db.transaction() as c:
c.execute('''
DELETE FROM products
@ -206,7 +206,7 @@ class DatabaseFacade(object):
raise DatabaseConsistencyError(
f'delete_product should affect 1 products row, but affected {affected}')
def increment_consumption(self, user: User, product: Product, count: int = 1):
def increment_consumption(self, user: User, product: Product, count: int = 1) -> None:
with self.db.transaction() as c:
c.execute('''
SELECT count
@ -265,7 +265,7 @@ class DatabaseFacade(object):
raise DatabaseConsistencyError(
f'increment_consumption should affect 1 products row, but affected {affected}')
def restock(self, product: Product, count: int):
def restock(self, product: Product, count: int) -> None:
with self.db.transaction() as c:
c.execute('''
UPDATE products
@ -279,7 +279,7 @@ class DatabaseFacade(object):
if affected != 1:
raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}')
def deposit(self, user: User, amount: int):
def deposit(self, user: User, amount: int) -> None:
if amount < 0:
raise ValueError('Cannot deposit a negative value')
with self.db.transaction() as c:

View file

@ -9,10 +9,10 @@ from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
class DatabaseTest(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.db = Database(':memory:')
def test_create_user(self):
def test_create_user(self) -> None:
with self.db as db:
with db.transaction(exclusive=False) as c:
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
@ -25,7 +25,7 @@ class DatabaseTest(unittest.TestCase):
with self.assertRaises(ValueError):
db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com')
def test_list_users(self):
def test_list_users(self) -> None:
with self.db as db:
users = db.list_users()
self.assertEqual(0, len(users))
@ -51,7 +51,7 @@ class DatabaseTest(unittest.TestCase):
usercheck[user.id] = 1
self.assertEqual(3, len(usercheck))
def test_login(self):
def test_login(self) -> None:
with self.db as db:
with db.transaction() as c:
u = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
@ -80,7 +80,7 @@ class DatabaseTest(unittest.TestCase):
# Both password and touchkey should fail
db.login('testuser', password='supersecurepassword', touchkey='0123')
def test_change_password(self):
def test_change_password(self) -> None:
with self.db as db:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
db.login('testuser', 'supersecurepassword')
@ -103,7 +103,7 @@ class DatabaseTest(unittest.TestCase):
# Password change for an inexistent user should fail
db.change_password(user, 'adminpasswordreset', 'passwordwithoutuser')
def test_change_touchkey(self):
def test_change_touchkey(self) -> None:
with self.db as db:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
db.change_touchkey(user, 'supersecurepassword', '0123')
@ -127,7 +127,7 @@ class DatabaseTest(unittest.TestCase):
# Touchkey change for an inexistent user should fail
db.change_touchkey(user, '89ab', '048c')
def test_change_user(self):
def test_change_user(self) -> None:
with self.db as db:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
user.email = 'newaddress@example.com'
@ -142,7 +142,7 @@ class DatabaseTest(unittest.TestCase):
with self.assertRaises(DatabaseConsistencyError):
db.change_user(user)
def test_delete_user(self):
def test_delete_user(self) -> None:
with self.db as db:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
db.login('testuser', 'supersecurepassword')
@ -156,7 +156,7 @@ class DatabaseTest(unittest.TestCase):
# Should fail, as the user does not exist anymore
db.delete_user(user)
def test_create_product(self):
def test_create_product(self) -> None:
with self.db as db:
with db.transaction() as c:
db.create_product('Club Mate', 200, 200)
@ -169,7 +169,7 @@ class DatabaseTest(unittest.TestCase):
with self.assertRaises(ValueError):
db.create_product('Club Mate', 250, 250)
def test_list_products(self):
def test_list_products(self) -> None:
with self.db as db:
# Test empty list
products = db.list_products()
@ -193,7 +193,7 @@ class DatabaseTest(unittest.TestCase):
productcheck[product.id] = 1
self.assertEqual(3, len(productcheck))
def test_change_product(self):
def test_change_product(self) -> None:
with self.db as db:
product = db.create_product('Club Mate', 200, 200)
product.name = 'Flora Power Mate'
@ -213,7 +213,7 @@ class DatabaseTest(unittest.TestCase):
# Should fail, as a product with the same name already exists.
db.change_product(product2)
def test_delete_product(self):
def test_delete_product(self) -> None:
with self.db as db:
product = db.create_product('Club Mate', 200, 200)
product2 = db.create_product('Flora Power Mate', 200, 200)
@ -230,7 +230,7 @@ class DatabaseTest(unittest.TestCase):
# Should fail, as the product does not exist anymore
db.delete_product(product)
def test_deposit(self):
def test_deposit(self) -> None:
with self.db as db:
with db.transaction() as c:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
@ -252,7 +252,7 @@ class DatabaseTest(unittest.TestCase):
# Should fail, user id -1 does not exist
db.deposit(user, 42)
def test_restock(self):
def test_restock(self) -> None:
with self.db as db:
with db.transaction() as c:
product = db.create_product('Club Mate', 200, 200)
@ -271,7 +271,7 @@ class DatabaseTest(unittest.TestCase):
# Should fail, product id -1 does not exist
db.restock(product, 42)
def test_consumption(self):
def test_consumption(self) -> None:
with self.db as db:
# Set up test case
user1 = db.create_user('user1', 'supersecurepassword', 'testuser@example.com', member=True)

View file

@ -6,17 +6,17 @@ from matemat.db import DatabaseWrapper
class DatabaseTest(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.db = DatabaseWrapper(':memory:')
def test_create_schema(self):
def test_create_schema(self) -> None:
"""
Test creation of database schema in an empty database
"""
with self.db as db:
self.assertEqual(DatabaseWrapper.SCHEMA_VERSION, db._user_version)
def test_in_transaction(self):
def test_in_transaction(self) -> None:
"""
Test transaction tracking
"""
@ -26,7 +26,7 @@ class DatabaseTest(unittest.TestCase):
self.assertTrue(db.in_transaction())
self.assertFalse(db.in_transaction())
def test_transaction_nesting(self):
def test_transaction_nesting(self) -> None:
"""
Inner transactions should not do anything
"""
@ -43,7 +43,7 @@ class DatabaseTest(unittest.TestCase):
self.assertTrue(db.in_transaction())
self.assertFalse(db.in_transaction())
def test_transaction_commit(self):
def test_transaction_commit(self) -> None:
"""
If no error occurs, actions in a transaction should be committed.
"""
@ -57,7 +57,7 @@ class DatabaseTest(unittest.TestCase):
user = c.fetchone()
self.assertEqual((1, 'testuser', None, 'supersecurepassword', None, 1, 1, 0, 42), user)
def test_transaction_rollback(self):
def test_transaction_rollback(self) -> None:
"""
If an error occurs in a transaction, actions should be rolled back.
"""

View file

@ -1,4 +1,6 @@
from typing import Any
import apsw
from matemat.exceptions import DatabaseConsistencyError
@ -6,14 +8,14 @@ from matemat.exceptions import DatabaseConsistencyError
class Transaction(object):
def __init__(self, db: apsw.Connection, wrapper: 'DatabaseWrapper', exclusive: bool = True):
def __init__(self, db: apsw.Connection, wrapper: 'DatabaseWrapper', exclusive: bool = True) -> None:
self._db: apsw.Connection = db
self._cursor = None
self._excl = exclusive
self._wrapper: DatabaseWrapper = wrapper
self._is_dummy: bool = False
def __enter__(self):
def __enter__(self) -> Any:
if self._wrapper._in_transaction:
self._is_dummy = True
return self._db.cursor()
@ -27,7 +29,7 @@ class Transaction(object):
self._cursor.execute('BEGIN')
return self._cursor
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._is_dummy:
return
if exc_type is None:
@ -73,22 +75,22 @@ class DatabaseWrapper(object):
)
'''
def __init__(self, filename: str):
def __init__(self, filename: str) -> None:
self._filename: str = filename
self._sqlite_db: apsw.Connection = None
self._in_transaction: bool = False
def __enter__(self):
def __enter__(self) -> 'DatabaseWrapper':
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def transaction(self, exclusive: bool = True) -> Transaction:
return Transaction(self._sqlite_db, self, exclusive)
def _setup(self):
def _setup(self) -> None:
with self.transaction() as c:
version: int = self._user_version
if version < 1:
@ -97,16 +99,16 @@ class DatabaseWrapper(object):
self._upgrade(old=version, new=self.SCHEMA_VERSION)
self._user_version = self.SCHEMA_VERSION
def _upgrade(self, old: int, new: int):
def _upgrade(self, old: int, new: int) -> None:
pass
def connect(self):
def connect(self) -> None:
if self.is_connected():
raise RuntimeError(f'Database connection to {self._filename} is already established.')
self._sqlite_db = apsw.Connection(self._filename)
self._setup()
def close(self):
def close(self) -> None:
if self._sqlite_db is None:
raise RuntimeError(f'Database connection to {self._filename} is not established.')
if self.in_transaction():
@ -128,6 +130,6 @@ class DatabaseWrapper(object):
return version
@_user_version.setter
def _user_version(self, version: int):
def _user_version(self, version: int) -> None:
cursor = self._sqlite_db.cursor()
cursor.execute(f'PRAGMA user_version = {version}')

View file

@ -1,12 +1,16 @@
from typing import Optional
class AuthenticationError(BaseException):
def __init__(self, msg: str = None):
self._msg = msg
def __init__(self, msg: Optional[str] = None) -> None:
super().__init__()
self._msg: Optional[str] = msg
def __str__(self) -> str:
return f'AuthenticationError: {self._msg}'
@property
def msg(self) -> str:
def msg(self) -> Optional[str]:
return self._msg

View file

@ -1,12 +1,15 @@
from typing import Optional
class DatabaseConsistencyError(BaseException):
def __init__(self, msg: str = None):
self._msg = msg
def __init__(self, msg: Optional[str] = None) -> None:
self._msg: Optional[str] = msg
def __str__(self) -> str:
return f'DatabaseConsistencyError: {self._msg}'
@property
def msg(self) -> str:
def msg(self) -> Optional[str]:
return self._msg

View file

@ -1,16 +1,20 @@
from typing import Any
class Product(object):
def __init__(self,
product_id: int,
name: str,
price_member: int,
price_non_member: int):
price_non_member: int) -> None:
self._product_id: int = product_id
self._name: str = name
self._price_member: int = price_member
self._price_non_member: int = price_non_member
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if other is None or not isinstance(other, Product):
return False
return self._product_id == other._product_id \
@ -27,7 +31,7 @@ class Product(object):
return self._name
@name.setter
def name(self, name: str):
def name(self, name: str) -> None:
self._name = name
@property
@ -35,7 +39,7 @@ class Product(object):
return self._price_member
@price_member.setter
def price_member(self, price: int):
def price_member(self, price: int) -> None:
self._price_member = price
@property
@ -43,5 +47,5 @@ class Product(object):
return self._price_non_member
@price_non_member.setter
def price_non_member(self, price: int):
def price_non_member(self, price: int) -> None:
self._price_non_member = price

View file

@ -1,5 +1,5 @@
from typing import Optional
from typing import Optional, Any
class User(object):
@ -9,14 +9,14 @@ class User(object):
username: str,
email: Optional[str] = None,
admin: bool = False,
member: bool = True):
member: bool = True) -> None:
self._user_id: int = user_id
self._username: str = username
self._email: Optional[str] = email
self._admin: bool = admin
self._member: bool = member
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if other is None or not isinstance(other, User):
return False
return self._user_id == other._user_id \
@ -34,15 +34,11 @@ class User(object):
return self._username
@property
def email(self) -> str:
def email(self) -> Optional[str]:
return self._email
@email.setter
def email(self, email: str):
self._email = email
@email.setter
def email(self, email: str):
def email(self, email: str) -> None:
self._email = email
@property
@ -50,7 +46,7 @@ class User(object):
return self._admin
@is_admin.setter
def is_admin(self, admin: bool):
def is_admin(self, admin: bool) -> None:
self._admin = admin
@property
@ -58,5 +54,5 @@ class User(object):
return self._member
@is_member.setter
def is_member(self, member: bool):
def is_member(self, member: bool) -> None:
self._member = member

View file

@ -1,7 +1,6 @@
from typing import Tuple, Dict, Optional
from typing import Tuple, Dict
import socket
from http.server import HTTPServer, BaseHTTPRequestHandler
from http.cookies import SimpleCookie
from uuid import uuid4
@ -12,16 +11,16 @@ from matemat import __version__ as matemat_version
class MatematWebserver(object):
def __init__(self):
def __init__(self) -> None:
self._httpd = HTTPServer(('', 8080), HttpHandler)
def start(self):
def start(self) -> None:
self._httpd.serve_forever()
class HttpHandler(BaseHTTPRequestHandler):
def __init__(self, request: socket.socket, client_address: Tuple[str, int], server: HTTPServer):
def __init__(self, request: bytes, client_address: Tuple[str, int], server: HTTPServer) -> None:
super().__init__(request, client_address, server)
self._session_vars: Dict[str, Tuple[datetime, Dict[str, object]]] = dict()
print(self._session_vars)
@ -30,7 +29,7 @@ class HttpHandler(BaseHTTPRequestHandler):
def server_version(self) -> str:
return f'matemat/{matemat_version}'
def start_session(self) -> Optional[Tuple[str, datetime]]:
def start_session(self) -> Tuple[str, datetime]:
now = datetime.utcnow()
cookiestring = '\n'.join(self.headers.get_all('Cookie', failobj=[]))
cookie = SimpleCookie()
@ -42,13 +41,13 @@ class HttpHandler(BaseHTTPRequestHandler):
raise TimeoutError('Session timed out')
elif session_id not in self._session_vars:
self._session_vars[session_id] = (now + timedelta(hours=1)), dict()
return session_id
return session_id, now
def end_session(self, session_id: str):
def end_session(self, session_id: str) -> None:
if session_id in self._session_vars:
del self._session_vars[session_id]
def do_GET(self):
def do_GET(self) -> None:
try:
session_id, timeout = self.start_session()
except TimeoutError: