def transform_means(means, size, method='sigmoid'): """ Transforms raw parameters for the index tuples (with values in (-inf, inf)) into parameters within the bound of the dimensions of the tensor. In the case of a templated sparse layer, these parameters and the corresponding size tuple deascribe only the learned subtensor. :param means: (..., rank) tensor of raw parameter values :param size: Tuple describing the tensor dimensions. :return: (..., rank) """ # Compute upper bounds s = torch.tensor(list(size), dtype=torch.float, device='cuda' if means.is_cuda else 'cpu') - 1 s = util.unsqueezen(s, len(means.size()) - 1) s = s.expand_as(means) # Scale to [0, 1] if method == 'modulo': means = means.remainder(s) return means if method == 'clamp': means = means.clamp(0.0, 1.0) else: means = torch.sigmoid(means) return means * s
def transform_sigmas(sigmas, size, min_sigma=EPSILON): """ Transforms raw parameters for the conv matrices (with values in (-inf, inf)) into positive values, scaled proportional to the dimensions of the tensor. Note: each sigma is parametrized by a single value, which is expanded to a vector to fit the diagonal of the covariance matrix. In the case of a templated sparse layer, these parameters and the corresponing size tuple deascribe only the learned subtensor. :param sigmas: (..., ) matrix of raw sigma values :param size: Tuple describing the tensor dimensions. :param min_sigma: Minimal sigma value. :return:(..., rank) sigma values """ ssize = sigmas.size() r = len(size) # Scale to [0, 1] sigmas = F.softplus(sigmas + SIGMA_BOOST) + min_sigma # sigmas = sigmas[:, :, None].expand(b, k, r) sigmas = sigmas.unsqueeze(-1).expand(*(ssize + (r, ))) # Compute upper bounds s = torch.tensor(list(size), dtype=torch.float, device='cuda' if sigmas.is_cuda else 'cpu') s = util.unsqueezen(s, len(sigmas.size()) - 1) s = s.expand_as(sigmas) return sigmas * s
def ngenerate(means, gadditional, ladditional, rng=None, relative_range=None, seed=None, cuda=False, fm=None): """ Generates random integer index tuples based on continuous parameters. """ b = means.size(0) k, c, rank = means.size()[-3:] pref = means.size()[:-1] FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor if seed is not None: torch.manual_seed(seed) """ Generate neighbor tuples """ if fm is None: fm = floor_mask(rank, cuda) size = pref + (2**rank, rank) fm = util.unsqueezen(fm, len(size) - 2).expand(size) neighbor_ints = means.data.unsqueeze(-2).expand(*size).contiguous() neighbor_ints[fm] = neighbor_ints[fm].floor() neighbor_ints[~fm] = neighbor_ints[~fm].ceil() neighbor_ints = neighbor_ints.long() # print('means ', means.contiguous().view(-1, rank).max(dim=0)[0]) # print('neighbors ', neighbor_ints.view(-1, rank).max(dim=0)[0]) """ Sample uniformly from all integer tuples """ gsize = pref + (gadditional, rank) global_ints = FT(*gsize) global_ints.uniform_() global_ints *= (1.0 - EPSILON) rng = FT(rng) rngxp = util.unsqueezen(rng, len(gsize) - 1).expand_as(global_ints) global_ints = torch.floor(global_ints * rngxp).long() # print('globals ', global_ints.view(-1, rank).max(dim=0)[0]) """ Sample uniformly from a small range around the given index tuple """ lsize = pref + (ladditional, rank) local_ints = FT(*lsize) local_ints.uniform_() local_ints *= (1.0 - EPSILON) rngxp = util.unsqueezen(rng, len(lsize) - 1).expand_as( local_ints) # bounds of the tensor rrng = FT(relative_range) # bounds of the range from which to sample rrng = util.unsqueezen(rrng, len(lsize) - 1).expand_as(local_ints) # print(means.size()) mns_expand = means.round().unsqueeze(-2).expand_as(local_ints) # upper and lower bounds lower = mns_expand - rrng * 0.5 upper = mns_expand + rrng * 0.5 # check for any ranges that are out of bounds idxs = lower < 0.0 lower[idxs] = 0.0 idxs = upper > rngxp lower[idxs] = rngxp[idxs] - rrng[idxs] local_ints = (local_ints * rrng + lower).long() # print('mns_expand ', mns_expand.view(-1, rank).max(dim=0)[0]) # print('local ', local_ints.view(-1, rank).max(dim=0)[0]) all = torch.cat([neighbor_ints, global_ints, local_ints], dim=-2) fsize = pref[:-1] + (-1, rank) return all.view(*fsize) # combine all indices sampled within a chunk
def ngenerate(means, gadditional, ladditional, rng=None, relative_range=None, seed=None, cuda=False, fm=None, epsilon=EPSILON): """ Generates random integer index tuples based on continuous parameters. :param epsilon: The random bumbers are based on uniform samples in (0, 1-epsilon). Note that in some cases epsilon needs to be relatively big (e.g. 10-5) """ b = means.size(0) k, c, rank = means.size()[-3:] pref = means.size()[:-1] FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor rng = FT(tuple(rng)) # - the tuple() is there in case a torch.Size() object is passed (which causes torch to # interpret the argument as the size of the tensor rather than its content). bounds = util.unsqueezen( rng, len(pref) + 1).long() # index bound with unsqueezed dims for broadcasting if seed is not None: torch.manual_seed(seed) """ Generate neighbor tuples """ if fm is None: fm = floor_mask(rank, cuda) size = pref + (2**rank, rank) fm = util.unsqueezen(fm, len(size) - 2).expand(size) neighbor_ints = means.data.unsqueeze(-2).expand(*size).contiguous() neighbor_ints[fm] = neighbor_ints[fm].floor() neighbor_ints[~fm] = neighbor_ints[~fm].ceil() neighbor_ints = neighbor_ints.long() assert (neighbor_ints >= bounds).sum( ) == 0, 'One of the neighbor indices is outside the tensor bounds' """ Sample uniformly from all integer tuples """ gsize = pref + (gadditional, rank) global_ints = FT(*gsize) global_ints.uniform_() global_ints *= (1.0 - epsilon) rngxp = util.unsqueezen(rng, len(gsize) - 1).expand_as(global_ints) global_ints = torch.floor(global_ints * rngxp).long() assert (global_ints >= bounds).sum( ) == 0, 'One of the global sampled indices is outside the tensor bounds' """ Sample uniformly from a small range around the given index tuple """ lsize = pref + (ladditional, rank) local_ints = FT(*lsize) local_ints.uniform_() local_ints *= (1.0 - epsilon) rngxp = util.unsqueezen(rng, len(lsize) - 1).expand_as( local_ints) # bounds of the tensor rrng = FT(relative_range) # bounds of the range from which to sample rrng = util.unsqueezen(rrng, len(lsize) - 1).expand_as(local_ints) # print(means.size()) mns_expand = means.round().unsqueeze(-2).expand_as(local_ints) # upper and lower bounds lower = mns_expand - rrng * 0.5 upper = mns_expand + rrng * 0.5 # check for any ranges that are out of bounds idxs = lower < 0.0 lower[idxs] = 0.0 idxs = upper > rngxp lower[idxs] = rngxp[idxs] - rrng[idxs] cached = local_ints.clone() local_ints = (local_ints * rrng + lower).long() assert (local_ints >= bounds).sum() == 0, f'One of the local sampled indices is outside the tensor bounds (this may mean the epsilon is too small)' \ f'\n max sampled {(cached * rrng).max().item()}, rounded {(cached * rrng).max().long().item()} max lower limit {lower.max().item()}' \ f'\n sum {((cached * rrng).max() + lower.max()).item()}' \ f'\n rounds to {((cached * rrng).max() + lower.max()).long().item()}' #f'\n {means}\n {local_ints}\n {cached * rrng}' all = torch.cat([neighbor_ints, global_ints, local_ints], dim=-2) fsize = pref[:-1] + (-1, rank) return all.view(*fsize) # combine all indices sampled within a chunk