forked from s3lph/matemat
135 lines
4.3 KiB
Python
135 lines
4.3 KiB
Python
|
|
from typing import Any
|
|
|
|
import apsw
|
|
|
|
from matemat.exceptions import DatabaseConsistencyError
|
|
|
|
|
|
class Transaction(object):
|
|
|
|
def __init__(self, db: apsw.Connection, wrapper: 'DatabaseWrapper', exclusive: bool = True) -> None:
|
|
self._db: apsw.Connection = db
|
|
self._cursor = None
|
|
self._excl = exclusive
|
|
self._wrapper: DatabaseWrapper = wrapper
|
|
self._is_dummy: bool = False
|
|
|
|
def __enter__(self) -> Any:
|
|
if self._wrapper._in_transaction:
|
|
self._is_dummy = True
|
|
return self._db.cursor()
|
|
else:
|
|
self._is_dummy = False
|
|
self._cursor = self._db.cursor()
|
|
self._wrapper._in_transaction = True
|
|
if self._excl:
|
|
self._cursor.execute('BEGIN EXCLUSIVE')
|
|
else:
|
|
self._cursor.execute('BEGIN')
|
|
return self._cursor
|
|
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
if self._is_dummy:
|
|
return
|
|
if exc_type is None:
|
|
self._cursor.execute('COMMIT')
|
|
else:
|
|
self._cursor.execute('ROLLBACK')
|
|
self._wrapper._in_transaction = False
|
|
if exc_type == apsw.ConstraintError:
|
|
raise DatabaseConsistencyError(str(exc_val))
|
|
|
|
|
|
class DatabaseWrapper(object):
|
|
SCHEMA_VERSION = 1
|
|
|
|
SCHEMA = '''
|
|
CREATE TABLE users (
|
|
user_id INTEGER PRIMARY KEY,
|
|
username TEXT NOT NULL,
|
|
email TEXT DEFAULT NULL,
|
|
password TEXT NOT NULL,
|
|
touchkey TEXT DEFAULT NULL,
|
|
is_admin INTEGER(1) NOT NULL DEFAULT 0,
|
|
is_member INTEGER(1) NOT NULL DEFAULT 1,
|
|
balance INTEGER(8) NOT NULL DEFAULT 0,
|
|
lastchange INTEGER(8) NOT NULL DEFAULT 0
|
|
);
|
|
CREATE TABLE products (
|
|
product_id INTEGER PRIMARY KEY,
|
|
name TEXT UNIQUE NOT NULL,
|
|
stock INTEGER(8) NOT NULL DEFAULT 0,
|
|
price_member INTEGER(8) NOT NULL,
|
|
price_non_member INTEGER(8) NOT NULL
|
|
);
|
|
CREATE TABLE consumption (
|
|
user_id INTEGER NOT NULL,
|
|
product_id INTEGER NOT NULL,
|
|
count INTEGER(8) NOT NULL DEFAULT 0,
|
|
PRIMARY KEY (user_id, product_id),
|
|
FOREIGN KEY (user_id) REFERENCES users(user_id)
|
|
ON DELETE CASCADE ON UPDATE CASCADE,
|
|
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
|
ON DELETE CASCADE ON UPDATE CASCADE
|
|
)
|
|
'''
|
|
|
|
def __init__(self, filename: str) -> None:
|
|
self._filename: str = filename
|
|
self._sqlite_db: apsw.Connection = None
|
|
self._in_transaction: bool = False
|
|
|
|
def __enter__(self) -> 'DatabaseWrapper':
|
|
self.connect()
|
|
return self
|
|
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
self.close()
|
|
|
|
def transaction(self, exclusive: bool = True) -> Transaction:
|
|
return Transaction(self._sqlite_db, self, exclusive)
|
|
|
|
def _setup(self) -> None:
|
|
with self.transaction() as c:
|
|
version: int = self._user_version
|
|
if version < 1:
|
|
c.execute(self.SCHEMA)
|
|
elif version < self.SCHEMA_VERSION:
|
|
self._upgrade(old=version, new=self.SCHEMA_VERSION)
|
|
self._user_version = self.SCHEMA_VERSION
|
|
|
|
def _upgrade(self, old: int, new: int) -> None:
|
|
pass
|
|
|
|
def connect(self) -> None:
|
|
if self.is_connected():
|
|
raise RuntimeError(f'Database connection to {self._filename} is already established.')
|
|
self._sqlite_db = apsw.Connection(self._filename)
|
|
self._setup()
|
|
|
|
def close(self) -> None:
|
|
if self._sqlite_db is None:
|
|
raise RuntimeError(f'Database connection to {self._filename} is not established.')
|
|
if self.in_transaction():
|
|
raise RuntimeError(f'A transaction is still ongoing.')
|
|
self._sqlite_db.close()
|
|
self._sqlite_db = None
|
|
|
|
def in_transaction(self) -> bool:
|
|
return self._in_transaction
|
|
|
|
def is_connected(self) -> bool:
|
|
return self._sqlite_db is not None
|
|
|
|
@property
|
|
def _user_version(self) -> int:
|
|
cursor = self._sqlite_db.cursor()
|
|
cursor.execute('PRAGMA user_version')
|
|
version = int(cursor.fetchone()[0])
|
|
return version
|
|
|
|
@_user_version.setter
|
|
def _user_version(self, version: int) -> None:
|
|
cursor = self._sqlite_db.cursor()
|
|
cursor.execute(f'PRAGMA user_version = {version}')
|