class CartLine(Model): __tablename__ = "checkout_cartline" cart_id = Column(db.Integer()) quantity = Column(db.Integer()) variant_id = Column(db.Integer()) def __repr__(self): return f"CartLine(variant={self.variant}, quantity={self.quantity})" @property def is_shipping_required(self): return self.variant.is_shipping_required @property def variant(self): return ProductVariant.get_by_id(self.variant_id) @property def product(self): return self.variant.product @property def category(self): return self.product.category @property def subtotal(self): return self.variant.price * self.quantity
class ShippingMethod(Model): __tablename__ = "checkout_shippingmethod" title = Column(db.String(255), nullable=False) price = Column(db.DECIMAL(10, 2)) def __str__(self): return self.title + " $" + str(self.price)
class Conversation(Model): __tablename__ = "conversation_dialogue" # this is actually the users message box user_id = Column(db.Integer(), nullable=False) # the user to whom the conversation is addressed from_user_id = Column(db.Integer()) # the user who sent the message to_user_id = db.Column(db.Integer()) shared_id = db.Column(db.String(100), nullable=False) subject = db.Column(db.String(255)) trash = db.Column(db.Boolean, default=False, nullable=False) draft = db.Column(db.Boolean, default=False, nullable=False) unread = db.Column(db.Boolean, default=False, nullable=False) @property def last_message(self): return ( Message.query.filter_by(conversation_id=self.id) .order_by(Message.id.desc()) .first() ) @property def from_user(self): return User.query.filter_by(id=self.from_user_id).first() @property def to_user(self): return User.query.filter_by(id=self.to_user_id).first()
class ProductImage(Model): __tablename__ = "product_image" image = Column(db.String(255)) order = Column(db.Integer()) product_id = Column(db.Integer()) def __str__(self): return url_for("static", filename=self.image, _external=True) @staticmethod def clear_mc(target): rdb.delete(MC_KEY_PRODUCT_IMAGES.format(target.product_id)) @classmethod def __flush_insert_event__(cls, target): super().__flush_insert_event__(target) target.clear_mc(target) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) target.clear_mc(target) image_file = current_app.config["STATIC_DIR"] / target.image if image_file.exists(): image_file.unlink()
class DashboardMenu(Model): __tablename__ = "management_dashboard" title = Column(db.String(255), nullable=False) order = Column(db.Integer(), default=0) endpoint = Column(db.String(255)) icon_cls = Column(db.String(255)) parent_id = Column(db.Integer(), default=0) def __str__(self): return self.title @property def children(self): return DashboardMenu.query.filter( DashboardMenu.parent_id == self.id).all() @classmethod def first_level_items(cls): return cls.query.filter(cls.parent_id == 0).order_by("order").all() def is_active(self): if self.endpoint and self.endpoint in request.path: return True if any((child.is_active() for child in self.children)): return True return False def get_url(self): if self.children: return "#" if self.endpoint: return url_for("dashboard." + self.endpoint)
class Page(Model): __tablename__ = "public_page" title = Column(db.String(255), nullable=False) slug = Column(db.String(255)) content = Column(db.Text()) is_visible = Column(db.Boolean(), default=True) if Config.USE_REDIS: content = PropsItem("content", "") def get_absolute_url(self): identity = self.slug or self.id return url_for("public.show_page", identity=identity) @classmethod @cache(MC_KEY_PAGE_ID.format("{identity}")) def get_by_identity(cls, identity): try: int(identity) except ValueError: return Page.query.filter(Page.slug == identity).first() return Page.get_by_id(identity) @property def url(self): return self.get_absolute_url() def __str__(self): return self.title @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) rdb.delete(MC_KEY_PAGE_ID.format(target.id)) rdb.delete(MC_KEY_PAGE_ID.format(target.slug))
class PluginRegistry(Model): __tablename__ = "plugin_registry" name = Column(db.String(100), unique=True) enabled = Column(db.Boolean(), default=True) @property def info(self): return current_app.pluggy.plugin_metadata.get(self.name, {})
class Collection(Model): __tablename__ = "product_collection" title = Column(db.String(255), nullable=False) background_img = Column(db.String(255)) def __str__(self): return self.title def get_absolute_url(self): return url_for("product.show_collection", id=self.id) @property def background_img_url(self): return url_for("static", filename=self.background_img) @property def products(self): at_ids = (ProductCollection.query.with_entities( ProductCollection.product_id).filter_by( collection_id=self.id).all()) return Product.query.filter(Product.id.in_(id for id, in at_ids)).all() @property def attr_filter(self): attr_filter = set() for product in self.products: for attr in product.product_type.product_attributes: attr_filter.add(attr) return attr_filter def update_products(self, new_products): origin_ids = (ProductCollection.query.with_entities( ProductCollection.product_id).filter_by( collection_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_products = set(int(i) for i in new_products) need_del = origin_ids - new_products need_add = new_products - origin_ids for id in need_del: ProductCollection.query.filter_by( collection_id=self.id, product_id=id).first().delete(commit=False) for id in need_add: new = ProductCollection(collection_id=self.id, product_id=id) db.session.add(new) db.session.commit() def delete(self): need_del = ProductCollection.query.filter_by( collection_id=self.id).all() for item in need_del: item.delete(commit=False) db.session.delete(self) db.session.commit() if self.background_img: image = current_app.config["STATIC_DIR"] / self.background_img if image.exists(): image.unlink()
class Message(Model): __tablename__ = "conversation_messages" conversation_id = Column(db.Integer(), nullable=False) user_id = Column(db.Integer(), nullable=False) message = Column(db.Text, nullable=False) @property def user(self): return User.query.filter_by(id=self.user_id).first()
class AttributeChoiceValue(Model): __tablename__ = "product_attribute_value" title = Column(db.String(255), nullable=False) attribute_id = Column(db.Integer()) def __str__(self): return self.title @property def attribute(self): return ProductAttribute.get_by_id(self.attribute_id)
class OrderPayment(Model): __tablename__ = "order_payment" order_id = Column(db.Integer()) status = Column(TINYINT) total = Column(db.DECIMAL(10, 2)) delivery = Column(db.DECIMAL(10, 2)) description = Column(db.Text()) customer_ip_address = Column(db.String(100)) token = Column(db.String(100)) payment_method = Column(db.String(255)) payment_no = Column(db.String(255), unique=True) paid_at = Column(db.DateTime()) def pay_success(self, paid_at): self.paid_at = paid_at self.status = PaymentStatusKinds.confirmed.value order = Order.get_by_id(self.order_id) order.pay_success(payment=self)
class Site(Model): __tablename__ = "public_setting" header_text = Column(db.String(255), nullable=False) description = Column(db.Text()) top_menu_id = Column(db.Integer()) bottom_menu_id = Column(db.Integer()) @cache(MC_KEY_MENU_ITEMS.format("{self.id}", "{menu_id}")) def get_menu_items(self, menu_id): return (MenuItem.query.filter(MenuItem.site_id == menu_id).filter( MenuItem.parent_id == 0).order_by(MenuItem.order).all()) @property def top_menu_items(self): return self.get_menu_items(self.top_menu_id) @property def bottom_menu_items(self): return self.get_menu_items(self.bottom_menu_id)
class ProductCollection(Model): __tablename__ = "product_collection_product" product_id = Column(db.Integer()) collection_id = Column(db.Integer()) @classmethod @cache_by_args( MC_KEY_COLLECTION_PRODUCTS.format("{collection_id}", "{page}")) def get_product_by_collection(cls, collection_id, page): collection = Collection.get_by_id(collection_id) at_ids = (ProductCollection.query.with_entities( ProductCollection.product_id).filter( ProductCollection.collection_id == collection.id).all()) query = Product.query.filter(Product.id.in_(id for id, in at_ids)) ctx, query = get_product_list_context(query, collection) pagination = query.paginate(page, per_page=16) del pagination.query ctx.update(object=collection, pagination=pagination, products=pagination.items) return ctx @staticmethod def clear_mc(target): keys = rdb.keys( MC_KEY_COLLECTION_PRODUCTS.format(target.collection_id, "*")) for key in keys: rdb.delete(key) @classmethod def __flush_insert_event__(cls, target): target.clear_mc(target) @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) target.clear_mc(target) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) target.clear_mc(target)
class Setting(Model): __tablename__ = "management_setting" id = None key = Column(db.String(255), primary_key=True) value = Column(db.PickleType, nullable=False) name = Column(db.String(255), nullable=False) description = Column(db.Text, nullable=False) value_type = Column(db.Enum(SettingValueType), nullable=False) extra = Column(db.PickleType) @classmethod def get_settings(cls): settings = {} for s in cls.query.all(): settings[s.key] = s.value return settings @classmethod def update(cls, settings): """Updates the cache and stores the changes in the database. :param settings: A dictionary with setting items. """ # update the database for key, value in settings.items(): setting = cls.query.filter(Setting.key == key.lower()).first() setting.value = value db.session.add(setting) db.session.commit() def __repr__(self): return f"<{self.__class__.__name__} {self.key}>"
class OrderPayment(Model): __tablename__ = "order_payment" order_id = Column(db.Integer()) status = Column(db.Integer) total = Column(db.DECIMAL(10, 2)) delivery = Column(db.DECIMAL(10, 2)) description = Column(db.Text()) customer_ip_address = Column(db.String(100)) token = Column(db.String(100)) payment_method = Column(db.String(255)) payment_no = Column(db.String(255), unique=True) paid_at = Column(db.DateTime()) def pay_success(self, paid_at): self.paid_at = paid_at self.status = PaymentStatusKinds.confirmed.value self.save(commit=False) order = Order.get_by_id(self.order_id) order.pay_success(payment=self) @property def status_human(self): return PaymentStatusKinds(int(self.status)).name
class OrderLine(Model): __tablename__ = "order_line" product_name = Column(db.String(255)) product_sku = Column(db.String(100)) quantity = Column(db.Integer()) unit_price_net = Column(db.DECIMAL(10, 2)) is_shipping_required = Column(db.Boolean(), default=True) order_id = Column(db.Integer()) variant_id = Column(db.Integer()) product_id = Column(db.Integer()) @property def variant(self): return ProductVariant.get_by_id(self.variant_id) def get_total(self): return self.unit_price_net * self.quantity
class MenuItem(Model): __tablename__ = "public_menuitem" title = Column(db.String(255), nullable=False) order = Column(db.Integer(), default=0) url_ = Column("url", db.String(255)) category_id = Column(db.Integer(), default=0) collection_id = Column(db.Integer(), default=0) position = Column(db.Integer(), default=0) # item在site中的位置, 1是top,2是bottom page_id = Column(db.Integer(), default=0) parent_id = Column(db.Integer(), default=0) def __str__(self): return self.title @property def parent(self): return MenuItem.get_by_id(self.parent_id) @property @cache(MC_KEY_MENU_ITEM_CHILDREN.format("{self.id}")) def children(self): return ( MenuItem.query.filter(MenuItem.parent_id == self.id).order_by("order").all() ) @property def linked_object_url(self): if self.page_id: return Page.get_by_id(self.page_id).url elif self.category_id: return url_for("product.show_category", id=self.category_id) elif self.collection_id: return url_for("product.show_collection", id=self.collection_id) @property def url(self): return self.url_ if self.url_ else self.linked_object_url @classmethod def first_level_items(cls): return cls.query.filter(cls.parent_id == 0).order_by("order").all()
class UserAddress(Model): __tablename__ = "account_address" user_id = Column(db.Integer()) province = Column(db.String(255)) city = Column(db.String(255)) district = Column(db.String(255)) address = Column(db.String(255)) contact_name = Column(db.String(255)) contact_phone = Column(db.String(80)) @property def full_address(self): return f"{self.province}{self.city}{self.district}<br>{self.address}<br>{self.contact_name}<br>{self.contact_phone}" @hybrid_property def user(self): return User.get_by_id(self.user_id) def __str__(self): return self.full_address
class ProductTypeAttributes(Model): """存储的产品的属性是包括用户可选和不可选""" __tablename__ = "product_type_attribute" product_type_id = Column(db.Integer()) product_attribute_id = Column(db.Integer())
class OrderEvent(Model): __tablename__ = "order_event" order_id = Column(db.Integer()) user_id = Column(db.Integer()) type_ = Column("type", TINYINT())
class OrderNote(Model): __tablename__ = "order_note" order_id = Column(db.Integer()) user_id = Column(db.Integer()) content = Column(db.Text()) is_public = Column(db.Boolean(), default=True)
class ProductType(Model): __tablename__ = "product_type" title = Column(db.String(255), nullable=False) has_variants = Column(db.Boolean(), default=True) is_shipping_required = Column(db.Boolean(), default=False) def __str__(self): return self.title @property def product_attributes(self): at_ids = (ProductTypeAttributes.query.with_entities( ProductTypeAttributes.product_attribute_id).filter( ProductTypeAttributes.product_type_id == self.id).all()) return ProductAttribute.query.filter( ProductAttribute.id.in_(id for id, in at_ids)).all() @property def variant_attributes(self): at_ids = (ProductTypeVariantAttributes.query.with_entities( ProductTypeVariantAttributes.product_attribute_id).filter( ProductTypeVariantAttributes.product_type_id == self.id).all()) return ProductAttribute.query.filter( ProductAttribute.id.in_(id for id, in at_ids)).all() @property def variant_attr_id(self): if self.variant_attributes: return self.variant_attributes[0].id else: return None def update_product_attr(self, new_attrs): origin_ids = (ProductTypeAttributes.query.with_entities( ProductTypeAttributes.product_attribute_id).filter_by( product_type_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_attrs = set(int(i) for i in new_attrs) need_del = origin_ids - new_attrs need_add = new_attrs - origin_ids for id in need_del: ProductTypeAttributes.query.filter_by( product_type_id=self.id, product_attribute_id=id).first().delete(commit=False) for id in need_add: new = ProductTypeAttributes(product_type_id=self.id, product_attribute_id=id) db.session.add(new) db.session.commit() def update_variant_attr(self, variant_attr): origin_attr = ProductTypeVariantAttributes.query.filter_by( product_type_id=self.id).first() if origin_attr: origin_attr.product_attribute_id = variant_attr origin_attr.save() else: ProductTypeVariantAttributes.create( product_type_id=self.id, product_attribute_id=variant_attr) def delete(self): need_del_product_attrs = ProductTypeAttributes.query.filter_by( product_type_id=self.id).all() need_del_variant_attrs = ProductTypeVariantAttributes.query.filter_by( product_type_id=self.id).all() for item in itertools.chain(need_del_product_attrs, need_del_variant_attrs): item.delete(commit=False) need_update_products = Product.query.filter_by( product_type_id=self.id).all() for product in need_update_products: product.product_type_id = 0 db.session.add(product) db.session.delete(self) db.session.commit()
class Order(Model): __tablename__ = "order_order" token = Column(db.String(100), unique=True) shipping_address = Column(db.String(255)) user_id = Column(db.Integer()) total_net = Column(db.DECIMAL(10, 2)) discount_amount = Column(db.DECIMAL(10, 2), default=0) discount_name = Column(db.String(100)) voucher_id = Column(db.Integer()) shipping_price_net = Column(db.DECIMAL(10, 2)) status = Column(TINYINT()) shipping_method_name = Column(db.String(100)) shipping_method_id = Column(db.Integer()) ship_status = Column(TINYINT()) def __str__(self): return f"#{self.identity}" @classmethod def create_whole_order(cls, cart, note=None): # Step1, certify stock, voucher to_update_variants = [] to_update_orderlines = [] total_net = 0 for line in cart.lines: variant = ProductVariant.get_by_id(line.variant.id) result, msg = variant.check_enough_stock(line.quantity) if result is False: return result, msg variant.quantity_allocated += line.quantity to_update_variants.append(variant) orderline = OrderLine( variant_id=variant.id, quantity=line.quantity, product_name=variant.display_product(), product_sku=variant.sku, product_id=variant.sku.split("-")[0], unit_price_net=variant.price, is_shipping_required=variant.is_shipping_required, ) to_update_orderlines.append(orderline) total_net += orderline.get_total() voucher = None if cart.voucher_code: voucher = Voucher.get_by_code(cart.voucher_code) try: voucher.check_available(cart) except Exception as e: return False, str(e) # Step2, create Order obj try: shipping_method_id = None shipping_method_title = None shipping_method_price = 0 shipping_address = None if cart.shipping_method_id: shipping_method = ShippingMethod.get_by_id( cart.shipping_method_id) shipping_method_id = shipping_method.id shipping_method_title = shipping_method.title shipping_method_price = shipping_method.price shipping_address = UserAddress.get_by_id( cart.shipping_address_id).full_address order = cls.create( user_id=current_user.id, token=str(uuid4()), shipping_method_id=shipping_method_id, shipping_method_name=shipping_method_title, shipping_price_net=shipping_method_price, shipping_address=shipping_address, status=OrderStatusKinds.unfulfilled.value, total_net=total_net, ) except Exception as e: return False, str(e) # Step3, process others if note: order_note = OrderNote(order_id=order.id, user_id=current_user.id, content=note) db.session.add(order_note) if voucher: order.voucher_id = voucher.id order.discount_amount = voucher.get_vouchered_price(cart) order.discount_name = voucher.title voucher.used += 1 db.session.add(order) db.session.add(voucher) for variant in to_update_variants: db.session.add(variant) for orderline in to_update_orderlines: orderline.order_id = order.id db.session.add(orderline) for line in cart.lines: db.session.delete(line) db.session.delete(cart) db.session.commit() return order, "success" def get_absolute_url(self): return url_for("order.show", token=self.token) @property def identity(self): return self.token.split("-")[-1] @property def total(self): return self.total_net + self.shipping_price_net - self.discount_amount @property def status_human(self): return OrderStatusKinds(int(self.status)).name @property def total_human(self): return "$" + str(self.total) @classmethod def get_current_user_orders(cls): if current_user.is_authenticated: orders = (cls.query.filter_by(user_id=current_user.id).order_by( Order.id.desc()).all()) else: orders = [] return orders @classmethod def get_user_orders(cls, user_id): return cls.query.filter_by(user_id=user_id).all() @property def is_shipping_required(self): return any(line.is_shipping_required for line in self.lines) @property def is_self_order(self): return self.user_id == current_user.id @property def lines(self): return OrderLine.query.filter(OrderLine.order_id == self.id).all() @property def notes(self): return OrderNote.query.filter(OrderNote.order_id == self.id).all() @property def user(self): return User.get_by_id(self.user_id) @property def payment(self): return OrderPayment.query.filter_by(order_id=self.id).first() def pay_success(self, payment): self.status = OrderStatusKinds.fulfilled.value # to resolve another instance with key is already present in this session local_obj = db.session.merge(self) db.session.add(local_obj) for line in self.lines: variant = line.variant variant.quantity_allocated -= line.quantity variant.quantity -= line.quantity db.session.add(variant) db.session.commit() OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.payment_captured.value, ) def cancel(self): self.status = OrderStatusKinds.canceled.value db.session.add(self) for line in self.lines: variant = line.variant variant.quantity_allocated -= line.quantity db.session.add(variant) db.session.commit() OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.order_canceled.value, ) def complete(self): self.update(status=OrderStatusKinds.completed.value) OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.order_completed.value, ) def draft(self): self.update(status=OrderStatusKinds.draft.value) OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.draft_created.value, ) def delivered(self): self.update( status=OrderStatusKinds.shipped.value, ship_status=ShipStatusKinds.delivered.value, ) OrderEvent.create( order_id=self.id, user_id=self.user_id, type_=OrderEvents.order_delivered.value, )
class SaleProduct(Model): __tablename__ = "discount_sale_product" sale_id = Column(db.Integer()) product_id = Column(db.Integer())
class SaleCategory(Model): __tablename__ = "discount_sale_category" sale_id = Column(db.Integer()) category_id = Column(db.Integer())
class Voucher(Model): __tablename__ = "discount_voucher" type_ = Column("type", TINYINT()) title = Column(db.String(255)) code = Column(db.String(16), unique=True) usage_limit = Column(db.Integer()) used = Column(db.Integer(), default=0) start_date = Column(db.Date()) end_date = Column(db.Date()) discount_value_type = Column(TINYINT()) discount_value = Column(db.DECIMAL(10, 2)) limit = Column(db.DECIMAL(10, 2)) category_id = Column(db.Integer()) product_id = Column(db.Integer()) def __str__(self): return self.title @property def type_human(self): return VoucherTypeKinds(int(self.type_)).name @property def discount_value_type_human(self): return DiscountValueTypeKinds(int(self.discount_value_type)).name @property def validity_period(self): if self.start_date and self.end_date: return (datetime.strftime(self.start_date, "%m/%d/%Y") + " - " + datetime.strftime(self.end_date, "%m/%d/%Y")) return "" @classmethod def generate_code(cls): code = "".join(random.choices(string.ascii_uppercase, k=16)) exist = cls.query.filter_by(code=code).first() if not exist: return code else: return cls.generate_code() def check_available(self, cart=None): if self.start_date and self.start_date > datetime.now(): raise Exception( "The voucher code can not use now, please retry later") if self.end_date and self.end_date < datetime.now(): raise Exception("The voucher code has expired") if self.usage_limit and self.usage_limit - self.used < 0: raise Exception("This voucher code has been used out") if cart: self.check_available_by_cart(cart) return True def check_available_by_cart(self, cart): if self.type == VoucherTypeKinds.value.value: if self.limit and cart.subtotal < self.limit: raise Exception( f"The order total amount is not enough({self.limit}) to use this voucher code" ) elif self.type == VoucherTypeKinds.shipping.value: if self.limit and cart.shipping_method_price < self.limit: raise Exception( f"The order shipping price is not enough({self.limit}) to use this voucher code" ) elif self.type == VoucherTypeKinds.product.value: product = Product.get_by_id(self.product_id) # got any product in cart, should be zero if cart.get_product_price(self.product_id) == 0: raise Exception( f"This Voucher Code should be used for {product.title}") if self.limit and cart.get_product_price( self.product_id) < self.limit: raise Exception( f"The product {product.title} total amount is not enough({self.limit}) to use this voucher code" ) elif self.type == VoucherTypeKinds.category.value: category = Category.get_by_id(self.category_id) if cart.get_category_price(self.category_id) == 0: raise Exception( f"This Voucher Code should be used for {category.title}") if self.limit and cart.get_category_price( self.category_id) < self.limit: raise Exception( f"The category {category.title} total amount is not enough({self.limit}) to use this voucher code" ) @classmethod def get_by_code(cls, code): return cls.query.filter_by(code=code).first() def get_vouchered_price(self, cart): if self.type == VoucherTypeKinds.value.value: return self.get_voucher_from_price(cart.subtotal) elif self.type == VoucherTypeKinds.shipping.value: return self.get_voucher_from_price(cart.shipping_method_price) elif self.type == VoucherTypeKinds.product.value: return self.get_voucher_from_price( cart.get_product_price(self.product_id)) elif self.type == VoucherTypeKinds.category.value: return self.get_voucher_from_price( cart.get_category_price(self.category_id)) return 0 def get_voucher_from_price(self, price): if self.discount_value_type == DiscountValueTypeKinds.fixed.value: return self.discount_value if price > self.discount_value else price elif self.discount_value_type == DiscountValueTypeKinds.percent.value: price = price * self.discount_value / 100 return Decimal(price).quantize(Decimal("0.00"))
class Sale(Model): __tablename__ = "discount_sale" discount_value_type = Column(TINYINT()) title = Column(db.String(255)) discount_value = Column(db.DECIMAL(10, 2)) def __str__(self): return self.title @property def discount_value_type_label(self): return DiscountValueTypeKinds(int(self.discount_value_type)).name @classmethod def get_discounted_price(cls, product): sale_product = SaleProduct.query.filter_by( product_id=product.id).first() if sale_product: sale = Sale.get_by_id(sale_product.sale_id) else: sale_category = SaleCategory.query.filter_by( category_id=product.category.id).first() sale = Sale.get_by_id( sale_category.sale_id) if sale_category else None if sale is None: return 0 if sale.discount_value_type == DiscountValueTypeKinds.fixed.value: return sale.discount_value elif sale.discount_value_type == DiscountValueTypeKinds.percent.value: price = product.basic_price * sale.discount_value / 100 return Decimal(price).quantize(Decimal("0.00")) @property def categories(self): at_ids = (SaleCategory.query.with_entities( SaleCategory.category_id).filter( SaleCategory.sale_id == self.id).all()) return Category.query.filter(Category.id.in_(id for id, in at_ids)).all() @property def products_ids(self): return (SaleProduct.query.with_entities(SaleProduct.product_id).filter( SaleProduct.sale_id == self.id).all()) @property def products(self): return Product.query.filter( Product.id.in_(id for id, in self.products_ids)).all() def update_categories(self, category_ids): origin_ids = (SaleCategory.query.with_entities( SaleCategory.category_id).filter_by(sale_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_attrs = set(int(i) for i in category_ids) need_del = origin_ids - new_attrs need_add = new_attrs - origin_ids for id in need_del: SaleCategory.query.filter_by( sale_id=self.id, category_id=id).first().delete(commit=False) for id in need_add: new = SaleCategory(sale_id=self.id, category_id=id) db.session.add(new) db.session.commit() def update_products(self, product_ids): origin_ids = (SaleProduct.query.with_entities( SaleProduct.product_id).filter_by(sale_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_attrs = set(int(i) for i in product_ids) need_del = origin_ids - new_attrs need_add = new_attrs - origin_ids for id in need_del: SaleProduct.query.filter_by( sale_id=self.id, product_id=id).first().delete(commit=False) for id in need_add: new = SaleProduct(sale_id=self.id, product_id=id) db.session.add(new) db.session.commit() @staticmethod def clear_mc(target): # when update sales, need to update product discounts # for (id,) in target.products_ids: # rdb.delete(MC_KEY_PRODUCT_DISCOUNT_PRICE.format(id)) # need to process so many states, category update etc.. so delete all keys = rdb.keys(MC_KEY_PRODUCT_DISCOUNT_PRICE.format("*")) for key in keys: rdb.delete(key) @classmethod def __flush_insert_event__(cls, target): super().__flush_insert_event__(target) target.clear_mc(target) @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) target.clear_mc(target) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) target.clear_mc(target)
class ProductTypeVariantAttributes(Model): """存储的产品SKU的属性是可以给用户去选择的""" __tablename__ = "product_type_variant_attribute" product_type_id = Column(db.Integer()) product_attribute_id = Column(db.Integer())
class ProductVariant(Model): __tablename__ = "product_variant" sku = Column(db.String(32), unique=True) title = Column(db.String(255)) price_override = Column(db.DECIMAL(10, 2), default=0.00) quantity = Column(db.Integer(), default=0) quantity_allocated = Column(db.Integer(), default=0) product_id = Column(db.Integer(), default=0) attributes = Column(MutableDict.as_mutable(db.JSON())) def __str__(self): return self.title or self.sku def display_product(self): return f"{self.product} ({str(self)})" @property def sku_id(self): return self.sku.split("-")[1] @sku_id.setter def sku_id(self, data): pass @property def is_shipping_required(self): return self.product.product_type.is_shipping_required @property def quantity_available(self): return max(self.quantity - self.quantity_allocated, 0) @property def is_in_stock(self): return self.quantity_available > 0 @property def stock(self): return self.quantity - self.quantity_allocated @property def price(self): return self.price_override or self.product.price @property def product(self): return Product.get_by_id(self.product_id) def get_absolute_url(self): return url_for("product.show", id=self.product.id) @property def attribute_map(self): items = { ProductAttribute.get_by_id(k): AttributeChoiceValue.get_by_id(v) for k, v in self.attributes.items() } return items def check_enough_stock(self, quantity): if self.stock < quantity: return False, f"{self.display_product()} has not enough stock" return True, "success" @staticmethod def clear_mc(target): rdb.delete(MC_KEY_PRODUCT_VARIANT.format(target.product_id)) @classmethod def __flush_insert_event__(cls, target): super().__flush_insert_event__(target) target.clear_mc(target) @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) target.clear_mc(target) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) target.clear_mc(target)
class ProductAttribute(Model): __tablename__ = "product_attribute" title = Column(db.String(255), nullable=False) def __str__(self): return self.title @property @cache(MC_KEY_ATTRIBUTE_VALUES.format("{self.id}")) def values(self): return AttributeChoiceValue.query.filter( AttributeChoiceValue.attribute_id == self.id).all() @property def values_label(self): return ",".join([value.title for value in self.values]) @property def types(self): at_ids = (ProductTypeAttributes.query.with_entities( ProductTypeAttributes.product_type_id).filter_by( product_attribute_id=self.id).all()) return ProductType.query.filter( ProductType.id.in_(id for id, in at_ids)).all() @property def types_label(self): return ",".join([t.title for t in self.types]) def update_values(self, new_values): origin_values = list(value.title for value in self.values) need_del = set() need_add = set() for value in self.values: if value.title not in new_values: need_del.add(value) for value in new_values: if value not in origin_values: need_add.add(value) for value in need_del: value.delete(commit=False) for value in need_add: new = AttributeChoiceValue(title=value, attribute_id=self.id) db.session.add(new) db.session.commit() def update_types(self, new_types): origin_ids = (ProductTypeAttributes.query.with_entities( ProductTypeAttributes.product_type_id).filter_by( product_attribute_id=self.id).all()) origin_ids = set(i for i, in origin_ids) new_types = set(int(i) for i in new_types) need_del = origin_ids - new_types need_add = new_types - origin_ids for id in need_del: ProductTypeAttributes.query.filter_by( product_attribute_id=self.id, product_type_id=id).first().delete(commit=False) for id in need_add: new = ProductTypeAttributes(product_attribute_id=self.id, product_type_id=id) db.session.add(new) db.session.commit() def delete(self): need_del_product_attrs = ProductTypeAttributes.query.filter_by( product_attribute_id=self.id).all() need_del_variant_attrs = ProductTypeVariantAttributes.query.filter_by( product_attribute_id=self.id).all() for item in itertools.chain(need_del_product_attrs, need_del_variant_attrs, self.values): item.delete(commit=False) db.session.delete(self) db.session.commit() @classmethod def __flush_after_update_event__(cls, target): super().__flush_after_update_event__(target) rdb.delete(MC_KEY_ATTRIBUTE_VALUES.format(target.id)) @classmethod def __flush_delete_event__(cls, target): super().__flush_delete_event__(target) rdb.delete(MC_KEY_ATTRIBUTE_VALUES.format(target.id))