def _raw_construct(self, zsnhats): zs, nhats = zsnhats M, dim = nhats.shape # if no lowering is needed if self.nlow == 0: return self._pow2 * einsum(self.raw_einstr, zs, *[nhats] * self.v, optimize=self.raw_einpath) # lowering nhats first is better elif M * dim < dim**self.v: low_nhats = nhats * (flat_metric(dim)[np.newaxis]) einsum_args = [nhats] * self.nup + [low_nhats] * self.nlow return self._pow2 * einsum( self.raw_einstr, zs, *einsum_args, optimize=self.raw_einpath) # lowering EFM is better else: tensor = einsum(self.raw_einstr, zs, *[nhats] * self.v, optimize=self.raw_einpath) return self._pow2 * self._rl_construct(tensor)
def _efm_compute(self, efms_dict): einsum_args = [efms_dict[sig] for sig in self.efm_spec] return einsum(self.efm_einstr, *einsum_args, optimize=self.efm_einpath)
def _efp_compute(self, zs, thetas_dict): einsum_args = [thetas_dict[w] for w in self.weights] + self._n * [zs] return einsum(self.einstr, *einsum_args, optimize=self.einpath)