forked from s3lph/matemat
Merge branch 'db-tests' into 'master'
Additional unit tests for the Database Facade See merge request s3lph/matemat!1
This commit is contained in:
commit
8059765410
11 changed files with 570 additions and 187 deletions
|
@ -8,4 +8,4 @@ test:
|
||||||
- pip3 install wheel
|
- pip3 install wheel
|
||||||
- pip3 install -r requirements.txt
|
- pip3 install -r requirements.txt
|
||||||
- python3-coverage run --branch -m unittest discover matemat
|
- python3-coverage run --branch -m unittest discover matemat
|
||||||
- python3-coverage report -m
|
- python3-coverage report -m --include 'matemat/*' --omit '*/test_*.py'
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
|
|
||||||
from .database import Database
|
from .wrapper import DatabaseWrapper
|
||||||
|
from .facade import DatabaseFacade as Database
|
||||||
|
|
|
@ -1,142 +1,31 @@
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import apsw
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
from matemat.primitives import User, Product
|
from matemat.primitives import User, Product
|
||||||
from matemat.exceptions import AuthenticationException
|
from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
|
||||||
|
from matemat.db import DatabaseWrapper
|
||||||
|
|
||||||
|
|
||||||
class Transaction(object):
|
class DatabaseFacade(object):
|
||||||
|
|
||||||
def __init__(self, db: apsw.Connection, wrapper: 'Database', exclusive: bool = True):
|
|
||||||
self._db: apsw.Connection = db
|
|
||||||
self._cursor = None
|
|
||||||
self._excl = exclusive
|
|
||||||
self._wrapper: Database = wrapper
|
|
||||||
self._is_dummy: bool = False
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
if self._wrapper._in_transaction:
|
|
||||||
self._is_dummy = True
|
|
||||||
return self._db.cursor()
|
|
||||||
else:
|
|
||||||
self._is_dummy = False
|
|
||||||
self._cursor = self._db.cursor()
|
|
||||||
self._wrapper._in_transaction = True
|
|
||||||
if self._excl:
|
|
||||||
self._cursor.execute('BEGIN EXCLUSIVE')
|
|
||||||
else:
|
|
||||||
self._cursor.execute('BEGIN')
|
|
||||||
return self._cursor
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
if self._is_dummy:
|
|
||||||
return
|
|
||||||
if exc_type is None:
|
|
||||||
self._cursor.execute('COMMIT')
|
|
||||||
else:
|
|
||||||
self._cursor.execute('ROLLBACK')
|
|
||||||
self._wrapper._in_transaction = False
|
|
||||||
|
|
||||||
|
|
||||||
class Database(object):
|
|
||||||
SCHEMA_VERSION = 1
|
|
||||||
|
|
||||||
SCHEMA = '''
|
|
||||||
CREATE TABLE users (
|
|
||||||
user_id INTEGER PRIMARY KEY,
|
|
||||||
username TEXT NOT NULL,
|
|
||||||
email TEXT DEFAULT NULL,
|
|
||||||
password TEXT NOT NULL,
|
|
||||||
touchkey TEXT DEFAULT NULL,
|
|
||||||
is_admin INTEGER(1) NOT NULL DEFAULT 0,
|
|
||||||
is_member INTEGER(1) NOT NULL DEFAULT 1,
|
|
||||||
balance INTEGER(8) NOT NULL DEFAULT 0,
|
|
||||||
lastchange INTEGER(8) NOT NULL DEFAULT 0
|
|
||||||
);
|
|
||||||
CREATE TABLE products (
|
|
||||||
product_id INTEGER PRIMARY KEY,
|
|
||||||
name TEXT UNIQUE NOT NULL,
|
|
||||||
stock INTEGER(8) NOT NULL DEFAULT 0,
|
|
||||||
price_member INTEGER(8) NOT NULL,
|
|
||||||
price_non_member INTEGER(8) NOT NULL
|
|
||||||
);
|
|
||||||
CREATE TABLE consumption (
|
|
||||||
user_id INTEGER NOT NULL,
|
|
||||||
product_id INTEGER NOT NULL,
|
|
||||||
count INTEGER(8) NOT NULL DEFAULT 0,
|
|
||||||
PRIMARY KEY (user_id, product_id),
|
|
||||||
FOREIGN KEY (user_id) REFERENCES users(user_id)
|
|
||||||
ON DELETE CASCADE ON UPDATE CASCADE,
|
|
||||||
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
|
||||||
ON DELETE CASCADE ON UPDATE CASCADE
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, filename: str):
|
def __init__(self, filename: str):
|
||||||
self._filename: str = filename
|
self.db: DatabaseWrapper = DatabaseWrapper(filename)
|
||||||
self._sqlite_db: apsw.Connection = None
|
|
||||||
self._in_transaction: bool = False
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.connect()
|
self.db.__enter__()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.close()
|
self.db.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
def transaction(self, exclusive: bool = True) -> Transaction:
|
def transaction(self, exclusive: bool = True):
|
||||||
return Transaction(self._sqlite_db, self, exclusive)
|
return self.db.transaction(exclusive=exclusive)
|
||||||
|
|
||||||
def _setup(self):
|
|
||||||
with self.transaction() as c:
|
|
||||||
version: int = self._user_version
|
|
||||||
if version < 1:
|
|
||||||
c.execute(self.SCHEMA)
|
|
||||||
elif version < self.SCHEMA_VERSION:
|
|
||||||
self._upgrade(old=version, new=self.SCHEMA_VERSION)
|
|
||||||
self._user_version = self.SCHEMA_VERSION
|
|
||||||
|
|
||||||
def _upgrade(self, old: int, new: int):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
if self.is_connected():
|
|
||||||
raise RuntimeError(f'Database connection to {self._filename} is already established.')
|
|
||||||
self._sqlite_db = apsw.Connection(self._filename)
|
|
||||||
self._setup()
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
if self._sqlite_db is None:
|
|
||||||
raise RuntimeError(f'Database connection to {self._filename} is not established.')
|
|
||||||
if self.in_transaction():
|
|
||||||
raise RuntimeError(f'A transaction is still ongoing.')
|
|
||||||
self._sqlite_db.close()
|
|
||||||
self._sqlite_db = None
|
|
||||||
|
|
||||||
def in_transaction(self) -> bool:
|
|
||||||
return self._in_transaction
|
|
||||||
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self._sqlite_db is not None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _user_version(self) -> int:
|
|
||||||
cursor = self._sqlite_db.cursor()
|
|
||||||
cursor.execute('PRAGMA user_version')
|
|
||||||
version = int(cursor.fetchone()[0])
|
|
||||||
return version
|
|
||||||
|
|
||||||
@_user_version.setter
|
|
||||||
def _user_version(self, version: int):
|
|
||||||
cursor = self._sqlite_db.cursor()
|
|
||||||
cursor.execute(f'PRAGMA user_version = {version}')
|
|
||||||
|
|
||||||
def list_users(self) -> List[User]:
|
def list_users(self) -> List[User]:
|
||||||
users: List[User] = []
|
users: List[User] = []
|
||||||
with self.transaction(exclusive=False) as c:
|
with self.db.transaction(exclusive=False) as c:
|
||||||
for row in c.execute('''
|
for row in c.execute('''
|
||||||
SELECT user_id, username, email, is_admin, is_member
|
SELECT user_id, username, email, is_admin, is_member
|
||||||
FROM users
|
FROM users
|
||||||
|
@ -153,7 +42,7 @@ class Database(object):
|
||||||
member: bool = True) -> User:
|
member: bool = True) -> User:
|
||||||
pwhash: str = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12))
|
pwhash: str = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12))
|
||||||
user_id: int = -1
|
user_id: int = -1
|
||||||
with self.transaction() as c:
|
with self.db.transaction() as c:
|
||||||
c.execute('SELECT user_id FROM users WHERE username = ?', [username])
|
c.execute('SELECT user_id FROM users WHERE username = ?', [username])
|
||||||
if c.fetchone() is not None:
|
if c.fetchone() is not None:
|
||||||
raise ValueError(f'A user with the name \'{username}\' already exists.')
|
raise ValueError(f'A user with the name \'{username}\' already exists.')
|
||||||
|
@ -171,33 +60,35 @@ class Database(object):
|
||||||
user_id = int(c.fetchone()[0])
|
user_id = int(c.fetchone()[0])
|
||||||
return User(user_id, username, email, admin, member)
|
return User(user_id, username, email, admin, member)
|
||||||
|
|
||||||
def login(self, username: str, password: str) -> Optional[User]:
|
def login(self, username: str, password: Optional[str] = None, touchkey: Optional[str] = None) -> User:
|
||||||
with self.transaction(exclusive=False) as c:
|
if (password is None) == (touchkey is None):
|
||||||
|
raise ValueError('Exactly one of password and touchkey must be provided')
|
||||||
|
with self.db.transaction(exclusive=False) as c:
|
||||||
c.execute('''
|
c.execute('''
|
||||||
SELECT user_id, username, email, password, is_admin, is_member
|
SELECT user_id, username, email, password, touchkey, is_admin, is_member
|
||||||
FROM users
|
FROM users
|
||||||
WHERE username = ?
|
WHERE username = ?
|
||||||
''', [username])
|
''', [username])
|
||||||
row = c.fetchone()
|
row = c.fetchone()
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
raise AuthenticationError('User does not exist')
|
||||||
user_id, username, email, pwhash, admin, member = row
|
user_id, username, email, pwhash, tkhash, admin, member = row
|
||||||
if not bcrypt.checkpw(password.encode('utf-8'), pwhash):
|
if password is not None and not bcrypt.checkpw(password.encode('utf-8'), pwhash):
|
||||||
return None
|
raise AuthenticationError('Password mismatch')
|
||||||
|
elif touchkey is not None and not bcrypt.checkpw(touchkey.encode('utf-8'), tkhash):
|
||||||
|
raise AuthenticationError('Touchkey mismatch')
|
||||||
return User(user_id, username, email, admin, member)
|
return User(user_id, username, email, admin, member)
|
||||||
|
|
||||||
def change_password(self, user: User, oldpass: str, newpass: str, newpass2: str, verify_password: bool = True):
|
def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True):
|
||||||
if newpass != newpass2:
|
with self.db.transaction() as c:
|
||||||
raise ValueError('New passwords don\'t match.')
|
|
||||||
with self.transaction() as c:
|
|
||||||
c.execute('''
|
c.execute('''
|
||||||
SELECT password FROM users WHERE user_id = ?
|
SELECT password FROM users WHERE user_id = ?
|
||||||
''', [user.id])
|
''', [user.id])
|
||||||
row = c.fetchone()
|
row = c.fetchone()
|
||||||
if row is None:
|
if row is None:
|
||||||
raise AuthenticationException('User does not exist in database.')
|
raise AuthenticationError('User does not exist in database.')
|
||||||
if verify_password and not bcrypt.checkpw(oldpass.encode('utf-8'), row[0]):
|
if verify_password and not bcrypt.checkpw(oldpass.encode('utf-8'), row[0]):
|
||||||
raise AuthenticationException('Old password does not match.')
|
raise AuthenticationError('Old password does not match.')
|
||||||
pwhash: str = bcrypt.hashpw(newpass.encode('utf-8'), bcrypt.gensalt(12))
|
pwhash: str = bcrypt.hashpw(newpass.encode('utf-8'), bcrypt.gensalt(12))
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE users SET password = :pwhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id
|
UPDATE users SET password = :pwhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id
|
||||||
|
@ -207,15 +98,15 @@ class Database(object):
|
||||||
})
|
})
|
||||||
|
|
||||||
def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True):
|
def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True):
|
||||||
with self.transaction() as c:
|
with self.db.transaction() as c:
|
||||||
c.execute('''
|
c.execute('''
|
||||||
SELECT password FROM users WHERE user_id = ?
|
SELECT password FROM users WHERE user_id = ?
|
||||||
''', [user.id])
|
''', [user.id])
|
||||||
row = c.fetchone()
|
row = c.fetchone()
|
||||||
if row is None:
|
if row is None:
|
||||||
raise AuthenticationException('User does not exist in database.')
|
raise AuthenticationError('User does not exist in database.')
|
||||||
if verify_password and not bcrypt.checkpw(password.encode('utf-8'), row[0]):
|
if verify_password and not bcrypt.checkpw(password.encode('utf-8'), row[0]):
|
||||||
raise AuthenticationException('Password does not match.')
|
raise AuthenticationError('Password does not match.')
|
||||||
tkhash: str = bcrypt.hashpw(touchkey.encode('utf-8'), bcrypt.gensalt(12)) if touchkey is not None else None
|
tkhash: str = bcrypt.hashpw(touchkey.encode('utf-8'), bcrypt.gensalt(12)) if touchkey is not None else None
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE users SET touchkey = :tkhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id
|
UPDATE users SET touchkey = :tkhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id
|
||||||
|
@ -225,7 +116,7 @@ class Database(object):
|
||||||
})
|
})
|
||||||
|
|
||||||
def change_user(self, user: User):
|
def change_user(self, user: User):
|
||||||
with self.transaction() as c:
|
with self.db.transaction() as c:
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE users SET
|
UPDATE users SET
|
||||||
email = :email,
|
email = :email,
|
||||||
|
@ -239,19 +130,27 @@ class Database(object):
|
||||||
'is_admin': user.is_admin,
|
'is_admin': user.is_admin,
|
||||||
'is_member': user.is_member
|
'is_member': user.is_member
|
||||||
})
|
})
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(
|
||||||
|
f'change_user should affect 1 users row, but affected {affected}')
|
||||||
|
|
||||||
def delete_user(self, user: User):
|
def delete_user(self, user: User):
|
||||||
with self.transaction() as c:
|
with self.db.transaction() as c:
|
||||||
c.execute('''
|
c.execute('''
|
||||||
DELETE FROM users
|
DELETE FROM users
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
''', [user.id])
|
''', [user.id])
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(
|
||||||
|
f'delete_user should affect 1 users row, but affected {affected}')
|
||||||
|
|
||||||
def list_products(self) -> List[Product]:
|
def list_products(self) -> List[Product]:
|
||||||
products: List[Product] = []
|
products: List[Product] = []
|
||||||
with self.transaction(exclusive=False) as c:
|
with self.db.transaction(exclusive=False) as c:
|
||||||
for row in c.execute('''
|
for row in c.execute('''
|
||||||
SELECT product_id, name, price_member, price_external
|
SELECT product_id, name, price_member, price_non_member
|
||||||
FROM products
|
FROM products
|
||||||
'''):
|
'''):
|
||||||
product_id, name, price_member, price_external = row
|
product_id, name, price_member, price_external = row
|
||||||
|
@ -260,7 +159,7 @@ class Database(object):
|
||||||
|
|
||||||
def create_product(self, name: str, price_member: int, price_non_member: int) -> Product:
|
def create_product(self, name: str, price_member: int, price_non_member: int) -> Product:
|
||||||
product_id: int = -1
|
product_id: int = -1
|
||||||
with self.transaction() as c:
|
with self.db.transaction() as c:
|
||||||
c.execute('SELECT product_id FROM products WHERE name = ?', [name])
|
c.execute('SELECT product_id FROM products WHERE name = ?', [name])
|
||||||
if c.fetchone() is not None:
|
if c.fetchone() is not None:
|
||||||
raise ValueError(f'A product with the name \'{name}\' already exists.')
|
raise ValueError(f'A product with the name \'{name}\' already exists.')
|
||||||
|
@ -277,41 +176,47 @@ class Database(object):
|
||||||
return Product(product_id, name, price_member, price_non_member)
|
return Product(product_id, name, price_member, price_non_member)
|
||||||
|
|
||||||
def change_product(self, product: Product):
|
def change_product(self, product: Product):
|
||||||
if product.id == -1:
|
with self.db.transaction() as c:
|
||||||
raise ValueError('Invalid product ID')
|
|
||||||
with self.transaction() as c:
|
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE products
|
UPDATE products
|
||||||
SET
|
SET
|
||||||
name = :name,
|
name = :name,
|
||||||
price_member = :price_member,
|
price_member = :price_member,
|
||||||
price_non_member = :price_non_member
|
price_non_member = :price_non_member
|
||||||
WHERE product_id = :product_is
|
WHERE product_id = :product_id
|
||||||
''', {
|
''', {
|
||||||
'product_id': product.id,
|
'product_id': product.id,
|
||||||
|
'name': product.name,
|
||||||
'price_member': product.price_member,
|
'price_member': product.price_member,
|
||||||
'price_non_member': product.price_non_member
|
'price_non_member': product.price_non_member
|
||||||
})
|
})
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(
|
||||||
|
f'change_product should affect 1 products row, but affected {affected}')
|
||||||
|
|
||||||
def delete_product(self, product: Product):
|
def delete_product(self, product: Product):
|
||||||
if product.id == -1:
|
with self.db.transaction() as c:
|
||||||
raise ValueError('Invalid product ID')
|
|
||||||
with self.transaction() as c:
|
|
||||||
c.execute('''
|
c.execute('''
|
||||||
DELETE FROM products
|
DELETE FROM products
|
||||||
WHERE product_id = ?
|
WHERE product_id = ?
|
||||||
''', [product.id])
|
''', [product.id])
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(
|
||||||
|
f'delete_product should affect 1 products row, but affected {affected}')
|
||||||
|
|
||||||
def increment_consumption(self, user: User, product: Product, count: int = 1):
|
def increment_consumption(self, user: User, product: Product, count: int = 1):
|
||||||
if product.id == -1:
|
with self.db.transaction() as c:
|
||||||
raise ValueError('Invalid product ID')
|
|
||||||
with self.transaction() as c:
|
|
||||||
c.execute('''
|
c.execute('''
|
||||||
SELECT count
|
SELECT count
|
||||||
FROM consumption
|
FROM consumption
|
||||||
WHERE user_id = :user_id
|
WHERE user_id = :user_id
|
||||||
AND product_id = :product_id
|
AND product_id = :product_id
|
||||||
''')
|
''', {
|
||||||
|
'user_id': user.id,
|
||||||
|
'product_id': product.id
|
||||||
|
})
|
||||||
row = c.fetchone()
|
row = c.fetchone()
|
||||||
if row is None:
|
if row is None:
|
||||||
c.execute('''
|
c.execute('''
|
||||||
|
@ -332,6 +237,10 @@ class Database(object):
|
||||||
'product_id': product.id,
|
'product_id': product.id,
|
||||||
'count': count
|
'count': count
|
||||||
})
|
})
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(
|
||||||
|
f'increment_consumption should affect 1 consumption row, but affected {affected}')
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET balance = balance - :cost
|
SET balance = balance - :cost
|
||||||
|
@ -339,6 +248,10 @@ class Database(object):
|
||||||
'user_id': user.id,
|
'user_id': user.id,
|
||||||
'cost': count * product.price_member if user.is_member else count * product.price_non_member
|
'cost': count * product.price_member if user.is_member else count * product.price_non_member
|
||||||
})
|
})
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(
|
||||||
|
f'increment_consumption should affect 1 users row, but affected {affected}')
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE products
|
UPDATE products
|
||||||
SET stock = stock - :count
|
SET stock = stock - :count
|
||||||
|
@ -347,11 +260,13 @@ class Database(object):
|
||||||
'product_id': product.id,
|
'product_id': product.id,
|
||||||
'count': count
|
'count': count
|
||||||
})
|
})
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(
|
||||||
|
f'increment_consumption should affect 1 products row, but affected {affected}')
|
||||||
|
|
||||||
def restock(self, product: Product, count: int):
|
def restock(self, product: Product, count: int):
|
||||||
if product.id == -1:
|
with self.db.transaction() as c:
|
||||||
raise ValueError('Invalid product ID')
|
|
||||||
with self.transaction() as c:
|
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE products
|
UPDATE products
|
||||||
SET stock = stock + :count
|
SET stock = stock + :count
|
||||||
|
@ -360,9 +275,14 @@ class Database(object):
|
||||||
'product_id': product.id,
|
'product_id': product.id,
|
||||||
'count': count
|
'count': count
|
||||||
})
|
})
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}')
|
||||||
|
|
||||||
def deposit(self, user: User, amount: int):
|
def deposit(self, user: User, amount: int):
|
||||||
with self.transaction() as c:
|
if amount < 0:
|
||||||
|
raise ValueError('Cannot deposit a negative value')
|
||||||
|
with self.db.transaction() as c:
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET balance = balance + :amount
|
SET balance = balance + :amount
|
||||||
|
@ -371,3 +291,6 @@ class Database(object):
|
||||||
'user_id': user.id,
|
'user_id': user.id,
|
||||||
'amount': amount
|
'amount': amount
|
||||||
})
|
})
|
||||||
|
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||||
|
if affected != 1:
|
||||||
|
raise DatabaseConsistencyError(f'deposit should affect 1 users row, but affected {affected}')
|
313
matemat/db/test/test_facade.py
Normal file
313
matemat/db/test/test_facade.py
Normal file
|
@ -0,0 +1,313 @@
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
|
||||||
|
from matemat.db import Database
|
||||||
|
from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.db = Database(':memory:')
|
||||||
|
|
||||||
|
def test_create_user(self):
|
||||||
|
with self.db as db:
|
||||||
|
with db.transaction(exclusive=False) as c:
|
||||||
|
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||||
|
c.execute("SELECT * FROM users")
|
||||||
|
row = c.fetchone()
|
||||||
|
self.assertEqual('testuser', row[1])
|
||||||
|
self.assertEqual('testuser@example.com', row[2])
|
||||||
|
self.assertEqual(0, row[5])
|
||||||
|
self.assertEqual(1, row[6])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com')
|
||||||
|
|
||||||
|
def test_list_users(self):
|
||||||
|
with self.db as db:
|
||||||
|
users = db.list_users()
|
||||||
|
self.assertEqual(0, len(users))
|
||||||
|
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
|
||||||
|
db.create_user('anothertestuser', 'otherpassword', 'anothertestuser@example.com', False, True)
|
||||||
|
db.create_user('yatu', 'igotapasswordtoo', 'yatu@example.com', False, False)
|
||||||
|
users = db.list_users()
|
||||||
|
self.assertEqual(3, len(users))
|
||||||
|
usercheck = {}
|
||||||
|
for user in users:
|
||||||
|
if user.name == 'testuser':
|
||||||
|
self.assertEqual('testuser@example.com', user.email)
|
||||||
|
self.assertTrue(user.is_member)
|
||||||
|
self.assertTrue(user.is_admin)
|
||||||
|
elif user.name == 'anothertestuser':
|
||||||
|
self.assertEqual('anothertestuser@example.com', user.email)
|
||||||
|
self.assertTrue(user.is_member)
|
||||||
|
self.assertFalse(user.is_admin)
|
||||||
|
elif user.name == 'yatu':
|
||||||
|
self.assertEqual('yatu@example.com', user.email)
|
||||||
|
self.assertFalse(user.is_member)
|
||||||
|
self.assertFalse(user.is_admin)
|
||||||
|
usercheck[user.id] = 1
|
||||||
|
self.assertEqual(3, len(usercheck))
|
||||||
|
|
||||||
|
def test_login(self):
|
||||||
|
with self.db as db:
|
||||||
|
with db.transaction() as c:
|
||||||
|
u = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||||
|
# Add a touchkey without using the provided function
|
||||||
|
c.execute('''UPDATE users SET touchkey = :tkhash WHERE user_id = :user_id''', {
|
||||||
|
'tkhash': bcrypt.hashpw(b'0123', bcrypt.gensalt(12)),
|
||||||
|
'user_id': u.id
|
||||||
|
})
|
||||||
|
user = db.login('testuser', 'supersecurepassword')
|
||||||
|
self.assertEqual(u.id, user.id)
|
||||||
|
user = db.login('testuser', touchkey='0123')
|
||||||
|
self.assertEqual(u.id, user.id)
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
# Inexistent user should fail
|
||||||
|
db.login('nooone', 'supersecurepassword')
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
# Wrong password should fail
|
||||||
|
db.login('testuser', 'anothersecurepassword')
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
# Wrong touchkey should fail
|
||||||
|
db.login('testuser', touchkey='0124')
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# No password or touchkey should fail
|
||||||
|
db.login('testuser')
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Both password and touchkey should fail
|
||||||
|
db.login('testuser', password='supersecurepassword', touchkey='0123')
|
||||||
|
|
||||||
|
def test_change_password(self):
|
||||||
|
with self.db as db:
|
||||||
|
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||||
|
db.login('testuser', 'supersecurepassword')
|
||||||
|
# Normal password change should succeed
|
||||||
|
db.change_password(user, 'supersecurepassword', 'mynewpassword')
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
db.login('testuser', 'supersecurepassword')
|
||||||
|
db.login('testuser', 'mynewpassword')
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
# Should fail due to wrong old password
|
||||||
|
db.change_password(user, 'iforgotmypassword', 'mynewpassword')
|
||||||
|
db.login('testuser', 'mynewpassword')
|
||||||
|
# This should pass even though the old password is not known (admin password reset)
|
||||||
|
db.change_password(user, '42', 'adminpasswordreset', verify_password=False)
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
db.login('testuser', 'mynewpassword')
|
||||||
|
db.login('testuser', 'adminpasswordreset')
|
||||||
|
user._user_id = -1
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
# Password change for an inexistent user should fail
|
||||||
|
db.change_password(user, 'adminpasswordreset', 'passwordwithoutuser')
|
||||||
|
|
||||||
|
def test_change_touchkey(self):
|
||||||
|
with self.db as db:
|
||||||
|
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||||
|
db.change_touchkey(user, 'supersecurepassword', '0123')
|
||||||
|
db.login('testuser', touchkey='0123')
|
||||||
|
# Normal touchkey change should succeed
|
||||||
|
db.change_touchkey(user, 'supersecurepassword', touchkey='4567')
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
db.login('testuser', touchkey='0123')
|
||||||
|
db.login('testuser', touchkey='4567')
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
# Should fail due to wrong old password
|
||||||
|
db.change_touchkey(user, 'iforgotmypassword', '89ab')
|
||||||
|
db.login('testuser', touchkey='4567')
|
||||||
|
# This should pass even though the old password is not known (admin password reset)
|
||||||
|
db.change_touchkey(user, '42', '89ab', verify_password=False)
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
db.login('testuser', touchkey='4567')
|
||||||
|
db.login('testuser', touchkey='89ab')
|
||||||
|
user._user_id = -1
|
||||||
|
with self.assertRaises(AuthenticationError):
|
||||||
|
# Touchkey change for an inexistent user should fail
|
||||||
|
db.change_touchkey(user, '89ab', '048c')
|
||||||
|
|
||||||
|
def test_change_user(self):
|
||||||
|
with self.db as db:
|
||||||
|
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
|
||||||
|
user.email = 'newaddress@example.com'
|
||||||
|
user.is_admin = False
|
||||||
|
user.is_member = False
|
||||||
|
db.change_user(user)
|
||||||
|
checkuser = db.login('testuser', 'supersecurepassword')
|
||||||
|
self.assertEqual('newaddress@example.com', checkuser.email)
|
||||||
|
self.assertFalse(checkuser.is_admin)
|
||||||
|
self.assertFalse(checkuser.is_member)
|
||||||
|
user._user_id = -1
|
||||||
|
with self.assertRaises(DatabaseConsistencyError):
|
||||||
|
db.change_user(user)
|
||||||
|
|
||||||
|
def test_delete_user(self):
|
||||||
|
with self.db as db:
|
||||||
|
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
|
||||||
|
db.login('testuser', 'supersecurepassword')
|
||||||
|
db.delete_user(user)
|
||||||
|
try:
|
||||||
|
# Should fail, as the user does not exist anymore
|
||||||
|
db.login('testuser', 'supersecurepassword')
|
||||||
|
except AuthenticationError as e:
|
||||||
|
self.assertEqual('User does not exist', e.msg)
|
||||||
|
with self.assertRaises(DatabaseConsistencyError):
|
||||||
|
# Should fail, as the user does not exist anymore
|
||||||
|
db.delete_user(user)
|
||||||
|
|
||||||
|
def test_create_product(self):
|
||||||
|
with self.db as db:
|
||||||
|
with db.transaction() as c:
|
||||||
|
db.create_product('Club Mate', 200, 200)
|
||||||
|
c.execute("SELECT * FROM products")
|
||||||
|
row = c.fetchone()
|
||||||
|
self.assertEqual('Club Mate', row[1])
|
||||||
|
self.assertEqual(0, row[2])
|
||||||
|
self.assertEqual(200, row[3])
|
||||||
|
self.assertEqual(200, row[4])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
db.create_product('Club Mate', 250, 250)
|
||||||
|
|
||||||
|
def test_list_products(self):
|
||||||
|
with self.db as db:
|
||||||
|
# Test empty list
|
||||||
|
products = db.list_products()
|
||||||
|
self.assertEqual(0, len(products))
|
||||||
|
db.create_product('Club Mate', 200, 200)
|
||||||
|
db.create_product('Flora Power Mate', 200, 200)
|
||||||
|
db.create_product('Fritz Mate', 200, 250)
|
||||||
|
products = db.list_products()
|
||||||
|
self.assertEqual(3, len(products))
|
||||||
|
productcheck = {}
|
||||||
|
for product in products:
|
||||||
|
if product.name == 'Club Mate':
|
||||||
|
self.assertEqual(200, product.price_member)
|
||||||
|
self.assertEqual(200, product.price_non_member)
|
||||||
|
elif product.name == 'Flora Power Mate':
|
||||||
|
self.assertEqual(200, product.price_member)
|
||||||
|
self.assertEqual(200, product.price_non_member)
|
||||||
|
elif product.name == 'Fritz Mate':
|
||||||
|
self.assertEqual(200, product.price_member)
|
||||||
|
self.assertEqual(250, product.price_non_member)
|
||||||
|
productcheck[product.id] = 1
|
||||||
|
self.assertEqual(3, len(productcheck))
|
||||||
|
|
||||||
|
def test_change_product(self):
|
||||||
|
with self.db as db:
|
||||||
|
product = db.create_product('Club Mate', 200, 200)
|
||||||
|
product.name = 'Flora Power Mate'
|
||||||
|
product.price_member = 150
|
||||||
|
product.price_non_member = 250
|
||||||
|
db.change_product(product)
|
||||||
|
checkproduct = db.list_products()[0]
|
||||||
|
self.assertEqual('Flora Power Mate', checkproduct.name)
|
||||||
|
self.assertEqual(150, checkproduct.price_member)
|
||||||
|
self.assertEqual(250, checkproduct.price_non_member)
|
||||||
|
product._product_id = -1
|
||||||
|
with self.assertRaises(DatabaseConsistencyError):
|
||||||
|
db.change_product(product)
|
||||||
|
product2 = db.create_product('Club Mate', 200, 200)
|
||||||
|
product2.name = 'Flora Power Mate'
|
||||||
|
with self.assertRaises(DatabaseConsistencyError):
|
||||||
|
# Should fail, as a product with the same name already exists.
|
||||||
|
db.change_product(product2)
|
||||||
|
|
||||||
|
def test_delete_product(self):
|
||||||
|
with self.db as db:
|
||||||
|
product = db.create_product('Club Mate', 200, 200)
|
||||||
|
product2 = db.create_product('Flora Power Mate', 200, 200)
|
||||||
|
|
||||||
|
self.assertEqual(2, len(db.list_products()))
|
||||||
|
db.delete_product(product)
|
||||||
|
|
||||||
|
# Only the other product should remain in the DB
|
||||||
|
products = db.list_products()
|
||||||
|
self.assertEqual(1, len(products))
|
||||||
|
self.assertEqual(product2, products[0])
|
||||||
|
|
||||||
|
with self.assertRaises(DatabaseConsistencyError):
|
||||||
|
# Should fail, as the product does not exist anymore
|
||||||
|
db.delete_product(product)
|
||||||
|
|
||||||
|
def test_deposit(self):
|
||||||
|
with self.db as db:
|
||||||
|
with db.transaction() as c:
|
||||||
|
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
|
||||||
|
user2 = db.create_user('testuser2', 'supersecurepassword', 'testuser@example.com', True, True)
|
||||||
|
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user.id])
|
||||||
|
self.assertEqual(0, c.fetchone()[0])
|
||||||
|
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id])
|
||||||
|
self.assertEqual(0, c.fetchone()[0])
|
||||||
|
db.deposit(user, 1337)
|
||||||
|
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user.id])
|
||||||
|
self.assertEqual(1337, c.fetchone()[0])
|
||||||
|
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id])
|
||||||
|
self.assertEqual(0, c.fetchone()[0])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Should fail, negative amount
|
||||||
|
db.deposit(user, -42)
|
||||||
|
user._user_id = -1
|
||||||
|
with self.assertRaises(DatabaseConsistencyError):
|
||||||
|
# Should fail, user id -1 does not exist
|
||||||
|
db.deposit(user, 42)
|
||||||
|
|
||||||
|
def test_restock(self):
|
||||||
|
with self.db as db:
|
||||||
|
with db.transaction() as c:
|
||||||
|
product = db.create_product('Club Mate', 200, 200)
|
||||||
|
product2 = db.create_product('Flora Power Mate', 200, 200)
|
||||||
|
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product.id])
|
||||||
|
self.assertEqual(0, c.fetchone()[0])
|
||||||
|
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product2.id])
|
||||||
|
self.assertEqual(0, c.fetchone()[0])
|
||||||
|
db.restock(product, 42)
|
||||||
|
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product.id])
|
||||||
|
self.assertEqual(42, c.fetchone()[0])
|
||||||
|
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [product2.id])
|
||||||
|
self.assertEqual(0, c.fetchone()[0])
|
||||||
|
product._product_id = -1
|
||||||
|
with self.assertRaises(DatabaseConsistencyError):
|
||||||
|
# Should fail, product id -1 does not exist
|
||||||
|
db.restock(product, 42)
|
||||||
|
|
||||||
|
def test_consumption(self):
|
||||||
|
with self.db as db:
|
||||||
|
# Set up test case
|
||||||
|
user1 = db.create_user('user1', 'supersecurepassword', 'testuser@example.com', member=True)
|
||||||
|
user2 = db.create_user('user2', 'supersecurepassword', 'testuser@example.com', member=False)
|
||||||
|
user3 = db.create_user('user3', 'supersecurepassword', 'testuser@example.com', member=False)
|
||||||
|
db.deposit(user1, 1337)
|
||||||
|
db.deposit(user2, 4242)
|
||||||
|
db.deposit(user3, 1234)
|
||||||
|
clubmate = db.create_product('Club Mate', 200, 200)
|
||||||
|
florapowermate = db.create_product('Flora Power Mate', 150, 250)
|
||||||
|
fritzmate = db.create_product('Fritz Mate', 200, 200)
|
||||||
|
db.restock(clubmate, 50)
|
||||||
|
db.restock(florapowermate, 70)
|
||||||
|
db.restock(fritzmate, 10)
|
||||||
|
|
||||||
|
# user1 is somewhat addicted to caffeine
|
||||||
|
db.increment_consumption(user1, clubmate, 1)
|
||||||
|
db.increment_consumption(user1, clubmate, 2)
|
||||||
|
db.increment_consumption(user1, florapowermate, 3)
|
||||||
|
|
||||||
|
# user2 is reeeally addicted
|
||||||
|
db.increment_consumption(user2, clubmate, 7)
|
||||||
|
db.increment_consumption(user2, florapowermate, 3)
|
||||||
|
db.increment_consumption(user2, florapowermate, 4)
|
||||||
|
|
||||||
|
with db.transaction(exclusive=False) as c:
|
||||||
|
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user1.id])
|
||||||
|
self.assertEqual(1337 - 200 * 3 - 150 * 3, c.fetchone()[0])
|
||||||
|
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user2.id])
|
||||||
|
self.assertEqual(4242 - 200 * 7 - 250 * 7, c.fetchone()[0])
|
||||||
|
c.execute('''SELECT balance FROM users WHERE user_id = ?''', [user3.id])
|
||||||
|
self.assertEqual(1234, c.fetchone()[0])
|
||||||
|
|
||||||
|
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [clubmate.id])
|
||||||
|
self.assertEqual(40, c.fetchone()[0])
|
||||||
|
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [florapowermate.id])
|
||||||
|
self.assertEqual(60, c.fetchone()[0])
|
||||||
|
c.execute('''SELECT stock FROM products WHERE product_id = ?''', [fritzmate.id])
|
||||||
|
self.assertEqual(10, c.fetchone()[0])
|
|
@ -1,20 +1,20 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from matemat.db import Database
|
from matemat.db import DatabaseWrapper
|
||||||
|
|
||||||
|
|
||||||
class DatabaseTest(unittest.TestCase):
|
class DatabaseTest(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.db = Database(':memory:')
|
self.db = DatabaseWrapper(':memory:')
|
||||||
|
|
||||||
def test_create_schema(self):
|
def test_create_schema(self):
|
||||||
"""
|
"""
|
||||||
Test creation of database schema in an empty database
|
Test creation of database schema in an empty database
|
||||||
"""
|
"""
|
||||||
with self.db as db:
|
with self.db as db:
|
||||||
self.assertEqual(Database.SCHEMA_VERSION, db._user_version)
|
self.assertEqual(DatabaseWrapper.SCHEMA_VERSION, db._user_version)
|
||||||
|
|
||||||
def test_in_transaction(self):
|
def test_in_transaction(self):
|
||||||
"""
|
"""
|
||||||
|
@ -74,22 +74,3 @@ class DatabaseTest(unittest.TestCase):
|
||||||
c = db._sqlite_db.cursor()
|
c = db._sqlite_db.cursor()
|
||||||
c.execute("SELECT * FROM users")
|
c.execute("SELECT * FROM users")
|
||||||
self.assertIsNone(c.fetchone())
|
self.assertIsNone(c.fetchone())
|
||||||
|
|
||||||
def test_create_user(self):
|
|
||||||
with self.db as db:
|
|
||||||
with db.transaction() as c:
|
|
||||||
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
|
||||||
c.execute("SELECT * FROM users")
|
|
||||||
row = c.fetchone()
|
|
||||||
if row is None:
|
|
||||||
self.fail()
|
|
||||||
self.assertEqual('testuser', row[1])
|
|
||||||
self.assertEqual('testuser@example.com', row[2])
|
|
||||||
self.assertEqual(0, row[5])
|
|
||||||
self.assertEqual(1, row[6])
|
|
||||||
|
|
||||||
def test_create_existing_user(self):
|
|
||||||
with self.db as db:
|
|
||||||
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com')
|
|
133
matemat/db/wrapper.py
Normal file
133
matemat/db/wrapper.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
|
||||||
|
import apsw
|
||||||
|
|
||||||
|
from matemat.exceptions import DatabaseConsistencyError
|
||||||
|
|
||||||
|
|
||||||
|
class Transaction(object):
|
||||||
|
|
||||||
|
def __init__(self, db: apsw.Connection, wrapper: 'DatabaseWrapper', exclusive: bool = True):
|
||||||
|
self._db: apsw.Connection = db
|
||||||
|
self._cursor = None
|
||||||
|
self._excl = exclusive
|
||||||
|
self._wrapper: DatabaseWrapper = wrapper
|
||||||
|
self._is_dummy: bool = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self._wrapper._in_transaction:
|
||||||
|
self._is_dummy = True
|
||||||
|
return self._db.cursor()
|
||||||
|
else:
|
||||||
|
self._is_dummy = False
|
||||||
|
self._cursor = self._db.cursor()
|
||||||
|
self._wrapper._in_transaction = True
|
||||||
|
if self._excl:
|
||||||
|
self._cursor.execute('BEGIN EXCLUSIVE')
|
||||||
|
else:
|
||||||
|
self._cursor.execute('BEGIN')
|
||||||
|
return self._cursor
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self._is_dummy:
|
||||||
|
return
|
||||||
|
if exc_type is None:
|
||||||
|
self._cursor.execute('COMMIT')
|
||||||
|
else:
|
||||||
|
self._cursor.execute('ROLLBACK')
|
||||||
|
self._wrapper._in_transaction = False
|
||||||
|
if exc_type == apsw.ConstraintError:
|
||||||
|
raise DatabaseConsistencyError(str(exc_val))
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseWrapper(object):
|
||||||
|
SCHEMA_VERSION = 1
|
||||||
|
|
||||||
|
SCHEMA = '''
|
||||||
|
CREATE TABLE users (
|
||||||
|
user_id INTEGER PRIMARY KEY,
|
||||||
|
username TEXT NOT NULL,
|
||||||
|
email TEXT DEFAULT NULL,
|
||||||
|
password TEXT NOT NULL,
|
||||||
|
touchkey TEXT DEFAULT NULL,
|
||||||
|
is_admin INTEGER(1) NOT NULL DEFAULT 0,
|
||||||
|
is_member INTEGER(1) NOT NULL DEFAULT 1,
|
||||||
|
balance INTEGER(8) NOT NULL DEFAULT 0,
|
||||||
|
lastchange INTEGER(8) NOT NULL DEFAULT 0
|
||||||
|
);
|
||||||
|
CREATE TABLE products (
|
||||||
|
product_id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT UNIQUE NOT NULL,
|
||||||
|
stock INTEGER(8) NOT NULL DEFAULT 0,
|
||||||
|
price_member INTEGER(8) NOT NULL,
|
||||||
|
price_non_member INTEGER(8) NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE consumption (
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
product_id INTEGER NOT NULL,
|
||||||
|
count INTEGER(8) NOT NULL DEFAULT 0,
|
||||||
|
PRIMARY KEY (user_id, product_id),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(user_id)
|
||||||
|
ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
||||||
|
ON DELETE CASCADE ON UPDATE CASCADE
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, filename: str):
|
||||||
|
self._filename: str = filename
|
||||||
|
self._sqlite_db: apsw.Connection = None
|
||||||
|
self._in_transaction: bool = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def transaction(self, exclusive: bool = True) -> Transaction:
|
||||||
|
return Transaction(self._sqlite_db, self, exclusive)
|
||||||
|
|
||||||
|
def _setup(self):
|
||||||
|
with self.transaction() as c:
|
||||||
|
version: int = self._user_version
|
||||||
|
if version < 1:
|
||||||
|
c.execute(self.SCHEMA)
|
||||||
|
elif version < self.SCHEMA_VERSION:
|
||||||
|
self._upgrade(old=version, new=self.SCHEMA_VERSION)
|
||||||
|
self._user_version = self.SCHEMA_VERSION
|
||||||
|
|
||||||
|
def _upgrade(self, old: int, new: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
if self.is_connected():
|
||||||
|
raise RuntimeError(f'Database connection to {self._filename} is already established.')
|
||||||
|
self._sqlite_db = apsw.Connection(self._filename)
|
||||||
|
self._setup()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._sqlite_db is None:
|
||||||
|
raise RuntimeError(f'Database connection to {self._filename} is not established.')
|
||||||
|
if self.in_transaction():
|
||||||
|
raise RuntimeError(f'A transaction is still ongoing.')
|
||||||
|
self._sqlite_db.close()
|
||||||
|
self._sqlite_db = None
|
||||||
|
|
||||||
|
def in_transaction(self) -> bool:
|
||||||
|
return self._in_transaction
|
||||||
|
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self._sqlite_db is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _user_version(self) -> int:
|
||||||
|
cursor = self._sqlite_db.cursor()
|
||||||
|
cursor.execute('PRAGMA user_version')
|
||||||
|
version = int(cursor.fetchone()[0])
|
||||||
|
return version
|
||||||
|
|
||||||
|
@_user_version.setter
|
||||||
|
def _user_version(self, version: int):
|
||||||
|
cursor = self._sqlite_db.cursor()
|
||||||
|
cursor.execute(f'PRAGMA user_version = {version}')
|
|
@ -1,11 +1,11 @@
|
||||||
|
|
||||||
class AuthenticationException(BaseException):
|
class AuthenticationError(BaseException):
|
||||||
|
|
||||||
def __init__(self, msg: str = None):
|
def __init__(self, msg: str = None):
|
||||||
self._msg = msg
|
self._msg = msg
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f'AuthenticationException: {self._msg}'
|
return f'AuthenticationError: {self._msg}'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def msg(self) -> str:
|
def msg(self) -> str:
|
12
matemat/exceptions/DatabaseConsistencyError.py
Normal file
12
matemat/exceptions/DatabaseConsistencyError.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
|
||||||
|
class DatabaseConsistencyError(BaseException):
|
||||||
|
|
||||||
|
def __init__(self, msg: str = None):
|
||||||
|
self._msg = msg
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'DatabaseConsistencyError: {self._msg}'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def msg(self) -> str:
|
||||||
|
return self._msg
|
|
@ -1,2 +1,3 @@
|
||||||
|
|
||||||
from .AuthenticatonException import AuthenticationException
|
from .AuthenticatonError import AuthenticationError
|
||||||
|
from .DatabaseConsistencyError import DatabaseConsistencyError
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
class Product(object):
|
class Product(object):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -11,6 +10,14 @@ class Product(object):
|
||||||
self._price_member: int = price_member
|
self._price_member: int = price_member
|
||||||
self._price_non_member: int = price_non_member
|
self._price_non_member: int = price_non_member
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if other is None or not isinstance(other, Product):
|
||||||
|
return False
|
||||||
|
return self._product_id == other._product_id \
|
||||||
|
and self._name == other._name \
|
||||||
|
and self._price_member == other._price_member \
|
||||||
|
and self._price_non_member == other._price_non_member
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> int:
|
def id(self) -> int:
|
||||||
return self._product_id
|
return self._product_id
|
||||||
|
@ -19,6 +26,11 @@ class Product(object):
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: str):
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def price_member(self) -> int:
|
def price_member(self) -> int:
|
||||||
return self._price_member
|
return self._price_member
|
||||||
|
|
|
@ -10,14 +10,21 @@ class User(object):
|
||||||
email: Optional[str] = None,
|
email: Optional[str] = None,
|
||||||
admin: bool = False,
|
admin: bool = False,
|
||||||
member: bool = True):
|
member: bool = True):
|
||||||
if user_id == -1:
|
|
||||||
raise ValueError('Invalid user ID')
|
|
||||||
self._user_id: int = user_id
|
self._user_id: int = user_id
|
||||||
self._username: str = username
|
self._username: str = username
|
||||||
self._email: Optional[str] = email
|
self._email: Optional[str] = email
|
||||||
self._admin: bool = admin
|
self._admin: bool = admin
|
||||||
self._member: bool = member
|
self._member: bool = member
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if other is None or not isinstance(other, User):
|
||||||
|
return False
|
||||||
|
return self._user_id == other._user_id \
|
||||||
|
and self._username == other._username \
|
||||||
|
and self._email == other._email \
|
||||||
|
and self._admin == other._admin \
|
||||||
|
and self._member == other._member
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> int:
|
def id(self) -> int:
|
||||||
return self._user_id
|
return self._user_id
|
||||||
|
|
Loading…
Reference in a new issue