Merge branch 'staging' into 'master'

Receipt generation

See merge request s3lph/matemat!43
This commit is contained in:
s3lph 2018-09-09 21:08:43 +00:00
commit 3995245127
30 changed files with 1307 additions and 64 deletions

2
doc

@ -1 +1 @@
Subproject commit 0cf3d59c8b37f84e915f5e30e7447f0611cc1238 Subproject commit 411880ae72b3a2204fed4b945bdb3a15d3ece364

View file

@ -4,11 +4,14 @@ from typing import List, Optional, Any, Type
import crypt import crypt
from hmac import compare_digest from hmac import compare_digest
from datetime import datetime
import logging
from matemat.db.primitives import User, Product from matemat.db.primitives import User, Product, ReceiptPreference, Receipt,\
Transaction, Consumption, Deposit, Modification
from matemat.exceptions import AuthenticationError, DatabaseConsistencyError from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
from matemat.db import DatabaseWrapper from matemat.db import DatabaseWrapper
from matemat.db.wrapper import Transaction from matemat.db.wrapper import DatabaseTransaction
class MatematDatabase(object): class MatematDatabase(object):
@ -46,7 +49,7 @@ class MatematDatabase(object):
# Pass context manager stuff through to the database wrapper # Pass context manager stuff through to the database wrapper
self.db.__exit__(exc_type, exc_val, exc_tb) self.db.__exit__(exc_type, exc_val, exc_tb)
def transaction(self, exclusive: bool = True) -> Transaction: def transaction(self, exclusive: bool = True) -> DatabaseTransaction:
""" """
Begin a new SQLite3 transaction (exclusive by default). You should never need to use the returned object (a Begin a new SQLite3 transaction (exclusive by default). You should never need to use the returned object (a
Transaction instance). It is provided in case there is a real need for it (e.g. for unit testing). Transaction instance). It is provided in case there is a real need for it (e.g. for unit testing).
@ -83,15 +86,19 @@ class MatematDatabase(object):
users: List[User] = [] users: List[User] = []
with self.db.transaction(exclusive=False) as c: with self.db.transaction(exclusive=False) as c:
for row in c.execute(''' for row in c.execute('''
SELECT user_id, username, email, is_admin, is_member, balance SELECT user_id, username, email, is_admin, is_member, balance, receipt_pref
FROM users FROM users
WHERE touchkey IS NOT NULL OR NOT :must_have_touchkey WHERE touchkey IS NOT NULL OR NOT :must_have_touchkey
''', { ''', {
'must_have_touchkey': with_touchkey 'must_have_touchkey': with_touchkey
}): }):
# Decompose each row and put the values into a User object # Decompose each row and put the values into a User object
user_id, username, email, is_admin, is_member, balance = row user_id, username, email, is_admin, is_member, balance, receipt_p = row
users.append(User(user_id, username, balance, email, is_admin, is_member)) try:
receipt_pref: ReceiptPreference = ReceiptPreference(receipt_p)
except ValueError:
raise DatabaseConsistencyError(f'{receipt_p} is not a valid ReceiptPreference')
users.append(User(user_id, username, balance, email, is_admin, is_member, receipt_pref))
return users return users
def get_user(self, uid: int) -> User: def get_user(self, uid: int) -> User:
@ -101,14 +108,22 @@ class MatematDatabase(object):
""" """
with self.db.transaction(exclusive=False) as c: with self.db.transaction(exclusive=False) as c:
# Fetch all values to construct the user # Fetch all values to construct the user
c.execute('SELECT user_id, username, email, is_admin, is_member, balance FROM users WHERE user_id = ?', c.execute('''
SELECT user_id, username, email, is_admin, is_member, balance, receipt_pref
FROM users
WHERE user_id = ?
''',
[uid]) [uid])
row = c.fetchone() row = c.fetchone()
if row is None: if row is None:
raise ValueError(f'No user with user ID {uid} exists.') raise ValueError(f'No user with user ID {uid} exists.')
# Unpack the row and construct the user # Unpack the row and construct the user
user_id, username, email, is_admin, is_member, balance = row user_id, username, email, is_admin, is_member, balance, receipt_p = row
return User(user_id, username, balance, email, is_admin, is_member) try:
receipt_pref: ReceiptPreference = ReceiptPreference(receipt_p)
except ValueError:
raise DatabaseConsistencyError(f'{receipt_p} is not a valid ReceiptPreference')
return User(user_id, username, balance, email, is_admin, is_member, receipt_pref)
def create_user(self, def create_user(self,
username: str, username: str,
@ -136,8 +151,8 @@ class MatematDatabase(object):
raise ValueError(f'A user with the name \'{username}\' already exists.') raise ValueError(f'A user with the name \'{username}\' already exists.')
# Insert the user into the database. # Insert the user into the database.
c.execute(''' c.execute('''
INSERT INTO users (username, email, password, balance, is_admin, is_member, lastchange) INSERT INTO users (username, email, password, balance, is_admin, is_member, lastchange, created)
VALUES (:username, :email, :pwhash, 0, :admin, :member, STRFTIME('%s', 'now')) VALUES (:username, :email, :pwhash, 0, :admin, :member, STRFTIME('%s', 'now'), STRFTIME('%s', 'now'))
''', { ''', {
'username': username, 'username': username,
'email': email, 'email': email,
@ -165,14 +180,14 @@ class MatematDatabase(object):
raise ValueError('Exactly one of password and touchkey must be provided') raise ValueError('Exactly one of password and touchkey must be provided')
with self.db.transaction(exclusive=False) as c: with self.db.transaction(exclusive=False) as c:
c.execute(''' c.execute('''
SELECT user_id, username, email, password, touchkey, is_admin, is_member, balance SELECT user_id, username, email, password, touchkey, is_admin, is_member, balance, receipt_pref
FROM users FROM users
WHERE username = ? WHERE username = ?
''', [username]) ''', [username])
row = c.fetchone() row = c.fetchone()
if row is None: if row is None:
raise AuthenticationError('User does not exist') raise AuthenticationError('User does not exist')
user_id, username, email, pwhash, tkhash, admin, member, balance = row user_id, username, email, pwhash, tkhash, admin, member, balance, receipt_p = row
if password is not None and not compare_digest(crypt.crypt(password, pwhash), pwhash): if password is not None and not compare_digest(crypt.crypt(password, pwhash), pwhash):
raise AuthenticationError('Password mismatch') raise AuthenticationError('Password mismatch')
elif touchkey is not None \ elif touchkey is not None \
@ -181,7 +196,11 @@ class MatematDatabase(object):
raise AuthenticationError('Touchkey mismatch') raise AuthenticationError('Touchkey mismatch')
elif touchkey is not None and tkhash is None: elif touchkey is not None and tkhash is None:
raise AuthenticationError('Touchkey not set') raise AuthenticationError('Touchkey not set')
return User(user_id, username, balance, email, admin, member) try:
receipt_pref: ReceiptPreference = ReceiptPreference(receipt_p)
except ValueError:
raise DatabaseConsistencyError(f'{receipt_p} is not a valid ReceiptPreference')
return User(user_id, username, balance, email, admin, member, receipt_pref)
def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True) -> None: def change_password(self, user: User, oldpass: str, newpass: str, verify_password: bool = True) -> None:
""" """
@ -243,8 +262,7 @@ class MatematDatabase(object):
'tkhash': tkhash 'tkhash': tkhash
}) })
def change_user(self, user: User, agent: Optional[User], **kwargs)\ def change_user(self, user: User, agent: Optional[User], **kwargs) -> None:
-> None:
""" """
Commit changes to the user in the database. If writing the requested changes succeeded, the values are updated Commit changes to the user in the database. If writing the requested changes succeeded, the values are updated
in the provided user object. Otherwise the user object is left untouched. The user to update is identified by in the provided user object. Otherwise the user object is left untouched. The user to update is identified by
@ -258,9 +276,11 @@ class MatematDatabase(object):
# Resolve the values to change # Resolve the values to change
name: str = kwargs['name'] if 'name' in kwargs else user.name name: str = kwargs['name'] if 'name' in kwargs else user.name
email: str = kwargs['email'] if 'email' in kwargs else user.email email: str = kwargs['email'] if 'email' in kwargs else user.email
balance: int = kwargs['balance'] if 'balance' in kwargs else user.balance
is_admin: bool = kwargs['is_admin'] if 'is_admin' in kwargs else user.is_admin is_admin: bool = kwargs['is_admin'] if 'is_admin' in kwargs else user.is_admin
is_member: bool = kwargs['is_member'] if 'is_member' in kwargs else user.is_member is_member: bool = kwargs['is_member'] if 'is_member' in kwargs else user.is_member
balance: int = kwargs['balance'] if 'balance' in kwargs else user.balance
balance_reason: Optional[str] = kwargs['balance_reason'] if 'balance_reason' in kwargs else None
receipt_pref: ReceiptPreference = kwargs['receipt_pref'] if 'receipt_pref' in kwargs else user.receipt_pref
with self.db.transaction() as c: with self.db.transaction() as c:
c.execute('SELECT balance FROM users WHERE user_id = :user_id', {'user_id': user.id}) c.execute('SELECT balance FROM users WHERE user_id = :user_id', {'user_id': user.id})
row = c.fetchone() row = c.fetchone()
@ -278,11 +298,13 @@ class MatematDatabase(object):
'value': balance - oldbalance, 'value': balance - oldbalance,
'old_balance': oldbalance 'old_balance': oldbalance
}) })
# TODO: Implement reason field
c.execute(''' c.execute('''
INSERT INTO modifications (ta_id, agent_id, reason) INSERT INTO modifications (ta_id, agent, reason)
VALUES (last_insert_rowid(), :agent_id, NULL) VALUES (last_insert_rowid(), :agent, :reason)
''', {'agent_id': agent.id}) ''', {
'agent': agent.name,
'reason': balance_reason
})
c.execute(''' c.execute('''
UPDATE users SET UPDATE users SET
username = :username, username = :username,
@ -290,6 +312,7 @@ class MatematDatabase(object):
balance = :balance, balance = :balance,
is_admin = :is_admin, is_admin = :is_admin,
is_member = :is_member, is_member = :is_member,
receipt_pref = :receipt_pref,
lastchange = STRFTIME('%s', 'now') lastchange = STRFTIME('%s', 'now')
WHERE user_id = :user_id WHERE user_id = :user_id
''', { ''', {
@ -298,7 +321,8 @@ class MatematDatabase(object):
'email': email, 'email': email,
'balance': balance, 'balance': balance,
'is_admin': is_admin, 'is_admin': is_admin,
'is_member': is_member 'is_member': is_member,
'receipt_pref': receipt_pref.value
}) })
# Only update the actual user object after the changes in the database succeeded # Only update the actual user object after the changes in the database succeeded
user.name = name user.name = name
@ -306,6 +330,7 @@ class MatematDatabase(object):
user.balance = balance user.balance = balance
user.is_admin = is_admin user.is_admin = is_admin
user.is_member = is_member user.is_member = is_member
user.receipt_pref = receipt_pref
def delete_user(self, user: User) -> None: def delete_user(self, user: User) -> None:
""" """
@ -460,10 +485,10 @@ class MatematDatabase(object):
'old_balance': user.balance 'old_balance': user.balance
}) })
c.execute(''' c.execute('''
INSERT INTO consumptions (ta_id, product_id) INSERT INTO consumptions (ta_id, product)
VALUES (last_insert_rowid(), :product_id) VALUES (last_insert_rowid(), :product)
''', { ''', {
'product_id': product.id 'product': product.name
}) })
# Subtract the price from the user's account balance. # Subtract the price from the user's account balance.
c.execute(''' c.execute('''
@ -491,6 +516,9 @@ class MatematDatabase(object):
if affected != 1: if affected != 1:
raise DatabaseConsistencyError( raise DatabaseConsistencyError(
f'increment_consumption should affect 1 products row, but affected {affected}') f'increment_consumption should affect 1 products row, but affected {affected}')
# Reflect the change in the user and product objects
user.balance -= price
product.stock -= 1
def restock(self, product: Product, count: int) -> None: def restock(self, product: Product, count: int) -> None:
""" """
@ -511,6 +539,8 @@ class MatematDatabase(object):
affected = c.execute('SELECT changes()').fetchone()[0] affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1: if affected != 1:
raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}') raise DatabaseConsistencyError(f'restock should affect 1 products row, but affected {affected}')
# Reflect the change in the product object
product.stock += count
def deposit(self, user: User, amount: int) -> None: def deposit(self, user: User, amount: int) -> None:
""" """
@ -523,13 +553,19 @@ class MatematDatabase(object):
if amount < 0: if amount < 0:
raise ValueError('Cannot deposit a negative value') raise ValueError('Cannot deposit a negative value')
with self.db.transaction() as c: with self.db.transaction() as c:
c.execute('''SELECT balance FROM users WHERE user_id = :user_id''',
[user.id])
row = c.fetchone()
if row is None:
raise DatabaseConsistencyError(f'No such user: {user.id}')
old_balance: int = row[0]
c.execute(''' c.execute('''
INSERT INTO transactions (user_id, value, old_balance) INSERT INTO transactions (user_id, value, old_balance)
VALUES (:user_id, :value, :old_balance) VALUES (:user_id, :value, :old_balance)
''', { ''', {
'user_id': user.id, 'user_id': user.id,
'value': amount, 'value': amount,
'old_balance': user.balance 'old_balance': old_balance
}) })
c.execute(''' c.execute('''
INSERT INTO deposits (ta_id) INSERT INTO deposits (ta_id)
@ -546,3 +582,84 @@ class MatematDatabase(object):
affected = c.execute('SELECT changes()').fetchone()[0] affected = c.execute('SELECT changes()').fetchone()[0]
if affected != 1: if affected != 1:
raise DatabaseConsistencyError(f'deposit should affect 1 users row, but affected {affected}') raise DatabaseConsistencyError(f'deposit should affect 1 users row, but affected {affected}')
# Reflect the change in the user object
user.balance = old_balance + amount
def check_receipt_due(self, user: User) -> bool:
if not isinstance(user.receipt_pref, ReceiptPreference):
raise TypeError()
if user.receipt_pref == ReceiptPreference.NONE or user.email is None:
return False
with self.db.transaction() as c:
c.execute('''
SELECT COALESCE(MAX(r.date), u.created)
FROM users AS u
LEFT JOIN receipts AS r
ON r.user_id = u.user_id
WHERE u.user_id = :user_id
''', [user.id])
last_receipt: datetime = datetime.fromtimestamp(c.fetchone()[0])
next_receipt_due: datetime = user.receipt_pref.next_receipt_due(last_receipt)
return datetime.utcnow() > next_receipt_due
def create_receipt(self, user: User, write: bool = False) -> Receipt:
transactions: List[Transaction] = []
with self.db.transaction() as cursor:
cursor.execute('''
SELECT COALESCE(MAX(r.date), u.created), COALESCE(MAX(r.last_ta_id), 0)
FROM users AS u
LEFT JOIN receipts AS r
ON r.user_id = u.user_id
WHERE u.user_id = :user_id
''', [user.id])
row = cursor.fetchone()
if row is None:
raise DatabaseConsistencyError(f'No such user: {user.id}')
fromdate, min_id = row
created: datetime = datetime.fromtimestamp(fromdate)
cursor.execute('''
SELECT
t.ta_id, t.value, t.old_balance, COALESCE(t.date, 0),
c.ta_id, d.ta_id, m.ta_id, c.product, m.agent, m.reason
FROM transactions AS t
LEFT JOIN consumptions AS c
ON t.ta_id = c.ta_id
LEFT JOIN deposits AS d
ON t.ta_id = d.ta_id
LEFT JOIN modifications AS m
ON t.ta_id = m.ta_id
WHERE t.user_id = :user_id
AND t.ta_id > :min_id
ORDER BY t.date ASC
''', {
'user_id': user.id,
'min_id': min_id
})
rows = cursor.fetchall()
for row in rows:
ta_id, value, old_balance, date, c, d, m, c_prod, m_agent, m_reason = row
if c == ta_id:
t: Transaction = Consumption(ta_id, user, value, old_balance, datetime.fromtimestamp(date), c_prod)
elif d == ta_id:
t = Deposit(ta_id, user, value, old_balance, datetime.fromtimestamp(date))
elif m == ta_id:
t = Modification(ta_id, user, value, old_balance, datetime.fromtimestamp(date), m_agent, m_reason)
else:
t = Transaction(ta_id, user, value, old_balance, datetime.fromtimestamp(date))
transactions.append(t)
if write:
cursor.execute('''
INSERT INTO receipts (user_id, first_ta_id, last_ta_id)
VALUES (:user_id, :first_ta, :last_ta)
''', {
'user_id': user.id,
'first_ta': transactions[0].id,
'last_ta': transactions[-1].id
})
cursor.execute('''SELECT last_insert_rowid()''')
receipt_id: int = int(cursor.fetchone()[0])
else:
receipt_id = -1
receipt = Receipt(receipt_id, transactions, user, created, datetime.utcnow())
return receipt

View file

@ -113,3 +113,91 @@ def migrate_schema_1_to_2(c: sqlite3.Cursor):
ta_id -= 1 ta_id -= 1
# Drop the old consumption table # Drop the old consumption table
c.execute('DROP TABLE consumption') c.execute('DROP TABLE consumption')
def migrate_schema_2_to_3(c: sqlite3.Cursor):
# Add missing columns to users table
c.execute('ALTER TABLE users ADD COLUMN receipt_pref INTEGER(1) NOT NULL DEFAULT 0')
c.execute('''ALTER TABLE users ADD COLUMN created INTEGER(8) NOT NULL DEFAULT 0''')
# Guess creation date based on the oldest entry in the database related to the user ( -1 minute for good measure)
c.execute('''
UPDATE users
SET created = COALESCE(
(SELECT MIN(t.date)
FROM transactions AS t
WHERE t.user_id = users.user_id),
lastchange) - 60
''')
# Fix ON DELETE in transactions table
c.execute('''
CREATE TABLE transactions_new (
ta_id INTEGER PRIMARY KEY,
user_id INTEGER DEFAULT NULL,
value INTEGER(8) NOT NULL,
old_balance INTEGER(8) NOT NULL,
date INTEGER(8) DEFAULT (STRFTIME('%s', 'now')),
FOREIGN KEY (user_id) REFERENCES users(user_id)
ON DELETE SET NULL ON UPDATE CASCADE
)
''')
c.execute('INSERT INTO transactions_new SELECT * FROM transactions')
c.execute('DROP TABLE transactions')
c.execute('ALTER TABLE transactions_new RENAME TO transactions')
# Change consumptions table
c.execute('''
CREATE TABLE consumptions_new (
ta_id INTEGER PRIMARY KEY,
product TEXT NOT NULL,
FOREIGN KEY (ta_id) REFERENCES transactions(ta_id)
ON DELETE CASCADE ON UPDATE CASCADE
)
''')
c.execute('''
INSERT INTO consumptions_new (ta_id, product)
SELECT c.ta_id, COALESCE(p.name, '<unknown>')
FROM consumptions as c
LEFT JOIN products as p
ON c.product_id = p.product_id
''')
c.execute('DROP TABLE consumptions')
c.execute('ALTER TABLE consumptions_new RENAME TO consumptions')
# Change modifications table
c.execute('''
CREATE TABLE modifications_new (
ta_id INTEGER NOT NULL,
agent TEXT NOT NULL,
reason TEXT DEFAULT NULL,
PRIMARY KEY (ta_id),
FOREIGN KEY (ta_id) REFERENCES transactions(ta_id)
ON DELETE CASCADE ON UPDATE CASCADE
)
''')
c.execute('''
INSERT INTO modifications_new (ta_id, agent, reason)
SELECT m.ta_id, COALESCE(u.username, '<unknown>'), m.reason
FROM modifications as m
LEFT JOIN users as u
ON u.user_id = m.agent_id
''')
c.execute('DROP TABLE modifications')
c.execute('ALTER TABLE modifications_new RENAME TO modifications')
# Create missing table
c.execute('''
CREATE TABLE receipts ( -- receipts sent to the users
receipt_id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
first_ta_id INTEGER NOT NULL,
last_ta_id INTEGER NOT NULL,
date INTEGER(8) DEFAULT (STRFTIME('%s', 'now')),
FOREIGN KEY (user_id) REFERENCES users(user_id)
ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (first_ta_id) REFERENCES transactions(ta_id)
ON DELETE SET NULL ON UPDATE CASCADE,
FOREIGN KEY (last_ta_id) REFERENCES transactions(ta_id)
ON DELETE SET NULL ON UPDATE CASCADE
)
''')

View file

@ -0,0 +1,17 @@
from typing import List
from dataclasses import dataclass
from datetime import datetime
from matemat.db.primitives import User, Transaction
@dataclass
class Receipt:
id: int
transactions: List[Transaction]
user: User
from_date: datetime
to_date: datetime

View file

@ -0,0 +1,62 @@
from typing import Callable
from enum import Enum
from datetime import datetime, timedelta
from matemat.util.monthdelta import add_months
class ReceiptPreference(Enum):
"""
A user's preference for the frequency of receiving receipts.
"""
def __new__(cls, *args, **kwargs):
e = object.__new__(cls)
# The enum's internal value
e._value_: int = args[0]
# The function calculating the date after which a new receipt is due.
e._datefunc: Callable[[datetime], datetime] = args[1]
# The human-readable description
e._human_readable: str = args[2]
return e
@property
def human_readable(self) -> str:
"""
A human-readable description of the receipt preference, to be displayed in the UI.
"""
return self._human_readable
def next_receipt_due(self, d: datetime) -> datetime:
return self._datefunc(d)
"""
No receipts should be generated.
"""
NONE = 0, (lambda d: None), 'No receipts'
"""
A receipt should be generated once a week.
"""
WEEKLY = 1, (lambda d: d + timedelta(weeks=1)), 'Weekly'
"""
A receipt should be generated once a month.
"""
MONTHLY = 2, (lambda d: add_months(d, 1)), 'Monthly'
"""
A receipt should be generated once every three month.
"""
QUARTERLY = 3, (lambda d: add_months(d, 3)), 'Quarterly'
"""
A receipt should be generated once every six month.
"""
BIANNUALLY = 4, (lambda d: add_months(d, 6)), 'Biannually'
"""
A receipt should be generated once a year.
"""
YEARLY = 5, (lambda d: add_months(d, 12)), 'Annually'

View file

@ -0,0 +1,74 @@
from typing import Optional
from dataclasses import dataclass
from datetime import datetime
from matemat.db.primitives import User
from matemat.util.currency_format import format_chf
@dataclass(frozen=True)
class Transaction:
id: int
user: User
value: int
old_balance: int
date: datetime
@property
def receipt_date(self) -> str:
if self.date == datetime.fromtimestamp(0):
return '<unknown> '
date: str = self.date.strftime('%d.%m.%Y, %H:%M')
return date
@property
def receipt_value(self) -> str:
value: str = format_chf(self.value, with_currencysign=False, plus_sign=True).rjust(8)
return value
@property
def receipt_description(self) -> str:
return 'Unidentified transaction'
@property
def receipt_message(self) -> Optional[str]:
return None
@dataclass(frozen=True)
class Consumption(Transaction):
product: str
@property
def receipt_description(self) -> str:
return self.product
@dataclass(frozen=True)
class Deposit(Transaction):
@property
def receipt_description(self) -> str:
return 'Deposit'
@dataclass(frozen=True)
class Modification(Transaction):
agent: str
reason: Optional[str]
@property
def receipt_description(self) -> str:
return f'Balance modified by {self.agent}'
@property
def receipt_message(self) -> Optional[str]:
if self.reason is None:
return None
else:
return f'Reason: «{self.reason}»'

View file

@ -2,6 +2,7 @@
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from matemat.db.primitives.ReceiptPreference import ReceiptPreference
@dataclass @dataclass
@ -17,6 +18,7 @@ class User:
:param email: The user's e-mail address (optional). :param email: The user's e-mail address (optional).
:param admin: Whether the user is an administrator. :param admin: Whether the user is an administrator.
:param member: Whether the user is a member. :param member: Whether the user is a member.
:param receipt_pref: The user's preference on how often to receive transaction receipts.
""" """
id: int id: int
@ -25,3 +27,4 @@ class User:
email: Optional[str] = None email: Optional[str] = None
is_admin: bool = False is_admin: bool = False
is_member: bool = False is_member: bool = False
receipt_pref: ReceiptPreference = ReceiptPreference.NONE

View file

@ -4,3 +4,6 @@ This package provides the 'primitive types' the Matemat software deals with - na
from .User import User from .User import User
from .Product import Product from .Product import Product
from .ReceiptPreference import ReceiptPreference
from .Transaction import Transaction, Consumption, Deposit, Modification
from .Receipt import Receipt

View file

@ -101,3 +101,80 @@ SCHEMAS[2] = [
ON DELETE CASCADE ON UPDATE CASCADE ON DELETE CASCADE ON UPDATE CASCADE
); );
'''] ''']
SCHEMAS[3] = [
'''
CREATE TABLE users (
user_id INTEGER PRIMARY KEY,
username TEXT UNIQUE NOT NULL,
email TEXT DEFAULT NULL,
password TEXT NOT NULL,
touchkey TEXT DEFAULT NULL,
is_admin INTEGER(1) NOT NULL DEFAULT 0,
is_member INTEGER(1) NOT NULL DEFAULT 1,
balance INTEGER(8) NOT NULL DEFAULT 0,
lastchange INTEGER(8) NOT NULL DEFAULT 0,
receipt_pref INTEGER(1) NOT NULL DEFAULT 0,
created INTEGER(8) NOT NULL DEFAULT 0
);
''',
'''
CREATE TABLE products (
product_id INTEGER PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
stock INTEGER(8) NOT NULL DEFAULT 0,
price_member INTEGER(8) NOT NULL,
price_non_member INTEGER(8) NOT NULL
);
''',
'''
CREATE TABLE transactions ( -- "superclass" of the following 3 tables
ta_id INTEGER PRIMARY KEY,
user_id INTEGER DEFAULT NULL,
value INTEGER(8) NOT NULL,
old_balance INTEGER(8) NOT NULL,
date INTEGER(8) DEFAULT (STRFTIME('%s', 'now')),
FOREIGN KEY (user_id) REFERENCES users(user_id)
ON DELETE SET NULL ON UPDATE CASCADE
);
''',
'''
CREATE TABLE consumptions ( -- transactions involving buying a product
ta_id INTEGER PRIMARY KEY,
product TEXT NOT NULL,
FOREIGN KEY (ta_id) REFERENCES transactions(ta_id)
ON DELETE CASCADE ON UPDATE CASCADE
);
''',
'''
CREATE TABLE deposits ( -- transactions involving depositing cash
ta_id INTEGER PRIMARY KEY,
FOREIGN KEY (ta_id) REFERENCES transactions(ta_id)
ON DELETE CASCADE ON UPDATE CASCADE
);
''',
'''
CREATE TABLE modifications ( -- transactions involving balance modification by an admin
ta_id INTEGER NOT NULL,
agent TEXT NOT NULL,
reason TEXT DEFAULT NULL,
PRIMARY KEY (ta_id),
FOREIGN KEY (ta_id) REFERENCES transactions(ta_id)
ON DELETE CASCADE ON UPDATE CASCADE
);
''',
'''
CREATE TABLE receipts ( -- receipts sent to the users
receipt_id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
first_ta_id INTEGER NOT NULL,
last_ta_id INTEGER NOT NULL,
date INTEGER(8) DEFAULT (STRFTIME('%s', 'now')),
FOREIGN KEY (user_id) REFERENCES users(user_id)
ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (first_ta_id) REFERENCES transactions(ta_id)
ON DELETE SET NULL ON UPDATE CASCADE,
FOREIGN KEY (last_ta_id) REFERENCES transactions(ta_id)
ON DELETE SET NULL ON UPDATE CASCADE
);
''']

View file

@ -2,9 +2,11 @@
import unittest import unittest
import crypt import crypt
from datetime import datetime, timedelta
from matemat.db import MatematDatabase from matemat.db import MatematDatabase
from matemat.db.primitives import User from matemat.db.primitives import User, Product, ReceiptPreference, Receipt,\
Transaction, Modification, Deposit, Consumption
from matemat.exceptions import AuthenticationError, DatabaseConsistencyError from matemat.exceptions import AuthenticationError, DatabaseConsistencyError
@ -24,12 +26,13 @@ class DatabaseTest(unittest.TestCase):
self.assertEqual('testuser@example.com', row[2]) self.assertEqual('testuser@example.com', row[2])
self.assertEqual(0, row[5]) self.assertEqual(0, row[5])
self.assertEqual(1, row[6]) self.assertEqual(1, row[6])
self.assertEqual(ReceiptPreference.NONE.value, row[7])
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com') db.create_user('testuser', 'supersecurepassword2', 'testuser2@example.com')
def test_get_user(self) -> None: def test_get_user(self) -> None:
with self.db as db: with self.db as db:
with db.transaction(exclusive=False): with db.transaction() as c:
created = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', created = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com',
admin=True, member=False) admin=True, member=False)
user = db.get_user(created.id) user = db.get_user(created.id)
@ -37,8 +40,14 @@ class DatabaseTest(unittest.TestCase):
self.assertEqual('testuser@example.com', user.email) self.assertEqual('testuser@example.com', user.email)
self.assertEqual(False, user.is_member) self.assertEqual(False, user.is_member)
self.assertEqual(True, user.is_admin) self.assertEqual(True, user.is_admin)
self.assertEqual(ReceiptPreference.NONE, user.receipt_pref)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
db.get_user(-1) db.get_user(-1)
# Write an invalid receipt preference to the database
c.execute('UPDATE users SET receipt_pref = 42 WHERE user_id = ?',
[user.id])
with self.assertRaises(DatabaseConsistencyError):
db.get_user(user.id)
def test_list_users(self) -> None: def test_list_users(self) -> None:
with self.db as db: with self.db as db:
@ -49,7 +58,8 @@ class DatabaseTest(unittest.TestCase):
testuser: User = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True) testuser: User = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
db.change_touchkey(testuser, '', 'touchkey', verify_password=False) db.change_touchkey(testuser, '', 'touchkey', verify_password=False)
db.create_user('anothertestuser', 'otherpassword', 'anothertestuser@example.com', False, True) db.create_user('anothertestuser', 'otherpassword', 'anothertestuser@example.com', False, True)
db.create_user('yatu', 'igotapasswordtoo', 'yatu@example.com', False, False) u = db.create_user('yatu', 'igotapasswordtoo', 'yatu@example.com', False, False)
db.change_user(u, agent=None, receipt_pref=ReceiptPreference.WEEKLY)
users = db.list_users() users = db.list_users()
users_with_touchkey = db.list_users(with_touchkey=True) users_with_touchkey = db.list_users(with_touchkey=True)
self.assertEqual(3, len(users)) self.assertEqual(3, len(users))
@ -67,6 +77,7 @@ class DatabaseTest(unittest.TestCase):
self.assertEqual('yatu@example.com', user.email) self.assertEqual('yatu@example.com', user.email)
self.assertFalse(user.is_member) self.assertFalse(user.is_member)
self.assertFalse(user.is_admin) self.assertFalse(user.is_admin)
self.assertEqual(ReceiptPreference.WEEKLY, user.receipt_pref)
usercheck[user.id] = 1 usercheck[user.id] = 1
self.assertEqual(3, len(usercheck)) self.assertEqual(3, len(usercheck))
self.assertEqual(1, len(users_with_touchkey)) self.assertEqual(1, len(users_with_touchkey))
@ -91,7 +102,7 @@ class DatabaseTest(unittest.TestCase):
except AuthenticationError as e: except AuthenticationError as e:
self.assertEqual('Touchkey not set', e.msg) self.assertEqual('Touchkey not set', e.msg)
# Add a touchkey without using the provided function # Add a touchkey without using the provided function
c.execute('''UPDATE users SET touchkey = :tkhash WHERE user_id = :user_id''', { c.execute('''UPDATE users SET touchkey = :tkhash, receipt_pref = 2 WHERE user_id = :user_id''', {
'tkhash': crypt.crypt('0123', crypt.mksalt()), 'tkhash': crypt.crypt('0123', crypt.mksalt()),
'user_id': u.id 'user_id': u.id
}) })
@ -99,6 +110,7 @@ class DatabaseTest(unittest.TestCase):
self.assertEqual(u.id, user.id) self.assertEqual(u.id, user.id)
user = db.login('testuser', touchkey='0123') user = db.login('testuser', touchkey='0123')
self.assertEqual(u.id, user.id) self.assertEqual(u.id, user.id)
self.assertEqual(ReceiptPreference.MONTHLY, user.receipt_pref)
with self.assertRaises(AuthenticationError): with self.assertRaises(AuthenticationError):
# Inexistent user should fail # Inexistent user should fail
db.login('nooone', 'supersecurepassword') db.login('nooone', 'supersecurepassword')
@ -166,18 +178,24 @@ class DatabaseTest(unittest.TestCase):
with self.db as db: with self.db as db:
agent = db.create_user('admin', 'supersecurepassword', 'admin@example.com', True, True) agent = db.create_user('admin', 'supersecurepassword', 'admin@example.com', True, True)
user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True) user = db.create_user('testuser', 'supersecurepassword', 'testuser@example.com', True, True)
db.change_user(user, agent, email='newaddress@example.com', is_admin=False, is_member=False, balance=4200) db.change_user(user, agent, email='newaddress@example.com', is_admin=False, is_member=False, balance=4200,
balance_reason='This is a reason!', receipt_pref=ReceiptPreference.MONTHLY)
# Changes must be reflected in the passed user object # Changes must be reflected in the passed user object
self.assertEqual('newaddress@example.com', user.email) self.assertEqual('newaddress@example.com', user.email)
self.assertFalse(user.is_admin) self.assertFalse(user.is_admin)
self.assertFalse(user.is_member) self.assertFalse(user.is_member)
self.assertEqual(4200, user.balance) self.assertEqual(4200, user.balance)
self.assertEqual(ReceiptPreference.MONTHLY, user.receipt_pref)
# Changes must be reflected in the database # Changes must be reflected in the database
checkuser = db.get_user(user.id) checkuser = db.get_user(user.id)
self.assertEqual('newaddress@example.com', user.email) self.assertEqual('newaddress@example.com', user.email)
self.assertFalse(checkuser.is_admin) self.assertFalse(checkuser.is_admin)
self.assertFalse(checkuser.is_member) self.assertFalse(checkuser.is_member)
self.assertEqual(4200, checkuser.balance) self.assertEqual(4200, checkuser.balance)
self.assertEqual(ReceiptPreference.MONTHLY, checkuser.receipt_pref)
with db.transaction(exclusive=False) as c:
c.execute('SELECT reason FROM modifications LIMIT 1')
self.assertEqual('This is a reason!', c.fetchone()[0])
# Balance change without an agent must fail # Balance change without an agent must fail
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
db.change_user(user, None, balance=0) db.change_user(user, None, balance=0)
@ -376,3 +394,186 @@ class DatabaseTest(unittest.TestCase):
db.increment_consumption(user1, florapowermate) db.increment_consumption(user1, florapowermate)
with self.assertRaises(DatabaseConsistencyError): with self.assertRaises(DatabaseConsistencyError):
db.increment_consumption(user2, clubmate) db.increment_consumption(user2, clubmate)
def test_check_receipt_due(self):
with self.db as db:
# Receipt preference set to 0
user0 = db.create_user('user0', 'supersecurepassword', 'user0@example.com', True, True)
# No email, no receipts
user1 = db.create_user('user1', 'supersecurepassword', None, True, True)
db.change_user(user1, agent=None, receipt_pref=ReceiptPreference.MONTHLY)
# Should receive a receipt, has never received a receipt before
user2 = db.create_user('user2', 'supersecurepassword', 'user2@example.com', True, True)
db.change_user(user2, agent=None, receipt_pref=ReceiptPreference.MONTHLY)
# Should receive a receipt, has received receipts before
user3 = db.create_user('user3', 'supersecurepassword', 'user3@example.com', True, True)
db.change_user(user3, agent=None, receipt_pref=ReceiptPreference.MONTHLY)
# Shouldn't receive a receipt, a month hasn't passed since the last receipt
user4 = db.create_user('user4', 'supersecurepassword', 'user4@example.com', True, True)
db.change_user(user4, agent=None, receipt_pref=ReceiptPreference.MONTHLY)
# Should receive a receipt, has been more than a year since the last receipt
user5 = db.create_user('user5', 'supersecurepassword', 'user5@example.com', True, True)
db.change_user(user5, agent=None, receipt_pref=ReceiptPreference.YEARLY)
# Shouldn't receive a receipt, a year hasn't passed since the last receipt
user6 = db.create_user('user6', 'supersecurepassword', 'user6@example.com', True, True)
db.change_user(user6, agent=None, receipt_pref=ReceiptPreference.YEARLY)
# Invalid receipt preference, should raise a ValueError
user7 = db.create_user('user7', 'supersecurepassword', 'user7@example.com', True, True)
user7.receipt_pref = 42
twoyears: int = int((datetime.utcnow() - timedelta(days=730)).timestamp())
halfyear: int = int((datetime.utcnow() - timedelta(days=183)).timestamp())
twomonths: int = int((datetime.utcnow() - timedelta(days=61)).timestamp())
halfmonth: int = int((datetime.utcnow() - timedelta(days=15)).timestamp())
with db.transaction() as c:
# Fix creation date for user2
c.execute('''
UPDATE users SET created = :twomonths WHERE user_id = :user2
''', {
'twomonths': twomonths,
'user2': user2.id
})
# Create transactions
c.execute('''
INSERT INTO transactions (ta_id, user_id, value, old_balance, date) VALUES
(1, :user0, 4200, 0, :twomonths),
(2, :user0, 100, 4200, :halfmonth),
(3, :user1, 4200, 0, :twomonths),
(4, :user1, 100, 4200, :halfmonth),
(5, :user2, 4200, 0, :twomonths),
(6, :user2, 100, 4200, :halfmonth),
(7, :user3, 4200, 0, :twomonths),
(8, :user3, 100, 4200, :halfmonth),
(9, :user4, 4200, 0, :twomonths),
(10, :user4, 100, 4200, :halfmonth),
(11, :user5, 4200, 0, :twoyears),
(12, :user5, 100, 4200, :halfyear),
(13, :user6, 4200, 0, :twoyears),
(14, :user6, 100, 4200, :halfyear)
''', {
'twoyears': twoyears,
'halfyear': halfyear,
'twomonths': twomonths,
'halfmonth': halfmonth,
'user0': user0.id,
'user1': user1.id,
'user2': user2.id,
'user3': user3.id,
'user4': user4.id,
'user5': user5.id,
'user6': user6.id
})
# Create receipts
c.execute('''
INSERT INTO receipts (user_id, first_ta_id, last_ta_id, date) VALUES
(:user3, 7, 7, :twomonths),
(:user4, 9, 9, :halfmonth),
(:user5, 11, 11, :twoyears),
(:user6, 13, 13, :halfyear)
''', {
'twoyears': twoyears,
'halfyear': halfyear,
'twomonths': twomonths,
'halfmonth': halfmonth,
'user3': user3.id,
'user4': user4.id,
'user5': user5.id,
'user6': user6.id
})
self.assertFalse(db.check_receipt_due(user0))
self.assertFalse(self.db.check_receipt_due(user1))
self.assertTrue(self.db.check_receipt_due(user2))
self.assertTrue(self.db.check_receipt_due(user3))
self.assertFalse(self.db.check_receipt_due(user4))
self.assertTrue(self.db.check_receipt_due(user5))
self.assertFalse(self.db.check_receipt_due(user6))
with self.assertRaises(TypeError):
self.db.check_receipt_due(user7)
def test_create_receipt(self):
with self.db as db:
now: datetime = datetime.utcnow()
admin: User = db.create_user('admin', 'supersecurepassword', 'admin@example.com', True, True)
user: User = db.create_user('user', 'supersecurepassword', 'user@example.com', True, True)
product: Product = db.create_product('Flora Power Mate', 200, 200)
# Create some transactions
db.change_user(user, agent=admin,
receipt_pref=ReceiptPreference.MONTHLY,
balance=4200, balance_reason='Here\'s a gift!')
db.increment_consumption(user, product)
db.deposit(user, 1337)
receipt1: Receipt = db.create_receipt(user, write=True)
with db.transaction() as c:
c.execute('SELECT COUNT(receipt_id) FROM receipts')
self.assertEqual(1, c.fetchone()[0])
db.increment_consumption(user, product)
db.change_user(user, agent=admin, balance=4200)
with db.transaction() as c:
# Unknown transactions
c.execute('''
INSERT INTO transactions (user_id, value, old_balance)
SELECT user_id, 500, balance
FROM users
WHERE user_id = :id
''', [user.id])
receipt2: Receipt = db.create_receipt(user, write=False)
with db.transaction() as c:
c.execute('SELECT COUNT(receipt_id) FROM receipts')
self.assertEqual(1, c.fetchone()[0])
self.assertEqual(user, receipt1.user)
self.assertEqual(3, len(receipt1.transactions))
self.assertIsInstance(receipt1.transactions[0], Modification)
t10: Modification = receipt1.transactions[0]
self.assertEqual(user, receipt1.user)
self.assertEqual(4200, t10.value)
self.assertEqual(0, t10.old_balance)
self.assertEqual(admin.name, t10.agent)
self.assertEqual('Here\'s a gift!', t10.reason)
self.assertIsInstance(receipt1.transactions[1], Consumption)
t11: Consumption = receipt1.transactions[1]
self.assertEqual(user, receipt1.user)
self.assertEqual(-200, t11.value)
self.assertEqual(4200, t11.old_balance)
self.assertEqual('Flora Power Mate', t11.product)
self.assertIsInstance(receipt1.transactions[2], Deposit)
t12: Deposit = receipt1.transactions[2]
self.assertEqual(user, receipt1.user)
self.assertEqual(1337, t12.value)
self.assertEqual(4000, t12.old_balance)
self.assertEqual(user, receipt2.user)
self.assertEqual(3, len(receipt2.transactions))
self.assertIsInstance(receipt2.transactions[0], Consumption)
t20: Consumption = receipt2.transactions[0]
self.assertEqual(user, receipt2.user)
self.assertEqual(-200, t20.value)
self.assertEqual(5337, t20.old_balance)
self.assertEqual('Flora Power Mate', t20.product)
self.assertIsInstance(receipt2.transactions[1], Modification)
t21: Modification = receipt2.transactions[1]
self.assertEqual(user, receipt2.user)
self.assertEqual(-937, t21.value)
self.assertEqual(5137, t21.old_balance)
self.assertEqual(admin.name, t21.agent)
self.assertEqual(None, t21.reason)
self.assertIs(type(receipt2.transactions[2]), Transaction)
t22: Transaction = receipt2.transactions[2]
self.assertEqual(user, receipt2.user)
self.assertEqual(500, t22.value)
self.assertEqual(4200, t22.old_balance)
# TODO: Test cases for primitive object vs database row mismatch

View file

@ -52,7 +52,10 @@ class TestMigrations(unittest.TestCase):
cursor.execute('PRAGMA user_version = 1') cursor.execute('PRAGMA user_version = 1')
# Kick off the migration # Kick off the migration
schema_version = self.db.SCHEMA_VERSION
self.db.SCHEMA_VERSION = 2
self.db._setup() self.db._setup()
self.db.SCHEMA_VERSION = schema_version
# Test whether the new tables were created # Test whether the new tables were created
cursor.execute('PRAGMA table_info(transactions)') cursor.execute('PRAGMA table_info(transactions)')
@ -122,3 +125,61 @@ class TestMigrations(unittest.TestCase):
ON t.ta_id = c.ta_id ON t.ta_id = c.ta_id
WHERE t.user_id = 3 AND c.product_id = 2 AND t.value = -150''') WHERE t.user_id = 3 AND c.product_id = 2 AND t.value = -150''')
self.assertEqual(4, cursor.fetchone()[0]) self.assertEqual(4, cursor.fetchone()[0])
def test_upgrade_2_to_3(self):
# Setup test db with example entries covering - hopefully - all cases
self._initialize_db(2)
cursor: sqlite3.Cursor = self.db._sqlite_db.cursor()
cursor.execute('''
INSERT INTO users VALUES
(1, 'testadmin', 'a@b.c', '$2a$10$herebehashes', NULL, 1, 1, 1337, 0),
(2, 'testuser', NULL, '$2a$10$herebehashes', '$2a$10$herebehashes', 0, 1, 4242, 0),
(3, 'alien', NULL, '$2a$10$herebehashes', '$2a$10$herebehashes', 0, 0, 1234, 0),
(4, 'neverused', NULL, '$2a$10$herebehashes', '$2a$10$herebehashes', 0, 0, 1234, 1234)
''')
cursor.execute('''
INSERT INTO products VALUES
(1, 'Club Mate', 42, 200, 250),
(2, 'Flora Power Mate', 10, 100, 150)
''')
cursor.execute('''
INSERT INTO transactions VALUES
(1, 1, 4200, 0, 1000), -- deposit
(2, 2, 1337, 0, 1001), -- modification
(3, 3, 1337, 0, 1002), -- modification with deleted agent
(4, 2, -200, 1337, 1003), -- consumption
(5, 3, -200, 1337, 1004) -- consumption with deleted product
''')
cursor.execute('''INSERT INTO deposits VALUES (1)''')
cursor.execute('''
INSERT INTO modifications VALUES
(2, 1, 'Account migration'),
(3, 42, 'You can''t find out who i am... MUAHAHAHA!!!')''')
cursor.execute('''INSERT INTO consumptions VALUES (4, 2), (5, 42)''')
cursor.execute('''PRAGMA user_version = 2''')
# Kick off the migration
schema_version = self.db.SCHEMA_VERSION
self.db.SCHEMA_VERSION = 3
self.db._setup()
self.db.SCHEMA_VERSION = schema_version
# Make sure the receipts table was created
cursor.execute('''SELECT COUNT(receipt_id) FROM receipts''')
self.assertEqual(0, cursor.fetchone()[0])
# Make sure users.created was populated with the expected values
cursor.execute('''SELECT u.created FROM users AS u ORDER BY u.user_id ASC''')
self.assertEqual([(940,), (941,), (942,), (1174,)], cursor.fetchall())
# Make sure the modifications table was changed to contain the username, or a fallback
cursor.execute('''SELECT agent FROM modifications WHERE ta_id = 2''')
self.assertEqual('testadmin', cursor.fetchone()[0])
cursor.execute('''SELECT agent FROM modifications WHERE ta_id = 3''')
self.assertEqual('<unknown>', cursor.fetchone()[0])
# Make sure the consumptions table was changed to contain the product name, or a fallback
cursor.execute('''SELECT product FROM consumptions WHERE ta_id = 4''')
self.assertEqual('Flora Power Mate', cursor.fetchone()[0])
cursor.execute('''SELECT product FROM consumptions WHERE ta_id = 5''')
self.assertEqual('<unknown>', cursor.fetchone()[0])

View file

@ -53,12 +53,12 @@ class DatabaseTest(unittest.TestCase):
with self.db as db: with self.db as db:
with db.transaction() as c: with db.transaction() as c:
c.execute(''' c.execute('''
INSERT INTO users VALUES (1, 'testuser', NULL, 'supersecurepassword', NULL, 1, 1, 0, 42) INSERT INTO users VALUES (1, 'testuser', NULL, 'supersecurepassword', NULL, 1, 1, 0, 42, 0, 0)
''') ''')
c = db._sqlite_db.cursor() c = db._sqlite_db.cursor()
c.execute("SELECT * FROM users") c.execute("SELECT * FROM users")
user = c.fetchone() user = c.fetchone()
self.assertEqual((1, 'testuser', None, 'supersecurepassword', None, 1, 1, 0, 42), user) self.assertEqual((1, 'testuser', None, 'supersecurepassword', None, 1, 1, 0, 42, 0, 0), user)
def test_transaction_rollback(self) -> None: def test_transaction_rollback(self) -> None:
""" """
@ -67,9 +67,9 @@ class DatabaseTest(unittest.TestCase):
with self.db as db: with self.db as db:
try: try:
with db.transaction() as c: with db.transaction() as c:
c.execute(""" c.execute('''
INSERT INTO users VALUES (1, 'testuser', NULL, 'supersecurepassword', NULL, 1, 1, 0, 42) INSERT INTO users VALUES (1, 'testuser', NULL, 'supersecurepassword', NULL, 1, 1, 0, 42, 0, 0)
""") ''')
raise ValueError('This should trigger a rollback') raise ValueError('This should trigger a rollback')
except ValueError as e: except ValueError as e:
if str(e) != 'This should trigger a rollback': if str(e) != 'This should trigger a rollback':

View file

@ -6,10 +6,10 @@ import sqlite3
from matemat.exceptions import DatabaseConsistencyError from matemat.exceptions import DatabaseConsistencyError
from matemat.db.schemas import SCHEMAS from matemat.db.schemas import SCHEMAS
from matemat.db.migrations import migrate_schema_1_to_2 from matemat.db.migrations import migrate_schema_1_to_2, migrate_schema_2_to_3
class Transaction(object): class DatabaseTransaction(object):
def __init__(self, db: sqlite3.Connection, exclusive: bool = True) -> None: def __init__(self, db: sqlite3.Connection, exclusive: bool = True) -> None:
self._db: sqlite3.Connection = db self._db: sqlite3.Connection = db
@ -43,7 +43,7 @@ class Transaction(object):
class DatabaseWrapper(object): class DatabaseWrapper(object):
SCHEMA_VERSION = 2 SCHEMA_VERSION = 3
def __init__(self, filename: str) -> None: def __init__(self, filename: str) -> None:
self._filename: str = filename self._filename: str = filename
@ -56,13 +56,20 @@ class DatabaseWrapper(object):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close() self.close()
def transaction(self, exclusive: bool = True) -> Transaction: def transaction(self, exclusive: bool = True) -> DatabaseTransaction:
if self._sqlite_db is None: if self._sqlite_db is None:
raise RuntimeError(f'Database connection to {self._filename} is not established.') raise RuntimeError(f'Database connection to {self._filename} is not established.')
return Transaction(self._sqlite_db, exclusive) return DatabaseTransaction(self._sqlite_db, exclusive)
def _setup(self) -> None: def _setup(self) -> None:
# Enable foreign key enforcement
cursor = self._sqlite_db.cursor()
cursor.execute('PRAGMA foreign_keys = 1')
# Create or update schemas if necessary
with self.transaction() as c: with self.transaction() as c:
# Defer foreign key enforcement in the setup transaction
c.execute('PRAGMA defer_foreign_keys = 1')
version: int = self._user_version version: int = self._user_version
if version < 1: if version < 1:
# Don't use executescript, as it issues a COMMIT first # Don't use executescript, as it issues a COMMIT first
@ -77,8 +84,10 @@ class DatabaseWrapper(object):
def _upgrade(self, from_version: int, to_version: int) -> None: def _upgrade(self, from_version: int, to_version: int) -> None:
with self.transaction() as c: with self.transaction() as c:
# Note to future s3lph: If there are further migrations, also consider upgrades like 1 -> 3 # Note to future s3lph: If there are further migrations, also consider upgrades like 1 -> 3
if from_version == 1 and to_version == 2: if from_version == 1 and to_version >= 2:
migrate_schema_1_to_2(c) migrate_schema_1_to_2(c)
if from_version <= 2 and to_version >= 3:
migrate_schema_2_to_3(c)
def connect(self) -> None: def connect(self) -> None:
if self.is_connected(): if self.is_connected():

View file

@ -1,11 +1,12 @@
def format_chf(value: int, with_currencysign: bool = True) -> str: def format_chf(value: int, with_currencysign: bool = True, plus_sign: bool = False) -> str:
""" """
Formats a centime value into a commonly understood representation ("CHF -13.37"). Formats a centime value into a commonly understood representation ("CHF -13.37").
:param value: The value to format, in centimes. :param value: The value to format, in centimes.
:param with_currencysign: Whether to include the currency prefix ("CHF ") in the output. :param with_currencysign: Whether to include the currency prefix ("CHF ") in the output.
:param plus_sign: Whether to denote positive values with an explicit "+" sign before the value.
:return: A human-readable string representation. :return: A human-readable string representation.
""" """
sign: str = '' sign: str = ''
@ -13,6 +14,8 @@ def format_chf(value: int, with_currencysign: bool = True) -> str:
# As // and % round towards -Inf, convert into a positive value and prepend the negative sign # As // and % round towards -Inf, convert into a positive value and prepend the negative sign
sign = '-' sign = '-'
value = -value value = -value
elif plus_sign:
sign = '+'
# Split into full francs and fractions (centimes) # Split into full francs and fractions (centimes)
full: int = value // 100 full: int = value // 100
frac: int = value % 100 frac: int = value % 100

View file

@ -0,0 +1,29 @@
from typing import Tuple
from datetime import datetime, timedelta
import calendar
def add_months(d: datetime, months: int) -> datetime:
"""
Add the given number of months to the passed date, considering the varying numbers of days in a month.
:param d: The date time to add to.
:param months: The number of months to add to.
:return: A datetime object offset by the requested number of months.
"""
if not isinstance(d, datetime) or not isinstance(months, int):
raise TypeError()
if months < 0:
raise ValueError('Can only add a positive number of months.')
nextmonth: Tuple[int, int] = (d.year, d.month)
days: int = 0
# Iterate the months between the passed date and the target month
for i in range(months):
days += calendar.monthlen(*nextmonth)
nextmonth = calendar.nextmonth(*nextmonth)
# Set the day of month temporarily to 1, then add the day offset to reach the 1st of the target month
newdate: datetime = d.replace(day=1) + timedelta(days=days)
# Re-set the day of month to the intended value, but capped by the max. day in the target month
newdate = newdate.replace(day=min(d.day, calendar.monthlen(newdate.year, newdate.month)))
return newdate

View file

@ -9,38 +9,56 @@ class TestCurrencyFormat(unittest.TestCase):
def test_format_zero(self): def test_format_zero(self):
self.assertEqual('CHF 0.00', format_chf(0)) self.assertEqual('CHF 0.00', format_chf(0))
self.assertEqual('0.00', format_chf(0, False)) self.assertEqual('0.00', format_chf(0, False))
self.assertEqual('CHF +0.00', format_chf(0, plus_sign=True))
self.assertEqual('+0.00', format_chf(0, False, plus_sign=True))
def test_format_positive_full(self): def test_format_positive_full(self):
self.assertEqual('CHF 42.00', format_chf(4200)) self.assertEqual('CHF 42.00', format_chf(4200))
self.assertEqual('42.00', format_chf(4200, False)) self.assertEqual('42.00', format_chf(4200, False))
self.assertEqual('CHF +42.00', format_chf(4200, plus_sign=True))
self.assertEqual('+42.00', format_chf(4200, False, plus_sign=True))
def test_format_negative_full(self): def test_format_negative_full(self):
self.assertEqual('CHF -42.00', format_chf(-4200)) self.assertEqual('CHF -42.00', format_chf(-4200))
self.assertEqual('-42.00', format_chf(-4200, False)) self.assertEqual('-42.00', format_chf(-4200, False))
self.assertEqual('CHF -42.00', format_chf(-4200, plus_sign=True))
self.assertEqual('-42.00', format_chf(-4200, False, plus_sign=True))
def test_format_positive_frac(self): def test_format_positive_frac(self):
self.assertEqual('CHF 13.37', format_chf(1337)) self.assertEqual('CHF 13.37', format_chf(1337))
self.assertEqual('13.37', format_chf(1337, False)) self.assertEqual('13.37', format_chf(1337, False))
self.assertEqual('CHF +13.37', format_chf(1337, plus_sign=True))
self.assertEqual('+13.37', format_chf(1337, False, plus_sign=True))
def test_format_negative_frac(self): def test_format_negative_frac(self):
self.assertEqual('CHF -13.37', format_chf(-1337)) self.assertEqual('CHF -13.37', format_chf(-1337))
self.assertEqual('-13.37', format_chf(-1337, False)) self.assertEqual('-13.37', format_chf(-1337, False))
self.assertEqual('CHF -13.37', format_chf(-1337, plus_sign=True))
self.assertEqual('-13.37', format_chf(-1337, False, plus_sign=True))
def test_format_pad_left_positive(self): def test_format_pad_left_positive(self):
self.assertEqual('CHF 0.01', format_chf(1)) self.assertEqual('CHF 0.01', format_chf(1))
self.assertEqual('0.01', format_chf(1, False)) self.assertEqual('0.01', format_chf(1, False))
self.assertEqual('CHF +0.01', format_chf(1, plus_sign=True))
self.assertEqual('+0.01', format_chf(1, False, plus_sign=True))
def test_format_pad_left_negative(self): def test_format_pad_left_negative(self):
self.assertEqual('CHF -0.01', format_chf(-1)) self.assertEqual('CHF -0.01', format_chf(-1))
self.assertEqual('-0.01', format_chf(-1, False)) self.assertEqual('-0.01', format_chf(-1, False))
self.assertEqual('CHF -0.01', format_chf(-1, plus_sign=True))
self.assertEqual('-0.01', format_chf(-1, False, plus_sign=True))
def test_format_pad_right_positive(self): def test_format_pad_right_positive(self):
self.assertEqual('CHF 4.20', format_chf(420)) self.assertEqual('CHF 4.20', format_chf(420))
self.assertEqual('4.20', format_chf(420, False)) self.assertEqual('4.20', format_chf(420, False))
self.assertEqual('CHF +4.20', format_chf(420, plus_sign=True))
self.assertEqual('+4.20', format_chf(420, False, plus_sign=True))
def test_format_pad_right_negative(self): def test_format_pad_right_negative(self):
self.assertEqual('CHF -4.20', format_chf(-420)) self.assertEqual('CHF -4.20', format_chf(-420))
self.assertEqual('-4.20', format_chf(-420, False)) self.assertEqual('-4.20', format_chf(-420, False))
self.assertEqual('CHF -4.20', format_chf(-420, plus_sign=True))
self.assertEqual('-4.20', format_chf(-420, False, plus_sign=True))
def test_parse_empty(self): def test_parse_empty(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -52,20 +70,29 @@ class TestCurrencyFormat(unittest.TestCase):
def test_parse_zero(self): def test_parse_zero(self):
self.assertEqual(0, parse_chf('CHF0')) self.assertEqual(0, parse_chf('CHF0'))
self.assertEqual(0, parse_chf('CHF-0'))
self.assertEqual(0, parse_chf('CHF+0'))
self.assertEqual(0, parse_chf('CHF 0')) self.assertEqual(0, parse_chf('CHF 0'))
self.assertEqual(0, parse_chf('CHF +0'))
self.assertEqual(0, parse_chf('CHF -0')) self.assertEqual(0, parse_chf('CHF -0'))
self.assertEqual(0, parse_chf('CHF 0.')) self.assertEqual(0, parse_chf('CHF 0.'))
self.assertEqual(0, parse_chf('CHF 0.0')) self.assertEqual(0, parse_chf('CHF 0.0'))
self.assertEqual(0, parse_chf('CHF 0.00')) self.assertEqual(0, parse_chf('CHF 0.00'))
self.assertEqual(0, parse_chf('CHF +0.'))
self.assertEqual(0, parse_chf('CHF +0.0'))
self.assertEqual(0, parse_chf('CHF +0.00'))
self.assertEqual(0, parse_chf('CHF -0.')) self.assertEqual(0, parse_chf('CHF -0.'))
self.assertEqual(0, parse_chf('CHF -0.0')) self.assertEqual(0, parse_chf('CHF -0.0'))
self.assertEqual(0, parse_chf('CHF -0.00')) self.assertEqual(0, parse_chf('CHF -0.00'))
self.assertEqual(0, parse_chf('0')) self.assertEqual(0, parse_chf('0'))
self.assertEqual(0, parse_chf('0')) self.assertEqual(0, parse_chf('+0'))
self.assertEqual(0, parse_chf('-0')) self.assertEqual(0, parse_chf('-0'))
self.assertEqual(0, parse_chf('0.')) self.assertEqual(0, parse_chf('0.'))
self.assertEqual(0, parse_chf('0.0')) self.assertEqual(0, parse_chf('0.0'))
self.assertEqual(0, parse_chf('0.00')) self.assertEqual(0, parse_chf('0.00'))
self.assertEqual(0, parse_chf('+0.'))
self.assertEqual(0, parse_chf('+0.0'))
self.assertEqual(0, parse_chf('+0.00'))
self.assertEqual(0, parse_chf('-0.')) self.assertEqual(0, parse_chf('-0.'))
self.assertEqual(0, parse_chf('-0.0')) self.assertEqual(0, parse_chf('-0.0'))
self.assertEqual(0, parse_chf('-0.00')) self.assertEqual(0, parse_chf('-0.00'))
@ -80,6 +107,16 @@ class TestCurrencyFormat(unittest.TestCase):
self.assertEqual(4200, parse_chf('CHF 42.0')) self.assertEqual(4200, parse_chf('CHF 42.0'))
self.assertEqual(4200, parse_chf('42.0')) self.assertEqual(4200, parse_chf('42.0'))
def test_parse_positive_full_with_sign(self):
self.assertEqual(4200, parse_chf('CHF +42.00'))
self.assertEqual(4200, parse_chf('+42.00'))
self.assertEqual(4200, parse_chf('CHF +42'))
self.assertEqual(4200, parse_chf('+42'))
self.assertEqual(4200, parse_chf('CHF +42.'))
self.assertEqual(4200, parse_chf('+42.'))
self.assertEqual(4200, parse_chf('CHF +42.0'))
self.assertEqual(4200, parse_chf('+42.0'))
def test_parse_negative_full(self): def test_parse_negative_full(self):
self.assertEqual(-4200, parse_chf('CHF -42.00')) self.assertEqual(-4200, parse_chf('CHF -42.00'))
self.assertEqual(-4200, parse_chf('-42.00')) self.assertEqual(-4200, parse_chf('-42.00'))
@ -94,6 +131,10 @@ class TestCurrencyFormat(unittest.TestCase):
self.assertEqual(1337, parse_chf('CHF 13.37')) self.assertEqual(1337, parse_chf('CHF 13.37'))
self.assertEqual(1337, parse_chf('13.37')) self.assertEqual(1337, parse_chf('13.37'))
def test_parse_positive_frac_with_sign(self):
self.assertEqual(1337, parse_chf('CHF +13.37'))
self.assertEqual(1337, parse_chf('+13.37'))
def test_parse_negative_frac(self): def test_parse_negative_frac(self):
self.assertEqual(-1337, parse_chf('CHF -13.37')) self.assertEqual(-1337, parse_chf('CHF -13.37'))
self.assertEqual(-1337, parse_chf('-13.37')) self.assertEqual(-1337, parse_chf('-13.37'))
@ -102,6 +143,10 @@ class TestCurrencyFormat(unittest.TestCase):
self.assertEqual(1, parse_chf('CHF 0.01')) self.assertEqual(1, parse_chf('CHF 0.01'))
self.assertEqual(1, parse_chf('0.01')) self.assertEqual(1, parse_chf('0.01'))
def test_parse_pad_left_positive_with_sign(self):
self.assertEqual(1, parse_chf('CHF +0.01'))
self.assertEqual(1, parse_chf('+0.01'))
def test_parse_pad_left_negative(self): def test_parse_pad_left_negative(self):
self.assertEqual(-1, parse_chf('CHF -0.01')) self.assertEqual(-1, parse_chf('CHF -0.01'))
self.assertEqual(-1, parse_chf('-0.01')) self.assertEqual(-1, parse_chf('-0.01'))
@ -112,6 +157,12 @@ class TestCurrencyFormat(unittest.TestCase):
self.assertEqual(420, parse_chf('CHF 4.2')) self.assertEqual(420, parse_chf('CHF 4.2'))
self.assertEqual(420, parse_chf('4.2')) self.assertEqual(420, parse_chf('4.2'))
def test_parse_pad_right_positive_with_sign(self):
self.assertEqual(420, parse_chf('CHF +4.20'))
self.assertEqual(420, parse_chf('+4.20'))
self.assertEqual(420, parse_chf('CHF +4.2'))
self.assertEqual(420, parse_chf('+4.2'))
def test_parse_pad_right_negative(self): def test_parse_pad_right_negative(self):
self.assertEqual(-420, parse_chf('CHF -4.20')) self.assertEqual(-420, parse_chf('CHF -4.20'))
self.assertEqual(-420, parse_chf('-4.20')) self.assertEqual(-420, parse_chf('-4.20'))
@ -121,10 +172,14 @@ class TestCurrencyFormat(unittest.TestCase):
def test_parse_too_many_decimals(self): def test_parse_too_many_decimals(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parse_chf('123.456') parse_chf('123.456')
with self.assertRaises(ValueError):
parse_chf('-123.456')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parse_chf('CHF 0.456') parse_chf('CHF 0.456')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parse_chf('CHF 0.450') parse_chf('CHF 0.450')
with self.assertRaises(ValueError):
parse_chf('CHF +0.456')
def test_parse_wrong_separator(self): def test_parse_wrong_separator(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -137,3 +192,7 @@ class TestCurrencyFormat(unittest.TestCase):
parse_chf('13.-7') parse_chf('13.-7')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parse_chf('CHF 13.-7') parse_chf('CHF 13.-7')
with self.assertRaises(ValueError):
parse_chf('+13.-7')
with self.assertRaises(ValueError):
parse_chf('CHF -13.-7')

View file

@ -0,0 +1,61 @@
import unittest
from datetime import datetime
from matemat.util.monthdelta import add_months
class TestMonthDelta(unittest.TestCase):
def test_monthdelta_zero(self):
date = datetime(2018, 9, 8, 13, 37, 42)
offset_date = date
self.assertEqual(offset_date, add_months(date, 0))
def test_monthdelta_one(self):
date = datetime(2018, 9, 8, 13, 37, 42)
offset_date = date.replace(month=10)
self.assertEqual(offset_date, add_months(date, 1))
def test_monthdelta_two(self):
date = datetime(2018, 9, 8, 13, 37, 42)
offset_date = date.replace(month=11)
self.assertEqual(offset_date, add_months(date, 2))
def test_monthdelta_yearwrap(self):
date = datetime(2018, 9, 8, 13, 37, 42)
offset_date = date.replace(year=2019, month=1)
self.assertEqual(offset_date, add_months(date, 4))
def test_monthdelta_yearwrap_five(self):
date = datetime(2018, 9, 8, 13, 37, 42)
offset_date = date.replace(year=2023, month=3)
self.assertEqual(offset_date, add_months(date, 54))
def test_monthdelta_rounddown_31_30(self):
date = datetime(2018, 3, 31, 13, 37, 42)
offset_date = date.replace(month=4, day=30)
self.assertEqual(offset_date, add_months(date, 1))
def test_monthdelta_rounddown_feb(self):
date = datetime(2018, 1, 31, 13, 37, 42)
offset_date = date.replace(month=2, day=28)
self.assertEqual(offset_date, add_months(date, 1))
def test_monthdelta_rounddown_feb_leap(self):
date = datetime(2020, 1, 31, 13, 37, 42)
offset_date = date.replace(month=2, day=29)
self.assertEqual(offset_date, add_months(date, 1))
def test_fail_negative(self):
date = datetime(2020, 1, 31, 13, 37, 42)
with self.assertRaises(ValueError):
add_months(date, -1)
def test_fail_type(self):
date = datetime(2020, 1, 31, 13, 37, 42)
with self.assertRaises(TypeError):
add_months(date, 1.2)
with self.assertRaises(TypeError):
add_months(42, 1)

View file

@ -8,5 +8,5 @@ server will attempt to serve the request with a static resource in a previously
from .requestargs import RequestArgument, RequestArguments from .requestargs import RequestArgument, RequestArguments
from .responses import PageletResponse, RedirectResponse, TemplateResponse from .responses import PageletResponse, RedirectResponse, TemplateResponse
from .httpd import MatematWebserver, HttpHandler, pagelet, pagelet_init from .httpd import MatematWebserver, HttpHandler, pagelet, pagelet_init, pagelet_cron
from .config import parse_config_file from .config import parse_config_file

View file

@ -11,6 +11,7 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
from http.cookies import SimpleCookie from http.cookies import SimpleCookie
from uuid import uuid4 from uuid import uuid4
from datetime import datetime, timedelta from datetime import datetime, timedelta
from threading import Event, Timer, Thread
import jinja2 import jinja2
@ -41,6 +42,9 @@ _PAGELET_PATHS: Dict[str, Callable[[str, # HTTP method (GET, POST, ...)
# The pagelet initialization functions, to be executed upon startup # The pagelet initialization functions, to be executed upon startup
_PAGELET_INIT_FUNCTIONS: Set[Callable[[Dict[str, str], logging.Logger], None]] = set() _PAGELET_INIT_FUNCTIONS: Set[Callable[[Dict[str, str], logging.Logger], None]] = set()
_PAGELET_CRON_STATIC_EVENT: Event = Event()
_PAGELET_CRON_RUNNER: Callable[[Callable[[Dict[str, str], jinja2.Environment, logging.Logger], None]], None] = None
# Inactivity timeout for client sessions # Inactivity timeout for client sessions
_SESSION_TIMEOUT: int = 3600 _SESSION_TIMEOUT: int = 3600
_MAX_POST: int = 1_000_000 _MAX_POST: int = 1_000_000
@ -118,6 +122,90 @@ def pagelet_init(fun: Callable[[Dict[str, str], logging.Logger], None]):
_PAGELET_INIT_FUNCTIONS.add(fun) _PAGELET_INIT_FUNCTIONS.add(fun)
class _GlobalEventTimer(Thread):
"""
A timer similar to threading.Timer, except that waits on an externally supplied threading.Event instance,
therefore allowing all timers waiting on the same event to be cancelled at once.
"""
def __init__(self, interval: float, event: Event, fun, *args, **kwargs):
"""
Create a new _GlobalEventTimer.
:param interval: The delay after which to run the function.
:param event: The external threading.Event to wait on.
:param fun: The function to call.
:param args: The positional arguments to pass to the function.
:param kwargs: The keyword arguments to pass to the function.
"""
Thread.__init__(self)
self.interval = interval
self.fun = fun
self.args = args if args is not None else []
self.kwargs = kwargs if kwargs is not None else {}
self.event = event
def run(self):
self.event.wait(self.interval)
if not self.event.is_set():
self.fun(*self.args, **self.kwargs)
# Do NOT call event.set(), as done in threading.Timer, as that would cancel all other timers
def pagelet_cron(weeks: int = 0,
days: int = 0,
hours: int = 0,
seconds: int = 0,
minutes: int = 0,
milliseconds: int = 0,
microseconds: int = 0):
"""
Annotate a function to act as a pagelet cron function. The function will be called in a regular interval, defined
by the arguments passed to the decorator, which are passed to a timedelta object.
The function must have the following signature:
(config: Dict[str, str], jinja_env: jinja2.Environment, logger: logging.Logger) -> None
config: The mutable dictionary of variables read from the [Pagelets] section of the configuration file.
jinja_env: The Jinja2 environment used by the web server.
logger: The server's logger instance.
returns: Nothing.
:param weeks: Number of weeks in the interval.
:param days: Number of days in the interval.
:param hours: Number of hours in the interval.
:param seconds: Number of seconds in the interval.
:param minutes: Number of minutes in the interval.
:param milliseconds: Number of milliseconds in the interval.
:param microseconds: Number of microseconds in the interval.
"""
def cron_wrapper(fun: Callable[[Dict[str, str], jinja2.Environment, logging.Logger], None]):
# Create the timedelta object
delta: timedelta = timedelta(weeks=weeks,
days=days,
hours=hours,
seconds=seconds,
minutes=minutes,
milliseconds=milliseconds,
microseconds=microseconds)
# This function is called once in the specified interval
def cron():
# Set a new timer
t: Timer = _GlobalEventTimer(delta.total_seconds(), _PAGELET_CRON_STATIC_EVENT, cron)
t.start()
# Have the cron job be picked up by the cron runner provided by the web server
if _PAGELET_CRON_RUNNER is not None:
_PAGELET_CRON_RUNNER(fun)
# Set a timer to run the cron job after the specified interval
timer: Timer = _GlobalEventTimer(delta.total_seconds(), _PAGELET_CRON_STATIC_EVENT, cron)
timer.start()
return cron_wrapper
class MatematHTTPServer(HTTPServer): class MatematHTTPServer(HTTPServer):
""" """
A http.server.HTTPServer subclass that acts as a container for data that must be persistent between requests. A http.server.HTTPServer subclass that acts as a container for data that must be persistent between requests.
@ -212,10 +300,14 @@ class MatematWebserver(object):
running. If any exception is raised in the initialization phase, the program is terminated with a non-zero running. If any exception is raised in the initialization phase, the program is terminated with a non-zero
exit code. exit code.
""" """
global _PAGELET_CRON_RUNNER
try:
try: try:
# Run all pagelet initialization functions # Run all pagelet initialization functions
for fun in _PAGELET_INIT_FUNCTIONS: for fun in _PAGELET_INIT_FUNCTIONS:
fun(self._httpd.pagelet_variables, self._httpd.logger) fun(self._httpd.pagelet_variables, self._httpd.logger)
# Set pagelet cron runner to self
_PAGELET_CRON_RUNNER = self._cron_runner
except BaseException as e: except BaseException as e:
# If an error occurs, log it and terminate # If an error occurs, log it and terminate
self._httpd.logger.exception(e) self._httpd.logger.exception(e)
@ -223,6 +315,19 @@ class MatematWebserver(object):
raise e raise e
# If pagelet initialization went fine, start the HTTP server # If pagelet initialization went fine, start the HTTP server
self._httpd.serve_forever() self._httpd.serve_forever()
finally:
# Cancel all cron timers at once when the webserver is shutting down
_PAGELET_CRON_STATIC_EVENT.set()
def _cron_runner(self, fun: Callable[[Dict[str, str], jinja2.Environment, logging.Logger], None]):
self._httpd.logger.info('Executing cron job "%s"', fun.__name__)
try:
fun(self._httpd.pagelet_variables,
self._httpd.jinja_env,
self._httpd.logger)
self._httpd.logger.info('Completed cron job "%s"', fun.__name__)
except BaseException as e:
self._httpd.logger.exception('Cron job "%s" failed:', fun.__name__, exc_info=e)
class HttpHandler(BaseHTTPRequestHandler): class HttpHandler(BaseHTTPRequestHandler):

View file

@ -15,3 +15,4 @@ from .admin import admin
from .moduser import moduser from .moduser import moduser
from .modproduct import modproduct from .modproduct import modproduct
from .userbootstrap import userbootstrap from .userbootstrap import userbootstrap
from .receipt_smtp_cron import receipt_smtp_cron

View file

@ -5,7 +5,7 @@ import magic
from matemat.webserver import pagelet, RequestArguments, PageletResponse, RedirectResponse, TemplateResponse from matemat.webserver import pagelet, RequestArguments, PageletResponse, RedirectResponse, TemplateResponse
from matemat.db import MatematDatabase from matemat.db import MatematDatabase
from matemat.db.primitives import User from matemat.db.primitives import User, ReceiptPreference
from matemat.exceptions import DatabaseConsistencyError, HttpException from matemat.exceptions import DatabaseConsistencyError, HttpException
@ -47,7 +47,8 @@ def admin(method: str,
# Render the "Admin/Settings" page # Render the "Admin/Settings" page
return TemplateResponse('admin.html', return TemplateResponse('admin.html',
authuser=user, authlevel=authlevel, users=users, products=products, authuser=user, authlevel=authlevel, users=users, products=products,
setupname=config['InstanceName']) receipt_preference_class=ReceiptPreference,
setupname=config['InstanceName'], config_smtp_enabled=config['SmtpSendReceipts'])
def handle_change(args: RequestArguments, user: User, db: MatematDatabase, config: Dict[str, str]) -> None: def handle_change(args: RequestArguments, user: User, db: MatematDatabase, config: Dict[str, str]) -> None:
@ -73,9 +74,13 @@ def handle_change(args: RequestArguments, user: User, db: MatematDatabase, confi
# An empty e-mail field should be interpreted as NULL # An empty e-mail field should be interpreted as NULL
if len(email) == 0: if len(email) == 0:
email = None email = None
# Attempt to update username and e-mail
try: try:
db.change_user(user, agent=None, name=username, email=email) receipt_pref = ReceiptPreference(int(str(args.receipt_pref)))
except ValueError:
return
# Attempt to update username, e-mail and receipt preference
try:
db.change_user(user, agent=None, name=username, email=email, receipt_pref=receipt_pref)
except DatabaseConsistencyError: except DatabaseConsistencyError:
return return

View file

@ -23,6 +23,31 @@ def initialization(config: Dict[str, str],
if 'DatabaseFile' not in config: if 'DatabaseFile' not in config:
config['DatabaseFile'] = './matemat.db' config['DatabaseFile'] = './matemat.db'
logger.warning('Property \'DatabaseFile\' not set, using \'./matemat.db\'') logger.warning('Property \'DatabaseFile\' not set, using \'./matemat.db\'')
if 'SmtpSendReceipts' not in config:
config['SmtpSendReceipts'] = '0'
logger.warning('Property \'SmtpSendReceipts\' not set, using \'0\'')
if config['SmtpSendReceipts'] == '1':
if 'SmtpFrom' not in config:
logger.fatal('\'SmtpSendReceipts\' set to \'1\', but \'SmtpFrom\' missing.')
raise KeyError()
if 'SmtpSubj' not in config:
logger.fatal('\'SmtpSendReceipts\' set to \'1\', but \'SmtpSubj\' missing.')
raise KeyError()
if 'SmtpHost' not in config:
logger.fatal('\'SmtpSendReceipts\' set to \'1\', but \'SmtpHost\' missing.')
raise KeyError()
if 'SmtpPort' not in config:
logger.fatal('\'SmtpSendReceipts\' set to \'1\', but \'SmtpPort\' missing.')
raise KeyError()
if 'SmtpUser' not in config:
logger.fatal('\'SmtpSendReceipts\' set to \'1\', but \'SmtpUser\' missing.')
raise KeyError()
if 'SmtpPass' not in config:
logger.fatal('\'SmtpSendReceipts\' set to \'1\', but \'SmtpPass\' missing.')
raise KeyError()
if 'SmtpEnforceTLS' not in config:
config['SmtpEnforceTLS'] = '1'
logger.warning('Property \'SmtpEnforceTLS\' not set, using \'1\'')
with MatematDatabase(config['DatabaseFile']): with MatematDatabase(config['DatabaseFile']):
# Connect to the database to create it and perform any schema migrations # Connect to the database to create it and perform any schema migrations
pass pass

View file

@ -1,11 +1,11 @@
from typing import Any, Dict, Union from typing import Any, Dict, Optional, Union
import os import os
import magic import magic
from matemat.webserver import pagelet, RequestArguments, PageletResponse, RedirectResponse, TemplateResponse from matemat.webserver import pagelet, RequestArguments, PageletResponse, RedirectResponse, TemplateResponse
from matemat.db import MatematDatabase from matemat.db import MatematDatabase
from matemat.db.primitives import User from matemat.db.primitives import User, ReceiptPreference
from matemat.exceptions import DatabaseConsistencyError, HttpException from matemat.exceptions import DatabaseConsistencyError, HttpException
from matemat.util.currency_format import parse_chf from matemat.util.currency_format import parse_chf
@ -56,7 +56,8 @@ def moduser(method: str,
# Render the "Modify User" page # Render the "Modify User" page
return TemplateResponse('moduser.html', return TemplateResponse('moduser.html',
authuser=authuser, user=user, authlevel=authlevel, authuser=authuser, user=user, authlevel=authlevel,
setupname=config['InstanceName']) receipt_preference_class=ReceiptPreference,
setupname=config['InstanceName'], config_smtp_enabled=config['SmtpSendReceipts'])
def handle_change(args: RequestArguments, user: User, authuser: User, db: MatematDatabase, config: Dict[str, str]) \ def handle_change(args: RequestArguments, user: User, authuser: User, db: MatematDatabase, config: Dict[str, str]) \
@ -87,13 +88,24 @@ def handle_change(args: RequestArguments, user: User, authuser: User, db: Matema
# Admin requested update of the user's details # Admin requested update of the user's details
elif change == 'update': elif change == 'update':
# Only write a change if all properties of the user are present in the request arguments # Only write a change if all properties of the user are present in the request arguments
if 'username' not in args or 'email' not in args or 'password' not in args or 'balance' not in args: if 'username' not in args or \
'email' not in args or \
'password' not in args or \
'balance' not in args or \
'receipt_pref' not in args:
return return
# Read the properties from the request arguments # Read the properties from the request arguments
username = str(args.username) username = str(args.username)
email = str(args.email) email = str(args.email)
try:
receipt_pref = ReceiptPreference(int(str(args.receipt_pref)))
except ValueError:
return
password = str(args.password) password = str(args.password)
balance = parse_chf(str(args.balance)) balance = parse_chf(str(args.balance))
balance_reason: Optional[str] = str(args.reason)
if balance_reason == '':
balance_reason = None
is_member = 'ismember' in args is_member = 'ismember' in args
is_admin = 'isadmin' in args is_admin = 'isadmin' in args
# An empty e-mail field should be interpreted as NULL # An empty e-mail field should be interpreted as NULL
@ -106,7 +118,7 @@ def handle_change(args: RequestArguments, user: User, authuser: User, db: Matema
db.change_password(user, '', password, verify_password=False) db.change_password(user, '', password, verify_password=False)
# Write the user detail changes # Write the user detail changes
db.change_user(user, agent=authuser, name=username, email=email, is_member=is_member, is_admin=is_admin, db.change_user(user, agent=authuser, name=username, email=email, is_member=is_member, is_admin=is_admin,
balance=balance) balance=balance, balance_reason=balance_reason, receipt_pref=receipt_pref)
except DatabaseConsistencyError: except DatabaseConsistencyError:
return return
# If a new avatar was uploaded, process it # If a new avatar was uploaded, process it

View file

@ -0,0 +1,98 @@
from typing import Dict, List, Tuple
import logging
import smtplib as smtp
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from jinja2 import Environment, Template
from matemat.webserver import pagelet_cron
from matemat.db import MatematDatabase
from matemat.db.primitives import User, Receipt
from matemat.util.currency_format import format_chf
@pagelet_cron(hours=6)
def receipt_smtp_cron(config: Dict[str, str],
jinja_env: Environment,
logger: logging.Logger) -> None:
if config['SmtpSendReceipts'] != '1':
# Sending receipts via mail is disabled
return
logger.info('Searching users due for receipts.')
receipts: List[Receipt] = []
# Connect to the database
with MatematDatabase(config['DatabaseFile']) as db:
users: List[User] = db.list_users()
for user in users:
if user.email is None:
logger.debug('User "%s" has no e-mail address.', user.name)
if db.check_receipt_due(user):
logger.info('Generating receipt for user "%s".', user.name)
# Generate receipts that are due
receipt: Receipt = db.create_receipt(user, write=True)
receipts.append(receipt)
else:
logger.debug('No receipt due for user "%s".', user.name)
# Send all generated receipts via e-mailgi
if len(receipts) > 0:
_send_receipt_mails(receipts, jinja_env, logger, config)
def _send_receipt_mails(receipts: List[Receipt],
jinja_env: Environment,
logger: logging.Logger,
config: Dict[str, str]) -> None:
mails: List[Tuple[str, MIMEMultipart]] = []
for receipt in receipts:
if receipt.user.email is None:
continue
# Create a new message object
msg: MIMEMultipart = MIMEMultipart()
msg['From'] = config['SmtpFrom']
msg['To'] = receipt.user.email
msg['Subject'] = config['SmtpSubj']
# Format the receipt properties for the text representation
fdate: str = receipt.from_date.strftime('%d.%m.%Y, %H:%M')
tdate: str = receipt.to_date.strftime('%d.%m.%Y, %H:%M')
username: str = receipt.user.name.rjust(40)
if len(receipt.transactions) == 0:
fbal: str = format_chf(receipt.user.balance).rjust(12)
else:
fbal = format_chf(receipt.transactions[0].old_balance).rjust(12)
tbal: str = format_chf(receipt.user.balance).rjust(12)
# Render the receipt
template: Template = jinja_env.get_template('receipt.txt')
rendered: str = template.render(fdate=fdate, tdate=tdate, user=username, fbal=fbal, tbal=tbal,
receipt_id=receipt.id, transactions=receipt.transactions,
instance_name=config['InstanceName'])
# Put the rendered receipt in the message body
body: MIMEText = MIMEText(rendered)
msg.attach(body)
mails.append((receipt.user.email, msg))
# Connect to the SMTP Server
con: smtp.SMTP = smtp.SMTP(config['SmtpHost'], config['SmtpPort'])
try:
# Attempt to upgrade to a TLS connection
try:
con.starttls()
except BaseException:
# If STARTTLS failed, only continue if explicitly requested by configuration
if config['SmtpEnforceTLS'] != '0':
logger.error('STARTTLS not supported by SMTP server, aborting!')
return
else:
logger.warning('Sending e-mails in plain text as requested by SmtpEnforceTLS=0.')
# Send SMTP login credentials
con.login(config['SmtpUser'], config['SmtpPass'])
# Send the e-mails
for to, msg in mails:
logger.info('Sending mail to %s', to)
con.sendmail(config['SmtpFrom'], to, msg.as_string())
except smtp.SMTPException as e:
logger.exception('Exception while sending receipt e-mails', exc_info=e)
finally:
con.close()

View file

@ -30,6 +30,15 @@ Name=Matemat
UploadDir= /var/test/static/upload UploadDir= /var/test/static/upload
DatabaseFile=/var/test/db/test.db DatabaseFile=/var/test/db/test.db
SmtpSendReceipts=1
SmtpEnforceTLS=0
SmtpFrom=matemat@example.com
SmtpSubj=Matemat Receipt
SmtpHost=smtp.example.com
SmtpPort=587
SmtpUser=matemat@example.com
SmtpPass=SuperSecurePassword
[HttpHeaders] [HttpHeaders]
Content-Security-Policy = default-src: 'self'; Content-Security-Policy = default-src: 'self';
X-I-Am-A-Header = andthisismyvalue X-I-Am-A-Header = andthisismyvalue
@ -42,6 +51,8 @@ Port=443
[Pagelets] [Pagelets]
Name=Matemat (Unit Test 2) Name=Matemat (Unit Test 2)
SmtpSendReceipts=1
[HttpHeaders] [HttpHeaders]
X-I-Am-A-Header = andthisismyothervalue X-I-Am-A-Header = andthisismyothervalue
''' '''
@ -153,6 +164,14 @@ class TestConfig(TestCase):
self.assertEqual('Matemat\n(Unit Test)', config['pagelet_variables']['Name']) self.assertEqual('Matemat\n(Unit Test)', config['pagelet_variables']['Name'])
self.assertEqual('/var/test/static/upload', config['pagelet_variables']['UploadDir']) self.assertEqual('/var/test/static/upload', config['pagelet_variables']['UploadDir'])
self.assertEqual('/var/test/db/test.db', config['pagelet_variables']['DatabaseFile']) self.assertEqual('/var/test/db/test.db', config['pagelet_variables']['DatabaseFile'])
self.assertEqual('1', config['pagelet_variables']['SmtpSendReceipts'])
self.assertEqual('0', config['pagelet_variables']['SmtpEnforceTLS'])
self.assertEqual('matemat@example.com', config['pagelet_variables']['SmtpFrom'])
self.assertEqual('Matemat Receipt', config['pagelet_variables']['SmtpSubj'])
self.assertEqual('smtp.example.com', config['pagelet_variables']['SmtpHost'])
self.assertEqual('587', config['pagelet_variables']['SmtpPort'])
self.assertEqual('matemat@example.com', config['pagelet_variables']['SmtpUser'])
self.assertEqual('SuperSecurePassword', config['pagelet_variables']['SmtpPass'])
self.assertIsInstance(config['headers'], dict) self.assertIsInstance(config['headers'], dict)
self.assertEqual(2, len(config['headers'])) self.assertEqual(2, len(config['headers']))
self.assertEqual('default-src: \'self\';', config['headers']['Content-Security-Policy']) self.assertEqual('default-src: \'self\';', config['headers']['Content-Security-Policy'])

View file

@ -0,0 +1,65 @@
from typing import Dict
import unittest
import logging
from threading import Lock, Thread, Timer
from time import sleep
import jinja2
from matemat.webserver import MatematWebserver, pagelet_cron
lock: Lock = Lock()
cron1called: int = 0
cron2called: int = 0
@pagelet_cron(seconds=4)
def cron1(config: Dict[str, str],
jinja_env: jinja2.Environment,
logger: logging.Logger) -> None:
global cron1called
with lock:
cron1called += 1
@pagelet_cron(seconds=3)
def cron2(config: Dict[str, str],
jinja_env: jinja2.Environment,
logger: logging.Logger) -> None:
global cron2called
with lock:
cron2called += 1
class TestPageletCron(unittest.TestCase):
def setUp(self):
self.srv = MatematWebserver('::1', 0, '/nonexistent', '/nonexistent', {}, {},
logging.NOTSET, logging.NullHandler())
self.srv_port = int(self.srv._httpd.socket.getsockname()[1])
self.timer = Timer(10.0, self.srv._httpd.shutdown)
self.timer.start()
def tearDown(self):
self.timer.cancel()
if self.srv is not None:
self.srv._httpd.socket.close()
def test_cron(self):
"""
Test that the cron functions are called properly.
"""
thread = Thread(target=self.srv.start)
thread.start()
sleep(12)
self.srv._httpd.shutdown()
with lock:
self.assertEqual(2, cron1called)
self.assertEqual(3, cron2called)
# Make sure the cron threads were stopped
sleep(5)
with lock:
self.assertEqual(2, cron1called)
self.assertEqual(3, cron2called)

View file

@ -8,6 +8,15 @@
<label for="admin-myaccount-email">E-Mail: </label> <label for="admin-myaccount-email">E-Mail: </label>
<input id="admin-myaccount-email" type="text" name="email" value="{% if authuser.email is not none %}{{ authuser.email }}{% endif %}" /><br/> <input id="admin-myaccount-email" type="text" name="email" value="{% if authuser.email is not none %}{{ authuser.email }}{% endif %}" /><br/>
<label for="admin-myaccount-receipt-pref">Receipts: </label>
<select id="admin-myaccount-receipt-pref" name="receipt_pref">
{% for pref in receipt_preference_class %}
<option value="{{ pref.value }}" {% if authuser.receipt_pref == pref %} selected="selected" {% endif %}>{{ pref.human_readable }}</option>
{% endfor %}
</select>
{% if config_smtp_enabled != '1' %}Sending receipts is disabled by your administrator.{% endif %}
<br/>
<label for="admin-myaccount-ismember">Member: </label> <label for="admin-myaccount-ismember">Member: </label>
<input id="admin-myaccount-ismember" type="checkbox" disabled="disabled" {% if authuser.is_member %} checked="checked" {% endif %}/><br/> <input id="admin-myaccount-ismember" type="checkbox" disabled="disabled" {% if authuser.is_member %} checked="checked" {% endif %}/><br/>

View file

@ -20,6 +20,15 @@
<label for="moduser-account-password">Password: </label> <label for="moduser-account-password">Password: </label>
<input id="moduser-account-password" type="password" name="password" /><br/> <input id="moduser-account-password" type="password" name="password" /><br/>
<label for="moduser-account-receipt-pref">Receipts: </label>
<select id="moduser-account-receipt-pref" name="receipt_pref">
{% for pref in receipt_preference_class %}
<option value="{{ pref.value }}" {% if user.receipt_pref == pref %} selected="selected" {% endif %}>{{ pref.human_readable }}</option>
{% endfor %}
</select>
{% if config_smtp_enabled != '1' %}Sending receipts is disabled by your administrator.{% endif %}
<br/>
<label for="moduser-account-ismember">Member: </label> <label for="moduser-account-ismember">Member: </label>
<input id="moduser-account-ismember" name="ismember" type="checkbox" {% if user.is_member %} checked="checked" {% endif %}/><br/> <input id="moduser-account-ismember" name="ismember" type="checkbox" {% if user.is_member %} checked="checked" {% endif %}/><br/>
@ -29,6 +38,9 @@
<label for="moduser-account-balance">Balance: </label> <label for="moduser-account-balance">Balance: </label>
CHF <input id="moduser-account-balance" name="balance" type="number" step="0.01" value="{{ user.balance|chf(False) }}" /><br/> CHF <input id="moduser-account-balance" name="balance" type="number" step="0.01" value="{{ user.balance|chf(False) }}" /><br/>
<label for="moduser-account-balance-reason">Reason for balance modification: </label>
<input id="moduser-account-balance-reason" type="text" name="reason" placeholder="Shows up on receipt" /><br/>
<label for="moduser-account-avatar"> <label for="moduser-account-avatar">
<img height="150" src="/upload/thumbnails/users/{{ user.id }}.png" alt="Avatar of {{ user.name }}" /> <img height="150" src="/upload/thumbnails/users/{{ user.id }}.png" alt="Avatar of {{ user.name }}" />
</label><br/> </label><br/>

26
templates/receipt.txt Normal file
View file

@ -0,0 +1,26 @@
===================================================================
MATEMAT RECEIPT
===================================================================
User: {{ user|safe }}
Accounting period: {{ fdate|safe }} -- {{ tdate|safe }}
Opening balance: {{ fbal|safe }}
Transactions:
{% for t in transactions %}
{% include 'transaction.txt' %}
{% endfor %}
------------
Closing balance: {{ tbal|safe }}
===================================================================
{{ instance_name|striptags }}{% if receipt_id > 1 %}
Receipt N° {{ receipt_id|safe }}{% endif %}
This receipt is only provided for informational purposes and has no
legal force.

View file

@ -0,0 +1,2 @@
{{ t.receipt_date|safe }} {{ t.receipt_description.ljust(36)|safe }} {% if t.receipt_message is none %}{{ t.receipt_value.rjust(8)|safe }}{% else %}
{{ t.receipt_message.ljust(36)|safe }} {{ t.receipt_value.rjust(8)|safe }}{% endif %}