def _raw_construct(self, zsnhats): zs, nhats = zsnhats M, dim = nhats.shape # if no lowering is needed if self.nlow == 0: return self._pow2 * einsum(self.raw_einstr, zs, *[nhats] * self.v, optimize=self.raw_einpath) # lowering nhats first is better elif M * dim < dim**self.v: low_nhats = nhats * (flat_metric(dim)[np.newaxis]) einsum_args = [nhats] * self.nup + [low_nhats] * self.nlow return self._pow2 * einsum( self.raw_einstr, zs, *einsum_args, optimize=self.raw_einpath) # lowering EFM is better else: tensor = einsum(self.raw_einstr, zs, *[nhats] * self.v, optimize=self.raw_einpath) return self._pow2 * self._rl_construct(tensor)
def _rl_construct(self, tensor): # fine to use pure c_einsum here as it's used anyway return c_einsum(self.rl_einstr, tensor, *[flat_metric(len(tensor))] * self._rl_diff)