def finalize(self, key, interval): "Finalizes the list of values for the coordinate" if key not in self: raise KeyError("Unknown key %s" % key) if self[key] is None or self[key] == slice(None): return values = set() interval = tuple(interval) for value in self.expand(self[key]): if isinstance(value, str): if value not in interval: raise ValueError("Value %s not in interval" % value) values.add(value) continue if isinstance(value, int): values.add(interval[value]) continue assert isinstance(value, (slice, range)), "Trivial assertion" if isinstance(value, range): value = slice(value.start, value.stop, value.step) values.update(interval[value]) assert values <= set(interval), "Trivial assertion" if values == set(interval): values = slice(None) elif isiterable(values, str): values = tuple(sorted(values, key=interval.index)) else: tmp = tuple(compact_indexes(sorted(values))) if len(tmp) == 1: values = tmp[0] self[key] = values
def get_axes(self, *axes): "Returns the corresponding field axes to the given axes/dimensions" if not isiterable(axes, str): raise TypeError("The arguments need to be a list of strings") if "all" in axes: return tuple(sorted(self.axes)) axes = (axis for axis in self.lattice.expand(*axes) if axis in self.axes) return tuple(sorted(axes))
def expand(cls, *indexes): "Expands all the indexes in the list." for idx in indexes: if isinstance(idx, (int, str, slice, range, type(None))): yield idx elif isiterable(idx): yield from cls.expand(*idx) else: raise TypeError("Unexpected type %s" % type(idx))
def get_indexes(self, *axes): "Returns the corresponding indexes of the given axes/indexes/dimensions" if not isiterable(axes, str): raise TypeError("The arguments need to be a list of strings") if "all" in axes: return tuple(sorted(self.indexes)) indexes = set(axis for axis in axes if axis in self.indexes) axes = tuple(self.lattice.expand(set(axes).difference(indexes))) indexes.update( [idx for idx in self.indexes if self.index_to_axis(idx) in axes]) return tuple(sorted(indexes))
def get_indexes(self, coords): "Returns the indexes of the values of coords" if self == coords: return {} indexes = coords.copy() for key, val in coords.items(): if self[key] == val: continue if val is None: if isinstance(self[key], (int, str)): continue raise ValueError("None can only be assigned axis of size one") if self[key] == slice(None): continue if self[key] is None: indexes[key] = None continue if isinstance(self[key], (str, int)): raise ValueError( "Key %s with value %s is not compatible with %s" % (key, val, self[key])) if isinstance(val, (str, int)): if val not in self[key]: raise ValueError("%s not in field coordinates" % (val)) if isinstance(val, int): indexes[key] = self[key].index(val) continue if isiterable(self[key], str): if set(val) <= set(self[key]): continue raise ValueError("%s not in field coordinates" % (set(val).difference(self[key]))) assert isiterable(self[key], int), "Unexpected value %s" % self[key] if set(val) <= set(self[key]): indexes[key] = tuple(self[key].index(idx) for idx in val) continue raise ValueError("%s not in field coordinates" % (set(val).difference(self[key]))) return indexes.cleaned()
def dims(self, value): if self.frozen: raise RuntimeError( "The lattice has been frozen and dims cannot be changed") if not value: self._dims = LatticeAxes(lattice=self) return if isinstance(value, (dict, MappingProxyType)): self.dims.reset(value) # Adding default labels and groups dirs = list(self.dims) self.labels.setdefault("dirs", dirs) if len(dirs) > 1: self.groups.setdefault("time", (dirs[0], )) self.groups.setdefault("space", tuple(dirs[1:])) return if isinstance(value, int): if value < 0: raise ValueError("Non-positive number of dims") self.dims = [1] * value return if isiterable(value, int): if len(value) <= len(Lattice.default_dims_labels): self.dims = { Lattice.default_dims_labels[i]: v for i, v in enumerate(value) } else: self.dims = {"dim%d" % i: v for i, v in enumerate(value)} return if isiterable(value, str): self.dims = {v: 1 for v in value} return raise TypeError("Not allowed type %s for dims" % type(value))
def get_input_axes(cls, *axes, **kwargs): "Auxiliary function to uniform the axes input parameters" if not (bool(axes), "axes" in kwargs, "axis" in kwargs).count(True) <= 1: raise ValueError( "Only one between *axes, axes= or axis= can be used") axes = kwargs.pop("axis", kwargs.pop("axes", axes)) if isinstance(axes, str): axes = (axes, ) if not isiterable(axes, str): raise TypeError("Type for axes not valid. %s" % (axes)) return axes, kwargs
def reorder_label(self, label, label_order=None, **kwargs): "Changes the order of the label." rng = self.get_range(label) if not isiterable(self.get_range(label), str): raise KeyError("%s is not a label of the field" % label) if len(rng) <= 1: return self.copy() if label_order is None: label_order = Permutation(rng, label=label) labels_order = kwargs.pop("labels_order", {}) labels_order[label] = label_order return self.copy(labels_order=labels_order, **kwargs)
def transpose(self, *axes, **axes_order): """ Transposes the matrix/tensor indexes of the field. *NOTE*: this is conceptually different from numpy.transpose where all the axes are transposed. Parameters ---------- axes: str If given, only the listed axes are transposed, otherwise all the tensorial axes are changed. By default the order of the indexes is inverted. axes_order: dict Same as axes, but specifying the reordering of the indexes. The key must be one of the axis and the value the order using an index per repetition of the axis numbering from 0,1,... """ counts = dict(self.axes_counts) for (axis, val) in axes_order.items(): if not axis in counts: raise KeyError("Axis %s not in field" % (axis)) if not isiterable(val): raise TypeError("Type of value for axis %s not valid" % (axis)) val = tuple(val) if not len(val) == counts[axis]: raise ValueError( "%d indexes have been given for axis %s but it has count %d" % (len(val), axis, counts[axis])) if not set(val) == set(range(counts[axis])): raise ValueError( "%s has been given for axis %s. Not a permutation of %s." % (val, axis, tuple(range(counts[axis])))) if not axes and not axes_order: axes = ("dofs", ) axes = [ axis for axis in self.get_axes(*axes) if axis not in axes_order and counts[axis] > 1 ] for (axis, val) in tuple(axes_order.items()): if val == tuple(range(counts[axis])): del axes_order[axis] if not axes and not axes_order: return self.copy() return self.copy( self.backend.transpose(self.indexes_order, axes=axes, **axes_order))
def dofs(self, value): if self.frozen: raise RuntimeError( "The lattice has been frozen and dofs cannot be changed") if not value: self._dofs = LatticeAxes(lattice=self) return if isinstance(value, (dict, MappingProxyType)): self.dofs.reset(value) return if isinstance(value, str): assert value in Lattice.theories, "Unknown dofs name" value = Lattice.theories[value].copy() labels = value.pop("labels", {}) groups = value.pop("groups", {}) self.dofs = value self.labels.update(labels) self.groups.update(groups) return if isinstance(value, int): if value < 0: raise ValueError("Non-positive number of dofs") self.dofs = [1] * value return if isiterable(value, int): self.dofs = {"dof%d" % i: v for i, v in enumerate(value)} return if isiterable(value, str): self.dofs = {v: 1 for v in value} return raise TypeError("Not allowed type %s for dofs" % type(value))
def expand(self, *dimensions): "Expand the list of dimensions into the fundamental dimensions and degrees of freedom" for dim in dimensions: if isinstance(dim, str): if dim not in self.keys(): raise ValueError("Unknown dimension: %s" % dim) if dim in self.axes: yield dim else: yield from self.expand(self[dim]) elif isiterable(dim): yield from self.expand(*dim) else: raise TypeError("Unexpected type %s with value %s" % (type(dim), dim))
def __setitem__(self, key, val): if key in self and isinstance(val, int): for _k in self[key]: self.lattice[_k] = val return if isinstance(val, str): val = (val, ) if not isiterable(val, str): raise TypeError("Groups value can only be a list of strings") if self.lattice is not None: val = tuple(val) keys = set(self.lattice.keys()) if not keys >= set(val): raise ValueError("%s are not lattice keys" % set(val).difference(keys)) super().__setitem__(key, val)
def __setitem__(self, key, val): if isinstance(val, str): val = (val, ) if not isiterable(val, str): raise TypeError("Labels value can only be a list of strings") val = tuple(val) if not len(set(val)) == len(val): raise ValueError("%s contains repeated labels" % (val, )) labels = set(self.labels()) if key in self: labels = labels.difference(self[key]) inter = labels.intersection(val) if inter: raise ValueError("%s are labels already in use" % inter) super().__setitem__(key, val)
def format_coords(cls, *keys, **coords): "Returns a list of keys, coords from the given keys and coords" args = set() coords = Coordinates(coords) for key in keys: if key is None: continue if isinstance(key, str): args.add(key) elif isinstance(key, dict): coords.update(key) else: if not isiterable(key): raise TypeError( "keys can be str, dict or iterables. %s not accepted." % key) _args, _coords = cls.format_coords(*key) coords.update(_coords) args.update(_args) return tuple(args), coords
def isdistributed(val): "Returns if the argument is a Dask distributed object" return isiterable(val, Future)