def tr_to_tensor(factors): """Returns the full tensor whose TR decomposition is given by 'factors' Re-assembles 'factors', which represent a tensor in TR format into the corresponding full tensor Parameters ---------- factors : list of 3D-arrays TR factors (TR-cores) Returns ------- output_tensor : ndarray tensor whose TR decomposition was given by 'factors' """ full_shape = [f.shape[1] for f in factors] full_tensor = tl.reshape(factors[0], (-1, factors[0].shape[2])) for factor in factors[1:-1]: rank_prev, _, rank_next = factor.shape factor = tl.reshape(factor, (rank_prev, -1)) full_tensor = tl.dot(full_tensor, factor) full_tensor = tl.reshape(full_tensor, (-1, rank_next)) full_tensor = tl.reshape(full_tensor, (factors[-1].shape[2], -1, factors[-1].shape[0])) full_tensor = tl.moveaxis(full_tensor, 0, -1) full_tensor = tl.reshape(full_tensor, (-1, factors[-1].shape[0] * factors[-1].shape[2])) factor = tl.moveaxis(factors[-1], -1, 1) factor = tl.reshape(factor, (-1, full_shape[-1])) full_tensor = tl.dot(full_tensor, factor) return tl.reshape(full_tensor, full_shape)
def convert_mps_back_to_mpo(mpo, mps): new_mpo = [] for i in range(len(mps)): core = mps[i] core = tl.moveaxis(core, 1, 2) core = tl.base.partial_fold(core, mode = 2, shape = mpo[i].shape) new_mpo.append(core) return mpo
def convert_mpo_to_mps(mpo): mps = [] for i in range(len(mpo)): core = mpo[i] core = tl.base.partial_unfold(core, mode = 2) core = tl.moveaxis(core, 1, 2) mps.append(core) return mps
def convolve(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): """Convolution of any specified order, wrapper on torch's F.convNd Parameters ---------- x : torch.Tensor or FactorizedTensor input tensor weight : torch.Tensor convolutional weights bias : bool, optional by default None stride : int, optional by default 1 padding : int, optional by default 0 dilation : int, optional by default 1 groups : int, optional by default 1 Returns ------- torch.Tensor `x` convolved with `weight` """ try: if torch.is_tensor(weight): return _CONVOLUTION[weight.ndim - 2](x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) else: if isinstance(weight, TTTensor): weight = tl.moveaxis(weight.to_tensor(), -1, 0) else: weight = weight.to_tensor() return _CONVOLUTION[weight.ndim - 2](x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) except KeyError: raise ValueError(f'Got tensor of order={weight.ndim} but pytorch only supports up to 3rd order (3D) Convs.')
def convert_mps_core_back_to_mpo_core(mpo_core, mps_core): new_mpo_core = tl.moveaxis(mps_core, 1, 2) new_mpo_core = tl.base.partial_fold(new_mpo_core, mode=2, shape = mpo_core.shape) return new_mpo_core
def kernel_to_tensor(factorization, kernel): """Returns a convolutional kernel ready to be factorized """ if factorization.lower() == 'tt': kernel = tl.moveaxis(kernel, 0, -1) return kernel
def tensor_to_kernel(factorization, tensor): """Returns a kernel from a tensor factorization """ if factorization.lower() == 'tt': tensor = tl.moveaxis(tensor, -1, 0) return tensor