From 411372cc21eb27ae99c52dc8491ec099c4cd549f Mon Sep 17 00:00:00 2001 From: s3lph Date: Wed, 30 May 2018 01:57:26 +0200 Subject: [PATCH 1/3] Split database class into Facade (matemat database API) and Wrapper (internals). --- matemat/db/__init__.py | 3 +- matemat/db/{database.py => facade.py} | 237 +++++++----------- matemat/db/test/test_facade.py | 69 +++++ .../{test_database.py => test_wrapper.py} | 25 +- matemat/db/wrapper.py | 129 ++++++++++ ...atonException.py => AuthenticatonError.py} | 4 +- .../exceptions/DatabaseConsistencyError.py | 12 + matemat/exceptions/__init__.py | 3 +- 8 files changed, 306 insertions(+), 176 deletions(-) rename matemat/db/{database.py => facade.py} (60%) create mode 100644 matemat/db/test/test_facade.py rename matemat/db/test/{test_database.py => test_wrapper.py} (71%) create mode 100644 matemat/db/wrapper.py rename matemat/exceptions/{AuthenticatonException.py => AuthenticatonError.py} (61%) create mode 100644 matemat/exceptions/DatabaseConsistencyError.py diff --git a/matemat/db/__init__.py b/matemat/db/__init__.py index a0697c9..0f2bc0a 100644 --- a/matemat/db/__init__.py +++ b/matemat/db/__init__.py @@ -1,2 +1,3 @@ -from .database import Database +from .wrapper import DatabaseWrapper +from .facade import DatabaseFacade as Database diff --git a/matemat/db/database.py b/matemat/db/facade.py similarity index 60% rename from matemat/db/database.py rename to matemat/db/facade.py index c297721..6056e7c 100644 --- a/matemat/db/database.py +++ b/matemat/db/facade.py @@ -1,142 +1,31 @@ from typing import List, Optional -import apsw import bcrypt from matemat.primitives import User, Product -from matemat.exceptions import AuthenticationException +from matemat.exceptions import AuthenticationError, DatabaseConsistencyError +from matemat.db import DatabaseWrapper -class Transaction(object): - - def __init__(self, db: apsw.Connection, wrapper: 'Database', exclusive: bool = True): - self._db: apsw.Connection = db - self._cursor = None - self._excl = exclusive - self._wrapper: Database = wrapper - self._is_dummy: bool = False - - def __enter__(self): - if self._wrapper._in_transaction: - self._is_dummy = True - return self._db.cursor() - else: - self._is_dummy = False - self._cursor = self._db.cursor() - self._wrapper._in_transaction = True - if self._excl: - self._cursor.execute('BEGIN EXCLUSIVE') - else: - self._cursor.execute('BEGIN') - return self._cursor - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._is_dummy: - return - if exc_type is None: - self._cursor.execute('COMMIT') - else: - self._cursor.execute('ROLLBACK') - self._wrapper._in_transaction = False - - -class Database(object): - SCHEMA_VERSION = 1 - - SCHEMA = ''' - CREATE TABLE users ( - user_id INTEGER PRIMARY KEY, - username TEXT NOT NULL, - email TEXT DEFAULT NULL, - password TEXT NOT NULL, - touchkey TEXT DEFAULT NULL, - is_admin INTEGER(1) NOT NULL DEFAULT 0, - is_member INTEGER(1) NOT NULL DEFAULT 1, - balance INTEGER(8) NOT NULL DEFAULT 0, - lastchange INTEGER(8) NOT NULL DEFAULT 0 - ); - CREATE TABLE products ( - product_id INTEGER PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - stock INTEGER(8) NOT NULL DEFAULT 0, - price_member INTEGER(8) NOT NULL, - price_non_member INTEGER(8) NOT NULL - ); - CREATE TABLE consumption ( - user_id INTEGER NOT NULL, - product_id INTEGER NOT NULL, - count INTEGER(8) NOT NULL DEFAULT 0, - PRIMARY KEY (user_id, product_id), - FOREIGN KEY (user_id) REFERENCES users(user_id) - ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (product_id) REFERENCES products(product_id) - ON DELETE CASCADE ON UPDATE CASCADE - ) - ''' +class DatabaseFacade(object): def __init__(self, filename: str): - self._filename: str = filename - self._sqlite_db: apsw.Connection = None - self._in_transaction: bool = False + self.db: DatabaseWrapper = DatabaseWrapper(filename) def __enter__(self): - self.connect() + self.db.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + self.db.__exit__(exc_type, exc_val, exc_tb) - def transaction(self, exclusive: bool = True) -> Transaction: - return Transaction(self._sqlite_db, self, exclusive) - - def _setup(self): - with self.transaction() as c: - version: int = self._user_version - if version < 1: - c.execute(self.SCHEMA) - elif version < self.SCHEMA_VERSION: - self._upgrade(old=version, new=self.SCHEMA_VERSION) - self._user_version = self.SCHEMA_VERSION - - def _upgrade(self, old: int, new: int): - pass - - def connect(self): - if self.is_connected(): - raise RuntimeError(f'Database connection to {self._filename} is already established.') - self._sqlite_db = apsw.Connection(self._filename) - self._setup() - - def close(self): - if self._sqlite_db is None: - raise RuntimeError(f'Database connection to {self._filename} is not established.') - if self.in_transaction(): - raise RuntimeError(f'A transaction is still ongoing.') - self._sqlite_db.close() - self._sqlite_db = None - - def in_transaction(self) -> bool: - return self._in_transaction - - def is_connected(self) -> bool: - return self._sqlite_db is not None - - @property - def _user_version(self) -> int: - cursor = self._sqlite_db.cursor() - cursor.execute('PRAGMA user_version') - version = int(cursor.fetchone()[0]) - return version - - @_user_version.setter - def _user_version(self, version: int): - cursor = self._sqlite_db.cursor() - cursor.execute(f'PRAGMA user_version = {version}') + def transaction(self, exclusive: bool = True): + return self.db.transaction(exclusive=exclusive) def list_users(self) -> List[User]: users: List[User] = [] - with self.transaction(exclusive=False) as c: + with self.db.transaction(exclusive=False) as c: for row in c.execute(''' SELECT user_id, username, email, is_admin, is_member FROM users @@ -153,7 +42,7 @@ class Database(object): member: bool = True) -> User: pwhash: str = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12)) user_id: int = -1 - with self.transaction() as c: + with self.db.transaction() as c: c.execute('SELECT user_id FROM users WHERE username = ?', [username]) if c.fetchone() is not None: raise ValueError(f'A user with the name \'{username}\' already exists.') @@ -167,37 +56,43 @@ class Database(object): 'admin': admin, 'member': member }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'create_user should affect 1 users row, but affected {affected}') c.execute('SELECT last_insert_rowid()') user_id = int(c.fetchone()[0]) return User(user_id, username, email, admin, member) - def login(self, username: str, password: str) -> Optional[User]: - with self.transaction(exclusive=False) as c: + def login(self, username: str, password: Optional[str] = None, touchkey: Optional[str] = None) -> User: + if (password is None) == (touchkey is None): + raise ValueError('Exactly one of password and touchkey must be provided') + with self.db.transaction(exclusive=False) as c: c.execute(''' - SELECT user_id, username, email, password, is_admin, is_member + SELECT user_id, username, email, password, touchkey, is_admin, is_member FROM users WHERE username = ? ''', [username]) row = c.fetchone() if row is None: - return None - user_id, username, email, pwhash, admin, member = row - if not bcrypt.checkpw(password.encode('utf-8'), pwhash): - return None + raise AuthenticationError('User does not exist') + user_id, username, email, pwhash, tkhash, admin, member = row + if password is not None and not bcrypt.checkpw(password.encode('utf-8'), pwhash): + raise AuthenticationError('Password mismatch') + elif touchkey is not None and not bcrypt.checkpw(touchkey.encode('utf-8'), tkhash): + raise AuthenticationError('Touchkey mismatch') return User(user_id, username, email, admin, member) - def change_password(self, user: User, oldpass: str, newpass: str, newpass2: str, verify_password: bool = True): - if newpass != newpass2: - raise ValueError('New passwords don\'t match.') - with self.transaction() as c: + def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True): + with self.db.transaction() as c: c.execute(''' SELECT password FROM users WHERE user_id = ? ''', [user.id]) row = c.fetchone() if row is None: - raise AuthenticationException('User does not exist in database.') + raise AuthenticationError('User does not exist in database.') if verify_password and not bcrypt.checkpw(oldpass.encode('utf-8'), row[0]): - raise AuthenticationException('Old password does not match.') + raise AuthenticationError('Old password does not match.') pwhash: str = bcrypt.hashpw(newpass.encode('utf-8'), bcrypt.gensalt(12)) c.execute(''' UPDATE users SET password = :pwhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id @@ -205,17 +100,21 @@ class Database(object): 'user_id': user.id, 'pwhash': pwhash }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'change_password should affect 1 users row, but affected {affected}') def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True): - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' SELECT password FROM users WHERE user_id = ? ''', [user.id]) row = c.fetchone() if row is None: - raise AuthenticationException('User does not exist in database.') + raise AuthenticationError('User does not exist in database.') if verify_password and not bcrypt.checkpw(password.encode('utf-8'), row[0]): - raise AuthenticationException('Password does not match.') + raise AuthenticationError('Password does not match.') tkhash: str = bcrypt.hashpw(touchkey.encode('utf-8'), bcrypt.gensalt(12)) if touchkey is not None else None c.execute(''' UPDATE users SET touchkey = :tkhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id @@ -223,9 +122,13 @@ class Database(object): 'user_id': user.id, 'tkhash': tkhash }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'change_touchkey should affect 1 users row, but affected {affected}') def change_user(self, user: User): - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' UPDATE users SET email = :email, @@ -239,17 +142,25 @@ class Database(object): 'is_admin': user.is_admin, 'is_member': user.is_member }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'change_user should affect 1 users row, but affected {affected}') def delete_user(self, user: User): - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' DELETE FROM users WHERE user_id = ? ''', [user.id]) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'delete_user should affect 1 users row, but affected {affected}') def list_products(self) -> List[Product]: products: List[Product] = [] - with self.transaction(exclusive=False) as c: + with self.db.transaction(exclusive=False) as c: for row in c.execute(''' SELECT product_id, name, price_member, price_external FROM products @@ -260,7 +171,7 @@ class Database(object): def create_product(self, name: str, price_member: int, price_non_member: int) -> Product: product_id: int = -1 - with self.transaction() as c: + with self.db.transaction() as c: c.execute('SELECT product_id FROM products WHERE name = ?', [name]) if c.fetchone() is not None: raise ValueError(f'A product with the name \'{name}\' already exists.') @@ -273,13 +184,17 @@ class Database(object): 'price_non_member': price_non_member }) c.execute('SELECT last_insert_rowid()') + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'create_product should affect 1 products row, but affected {affected}') product_id = int(c.fetchone()[0]) return Product(product_id, name, price_member, price_non_member) def change_product(self, product: Product): if product.id == -1: raise ValueError('Invalid product ID') - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' UPDATE products SET @@ -292,20 +207,24 @@ class Database(object): 'price_member': product.price_member, 'price_non_member': product.price_non_member }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'change_product should affect 1 products row, but affected {affected}') def delete_product(self, product: Product): - if product.id == -1: - raise ValueError('Invalid product ID') - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' DELETE FROM products WHERE product_id = ? ''', [product.id]) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'delete_product should affect 1 products row, but affected {affected}') def increment_consumption(self, user: User, product: Product, count: int = 1): - if product.id == -1: - raise ValueError('Invalid product ID') - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' SELECT count FROM consumption @@ -332,6 +251,10 @@ class Database(object): 'product_id': product.id, 'count': count }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'increment_consumption should affect 1 consumption row, but affected {affected}') c.execute(''' UPDATE users SET balance = balance - :cost @@ -339,6 +262,10 @@ class Database(object): 'user_id': user.id, 'cost': count * product.price_member if user.is_member else count * product.price_non_member }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'increment_consumption should affect 1 users row, but affected {affected}') c.execute(''' UPDATE products SET stock = stock - :count @@ -347,11 +274,15 @@ class Database(object): 'product_id': product.id, 'count': count }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'increment_consumption should affect 1 products row, but affected {affected}') def restock(self, product: Product, count: int): if product.id == -1: raise ValueError('Invalid product ID') - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' UPDATE products SET stock = stock + :count @@ -360,9 +291,12 @@ class Database(object): 'product_id': product.id, 'count': count }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}') def deposit(self, user: User, amount: int): - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' UPDATE users SET balance = balance + :amount @@ -371,3 +305,6 @@ class Database(object): 'user_id': user.id, 'amount': amount }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError(f'deposit should affect 1 users row, but affected {affected}') diff --git a/matemat/db/test/test_facade.py b/matemat/db/test/test_facade.py new file mode 100644 index 0000000..049edcb --- /dev/null +++ b/matemat/db/test/test_facade.py @@ -0,0 +1,69 @@ + +import unittest + +from matemat.db import Database +from matemat.exceptions import AuthenticationError, DatabaseConsistencyError + + +class DatabaseTest(unittest.TestCase): + + def setUp(self): + self.db = Database(':memory:') + + def test_create_user(self): + with self.db as db: + with db.transaction() as c: + db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + c.execute("SELECT * FROM users") + row = c.fetchone() + if row is None: + self.fail() + self.assertEqual('testuser', row[1]) + self.assertEqual('testuser@example.com', row[2]) + self.assertEqual(0, row[5]) + self.assertEqual(1, row[6]) + + def test_create_existing_user(self): + with self.db as db: + db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + with self.assertRaises(ValueError): + db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com') + + def test_change_password(self): + with self.db as db: + user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + db.login('testuser', 'supersecurepassword') + # Normal password change should succeed + db.change_password(user, 'supersecurepassword', 'mynewpassword') + with self.assertRaises(AuthenticationError): + db.login('testuser', 'supersecurepassword') + db.login('testuser', 'mynewpassword') + with self.assertRaises(AuthenticationError): + # Should fail due to wrong old password + db.change_password(user, 'iforgotmypassword', 'mynewpassword') + db.login('testuser', 'mynewpassword') + # This should pass even though the old password is not known (admin password reset) + db.change_password(user, '42', 'adminpasswordreset', verify_password=False) + with self.assertRaises(AuthenticationError): + db.login('testuser', 'mynewpassword') + db.login('testuser', 'adminpasswordreset') + + def test_change_touchkey(self): + with self.db as db: + user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + db.change_touchkey(user, 'supersecurepassword', '0123') + db.login('testuser', touchkey='0123') + # Normal touchkey change should succeed + db.change_touchkey(user, 'supersecurepassword', touchkey='4567') + with self.assertRaises(AuthenticationError): + db.login('testuser', touchkey='0123') + db.login('testuser', touchkey='4567') + with self.assertRaises(AuthenticationError): + # Should fail due to wrong old password + db.change_touchkey(user, 'iforgotmypassword', '89ab') + db.login('testuser', touchkey='4567') + # This should pass even though the old password is not known (admin password reset) + db.change_touchkey(user, '42', '89ab', verify_password=False) + with self.assertRaises(AuthenticationError): + db.login('testuser', touchkey='4567') + db.login('testuser', touchkey='89ab') diff --git a/matemat/db/test/test_database.py b/matemat/db/test/test_wrapper.py similarity index 71% rename from matemat/db/test/test_database.py rename to matemat/db/test/test_wrapper.py index a03fbd4..334ecdf 100644 --- a/matemat/db/test/test_database.py +++ b/matemat/db/test/test_wrapper.py @@ -1,20 +1,20 @@ import unittest -from matemat.db import Database +from matemat.db import DatabaseWrapper class DatabaseTest(unittest.TestCase): def setUp(self): - self.db = Database(':memory:') + self.db = DatabaseWrapper(':memory:') def test_create_schema(self): """ Test creation of database schema in an empty database """ with self.db as db: - self.assertEqual(Database.SCHEMA_VERSION, db._user_version) + self.assertEqual(DatabaseWrapper.SCHEMA_VERSION, db._user_version) def test_in_transaction(self): """ @@ -74,22 +74,3 @@ class DatabaseTest(unittest.TestCase): c = db._sqlite_db.cursor() c.execute("SELECT * FROM users") self.assertIsNone(c.fetchone()) - - def test_create_user(self): - with self.db as db: - with db.transaction() as c: - db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') - c.execute("SELECT * FROM users") - row = c.fetchone() - if row is None: - self.fail() - self.assertEqual('testuser', row[1]) - self.assertEqual('testuser@example.com', row[2]) - self.assertEqual(0, row[5]) - self.assertEqual(1, row[6]) - - def test_create_existing_user(self): - with self.db as db: - db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') - with self.assertRaises(ValueError): - db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com') diff --git a/matemat/db/wrapper.py b/matemat/db/wrapper.py new file mode 100644 index 0000000..3b92936 --- /dev/null +++ b/matemat/db/wrapper.py @@ -0,0 +1,129 @@ + +import apsw + + +class Transaction(object): + + def __init__(self, db: apsw.Connection, wrapper: 'DatabaseWrapper', exclusive: bool = True): + self._db: apsw.Connection = db + self._cursor = None + self._excl = exclusive + self._wrapper: DatabaseWrapper = wrapper + self._is_dummy: bool = False + + def __enter__(self): + if self._wrapper._in_transaction: + self._is_dummy = True + return self._db.cursor() + else: + self._is_dummy = False + self._cursor = self._db.cursor() + self._wrapper._in_transaction = True + if self._excl: + self._cursor.execute('BEGIN EXCLUSIVE') + else: + self._cursor.execute('BEGIN') + return self._cursor + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._is_dummy: + return + if exc_type is None: + self._cursor.execute('COMMIT') + else: + self._cursor.execute('ROLLBACK') + self._wrapper._in_transaction = False + + +class DatabaseWrapper(object): + SCHEMA_VERSION = 1 + + SCHEMA = ''' + CREATE TABLE users ( + user_id INTEGER PRIMARY KEY, + username TEXT NOT NULL, + email TEXT DEFAULT NULL, + password TEXT NOT NULL, + touchkey TEXT DEFAULT NULL, + is_admin INTEGER(1) NOT NULL DEFAULT 0, + is_member INTEGER(1) NOT NULL DEFAULT 1, + balance INTEGER(8) NOT NULL DEFAULT 0, + lastchange INTEGER(8) NOT NULL DEFAULT 0 + ); + CREATE TABLE products ( + product_id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + stock INTEGER(8) NOT NULL DEFAULT 0, + price_member INTEGER(8) NOT NULL, + price_non_member INTEGER(8) NOT NULL + ); + CREATE TABLE consumption ( + user_id INTEGER NOT NULL, + product_id INTEGER NOT NULL, + count INTEGER(8) NOT NULL DEFAULT 0, + PRIMARY KEY (user_id, product_id), + FOREIGN KEY (user_id) REFERENCES users(user_id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (product_id) REFERENCES products(product_id) + ON DELETE CASCADE ON UPDATE CASCADE + ) + ''' + + def __init__(self, filename: str): + self._filename: str = filename + self._sqlite_db: apsw.Connection = None + self._in_transaction: bool = False + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def transaction(self, exclusive: bool = True) -> Transaction: + return Transaction(self._sqlite_db, self, exclusive) + + def _setup(self): + with self.transaction() as c: + version: int = self._user_version + if version < 1: + c.execute(self.SCHEMA) + elif version < self.SCHEMA_VERSION: + self._upgrade(old=version, new=self.SCHEMA_VERSION) + self._user_version = self.SCHEMA_VERSION + + def _upgrade(self, old: int, new: int): + pass + + def connect(self): + if self.is_connected(): + raise RuntimeError(f'Database connection to {self._filename} is already established.') + self._sqlite_db = apsw.Connection(self._filename) + self._setup() + + def close(self): + if self._sqlite_db is None: + raise RuntimeError(f'Database connection to {self._filename} is not established.') + if self.in_transaction(): + raise RuntimeError(f'A transaction is still ongoing.') + self._sqlite_db.close() + self._sqlite_db = None + + def in_transaction(self) -> bool: + return self._in_transaction + + def is_connected(self) -> bool: + return self._sqlite_db is not None + + @property + def _user_version(self) -> int: + cursor = self._sqlite_db.cursor() + cursor.execute('PRAGMA user_version') + version = int(cursor.fetchone()[0]) + return version + + @_user_version.setter + def _user_version(self, version: int): + cursor = self._sqlite_db.cursor() + cursor.execute(f'PRAGMA user_version = {version}') diff --git a/matemat/exceptions/AuthenticatonException.py b/matemat/exceptions/AuthenticatonError.py similarity index 61% rename from matemat/exceptions/AuthenticatonException.py rename to matemat/exceptions/AuthenticatonError.py index 55eed7c..4494a33 100644 --- a/matemat/exceptions/AuthenticatonException.py +++ b/matemat/exceptions/AuthenticatonError.py @@ -1,11 +1,11 @@ -class AuthenticationException(BaseException): +class AuthenticationError(BaseException): def __init__(self, msg: str = None): self._msg = msg def __str__(self) -> str: - return f'AuthenticationException: {self._msg}' + return f'AuthenticationErro: {self._msg}' @property def msg(self) -> str: diff --git a/matemat/exceptions/DatabaseConsistencyError.py b/matemat/exceptions/DatabaseConsistencyError.py new file mode 100644 index 0000000..d07d7a1 --- /dev/null +++ b/matemat/exceptions/DatabaseConsistencyError.py @@ -0,0 +1,12 @@ + +class DatabaseConsistencyError(BaseException): + + def __init__(self, msg: str = None): + self._msg = msg + + def __str__(self) -> str: + return f'DatabaseConsistencyError: {self._msg}' + + @property + def msg(self) -> str: + return self._msg diff --git a/matemat/exceptions/__init__.py b/matemat/exceptions/__init__.py index 08617cb..cacbbb9 100644 --- a/matemat/exceptions/__init__.py +++ b/matemat/exceptions/__init__.py @@ -1,2 +1,3 @@ -from .AuthenticatonException import AuthenticationException +from .AuthenticatonError import AuthenticationError +from .DatabaseConsistencyError import DatabaseConsistencyError From c238e8e9c8ebae68a9c3f36f045662f19e57ccd7 Mon Sep 17 00:00:00 2001 From: s3lph Date: Wed, 30 May 2018 02:09:32 +0200 Subject: [PATCH 2/3] test_facade: Added a login test case --- matemat/db/test/test_facade.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/matemat/db/test/test_facade.py b/matemat/db/test/test_facade.py index 049edcb..859ea37 100644 --- a/matemat/db/test/test_facade.py +++ b/matemat/db/test/test_facade.py @@ -1,6 +1,8 @@ import unittest +import bcrypt + from matemat.db import Database from matemat.exceptions import AuthenticationError, DatabaseConsistencyError @@ -29,6 +31,36 @@ class DatabaseTest(unittest.TestCase): with self.assertRaises(ValueError): db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com') + def test_login(self): + with self.db as db: + with db.transaction() as c: + u = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + # Add a touchkey without using the provided function + c.execute('''UPDATE users SET touchkey = :tkhash WHERE user_id = :user_id''', { + 'tkhash': bcrypt.hashpw(b'0123', bcrypt.gensalt(12)), + 'user_id': u.id + }) + user = db.login('testuser', 'supersecurepassword') + self.assertEqual(u.id, user.id) + user = db.login('testuser', touchkey='0123') + self.assertEqual(u.id, user.id) + with self.assertRaises(AuthenticationError): + # Inexistent user should fail + db.login('nooone', 'supersecurepassword') + with self.assertRaises(AuthenticationError): + # Wrong password should fail + db.login('testuser', 'anothersecurepassword') + with self.assertRaises(AuthenticationError): + # Wrong touchkey should fail + db.login('testuser', touchkey='0124') + with self.assertRaises(ValueError): + # No password or touchkey should fail + db.login('testuser') + with self.assertRaises(ValueError): + # Both password and touchkey should fail + db.login('testuser', password='supersecurepassword', touchkey='0123') + + def test_change_password(self): with self.db as db: user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') From 01b0b95652848193ba697e6ea779e122b82e87fd Mon Sep 17 00:00:00 2001 From: s3lph Date: Tue, 5 Jun 2018 19:14:35 +0200 Subject: [PATCH 3/3] Added lots of unit tests for the database facade, and already fixed some bugs. --- .gitlab-ci.yml | 2 +- matemat/db/facade.py | 32 +--- matemat/db/test/test_facade.py | 228 ++++++++++++++++++++++- matemat/db/wrapper.py | 4 + matemat/exceptions/AuthenticatonError.py | 2 +- matemat/primitives/Product.py | 14 +- matemat/primitives/User.py | 11 +- 7 files changed, 257 insertions(+), 36 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d783cc3..c1ca31e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -8,4 +8,4 @@ test: - pip3 install wheel - pip3 install -r requirements.txt - python3-coverage run --branch -m unittest discover matemat - - python3-coverage report -m + - python3-coverage report -m --include 'matemat/*' --omit '*/test_*.py' diff --git a/matemat/db/facade.py b/matemat/db/facade.py index 6056e7c..abd1829 100644 --- a/matemat/db/facade.py +++ b/matemat/db/facade.py @@ -56,10 +56,6 @@ class DatabaseFacade(object): 'admin': admin, 'member': member }) - affected = c.execute('SELECT changes()').fetchone()[0] - if affected != 1: - raise DatabaseConsistencyError( - f'create_user should affect 1 users row, but affected {affected}') c.execute('SELECT last_insert_rowid()') user_id = int(c.fetchone()[0]) return User(user_id, username, email, admin, member) @@ -100,10 +96,6 @@ class DatabaseFacade(object): 'user_id': user.id, 'pwhash': pwhash }) - affected = c.execute('SELECT changes()').fetchone()[0] - if affected != 1: - raise DatabaseConsistencyError( - f'change_password should affect 1 users row, but affected {affected}') def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True): with self.db.transaction() as c: @@ -122,10 +114,6 @@ class DatabaseFacade(object): 'user_id': user.id, 'tkhash': tkhash }) - affected = c.execute('SELECT changes()').fetchone()[0] - if affected != 1: - raise DatabaseConsistencyError( - f'change_touchkey should affect 1 users row, but affected {affected}') def change_user(self, user: User): with self.db.transaction() as c: @@ -162,7 +150,7 @@ class DatabaseFacade(object): products: List[Product] = [] with self.db.transaction(exclusive=False) as c: for row in c.execute(''' - SELECT product_id, name, price_member, price_external + SELECT product_id, name, price_member, price_non_member FROM products '''): product_id, name, price_member, price_external = row @@ -184,16 +172,10 @@ class DatabaseFacade(object): 'price_non_member': price_non_member }) c.execute('SELECT last_insert_rowid()') - affected = c.execute('SELECT changes()').fetchone()[0] - if affected != 1: - raise DatabaseConsistencyError( - f'create_product should affect 1 products row, but affected {affected}') product_id = int(c.fetchone()[0]) return Product(product_id, name, price_member, price_non_member) def change_product(self, product: Product): - if product.id == -1: - raise ValueError('Invalid product ID') with self.db.transaction() as c: c.execute(''' UPDATE products @@ -201,9 +183,10 @@ class DatabaseFacade(object): name = :name, price_member = :price_member, price_non_member = :price_non_member - WHERE product_id = :product_is + WHERE product_id = :product_id ''', { 'product_id': product.id, + 'name': product.name, 'price_member': product.price_member, 'price_non_member': product.price_non_member }) @@ -230,7 +213,10 @@ class DatabaseFacade(object): FROM consumption WHERE user_id = :user_id AND product_id = :product_id - ''') + ''', { + 'user_id': user.id, + 'product_id': product.id + }) row = c.fetchone() if row is None: c.execute(''' @@ -280,8 +266,6 @@ class DatabaseFacade(object): f'increment_consumption should affect 1 products row, but affected {affected}') def restock(self, product: Product, count: int): - if product.id == -1: - raise ValueError('Invalid product ID') with self.db.transaction() as c: c.execute(''' UPDATE products @@ -296,6 +280,8 @@ class DatabaseFacade(object): raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}') def deposit(self, user: User, amount: int): + if amount < 0: + raise ValueError('Cannot deposit a negative value') with self.db.transaction() as c: c.execute(''' UPDATE users diff --git a/matemat/db/test/test_facade.py b/matemat/db/test/test_facade.py index 859ea37..03c893f 100644 --- a/matemat/db/test/test_facade.py +++ b/matemat/db/test/test_facade.py @@ -14,22 +14,42 @@ class DatabaseTest(unittest.TestCase): def test_create_user(self): with self.db as db: - with db.transaction() as c: + with db.transaction(exclusive=False) as c: db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') c.execute("SELECT * FROM users") row = c.fetchone() - if row is None: - self.fail() self.assertEqual('testuser', row[1]) self.assertEqual('testuser@example.com', row[2]) self.assertEqual(0, row[5]) self.assertEqual(1, row[6]) + with self.assertRaises(ValueError): + db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com') - def test_create_existing_user(self): + def test_list_users(self): with self.db as db: - db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') - with self.assertRaises(ValueError): - db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com') + users = db.list_users() + self.assertEqual(0, len(users)) + db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True) + db.create_user('anothertestuser', 'otherpassword', 'anothertestuser@example.com', False, True) + db.create_user('yatu', 'igotapasswordtoo', 'yatu@example.com', False, False) + users = db.list_users() + self.assertEqual(3, len(users)) + usercheck = {} + for user in users: + if user.name == 'testuser': + self.assertEqual('testuser@example.com', user.email) + self.assertTrue(user.is_member) + self.assertTrue(user.is_admin) + elif user.name == 'anothertestuser': + self.assertEqual('anothertestuser@example.com', user.email) + self.assertTrue(user.is_member) + self.assertFalse(user.is_admin) + elif user.name == 'yatu': + self.assertEqual('yatu@example.com', user.email) + self.assertFalse(user.is_member) + self.assertFalse(user.is_admin) + usercheck[user.id] = 1 + self.assertEqual(3, len(usercheck)) def test_login(self): with self.db as db: @@ -60,7 +80,6 @@ class DatabaseTest(unittest.TestCase): # Both password and touchkey should fail db.login('testuser', password='supersecurepassword', touchkey='0123') - def test_change_password(self): with self.db as db: user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') @@ -79,6 +98,10 @@ class DatabaseTest(unittest.TestCase): with self.assertRaises(AuthenticationError): db.login('testuser', 'mynewpassword') db.login('testuser', 'adminpasswordreset') + user._user_id = -1 + with self.assertRaises(AuthenticationError): + # Password change for an inexistent user should fail + db.change_password(user, 'adminpasswordreset', 'passwordwithoutuser') def test_change_touchkey(self): with self.db as db: @@ -99,3 +122,192 @@ class DatabaseTest(unittest.TestCase): with self.assertRaises(AuthenticationError): db.login('testuser', touchkey='4567') db.login('testuser', touchkey='89ab') + user._user_id = -1 + with self.assertRaises(AuthenticationError): + # Touchkey change for an inexistent user should fail + db.change_touchkey(user, '89ab', '048c') + + def test_change_user(self): + with self.db as db: + user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True) + user.email = 'newaddress@example.com' + user.is_admin = False + user.is_member = False + db.change_user(user) + checkuser = db.login('testuser', 'supersecurepassword') + self.assertEqual('newaddress@example.com', checkuser.email) + self.assertFalse(checkuser.is_admin) + self.assertFalse(checkuser.is_member) + user._user_id = -1 + with self.assertRaises(DatabaseConsistencyError): + db.change_user(user) + + def test_delete_user(self): + with self.db as db: + user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True) + db.login('testuser', 'supersecurepassword') + db.delete_user(user) + try: + # Should fail, as the user does not exist anymore + db.login('testuser', 'supersecurepassword') + except AuthenticationError as e: + self.assertEqual('User does not exist', e.msg) + with self.assertRaises(DatabaseConsistencyError): + # Should fail, as the user does not exist anymore + db.delete_user(user) + + def test_create_product(self): + with self.db as db: + with db.transaction() as c: + db.create_product('Club Mate', 200, 200) + c.execute("SELECT * FROM products") + row = c.fetchone() + self.assertEqual('Club Mate', row[1]) + self.assertEqual(0, row[2]) + self.assertEqual(200, row[3]) + self.assertEqual(200, row[4]) + with self.assertRaises(ValueError): + db.create_product('Club Mate', 250, 250) + + def test_list_products(self): + with self.db as db: + # Test empty list + products = db.list_products() + self.assertEqual(0, len(products)) + db.create_product('Club Mate', 200, 200) + db.create_product('Flora Power Mate', 200, 200) + db.create_product('Fritz Mate', 200, 250) + products = db.list_products() + self.assertEqual(3, len(products)) + productcheck = {} + for product in products: + if product.name == 'Club Mate': + self.assertEqual(200, product.price_member) + self.assertEqual(200, product.price_non_member) + elif product.name == 'Flora Power Mate': + self.assertEqual(200, product.price_member) + self.assertEqual(200, product.price_non_member) + elif product.name == 'Fritz Mate': + self.assertEqual(200, product.price_member) + self.assertEqual(250, product.price_non_member) + productcheck[product.id] = 1 + self.assertEqual(3, len(productcheck)) + + def test_change_product(self): + with self.db as db: + product = db.create_product('Club Mate', 200, 200) + product.name = 'Flora Power Mate' + product.price_member = 150 + product.price_non_member = 250 + db.change_product(product) + checkproduct = db.list_products()[0] + self.assertEqual('Flora Power Mate', checkproduct.name) + self.assertEqual(150, checkproduct.price_member) + self.assertEqual(250, checkproduct.price_non_member) + product._product_id = -1 + with self.assertRaises(DatabaseConsistencyError): + db.change_product(product) + product2 = db.create_product('Club Mate', 200, 200) + product2.name = 'Flora Power Mate' + with self.assertRaises(DatabaseConsistencyError): + # Should fail, as a product with the same name already exists. + db.change_product(product2) + + def test_delete_product(self): + with self.db as db: + product = db.create_product('Club Mate', 200, 200) + product2 = db.create_product('Flora Power Mate', 200, 200) + + self.assertEqual(2, len(db.list_products())) + db.delete_product(product) + + # Only the other product should remain in the DB + products = db.list_products() + self.assertEqual(1, len(products)) + self.assertEqual(product2, products[0]) + + with self.assertRaises(DatabaseConsistencyError): + # Should fail, as the product does not exist anymore + db.delete_product(product) + + def test_deposit(self): + with self.db as db: + with db.transaction() as c: + user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True) + user2 = db.create_user('testuser2', 'supersecurepassword', 'testuser@example.com', True, True) + c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user.id]) + self.assertEqual(0, c.fetchone()[0]) + c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id]) + self.assertEqual(0, c.fetchone()[0]) + db.deposit(user, 1337) + c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user.id]) + self.assertEqual(1337, c.fetchone()[0]) + c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id]) + self.assertEqual(0, c.fetchone()[0]) + with self.assertRaises(ValueError): + # Should fail, negative amount + db.deposit(user, -42) + user._user_id = -1 + with self.assertRaises(DatabaseConsistencyError): + # Should fail, user id -1 does not exist + db.deposit(user, 42) + + def test_restock(self): + with self.db as db: + with db.transaction() as c: + product = db.create_product('Club Mate', 200, 200) + product2 = db.create_product('Flora Power Mate', 200, 200) + c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product.id]) + self.assertEqual(0, c.fetchone()[0]) + c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product2.id]) + self.assertEqual(0, c.fetchone()[0]) + db.restock(product, 42) + c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product.id]) + self.assertEqual(42, c.fetchone()[0]) + c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product2.id]) + self.assertEqual(0, c.fetchone()[0]) + product._product_id = -1 + with self.assertRaises(DatabaseConsistencyError): + # Should fail, product id -1 does not exist + db.restock(product, 42) + + def test_consumption(self): + with self.db as db: + # Set up test case + user1 = db.create_user('user1', 'supersecurepassword', 'testuser@example.com', member=True) + user2 = db.create_user('user2', 'supersecurepassword', 'testuser@example.com', member=False) + user3 = db.create_user('user3', 'supersecurepassword', 'testuser@example.com', member=False) + db.deposit(user1, 1337) + db.deposit(user2, 4242) + db.deposit(user3, 1234) + clubmate = db.create_product('Club Mate', 200, 200) + florapowermate = db.create_product('Flora Power Mate', 150, 250) + fritzmate = db.create_product('Fritz Mate', 200, 200) + db.restock(clubmate, 50) + db.restock(florapowermate, 70) + db.restock(fritzmate, 10) + + # user1 is somewhat addicted to caffeine + db.increment_consumption(user1, clubmate, 1) + db.increment_consumption(user1, clubmate, 2) + db.increment_consumption(user1, florapowermate, 3) + + # user2 is reeeally addicted + db.increment_consumption(user2, clubmate, 7) + db.increment_consumption(user2, florapowermate, 3) + db.increment_consumption(user2, florapowermate, 4) + + with db.transaction(exclusive=False) as c: + c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user1.id]) + self.assertEqual(1337 - 200 * 3 - 150 * 3, c.fetchone()[0]) + c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id]) + self.assertEqual(4242 - 200 * 7 - 250 * 7, c.fetchone()[0]) + c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user3.id]) + self.assertEqual(1234, c.fetchone()[0]) + + c.execute('''SELECT stock FROM products WHERE product_id = ?''', [clubmate.id]) + self.assertEqual(40, c.fetchone()[0]) + c.execute('''SELECT stock FROM products WHERE product_id = ?''', [florapowermate.id]) + self.assertEqual(60, c.fetchone()[0]) + c.execute('''SELECT stock FROM products WHERE product_id = ?''', [fritzmate.id]) + self.assertEqual(10, c.fetchone()[0]) diff --git a/matemat/db/wrapper.py b/matemat/db/wrapper.py index 3b92936..16b9b97 100644 --- a/matemat/db/wrapper.py +++ b/matemat/db/wrapper.py @@ -1,6 +1,8 @@ import apsw +from matemat.exceptions import DatabaseConsistencyError + class Transaction(object): @@ -33,6 +35,8 @@ class Transaction(object): else: self._cursor.execute('ROLLBACK') self._wrapper._in_transaction = False + if exc_type == apsw.ConstraintError: + raise DatabaseConsistencyError(str(exc_val)) class DatabaseWrapper(object): diff --git a/matemat/exceptions/AuthenticatonError.py b/matemat/exceptions/AuthenticatonError.py index 4494a33..5618550 100644 --- a/matemat/exceptions/AuthenticatonError.py +++ b/matemat/exceptions/AuthenticatonError.py @@ -5,7 +5,7 @@ class AuthenticationError(BaseException): self._msg = msg def __str__(self) -> str: - return f'AuthenticationErro: {self._msg}' + return f'AuthenticationError: {self._msg}' @property def msg(self) -> str: diff --git a/matemat/primitives/Product.py b/matemat/primitives/Product.py index f3d1674..2f08c7e 100644 --- a/matemat/primitives/Product.py +++ b/matemat/primitives/Product.py @@ -1,4 +1,3 @@ - class Product(object): def __init__(self, @@ -11,6 +10,14 @@ class Product(object): self._price_member: int = price_member self._price_non_member: int = price_non_member + def __eq__(self, other): + if other is None or not isinstance(other, Product): + return False + return self._product_id == other._product_id \ + and self._name == other._name \ + and self._price_member == other._price_member \ + and self._price_non_member == other._price_non_member + @property def id(self) -> int: return self._product_id @@ -19,6 +26,11 @@ class Product(object): def name(self) -> str: return self._name + @name.setter + def name(self, name: str): + self._name = name + + @property def price_member(self) -> int: return self._price_member diff --git a/matemat/primitives/User.py b/matemat/primitives/User.py index 18fabb1..847c05b 100644 --- a/matemat/primitives/User.py +++ b/matemat/primitives/User.py @@ -10,14 +10,21 @@ class User(object): email: Optional[str] = None, admin: bool = False, member: bool = True): - if user_id == -1: - raise ValueError('Invalid user ID') self._user_id: int = user_id self._username: str = username self._email: Optional[str] = email self._admin: bool = admin self._member: bool = member + def __eq__(self, other): + if other is None or not isinstance(other, User): + return False + return self._user_id == other._user_id \ + and self._username == other._username \ + and self._email == other._email \ + and self._admin == other._admin \ + and self._member == other._member + @property def id(self) -> int: return self._user_id