diff --git a/matemat/db/database.py b/matemat/db/database.py index 1b49c76..5c41300 100644 --- a/matemat/db/database.py +++ b/matemat/db/database.py @@ -254,7 +254,7 @@ class Database(object): FROM products '''): product_id, name, price_member, price_external = row - products.append(Product(product_id, name, price_member, price_external)) + products.append(Product.from_database(product_id, name, price_member, price_external)) return products def create_product(self, name: str, price_member: int, price_non_member: int) -> Product: @@ -273,9 +273,11 @@ class Database(object): }) c.execute('SELECT last_insert_rowid()') product_id = int(c.fetchone()[0]) - return Product(product_id, name, price_member, price_non_member) + return Product.from_database(product_id, name, price_member, price_non_member) def change_product(self, product: Product): + if product.id == -1: + raise ValueError('Invalid product ID') with self.transaction() as c: c.execute(''' UPDATE products @@ -291,6 +293,8 @@ class Database(object): }) def delete_product(self, product: Product): + if product.id == -1: + raise ValueError('Invalid product ID') with self.transaction() as c: c.execute(''' DELETE FROM products @@ -298,6 +302,8 @@ class Database(object): ''', [product.id]) def increment_consumption(self, user: User, product: Product, count: int = 1): + if product.id == -1: + raise ValueError('Invalid product ID') with self.transaction() as c: c.execute(''' SELECT count @@ -342,6 +348,8 @@ class Database(object): }) def restock(self, product: Product, count: int): + if product.id == -1: + raise ValueError('Invalid product ID') with self.transaction() as c: c.execute(''' UPDATE products diff --git a/matemat/primitives/Product.py b/matemat/primitives/Product.py index 9a97050..a761064 100644 --- a/matemat/primitives/Product.py +++ b/matemat/primitives/Product.py @@ -1,28 +1,24 @@ -from inspect import stack - - class Product(object): - def __init__(self): - print(stack()) - self._product_id: int = 0 - self._name: str = '' - self._price_member: int = 0 - self._price_non_member: int = 0 - raise NotImplementedError('This shoudt not be called!') + def __init__(self, + product_id: int = -1, + name: str = '', + price_member: int = 0, + price_non_member: int = 0): + self._product_id: int = product_id + self._name: str = name + self._price_member: int = price_member + self._price_non_member: int = price_non_member - # def __init__(self, - # product_id: int, - # name: str, - # price_member: int, - # price_non_member: int): - # if product_id == -1: - # raise ValueError('Invalid product ID') - # self._product_id: int = product_id - # self._name: str = name - # self._price_member: int = price_member - # self._price_non_member: int = price_non_member + @classmethod + def from_database(cls, + product_id: int, + name: str, + price_member: int, + price_non_member: int) -> 'Product': + product = cls(product_id, name, price_member, price_non_member) + return product @property def id(self) -> int: