diff --git a/matemat/db/facade.py b/matemat/db/facade.py index 9d4d924..b166a32 100644 --- a/matemat/db/facade.py +++ b/matemat/db/facade.py @@ -577,6 +577,8 @@ class MatematDatabase(object): user.balance = old_balance + amount def check_receipt_due(self, user: User) -> bool: + if not isinstance(user.receipt_pref, ReceiptPreference): + raise TypeError() if user.receipt_pref == ReceiptPreference.NONE or user.email is None: return False with self.db.transaction() as c: @@ -587,14 +589,10 @@ class MatematDatabase(object): ON r.user_id = u.user_id WHERE u.user_id = :user_id ''', [user.id]) - last_receipt: int = c.fetchone()[0] - if user.receipt_pref == ReceiptPreference.MONTHLY: - date_diff: int = timedelta(days=31).total_seconds() - elif user.receipt_pref == ReceiptPreference.YEARLY: - date_diff = timedelta(days=365).total_seconds() - else: - raise ValueError() - return datetime.utcnow().timestamp() > last_receipt + date_diff + last_receipt: datetime = datetime.fromtimestamp(c.fetchone()[0]) + next_receipt_due: datetime = user.receipt_pref.next_receipt_due(last_receipt) + + return datetime.utcnow() > next_receipt_due def create_receipt(self, user: User, write: bool = False) -> Receipt: transactions: List[Transaction] = [] diff --git a/matemat/db/primitives/ReceiptPreference.py b/matemat/db/primitives/ReceiptPreference.py index 77120f1..fb15f18 100644 --- a/matemat/db/primitives/ReceiptPreference.py +++ b/matemat/db/primitives/ReceiptPreference.py @@ -1,5 +1,10 @@ +from typing import Callable + from enum import Enum +from datetime import datetime, timedelta +from matemat.util.monthdelta import add_months + class ReceiptPreference(Enum): """ @@ -10,8 +15,10 @@ class ReceiptPreference(Enum): e = object.__new__(cls) # The enum's internal value e._value_: int = args[0] + # The function calculating the date after which a new receipt is due. + e._datefunc: Callable[[datetime], datetime] = args[1] # The human-readable description - e._human_readable: str = args[1] + e._human_readable: str = args[2] return e @property @@ -21,17 +28,35 @@ class ReceiptPreference(Enum): """ return self._human_readable + def next_receipt_due(self, d: datetime) -> datetime: + return self._datefunc(d) + """ No receipts should be generated. """ - NONE = 0, 'No receipts' + NONE = 0, (lambda d: None), 'No receipts' + + """ + A receipt should be generated once a week. + """ + WEEKLY = 1, (lambda d: d + timedelta(weeks=1)), 'Weekly' """ A receipt should be generated once a month. """ - MONTHLY = 2, 'Aggregated, monthly' + MONTHLY = 2, (lambda d: add_months(d, 1)), 'Monthly' + + """ + A receipt should be generated once every three month. + """ + QUARTERLY = 3, (lambda d: add_months(d, 3)), 'Quarterly' + + """ + A receipt should be generated once every six month. + """ + BIANNUALLY = 4, (lambda d: add_months(d, 6)), 'Biannually' """ A receipt should be generated once a year. """ - YEARLY = 3, 'Aggregated, yearly' + YEARLY = 5, (lambda d: add_months(d, 12)), 'Annually' diff --git a/matemat/db/test/test_facade.py b/matemat/db/test/test_facade.py index d7b71bc..70cb4dc 100644 --- a/matemat/db/test/test_facade.py +++ b/matemat/db/test/test_facade.py @@ -486,7 +486,7 @@ class DatabaseTest(unittest.TestCase): self.assertFalse(self.db.check_receipt_due(user4)) self.assertTrue(self.db.check_receipt_due(user5)) self.assertFalse(self.db.check_receipt_due(user6)) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): self.db.check_receipt_due(user7) def test_create_receipt(self): diff --git a/matemat/util/monthdelta.py b/matemat/util/monthdelta.py new file mode 100644 index 0000000..49c1fa0 --- /dev/null +++ b/matemat/util/monthdelta.py @@ -0,0 +1,29 @@ + +from typing import Tuple + +from datetime import datetime, timedelta +import calendar + + +def add_months(d: datetime, months: int) -> datetime: + """ + Add the given number of months to the passed date, considering the varying numbers of days in a month. + :param d: The date time to add to. + :param months: The number of months to add to. + :return: A datetime object offset by the requested number of months. + """ + if not isinstance(d, datetime) or not isinstance(months, int): + raise TypeError() + if months < 0: + raise ValueError('Can only add a positive number of months.') + nextmonth: Tuple[int, int] = (d.year, d.month) + days: int = 0 + # Iterate the months between the passed date and the target month + for i in range(months): + days += calendar.monthlen(*nextmonth) + nextmonth = calendar.nextmonth(*nextmonth) + # Set the day of month temporarily to 1, then add the day offset to reach the 1st of the target month + newdate: datetime = d.replace(day=1) + timedelta(days=days) + # Re-set the day of month to the intended value, but capped by the max. day in the target month + newdate = newdate.replace(day=min(d.day, calendar.monthlen(newdate.year, newdate.month))) + return newdate diff --git a/matemat/util/test/test_monthdelta.py b/matemat/util/test/test_monthdelta.py new file mode 100644 index 0000000..98def9a --- /dev/null +++ b/matemat/util/test/test_monthdelta.py @@ -0,0 +1,61 @@ + +import unittest + +from datetime import datetime + +from matemat.util.monthdelta import add_months + + +class TestMonthDelta(unittest.TestCase): + + def test_monthdelta_zero(self): + date = datetime(2018, 9, 8, 13, 37, 42) + offset_date = date + self.assertEqual(offset_date, add_months(date, 0)) + + def test_monthdelta_one(self): + date = datetime(2018, 9, 8, 13, 37, 42) + offset_date = date.replace(month=10) + self.assertEqual(offset_date, add_months(date, 1)) + + def test_monthdelta_two(self): + date = datetime(2018, 9, 8, 13, 37, 42) + offset_date = date.replace(month=11) + self.assertEqual(offset_date, add_months(date, 2)) + + def test_monthdelta_yearwrap(self): + date = datetime(2018, 9, 8, 13, 37, 42) + offset_date = date.replace(year=2019, month=1) + self.assertEqual(offset_date, add_months(date, 4)) + + def test_monthdelta_yearwrap_five(self): + date = datetime(2018, 9, 8, 13, 37, 42) + offset_date = date.replace(year=2023, month=3) + self.assertEqual(offset_date, add_months(date, 54)) + + def test_monthdelta_rounddown_31_30(self): + date = datetime(2018, 3, 31, 13, 37, 42) + offset_date = date.replace(month=4, day=30) + self.assertEqual(offset_date, add_months(date, 1)) + + def test_monthdelta_rounddown_feb(self): + date = datetime(2018, 1, 31, 13, 37, 42) + offset_date = date.replace(month=2, day=28) + self.assertEqual(offset_date, add_months(date, 1)) + + def test_monthdelta_rounddown_feb_leap(self): + date = datetime(2020, 1, 31, 13, 37, 42) + offset_date = date.replace(month=2, day=29) + self.assertEqual(offset_date, add_months(date, 1)) + + def test_fail_negative(self): + date = datetime(2020, 1, 31, 13, 37, 42) + with self.assertRaises(ValueError): + add_months(date, -1) + + def test_fail_type(self): + date = datetime(2020, 1, 31, 13, 37, 42) + with self.assertRaises(TypeError): + add_months(date, 1.2) + with self.assertRaises(TypeError): + add_months(42, 1)