예제 #1
0
def tr_to_tensor(factors):
    """Returns the full tensor whose TR decomposition is given by 'factors'

        Re-assembles 'factors', which represent a tensor in TR format
        into the corresponding full tensor

    Parameters
    ----------
    factors : list of 3D-arrays
              TR factors (TR-cores)

    Returns
    -------
    output_tensor : ndarray
                   tensor whose TR decomposition was given by 'factors'
    """
    full_shape = [f.shape[1] for f in factors]
    full_tensor = tl.reshape(factors[0], (-1, factors[0].shape[2]))

    for factor in factors[1:-1]:
        rank_prev, _, rank_next = factor.shape
        factor = tl.reshape(factor, (rank_prev, -1))
        full_tensor = tl.dot(full_tensor, factor)
        full_tensor = tl.reshape(full_tensor, (-1, rank_next))

    full_tensor = tl.reshape(full_tensor,
                             (factors[-1].shape[2], -1, factors[-1].shape[0]))
    full_tensor = tl.moveaxis(full_tensor, 0, -1)
    full_tensor = tl.reshape(full_tensor,
                             (-1, factors[-1].shape[0] * factors[-1].shape[2]))
    factor = tl.moveaxis(factors[-1], -1, 1)
    factor = tl.reshape(factor, (-1, full_shape[-1]))
    full_tensor = tl.dot(full_tensor, factor)
    return tl.reshape(full_tensor, full_shape)
예제 #2
0
def convert_mps_back_to_mpo(mpo, mps):
    new_mpo = []
    for i in range(len(mps)):
        core = mps[i]
        core = tl.moveaxis(core, 1, 2)
        core = tl.base.partial_fold(core, mode = 2, shape = mpo[i].shape)
        new_mpo.append(core)
    return mpo
예제 #3
0
def convert_mpo_to_mps(mpo):
    mps = []
    for i in range(len(mpo)):
        core = mpo[i]
        core = tl.base.partial_unfold(core, mode = 2)
        core = tl.moveaxis(core, 1, 2)
        mps.append(core)
    return mps
예제 #4
0
def convolve(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    """Convolution of any specified order, wrapper on torch's F.convNd

    Parameters
    ----------
    x : torch.Tensor or FactorizedTensor
        input tensor
    weight : torch.Tensor
        convolutional weights
    bias : bool, optional
        by default None
    stride : int, optional
        by default 1
    padding : int, optional
        by default 0
    dilation : int, optional
        by default 1
    groups : int, optional
        by default 1

    Returns
    -------
    torch.Tensor
        `x` convolved with `weight`
    """
    try:
        if torch.is_tensor(weight):
            return _CONVOLUTION[weight.ndim - 2](x, weight, bias=bias, stride=stride, padding=padding, 
                                                 dilation=dilation, groups=groups)
        else:
            if isinstance(weight, TTTensor):
                weight = tl.moveaxis(weight.to_tensor(), -1, 0)
            else:
                weight = weight.to_tensor()
            return _CONVOLUTION[weight.ndim - 2](x, weight, bias=bias, stride=stride, padding=padding, 
                                                 dilation=dilation, groups=groups)
    except KeyError:
        raise ValueError(f'Got tensor of order={weight.ndim} but pytorch only supports up to 3rd order (3D) Convs.')
예제 #5
0
def convert_mps_core_back_to_mpo_core(mpo_core, mps_core):
    new_mpo_core = tl.moveaxis(mps_core, 1, 2)
    new_mpo_core = tl.base.partial_fold(new_mpo_core, mode=2, shape = mpo_core.shape)
    return new_mpo_core
예제 #6
0
def kernel_to_tensor(factorization, kernel):
    """Returns a convolutional kernel ready to be factorized
    """
    if factorization.lower() == 'tt':
        kernel = tl.moveaxis(kernel, 0, -1)
    return kernel
예제 #7
0
def tensor_to_kernel(factorization, tensor):
    """Returns a kernel from a tensor factorization
    """
    if factorization.lower() == 'tt':
        tensor = tl.moveaxis(tensor, -1, 0)
    return tensor