def flatten_result(source): """ Turns the given source sequence into a list of reg-exp possibilities and their arguments. Returns a list of strings and a list of argument lists. Each of the two lists will be of the same length. """ if source is None: return [''], [[]] if isinstance(source, Group): if source[1] is None: params = [] else: params = [source[1]] return [source[0]], [params] result = [''] result_args = [[]] pos = last = 0 for pos, elt in enumerate(source): if isinstance(elt, six.string_types): continue piece = ''.join(source[last:pos]) if isinstance(elt, Group): piece += elt[0] param = elt[1] else: param = None last = pos + 1 for i in range(len(result)): result[i] += piece if param: result_args[i].append(param) if isinstance(elt, (Choice, NonCapture)): if isinstance(elt, NonCapture): elt = [elt] inner_result, inner_args = [], [] for item in elt: res, args = flatten_result(item) inner_result.extend(res) inner_args.extend(args) new_result = [] new_args = [] for item, args in zip(result, result_args): for i_item, i_args in zip(inner_result, inner_args): new_result.append(item + i_item) new_args.append(args[:] + i_args) result = new_result result_args = new_args if pos >= last: piece = ''.join(source[last:]) for i in range(len(result)): result[i] += piece return result, result_args
def as_sql(self): # We don't need quote_name_unless_alias() here, since these are all # going to be column names (so we can avoid the extra overhead). qn = self.connection.ops.quote_name opts = self.query.model._meta result = ['INSERT INTO %s' % qn(opts.db_table)] has_fields = bool(self.query.fields) fields = self.query.fields if has_fields else [opts.pk] result.append('(%s)' % ', '.join([qn(f.column) for f in fields])) if has_fields: params = values = [ [ f.get_db_prep_save(getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True), connection=self.connection) for f in fields ] for obj in self.query.objs ] else: values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs] params = [[]] fields = [None] can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and not self.return_id and self.connection.features.has_bulk_insert) if can_bulk: placeholders = [["%s"] * len(fields)] else: placeholders = [ [self.placeholder(field, v) for field, v in zip(fields, val)] for val in values ] if self.return_id and self.connection.features.can_return_id_from_insert: params = params[0] col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) result.append("VALUES (%s)" % ", ".join(placeholders[0])) r_fmt, r_params = self.connection.ops.return_insert_id() result.append(r_fmt % col) params += r_params return [(" ".join(result), tuple(params))] if can_bulk: result.append(self.connection.ops.bulk_insert_sql(fields, len(values))) return [(" ".join(result), tuple([v for val in values for v in val]))] else: return [ (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals) for p, vals in zip(placeholders, params) ]
def resolve_columns(self, row, fields=()): """ This routine is necessary so that distances and geometries returned from extra selection SQL get resolved appropriately into Python objects. """ values = [] aliases = list(self.query.extra_select) # Have to set a starting row number offset that is used for # determining the correct starting row index -- needed for # doing pagination with Oracle. rn_offset = 0 if self.connection.ops.oracle: if self.query.high_mark is not None or self.query.low_mark: rn_offset = 1 index_start = rn_offset + len(aliases) # Converting any extra selection values (e.g., geometries and # distance objects added by GeoQuerySet methods). values = [self.query.convert_values(v, self.query.extra_select_fields.get(a, None), self.connection) for v, a in zip(row[rn_offset:index_start], aliases)] if self.connection.ops.oracle or getattr(self.query, 'geo_values', False): # We resolve the rest of the columns if we're on Oracle or if # the `geo_values` attribute is defined. for value, field in map(None, row[index_start:], fields): values.append(self.query.convert_values(value, field, self.connection)) else: values.extend(row[index_start:]) return tuple(values)
def results_iter(self): """ Returns an iterator over the results from executing this query. """ resolve_columns = hasattr(self, 'resolve_columns') fields = None has_aggregate_select = bool(self.query.aggregate_select) # Set transaction dirty if we're using SELECT FOR UPDATE to ensure # a subsequent commit/rollback is executed, so any database locks # are released. if self.query.select_for_update and transaction.is_managed(self.using): transaction.set_dirty(self.using) for rows in self.execute_sql(MULTI): for row in rows: if resolve_columns: if fields is None: # We only set this up here because # related_select_fields isn't populated until # execute_sql() has been called. if self.query.select_fields: fields = self.query.select_fields + self.query.related_select_fields else: fields = self.query.model._meta.fields # If the field was deferred, exclude it from being passed # into `resolve_columns` because it wasn't selected. only_load = self.deferred_to_columns() if only_load: db_table = self.query.model._meta.db_table fields = [f for f in fields if db_table in only_load and f.column in only_load[db_table]] row = self.resolve_columns(row, fields) if has_aggregate_select: aggregate_start = len(self.query.extra_select) + len(self.query.select) aggregate_end = aggregate_start + len(self.query.aggregate_select) row = tuple(row[:aggregate_start]) + tuple([ self.query.resolve_aggregate(value, aggregate, self.connection) for (alias, aggregate), value in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end]) ]) + tuple(row[aggregate_end:]) yield row
def normalize(pattern): """ Given a reg-exp pattern, normalizes it to an iterable of forms that suffice for reverse matching. This does the following: (1) For any repeating sections, keeps the minimum number of occurrences permitted (this means zero for optional groups). (2) If an optional group includes parameters, include one occurrence of that group (along with the zero occurrence case from step (1)). (3) Select the first (essentially an arbitrary) element from any character class. Select an arbitrary character for any unordered class (e.g. '.' or '\w') in the pattern. (5) Ignore comments and any of the reg-exp flags that won't change what we construct ("iLmsu"). "(?x)" is an error, however. (6) Raise an error on all other non-capturing (?...) forms (e.g. look-ahead and look-behind matches) and any disjunctive ('|') constructs. Django's URLs for forward resolving are either all positional arguments or all keyword arguments. That is assumed here, as well. Although reverse resolving can be done using positional args when keyword args are specified, the two cannot be mixed in the same reverse() call. """ # Do a linear scan to work out the special features of this pattern. The # idea is that we scan once here and collect all the information we need to # make future decisions. result = [] non_capturing_groups = [] consume_next = True pattern_iter = next_char(iter(pattern)) num_args = 0 # A "while" loop is used here because later on we need to be able to peek # at the next character and possibly go around without consuming another # one at the top of the loop. try: ch, escaped = next(pattern_iter) except StopIteration: return [('', [])] try: while True: if escaped: result.append(ch) elif ch == '.': # Replace "any character" with an arbitrary representative. result.append(".") elif ch == '|': # FIXME: One day we'll should do this, but not in 1.0. raise NotImplementedError elif ch == "^": pass elif ch == '$': break elif ch == ')': # This can only be the end of a non-capturing group, since all # other unescaped parentheses are handled by the grouping # section later (and the full group is handled there). # # We regroup everything inside the capturing group so that it # can be quantified, if necessary. start = non_capturing_groups.pop() inner = NonCapture(result[start:]) result = result[:start] + [inner] elif ch == '[': # Replace ranges with the first character in the range. ch, escaped = next(pattern_iter) result.append(ch) ch, escaped = next(pattern_iter) while escaped or ch != ']': ch, escaped = next(pattern_iter) elif ch == '(': # Some kind of group. ch, escaped = next(pattern_iter) if ch != '?' or escaped: # A positional group name = "_%d" % num_args num_args += 1 result.append(Group((("%%(%s)s" % name), name))) walk_to_end(ch, pattern_iter) else: ch, escaped = next(pattern_iter) if ch in "iLmsu#": # All of these are ignorable. Walk to the end of the # group. walk_to_end(ch, pattern_iter) elif ch == ':': # Non-capturing group non_capturing_groups.append(len(result)) elif ch != 'P': # Anything else, other than a named group, is something # we cannot reverse. raise ValueError("Non-reversible reg-exp portion: '(?%s'" % ch) else: ch, escaped = next(pattern_iter) if ch not in ('<', '='): raise ValueError("Non-reversible reg-exp portion: '(?P%s'" % ch) # We are in a named capturing group. Extra the name and # then skip to the end. if ch == '<': terminal_char = '>' # We are in a named backreference. else: terminal_char = ')' name = [] ch, escaped = next(pattern_iter) while ch != terminal_char: name.append(ch) ch, escaped = next(pattern_iter) param = ''.join(name) # Named backreferences have already consumed the # parenthesis. if terminal_char != ')': result.append(Group((("%%(%s)s" % param), param))) walk_to_end(ch, pattern_iter) else: result.append(Group((("%%(%s)s" % param), None))) elif ch in "*?+{": # Quanitifers affect the previous item in the result list. count, ch = get_quantifier(ch, pattern_iter) if ch: # We had to look ahead, but it wasn't need to compute the # quanitifer, so use this character next time around the # main loop. consume_next = False if count == 0: if contains(result[-1], Group): # If we are quantifying a capturing group (or # something containing such a group) and the minimum is # zero, we must also handle the case of one occurrence # being present. All the quantifiers (except {0,0}, # which we conveniently ignore) that have a 0 minimum # also allow a single occurrence. result[-1] = Choice([None, result[-1]]) else: result.pop() elif count > 1: result.extend([result[-1]] * (count - 1)) else: # Anything else is a literal. result.append(ch) if consume_next: ch, escaped = next(pattern_iter) else: consume_next = True except StopIteration: pass except NotImplementedError: # A case of using the disjunctive form. No results for you! return [('', [])] return list(zip(*flatten_result(result)))
def _ogrinspect(data_source, model_name, geom_name='geom', layer_key=0, srid=None, multi_geom=False, name_field=None, imports=True, decimal=False, blank=False, null=False): """ Helper routine for `ogrinspect` that generates GeoDjango models corresponding to the given data source. See the `ogrinspect` docstring for more details. """ # Getting the DataSource if isinstance(data_source, str): data_source = DataSource(data_source) elif isinstance(data_source, DataSource): pass else: raise TypeError('Data source parameter must be a string or a DataSource object.') # Getting the layer corresponding to the layer key and getting # a string listing of all OGR fields in the Layer. layer = data_source[layer_key] ogr_fields = layer.fields # Creating lists from the `null`, `blank`, and `decimal` # keyword arguments. def process_kwarg(kwarg): if isinstance(kwarg, (list, tuple)): return [s.lower() for s in kwarg] elif kwarg: return [s.lower() for s in ogr_fields] else: return [] null_fields = process_kwarg(null) blank_fields = process_kwarg(blank) decimal_fields = process_kwarg(decimal) # Gets the `null` and `blank` keywords for the given field name. def get_kwargs_str(field_name): kwlist = [] if field_name.lower() in null_fields: kwlist.append('null=True') if field_name.lower() in blank_fields: kwlist.append('blank=True') if kwlist: return ', ' + ', '.join(kwlist) else: return '' # For those wishing to disable the imports. if imports: yield '# This is an auto-generated Django model module created by ogrinspect.' yield 'from djangocg.contrib.gis.db import models' yield '' yield 'class %s(models.Model):' % model_name for field_name, width, precision, field_type in zip(ogr_fields, layer.field_widths, layer.field_precisions, layer.field_types): # The model field name. mfield = field_name.lower() if mfield[-1:] == '_': mfield += 'field' # Getting the keyword args string. kwargs_str = get_kwargs_str(field_name) if field_type is OFTReal: # By default OFTReals are mapped to `FloatField`, however, they # may also be mapped to `DecimalField` if specified in the # `decimal` keyword. if field_name.lower() in decimal_fields: yield ' %s = models.DecimalField(max_digits=%d, decimal_places=%d%s)' % (mfield, width, precision, kwargs_str) else: yield ' %s = models.FloatField(%s)' % (mfield, kwargs_str[2:]) elif field_type is OFTInteger: yield ' %s = models.IntegerField(%s)' % (mfield, kwargs_str[2:]) elif field_type is OFTString: yield ' %s = models.CharField(max_length=%s%s)' % (mfield, width, kwargs_str) elif field_type is OFTDate: yield ' %s = models.DateField(%s)' % (mfield, kwargs_str[2:]) elif field_type is OFTDateTime: yield ' %s = models.DateTimeField(%s)' % (mfield, kwargs_str[2:]) elif field_type is OFTTime: yield ' %s = models.TimeField(%s)' % (mfield, kwargs_str[2:]) else: raise TypeError('Unknown field type %s in %s' % (field_type, mfield)) # TODO: Autodetection of multigeometry types (see #7218). gtype = layer.geom_type if multi_geom and gtype.num in (1, 2, 3): geom_field = 'Multi%s' % gtype.django else: geom_field = gtype.django # Setting up the SRID keyword string. if srid is None: if layer.srs is None: srid_str = 'srid=-1' else: srid = layer.srs.srid if srid is None: srid_str = 'srid=-1' elif srid == 4326: # WGS84 is already the default. srid_str = '' else: srid_str = 'srid=%s' % srid else: srid_str = 'srid=%s' % srid yield ' %s = models.%s(%s)' % (geom_name, geom_field, srid_str) yield ' objects = models.GeoManager()' if name_field: yield '' yield ' def __str__(self): return self.%s' % name_field
def __init__(self, *args, **kwargs): signals.pre_init.send(sender=self.__class__, args=args, kwargs=kwargs) # Set up the storage for instance state self._state = ModelState() # There is a rather weird disparity here; if kwargs, it's set, then args # overrides it. It should be one or the other; don't duplicate the work # The reason for the kwargs check is that standard iterator passes in by # args, and instantiation for iteration is 33% faster. args_len = len(args) if args_len > len(self._meta.fields): # Daft, but matches old exception sans the err msg. raise IndexError("Number of args exceeds number of fields") fields_iter = iter(self._meta.fields) if not kwargs: # The ordering of the zip calls matter - zip throws StopIteration # when an iter throws it. So if the first iter throws it, the second # is *not* consumed. We rely on this, so don't change the order # without changing the logic. for val, field in zip(args, fields_iter): setattr(self, field.attname, val) else: # Slower, kwargs-ready version. for val, field in zip(args, fields_iter): setattr(self, field.attname, val) kwargs.pop(field.name, None) # Maintain compatibility with existing calls. if isinstance(field.rel, ManyToOneRel): kwargs.pop(field.attname, None) # Now we're left with the unprocessed fields that *must* come from # keywords, or default. for field in fields_iter: is_related_object = False # This slightly odd construct is so that we can access any # data-descriptor object (DeferredAttribute) without triggering its # __get__ method. if field.attname not in kwargs and isinstance( self.__class__.__dict__.get(field.attname), DeferredAttribute ): # This field will be populated on request. continue if kwargs: if isinstance(field.rel, ManyToOneRel): try: # Assume object instance was passed in. rel_obj = kwargs.pop(field.name) is_related_object = True except KeyError: try: # Object instance wasn't passed in -- must be an ID. val = kwargs.pop(field.attname) except KeyError: val = field.get_default() else: # Object instance was passed in. Special case: You can # pass in "None" for related objects if it's allowed. if rel_obj is None and field.null: val = None else: try: val = kwargs.pop(field.attname) except KeyError: # This is done with an exception rather than the # default argument on pop because we don't want # get_default() to be evaluated, and then not used. # Refs #12057. val = field.get_default() else: val = field.get_default() if is_related_object: # If we are passed a related instance, set it using the # field.name instead of field.attname (e.g. "user" instead of # "user_id") so that the object gets properly cached (and type # checked) by the RelatedObjectDescriptor. setattr(self, field.name, rel_obj) else: setattr(self, field.attname, val) if kwargs: for prop in list(kwargs): try: if isinstance(getattr(self.__class__, prop), property): setattr(self, prop, kwargs.pop(prop)) except AttributeError: pass if kwargs: raise TypeError("'%s' is an invalid keyword argument for this function" % list(kwargs)[0]) super(Model, self).__init__() signals.post_init.send(sender=self.__class__, instance=self)
def get_columns(self, with_aliases=False): """ Return the list of columns to use in the select statement. If no columns have been specified, returns all columns relating to fields in the model. If 'with_aliases' is true, any column names that are duplicated (without the table names) are given unique aliases. This is needed in some cases to avoid ambiguitity with nested queries. This routine is overridden from Query to handle customized selection of geometry columns. """ qn = self.quote_name_unless_alias qn2 = self.connection.ops.quote_name result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] aliases = set(self.query.extra_select.keys()) if with_aliases: col_aliases = aliases.copy() else: col_aliases = set() if self.query.select: only_load = self.deferred_to_columns() # This loop customized for GeoQuery. for col, field in zip(self.query.select, self.query.select_fields): if isinstance(col, (list, tuple)): alias, column = col table = self.query.alias_map[alias].table_name if table in only_load and column not in only_load[table]: continue r = self.get_field_select(field, alias, column) if with_aliases: if col[1] in col_aliases: c_alias = 'Col%d' % len(col_aliases) result.append('%s AS %s' % (r, c_alias)) aliases.add(c_alias) col_aliases.add(c_alias) else: result.append('%s AS %s' % (r, qn2(col[1]))) aliases.add(r) col_aliases.add(col[1]) else: result.append(r) aliases.add(r) col_aliases.add(col[1]) else: result.append(col.as_sql(qn, self.connection)) if hasattr(col, 'alias'): aliases.add(col.alias) col_aliases.add(col.alias) elif self.query.default_cols: cols, new_aliases = self.get_default_columns(with_aliases, col_aliases) result.extend(cols) aliases.update(new_aliases) max_name_length = self.connection.ops.max_name_length() result.extend([ '%s%s' % ( self.get_extra_select_format(alias) % aggregate.as_sql(qn, self.connection), alias is not None and ' AS %s' % qn(truncate_name(alias, max_name_length)) or '' ) for alias, aggregate in self.query.aggregate_select.items() ]) # This loop customized for GeoQuery. for (table, col), field in zip(self.query.related_select_cols, self.query.related_select_fields): r = self.get_field_select(field, table, col) if with_aliases and col in col_aliases: c_alias = 'Col%d' % len(col_aliases) result.append('%s AS %s' % (r, c_alias)) aliases.add(c_alias) col_aliases.add(c_alias) else: result.append(r) aliases.add(r) col_aliases.add(col) self._select_aliases = aliases return result