Beispiel #1
0
def transform_means(means, size, method='sigmoid'):
    """
    Transforms raw parameters for the index tuples (with values in (-inf, inf)) into parameters within the bound of the
    dimensions of the tensor.

    In the case of a templated sparse layer, these parameters and the corresponding size tuple deascribe only the learned
    subtensor.

    :param means: (..., rank) tensor of raw parameter values
    :param size: Tuple describing the tensor dimensions.
    :return: (..., rank)
    """

    # Compute upper bounds
    s = torch.tensor(list(size),
                     dtype=torch.float,
                     device='cuda' if means.is_cuda else 'cpu') - 1
    s = util.unsqueezen(s, len(means.size()) - 1)
    s = s.expand_as(means)

    # Scale to [0, 1]
    if method == 'modulo':
        means = means.remainder(s)

        return means

    if method == 'clamp':
        means = means.clamp(0.0, 1.0)
    else:
        means = torch.sigmoid(means)

    return means * s
Beispiel #2
0
def transform_sigmas(sigmas, size, min_sigma=EPSILON):
    """
    Transforms raw parameters for the conv matrices (with values in (-inf, inf)) into positive values, scaled proportional
    to the dimensions of the tensor. Note: each sigma is parametrized by a single value, which is expanded to a vector to
    fit the diagonal of the covariance matrix.

    In the case of a templated sparse layer, these parameters and the corresponing size tuple deascribe only the learned
    subtensor.

    :param sigmas: (..., ) matrix of raw sigma values
    :param size: Tuple describing the tensor dimensions.
    :param min_sigma: Minimal sigma value.
    :return:(..., rank) sigma values
    """
    ssize = sigmas.size()
    r = len(size)

    # Scale to [0, 1]
    sigmas = F.softplus(sigmas + SIGMA_BOOST) + min_sigma
    # sigmas = sigmas[:, :, None].expand(b, k, r)
    sigmas = sigmas.unsqueeze(-1).expand(*(ssize + (r, )))

    # Compute upper bounds
    s = torch.tensor(list(size),
                     dtype=torch.float,
                     device='cuda' if sigmas.is_cuda else 'cpu')
    s = util.unsqueezen(s, len(sigmas.size()) - 1)
    s = s.expand_as(sigmas)

    return sigmas * s
Beispiel #3
0
def ngenerate(means,
              gadditional,
              ladditional,
              rng=None,
              relative_range=None,
              seed=None,
              cuda=False,
              fm=None):
    """

    Generates random integer index tuples based on continuous parameters.

    """

    b = means.size(0)
    k, c, rank = means.size()[-3:]
    pref = means.size()[:-1]

    FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    if seed is not None:
        torch.manual_seed(seed)
    """
    Generate neighbor tuples
    """
    if fm is None:
        fm = floor_mask(rank, cuda)

    size = pref + (2**rank, rank)
    fm = util.unsqueezen(fm, len(size) - 2).expand(size)

    neighbor_ints = means.data.unsqueeze(-2).expand(*size).contiguous()

    neighbor_ints[fm] = neighbor_ints[fm].floor()
    neighbor_ints[~fm] = neighbor_ints[~fm].ceil()

    neighbor_ints = neighbor_ints.long()
    # print('means     ', means.contiguous().view(-1, rank).max(dim=0)[0])
    # print('neighbors ', neighbor_ints.view(-1, rank).max(dim=0)[0])
    """
    Sample uniformly from all integer tuples
    """
    gsize = pref + (gadditional, rank)
    global_ints = FT(*gsize)

    global_ints.uniform_()
    global_ints *= (1.0 - EPSILON)

    rng = FT(rng)
    rngxp = util.unsqueezen(rng, len(gsize) - 1).expand_as(global_ints)

    global_ints = torch.floor(global_ints * rngxp).long()

    # print('globals ', global_ints.view(-1, rank).max(dim=0)[0])
    """
    Sample uniformly from a small range around the given index tuple
    """
    lsize = pref + (ladditional, rank)
    local_ints = FT(*lsize)

    local_ints.uniform_()
    local_ints *= (1.0 - EPSILON)

    rngxp = util.unsqueezen(rng,
                            len(lsize) - 1).expand_as(
                                local_ints)  # bounds of the tensor

    rrng = FT(relative_range)  # bounds of the range from which to sample
    rrng = util.unsqueezen(rrng, len(lsize) - 1).expand_as(local_ints)

    # print(means.size())
    mns_expand = means.round().unsqueeze(-2).expand_as(local_ints)

    # upper and lower bounds
    lower = mns_expand - rrng * 0.5
    upper = mns_expand + rrng * 0.5

    # check for any ranges that are out of bounds
    idxs = lower < 0.0
    lower[idxs] = 0.0

    idxs = upper > rngxp
    lower[idxs] = rngxp[idxs] - rrng[idxs]

    local_ints = (local_ints * rrng + lower).long()

    # print('mns_expand ', mns_expand.view(-1, rank).max(dim=0)[0])
    # print('local      ', local_ints.view(-1, rank).max(dim=0)[0])

    all = torch.cat([neighbor_ints, global_ints, local_ints], dim=-2)

    fsize = pref[:-1] + (-1, rank)
    return all.view(*fsize)  # combine all indices sampled within a chunk
Beispiel #4
0
def ngenerate(means,
              gadditional,
              ladditional,
              rng=None,
              relative_range=None,
              seed=None,
              cuda=False,
              fm=None,
              epsilon=EPSILON):
    """

    Generates random integer index tuples based on continuous parameters.

    :param epsilon: The random bumbers are based on uniform samples in (0, 1-epsilon). Note that
      in some cases epsilon needs to be relatively big (e.g. 10-5)

    """

    b = means.size(0)
    k, c, rank = means.size()[-3:]
    pref = means.size()[:-1]

    FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    rng = FT(tuple(rng))
    # - the tuple() is there in case a torch.Size() object is passed (which causes torch to
    #   interpret the argument as the size of the tensor rather than its content).

    bounds = util.unsqueezen(
        rng,
        len(pref) +
        1).long()  # index bound with unsqueezed dims for broadcasting

    if seed is not None:
        torch.manual_seed(seed)
    """
    Generate neighbor tuples
    """
    if fm is None:
        fm = floor_mask(rank, cuda)

    size = pref + (2**rank, rank)
    fm = util.unsqueezen(fm, len(size) - 2).expand(size)

    neighbor_ints = means.data.unsqueeze(-2).expand(*size).contiguous()

    neighbor_ints[fm] = neighbor_ints[fm].floor()
    neighbor_ints[~fm] = neighbor_ints[~fm].ceil()

    neighbor_ints = neighbor_ints.long()

    assert (neighbor_ints >= bounds).sum(
    ) == 0, 'One of the neighbor indices is outside the tensor bounds'
    """
    Sample uniformly from all integer tuples
    """
    gsize = pref + (gadditional, rank)
    global_ints = FT(*gsize)

    global_ints.uniform_()
    global_ints *= (1.0 - epsilon)

    rngxp = util.unsqueezen(rng, len(gsize) - 1).expand_as(global_ints)

    global_ints = torch.floor(global_ints * rngxp).long()

    assert (global_ints >= bounds).sum(
    ) == 0, 'One of the global sampled indices is outside the tensor bounds'
    """
    Sample uniformly from a small range around the given index tuple
    """
    lsize = pref + (ladditional, rank)
    local_ints = FT(*lsize)

    local_ints.uniform_()
    local_ints *= (1.0 - epsilon)

    rngxp = util.unsqueezen(rng,
                            len(lsize) - 1).expand_as(
                                local_ints)  # bounds of the tensor

    rrng = FT(relative_range)  # bounds of the range from which to sample
    rrng = util.unsqueezen(rrng, len(lsize) - 1).expand_as(local_ints)

    # print(means.size())
    mns_expand = means.round().unsqueeze(-2).expand_as(local_ints)

    # upper and lower bounds
    lower = mns_expand - rrng * 0.5
    upper = mns_expand + rrng * 0.5

    # check for any ranges that are out of bounds
    idxs = lower < 0.0
    lower[idxs] = 0.0

    idxs = upper > rngxp
    lower[idxs] = rngxp[idxs] - rrng[idxs]

    cached = local_ints.clone()
    local_ints = (local_ints * rrng + lower).long()

    assert (local_ints >= bounds).sum() == 0, f'One of the local sampled indices is outside the tensor bounds (this may mean the epsilon is too small)' \
        f'\n max sampled  {(cached * rrng).max().item()}, rounded {(cached * rrng).max().long().item()}  max lower limit {lower.max().item()}' \
        f'\n sum          {((cached * rrng).max() + lower.max()).item()}' \
        f'\n rounds to    {((cached * rrng).max() + lower.max()).long().item()}'
    #f'\n {means}\n {local_ints}\n {cached * rrng}'

    all = torch.cat([neighbor_ints, global_ints, local_ints], dim=-2)

    fsize = pref[:-1] + (-1, rank)

    return all.view(*fsize)  # combine all indices sampled within a chunk