class Setting(Model): __tablename__ = "management_setting" id = None key = Column(db.String(255), primary_key=True) value = Column(db.PickleType, nullable=False) name = Column(db.String(255), nullable=False) description = Column(db.Text, nullable=False) value_type = Column(db.Enum(SettingValueType), nullable=False) extra = Column(db.PickleType) @classmethod def get_settings(cls): settings = {} for s in cls.query.all(): settings[s.key] = s.value return settings @classmethod def update(cls, settings): """Updates the cache and stores the changes in the database. :param settings: A dictionary with setting items. """ # update the database for key, value in settings.items(): setting = cls.query.filter(Setting.key == key.lower()).first() setting.value = value db.session.add(setting) db.session.commit() def __repr__(self): return f"<{self.__class__.__name__} {self.key}>"
class Conversation(Model): __tablename__ = "conversation_dialogue" # this is actually the users message box user_id = Column(db.Integer(), nullable=False) # the user to whom the conversation is addressed from_user_id = Column(db.Integer()) # the user who sent the message to_user_id = db.Column(db.Integer()) shared_id = db.Column(db.String(100), nullable=False) subject = db.Column(db.String(255)) trash = db.Column(db.Boolean, default=False, nullable=False) draft = db.Column(db.Boolean, default=False, nullable=False) unread = db.Column(db.Boolean, default=False, nullable=False) @property def last_message(self): return ( Message.query.filter_by(conversation_id=self.id) .order_by(Message.id.desc()) .first() ) @property def from_user(self): return User.query.filter_by(id=self.from_user_id).first() @property def to_user(self): return User.query.filter_by(id=self.to_user_id).first()
class DashboardMenu(Model): __tablename__ = "management_dashboard" title = Column(db.String(255), nullable=False) order = Column(db.Integer(), default=0) endpoint = Column(db.String(255)) icon_cls = Column(db.String(255)) parent_id = Column(db.Integer(), default=0) def __str__(self): return self.title @property def children(self): return DashboardMenu.query.filter( DashboardMenu.parent_id == self.id).all() @classmethod def first_level_items(cls): return cls.query.filter(cls.parent_id == 0).order_by("order").all() def is_active(self): if self.endpoint and self.endpoint in request.path: return True if any((child.is_active() for child in self.children)): return True return False def get_url(self): if self.children: return "#" if self.endpoint: return url_for("dashboard." + self.endpoint)
class Page(Model): __tablename__ = "public_page" title = Column(db.String(255), nullable=False) slug = Column(db.String(255)) content = Column(db.Text()) is_visible = Column(db.Boolean(), default=True) if Config.USE_REDIS: content = PropsItem("content", "") def get_absolute_url(self): identity = self.slug or self.id return url_for("public.show_page", identity=identity) @classmethod @cache(MC_KEY_PAGE_ID.format("{identity}")) def get_by_identity(cls, identity): try: int(identity) except ValueError: return Page.query.filter(Page.slug == identity).first() return Page.get_by_id(identity) @property def url(self): return self.get_absolute_url() def __str__(self): return self.title @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) rdb.delete(MC_KEY_PAGE_ID.format(target.id)) rdb.delete(MC_KEY_PAGE_ID.format(target.slug))
class Collection(Model): __tablename__ = "product_collection" title = Column(db.String(255), nullable=False) background_img = Column(db.String(255)) def __str__(self): return self.title def get_absolute_url(self): return url_for("product.show_collection", id=self.id) @property def background_img_url(self): return url_for("static", filename=self.background_img) @property def products(self): at_ids = (ProductCollection.query.with_entities( ProductCollection.product_id).filter_by( collection_id=self.id).all()) return Product.query.filter(Product.id.in_(id for id, in at_ids)).all() @property def attr_filter(self): attr_filter = set() for product in self.products: for attr in product.product_type.product_attributes: attr_filter.add(attr) return attr_filter def update_products(self, new_products): origin_ids = (ProductCollection.query.with_entities( ProductCollection.product_id).filter_by( collection_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_products = set(int(i) for i in new_products) need_del = origin_ids - new_products need_add = new_products - origin_ids for id in need_del: ProductCollection.query.filter_by( collection_id=self.id, product_id=id).first().delete(commit=False) for id in need_add: new = ProductCollection(collection_id=self.id, product_id=id) db.session.add(new) db.session.commit() def delete(self): need_del = ProductCollection.query.filter_by( collection_id=self.id).all() for item in need_del: item.delete(commit=False) db.session.delete(self) db.session.commit() if self.background_img: image = current_app.config["STATIC_DIR"] / self.background_img if image.exists(): image.unlink()
class OrderLine(Model): __tablename__ = "order_line" product_name = Column(db.String(255)) product_sku = Column(db.String(100)) quantity = Column(db.Integer()) unit_price_net = Column(db.DECIMAL(10, 2)) is_shipping_required = Column(db.Boolean(), default=True) order_id = Column(db.Integer()) variant_id = Column(db.Integer()) @property def variant(self): return ProductVariant.get_by_id(self.variant_id) def get_total(self): return self.unit_price_net * self.quantity
class ShippingMethod(Model): __tablename__ = "checkout_shippingmethod" title = Column(db.String(255), nullable=False) price = Column(db.DECIMAL(10, 2)) def __str__(self): return self.title + " $" + str(self.price)
class ProductImage(Model): __tablename__ = "product_image" image = Column(db.String(255)) order = Column(db.Integer()) product_id = Column(db.Integer()) def __str__(self): return url_for("static", filename=self.image, _external=True) @staticmethod def clear_mc(target): rdb.delete(MC_KEY_PRODUCT_IMAGES.format(target.product_id)) @classmethod def __flush_insert_event__(cls, target): super().__flush_insert_event__(target) target.clear_mc(target) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) target.clear_mc(target) image_file = current_app.config["STATIC_DIR"] / target.image if image_file.exists(): image_file.unlink()
class PluginRegistry(Model): __tablename__ = "plugin_registry" name = Column(db.String(100), unique=True) enabled = Column(db.Boolean(), default=True) @property def info(self): return current_app.pluggy.plugin_metadata.get(self.name, {})
class OrderPayment(Model): __tablename__ = "order_payment" order_id = Column(db.Integer()) status = Column(TINYINT) total = Column(db.DECIMAL(10, 2)) delivery = Column(db.DECIMAL(10, 2)) description = Column(db.Text()) customer_ip_address = Column(db.String(100)) token = Column(db.String(100)) payment_method = Column(db.String(255)) payment_no = Column(db.String(255), unique=True) paid_at = Column(db.DateTime()) def pay_success(self, paid_at): self.paid_at = paid_at self.status = PaymentStatusKinds.confirmed.value order = Order.get_by_id(self.order_id) order.pay_success(payment=self)
class AttributeChoiceValue(Model): __tablename__ = "product_attribute_value" title = Column(db.String(255), nullable=False) attribute_id = Column(db.Integer()) def __str__(self): return self.title @property def attribute(self): return ProductAttribute.get_by_id(self.attribute_id)
class MenuItem(Model): __tablename__ = "public_menuitem" title = Column(db.String(255), nullable=False) order = Column(db.Integer(), default=0) url_ = Column("url", db.String(255)) category_id = Column(db.Integer(), default=0) collection_id = Column(db.Integer(), default=0) position = Column(db.Integer(), default=0) # item在site中的位置, 1是top,2是bottom page_id = Column(db.Integer(), default=0) parent_id = Column(db.Integer(), default=0) def __str__(self): return self.title @property def parent(self): return MenuItem.get_by_id(self.parent_id) @property @cache(MC_KEY_MENU_ITEM_CHILDREN.format("{self.id}")) def children(self): return ( MenuItem.query.filter(MenuItem.parent_id == self.id).order_by("order").all() ) @property def linked_object_url(self): if self.page_id: return Page.get_by_id(self.page_id).url elif self.category_id: return url_for("product.show_category", id=self.category_id) elif self.collection_id: return url_for("product.show_collection", id=self.collection_id) @property def url(self): return self.url_ if self.url_ else self.linked_object_url @classmethod def first_level_items(cls): return cls.query.filter(cls.parent_id == 0).order_by("order").all()
class OrderPayment(Model): __tablename__ = "order_payment" order_id = Column(db.Integer()) status = Column(db.Integer) total = Column(db.DECIMAL(10, 2)) delivery = Column(db.DECIMAL(10, 2)) description = Column(db.Text()) customer_ip_address = Column(db.String(100)) token = Column(db.String(100)) payment_method = Column(db.String(255)) payment_no = Column(db.String(255), unique=True) paid_at = Column(db.DateTime()) def pay_success(self, paid_at): self.paid_at = paid_at self.status = PaymentStatusKinds.confirmed.value self.save(commit=False) order = Order.get_by_id(self.order_id) order.pay_success(payment=self) @property def status_human(self): return PaymentStatusKinds(int(self.status)).name
class Site(Model): __tablename__ = "public_setting" header_text = Column(db.String(255), nullable=False) description = Column(db.Text()) top_menu_id = Column(db.Integer()) bottom_menu_id = Column(db.Integer()) @cache(MC_KEY_MENU_ITEMS.format("{self.id}", "{menu_id}")) def get_menu_items(self, menu_id): return (MenuItem.query.filter(MenuItem.site_id == menu_id).filter( MenuItem.parent_id == 0).order_by(MenuItem.order).all()) @property def top_menu_items(self): return self.get_menu_items(self.top_menu_id) @property def bottom_menu_items(self): return self.get_menu_items(self.bottom_menu_id)
class UserAddress(Model): __tablename__ = "account_address" user_id = Column(db.Integer()) province = Column(db.String(255)) city = Column(db.String(255)) district = Column(db.String(255)) address = Column(db.String(255)) contact_name = Column(db.String(255)) contact_phone = Column(db.String(80)) @property def full_address(self): return f"{self.province}{self.city}{self.district}<br>{self.address}<br>{self.contact_name}<br>{self.contact_phone}" @hybrid_property def user(self): return User.get_by_id(self.user_id) def __str__(self): return self.full_address
class User(Model, UserMixin): __tablename__ = "account_user" username = Column(db.String(80), unique=True, nullable=False, comment="user`s name") email = Column(db.String(80), unique=True, nullable=False) #: The hashed password _password = Column("password", db.String(128)) nick_name = Column(db.String(255)) is_active = Column(db.Boolean(), default=False) open_id = Column(db.String(80), index=True) session_key = Column(db.String(80), index=True) def __init__(self, username, email, password, **kwargs): super().__init__(username=username, email=email, password=password, **kwargs) def __str__(self): return self.username @hybrid_property def password(self): return self._password @password.setter def password(self, value): self._password = bcrypt.generate_password_hash(value) @property def avatar(self): return Gravatar(self.email).get_image() def check_password(self, value): """Check password.""" return bcrypt.check_password_hash(self.password, value) @property def addresses(self): return UserAddress.query.filter_by(user_id=self.id) @property def is_active_human(self): return "Y" if self.is_active else "N" @property def roles(self): at_ids = ( UserRole.query.with_entities(UserRole.role_id) .filter_by(user_id=self.id) .all() ) return Role.query.filter(Role.id.in_(id for id, in at_ids)).all() def delete(self): for addr in self.addresses: addr.delete() return super().delete() def can(self, permissions): if not self.roles: return False all_perms = reduce(or_, map(lambda x: x.permissions, self.roles)) return all_perms & permissions == permissions def can_admin(self): return self.can(Permission.ADMINISTER) def can_edit(self): return self.can(Permission.EDITOR)
class Order(Model): __tablename__ = "order_order" token = Column(db.String(100), unique=True) shipping_address = Column(db.String(255)) user_id = Column(db.Integer()) total_net = Column(db.DECIMAL(10, 2)) discount_amount = Column(db.DECIMAL(10, 2), default=0) discount_name = Column(db.String(100)) voucher_id = Column(db.Integer()) shipping_price_net = Column(db.DECIMAL(10, 2)) status = Column(TINYINT()) shipping_method_name = Column(db.String(100)) shipping_method_id = Column(db.Integer()) ship_status = Column(TINYINT()) def __str__(self): return f"#{self.identity}" @classmethod def create_whole_order(cls, cart, note=None): # Step1, certify stock, voucher to_update_variants = [] to_update_orderlines = [] total_net = 0 for line in cart.lines: variant = ProductVariant.get_by_id(line.variant.id) result, msg = variant.check_enough_stock(line.quantity) if result is False: return result, msg variant.quantity_allocated += line.quantity to_update_variants.append(variant) orderline = OrderLine( variant_id=variant.id, quantity=line.quantity, product_name=variant.display_product(), product_sku=variant.sku, product_id=variant.sku.split("-")[0], unit_price_net=variant.price, is_shipping_required=variant.is_shipping_required, ) to_update_orderlines.append(orderline) total_net += orderline.get_total() voucher = None if cart.voucher_code: voucher = Voucher.get_by_code(cart.voucher_code) try: voucher.check_available(cart) except Exception as e: return False, str(e) # Step2, create Order obj try: shipping_method_id = None shipping_method_title = None shipping_method_price = 0 shipping_address = None if cart.shipping_method_id: shipping_method = ShippingMethod.get_by_id( cart.shipping_method_id) shipping_method_id = shipping_method.id shipping_method_title = shipping_method.title shipping_method_price = shipping_method.price shipping_address = UserAddress.get_by_id( cart.shipping_address_id).full_address order = cls.create( user_id=current_user.id, token=str(uuid4()), shipping_method_id=shipping_method_id, shipping_method_name=shipping_method_title, shipping_price_net=shipping_method_price, shipping_address=shipping_address, status=OrderStatusKinds.unfulfilled.value, total_net=total_net, ) except Exception as e: return False, str(e) # Step3, process others if note: order_note = OrderNote(order_id=order.id, user_id=current_user.id, content=note) db.session.add(order_note) if voucher: order.voucher_id = voucher.id order.discount_amount = voucher.get_vouchered_price(cart) order.discount_name = voucher.title voucher.used += 1 db.session.add(order) db.session.add(voucher) for variant in to_update_variants: db.session.add(variant) for orderline in to_update_orderlines: orderline.order_id = order.id db.session.add(orderline) for line in cart.lines: db.session.delete(line) db.session.delete(cart) db.session.commit() return order, "success" def get_absolute_url(self): return url_for("order.show", token=self.token) @property def identity(self): return self.token.split("-")[-1] @property def total(self): return self.total_net + self.shipping_price_net - self.discount_amount @property def status_human(self): return OrderStatusKinds(int(self.status)).name @property def total_human(self): return "$" + str(self.total) @classmethod def get_current_user_orders(cls): if current_user.is_authenticated: orders = (cls.query.filter_by(user_id=current_user.id).order_by( Order.id.desc()).all()) else: orders = [] return orders @classmethod def get_user_orders(cls, user_id): return cls.query.filter_by(user_id=user_id).all() @property def is_shipping_required(self): return any(line.is_shipping_required for line in self.lines) @property def is_self_order(self): return self.user_id == current_user.id @property def lines(self): return OrderLine.query.filter(OrderLine.order_id == self.id).all() @property def notes(self): return OrderNote.query.filter(OrderNote.order_id == self.id).all() @property def user(self): return User.get_by_id(self.user_id) @property def payment(self): return OrderPayment.query.filter_by(order_id=self.id).first() def pay_success(self, payment): self.status = OrderStatusKinds.fulfilled.value # to resolve another instance with key is already present in this session local_obj = db.session.merge(self) db.session.add(local_obj) for line in self.lines: variant = line.variant variant.quantity_allocated -= line.quantity variant.quantity -= line.quantity db.session.add(variant) db.session.commit() OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.payment_captured.value, ) def cancel(self): self.status = OrderStatusKinds.canceled.value db.session.add(self) for line in self.lines: variant = line.variant variant.quantity_allocated -= line.quantity db.session.add(variant) db.session.commit() OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.order_canceled.value, ) def complete(self): self.update(status=OrderStatusKinds.completed.value) OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.order_completed.value, ) def draft(self): self.update(status=OrderStatusKinds.draft.value) OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.draft_created.value, ) def delivered(self): self.update( status=OrderStatusKinds.shipped.value, ship_status=ShipStatusKinds.delivered.value, ) OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.order_delivered.value, )
class ProductType(Model): __tablename__ = "product_type" title = Column(db.String(255), nullable=False) has_variants = Column(db.Boolean(), default=True) is_shipping_required = Column(db.Boolean(), default=False) def __str__(self): return self.title @property def product_attributes(self): at_ids = (ProductTypeAttributes.query.with_entities( ProductTypeAttributes.product_attribute_id).filter( ProductTypeAttributes.product_type_id == self.id).all()) return ProductAttribute.query.filter( ProductAttribute.id.in_(id for id, in at_ids)).all() @property def variant_attributes(self): at_ids = (ProductTypeVariantAttributes.query.with_entities( ProductTypeVariantAttributes.product_attribute_id).filter( ProductTypeVariantAttributes.product_type_id == self.id).all()) return ProductAttribute.query.filter( ProductAttribute.id.in_(id for id, in at_ids)).all() @property def variant_attr_id(self): if self.variant_attributes: return self.variant_attributes[0].id else: return None def update_product_attr(self, new_attrs): origin_ids = (ProductTypeAttributes.query.with_entities( ProductTypeAttributes.product_attribute_id).filter_by( product_type_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_attrs = set(int(i) for i in new_attrs) need_del = origin_ids - new_attrs need_add = new_attrs - origin_ids for id in need_del: ProductTypeAttributes.query.filter_by( product_type_id=self.id, product_attribute_id=id).first().delete(commit=False) for id in need_add: new = ProductTypeAttributes(product_type_id=self.id, product_attribute_id=id) db.session.add(new) db.session.commit() def update_variant_attr(self, variant_attr): origin_attr = ProductTypeVariantAttributes.query.filter_by( product_type_id=self.id).first() if origin_attr: origin_attr.product_attribute_id = variant_attr origin_attr.save() else: ProductTypeVariantAttributes.create( product_type_id=self.id, product_attribute_id=variant_attr) def delete(self): need_del_product_attrs = ProductTypeAttributes.query.filter_by( product_type_id=self.id).all() need_del_variant_attrs = ProductTypeVariantAttributes.query.filter_by( product_type_id=self.id).all() for item in itertools.chain(need_del_product_attrs, need_del_variant_attrs): item.delete(commit=False) need_update_products = Product.query.filter_by( product_type_id=self.id).all() for product in need_update_products: product.product_type_id = 0 db.session.add(product) db.session.delete(self) db.session.commit()
class Product(Model): __tablename__ = "product_product" title = Column(db.String(255), nullable=False) on_sale = Column(db.Boolean(), default=True) rating = Column(db.DECIMAL(8, 2), default=5.0) sold_count = Column(db.Integer(), default=0) review_count = Column(db.Integer(), default=0) basic_price = Column(db.DECIMAL(10, 2)) category_id = Column(db.Integer()) is_featured = Column(db.Boolean(), default=False) product_type_id = Column(db.Integer()) attributes = Column(MutableDict.as_mutable(db.JSON())) description = Column(db.Text()) if Config.USE_REDIS: description = PropsItem("description") def __str__(self): return self.title def __iter__(self): return iter(self.variants) def get_absolute_url(self): return url_for("product.show", id=self.id) @property @cache(MC_KEY_PRODUCT_IMAGES.format("{self.id}")) def images(self): return ProductImage.query.filter( ProductImage.product_id == self.id).all() @property def first_img(self): if self.images: return str(self.images[0]) return "" @property def is_in_stock(self): return any(variant.is_in_stock for variant in self) @property def category(self): return Category.get_by_id(self.category_id) @property def product_type(self): return ProductType.get_by_id(self.product_type_id) @property def is_discounted(self): if float(self.discounted_price) > 0: return True return False @property @cache(MC_KEY_PRODUCT_DISCOUNT_PRICE.format("{self.id}")) def discounted_price(self): from flaskshop.discount.models import Sale return Sale.get_discounted_price(self) @property def price(self): if self.is_discounted: return self.basic_price - self.discounted_price return self.basic_price @property def price_human(self): return "$" + str(self.price) @property def on_sale_human(self): return "Y" if self.on_sale else "N" @property @cache(MC_KEY_PRODUCT_VARIANT.format("{self.id}")) def variant(self): return ProductVariant.query.filter( ProductVariant.product_id == self.id).all() @property def attribute_map(self): items = { ProductAttribute.get_by_id(k): AttributeChoiceValue.get_by_id(v) for k, v in self.attributes.items() } return items @classmethod # @cache(MC_KEY_FEATURED_PRODUCTS.format("{num}")) def get_featured_product(cls, num=8): # 首頁的 featured products return cls.query.filter_by(is_featured=True).limit(num).all() def update_images(self, new_images): origin_ids = (ProductImage.query.with_entities( ProductImage.product_id).filter_by(product_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_images = set(int(i) for i in new_images) need_del = origin_ids - new_images need_add = new_images - origin_ids for id in need_del: ProductImage.get_by_id(id).delete(commit=False) for id in need_add: image = ProductImage.get_by_id(id) image.product_id = self.id image.save(commit=False) db.session.commit() def update_attributes(self, attr_values): attr_entries = [ str(item.id) for item in self.product_type.product_attributes ] attributes = dict(zip(attr_entries, attr_values)) self.attributes = attributes def generate_variants(self): if not self.product_type.has_variants: ProductVariant.create(sku=str(self.id) + "-1337", product_id=self.id) else: sku_id = 1337 variant_attributes = self.product_type.variant_attributes[0] for value in variant_attributes.values: sku = str(self.id) + "-" + str(sku_id) attributes = {str(variant_attributes.id): str(value.id)} ProductVariant.create( sku=sku, title=value.title, product_id=self.id, attributes=attributes, ) sku_id += 1 def delete(self): need_del_collection_products = ProductCollection.query.filter_by( product_id=self.id).all() for item in itertools.chain(self.images, self.variant, need_del_collection_products): item.delete(commit=False) db.session.delete(self) db.session.commit() @staticmethod def clear_mc(target): rdb.delete(MC_KEY_PRODUCT_DISCOUNT_PRICE.format(target.id)) keys = rdb.keys(MC_KEY_FEATURED_PRODUCTS.format("*")) for key in keys: rdb.delete(key) @staticmethod def clear_category_cache(target): keys = rdb.keys( MC_KEY_CATEGORY_PRODUCTS.format(target.category_id, "*")) for key in keys: rdb.delete(key) @classmethod def __flush_insert_event__(cls, target): super().__flush_insert_event__(target) if current_app.config["USE_ES"]: from flaskshop.public.search import Item Item.add(target) @classmethod def __flush_before_update_event__(cls, target): super().__flush_before_update_event__(target) target.clear_category_cache(target) @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) target.clear_mc(target) target.clear_category_cache(target) if current_app.config["USE_ES"]: from flaskshop.public.search import Item Item.update_item(target) @classmethod def __flush_delete_event__(cls, target): from flaskshop.public.search import Item super().__flush_delete_event__(target) target.clear_mc(target) target.clear_category_cache(target) Item.delete(target)
class Cart(Model): __tablename__ = "checkout_cart" user_id = Column(db.Integer()) voucher_code = Column(db.String(255)) quantity = Column(db.Integer()) shipping_address_id = Column(db.Integer()) shipping_method_id = Column(db.Integer()) @property def subtotal(self): return sum(line.subtotal for line in self) @property def total(self): return self.subtotal + self.shipping_method_price - self.discount_amount @property def discount_amount(self): return self.voucher.get_vouchered_price( self) if self.voucher_code else 0 @property def lines(self): return CartLine.query.filter(CartLine.cart_id == self.id).all() @classmethod @cache(MC_KEY_CART_BY_USER.format("{user_id}")) def get_cart_by_user_id(cls, user_id): return cls.query.filter_by(user_id=user_id).first() @classmethod def get_current_user_cart(cls): if current_user.is_authenticated: cart = cls.get_cart_by_user_id(current_user.id) else: cart = None return cart @classmethod def add_to_currentuser_cart(cls, quantity, variant_id): cart = cls.get_current_user_cart() variant = ProductVariant.get_by_id(variant_id) result, msg = variant.check_enough_stock(quantity) if result is False: flash(msg, "warning") return if cart: cart.quantity += quantity cart.save() else: cart = cls.create(user_id=current_user.id, quantity=quantity) line = CartLine.query.filter_by(cart_id=cart.id, variant_id=variant_id).first() if line: quantity += line.quantity line.update(quantity=quantity) else: CartLine.create(variant_id=variant_id, quantity=quantity, cart_id=cart.id) def get_product_price(self, product_id): price = 0 for line in self: if line.product.id == product_id: price += line.subtotal return price def get_category_price(self, category_id): price = 0 for line in self: if line.category.id == category_id: price += line.subtotal return price @property def is_shipping_required(self): return any(line.is_shipping_required for line in self) @property def shipping_method(self): return ShippingMethod.get_by_id(self.shipping_method_id) @property def shipping_method_price(self): if self.shipping_method: return self.shipping_method.price return 0 @property def voucher(self): if self.voucher_code: return Voucher.get_by_code(self.voucher_code) return None def __repr__(self): return f"Cart(quantity={self.quantity})" def __iter__(self): return iter(self.lines) def __len__(self): return len(self.lines) def update_quantity(self): self.quantity = sum(line.quantity for line in self) if self.quantity == 0: self.delete() else: self.save() return self.quantity @classmethod def __flush_insert_event__(cls, target): rdb.delete(MC_KEY_CART_BY_USER.format(current_user.id)) @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) rdb.delete(MC_KEY_CART_BY_USER.format(current_user.id)) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) rdb.delete(MC_KEY_CART_BY_USER.format(current_user.id))
class Sale(Model): __tablename__ = "discount_sale" discount_value_type = Column(TINYINT()) title = Column(db.String(255)) discount_value = Column(db.DECIMAL(10, 2)) def __str__(self): return self.title @property def discount_value_type_label(self): return DiscountValueTypeKinds(int(self.discount_value_type)).name @classmethod def get_discounted_price(cls, product): sale_product = SaleProduct.query.filter_by( product_id=product.id).first() if sale_product: sale = Sale.get_by_id(sale_product.sale_id) else: sale_category = SaleCategory.query.filter_by( category_id=product.category.id).first() sale = Sale.get_by_id( sale_category.sale_id) if sale_category else None if sale is None: return 0 if sale.discount_value_type == DiscountValueTypeKinds.fixed.value: return sale.discount_value elif sale.discount_value_type == DiscountValueTypeKinds.percent.value: price = product.basic_price * sale.discount_value / 100 return Decimal(price).quantize(Decimal("0.00")) @property def categories(self): at_ids = (SaleCategory.query.with_entities( SaleCategory.category_id).filter( SaleCategory.sale_id == self.id).all()) return Category.query.filter(Category.id.in_(id for id, in at_ids)).all() @property def products_ids(self): return (SaleProduct.query.with_entities(SaleProduct.product_id).filter( SaleProduct.sale_id == self.id).all()) @property def products(self): return Product.query.filter( Product.id.in_(id for id, in self.products_ids)).all() def update_categories(self, category_ids): origin_ids = (SaleCategory.query.with_entities( SaleCategory.category_id).filter_by(sale_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_attrs = set(int(i) for i in category_ids) need_del = origin_ids - new_attrs need_add = new_attrs - origin_ids for id in need_del: SaleCategory.query.filter_by( sale_id=self.id, category_id=id).first().delete(commit=False) for id in need_add: new = SaleCategory(sale_id=self.id, category_id=id) db.session.add(new) db.session.commit() def update_products(self, product_ids): origin_ids = (SaleProduct.query.with_entities( SaleProduct.product_id).filter_by(sale_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_attrs = set(int(i) for i in product_ids) need_del = origin_ids - new_attrs need_add = new_attrs - origin_ids for id in need_del: SaleProduct.query.filter_by( sale_id=self.id, product_id=id).first().delete(commit=False) for id in need_add: new = SaleProduct(sale_id=self.id, product_id=id) db.session.add(new) db.session.commit() @staticmethod def clear_mc(target): # when update sales, need to update product discounts # for (id,) in target.products_ids: # rdb.delete(MC_KEY_PRODUCT_DISCOUNT_PRICE.format(id)) # need to process so many states, category update etc.. so delete all keys = rdb.keys(MC_KEY_PRODUCT_DISCOUNT_PRICE.format("*")) for key in keys: rdb.delete(key) @classmethod def __flush_insert_event__(cls, target): super().__flush_insert_event__(target) target.clear_mc(target) @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) target.clear_mc(target) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) target.clear_mc(target)
class ProductVariant(Model): __tablename__ = "product_variant" sku = Column(db.String(32), unique=True) title = Column(db.String(255)) price_override = Column(db.DECIMAL(10, 2), default=0.00) quantity = Column(db.Integer(), default=0) quantity_allocated = Column(db.Integer(), default=0) product_id = Column(db.Integer(), default=0) attributes = Column(MutableDict.as_mutable(db.JSON())) def __str__(self): return self.title or self.sku def display_product(self): return f"{self.product} ({str(self)})" @property def sku_id(self): return self.sku.split("-")[1] @sku_id.setter def sku_id(self, data): pass @property def is_shipping_required(self): return self.product.product_type.is_shipping_required @property def quantity_available(self): return max(self.quantity - self.quantity_allocated, 0) @property def is_in_stock(self): return self.quantity_available > 0 @property def stock(self): return self.quantity - self.quantity_allocated @property def price(self): return self.price_override or self.product.price @property def product(self): return Product.get_by_id(self.product_id) def get_absolute_url(self): return url_for("product.show", id=self.product.id) @property def attribute_map(self): items = { ProductAttribute.get_by_id(k): AttributeChoiceValue.get_by_id(v) for k, v in self.attributes.items() } return items def check_enough_stock(self, quantity): if self.stock < quantity: return False, f"{self.display_product()} has not enough stock" return True, "success" @staticmethod def clear_mc(target): rdb.delete(MC_KEY_PRODUCT_VARIANT.format(target.product_id)) @classmethod def __flush_insert_event__(cls, target): super().__flush_insert_event__(target) target.clear_mc(target) @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) target.clear_mc(target) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) target.clear_mc(target)
class Voucher(Model): __tablename__ = "discount_voucher" type_ = Column("type", TINYINT()) title = Column(db.String(255)) code = Column(db.String(16), unique=True) usage_limit = Column(db.Integer()) used = Column(db.Integer(), default=0) start_date = Column(db.Date()) end_date = Column(db.Date()) discount_value_type = Column(TINYINT()) discount_value = Column(db.DECIMAL(10, 2)) limit = Column(db.DECIMAL(10, 2)) category_id = Column(db.Integer()) product_id = Column(db.Integer()) def __str__(self): return self.title @property def type_human(self): return VoucherTypeKinds(int(self.type_)).name @property def discount_value_type_human(self): return DiscountValueTypeKinds(int(self.discount_value_type)).name @property def validity_period(self): if self.start_date and self.end_date: return (datetime.strftime(self.start_date, "%m/%d/%Y") + " - " + datetime.strftime(self.end_date, "%m/%d/%Y")) return "" @classmethod def generate_code(cls): code = "".join(random.choices(string.ascii_uppercase, k=16)) exist = cls.query.filter_by(code=code).first() if not exist: return code else: return cls.generate_code() def check_available(self, cart=None): if self.start_date and self.start_date > datetime.now(): raise Exception( "The voucher code can not use now, please retry later") if self.end_date and self.end_date < datetime.now(): raise Exception("The voucher code has expired") if self.usage_limit and self.usage_limit - self.used < 0: raise Exception("This voucher code has been used out") if cart: self.check_available_by_cart(cart) return True def check_available_by_cart(self, cart): if self.type == VoucherTypeKinds.value.value: if self.limit and cart.subtotal < self.limit: raise Exception( f"The order total amount is not enough({self.limit}) to use this voucher code" ) elif self.type == VoucherTypeKinds.shipping.value: if self.limit and cart.shipping_method_price < self.limit: raise Exception( f"The order shipping price is not enough({self.limit}) to use this voucher code" ) elif self.type == VoucherTypeKinds.product.value: product = Product.get_by_id(self.product_id) # got any product in cart, should be zero if cart.get_product_price(self.product_id) == 0: raise Exception( f"This Voucher Code should be used for {product.title}") if self.limit and cart.get_product_price( self.product_id) < self.limit: raise Exception( f"The product {product.title} total amount is not enough({self.limit}) to use this voucher code" ) elif self.type == VoucherTypeKinds.category.value: category = Category.get_by_id(self.category_id) if cart.get_category_price(self.category_id) == 0: raise Exception( f"This Voucher Code should be used for {category.title}") if self.limit and cart.get_category_price( self.category_id) < self.limit: raise Exception( f"The category {category.title} total amount is not enough({self.limit}) to use this voucher code" ) @classmethod def get_by_code(cls, code): return cls.query.filter_by(code=code).first() def get_vouchered_price(self, cart): if self.type == VoucherTypeKinds.value.value: return self.get_voucher_from_price(cart.subtotal) elif self.type == VoucherTypeKinds.shipping.value: return self.get_voucher_from_price(cart.shipping_method_price) elif self.type == VoucherTypeKinds.product.value: return self.get_voucher_from_price( cart.get_product_price(self.product_id)) elif self.type == VoucherTypeKinds.category.value: return self.get_voucher_from_price( cart.get_category_price(self.category_id)) return 0 def get_voucher_from_price(self, price): if self.discount_value_type == DiscountValueTypeKinds.fixed.value: return self.discount_value if price > self.discount_value else price elif self.discount_value_type == DiscountValueTypeKinds.percent.value: price = price * self.discount_value / 100 return Decimal(price).quantize(Decimal("0.00"))
class ProductAttribute(Model): __tablename__ = "product_attribute" title = Column(db.String(255), nullable=False) def __str__(self): return self.title @property @cache(MC_KEY_ATTRIBUTE_VALUES.format("{self.id}")) def values(self): return AttributeChoiceValue.query.filter( AttributeChoiceValue.attribute_id == self.id).all() @property def values_label(self): return ",".join([value.title for value in self.values]) @property def types(self): at_ids = (ProductTypeAttributes.query.with_entities( ProductTypeAttributes.product_type_id).filter_by( product_attribute_id=self.id).all()) return ProductType.query.filter( ProductType.id.in_(id for id, in at_ids)).all() @property def types_label(self): return ",".join([t.title for t in self.types]) def update_values(self, new_values): origin_values = list(value.title for value in self.values) need_del = set() need_add = set() for value in self.values: if value.title not in new_values: need_del.add(value) for value in new_values: if value not in origin_values: need_add.add(value) for value in need_del: value.delete(commit=False) for value in need_add: new = AttributeChoiceValue(title=value, attribute_id=self.id) db.session.add(new) db.session.commit() def update_types(self, new_types): origin_ids = (ProductTypeAttributes.query.with_entities( ProductTypeAttributes.product_type_id).filter_by( product_attribute_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_types = set(int(i) for i in new_types) need_del = origin_ids - new_types need_add = new_types - origin_ids for id in need_del: ProductTypeAttributes.query.filter_by( product_attribute_id=self.id, product_type_id=id).first().delete(commit=False) for id in need_add: new = ProductTypeAttributes(product_attribute_id=self.id, product_type_id=id) db.session.add(new) db.session.commit() def delete(self): need_del_product_attrs = ProductTypeAttributes.query.filter_by( product_attribute_id=self.id).all() need_del_variant_attrs = ProductTypeVariantAttributes.query.filter_by( product_attribute_id=self.id).all() for item in itertools.chain(need_del_product_attrs, need_del_variant_attrs, self.values): item.delete(commit=False) db.session.delete(self) db.session.commit() @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) rdb.delete(MC_KEY_ATTRIBUTE_VALUES.format(target.id)) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) rdb.delete(MC_KEY_ATTRIBUTE_VALUES.format(target.id))
class Category(Model): __tablename__ = "product_category" title = Column(db.String(255), nullable=False) parent_id = Column(db.Integer(), default=0) background_img = Column(db.String(255)) def __str__(self): return self.title def get_absolute_url(self): return url_for("product.show_category", id=self.id) @property def background_img_url(self): return url_for("static", filename=self.background_img) @property def products(self): all_category_ids = [child.id for child in self.children] + [self.id] return Product.query.filter( Product.category_id.in_(all_category_ids)).all() @property @cache(MC_KEY_CATEGORY_CHILDREN.format("{self.id}")) def children(self): return Category.query.filter(Category.parent_id == self.id).all() @property def parent(self): return Category.get_by_id(self.parent_id) @property def attr_filter(self): attr_filter = set() for product in self.products: for attr in product.product_type.product_attributes: attr_filter.add(attr) return attr_filter @classmethod @cache_by_args(MC_KEY_CATEGORY_PRODUCTS.format("{category_id}", "{page}")) def get_product_by_category(cls, category_id, page): category = Category.get_by_id(category_id) all_category_ids = [child.id for child in category.children] + [category.id] query = Product.query.filter(Product.category_id.in_(all_category_ids)) ctx, query = get_product_list_context(query, category) pagination = query.paginate(page, per_page=16) del pagination.query ctx.update(object=category, pagination=pagination, products=pagination.items) return ctx @classmethod def first_level_items(cls): return cls.query.filter(cls.parent_id == 0).all() def delete(self): for child in self.children: child.parent_id = 0 db.session.add(child) need_update_products = Product.query.filter_by( category_id=self.id).all() for product in need_update_products: product.category_id = 0 db.session.add(product) db.session.delete(self) db.session.commit() if self.background_img: image = current_app.config["STATIC_DIR"] / self.background_img if image.exists(): image.unlink() @staticmethod def clear_mc(target): rdb.delete(MC_KEY_CATEGORY_CHILDREN.format(target.id)) keys = rdb.keys(MC_KEY_CATEGORY_PRODUCTS.format(target.id, "*")) for key in keys: rdb.delete(key) @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) target.clear_mc(target) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) target.clear_mc(target)
class Role(Model): __tablename__ = "account_role" name = Column(db.String(80), unique=True) permissions = Column(db.Integer(), default=Permission.LOGIN)