# -*- coding: utf-8 -*-
# This file is part of Shuup.
#
# Copyright (c) 2012-2017, Shoop Commerce Ltd. All rights reserved.
#
# This source code is licensed under the OSL-3.0 license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import unicode_literals, with_statement
import datetime
import itertools
from operator import iand, ior
import dateutil.parser
import six
import xlrd
from django.db.models import AutoField, ForeignKey, Q
from django.db.models.fields import BooleanField, DateField, DateTimeField
from django.db.models.fields.related import RelatedField
from django.db.transaction import atomic
from django.utils.text import force_text
from django.utils.translation import ugettext_lazy as _
from enumfields import EnumIntegerField
from shuup.importer._mapper import RelatedMapper
from shuup.importer.exceptions import ImporterError
from shuup.importer.importing.meta import ImportMetaBase
from shuup.importer.importing.session import DataImporterRowSession
from shuup.importer.utils import copy_update, fold_mapping_name
from shuup.importer.utils.importer import ImportMode
[docs]class DataImporter(object):
meta_class_getter_name = "get_import_meta"
meta_base_class = ImportMetaBase
extra_matches = {}
unique_fields = {}
unmatched_fields = set()
relation_map_cache = {}
model = None
def __init__(self, data, shop, language):
self.shop = shop
self.data = data
self.data_keys = data[0].keys()
self.language = language
meta_class_getter = getattr(self.model, self.meta_class_getter_name, None)
meta_class = meta_class_getter() if meta_class_getter else self.meta_base_class
self._meta = (meta_class(self, self.model) if meta_class else None)
self.field_defaults = self._meta.get_import_defaults()
[docs] def process_data(self):
mapping = self.create_mapping()
data_map = self.map_data_to_fields(mapping)
return data_map
[docs] def create_mapping(self):
mapping = {}
aliases = self._meta.field_aliases
for model in self.get_related_models():
for field, mode in self._get_fields_with_modes(model):
map_base = self._get_map_base(field, mode)
if isinstance(field, RelatedField) and not field.null:
map_base["priority"] -= 10
# Figure out names
names = [field.name]
if field.verbose_name:
names.append(field.verbose_name)
# find aliases
this_aliases = aliases.get(field.name)
if this_aliases:
if isinstance(this_aliases, six.string_types):
this_aliases = [this_aliases]
names.extend(this_aliases)
# Assign into mapping
for name in names:
if map_base.get("translated"):
mapping[name] = copy_update(map_base, lang=self.language)
else:
mapping[name] = map_base
mapping = dict((fold_mapping_name(mname), mdata) for (mname, mdata) in six.iteritems(mapping))
self.mapping = mapping
return mapping
[docs] def map_data_to_fields(self, model_mapping):
"""
Map fields
If field is not found it will be saved into unmapped
:return:
"""
# reset unmatched here
self.unmatched_fields = set()
data_map = {}
for field_name in sorted(self.data_keys):
mfname = fold_mapping_name(field_name)
mapped_value = model_mapping.get(mfname)
if not mapped_value:
for fld, opt in six.iteritems(model_mapping):
matcher = opt.get("matcher")
if matcher and (matcher(field_name) or matcher(mfname)):
mapped_value = opt
break
if mapped_value:
data_map[field_name] = mapped_value
if mapped_value.get("keyable"):
self.unique_fields[field_name] = mapped_value
elif not mapped_value and not self._meta.has_post_save_handler(field_name):
self.unmatched_fields.add(field_name)
self.data_map = data_map
return data_map
[docs] def manually_match(self, imported_field_name, target_field_name):
if target_field_name == "0": # nothing was selected
return
target_model, shuup_field_name = target_field_name.split(":")
mapping = self.mapping.get(shuup_field_name)
mapping["matcher"] = self.matcher
mapping["setter"] = self.set_extra_match
self.extra_matches[target_field_name] = imported_field_name
self.mapping[shuup_field_name] = mapping
return self.mapping
[docs] def do_remap(self):
self.map_data_to_fields(self.mapping)
[docs] def matcher(self, value):
for original_field, new_field in six.iteritems(self.extra_matches):
if new_field == value:
return True
return False
[docs] def do_import(self, import_mode):
self.import_mode = import_mode
self.other_log_messages = []
self.new_objects = []
self.updated_objects = []
self.log_messages = []
for row in self.data:
self.process_row(row)
[docs] def resolve_object(self, cls, value):
try:
value = int(value)
return cls.objects.get(pk=value)
except:
name_fields = ["name", "title"]
q = Q()
for f in name_fields:
if hasattr(cls, "_parler_meta"):
f = "translations__%s" % f
q |= Q(**{f: value})
return cls.objects.get(q)
def _resolve_obj(self, row):
obj = self._find_matching_object(row, self.shop)
if not obj:
if self.import_mode == ImportMode.UPDATE:
self.other_log_messages.append(_("Row ignored (no existing item and creating new is not allowed)."))
return (None, True)
self.target_model = self.find_matching_model(row)
obj = self.target_model(**self.field_defaults)
new = True
else:
new = False
if self.import_mode == ImportMode.CREATE:
self.other_log_messages.append(
_("Row ignored (object already exists (%(object_name)s with id: %(object_id)s).") % {
"object_name": str(obj),
"object_id": obj.pk
}
)
return (None, False)
if hasattr(obj, "_parler_meta"):
obj.set_current_language(self.language)
return (obj, new)
def _row_valid(self, mapping, value, obj):
if not mapping.get("writable"):
return False
if obj.pk and value is None: # Don't empty fields
return False
return True
@atomic
[docs] def process_row(self, row):
if all((not val) for val in row.values()): # Empty row, skip it
return
obj, new = self._resolve_obj(row)
if not obj:
return
row_session = DataImporterRowSession(self, row, obj, self.shop)
for fname, mapping in sorted(six.iteritems(self.data_map), key=lambda x: (x[1].get("priority"), x[0])):
field = mapping.get("field")
if not field:
continue
if field.name in self._meta.fields_to_skip:
continue
value = orig_value = row.get(fname)
if not self._row_valid(mapping, value, obj):
continue
value = self._handle_special_row_values(mapping, value)
setter = mapping.get("setter")
if setter:
value, has_related = self._handle_related_value(field, mapping, orig_value, row_session, obj, value)
setter(row_session, value, mapping)
continue
value, has_related = self._handle_related_value(field, mapping, orig_value, row_session, obj, value)
if has_related:
continue
if field and not field.blank and value in (None, ""):
continue # Skip fields that require a value but don't have one in the original data.
self._handle_row_field(field, mapping, orig_value, row_session, obj, value)
self.save_row(new, row_session)
def _handle_related_value(self, field, mapping, orig_value, row_session, obj, value):
has_related = False
if mapping.get("fk"):
value = self._handle_row_fk_value(field, orig_value, row_session, value)
if not field.null and value is None:
has_related = True
elif mapping.get("m2m"):
self._handle_row_m2m_value(field, orig_value, row_session, obj, value)
has_related = True
elif mapping.get("is_enum_field"):
for k, v in field.get_choices():
if fold_mapping_name(force_text(v)) == fold_mapping_name(orig_value):
value = k
break
return (value, has_related)
def _handle_special_row_values(self, mapping, value):
if mapping.get("datatype") in ["datetime", "date"]:
if isinstance(value, float): # Sort of terrible
value = datetime.datetime(*xlrd.xldate_as_tuple(value, self.data.meta["xls_datemode"]))
if isinstance(value, float):
if int(value) == value:
value = int(value)
return value
def _handle_row_field(self, field, mapping, orig_value, row_session, target, value):
value = self._get_field_choices_value(field, value)
# Ensure the datetime, date, and boolean values are presented properly
if isinstance(field, DateField) or isinstance(field, DateTimeField):
try:
value = dateutil.parser.parse(value)
except ValueError:
# todo: Handle these somehow
value = None
elif isinstance(field, BooleanField):
if not value or value == "" or value == " ":
value = False
if mapping.get("fk") and value.pk:
setattr(target, field.name, value)
else:
try:
value = field.to_python(value)
except Exception as exc:
row_session.log(
_("Error setting value for field %(field_name)s. (%(exception)s)") % {
"field_name": (field.verbose_name or field.name),
"exception": exc
}
)
else:
value = self._meta.mutate_normal_field_set(row_session, field, value, original=orig_value)
setattr(target, field.name, value)
def _get_field_choices_value(self, field, value):
if field.choices:
for (ck, cv) in field.choices:
if value in (ck, cv):
value = ck
break
return value
def _handle_row_m2m_value(self, field, orig_value, row_session, target, value):
value = self.process_related_value(row_session, field, value, multi=True)
if orig_value and not value:
row_session.log(
_("Couldn't set value %(original_value)s for field %(field_name)s.") % {
"original_value": orig_value,
"field_name": (field.verbose_name or field.name)
}
)
row_session.defer("m2m_%s" % field.name, target, {field.name: value})
def _handle_row_fk_value(self, field, orig_value, row_session, value):
value = self.process_related_value(row_session, field, value, multi=False)
if orig_value and not value:
row_session.log(
_("Couldn't set value %(original_value)s for field %(field_name)s.") % {
"original_value": orig_value,
"field_name": (field.verbose_name or field.name)
}
)
return value
[docs] def save_row(self, new, row_session):
self._meta.presave_hook(row_session)
try:
row_session.save()
self._meta.postsave_hook(row_session)
(self.new_objects if new else self.updated_objects).append(row_session.instance)
for post_save_handler, fields in six.iteritems(self._meta.post_save_handlers):
if hasattr(self._meta, post_save_handler):
func = getattr(self._meta, post_save_handler)
func(fields, row_session)
if row_session.log_messages:
self.log_messages.append({
"instance": row_session.instance,
"messages": row_session.log_messages
})
except ImporterError as e:
self.other_log_messages.append(e.message)
[docs] def get_fields_for_mapping(self, only_non_mapped=True):
"""
Get fields for manual mapping
:return: List of fields `module_name.Model:field` or empty list
:rtype: list
"""
fields = []
mapped_keys = [k for k in self.data_map]
for model in self.get_related_models():
for field in model._meta.local_fields:
if only_non_mapped and field.name in mapped_keys:
continue
model_field = "%s:%s" % (model.__name__, field.name)
fields.append((model_field, field.verbose_name))
if hasattr(model, "_parler_meta"):
for field in model._parler_meta.root_model._meta.get_fields():
if only_non_mapped and field.name in mapped_keys:
continue
model_field = "%s:%s" % (model.__name__, field.name)
fields.append((model_field, field.verbose_name))
return fields
def _get_map_base(self, field, mode):
is_translation = (mode == 2)
is_m2m = (mode == 1)
is_fk = isinstance(field, ForeignKey)
is_enum_field = isinstance(field, EnumIntegerField)
return {
"name": field.verbose_name or field.name,
"id": field.name,
"field": field,
"keyable": field.unique,
"writable": field.editable and not isinstance(field, AutoField),
"pk": bool(field.primary_key),
"translated": is_translation,
"priority": 0,
"m2m": is_m2m,
"fk": is_fk,
"is_enum_field": is_enum_field,
}
def _find_matching_object(self, row, shop):
"""
Find object that matches the given row and shop
:return: Found object or ``None``
"""
field_map_values = [(fname, mapping, row.get(fname)) for (fname, mapping) in six.iteritems(self.unique_fields)]
row_keys = dict((mapping["field"].name, value) for (fname, mapping, value) in field_map_values if value)
if row_keys:
qs = [Q(**{fname: value}) for (fname, value) in six.iteritems(row_keys)]
if "shop" in [field.name for field in self.model._meta.local_fields]:
qs &= Q(shop=shop)
and_query = six.moves.reduce(iand, [Q()] + qs)
or_query = six.moves.reduce(ior, [Q()] + qs)
try:
return self.model.objects.get(and_query)
except: # Found multiple or zero -- not okay
pass
return self.model.objects.filter(or_query).first()
return None
def _get_fields_with_modes(self, model):
return itertools.chain(
zip(model._meta.local_fields, itertools.repeat(0)),
zip(model._meta.local_many_to_many, itertools.repeat(1)),
zip((f for f in model._parler_meta.root_model._meta.get_fields()
if f.name not in ("id", "master", "language_code")), itertools.repeat(2))
if hasattr(model, "_parler_meta") else ()
)
[docs] def get_row_model(self, row):
"""
Get model that matches the row
Can be used in cases where you have multiple types of data in same import
:param row: A row dict
"""
return self.model
@property
def is_multi_model(self):
return (len(self.get_related_models()) > 1)
[docs] def find_matching_model(self, row):
if not self.is_multi_model:
return self.model
return self.get_row_model(row)