diff --git a/doc b/doc index 5335524..2ce0e26 160000 --- a/doc +++ b/doc @@ -1 +1 @@ -Subproject commit 5335524d3e57c7551f31c7e21fc04c464b23429a +Subproject commit 2ce0e26b101192b92061c299ffc0e5524104f215 diff --git a/matemat/db/facade.py b/matemat/db/facade.py index 67f1c5f..64418d9 100644 --- a/matemat/db/facade.py +++ b/matemat/db/facade.py @@ -232,7 +232,7 @@ class MatematDatabase(object): 'tkhash': tkhash }) - def change_user(self, user: User, **kwargs)\ + def change_user(self, user: User, agent: User, **kwargs)\ -> None: """ Commit changes to the user in the database. If writing the requested changes succeeded, the values are updated @@ -240,6 +240,7 @@ class MatematDatabase(object): the ID field in the provided user object. :param user: The user object to update and to identify the requested user by. + :param agent: The user that is performing the change. :param kwargs: The properties to change. :raises DatabaseConsistencyError: If the user represented by the object does not exist. """ @@ -250,6 +251,25 @@ class MatematDatabase(object): is_admin: bool = kwargs['is_admin'] if 'is_admin' in kwargs else user.is_admin is_member: bool = kwargs['is_member'] if 'is_member' in kwargs else user.is_member with self.db.transaction() as c: + c.execute('SELECT balance FROM users WHERE user_id = :user_id', {'user_id': user.id}) + row = c.fetchone() + if row is None: + raise DatabaseConsistencyError(f'User with ID {user.id} does not exist') + oldbalance: int = row[0] + if balance != oldbalance: + c.execute(''' + INSERT INTO transactions (user_id, value, old_balance) + VALUES (:user_id, :value, :old_balance) + ''', { + 'user_id': user.id, + 'value': balance - oldbalance, + 'old_balance': oldbalance + }) + # TODO: Implement reason field + c.execute(''' + INSERT INTO modifications (ta_id, agent_id, reason) + VALUES (last_insert_rowid(), :agent_id, NULL) + ''', {'agent_id': agent.id}) c.execute(''' UPDATE users SET username = :username, @@ -411,60 +431,38 @@ class MatematDatabase(object): raise DatabaseConsistencyError( f'delete_product should affect 1 products row, but affected {affected}') - def increment_consumption(self, user: User, product: Product, count: int = 1) -> None: + def increment_consumption(self, user: User, product: Product) -> None: """ Decrement the user's balance by the price of the product, decrement the products stock, and create an entry in the statistics table. + :param user: The user buying a product. :param product: The product the user is buying. - :param count: How many units of the product the user is buying, defaults to 1. :raises DatabaseConsistencyError: If the user or the product does not exist in the database. """ + price: int = product.price_member if user.is_member else product.price_non_member with self.db.transaction() as c: - # Retrieve the consumption entry for the (user, product) pair, if any. c.execute(''' - SELECT count - FROM consumption - WHERE user_id = :user_id - AND product_id = :product_id + INSERT INTO transactions (user_id, value, old_balance) + VALUES (:user_id, :value, :old_balance) ''', { 'user_id': user.id, + 'value': -price, + 'old_balance': user.balance + }) + c.execute(''' + INSERT INTO consumptions (ta_id, product_id) + VALUES (last_insert_rowid(), :product_id) + ''', { 'product_id': product.id }) - row = c.fetchone() - if row is None: - # If the entry does not exist, create a new one. - c.execute(''' - INSERT INTO consumption (user_id, product_id, count) - VALUES (:user_id, :product_id, :count) - ''', { - 'user_id': user.id, - 'product_id': product.id, - 'count': count - }) - else: - # If the entry exists, update the consumption count. - c.execute(''' - UPDATE consumption - SET count = count + :count - WHERE user_id = :user_id AND product_id = :product_id - ''', { - 'user_id': user.id, - 'product_id': product.id, - 'count': count - }) - # Make sure exactly one consumption row was updated/inserted. - affected = c.execute('SELECT changes()').fetchone()[0] - if affected != 1: - raise DatabaseConsistencyError( - f'increment_consumption should affect 1 consumption row, but affected {affected}') - # Compute the cost of the transaction and subtract it from the user's account balance. + # Subtract the price from the user's account balance. c.execute(''' UPDATE users SET balance = balance - :cost WHERE user_id = :user_id''', { 'user_id': user.id, - 'cost': count * product.price_member if user.is_member else count * product.price_non_member + 'cost': price }) # Make sure exactly one user row was updated. affected = c.execute('SELECT changes()').fetchone()[0] @@ -474,11 +472,10 @@ class MatematDatabase(object): # Subtract the number of purchased units from the product's stock. c.execute(''' UPDATE products - SET stock = stock - :count + SET stock = stock - 1 WHERE product_id = :product_id ''', { 'product_id': product.id, - 'count': count }) # Make sure exactly one product row was updated. affected = c.execute('SELECT changes()').fetchone()[0] @@ -509,6 +506,7 @@ class MatematDatabase(object): def deposit(self, user: User, amount: int) -> None: """ Update the account balance of a user. + :param user: The user to update the account balance for. :param amount: The amount to add to the account balance. :raises DatabaseConsistencyError: If the user represented by the object does not exist. @@ -516,6 +514,18 @@ class MatematDatabase(object): if amount < 0: raise ValueError('Cannot deposit a negative value') with self.db.transaction() as c: + c.execute(''' + INSERT INTO transactions (user_id, value, old_balance) + VALUES (:user_id, :value, :old_balance) + ''', { + 'user_id': user.id, + 'value': amount, + 'old_balance': user.balance + }) + c.execute(''' + INSERT INTO deposits (ta_id) + VALUES (last_insert_rowid()) + ''') c.execute(''' UPDATE users SET balance = balance + :amount diff --git a/matemat/db/migrations.py b/matemat/db/migrations.py new file mode 100644 index 0000000..28274d6 --- /dev/null +++ b/matemat/db/migrations.py @@ -0,0 +1,115 @@ + +from typing import Dict + +import sqlite3 + + +def migrate_schema_1_to_2(c: sqlite3.Cursor): + # Create missing tables + c.execute(''' + CREATE TABLE transactions ( + ta_id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + value INTEGER(8) NOT NULL, + old_balance INTEGER(8) NOT NULL, + date INTEGER(8) DEFAULT (STRFTIME('%s', 'now')), + FOREIGN KEY (user_id) REFERENCES users(user_id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + ''') + c.execute(''' + CREATE TABLE consumptions ( + ta_id INTEGER PRIMARY KEY, + product_id INTEGER DEFAULT NULL, + FOREIGN KEY (ta_id) REFERENCES transactions(ta_id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (product_id) REFERENCES products(product_id) + ON DELETE SET NULL ON UPDATE CASCADE + ); + ''') + c.execute(''' + CREATE TABLE deposits ( + ta_id INTEGER PRIMARY KEY, + FOREIGN KEY (ta_id) REFERENCES transactions(ta_id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + ''') + c.execute(''' + CREATE TABLE modifications ( + ta_id INTEGER NOT NULL, + agent_id INTEGER NOT NULL, + reason TEXT DEFAULT NULL, + PRIMARY KEY (ta_id), + FOREIGN KEY (ta_id) REFERENCES transactions(ta_id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (agent_id) REFERENCES users(user_id) + ON DELETE SET NULL ON UPDATE CASCADE + ); + ''') + + # + # Convert entries from the old consumption table into entries for the new consumptions table + # + + # Fetch current users, their balance and membership status + c.execute('SELECT user_id, balance, is_member FROM users') + balances: Dict[int, int] = dict() + memberships: Dict[int, bool] = dict() + for user_id, balance, member in c: + balances[user_id] = balance + memberships[user_id] = bool(member) + + # Fetch current products and their prices + c.execute('SELECT product_id, price_member, price_non_member FROM products') + prices_member: Dict[int, int] = dict() + prices_non_member: Dict[int, int] = dict() + for product_id, price_member, price_non_member in c: + prices_member[product_id] = price_member + prices_non_member[product_id] = price_non_member + + # As the following migration does reverse insertions, compute the max. primary key that can occur, and + # further down count downward from there + c.execute('SELECT SUM(count) FROM consumption') + ta_id: int = c.fetchone()[0] + + # Iterate (users x products) + for user_id in balances.keys(): + member: bool = memberships[user_id] + for product_id in prices_member: + price: int = prices_member[product_id] if member else prices_non_member[product_id] + + # Select the number of items the user has bought from this product + c.execute(''' + SELECT consumption.count FROM consumption + WHERE user_id = :user_id AND product_id = :product_id + ''', { + 'user_id': user_id, + 'product_id': product_id + }) + row = c.fetchone() + if row is not None: + count: int = row[0] + # Insert one row per bought item, setting the date to NULL, as it is not known + for _ in range(count): + # This migration "goes back in time", so after processing a purchase entry, "locally + # refund" the payment to reconstruct the "older" entries + balances[user_id] += price + # Insert into base table + c.execute(''' + INSERT INTO transactions (ta_id, user_id, value, old_balance, date) + VALUES (:ta_id, :user_id, :value, :old_balance, NULL) + ''', { + 'ta_id': ta_id, + 'user_id': user_id, + 'value': -price, + 'old_balance': balances[user_id] + }) + # Insert into specialization table + c.execute('INSERT INTO consumptions (ta_id, product_id) VALUES (:ta_id, :product_id)', { + 'ta_id': ta_id, + 'product_id': product_id + }) + # Decrement the transaction table insertion primary key + ta_id -= 1 + # Drop the old consumption table + c.execute('DROP TABLE consumption') diff --git a/matemat/db/schemas.py b/matemat/db/schemas.py new file mode 100644 index 0000000..5434ec1 --- /dev/null +++ b/matemat/db/schemas.py @@ -0,0 +1,103 @@ +from typing import Dict, List + +SCHEMAS: Dict[int, List[str]] = dict() + +SCHEMAS[1] = [ + ''' + CREATE TABLE users ( + user_id INTEGER PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + email TEXT DEFAULT NULL, + password TEXT NOT NULL, + touchkey TEXT DEFAULT NULL, + is_admin INTEGER(1) NOT NULL DEFAULT 0, + is_member INTEGER(1) NOT NULL DEFAULT 1, + balance INTEGER(8) NOT NULL DEFAULT 0, + lastchange INTEGER(8) NOT NULL DEFAULT 0 + ); + ''', + ''' + CREATE TABLE products ( + product_id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + stock INTEGER(8) NOT NULL DEFAULT 0, + price_member INTEGER(8) NOT NULL, + price_non_member INTEGER(8) NOT NULL + ); + ''', + ''' + CREATE TABLE consumption ( + user_id INTEGER NOT NULL, + product_id INTEGER NOT NULL, + count INTEGER(8) NOT NULL DEFAULT 0, + PRIMARY KEY (user_id, product_id), + FOREIGN KEY (user_id) REFERENCES users(user_id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (product_id) REFERENCES products(product_id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + '''] + +SCHEMAS[2] = [ + ''' + CREATE TABLE users ( + user_id INTEGER PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + email TEXT DEFAULT NULL, + password TEXT NOT NULL, + touchkey TEXT DEFAULT NULL, + is_admin INTEGER(1) NOT NULL DEFAULT 0, + is_member INTEGER(1) NOT NULL DEFAULT 1, + balance INTEGER(8) NOT NULL DEFAULT 0, + lastchange INTEGER(8) NOT NULL DEFAULT 0 + ); + ''', + ''' + CREATE TABLE products ( + product_id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + stock INTEGER(8) NOT NULL DEFAULT 0, + price_member INTEGER(8) NOT NULL, + price_non_member INTEGER(8) NOT NULL + ); + ''', + ''' + CREATE TABLE transactions ( -- "superclass" of the following 3 tables + ta_id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + value INTEGER(8) NOT NULL, + old_balance INTEGER(8) NOT NULL, + date INTEGER(8) DEFAULT (STRFTIME('%s', 'now')), + FOREIGN KEY (user_id) REFERENCES users(user_id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + ''', + ''' + CREATE TABLE consumptions ( -- transactions involving buying a product + ta_id INTEGER PRIMARY KEY, + product_id INTEGER DEFAULT NULL, + FOREIGN KEY (ta_id) REFERENCES transactions(ta_id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (product_id) REFERENCES products(product_id) + ON DELETE SET NULL ON UPDATE CASCADE + ); + ''', + ''' + CREATE TABLE deposits ( -- transactions involving depositing cash + ta_id INTEGER PRIMARY KEY, + FOREIGN KEY (ta_id) REFERENCES transactions(ta_id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + ''', + ''' + CREATE TABLE modifications ( -- transactions involving balance modification by an admin + ta_id INTEGER NOT NULL, + agent_id INTEGER NOT NULL, + reason TEXT DEFAULT NULL, + PRIMARY KEY (ta_id), + FOREIGN KEY (ta_id) REFERENCES transactions(ta_id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (agent_id) REFERENCES users(user_id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + '''] diff --git a/matemat/db/test/test_facade.py b/matemat/db/test/test_facade.py index d15fa64..04c0eb4 100644 --- a/matemat/db/test/test_facade.py +++ b/matemat/db/test/test_facade.py @@ -143,18 +143,23 @@ class DatabaseTest(unittest.TestCase): def test_change_user(self) -> None: with self.db as db: + agent = db.create_user('admin', 'supersecurepassword', 'admin@example.com', True, True) user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True) - user.email = 'newaddress@example.com' - user.is_admin = False - user.is_member = False - db.change_user(user) - checkuser = db.login('testuser', 'supersecurepassword') - self.assertEqual('newaddress@example.com', checkuser.email) + db.change_user(user, agent, email='newaddress@example.com', is_admin=False, is_member=False, balance=4200) + # Changes must be reflected in the passed user object + self.assertEqual('newaddress@example.com', user.email) + self.assertFalse(user.is_admin) + self.assertFalse(user.is_member) + self.assertEqual(4200, user.balance) + # Changes must be reflected in the database + checkuser = db.get_user(user.id) + self.assertEqual('newaddress@example.com', user.email) self.assertFalse(checkuser.is_admin) self.assertFalse(checkuser.is_member) + self.assertEqual(4200, checkuser.balance) user._user_id = -1 with self.assertRaises(DatabaseConsistencyError): - db.change_user(user) + db.change_user(user, agent, is_member='True') def test_delete_user(self) -> None: with self.db as db: @@ -221,14 +226,18 @@ class DatabaseTest(unittest.TestCase): def test_change_product(self) -> None: with self.db as db: product = db.create_product('Club Mate', 200, 200) - product.name = 'Flora Power Mate' - product.price_member = 150 - product.price_non_member = 250 - db.change_product(product) - checkproduct = db.list_products()[0] + db.change_product(product, name='Flora Power Mate', price_member=150, price_non_member=250, stock=42) + # Changes must be reflected in the passed object + self.assertEqual('Flora Power Mate', product.name) + self.assertEqual(150, product.price_member) + self.assertEqual(250, product.price_non_member) + self.assertEqual(42, product.stock) + # Changes must be reflected in the database + checkproduct = db.get_product(product.id) self.assertEqual('Flora Power Mate', checkproduct.name) self.assertEqual(150, checkproduct.price_member) self.assertEqual(250, checkproduct.price_non_member) + self.assertEqual(42, checkproduct.stock) product._product_id = -1 with self.assertRaises(DatabaseConsistencyError): db.change_product(product) @@ -313,14 +322,14 @@ class DatabaseTest(unittest.TestCase): db.restock(fritzmate, 10) # user1 is somewhat addicted to caffeine - db.increment_consumption(user1, clubmate, 1) - db.increment_consumption(user1, clubmate, 2) - db.increment_consumption(user1, florapowermate, 3) + for _ in range(3): + db.increment_consumption(user1, clubmate) + db.increment_consumption(user1, florapowermate) # user2 is reeeally addicted - db.increment_consumption(user2, clubmate, 7) - db.increment_consumption(user2, florapowermate, 3) - db.increment_consumption(user2, florapowermate, 4) + for _ in range(7): + db.increment_consumption(user2, clubmate) + db.increment_consumption(user2, florapowermate) with db.transaction(exclusive=False) as c: c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user1.id]) diff --git a/matemat/db/test/test_migrations.py b/matemat/db/test/test_migrations.py new file mode 100644 index 0000000..0488a96 --- /dev/null +++ b/matemat/db/test/test_migrations.py @@ -0,0 +1,124 @@ + +import unittest + +import sqlite3 + +from matemat.db import DatabaseWrapper +from matemat.db.schemas import SCHEMAS + + +class TestMigrations(unittest.TestCase): + + def setUp(self): + # Create an in-memory database for testing + self.db = DatabaseWrapper(':memory:') + + def _initialize_db(self, schema_version: int): + self.db._sqlite_db = sqlite3.connect(':memory:') + cursor: sqlite3.Cursor = self.db._sqlite_db.cursor() + cursor.execute('BEGIN EXCLUSIVE') + for cmd in SCHEMAS[schema_version]: + cursor.execute(cmd) + cursor.execute('COMMIT') + + def test_downgrade_fail(self): + # Test that downgrades are forbidden + self.db.SCHEMA_VERSION = 1 + self.db._sqlite_db = sqlite3.connect(':memory:') + self.db._sqlite_db.execute('PRAGMA user_version = 2') + with self.assertRaises(RuntimeError): + with self.db: + pass + + def test_upgrade_1_to_2(self): + # Setup test db with example entries covering - hopefully - all cases + self._initialize_db(1) + cursor: sqlite3.Cursor = self.db._sqlite_db.cursor() + cursor.execute(''' + INSERT INTO users VALUES + (1, 'testadmin', 'a@b.c', '$2a$10$herebehashes', NULL, 1, 1, 1337, 0), + (2, 'testuser', NULL, '$2a$10$herebehashes', '$2a$10$herebehashes', 0, 1, 4242, 0), + (3, 'alien', NULL, '$2a$10$herebehashes', '$2a$10$herebehashes', 0, 0, 1234, 0) + ''') + cursor.execute(''' + INSERT INTO products VALUES + (1, 'Club Mate', 42, 200, 250), + (2, 'Flora Power Mate (1/4l)', 10, 100, 150) + ''') + cursor.execute(''' + INSERT INTO consumption VALUES + (1, 1, 5), (1, 2, 3), (2, 2, 10), (3, 1, 3), (3, 2, 4) + ''') + cursor.execute('PRAGMA user_version = 1') + + # Kick off the migration + self.db._setup() + + # Test whether the new tables were created + cursor.execute('PRAGMA table_info(transactions)') + self.assertNotEqual(0, len(cursor.fetchall())) + cursor.execute('PRAGMA table_info(consumptions)') + self.assertNotEqual(0, len(cursor.fetchall())) + cursor.execute('PRAGMA table_info(deposits)') + self.assertNotEqual(0, len(cursor.fetchall())) + cursor.execute('PRAGMA table_info(modifications)') + self.assertNotEqual(0, len(cursor.fetchall())) + # Test whether the old consumption table was dropped + cursor.execute('PRAGMA table_info(consumption)') + self.assertEqual(0, len(cursor.fetchall())) + + # Test number of entries in the new tables + cursor.execute('SELECT COUNT(ta_id) FROM transactions') + self.assertEqual(25, cursor.fetchone()[0]) + cursor.execute('SELECT COUNT(ta_id) FROM consumptions') + self.assertEqual(25, cursor.fetchone()[0]) + cursor.execute('SELECT COUNT(ta_id) FROM deposits') + self.assertEqual(0, cursor.fetchone()[0]) + cursor.execute('SELECT COUNT(ta_id) FROM modifications') + self.assertEqual(0, cursor.fetchone()[0]) + + # The (user_id=2 x product_id=1) combination should never appear + cursor.execute(''' + SELECT COUNT(t.ta_id) + FROM transactions AS t + LEFT JOIN consumptions AS c + ON t.ta_id = c.ta_id + WHERE t.user_id = 2 AND c.product_id = 1''') + self.assertEqual(0, cursor.fetchone()[0]) + + # Test that one entry per consumption was created, and their values match the negative price + cursor.execute(''' + SELECT COUNT(t.ta_id) + FROM transactions AS t + LEFT JOIN consumptions AS c + ON t.ta_id = c.ta_id + WHERE t.user_id = 1 AND c.product_id = 1 AND t.value = -200''') + self.assertEqual(5, cursor.fetchone()[0]) + cursor.execute(''' + SELECT COUNT(t.ta_id) + FROM transactions AS t + LEFT JOIN consumptions AS c + ON t.ta_id = c.ta_id + WHERE t.user_id = 1 AND c.product_id = 2 AND t.value = -100''') + self.assertEqual(3, cursor.fetchone()[0]) + cursor.execute(''' + SELECT COUNT(t.ta_id) + FROM transactions AS t + LEFT JOIN consumptions AS c + ON t.ta_id = c.ta_id + WHERE t.user_id = 2 AND c.product_id = 2 AND t.value = -100''') + self.assertEqual(10, cursor.fetchone()[0]) + cursor.execute(''' + SELECT COUNT(t.ta_id) + FROM transactions AS t + LEFT JOIN consumptions AS c + ON t.ta_id = c.ta_id + WHERE t.user_id = 3 AND c.product_id = 1 AND t.value = -250''') + self.assertEqual(3, cursor.fetchone()[0]) + cursor.execute(''' + SELECT COUNT(t.ta_id) + FROM transactions AS t + LEFT JOIN consumptions AS c + ON t.ta_id = c.ta_id + WHERE t.user_id = 3 AND c.product_id = 2 AND t.value = -150''') + self.assertEqual(4, cursor.fetchone()[0]) diff --git a/matemat/db/test/test_wrapper.py b/matemat/db/test/test_wrapper.py index e5a52c2..1b38f60 100644 --- a/matemat/db/test/test_wrapper.py +++ b/matemat/db/test/test_wrapper.py @@ -1,6 +1,8 @@ import unittest +import sqlite3 + from matemat.db import DatabaseWrapper diff --git a/matemat/db/wrapper.py b/matemat/db/wrapper.py index 34af430..c0cddc4 100644 --- a/matemat/db/wrapper.py +++ b/matemat/db/wrapper.py @@ -4,6 +4,8 @@ from typing import Any, Optional import sqlite3 from matemat.exceptions import DatabaseConsistencyError +from matemat.db.schemas import SCHEMAS +from matemat.db.migrations import migrate_schema_1_to_2 class Transaction(object): @@ -39,38 +41,8 @@ class Transaction(object): class DatabaseWrapper(object): - SCHEMA_VERSION = 1 - SCHEMA = ''' - CREATE TABLE users ( - user_id INTEGER PRIMARY KEY, - username TEXT UNIQUE NOT NULL, - email TEXT DEFAULT NULL, - password TEXT NOT NULL, - touchkey TEXT DEFAULT NULL, - is_admin INTEGER(1) NOT NULL DEFAULT 0, - is_member INTEGER(1) NOT NULL DEFAULT 1, - balance INTEGER(8) NOT NULL DEFAULT 0, - lastchange INTEGER(8) NOT NULL DEFAULT 0 - ); - CREATE TABLE products ( - product_id INTEGER PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - stock INTEGER(8) NOT NULL DEFAULT 0, - price_member INTEGER(8) NOT NULL, - price_non_member INTEGER(8) NOT NULL - ); - CREATE TABLE consumption ( - user_id INTEGER NOT NULL, - product_id INTEGER NOT NULL, - count INTEGER(8) NOT NULL DEFAULT 0, - PRIMARY KEY (user_id, product_id), - FOREIGN KEY (user_id) REFERENCES users(user_id) - ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (product_id) REFERENCES products(product_id) - ON DELETE CASCADE ON UPDATE CASCADE - ); - ''' + SCHEMA_VERSION = 2 def __init__(self, filename: str) -> None: self._filename: str = filename @@ -92,13 +64,20 @@ class DatabaseWrapper(object): with self.transaction() as c: version: int = self._user_version if version < 1: - c.executescript(self.SCHEMA) + # Don't use executescript, as it issues a COMMIT first + for command in SCHEMAS[self.SCHEMA_VERSION]: + c.execute(command) elif version < self.SCHEMA_VERSION: - self._upgrade(old=version, new=self.SCHEMA_VERSION) + self._upgrade(from_version=version, to_version=self.SCHEMA_VERSION) + elif version > self.SCHEMA_VERSION: + raise RuntimeError('Database schema is newer than supported by this version of Matemat.') self._user_version = self.SCHEMA_VERSION - def _upgrade(self, old: int, new: int) -> None: - pass + def _upgrade(self, from_version: int, to_version: int) -> None: + with self.transaction() as c: + # Note to future s3lph: If there are further migrations, also consider upgrades like 1 -> 3 + if from_version == 1 and to_version == 2: + migrate_schema_1_to_2(c) def connect(self) -> None: if self.is_connected(): diff --git a/matemat/webserver/pagelets/buy.py b/matemat/webserver/pagelets/buy.py index 24ac60c..3b68cee 100644 --- a/matemat/webserver/pagelets/buy.py +++ b/matemat/webserver/pagelets/buy.py @@ -17,12 +17,8 @@ def buy(method: str, with MatematDatabase(config['DatabaseFile']) as db: uid: int = session_vars['authenticated_user'] user = db.get_user(uid) - if 'n' in args: - n = int(str(args.n)) - else: - n = 1 if 'pid' in args: pid = int(str(args.pid)) product = db.get_product(pid) - db.increment_consumption(user, product, n) + db.increment_consumption(user, product) return RedirectResponse('/') diff --git a/matemat/webserver/test/abstract_httpd_test.py b/matemat/webserver/test/abstract_httpd_test.py index 5680836..aadc346 100644 --- a/matemat/webserver/test/abstract_httpd_test.py +++ b/matemat/webserver/test/abstract_httpd_test.py @@ -119,10 +119,12 @@ class MockServer: # Set up logger self.logger: logging.Logger = logging.getLogger('matemat unit test') self.logger.setLevel(logging.DEBUG) + # Disable logging + self.logger.addHandler(logging.NullHandler()) # Initalize a log handler to stderr and set the log format - sh: logging.StreamHandler = logging.StreamHandler() - sh.setFormatter(logging.Formatter('%(asctime)s %(name)s [%(levelname)s]: %(message)s')) - self.logger.addHandler(sh) + # sh: logging.StreamHandler = logging.StreamHandler() + # sh.setFormatter(logging.Formatter('%(asctime)s %(name)s [%(levelname)s]: %(message)s')) + # self.logger.addHandler(sh) class MockSocket(bytes):