Split database class into Facade (matemat database API) and Wrapper (internals).
This commit is contained in:
parent
e94f8bd29d
commit
411372cc21
8 changed files with 306 additions and 176 deletions
|
@ -1,2 +1,3 @@
|
|||
|
||||
from .database import Database
|
||||
from .wrapper import DatabaseWrapper
|
||||
from .facade import DatabaseFacade as Database
|
||||
|
|
|
@ -1,142 +1,31 @@
|
|||
|
||||
from typing import List, Optional
|
||||
|
||||
import apsw
|
||||
import bcrypt
|
||||
|
||||
from matemat.primitives import User, Product
|
||||
from matemat.exceptions import AuthenticationException
|
||||
from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
|
||||
from matemat.db import DatabaseWrapper
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
|
||||
def __init__(self, db: apsw.Connection, wrapper: 'Database', exclusive: bool = True):
|
||||
self._db: apsw.Connection = db
|
||||
self._cursor = None
|
||||
self._excl = exclusive
|
||||
self._wrapper: Database = wrapper
|
||||
self._is_dummy: bool = False
|
||||
|
||||
def __enter__(self):
|
||||
if self._wrapper._in_transaction:
|
||||
self._is_dummy = True
|
||||
return self._db.cursor()
|
||||
else:
|
||||
self._is_dummy = False
|
||||
self._cursor = self._db.cursor()
|
||||
self._wrapper._in_transaction = True
|
||||
if self._excl:
|
||||
self._cursor.execute('BEGIN EXCLUSIVE')
|
||||
else:
|
||||
self._cursor.execute('BEGIN')
|
||||
return self._cursor
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._is_dummy:
|
||||
return
|
||||
if exc_type is None:
|
||||
self._cursor.execute('COMMIT')
|
||||
else:
|
||||
self._cursor.execute('ROLLBACK')
|
||||
self._wrapper._in_transaction = False
|
||||
|
||||
|
||||
class Database(object):
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
SCHEMA = '''
|
||||
CREATE TABLE users (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
email TEXT DEFAULT NULL,
|
||||
password TEXT NOT NULL,
|
||||
touchkey TEXT DEFAULT NULL,
|
||||
is_admin INTEGER(1) NOT NULL DEFAULT 0,
|
||||
is_member INTEGER(1) NOT NULL DEFAULT 1,
|
||||
balance INTEGER(8) NOT NULL DEFAULT 0,
|
||||
lastchange INTEGER(8) NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE products (
|
||||
product_id INTEGER PRIMARY KEY,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
stock INTEGER(8) NOT NULL DEFAULT 0,
|
||||
price_member INTEGER(8) NOT NULL,
|
||||
price_non_member INTEGER(8) NOT NULL
|
||||
);
|
||||
CREATE TABLE consumption (
|
||||
user_id INTEGER NOT NULL,
|
||||
product_id INTEGER NOT NULL,
|
||||
count INTEGER(8) NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (user_id, product_id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE
|
||||
)
|
||||
'''
|
||||
class DatabaseFacade(object):
|
||||
|
||||
def __init__(self, filename: str):
|
||||
self._filename: str = filename
|
||||
self._sqlite_db: apsw.Connection = None
|
||||
self._in_transaction: bool = False
|
||||
self.db: DatabaseWrapper = DatabaseWrapper(filename)
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
self.db.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
self.db.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def transaction(self, exclusive: bool = True) -> Transaction:
|
||||
return Transaction(self._sqlite_db, self, exclusive)
|
||||
|
||||
def _setup(self):
|
||||
with self.transaction() as c:
|
||||
version: int = self._user_version
|
||||
if version < 1:
|
||||
c.execute(self.SCHEMA)
|
||||
elif version < self.SCHEMA_VERSION:
|
||||
self._upgrade(old=version, new=self.SCHEMA_VERSION)
|
||||
self._user_version = self.SCHEMA_VERSION
|
||||
|
||||
def _upgrade(self, old: int, new: int):
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected():
|
||||
raise RuntimeError(f'Database connection to {self._filename} is already established.')
|
||||
self._sqlite_db = apsw.Connection(self._filename)
|
||||
self._setup()
|
||||
|
||||
def close(self):
|
||||
if self._sqlite_db is None:
|
||||
raise RuntimeError(f'Database connection to {self._filename} is not established.')
|
||||
if self.in_transaction():
|
||||
raise RuntimeError(f'A transaction is still ongoing.')
|
||||
self._sqlite_db.close()
|
||||
self._sqlite_db = None
|
||||
|
||||
def in_transaction(self) -> bool:
|
||||
return self._in_transaction
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self._sqlite_db is not None
|
||||
|
||||
@property
|
||||
def _user_version(self) -> int:
|
||||
cursor = self._sqlite_db.cursor()
|
||||
cursor.execute('PRAGMA user_version')
|
||||
version = int(cursor.fetchone()[0])
|
||||
return version
|
||||
|
||||
@_user_version.setter
|
||||
def _user_version(self, version: int):
|
||||
cursor = self._sqlite_db.cursor()
|
||||
cursor.execute(f'PRAGMA user_version = {version}')
|
||||
def transaction(self, exclusive: bool = True):
|
||||
return self.db.transaction(exclusive=exclusive)
|
||||
|
||||
def list_users(self) -> List[User]:
|
||||
users: List[User] = []
|
||||
with self.transaction(exclusive=False) as c:
|
||||
with self.db.transaction(exclusive=False) as c:
|
||||
for row in c.execute('''
|
||||
SELECT user_id, username, email, is_admin, is_member
|
||||
FROM users
|
||||
|
@ -153,7 +42,7 @@ class Database(object):
|
|||
member: bool = True) -> User:
|
||||
pwhash: str = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12))
|
||||
user_id: int = -1
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('SELECT user_id FROM users WHERE username = ?', [username])
|
||||
if c.fetchone() is not None:
|
||||
raise ValueError(f'A user with the name \'{username}\' already exists.')
|
||||
|
@ -167,37 +56,43 @@ class Database(object):
|
|||
'admin': admin,
|
||||
'member': member
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'create_user should affect 1 users row, but affected {affected}')
|
||||
c.execute('SELECT last_insert_rowid()')
|
||||
user_id = int(c.fetchone()[0])
|
||||
return User(user_id, username, email, admin, member)
|
||||
|
||||
def login(self, username: str, password: str) -> Optional[User]:
|
||||
with self.transaction(exclusive=False) as c:
|
||||
def login(self, username: str, password: Optional[str] = None, touchkey: Optional[str] = None) -> User:
|
||||
if (password is None) == (touchkey is None):
|
||||
raise ValueError('Exactly one of password and touchkey must be provided')
|
||||
with self.db.transaction(exclusive=False) as c:
|
||||
c.execute('''
|
||||
SELECT user_id, username, email, password, is_admin, is_member
|
||||
SELECT user_id, username, email, password, touchkey, is_admin, is_member
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
''', [username])
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
user_id, username, email, pwhash, admin, member = row
|
||||
if not bcrypt.checkpw(password.encode('utf-8'), pwhash):
|
||||
return None
|
||||
raise AuthenticationError('User does not exist')
|
||||
user_id, username, email, pwhash, tkhash, admin, member = row
|
||||
if password is not None and not bcrypt.checkpw(password.encode('utf-8'), pwhash):
|
||||
raise AuthenticationError('Password mismatch')
|
||||
elif touchkey is not None and not bcrypt.checkpw(touchkey.encode('utf-8'), tkhash):
|
||||
raise AuthenticationError('Touchkey mismatch')
|
||||
return User(user_id, username, email, admin, member)
|
||||
|
||||
def change_password(self, user: User, oldpass: str, newpass: str, newpass2: str, verify_password: bool = True):
|
||||
if newpass != newpass2:
|
||||
raise ValueError('New passwords don\'t match.')
|
||||
with self.transaction() as c:
|
||||
def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True):
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
SELECT password FROM users WHERE user_id = ?
|
||||
''', [user.id])
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
raise AuthenticationException('User does not exist in database.')
|
||||
raise AuthenticationError('User does not exist in database.')
|
||||
if verify_password and not bcrypt.checkpw(oldpass.encode('utf-8'), row[0]):
|
||||
raise AuthenticationException('Old password does not match.')
|
||||
raise AuthenticationError('Old password does not match.')
|
||||
pwhash: str = bcrypt.hashpw(newpass.encode('utf-8'), bcrypt.gensalt(12))
|
||||
c.execute('''
|
||||
UPDATE users SET password = :pwhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id
|
||||
|
@ -205,17 +100,21 @@ class Database(object):
|
|||
'user_id': user.id,
|
||||
'pwhash': pwhash
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'change_password should affect 1 users row, but affected {affected}')
|
||||
|
||||
def change_touchkey(self, user: User, password: str, touchkey: Optional[str], verify_password: bool = True):
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
SELECT password FROM users WHERE user_id = ?
|
||||
''', [user.id])
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
raise AuthenticationException('User does not exist in database.')
|
||||
raise AuthenticationError('User does not exist in database.')
|
||||
if verify_password and not bcrypt.checkpw(password.encode('utf-8'), row[0]):
|
||||
raise AuthenticationException('Password does not match.')
|
||||
raise AuthenticationError('Password does not match.')
|
||||
tkhash: str = bcrypt.hashpw(touchkey.encode('utf-8'), bcrypt.gensalt(12)) if touchkey is not None else None
|
||||
c.execute('''
|
||||
UPDATE users SET touchkey = :tkhash, lastchange = STRFTIME('%s', 'now') WHERE user_id = :user_id
|
||||
|
@ -223,9 +122,13 @@ class Database(object):
|
|||
'user_id': user.id,
|
||||
'tkhash': tkhash
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'change_touchkey should affect 1 users row, but affected {affected}')
|
||||
|
||||
def change_user(self, user: User):
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
UPDATE users SET
|
||||
email = :email,
|
||||
|
@ -239,17 +142,25 @@ class Database(object):
|
|||
'is_admin': user.is_admin,
|
||||
'is_member': user.is_member
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'change_user should affect 1 users row, but affected {affected}')
|
||||
|
||||
def delete_user(self, user: User):
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
DELETE FROM users
|
||||
WHERE user_id = ?
|
||||
''', [user.id])
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'delete_user should affect 1 users row, but affected {affected}')
|
||||
|
||||
def list_products(self) -> List[Product]:
|
||||
products: List[Product] = []
|
||||
with self.transaction(exclusive=False) as c:
|
||||
with self.db.transaction(exclusive=False) as c:
|
||||
for row in c.execute('''
|
||||
SELECT product_id, name, price_member, price_external
|
||||
FROM products
|
||||
|
@ -260,7 +171,7 @@ class Database(object):
|
|||
|
||||
def create_product(self, name: str, price_member: int, price_non_member: int) -> Product:
|
||||
product_id: int = -1
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('SELECT product_id FROM products WHERE name = ?', [name])
|
||||
if c.fetchone() is not None:
|
||||
raise ValueError(f'A product with the name \'{name}\' already exists.')
|
||||
|
@ -273,13 +184,17 @@ class Database(object):
|
|||
'price_non_member': price_non_member
|
||||
})
|
||||
c.execute('SELECT last_insert_rowid()')
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'create_product should affect 1 products row, but affected {affected}')
|
||||
product_id = int(c.fetchone()[0])
|
||||
return Product(product_id, name, price_member, price_non_member)
|
||||
|
||||
def change_product(self, product: Product):
|
||||
if product.id == -1:
|
||||
raise ValueError('Invalid product ID')
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
UPDATE products
|
||||
SET
|
||||
|
@ -292,20 +207,24 @@ class Database(object):
|
|||
'price_member': product.price_member,
|
||||
'price_non_member': product.price_non_member
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'change_product should affect 1 products row, but affected {affected}')
|
||||
|
||||
def delete_product(self, product: Product):
|
||||
if product.id == -1:
|
||||
raise ValueError('Invalid product ID')
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
DELETE FROM products
|
||||
WHERE product_id = ?
|
||||
''', [product.id])
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'delete_product should affect 1 products row, but affected {affected}')
|
||||
|
||||
def increment_consumption(self, user: User, product: Product, count: int = 1):
|
||||
if product.id == -1:
|
||||
raise ValueError('Invalid product ID')
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
SELECT count
|
||||
FROM consumption
|
||||
|
@ -332,6 +251,10 @@ class Database(object):
|
|||
'product_id': product.id,
|
||||
'count': count
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'increment_consumption should affect 1 consumption row, but affected {affected}')
|
||||
c.execute('''
|
||||
UPDATE users
|
||||
SET balance = balance - :cost
|
||||
|
@ -339,6 +262,10 @@ class Database(object):
|
|||
'user_id': user.id,
|
||||
'cost': count * product.price_member if user.is_member else count * product.price_non_member
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'increment_consumption should affect 1 users row, but affected {affected}')
|
||||
c.execute('''
|
||||
UPDATE products
|
||||
SET stock = stock - :count
|
||||
|
@ -347,11 +274,15 @@ class Database(object):
|
|||
'product_id': product.id,
|
||||
'count': count
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(
|
||||
f'increment_consumption should affect 1 products row, but affected {affected}')
|
||||
|
||||
def restock(self, product: Product, count: int):
|
||||
if product.id == -1:
|
||||
raise ValueError('Invalid product ID')
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
UPDATE products
|
||||
SET stock = stock + :count
|
||||
|
@ -360,9 +291,12 @@ class Database(object):
|
|||
'product_id': product.id,
|
||||
'count': count
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}')
|
||||
|
||||
def deposit(self, user: User, amount: int):
|
||||
with self.transaction() as c:
|
||||
with self.db.transaction() as c:
|
||||
c.execute('''
|
||||
UPDATE users
|
||||
SET balance = balance + :amount
|
||||
|
@ -371,3 +305,6 @@ class Database(object):
|
|||
'user_id': user.id,
|
||||
'amount': amount
|
||||
})
|
||||
affected = c.execute('SELECT changes()').fetchone()[0]
|
||||
if affected != 1:
|
||||
raise DatabaseConsistencyError(f'deposit should affect 1 users row, but affected {affected}')
|
69
matemat/db/test/test_facade.py
Normal file
69
matemat/db/test/test_facade.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from matemat.db import Database
|
||||
from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
|
||||
|
||||
|
||||
class DatabaseTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.db = Database(':memory:')
|
||||
|
||||
def test_create_user(self):
|
||||
with self.db as db:
|
||||
with db.transaction() as c:
|
||||
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||
c.execute("SELECT * FROM users")
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
self.fail()
|
||||
self.assertEqual('testuser', row[1])
|
||||
self.assertEqual('testuser@example.com', row[2])
|
||||
self.assertEqual(0, row[5])
|
||||
self.assertEqual(1, row[6])
|
||||
|
||||
def test_create_existing_user(self):
|
||||
with self.db as db:
|
||||
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||
with self.assertRaises(ValueError):
|
||||
db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com')
|
||||
|
||||
def test_change_password(self):
|
||||
with self.db as db:
|
||||
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||
db.login('testuser', 'supersecurepassword')
|
||||
# Normal password change should succeed
|
||||
db.change_password(user, 'supersecurepassword', 'mynewpassword')
|
||||
with self.assertRaises(AuthenticationError):
|
||||
db.login('testuser', 'supersecurepassword')
|
||||
db.login('testuser', 'mynewpassword')
|
||||
with self.assertRaises(AuthenticationError):
|
||||
# Should fail due to wrong old password
|
||||
db.change_password(user, 'iforgotmypassword', 'mynewpassword')
|
||||
db.login('testuser', 'mynewpassword')
|
||||
# This should pass even though the old password is not known (admin password reset)
|
||||
db.change_password(user, '42', 'adminpasswordreset', verify_password=False)
|
||||
with self.assertRaises(AuthenticationError):
|
||||
db.login('testuser', 'mynewpassword')
|
||||
db.login('testuser', 'adminpasswordreset')
|
||||
|
||||
def test_change_touchkey(self):
|
||||
with self.db as db:
|
||||
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||
db.change_touchkey(user, 'supersecurepassword', '0123')
|
||||
db.login('testuser', touchkey='0123')
|
||||
# Normal touchkey change should succeed
|
||||
db.change_touchkey(user, 'supersecurepassword', touchkey='4567')
|
||||
with self.assertRaises(AuthenticationError):
|
||||
db.login('testuser', touchkey='0123')
|
||||
db.login('testuser', touchkey='4567')
|
||||
with self.assertRaises(AuthenticationError):
|
||||
# Should fail due to wrong old password
|
||||
db.change_touchkey(user, 'iforgotmypassword', '89ab')
|
||||
db.login('testuser', touchkey='4567')
|
||||
# This should pass even though the old password is not known (admin password reset)
|
||||
db.change_touchkey(user, '42', '89ab', verify_password=False)
|
||||
with self.assertRaises(AuthenticationError):
|
||||
db.login('testuser', touchkey='4567')
|
||||
db.login('testuser', touchkey='89ab')
|
|
@ -1,20 +1,20 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from matemat.db import Database
|
||||
from matemat.db import DatabaseWrapper
|
||||
|
||||
|
||||
class DatabaseTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.db = Database(':memory:')
|
||||
self.db = DatabaseWrapper(':memory:')
|
||||
|
||||
def test_create_schema(self):
|
||||
"""
|
||||
Test creation of database schema in an empty database
|
||||
"""
|
||||
with self.db as db:
|
||||
self.assertEqual(Database.SCHEMA_VERSION, db._user_version)
|
||||
self.assertEqual(DatabaseWrapper.SCHEMA_VERSION, db._user_version)
|
||||
|
||||
def test_in_transaction(self):
|
||||
"""
|
||||
|
@ -74,22 +74,3 @@ class DatabaseTest(unittest.TestCase):
|
|||
c = db._sqlite_db.cursor()
|
||||
c.execute("SELECT * FROM users")
|
||||
self.assertIsNone(c.fetchone())
|
||||
|
||||
def test_create_user(self):
|
||||
with self.db as db:
|
||||
with db.transaction() as c:
|
||||
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||
c.execute("SELECT * FROM users")
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
self.fail()
|
||||
self.assertEqual('testuser', row[1])
|
||||
self.assertEqual('testuser@example.com', row[2])
|
||||
self.assertEqual(0, row[5])
|
||||
self.assertEqual(1, row[6])
|
||||
|
||||
def test_create_existing_user(self):
|
||||
with self.db as db:
|
||||
db.create_user('testuser', 'supersecurepassword', 'testuser@example.com')
|
||||
with self.assertRaises(ValueError):
|
||||
db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com')
|
129
matemat/db/wrapper.py
Normal file
129
matemat/db/wrapper.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
|
||||
import apsw
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
|
||||
def __init__(self, db: apsw.Connection, wrapper: 'DatabaseWrapper', exclusive: bool = True):
|
||||
self._db: apsw.Connection = db
|
||||
self._cursor = None
|
||||
self._excl = exclusive
|
||||
self._wrapper: DatabaseWrapper = wrapper
|
||||
self._is_dummy: bool = False
|
||||
|
||||
def __enter__(self):
|
||||
if self._wrapper._in_transaction:
|
||||
self._is_dummy = True
|
||||
return self._db.cursor()
|
||||
else:
|
||||
self._is_dummy = False
|
||||
self._cursor = self._db.cursor()
|
||||
self._wrapper._in_transaction = True
|
||||
if self._excl:
|
||||
self._cursor.execute('BEGIN EXCLUSIVE')
|
||||
else:
|
||||
self._cursor.execute('BEGIN')
|
||||
return self._cursor
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._is_dummy:
|
||||
return
|
||||
if exc_type is None:
|
||||
self._cursor.execute('COMMIT')
|
||||
else:
|
||||
self._cursor.execute('ROLLBACK')
|
||||
self._wrapper._in_transaction = False
|
||||
|
||||
|
||||
class DatabaseWrapper(object):
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
SCHEMA = '''
|
||||
CREATE TABLE users (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
email TEXT DEFAULT NULL,
|
||||
password TEXT NOT NULL,
|
||||
touchkey TEXT DEFAULT NULL,
|
||||
is_admin INTEGER(1) NOT NULL DEFAULT 0,
|
||||
is_member INTEGER(1) NOT NULL DEFAULT 1,
|
||||
balance INTEGER(8) NOT NULL DEFAULT 0,
|
||||
lastchange INTEGER(8) NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE products (
|
||||
product_id INTEGER PRIMARY KEY,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
stock INTEGER(8) NOT NULL DEFAULT 0,
|
||||
price_member INTEGER(8) NOT NULL,
|
||||
price_non_member INTEGER(8) NOT NULL
|
||||
);
|
||||
CREATE TABLE consumption (
|
||||
user_id INTEGER NOT NULL,
|
||||
product_id INTEGER NOT NULL,
|
||||
count INTEGER(8) NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (user_id, product_id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE
|
||||
)
|
||||
'''
|
||||
|
||||
def __init__(self, filename: str):
|
||||
self._filename: str = filename
|
||||
self._sqlite_db: apsw.Connection = None
|
||||
self._in_transaction: bool = False
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def transaction(self, exclusive: bool = True) -> Transaction:
|
||||
return Transaction(self._sqlite_db, self, exclusive)
|
||||
|
||||
def _setup(self):
|
||||
with self.transaction() as c:
|
||||
version: int = self._user_version
|
||||
if version < 1:
|
||||
c.execute(self.SCHEMA)
|
||||
elif version < self.SCHEMA_VERSION:
|
||||
self._upgrade(old=version, new=self.SCHEMA_VERSION)
|
||||
self._user_version = self.SCHEMA_VERSION
|
||||
|
||||
def _upgrade(self, old: int, new: int):
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected():
|
||||
raise RuntimeError(f'Database connection to {self._filename} is already established.')
|
||||
self._sqlite_db = apsw.Connection(self._filename)
|
||||
self._setup()
|
||||
|
||||
def close(self):
|
||||
if self._sqlite_db is None:
|
||||
raise RuntimeError(f'Database connection to {self._filename} is not established.')
|
||||
if self.in_transaction():
|
||||
raise RuntimeError(f'A transaction is still ongoing.')
|
||||
self._sqlite_db.close()
|
||||
self._sqlite_db = None
|
||||
|
||||
def in_transaction(self) -> bool:
|
||||
return self._in_transaction
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self._sqlite_db is not None
|
||||
|
||||
@property
|
||||
def _user_version(self) -> int:
|
||||
cursor = self._sqlite_db.cursor()
|
||||
cursor.execute('PRAGMA user_version')
|
||||
version = int(cursor.fetchone()[0])
|
||||
return version
|
||||
|
||||
@_user_version.setter
|
||||
def _user_version(self, version: int):
|
||||
cursor = self._sqlite_db.cursor()
|
||||
cursor.execute(f'PRAGMA user_version = {version}')
|
|
@ -1,11 +1,11 @@
|
|||
|
||||
class AuthenticationException(BaseException):
|
||||
class AuthenticationError(BaseException):
|
||||
|
||||
def __init__(self, msg: str = None):
|
||||
self._msg = msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'AuthenticationException: {self._msg}'
|
||||
return f'AuthenticationErro: {self._msg}'
|
||||
|
||||
@property
|
||||
def msg(self) -> str:
|
12
matemat/exceptions/DatabaseConsistencyError.py
Normal file
12
matemat/exceptions/DatabaseConsistencyError.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
|
||||
class DatabaseConsistencyError(BaseException):
|
||||
|
||||
def __init__(self, msg: str = None):
|
||||
self._msg = msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'DatabaseConsistencyError: {self._msg}'
|
||||
|
||||
@property
|
||||
def msg(self) -> str:
|
||||
return self._msg
|
|
@ -1,2 +1,3 @@
|
|||
|
||||
from .AuthenticatonException import AuthenticationException
|
||||
from .AuthenticatonError import AuthenticationError
|
||||
from .DatabaseConsistencyError import DatabaseConsistencyError
|
||||
|
|
Loading…
Reference in a new issue