Merge branch 'db-tests' into 'master'

Additional unit tests for the Database Facade

See merge request s3lph/matemat!1
This commit is contained in:
s3lph 2018-06-05 17:34:00 +00:00
commit 8059765410
11 changed files with 570 additions and 187 deletions

View file

@ -8,4 +8,4 @@ test:
- pip3 install wheel - pip3 install wheel
- pip3 install -r requirements.txt - pip3 install -r requirements.txt
- python3-coverage run --branch -m unittest discover matemat - python3-coverage run --branch -m unittest discover matemat
- python3-coverage report -m - python3-coverage report -m --include 'matemat/*' --omit '*/test_*.py'

View file

@ -1,2 +1,3 @@
from .database import Database from .wrapper import DatabaseWrapper
from .facade import DatabaseFacade as Database

View file

@ -1,142 +1,31 @@
from typing import List, Optional from typing import List, Optional
import apsw
import bcrypt import bcrypt
from matemat.primitives import User, Product from matemat.primitives import User, Product
from matemat.exceptions import AuthenticationException from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
from matemat.db import DatabaseWrapper
class Transaction(object): class DatabaseFacade(object):
def __init__(self, db: apsw.Connection, wrapper: 'Database', exclusive: bool = True):
self._db: apsw.Connection = db
self._cursor = None
self._excl = exclusive
self._wrapper: Database = wrapper
self._is_dummy: bool = False
def __enter__(self):
if self._wrapper._in_transaction:
self._is_dummy = True
return self._db.cursor()
else:
self._is_dummy = False
self._cursor = self._db.cursor()
self._wrapper._in_transaction = True
if self._excl:
self._cursor.execute('BEGIN EXCLUSIVE')
else:
self._cursor.execute('BEGIN')
return self._cursor
def __exit__(self, exc_type, exc_val, exc_tb):
if self._is_dummy:
return
if exc_type is None:
self._cursor.execute('COMMIT')
else:
self._cursor.execute('ROLLBACK')
self._wrapper._in_transaction = False
class Database(object):
SCHEMA_VERSION = 1
SCHEMA = '''
CREATE TABLE users (
user_id INTEGER PRIMARY KEY,
username TEXT NOT NULL,
email TEXT DEFAULT NULL,
password TEXT NOT NULL,
touchkey TEXT DEFAULT NULL,
is_admin INTEGER(1) NOT NULL DEFAULT 0,
is_member INTEGER(1) NOT NULL DEFAULT 1,
balance INTEGER(8) NOT NULL DEFAULT 0,
lastchange INTEGER(8) NOT NULL DEFAULT 0
);
CREATE TABLE products (
product_id INTEGER PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
stock INTEGER(8) NOT NULL DEFAULT 0,
price_member INTEGER(8) NOT NULL,
price_non_member INTEGER(8) NOT NULL
);
CREATE TABLE consumption (
user_id INTEGER NOT NULL,
product_id INTEGER NOT NULL,
count INTEGER(8) NOT NULL DEFAULT 0,
PRIMARY KEY (user_id, product_id),
FOREIGN KEY (user_id) REFERENCES users(user_id)
ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (product_id) REFERENCES products(product_id)
ON DELETE CASCADE ON UPDATE CASCADE
)
'''
def __init__(self, filename: str): def __init__(self, filename: str):
self._filename: str = filename self.db: DatabaseWrapper = DatabaseWrapper(filename)
self._sqlite_db: apsw.Connection = None
self._in_transaction: bool = False
def __enter__(self): def __enter__(self):
self.connect() self.db.__enter__()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.close() self.db.__exit__(exc_type, exc_val, exc_tb)
def transaction(self, exclusive: bool = True) -> Transaction: def transaction(self, exclusive: bool = True):
return Transaction(self._sqlite_db, self, exclusive) return self.db.transaction(exclusive=exclusive)
def _setup(self):
with self.transaction() as c:
version: int = self._user_version
if version < 1:
c.execute(self.SCHEMA)
elif version < self.SCHEMA_VERSION:
self._upgrade(old=version, new=self.SCHEMA_VERSION)
self._user_version = self.SCHEMA_VERSION
def _upgrade(self, old: int, new: int):
pass
def connect(self):
if self.is_connected():
raise RuntimeError(f'Database connection to {self._filename} is already established.')
self._sqlite_db = apsw.Connection(self._filename)
self._setup()
def close(self):
if self._sqlite_db is None:
raise RuntimeError(f'Database connection to {self._filename} is not established.')
if self.in_transaction():
raise RuntimeError(f'A transaction is still ongoing.')
self._sqlite_db.close()
self._sqlite_db = None
def in_transaction(self) -> bool:
return self._in_transaction
def is_connected(self) -> bool:
return self._sqlite_db is not None
@property
def _user_version(self) -> int:
cursor = self._sqlite_db.cursor()
cursor.execute('PRAGMA user_version')
version = int(cursor.fetchone()[0])
return version
@_user_version.setter
def _user_version(self, version: int):
cursor = self._sqlite_db.cursor()
cursor.execute(f'PRAGMA user_version = {version}')
def list_users(self) -> List[User]: def list_users(self) -> List[User]:
users: List[User] = [] users: List[User] = []
with self.transaction(exclusive=False) as c: with self.db.transaction(exclusive=False) as c:
for row in c.execute(''' for row in c.execute('''
SELECT user_id, username, email, is_admin, is_member SELECT user_id, username, email, is_admin, is_member
FROM users FROM users
@ -153,7 +42,7 @@ class Database(object):
member: bool = True) -> User: member: bool = True) -> User:
pwhash: str = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12)) pwhash: str = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12))
user_id: int = -1 user_id: int = -1
with self.transaction() as c: with self.db.transaction() as c:
c.execute('SELECT user_id FROM users WHERE username = ?', [username]) c.execute('SELECT user_id FROM users WHERE username = ?', [username])
if c.fetchone() is not None: if c.fetchone() is not None:
raise ValueError(f'A user with the name \'{username}\' already exists.') raise ValueError(f'A user with the name \'{username}\' already exists.')
@ -171,33 +60,35 @@ class Database(object):
user_id = int(c.fetchone()[0]) user_id = int(c.fetchone()[0])
return User(user_id, username, email, admin, member) return User(user_id, username, email, admin, member)
def login(self, username: str, password: str) -> Optional[User]: def login(self, username: str, password: Optional[str] = None, touchkey: Optional[str] = None) -> User:
with self.transaction(exclusive=False) as c: if (password is None) == (touchkey is None):
raise ValueError('Exactly one of password and touchkey must be provided')
with self.db.transaction(exclusive=False) as c:
c.execute(''' c.execute('''
SELECT user_id, username, email, password, is_admin, is_member SELECT user_id, username, email, password, touchkey, is_admin, is_member
FROM users FROM users
WHERE username = ? WHERE username = ?
''', [username]) ''', [username])
row = c.fetchone() row = c.fetchone()
if row is None: if row is None:
return None raise AuthenticationError('User does not exist')
user_id, username, email, pwhash, admin, member = row user_id, username, email, pwhash, tkhash, admin, member = row
if not bcrypt.checkpw(password.encode('utf-8'), pwhash): if password is not None and not bcrypt.checkpw(password.encode('utf-8'), pwhash):
return None raise AuthenticationError('Password mismatch')
elif touchkey is not None and not bcrypt.checkpw(touchkey.encode('utf-8'), tkhash):
raise AuthenticationError('Touchkey mismatch')
return User(user_id, username, email, admin, member) return User(user_id, username, email, admin, member)
def change_password(self, user: User, oldpass: str, newpass: str, newpass2: str, verify_password: bool = True): def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True):
if newpass != newpass2: with self.db.transaction() as c:
raise ValueError('New passwords don\'t match.')
with self.transaction() as c:
c.execute(''' c.execute('''
SELECT password FROM users WHERE user_id = ? SELECT password FROM users WHERE user_id = ?
''', [user.id]) ''', [user.id])
row = c.fetchone() row = c.fetchone()
if row is None: if row is None:
raise AuthenticationException('User does not exist in database.') raise AuthenticationError('User does not exist in database.')
if verify_password and not bcrypt.checkpw(oldpass.encode('utf-8'), row[0]): if verify_password and not bcrypt.checkpw(oldpass.encode('utf-8'), row[0]):
raise AuthenticationException('Old password does not match.') raise AuthenticationError('Old password does not match.')
pwhash: str = bcrypt.hashpw(newpass.encode('utf-8'), bcrypt.gensalt(12)) pwhash: str = bcrypt.hashpw(newpass.encode('utf-8'), bcrypt.gensalt(12))
c.execute(''' c.execute('''
UPDATE users SET password = :pwhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id UPDATE users SET password = :pwhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id
@ -207,15 +98,15 @@ class Database(object):
}) })
def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True): def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True):
with self.transaction() as c: with self.db.transaction() as c:
c.execute(''' c.execute('''
SELECT password FROM users WHERE user_id = ? SELECT password FROM users WHERE user_id = ?
''', [user.id]) ''', [user.id])
row = c.fetchone() row = c.fetchone()
if row is None: if row is None:
raise AuthenticationException('User does not exist in database.') raise AuthenticationError('User does not exist in database.')
if verify_password and not bcrypt.checkpw(password.encode('utf-8'), row[0]): if verify_password and not bcrypt.checkpw(password.encode('utf-8'), row[0]):
raise AuthenticationException('Password does not match.') raise AuthenticationError('Password does not match.')
tkhash: str = bcrypt.hashpw(touchkey.encode('utf-8'), bcrypt.gensalt(12)) if touchkey is not None else None tkhash: str = bcrypt.hashpw(touchkey.encode('utf-8'), bcrypt.gensalt(12)) if touchkey is not None else None
c.execute(''' c.execute('''
UPDATE users SET touchkey = :tkhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id UPDATE users SET touchkey = :tkhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id
@ -225,7 +116,7 @@ class Database(object):
}) })
def change_user(self, user: User): def change_user(self, user: User):
with self.transaction() as c: with self.db.transaction() as c:
c.execute(''' c.execute('''
UPDATE users SET UPDATE users SET
email = :email, email = :email,
@ -239,19 +130,27 @@ class Database(object):
'is_admin': user.is_admin, 'is_admin': user.is_admin,
'is_member': user.is_member 'is_member': user.is_member
}) })
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(
f'change_user should affect 1 users row, but affected {affected}')
def delete_user(self, user: User): def delete_user(self, user: User):
with self.transaction() as c: with self.db.transaction() as c:
c.execute(''' c.execute('''
DELETE FROM users DELETE FROM users
WHERE user_id = ? WHERE user_id = ?
''', [user.id]) ''', [user.id])
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(
f'delete_user should affect 1 users row, but affected {affected}')
def list_products(self) -> List[Product]: def list_products(self) -> List[Product]:
products: List[Product] = [] products: List[Product] = []
with self.transaction(exclusive=False) as c: with self.db.transaction(exclusive=False) as c:
for row in c.execute(''' for row in c.execute('''
SELECT product_id, name, price_member, price_external SELECT product_id, name, price_member, price_non_member
FROM products FROM products
'''): '''):
product_id, name, price_member, price_external = row product_id, name, price_member, price_external = row
@ -260,7 +159,7 @@ class Database(object):
def create_product(self, name: str, price_member: int, price_non_member: int) -> Product: def create_product(self, name: str, price_member: int, price_non_member: int) -> Product:
product_id: int = -1 product_id: int = -1
with self.transaction() as c: with self.db.transaction() as c:
c.execute('SELECT product_id FROM products WHERE name = ?', [name]) c.execute('SELECT product_id FROM products WHERE name = ?', [name])
if c.fetchone() is not None: if c.fetchone() is not None:
raise ValueError(f'A product with the name \'{name}\' already exists.') raise ValueError(f'A product with the name \'{name}\' already exists.')
@ -277,41 +176,47 @@ class Database(object):
return Product(product_id, name, price_member, price_non_member) return Product(product_id, name, price_member, price_non_member)
def change_product(self, product: Product): def change_product(self, product: Product):
if product.id == -1: with self.db.transaction() as c:
raise ValueError('Invalid product ID')
with self.transaction() as c:
c.execute(''' c.execute('''
UPDATE products UPDATE products
SET SET
name = :name, name = :name,
price_member = :price_member, price_member = :price_member,
price_non_member = :price_non_member price_non_member = :price_non_member
WHERE product_id = :product_is WHERE product_id = :product_id
''', { ''', {
'product_id': product.id, 'product_id': product.id,
'name': product.name,
'price_member': product.price_member, 'price_member': product.price_member,
'price_non_member': product.price_non_member 'price_non_member': product.price_non_member
}) })
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(
f'change_product should affect 1 products row, but affected {affected}')
def delete_product(self, product: Product): def delete_product(self, product: Product):
if product.id == -1: with self.db.transaction() as c:
raise ValueError('Invalid product ID')
with self.transaction() as c:
c.execute(''' c.execute('''
DELETE FROM products DELETE FROM products
WHERE product_id = ? WHERE product_id = ?
''', [product.id]) ''', [product.id])
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(
f'delete_product should affect 1 products row, but affected {affected}')
def increment_consumption(self, user: User, product: Product, count: int = 1): def increment_consumption(self, user: User, product: Product, count: int = 1):
if product.id == -1: with self.db.transaction() as c:
raise ValueError('Invalid product ID')
with self.transaction() as c:
c.execute(''' c.execute('''
SELECT count SELECT count
FROM consumption FROM consumption
WHERE user_id = :user_id WHERE user_id = :user_id
AND product_id = :product_id AND product_id = :product_id
''') ''', {
'user_id': user.id,
'product_id': product.id
})
row = c.fetchone() row = c.fetchone()
if row is None: if row is None:
c.execute(''' c.execute('''
@ -332,6 +237,10 @@ class Database(object):
'product_id': product.id, 'product_id': product.id,
'count': count 'count': count
}) })
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(
f'increment_consumption should affect 1 consumption row, but affected {affected}')
c.execute(''' c.execute('''
UPDATE users UPDATE users
SET balance = balance - :cost SET balance = balance - :cost
@ -339,6 +248,10 @@ class Database(object):
'user_id': user.id, 'user_id': user.id,
'cost': count * product.price_member if user.is_member else count * product.price_non_member 'cost': count * product.price_member if user.is_member else count * product.price_non_member
}) })
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(
f'increment_consumption should affect 1 users row, but affected {affected}')
c.execute(''' c.execute('''
UPDATE products UPDATE products
SET stock = stock - :count SET stock = stock - :count
@ -347,11 +260,13 @@ class Database(object):
'product_id': product.id, 'product_id': product.id,
'count': count 'count': count
}) })
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(
f'increment_consumption should affect 1 products row, but affected {affected}')
def restock(self, product: Product, count: int): def restock(self, product: Product, count: int):
if product.id == -1: with self.db.transaction() as c:
raise ValueError('Invalid product ID')
with self.transaction() as c:
c.execute(''' c.execute('''
UPDATE products UPDATE products
SET stock = stock + :count SET stock = stock + :count
@ -360,9 +275,14 @@ class Database(object):
'product_id': product.id, 'product_id': product.id,
'count': count 'count': count
}) })
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}')
def deposit(self, user: User, amount: int): def deposit(self, user: User, amount: int):
with self.transaction() as c: if amount < 0:
raise ValueError('Cannot deposit a negative value')
with self.db.transaction() as c:
c.execute(''' c.execute('''
UPDATE users UPDATE users
SET balance = balance + :amount SET balance = balance + :amount
@ -371,3 +291,6 @@ class Database(object):
'user_id': user.id, 'user_id': user.id,
'amount': amount 'amount': amount
}) })
affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1:
raise DatabaseConsistencyError(f'deposit should affect 1 users row, but affected {affected}')

View file

@ -0,0 +1,313 @@
import unittest
import bcrypt
from matemat.db import Database
from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
class DatabaseTest(unittest.TestCase):
def setUp(self):
self.db = Database(':memory:')
def test_create_user(self):
with self.db as db:
with db.transaction(exclusive=False) as c:
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
c.execute("SELECT * FROM users")
row = c.fetchone()
self.assertEqual('testuser', row[1])
self.assertEqual('testuser@example.com', row[2])
self.assertEqual(0, row[5])
self.assertEqual(1, row[6])
with self.assertRaises(ValueError):
db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com')
def test_list_users(self):
with self.db as db:
users = db.list_users()
self.assertEqual(0, len(users))
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
db.create_user('anothertestuser', 'otherpassword', 'anothertestuser@example.com', False, True)
db.create_user('yatu', 'igotapasswordtoo', 'yatu@example.com', False, False)
users = db.list_users()
self.assertEqual(3, len(users))
usercheck = {}
for user in users:
if user.name == 'testuser':
self.assertEqual('testuser@example.com', user.email)
self.assertTrue(user.is_member)
self.assertTrue(user.is_admin)
elif user.name == 'anothertestuser':
self.assertEqual('anothertestuser@example.com', user.email)
self.assertTrue(user.is_member)
self.assertFalse(user.is_admin)
elif user.name == 'yatu':
self.assertEqual('yatu@example.com', user.email)
self.assertFalse(user.is_member)
self.assertFalse(user.is_admin)
usercheck[user.id] = 1
self.assertEqual(3, len(usercheck))
def test_login(self):
with self.db as db:
with db.transaction() as c:
u = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
# Add a touchkey without using the provided function
c.execute('''UPDATE users SET touchkey = :tkhash WHERE user_id = :user_id''', {
'tkhash': bcrypt.hashpw(b'0123', bcrypt.gensalt(12)),
'user_id': u.id
})
user = db.login('testuser', 'supersecurepassword')
self.assertEqual(u.id, user.id)
user = db.login('testuser', touchkey='0123')
self.assertEqual(u.id, user.id)
with self.assertRaises(AuthenticationError):
# Inexistent user should fail
db.login('nooone', 'supersecurepassword')
with self.assertRaises(AuthenticationError):
# Wrong password should fail
db.login('testuser', 'anothersecurepassword')
with self.assertRaises(AuthenticationError):
# Wrong touchkey should fail
db.login('testuser', touchkey='0124')
with self.assertRaises(ValueError):
# No password or touchkey should fail
db.login('testuser')
with self.assertRaises(ValueError):
# Both password and touchkey should fail
db.login('testuser', password='supersecurepassword', touchkey='0123')
def test_change_password(self):
with self.db as db:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
db.login('testuser', 'supersecurepassword')
# Normal password change should succeed
db.change_password(user, 'supersecurepassword', 'mynewpassword')
with self.assertRaises(AuthenticationError):
db.login('testuser', 'supersecurepassword')
db.login('testuser', 'mynewpassword')
with self.assertRaises(AuthenticationError):
# Should fail due to wrong old password
db.change_password(user, 'iforgotmypassword', 'mynewpassword')
db.login('testuser', 'mynewpassword')
# This should pass even though the old password is not known (admin password reset)
db.change_password(user, '42', 'adminpasswordreset', verify_password=False)
with self.assertRaises(AuthenticationError):
db.login('testuser', 'mynewpassword')
db.login('testuser', 'adminpasswordreset')
user._user_id = -1
with self.assertRaises(AuthenticationError):
# Password change for an inexistent user should fail
db.change_password(user, 'adminpasswordreset', 'passwordwithoutuser')
def test_change_touchkey(self):
with self.db as db:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
db.change_touchkey(user, 'supersecurepassword', '0123')
db.login('testuser', touchkey='0123')
# Normal touchkey change should succeed
db.change_touchkey(user, 'supersecurepassword', touchkey='4567')
with self.assertRaises(AuthenticationError):
db.login('testuser', touchkey='0123')
db.login('testuser', touchkey='4567')
with self.assertRaises(AuthenticationError):
# Should fail due to wrong old password
db.change_touchkey(user, 'iforgotmypassword', '89ab')
db.login('testuser', touchkey='4567')
# This should pass even though the old password is not known (admin password reset)
db.change_touchkey(user, '42', '89ab', verify_password=False)
with self.assertRaises(AuthenticationError):
db.login('testuser', touchkey='4567')
db.login('testuser', touchkey='89ab')
user._user_id = -1
with self.assertRaises(AuthenticationError):
# Touchkey change for an inexistent user should fail
db.change_touchkey(user, '89ab', '048c')
def test_change_user(self):
with self.db as db:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
user.email = 'newaddress@example.com'
user.is_admin = False
user.is_member = False
db.change_user(user)
checkuser = db.login('testuser', 'supersecurepassword')
self.assertEqual('newaddress@example.com', checkuser.email)
self.assertFalse(checkuser.is_admin)
self.assertFalse(checkuser.is_member)
user._user_id = -1
with self.assertRaises(DatabaseConsistencyError):
db.change_user(user)
def test_delete_user(self):
with self.db as db:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
db.login('testuser', 'supersecurepassword')
db.delete_user(user)
try:
# Should fail, as the user does not exist anymore
db.login('testuser', 'supersecurepassword')
except AuthenticationError as e:
self.assertEqual('User does not exist', e.msg)
with self.assertRaises(DatabaseConsistencyError):
# Should fail, as the user does not exist anymore
db.delete_user(user)
def test_create_product(self):
with self.db as db:
with db.transaction() as c:
db.create_product('Club Mate', 200, 200)
c.execute("SELECT * FROM products")
row = c.fetchone()
self.assertEqual('Club Mate', row[1])
self.assertEqual(0, row[2])
self.assertEqual(200, row[3])
self.assertEqual(200, row[4])
with self.assertRaises(ValueError):
db.create_product('Club Mate', 250, 250)
def test_list_products(self):
with self.db as db:
# Test empty list
products = db.list_products()
self.assertEqual(0, len(products))
db.create_product('Club Mate', 200, 200)
db.create_product('Flora Power Mate', 200, 200)
db.create_product('Fritz Mate', 200, 250)
products = db.list_products()
self.assertEqual(3, len(products))
productcheck = {}
for product in products:
if product.name == 'Club Mate':
self.assertEqual(200, product.price_member)
self.assertEqual(200, product.price_non_member)
elif product.name == 'Flora Power Mate':
self.assertEqual(200, product.price_member)
self.assertEqual(200, product.price_non_member)
elif product.name == 'Fritz Mate':
self.assertEqual(200, product.price_member)
self.assertEqual(250, product.price_non_member)
productcheck[product.id] = 1
self.assertEqual(3, len(productcheck))
def test_change_product(self):
with self.db as db:
product = db.create_product('Club Mate', 200, 200)
product.name = 'Flora Power Mate'
product.price_member = 150
product.price_non_member = 250
db.change_product(product)
checkproduct = db.list_products()[0]
self.assertEqual('Flora Power Mate', checkproduct.name)
self.assertEqual(150, checkproduct.price_member)
self.assertEqual(250, checkproduct.price_non_member)
product._product_id = -1
with self.assertRaises(DatabaseConsistencyError):
db.change_product(product)
product2 = db.create_product('Club Mate', 200, 200)
product2.name = 'Flora Power Mate'
with self.assertRaises(DatabaseConsistencyError):
# Should fail, as a product with the same name already exists.
db.change_product(product2)
def test_delete_product(self):
with self.db as db:
product = db.create_product('Club Mate', 200, 200)
product2 = db.create_product('Flora Power Mate', 200, 200)
self.assertEqual(2, len(db.list_products()))
db.delete_product(product)
# Only the other product should remain in the DB
products = db.list_products()
self.assertEqual(1, len(products))
self.assertEqual(product2, products[0])
with self.assertRaises(DatabaseConsistencyError):
# Should fail, as the product does not exist anymore
db.delete_product(product)
def test_deposit(self):
with self.db as db:
with db.transaction() as c:
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
user2 = db.create_user('testuser2', 'supersecurepassword', 'testuser@example.com', True, True)
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user.id])
self.assertEqual(0, c.fetchone()[0])
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id])
self.assertEqual(0, c.fetchone()[0])
db.deposit(user, 1337)
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user.id])
self.assertEqual(1337, c.fetchone()[0])
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id])
self.assertEqual(0, c.fetchone()[0])
with self.assertRaises(ValueError):
# Should fail, negative amount
db.deposit(user, -42)
user._user_id = -1
with self.assertRaises(DatabaseConsistencyError):
# Should fail, user id -1 does not exist
db.deposit(user, 42)
def test_restock(self):
with self.db as db:
with db.transaction() as c:
product = db.create_product('Club Mate', 200, 200)
product2 = db.create_product('Flora Power Mate', 200, 200)
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product.id])
self.assertEqual(0, c.fetchone()[0])
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product2.id])
self.assertEqual(0, c.fetchone()[0])
db.restock(product, 42)
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product.id])
self.assertEqual(42, c.fetchone()[0])
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product2.id])
self.assertEqual(0, c.fetchone()[0])
product._product_id = -1
with self.assertRaises(DatabaseConsistencyError):
# Should fail, product id -1 does not exist
db.restock(product, 42)
def test_consumption(self):
with self.db as db:
# Set up test case
user1 = db.create_user('user1', 'supersecurepassword', 'testuser@example.com', member=True)
user2 = db.create_user('user2', 'supersecurepassword', 'testuser@example.com', member=False)
user3 = db.create_user('user3', 'supersecurepassword', 'testuser@example.com', member=False)
db.deposit(user1, 1337)
db.deposit(user2, 4242)
db.deposit(user3, 1234)
clubmate = db.create_product('Club Mate', 200, 200)
florapowermate = db.create_product('Flora Power Mate', 150, 250)
fritzmate = db.create_product('Fritz Mate', 200, 200)
db.restock(clubmate, 50)
db.restock(florapowermate, 70)
db.restock(fritzmate, 10)
# user1 is somewhat addicted to caffeine
db.increment_consumption(user1, clubmate, 1)
db.increment_consumption(user1, clubmate, 2)
db.increment_consumption(user1, florapowermate, 3)
# user2 is reeeally addicted
db.increment_consumption(user2, clubmate, 7)
db.increment_consumption(user2, florapowermate, 3)
db.increment_consumption(user2, florapowermate, 4)
with db.transaction(exclusive=False) as c:
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user1.id])
self.assertEqual(1337 - 200 * 3 - 150 * 3, c.fetchone()[0])
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id])
self.assertEqual(4242 - 200 * 7 - 250 * 7, c.fetchone()[0])
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user3.id])
self.assertEqual(1234, c.fetchone()[0])
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [clubmate.id])
self.assertEqual(40, c.fetchone()[0])
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [florapowermate.id])
self.assertEqual(60, c.fetchone()[0])
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [fritzmate.id])
self.assertEqual(10, c.fetchone()[0])

View file

@ -1,20 +1,20 @@
import unittest import unittest
from matemat.db import Database from matemat.db import DatabaseWrapper
class DatabaseTest(unittest.TestCase): class DatabaseTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.db = Database(':memory:') self.db = DatabaseWrapper(':memory:')
def test_create_schema(self): def test_create_schema(self):
""" """
Test creation of database schema in an empty database Test creation of database schema in an empty database
""" """
with self.db as db: with self.db as db:
self.assertEqual(Database.SCHEMA_VERSION, db._user_version) self.assertEqual(DatabaseWrapper.SCHEMA_VERSION, db._user_version)
def test_in_transaction(self): def test_in_transaction(self):
""" """
@ -74,22 +74,3 @@ class DatabaseTest(unittest.TestCase):
c = db._sqlite_db.cursor() c = db._sqlite_db.cursor()
c.execute("SELECT * FROM users") c.execute("SELECT * FROM users")
self.assertIsNone(c.fetchone()) self.assertIsNone(c.fetchone())
def test_create_user(self):
with self.db as db:
with db.transaction() as c:
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
c.execute("SELECT * FROM users")
row = c.fetchone()
if row is None:
self.fail()
self.assertEqual('testuser', row[1])
self.assertEqual('testuser@example.com', row[2])
self.assertEqual(0, row[5])
self.assertEqual(1, row[6])
def test_create_existing_user(self):
with self.db as db:
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
with self.assertRaises(ValueError):
db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com')

133
matemat/db/wrapper.py Normal file
View file

@ -0,0 +1,133 @@
import apsw
from matemat.exceptions import DatabaseConsistencyError
class Transaction(object):
def __init__(self, db: apsw.Connection, wrapper: 'DatabaseWrapper', exclusive: bool = True):
self._db: apsw.Connection = db
self._cursor = None
self._excl = exclusive
self._wrapper: DatabaseWrapper = wrapper
self._is_dummy: bool = False
def __enter__(self):
if self._wrapper._in_transaction:
self._is_dummy = True
return self._db.cursor()
else:
self._is_dummy = False
self._cursor = self._db.cursor()
self._wrapper._in_transaction = True
if self._excl:
self._cursor.execute('BEGIN EXCLUSIVE')
else:
self._cursor.execute('BEGIN')
return self._cursor
def __exit__(self, exc_type, exc_val, exc_tb):
if self._is_dummy:
return
if exc_type is None:
self._cursor.execute('COMMIT')
else:
self._cursor.execute('ROLLBACK')
self._wrapper._in_transaction = False
if exc_type == apsw.ConstraintError:
raise DatabaseConsistencyError(str(exc_val))
class DatabaseWrapper(object):
SCHEMA_VERSION = 1
SCHEMA = '''
CREATE TABLE users (
user_id INTEGER PRIMARY KEY,
username TEXT NOT NULL,
email TEXT DEFAULT NULL,
password TEXT NOT NULL,
touchkey TEXT DEFAULT NULL,
is_admin INTEGER(1) NOT NULL DEFAULT 0,
is_member INTEGER(1) NOT NULL DEFAULT 1,
balance INTEGER(8) NOT NULL DEFAULT 0,
lastchange INTEGER(8) NOT NULL DEFAULT 0
);
CREATE TABLE products (
product_id INTEGER PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
stock INTEGER(8) NOT NULL DEFAULT 0,
price_member INTEGER(8) NOT NULL,
price_non_member INTEGER(8) NOT NULL
);
CREATE TABLE consumption (
user_id INTEGER NOT NULL,
product_id INTEGER NOT NULL,
count INTEGER(8) NOT NULL DEFAULT 0,
PRIMARY KEY (user_id, product_id),
FOREIGN KEY (user_id) REFERENCES users(user_id)
ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (product_id) REFERENCES products(product_id)
ON DELETE CASCADE ON UPDATE CASCADE
)
'''
def __init__(self, filename: str):
self._filename: str = filename
self._sqlite_db: apsw.Connection = None
self._in_transaction: bool = False
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def transaction(self, exclusive: bool = True) -> Transaction:
return Transaction(self._sqlite_db, self, exclusive)
def _setup(self):
with self.transaction() as c:
version: int = self._user_version
if version < 1:
c.execute(self.SCHEMA)
elif version < self.SCHEMA_VERSION:
self._upgrade(old=version, new=self.SCHEMA_VERSION)
self._user_version = self.SCHEMA_VERSION
def _upgrade(self, old: int, new: int):
pass
def connect(self):
if self.is_connected():
raise RuntimeError(f'Database connection to {self._filename} is already established.')
self._sqlite_db = apsw.Connection(self._filename)
self._setup()
def close(self):
if self._sqlite_db is None:
raise RuntimeError(f'Database connection to {self._filename} is not established.')
if self.in_transaction():
raise RuntimeError(f'A transaction is still ongoing.')
self._sqlite_db.close()
self._sqlite_db = None
def in_transaction(self) -> bool:
return self._in_transaction
def is_connected(self) -> bool:
return self._sqlite_db is not None
@property
def _user_version(self) -> int:
cursor = self._sqlite_db.cursor()
cursor.execute('PRAGMA user_version')
version = int(cursor.fetchone()[0])
return version
@_user_version.setter
def _user_version(self, version: int):
cursor = self._sqlite_db.cursor()
cursor.execute(f'PRAGMA user_version = {version}')

View file

@ -1,11 +1,11 @@
class AuthenticationException(BaseException): class AuthenticationError(BaseException):
def __init__(self, msg: str = None): def __init__(self, msg: str = None):
self._msg = msg self._msg = msg
def __str__(self) -> str: def __str__(self) -> str:
return f'AuthenticationException: {self._msg}' return f'AuthenticationError: {self._msg}'
@property @property
def msg(self) -> str: def msg(self) -> str:

View file

@ -0,0 +1,12 @@
class DatabaseConsistencyError(BaseException):
def __init__(self, msg: str = None):
self._msg = msg
def __str__(self) -> str:
return f'DatabaseConsistencyError: {self._msg}'
@property
def msg(self) -> str:
return self._msg

View file

@ -1,2 +1,3 @@
from .AuthenticatonException import AuthenticationException from .AuthenticatonError import AuthenticationError
from .DatabaseConsistencyError import DatabaseConsistencyError

View file

@ -1,4 +1,3 @@
class Product(object): class Product(object):
def __init__(self, def __init__(self,
@ -11,6 +10,14 @@ class Product(object):
self._price_member: int = price_member self._price_member: int = price_member
self._price_non_member: int = price_non_member self._price_non_member: int = price_non_member
def __eq__(self, other):
if other is None or not isinstance(other, Product):
return False
return self._product_id == other._product_id \
and self._name == other._name \
and self._price_member == other._price_member \
and self._price_non_member == other._price_non_member
@property @property
def id(self) -> int: def id(self) -> int:
return self._product_id return self._product_id
@ -19,6 +26,11 @@ class Product(object):
def name(self) -> str: def name(self) -> str:
return self._name return self._name
@name.setter
def name(self, name: str):
self._name = name
@property @property
def price_member(self) -> int: def price_member(self) -> int:
return self._price_member return self._price_member

View file

@ -10,14 +10,21 @@ class User(object):
email: Optional[str] = None, email: Optional[str] = None,
admin: bool = False, admin: bool = False,
member: bool = True): member: bool = True):
if user_id == -1:
raise ValueError('Invalid user ID')
self._user_id: int = user_id self._user_id: int = user_id
self._username: str = username self._username: str = username
self._email: Optional[str] = email self._email: Optional[str] = email
self._admin: bool = admin self._admin: bool = admin
self._member: bool = member self._member: bool = member
def __eq__(self, other):
if other is None or not isinstance(other, User):
return False
return self._user_id == other._user_id \
and self._username == other._username \
and self._email == other._email \
and self._admin == other._admin \
and self._member == other._member
@property @property
def id(self) -> int: def id(self) -> int:
return self._user_id return self._user_id