def _parse_kwargs(self, wrap): """ Parse kwargs """ bz = self._obj parent = bz.parent if wrap is None: # we always return a wrap def wrap(v, parent=None, k=None, weight=None): return v else: wrap = allow_kwargs("parent", "k", "weight")(wrap) return bz, parent, wrap
def _parse_kwargs(self, wrap, eta=False, eta_key=""): """ Parse kwargs """ bz = self._obj parent = bz.parent if wrap is None: # we always return a wrap def wrap(v, parent=None, k=None, weight=None): return v else: wrap = allow_kwargs("parent", "k", "weight")(wrap) eta = tqdm_eta(len(bz), f"{bz.__class__.__name__}.{eta_key}", "k", eta) return bz, parent, wrap, eta
def _call(self, *args, **kwargs): func = getattr(self.parent, self._bz_attr) wrap = allow_kwargs('parent', 'k', 'weight')(kwargs.pop('wrap', _do_nothing)) eta = tqdm_eta(len(self), self.__class__.__name__ + '.asyield', 'k', kwargs.pop('eta', False)) parent = self.parent k = self.k w = self.weight for i in range(len(k)): yield wrap(func(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i]) eta.update() eta.close()
def _call(self, *args, **kwargs): func = getattr(self.parent, self._bz_attr) wrap = allow_kwargs('parent', 'k', 'weight')(kwargs.pop('wrap', _do_nothing)) eta = tqdm_eta(len(self), self.__class__.__name__ + '.asarray', 'k', kwargs.pop('eta', False)) parent = self.parent k = self.k w = self.weight v = wrap(func(*args, k=k[0], **kwargs), parent=parent, k=k[0], weight=w[0]) if v.ndim == 0: a = np.empty([len(self)], dtype=v.dtype) else: a = np.empty((len(self), ) + v.shape, dtype=v.dtype) a[0] = v del v for i in range(1, len(k)): a[i] = wrap(func(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i]) eta.update() eta.close() return a
def _call(self, *args, **kwargs): data_axis = kwargs.pop('data_axis', None) grid_unit = kwargs.pop('grid_unit', 'b') func = getattr(self.parent, self._bz_attr) wrap = allow_kwargs('parent', 'k', 'weight')(kwargs.pop('wrap', _do_nothing)) eta = tqdm_eta(len(self), self.__class__.__name__ + '.asgrid', 'k', kwargs.pop('eta', False)) parent = self.parent k = self.k w = self.weight # Extract information from the MP grid, these values # define the Grid size, etc. diag = self._diag.copy() if not np.all(self._displ == 0): raise SislError(self.__class__.__name__ + '.{} requires the displacement to be 0 for all k-points.'.format(self._bz_attr)) displ = self._displ.copy() size = self._size.copy() steps = size / diag if self._centered: offset = np.where(diag % 2 == 0, steps, steps / 2) else: offset = np.where(diag % 2 == 0, steps / 2, steps) # Instead of doing # _in_primitive(k) + 0.5 - offset # we can do it here # _in_primitive(k) + offset' offset -= 0.5 # Check the TRS direction trs_axis = self._trs _in_primitive = self.in_primitive _rint = np.rint _int32 = np.int32 def k2idx(k): # In case TRS is applied two indices may be returned return _rint((_in_primitive(k) - offset) / steps).astype(_int32) # To find the opposite k-point, do this # idx[i] = [diag[i] - idx[i] - 1, idx[i] # with i in [0, 1, 2] # Create cell from the reciprocal cell. if grid_unit == 'b': cell = np.diag(self._size) else: cell = parent.sc.rcell * self._size.reshape(1, -1) / units('Ang', grid_unit) # Find the grid origo origo = -(cell * 0.5).sum(0) # Calculate first k-point (to get size and dtype) v = wrap(func(*args, k=k[0], **kwargs), parent=parent, k=k[0], weight=w[0]) if data_axis is None: if v.size != 1: raise SislError(self.__class__.__name__ + '.{} requires one value per-kpoint because of the 3D grid values'.format(self._bz_attr)) else: # Check the weights weights = self.grid(diag[data_axis], displ[data_axis], size[data_axis], centered=self._centered, trs=trs_axis == data_axis)[1] # Correct the Grid size diag[data_axis] = len(v) # Create the orthogonal cell direction to ensure it is orthogonal # Since array axis is cyclic for negative numbers, we simply do this cell[data_axis, :] = cross(cell[data_axis-1, :], cell[data_axis-2, :]) # Check whether we should rotate it if cart2spher(cell[data_axis, :])[2] > pi / 4: cell[data_axis, :] *= -1 # Correct cell for the grid if trs_axis >= 0: origo[trs_axis] = 0. # Correct offset since we only have the positive halve if self._diag[trs_axis] % 2 == 0 and not self._centered: offset[trs_axis] = steps[trs_axis] / 2 else: offset[trs_axis] = 0. # Find number of points if trs_axis != data_axis: diag[trs_axis] = len(self.grid(diag[trs_axis], displ[trs_axis], size[trs_axis], centered=self._centered, trs=True)[1]) # Create the grid in the reciprocal cell sc = SuperCell(cell, origo=origo) grid = Grid(diag, sc=sc, dtype=v.dtype) if data_axis is None: grid[k2idx(k[0])] = v else: idx = k2idx(k[0]).tolist() weight = weights[idx[data_axis]] idx[data_axis] = slice(None) grid[idx] = v * weight del v # Now perform calculation if data_axis is None: for i in range(1, len(k)): grid[k2idx(k[i])] = wrap(func(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i]) eta.update() else: for i in range(1, len(k)): idx = k2idx(k[i]).tolist() weight = weights[idx[data_axis]] idx[data_axis] = slice(None) grid[idx] = wrap(func(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i]) * weight eta.update() eta.close() return grid