diff --git a/matemat/db/__init__.py b/matemat/db/__init__.py index a0697c9..0f2bc0a 100644 --- a/matemat/db/__init__.py +++ b/matemat/db/__init__.py @@ -1,2 +1,3 @@ -from .database import Database +from .wrapper import DatabaseWrapper +from .facade import DatabaseFacade as Database diff --git a/matemat/db/database.py b/matemat/db/facade.py similarity index 60% rename from matemat/db/database.py rename to matemat/db/facade.py index c297721..6056e7c 100644 --- a/matemat/db/database.py +++ b/matemat/db/facade.py @@ -1,142 +1,31 @@ from typing import List, Optional -import apsw import bcrypt from matemat.primitives import User, Product -from matemat.exceptions import AuthenticationException +from matemat.exceptions import AuthenticationError, DatabaseConsistencyError +from matemat.db import DatabaseWrapper -class Transaction(object): - - def __init__(self, db: apsw.Connection, wrapper: 'Database', exclusive: bool = True): - self._db: apsw.Connection = db - self._cursor = None - self._excl = exclusive - self._wrapper: Database = wrapper - self._is_dummy: bool = False - - def __enter__(self): - if self._wrapper._in_transaction: - self._is_dummy = True - return self._db.cursor() - else: - self._is_dummy = False - self._cursor = self._db.cursor() - self._wrapper._in_transaction = True - if self._excl: - self._cursor.execute('BEGIN EXCLUSIVE') - else: - self._cursor.execute('BEGIN') - return self._cursor - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._is_dummy: - return - if exc_type is None: - self._cursor.execute('COMMIT') - else: - self._cursor.execute('ROLLBACK') - self._wrapper._in_transaction = False - - -class Database(object): - SCHEMA_VERSION = 1 - - SCHEMA = ''' - CREATE TABLE users ( - user_id INTEGER PRIMARY KEY, - username TEXT NOT NULL, - email TEXT DEFAULT NULL, - password TEXT NOT NULL, - touchkey TEXT DEFAULT NULL, - is_admin INTEGER(1) NOT NULL DEFAULT 0, - is_member INTEGER(1) NOT NULL DEFAULT 1, - balance INTEGER(8) NOT NULL DEFAULT 0, - lastchange INTEGER(8) NOT NULL DEFAULT 0 - ); - CREATE TABLE products ( - product_id INTEGER PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - stock INTEGER(8) NOT NULL DEFAULT 0, - price_member INTEGER(8) NOT NULL, - price_non_member INTEGER(8) NOT NULL - ); - CREATE TABLE consumption ( - user_id INTEGER NOT NULL, - product_id INTEGER NOT NULL, - count INTEGER(8) NOT NULL DEFAULT 0, - PRIMARY KEY (user_id, product_id), - FOREIGN KEY (user_id) REFERENCES users(user_id) - ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (product_id) REFERENCES products(product_id) - ON DELETE CASCADE ON UPDATE CASCADE - ) - ''' +class DatabaseFacade(object): def __init__(self, filename: str): - self._filename: str = filename - self._sqlite_db: apsw.Connection = None - self._in_transaction: bool = False + self.db: DatabaseWrapper = DatabaseWrapper(filename) def __enter__(self): - self.connect() + self.db.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + self.db.__exit__(exc_type, exc_val, exc_tb) - def transaction(self, exclusive: bool = True) -> Transaction: - return Transaction(self._sqlite_db, self, exclusive) - - def _setup(self): - with self.transaction() as c: - version: int = self._user_version - if version < 1: - c.execute(self.SCHEMA) - elif version < self.SCHEMA_VERSION: - self._upgrade(old=version, new=self.SCHEMA_VERSION) - self._user_version = self.SCHEMA_VERSION - - def _upgrade(self, old: int, new: int): - pass - - def connect(self): - if self.is_connected(): - raise RuntimeError(f'Database connection to {self._filename} is already established.') - self._sqlite_db = apsw.Connection(self._filename) - self._setup() - - def close(self): - if self._sqlite_db is None: - raise RuntimeError(f'Database connection to {self._filename} is not established.') - if self.in_transaction(): - raise RuntimeError(f'A transaction is still ongoing.') - self._sqlite_db.close() - self._sqlite_db = None - - def in_transaction(self) -> bool: - return self._in_transaction - - def is_connected(self) -> bool: - return self._sqlite_db is not None - - @property - def _user_version(self) -> int: - cursor = self._sqlite_db.cursor() - cursor.execute('PRAGMA user_version') - version = int(cursor.fetchone()[0]) - return version - - @_user_version.setter - def _user_version(self, version: int): - cursor = self._sqlite_db.cursor() - cursor.execute(f'PRAGMA user_version = {version}') + def transaction(self, exclusive: bool = True): + return self.db.transaction(exclusive=exclusive) def list_users(self) -> List[User]: users: List[User] = [] - with self.transaction(exclusive=False) as c: + with self.db.transaction(exclusive=False) as c: for row in c.execute(''' SELECT user_id, username, email, is_admin, is_member FROM users @@ -153,7 +42,7 @@ class Database(object): member: bool = True) -> User: pwhash: str = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12)) user_id: int = -1 - with self.transaction() as c: + with self.db.transaction() as c: c.execute('SELECT user_id FROM users WHERE username = ?', [username]) if c.fetchone() is not None: raise ValueError(f'A user with the name \'{username}\' already exists.') @@ -167,37 +56,43 @@ class Database(object): 'admin': admin, 'member': member }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'create_user should affect 1 users row, but affected {affected}') c.execute('SELECT last_insert_rowid()') user_id = int(c.fetchone()[0]) return User(user_id, username, email, admin, member) - def login(self, username: str, password: str) -> Optional[User]: - with self.transaction(exclusive=False) as c: + def login(self, username: str, password: Optional[str] = None, touchkey: Optional[str] = None) -> User: + if (password is None) == (touchkey is None): + raise ValueError('Exactly one of password and touchkey must be provided') + with self.db.transaction(exclusive=False) as c: c.execute(''' - SELECT user_id, username, email, password, is_admin, is_member + SELECT user_id, username, email, password, touchkey, is_admin, is_member FROM users WHERE username = ? ''', [username]) row = c.fetchone() if row is None: - return None - user_id, username, email, pwhash, admin, member = row - if not bcrypt.checkpw(password.encode('utf-8'), pwhash): - return None + raise AuthenticationError('User does not exist') + user_id, username, email, pwhash, tkhash, admin, member = row + if password is not None and not bcrypt.checkpw(password.encode('utf-8'), pwhash): + raise AuthenticationError('Password mismatch') + elif touchkey is not None and not bcrypt.checkpw(touchkey.encode('utf-8'), tkhash): + raise AuthenticationError('Touchkey mismatch') return User(user_id, username, email, admin, member) - def change_password(self, user: User, oldpass: str, newpass: str, newpass2: str, verify_password: bool = True): - if newpass != newpass2: - raise ValueError('New passwords don\'t match.') - with self.transaction() as c: + def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True): + with self.db.transaction() as c: c.execute(''' SELECT password FROM users WHERE user_id = ? ''', [user.id]) row = c.fetchone() if row is None: - raise AuthenticationException('User does not exist in database.') + raise AuthenticationError('User does not exist in database.') if verify_password and not bcrypt.checkpw(oldpass.encode('utf-8'), row[0]): - raise AuthenticationException('Old password does not match.') + raise AuthenticationError('Old password does not match.') pwhash: str = bcrypt.hashpw(newpass.encode('utf-8'), bcrypt.gensalt(12)) c.execute(''' UPDATE users SET password = :pwhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id @@ -205,17 +100,21 @@ class Database(object): 'user_id': user.id, 'pwhash': pwhash }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'change_password should affect 1 users row, but affected {affected}') def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True): - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' SELECT password FROM users WHERE user_id = ? ''', [user.id]) row = c.fetchone() if row is None: - raise AuthenticationException('User does not exist in database.') + raise AuthenticationError('User does not exist in database.') if verify_password and not bcrypt.checkpw(password.encode('utf-8'), row[0]): - raise AuthenticationException('Password does not match.') + raise AuthenticationError('Password does not match.') tkhash: str = bcrypt.hashpw(touchkey.encode('utf-8'), bcrypt.gensalt(12)) if touchkey is not None else None c.execute(''' UPDATE users SET touchkey = :tkhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id @@ -223,9 +122,13 @@ class Database(object): 'user_id': user.id, 'tkhash': tkhash }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'change_touchkey should affect 1 users row, but affected {affected}') def change_user(self, user: User): - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' UPDATE users SET email = :email, @@ -239,17 +142,25 @@ class Database(object): 'is_admin': user.is_admin, 'is_member': user.is_member }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'change_user should affect 1 users row, but affected {affected}') def delete_user(self, user: User): - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' DELETE FROM users WHERE user_id = ? ''', [user.id]) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'delete_user should affect 1 users row, but affected {affected}') def list_products(self) -> List[Product]: products: List[Product] = [] - with self.transaction(exclusive=False) as c: + with self.db.transaction(exclusive=False) as c: for row in c.execute(''' SELECT product_id, name, price_member, price_external FROM products @@ -260,7 +171,7 @@ class Database(object): def create_product(self, name: str, price_member: int, price_non_member: int) -> Product: product_id: int = -1 - with self.transaction() as c: + with self.db.transaction() as c: c.execute('SELECT product_id FROM products WHERE name = ?', [name]) if c.fetchone() is not None: raise ValueError(f'A product with the name \'{name}\' already exists.') @@ -273,13 +184,17 @@ class Database(object): 'price_non_member': price_non_member }) c.execute('SELECT last_insert_rowid()') + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'create_product should affect 1 products row, but affected {affected}') product_id = int(c.fetchone()[0]) return Product(product_id, name, price_member, price_non_member) def change_product(self, product: Product): if product.id == -1: raise ValueError('Invalid product ID') - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' UPDATE products SET @@ -292,20 +207,24 @@ class Database(object): 'price_member': product.price_member, 'price_non_member': product.price_non_member }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'change_product should affect 1 products row, but affected {affected}') def delete_product(self, product: Product): - if product.id == -1: - raise ValueError('Invalid product ID') - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' DELETE FROM products WHERE product_id = ? ''', [product.id]) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'delete_product should affect 1 products row, but affected {affected}') def increment_consumption(self, user: User, product: Product, count: int = 1): - if product.id == -1: - raise ValueError('Invalid product ID') - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' SELECT count FROM consumption @@ -332,6 +251,10 @@ class Database(object): 'product_id': product.id, 'count': count }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'increment_consumption should affect 1 consumption row, but affected {affected}') c.execute(''' UPDATE users SET balance = balance - :cost @@ -339,6 +262,10 @@ class Database(object): 'user_id': user.id, 'cost': count * product.price_member if user.is_member else count * product.price_non_member }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'increment_consumption should affect 1 users row, but affected {affected}') c.execute(''' UPDATE products SET stock = stock - :count @@ -347,11 +274,15 @@ class Database(object): 'product_id': product.id, 'count': count }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError( + f'increment_consumption should affect 1 products row, but affected {affected}') def restock(self, product: Product, count: int): if product.id == -1: raise ValueError('Invalid product ID') - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' UPDATE products SET stock = stock + :count @@ -360,9 +291,12 @@ class Database(object): 'product_id': product.id, 'count': count }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}') def deposit(self, user: User, amount: int): - with self.transaction() as c: + with self.db.transaction() as c: c.execute(''' UPDATE users SET balance = balance + :amount @@ -371,3 +305,6 @@ class Database(object): 'user_id': user.id, 'amount': amount }) + affected = c.execute('SELECT changes()').fetchone()[0] + if affected != 1: + raise DatabaseConsistencyError(f'deposit should affect 1 users row, but affected {affected}') diff --git a/matemat/db/test/test_facade.py b/matemat/db/test/test_facade.py new file mode 100644 index 0000000..049edcb --- /dev/null +++ b/matemat/db/test/test_facade.py @@ -0,0 +1,69 @@ + +import unittest + +from matemat.db import Database +from matemat.exceptions import AuthenticationError, DatabaseConsistencyError + + +class DatabaseTest(unittest.TestCase): + + def setUp(self): + self.db = Database(':memory:') + + def test_create_user(self): + with self.db as db: + with db.transaction() as c: + db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + c.execute("SELECT * FROM users") + row = c.fetchone() + if row is None: + self.fail() + self.assertEqual('testuser', row[1]) + self.assertEqual('testuser@example.com', row[2]) + self.assertEqual(0, row[5]) + self.assertEqual(1, row[6]) + + def test_create_existing_user(self): + with self.db as db: + db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + with self.assertRaises(ValueError): + db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com') + + def test_change_password(self): + with self.db as db: + user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + db.login('testuser', 'supersecurepassword') + # Normal password change should succeed + db.change_password(user, 'supersecurepassword', 'mynewpassword') + with self.assertRaises(AuthenticationError): + db.login('testuser', 'supersecurepassword') + db.login('testuser', 'mynewpassword') + with self.assertRaises(AuthenticationError): + # Should fail due to wrong old password + db.change_password(user, 'iforgotmypassword', 'mynewpassword') + db.login('testuser', 'mynewpassword') + # This should pass even though the old password is not known (admin password reset) + db.change_password(user, '42', 'adminpasswordreset', verify_password=False) + with self.assertRaises(AuthenticationError): + db.login('testuser', 'mynewpassword') + db.login('testuser', 'adminpasswordreset') + + def test_change_touchkey(self): + with self.db as db: + user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') + db.change_touchkey(user, 'supersecurepassword', '0123') + db.login('testuser', touchkey='0123') + # Normal touchkey change should succeed + db.change_touchkey(user, 'supersecurepassword', touchkey='4567') + with self.assertRaises(AuthenticationError): + db.login('testuser', touchkey='0123') + db.login('testuser', touchkey='4567') + with self.assertRaises(AuthenticationError): + # Should fail due to wrong old password + db.change_touchkey(user, 'iforgotmypassword', '89ab') + db.login('testuser', touchkey='4567') + # This should pass even though the old password is not known (admin password reset) + db.change_touchkey(user, '42', '89ab', verify_password=False) + with self.assertRaises(AuthenticationError): + db.login('testuser', touchkey='4567') + db.login('testuser', touchkey='89ab') diff --git a/matemat/db/test/test_database.py b/matemat/db/test/test_wrapper.py similarity index 71% rename from matemat/db/test/test_database.py rename to matemat/db/test/test_wrapper.py index a03fbd4..334ecdf 100644 --- a/matemat/db/test/test_database.py +++ b/matemat/db/test/test_wrapper.py @@ -1,20 +1,20 @@ import unittest -from matemat.db import Database +from matemat.db import DatabaseWrapper class DatabaseTest(unittest.TestCase): def setUp(self): - self.db = Database(':memory:') + self.db = DatabaseWrapper(':memory:') def test_create_schema(self): """ Test creation of database schema in an empty database """ with self.db as db: - self.assertEqual(Database.SCHEMA_VERSION, db._user_version) + self.assertEqual(DatabaseWrapper.SCHEMA_VERSION, db._user_version) def test_in_transaction(self): """ @@ -74,22 +74,3 @@ class DatabaseTest(unittest.TestCase): c = db._sqlite_db.cursor() c.execute("SELECT * FROM users") self.assertIsNone(c.fetchone()) - - def test_create_user(self): - with self.db as db: - with db.transaction() as c: - db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') - c.execute("SELECT * FROM users") - row = c.fetchone() - if row is None: - self.fail() - self.assertEqual('testuser', row[1]) - self.assertEqual('testuser@example.com', row[2]) - self.assertEqual(0, row[5]) - self.assertEqual(1, row[6]) - - def test_create_existing_user(self): - with self.db as db: - db.create_user('testuser', 'supersecurepassword', 'testuser@example.com') - with self.assertRaises(ValueError): - db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com') diff --git a/matemat/db/wrapper.py b/matemat/db/wrapper.py new file mode 100644 index 0000000..3b92936 --- /dev/null +++ b/matemat/db/wrapper.py @@ -0,0 +1,129 @@ + +import apsw + + +class Transaction(object): + + def __init__(self, db: apsw.Connection, wrapper: 'DatabaseWrapper', exclusive: bool = True): + self._db: apsw.Connection = db + self._cursor = None + self._excl = exclusive + self._wrapper: DatabaseWrapper = wrapper + self._is_dummy: bool = False + + def __enter__(self): + if self._wrapper._in_transaction: + self._is_dummy = True + return self._db.cursor() + else: + self._is_dummy = False + self._cursor = self._db.cursor() + self._wrapper._in_transaction = True + if self._excl: + self._cursor.execute('BEGIN EXCLUSIVE') + else: + self._cursor.execute('BEGIN') + return self._cursor + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._is_dummy: + return + if exc_type is None: + self._cursor.execute('COMMIT') + else: + self._cursor.execute('ROLLBACK') + self._wrapper._in_transaction = False + + +class DatabaseWrapper(object): + SCHEMA_VERSION = 1 + + SCHEMA = ''' + CREATE TABLE users ( + user_id INTEGER PRIMARY KEY, + username TEXT NOT NULL, + email TEXT DEFAULT NULL, + password TEXT NOT NULL, + touchkey TEXT DEFAULT NULL, + is_admin INTEGER(1) NOT NULL DEFAULT 0, + is_member INTEGER(1) NOT NULL DEFAULT 1, + balance INTEGER(8) NOT NULL DEFAULT 0, + lastchange INTEGER(8) NOT NULL DEFAULT 0 + ); + CREATE TABLE products ( + product_id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + stock INTEGER(8) NOT NULL DEFAULT 0, + price_member INTEGER(8) NOT NULL, + price_non_member INTEGER(8) NOT NULL + ); + CREATE TABLE consumption ( + user_id INTEGER NOT NULL, + product_id INTEGER NOT NULL, + count INTEGER(8) NOT NULL DEFAULT 0, + PRIMARY KEY (user_id, product_id), + FOREIGN KEY (user_id) REFERENCES users(user_id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (product_id) REFERENCES products(product_id) + ON DELETE CASCADE ON UPDATE CASCADE + ) + ''' + + def __init__(self, filename: str): + self._filename: str = filename + self._sqlite_db: apsw.Connection = None + self._in_transaction: bool = False + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def transaction(self, exclusive: bool = True) -> Transaction: + return Transaction(self._sqlite_db, self, exclusive) + + def _setup(self): + with self.transaction() as c: + version: int = self._user_version + if version < 1: + c.execute(self.SCHEMA) + elif version < self.SCHEMA_VERSION: + self._upgrade(old=version, new=self.SCHEMA_VERSION) + self._user_version = self.SCHEMA_VERSION + + def _upgrade(self, old: int, new: int): + pass + + def connect(self): + if self.is_connected(): + raise RuntimeError(f'Database connection to {self._filename} is already established.') + self._sqlite_db = apsw.Connection(self._filename) + self._setup() + + def close(self): + if self._sqlite_db is None: + raise RuntimeError(f'Database connection to {self._filename} is not established.') + if self.in_transaction(): + raise RuntimeError(f'A transaction is still ongoing.') + self._sqlite_db.close() + self._sqlite_db = None + + def in_transaction(self) -> bool: + return self._in_transaction + + def is_connected(self) -> bool: + return self._sqlite_db is not None + + @property + def _user_version(self) -> int: + cursor = self._sqlite_db.cursor() + cursor.execute('PRAGMA user_version') + version = int(cursor.fetchone()[0]) + return version + + @_user_version.setter + def _user_version(self, version: int): + cursor = self._sqlite_db.cursor() + cursor.execute(f'PRAGMA user_version = {version}') diff --git a/matemat/exceptions/AuthenticatonException.py b/matemat/exceptions/AuthenticatonError.py similarity index 61% rename from matemat/exceptions/AuthenticatonException.py rename to matemat/exceptions/AuthenticatonError.py index 55eed7c..4494a33 100644 --- a/matemat/exceptions/AuthenticatonException.py +++ b/matemat/exceptions/AuthenticatonError.py @@ -1,11 +1,11 @@ -class AuthenticationException(BaseException): +class AuthenticationError(BaseException): def __init__(self, msg: str = None): self._msg = msg def __str__(self) -> str: - return f'AuthenticationException: {self._msg}' + return f'AuthenticationErro: {self._msg}' @property def msg(self) -> str: diff --git a/matemat/exceptions/DatabaseConsistencyError.py b/matemat/exceptions/DatabaseConsistencyError.py new file mode 100644 index 0000000..d07d7a1 --- /dev/null +++ b/matemat/exceptions/DatabaseConsistencyError.py @@ -0,0 +1,12 @@ + +class DatabaseConsistencyError(BaseException): + + def __init__(self, msg: str = None): + self._msg = msg + + def __str__(self) -> str: + return f'DatabaseConsistencyError: {self._msg}' + + @property + def msg(self) -> str: + return self._msg diff --git a/matemat/exceptions/__init__.py b/matemat/exceptions/__init__.py index 08617cb..cacbbb9 100644 --- a/matemat/exceptions/__init__.py +++ b/matemat/exceptions/__init__.py @@ -1,2 +1,3 @@ -from .AuthenticatonException import AuthenticationException +from .AuthenticatonError import AuthenticationError +from .DatabaseConsistencyError import DatabaseConsistencyError