# This file is part of Shuup.
#
# Copyright (c) 2012-2021, Shuup Commerce Inc. All rights reserved.
#
# This source code is licensed under the OSL-3.0 license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import unicode_literals
from collections import defaultdict
from decimal import Decimal
from django.utils.translation import ugettext as _
from itertools import chain
from shuup.core.fields.utils import ensure_decimal_places
from shuup.utils.money import Money
from ._line_tax import LineTax
class TaxSummary(list):
@classmethod
[docs] def from_line_taxes(cls, line_taxes, untaxed):
"""
Create TaxSummary from LineTaxes.
:param line_taxes: List of line taxes to summarize.
:type line_taxes: list[LineTax]
:param untaxed: Sum of taxless prices that have no taxes added.
:type untaxed: shuup.core.pricing.TaxlessPrice
"""
zero_amount = Money(0, untaxed.currency)
tax_amount_by_tax = defaultdict(lambda: zero_amount)
raw_base_amount_by_tax = defaultdict(lambda: zero_amount)
base_amount_by_tax = defaultdict(lambda: zero_amount)
for line_tax in line_taxes:
assert isinstance(line_tax, LineTax)
tax_amount_by_tax[line_tax.tax] += line_tax.amount
raw_base_amount_by_tax[line_tax.tax] += line_tax.base_amount
base_amount_by_tax[line_tax.tax] += line_tax.base_amount.as_rounded()
lines = [
TaxSummaryLine.from_tax(tax, base_amount_by_tax[tax], raw_base_amount_by_tax[tax], tax_amount)
for (tax, tax_amount) in tax_amount_by_tax.items()
]
if untaxed:
lines.append(
TaxSummaryLine(
tax_id=None,
tax_code="",
tax_name=_("Untaxed"),
tax_rate=Decimal(0),
based_on=untaxed.amount.as_rounded(),
raw_based_on=untaxed.amount,
tax_amount=zero_amount,
)
)
return cls(sorted(lines, key=TaxSummaryLine.get_sort_key))
def __repr__(self):
super_repr = super(TaxSummary, self).__repr__()
return "%s(%s)" % (type(self).__name__, super_repr)
class TaxSummaryLine(object):
_FIELDS = ["tax_id", "tax_code", "tax_name", "tax_rate", "raw_based_on", "based_on", "tax_amount", "taxful"]
_MONEY_FIELDS = set(["tax_amount", "taxful", "based_on", "raw_based_on"])
@classmethod
def from_tax(cls, tax, based_on, raw_based_on, tax_amount):
return cls(
tax_id=tax.id,
tax_code=tax.code,
tax_name=tax.name,
tax_rate=tax.rate,
based_on=based_on,
raw_based_on=raw_based_on,
tax_amount=tax_amount,
)
def __init__(self, tax_id, tax_code, tax_name, tax_rate, based_on, raw_based_on, tax_amount):
self.tax_id = tax_id
self.tax_code = tax_code
self.tax_name = tax_name
self.tax_rate = tax_rate
self.raw_based_on = ensure_decimal_places(raw_based_on)
self.based_on = ensure_decimal_places(based_on)
self.tax_amount = ensure_decimal_places(tax_amount)
self.taxful = (self.raw_based_on + tax_amount).as_rounded()
def get_sort_key(self):
return (-self.tax_rate or 0, self.tax_name)
def __repr__(self):
return "<{} {}/{}/{:.3%} based_on={} tax_amount={})>".format(
type(self).__name__, self.tax_id, self.tax_code, float(self.tax_rate or 0), self.based_on, self.tax_amount
)
def to_dict(self):
return dict(chain(*(self._serialize_field(x) for x in self._FIELDS)))
def _serialize_field(self, key):
value = getattr(self, key)
if isinstance(value, Money):
if key not in self._MONEY_FIELDS:
raise TypeError('Error! Non-price field "%s" has %r.' % (key, value))
return [(key, value.value), (key + "_currency", value.currency)]
assert not isinstance(value, Money)
return [(key, value)]