def get_tile_params_at_coord(self, plocs: torch.Tensor): """Return the parameters of the tiles that contain each of the locations in plocs.""" assert len(plocs.shape) == 2 and plocs.shape[1] == 2 assert plocs.device == self.locs.device n_total = len(plocs) slen = self.n_tiles_h * self.tile_slen wlen = self.n_tiles_w * self.tile_slen # coordinates on tiles. x_coords = torch.arange(0, slen, self.tile_slen, device=self.locs.device).long() y_coords = torch.arange(0, wlen, self.tile_slen, device=self.locs.device).long() x_indx = torch.searchsorted(x_coords.contiguous(), plocs[:, 0].contiguous()) - 1 y_indx = torch.searchsorted(y_coords.contiguous(), plocs[:, 1].contiguous()) - 1 return { k: v[:, x_indx, y_indx, :, :].reshape(n_total, -1) for k, v in self.items() }
def _cdf_distance(u_values, v_values): u_sorter = torch.argsort(u_values) v_sorter = torch.argsort(v_values) all_values = torch.cat((u_values, v_values)) all_values, _ = torch.sort(all_values) # Compute the differences between pairs of successive values of u and v. deltas = torch_diff(all_values) # Get the respective positions of the values of u and v among the values of # both distributions. u_cdf_indices = torch.searchsorted(u_values[u_sorter], all_values[:-1], right=True) v_cdf_indices = torch.searchsorted(v_values[v_sorter], all_values[:-1], right=True) # Calculate the CDFs of u and v using their weights, if specified. u_cdf = u_cdf_indices / u_values.shape[0] v_cdf = v_cdf_indices / v_values.shape[0] return torch.sum(torch.mul(torch.abs(u_cdf - v_cdf), deltas))
def _torch_cdf_distance(tensor_a, tensor_b): """ Torch implementation of _cdf_distance for Wasserstein distance input: tensor_a, tensor_b output: cdf_loss which the computed distance between the tensors. #Note: this function yields an difference of \approx 10^-9 Updated for batch support | 29/03/2022 Updated for multivariate time series support | 29/03/2022 Expects tensor_a and tensor_b to be of shape: (batch_size, segment_length, n_features), Example: a single batch of 10 time series with lengths of 12 should have shape=(1, 12, 10) """ batch_size = tensor_a.shape[0] assert tensor_a.shape == tensor_b.shape, 'tensor_a and tensor_b have different shape' #It is necessary to reshape the tensors to match the dimensions of Scipy. tensor_a = torch.reshape(torch.swapaxes( tensor_a, -1, -2), (batch_size, tensor_a.shape[2], tensor_a.shape[1])) tensor_b = torch.reshape(torch.swapaxes( tensor_b, -1, -2), (batch_size, tensor_b.shape[2], tensor_b.shape[1])) # Creater sorters: sorter_a = torch.argsort(tensor_a, dim=-1) sorter_b = torch.argsort(tensor_a, dim=-1) # We append both tensors and sort them all_values = torch.cat((tensor_a, tensor_b), dim=-1) all_values, idx = torch.sort(all_values, dim=-1) # Calculate the n-th discrete difference along the given axis (equivalent to np.diff()) deltas = all_values[:, :, 1:] - all_values[:, :, :-1] sorted_a, idx = torch.sort(tensor_a, dim=-1) sorted_b, idx = torch.sort(tensor_b, dim=-1) # Get the respective positions of the values of u and v among the values of # both distributions. a_cdf_index = torch.searchsorted(sorted_a.flatten(start_dim=2), all_values[:, :, :-1], right=True) # TODO: torch.searchsorted() expects contiguousarrays, passing non-contiguousarrays slows performance due to data copy | fix doesn't seem trivial b_cdf_index = torch.searchsorted(sorted_b.flatten(start_dim=2), all_values[:, :, :-1], right=True) #Compute the cdf a_cdf = a_cdf_index / tensor_a.shape[-1] b_cdf = b_cdf_index / tensor_b.shape[-1] #And the distance between them cdf_distance = torch.sum(torch.mul(torch.abs((a_cdf - b_cdf)), deltas), dim=-1) cdf_loss = cdf_distance.mean() return cdf_loss
def test_searchsorted_cpu(self): for i in range(1, 3): s = np.sort(np.random.rand(*((10, ) * i)), -1) v = np.random.rand(*((10, ) * i)) s_jt = jt.array(s) v_jt = jt.array(v) s_tc = torch.from_numpy(s) v_tc = torch.from_numpy(v) y_tc = torch.searchsorted(s_tc, v_tc, right=True) y_jt = jt.searchsorted(s_jt, v_jt, right=True) assert np.allclose(y_jt.numpy(), y_tc.data) y_jt = jt.searchsorted(s_jt, v_jt, right=False) y_tc = torch.searchsorted(s_tc, v_tc, right=False) assert np.allclose(y_jt.numpy(), y_tc.data)
def _torch_cdf_distance(tensor_a, tensor_b): """ Torch implementation of _cdf_distance for Wasserstein distance input: tensor_a, tensor_b output: cdf_loss which the computed distance between the tensors. #Note: this function yields an difference of \approx 10^-9 """ #It is necessary to reshape the tensors to match the dimensions of Scipy. tensor_a = torch.reshape(tensor_a, (1, tensor_a.shape[0])) tensor_b = torch.reshape(tensor_b, (1, tensor_b.shape[0])) # Creater sorters: sorter_a = torch.argsort(tensor_a) sorter_b = torch.argsort(tensor_a) # We append both tensors and sort them all_values = torch.cat((tensor_a, tensor_b), dim=1) all_values, idx = torch.sort(all_values, dim=1) # Calculate the n-th discrete difference along the given axis (equivalent to np.diff()) deltas = all_values[0, 1:] - all_values[0, :-1] sorted_a, idx = torch.sort(tensor_a, dim=1) sorted_b, idx = torch.sort(tensor_b, dim=1) # Get the respective positions of the values of u and v among the values of # both distributions. a_cdf_index = torch.searchsorted(sorted_a.flatten(), all_values[0, :-1], right=True) b_cdf_index = torch.searchsorted(sorted_b.flatten(), all_values[0, :-1], right=True) #Compute the cdf a_cdf = a_cdf_index / tensor_a.shape[1] b_cdf = b_cdf_index / tensor_b.shape[1] #And the distance between them cdf_distance = torch.sum(torch.mul(torch.abs((a_cdf - b_cdf)), deltas), dim=-1) cdf_loss = cdf_distance.mean() return cdf_loss
def sample(self, labels, index_positive, optimizer): self.step += 1 positive = torch.unique(labels[index_positive], sorted=True).cuda() if self.num_sample - positive.size(0) >= 0: perm = torch.rand(size=[self.num_local]).cuda() perm[positive] = 2.0 index = torch.topk(perm, k=self.num_sample)[1].cuda() index = index.sort()[0].cuda() else: index = positive self.weight_index = index labels[index_positive] = torch.searchsorted(index, labels[index_positive]) self.weight_activated = torch.nn.Parameter( self.weight[self.weight_index]) self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index] self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[ self.weight_index] if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)): # TODO the params of partial fc must be last in the params list optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None) optimizer.param_groups[-1]["params"][0] = self.weight_activated optimizer.state[self.weight_activated][ "exp_avg"] = self.weight_activated_exp_avg optimizer.state[self.weight_activated][ "exp_avg_sq"] = self.weight_activated_exp_avg_sq optimizer.state[self.weight_activated]["step"] = self.step else: raise
def forward(self, x): # x: (ndata, ndim) 2d array xx, yy, delta = self._prepare() #(ndim, nknot) index = torch.searchsorted(xx.detach(), x.T.contiguous().detach()).T y = torch.zeros_like(x) logderiv = torch.zeros_like(x) #linear extrapolation select0 = index == 0 dim = torch.repeat_interleave(torch.arange(self.ndim).view(1,-1), len(x), dim=0)[select0] y[select0] = yy[dim, 0] + (x[select0]-xx[dim, 0]) * delta[dim, 0] logderiv[select0] = self.logderiv[dim, 0] selectn = index == self.nknot dim = torch.repeat_interleave(torch.arange(self.ndim).view(1,-1), len(x), dim=0)[selectn] y[selectn] = yy[dim, -1] + (x[selectn]-xx[dim, -1]) * delta[dim, -1] logderiv[selectn] = self.logderiv[dim, -1] #rational quadratic spline select = ~(select0 | selectn) index = index[select] dim = torch.repeat_interleave(torch.arange(self.ndim).view(1,-1), len(x), dim=0)[select] xi = (x[select] - xx[dim, index-1]) / (xx[dim, index] - xx[dim, index-1]) s = (yy[dim, index]-yy[dim, index-1]) / (xx[dim, index]-xx[dim, index-1]) xi1_xi = xi*(1-xi) denominator = s + (delta[dim, index]+delta[dim, index-1]-2*s)*xi1_xi xi2 = xi**2 y[select] = yy[dim, index-1] + ((yy[dim, index]-yy[dim, index-1]) * (s*xi2+delta[dim, index-1]*xi1_xi)) / denominator logderiv[select] = 2*torch.log(s) + torch.log(delta[dim, index]*xi2 + 2*s*xi1_xi + delta[dim, index-1]*(1-xi)**2) - 2 * torch.log(denominator) return y, logderiv
def __getitem__(self, k): k = self.remapping[k] channels = self.dataset[k] if self.prefix and self.only_prefix: dur_channel = channels["dur_source"] assert dur_channel.sum() >= self.prefix token_times = dur_channel.cumsum(dim=-1) cut_after = torch.searchsorted(token_times, torch.tensor(self.prefix)) r = {} for channel_name, value in channels.items(): if isinstance(value, torch.Tensor) and "source" in channel_name: # if self.filter_short: assert value.size(0) >= self.prefix r[channel_name] = value[:cut_after + 1] else: r[channel_name] = value r["prefix"] = cut_after + 1 else: r = channels return r
def sample_pdf(bins, weights, N_samples): # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Take uniform samples u = torch.linspace(0., 1., steps=N_samples) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf.detach(), u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def sample(self, total_label): """ Sample all positive class centers in each rank, and random select neg class centers to filling a fixed `num_sample`. total_label: tensor Label after all gather, which cross all GPUs. """ index_positive = (self.class_start <= total_label) & ( total_label < self.class_start + self.num_local) total_label[~index_positive] = -1 total_label[index_positive] -= self.class_start if int(self.sample_rate) != 1: positive = torch.unique(total_label[index_positive], sorted=True) if self.num_sample - positive.size(0) >= 0: perm = torch.rand(size=[self.num_local], device=self.device) perm[positive] = 2.0 index = torch.topk(perm, k=self.num_sample)[1] index = index.sort()[0] else: index = positive self.index = index total_label[index_positive] = torch.searchsorted( index, total_label[index_positive]) self.sub_weight = Parameter(self.weight[index]) self.sub_weight_mom = self.weight_mom[index]
def is_neighbor(rowptr, col, a, b): # O(log(d_bar)) a_neighbs = get_neighbors(rowptr, col, a) if b <= a_neighbs[-1] and torch.searchsorted(a_neighbs, b, right=True): return True else: return False
def batch_sample(batch, num, baseline=False, deterministic=False): """ sample indices in batch (not repeat) Args: batch (tensor): the batch index in increasing order num (tensor): the num of samples for each batch baseline (bool): baseline is to use pure pytorch implementation deterministic (bool): whether produce the same result (fix the seed) Returns: index (tensor): the sampled indices (int64) """ assert batch[-1] + 1 == len(num), "num of batch does not match!" assert batch.dtype in [torch.int64], "unsupported data type for `batch`!" assert batch.device == num.device, "`batch` and `num` must be on the same device!" device = batch.device if deterministic: torch.manual_seed(0) value = torch.rand(batch.shape, device=device) if deterministic: torch.manual_seed(random.randint(0, 9999)) start_ind = torch.repeat_interleave( torch.searchsorted(batch, torch.arange(len(num), device=device)), num) ind = torch.arange(len(start_ind), device=device) - \ torch.repeat_interleave(torch.cat([torch.tensor([0], device=device), torch.cumsum(num, dim=0)[:-1]], dim=0), num) index_out = batch_sort(value, batch, baseline=baseline)[start_ind + ind] return index_out
def sample_ancestral_index(self, log_weights): """ sample ancestral indices """ sample_dim, batch_dim = log_weights.shape if self.strategy == 'systematic': positions = (self.uniformer.sample( (batch_dim, )) + self.spacing) / self.S # weights = log_weights.exp() normalized_weights = F.softmax(log_weights, 0) cumsums = torch.cumsum(normalized_weights.transpose(0, 1), dim=1) (normalizers, _) = torch.max(input=cumsums, dim=1, keepdim=True) normalized_cumsums = cumsums / normalizers ## B * S ancestral_index = torch.searchsorted(normalized_cumsums, positions) assert ancestral_index.shape == ( batch_dim, sample_dim ), "ERROR! systematic resampling resulted unexpected index shape." ancestral_index = ancestral_index.transpose(0, 1) elif self.strategy == 'multinomial': normalized_weights = F.softmax(log_weights, 0) ancestral_index = Categorical(normalized_weights.transpose( 0, 1)).sample((sample_dim, )) else: print("ERROR! unexpected resampling strategy.") return ancestral_index
def elliptical_slice( self, initial_theta: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: x1 = torch.randn_like(initial_theta) # we need to draw the new angle t_new # start by finding the slices rotation_angle, rotation_slices = self.find_rotated_intersections( initial_theta, x1) rotation_slices = rotation_slices.reshape(-1, 2) rotation_lengths = rotation_slices[:, 1] - rotation_slices[:, 0] # now construct the rotation angle cum_lengths = torch.cat(( torch.zeros(1, device=rotation_lengths.device, dtype=rotation_lengths.dtype), torch.cumsum(rotation_lengths, 0), )) random_angle = torch.rand(1) * cum_lengths[-1] idx = torch.searchsorted(cum_lengths, random_angle) - 1 t_new = rotation_slices[ idx, 0] + random_angle - cum_lengths[idx] + rotation_angle return self._sample_on_slice(t_new, initial_theta, x1)
def energy_spectrum(vel): """ Compute energy spectrum given a velocity field :param vel: tensor of shape (N, 3, res, res, res) :return spec: tensor of shape(N, res/2) :return k: tensor of shape (res/2,), frequencies corresponding to spec """ device = vel.device res = vel.shape[-2:] assert(res[0] == res[1]) r = res[0] k_end = int(r/2) vel_ = pad_rfft3(vel, onesided=False) # (N, 3, res, res, res, 2) uu_ = (torch.norm(vel_, dim=-1) / r**3)**2 e_ = torch.sum(uu_, dim=1) # (N, res, res, res) k = fftfreqs(res).to(device) # (3, res, res, res) rad = torch.norm(k, dim=0) # (res, res, res) k_bin = torch.arange(k_end, device=device).float()+1 bins = torch.zeros(k_end+1).to(device) bins[1:-1] = (k_bin[1:]+k_bin[:-1])/2 bins[-1] = k_bin[-1] bins = bins.unsqueeze(0) bins[1:] += 1e-3 inds = torch.searchsorted(bins, rad.flatten().unsqueeze(0)).squeeze().int() # bincount = torch.histc(inds.cpu(), bins=bins.shape[1]+1).to(device) bincount = torch.bincount(inds) asort = torch.argsort(inds.squeeze()) sorted_e_ = e_.view(e_.shape[0], -1)[:, asort] csum_e_ = torch.cumsum(sorted_e_, dim=1) binloc = torch.cumsum(bincount, dim=0).long()-1 spec_ = csum_e_[:,binloc[1:]] - csum_e_[:,binloc[:-1]] spec_ = spec_[:, :-1] spec_ = spec_ * 2 * np.pi * (k_bin.float()**2) / bincount[1:-1].float() return spec_, k_bin
def sample_fine(self, rays, weights): """ Weighted stratified (importance) sample :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) :param weights (B, Kc) :return (B, Kf-Kfd) """ device = rays.device B = rays.shape[0] weights = weights.detach() + 1e-5 # Prevent division by zero pdf = weights / torch.sum(weights, -1, keepdim=True) # (B, Kc) cdf = torch.cumsum(pdf, -1) # (B, Kc) cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (B, Kc+1) u = torch.rand(B, self.n_fine - self.n_fine_depth, dtype=torch.float32, device=device) # (B, Kf) inds = torch.searchsorted(cdf, u, right=True).float() - 1.0 # (B, Kf) inds = torch.clamp_min(inds, 0.0) z_steps = (inds + torch.rand_like(inds)) / self.n_coarse # (B, Kf) near, far = rays[:, -2:-1], rays[:, -1:] # (B, 1) if not self.lindisp: # Use linear sampling in depth space z_samp = near * (1 - z_steps) + far * z_steps # (B, Kf) else: # Use linear sampling in disparity space z_samp = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf) return z_samp
def systematic(w: torch.Tensor, normalized=False, u: Union[torch.Tensor, float] = None) -> torch.Tensor: """ Performs systematic resampling on either a 1D or 2D array. Args: w: the log weights to use for resampling. normalized: whether the weights are normalized are not. u: parameter for overriding the sampled index, only used for testing. """ is_1d = w.dim() == 1 if is_1d: w = w.unsqueeze(0) shape = (w.shape[0], 1) u = u if u is not None else (torch.empty(shape, device=w.device)).uniform_() w = normalize(w) if not normalized else w n = w.shape[1] index_range = torch.arange(n, dtype=u.dtype, device=w.device).unsqueeze(0) probs = (index_range + u) / n cumsum = w.cumsum(-1) cumsum[..., -1] = 1.0 res = torch.searchsorted(cumsum, probs) return res.squeeze(0) if is_1d else res
def sample_cdf(z_vals, weights, det=False): bins = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) weights = weights + 1e-5 pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) if det: u = torch.linspace(0., 1., config.N_samples) u = torch.broadcast_to(u, list(cdf.shape[:-1]) + [config.N_samples]) else: u = torch.rand(cdf.shape[0], config.N_samples) u = u.cuda(cdf.device) idxs = torch.searchsorted(cdf, u, right=True) below = torch.maximum(torch.zeros_like(idxs), idxs - 1).long() above = torch.minimum(torch.ones_like(idxs) * (cdf.shape[-1] - 1), idxs).long() # idxs_g = torch.stack([below, above], -1) cdf_below = torch.gather(cdf, dim=1, index=below) cdf_above = torch.gather(cdf, dim=1, index=above) bin_below = torch.gather(bins, dim=1, index=below) bin_above = torch.gather(bins, dim=1, index=above) denom = cdf_above - cdf_below denom = torch.clamp(denom, 1e-5, 99999) t = (u - cdf_below) / denom samples = bin_below + t * (bin_above - bin_below) return samples
def sample_pdf(bins, weights, num_samples, det=False): # TESTED (Carefully, line-to-line). # But chances of bugs persist; haven't integration-tested with # training routines. # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / weights.sum(-1).unsqueeze(-1) cdf = torch.cumsum(pdf, -1) cdf = torch.cat((torch.zeros_like(cdf[..., :1]), cdf), -1) # Take uniform samples if det: u = torch.linspace(0.0, 1.0, num_samples).to(weights) u = u.expand(list(cdf.shape[:-1]) + [num_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [num_samples]).to(weights) # Invert CDF inds = torch.searchsorted(cdf.contiguous(), u.contiguous(), right=True) below = torch.max(torch.zeros_like(inds), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack((below, above), -1) orig_inds_shape = inds_g.shape cdf_g = gather_cdf_util(cdf, inds_g) bins_g = gather_cdf_util(bins, inds_g) denom = cdf_g[..., 1] - cdf_g[..., 0] denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def _interp(self, xq, y): if self.y_is_given: y = self.y x = self.x nr = x.shape[-1] idxr = torch.searchsorted(x.detach(), xq.detach(), right=False) # (nrq) idxr = torch.clamp(idxr, 1, nr - 1) idxl = idxr - 1 # (nrq) from (0 to nr-2) if torch.numel(xq) > torch.numel(x): yl = y[..., :-1] # (*BY, nr-1) xl = x[..., :-1] # (nr-1) dy = y[..., 1:] - yl # (*BY, nr-1) dx = x[..., 1:] - xl # (nr-1) t = (xq - torch.gather(xl, -1, idxl)) / torch.gather(dx, -1, idxl) # (nrq) yq = dy[..., idxl] * t yq += yl[..., idxl] return yq else: xl = torch.gather(x, -1, idxl) xr = torch.gather(x, -1, idxr) yl = y[..., idxl].contiguous() yr = y[..., idxr].contiguous() dxrl = xr - xl # (nrq,) dyrl = yr - yl # (nbatch, nrq) t = (xq - xl) / dxrl # (nrq,) yq = yl + dyrl * t return yq
def __call__(self, x): idx = torch.searchsorted(self.xs, x, right=True) - 1 if idx >= self.n: return self.ys[-1] elif idx < 0: return self[0] else: return self.ys[idx] + self.slopes[idx] * (x - self.xs[idx])
def emd1D(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True): n = u_values.shape[-1] m = v_values.shape[-1] device = u_values.device dtype = u_values.dtype if u_weights is None: u_weights = torch.full((n, ), 1 / n, dtype=dtype, device=device) if v_weights is None: v_weights = torch.full((m, ), 1 / m, dtype=dtype, device=device) if require_sort: u_values, u_sorter = torch.sort(u_values, -1) v_values, v_sorter = torch.sort(v_values, -1) u_weights = u_weights[..., u_sorter] v_weights = v_weights[..., v_sorter] zero = torch.zeros(1, dtype=dtype, device=device) u_cdf = torch.cumsum(u_weights, -1) v_cdf = torch.cumsum(v_weights, -1) cdf_axis, _ = torch.sort(torch.cat((u_cdf, v_cdf), -1), -1) u_index = torch.searchsorted(u_cdf, cdf_axis) v_index = torch.searchsorted(v_cdf, cdf_axis) u_icdf = torch.gather(u_values, -1, u_index.clip(0, n - 1)) v_icdf = torch.gather(v_values, -1, v_index.clip(0, m - 1)) cdf_axis = torch.nn.functional.pad(cdf_axis, (1, 0)) delta = cdf_axis[..., 1:] - cdf_axis[..., :-1] if p == 1: return torch.sum(delta * torch.abs(u_icdf - v_icdf), axis=-1) if p == 2: return torch.sum(delta * torch.square(u_icdf - v_icdf), axis=-1) return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p), axis=-1)
def est_rank(self, user_id, pos_rat, candidate_items, sample_size): # (rank - 1)/N = (est_rank - 1)/sample_size candidate_rat = self.inference( user_id.repeat(candidate_items.shape[1], 1).T, candidate_items) sorted_seq, _ = torch.sort(candidate_rat) quick_r = torch.searchsorted(sorted_seq, pos_rat.unsqueeze(-1)) r = ((quick_r) * (self.num_item - 1) / sample_size).floor().long() return self._rank_weight_pre[r]
def _searchsorted(sorted_sequence, values): # searchsorted is introduced to PyTorch in 1.6.0 if _TORCH_HAS_SEARCHSORTED: return th.searchsorted(sorted_sequence, values) else: device = values.device return th.from_numpy( np.searchsorted(sorted_sequence.cpu().numpy(), values.cpu().numpy())).to(device)
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None): side = "right" if right else "left" if isinstance(a, np.ndarray): return np.searchsorted(a, v, side, sorter) # type: ignore if hasattr(torch, "searchsorted"): return torch.searchsorted(a, v, right=right) # type: ignore # if using old PyTorch, will convert to numpy array then compute ret = np.searchsorted(a.cpu().numpy(), v.cpu().numpy(), side, sorter) # type: ignore ret, *_ = convert_to_dst_type(ret, a) return ret
def sample_pdf(bins, weights, N_samples, det=False, pytest=False): # Get pdf weights = weights + 1e-5 # prevent nans # probability distribution function pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) # cumulative distribution function cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Take uniform samples if det: u = torch.linspace(0., 1., steps=N_samples) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) # Pytest, overwrite u with numpy's fixed random numbers if pytest: np.random.seed(0) new_shape = list(cdf.shape[:-1]) + [N_samples] if det: u = np.linspace(0., 1., N_samples) u = np.broadcast_to(u, new_shape) else: u = np.random.rand(*new_shape) u = torch.Tensor(u) # Invert CDF # see zhihu's explanation u = u.contiguous() # find index of each element in u in cdf (from y axis of CDF to x axis) inds = torch.searchsorted(cdf, u, right=True) # the upper and lower bound of indices below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] # find correspondence of indices in cdf_g and bins_g # CDF Values and coarse sample results cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) # get delta probability by difference denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) # t: factor of u occupy the interval between lower and upper bound t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def convert_to_trainId(self, mask): """ https://stackoverflow.com/questions/47171356 """ k = torch.tensor(list(self.id_to_trainId.keys())) v = torch.tensor(list(self.id_to_trainId.values())) sidx = k.argsort() ks = k[sidx] vs = v[sidx] return vs[torch.searchsorted(ks, mask)].to(torch.int64)
def forward(self, t): assert self.segs is not None, f'segments have not been initialized' lens = self.segs[1:].norm(2, dim=-1).cumsum(0) lens = lens / lens[-1] inds = torch.searchsorted(lens, t) lens = torch.cat([torch.zeros(1, dtype=lens.dtype), lens]) extra = t - lens[inds] return self.segs.cumsum(0)[inds] + extra.unsqueeze(-1) * \ F.normalize(self.segs[1:])[inds.clamp(max=self.n_seg-1)]
def center_distogram_torch(distogram, bins=DISTANCE_THRESHOLDS, min_t=1., center="median", wide="std"): """ Returns the central estimate of a distogram. Median for now. Inputs: * distogram: (N x N x B) where B is the number of buckets. supports batched predictions (batch, N, N, B). * bins: (B,) containing the cutoffs for the different buckets * min_t: float. lower bound for distances. TODO: return confidence/weights """ shape = distogram.shape # threshold to weights and find mean value of each bin n_bins = bins - 0.5 * (bins[2] - bins[2]) n_bins[0] = 1.5 # TODO: adapt so that mean option considers IGNORE_INDEX n_bins[-1] = n_bins[-1] n_bins = n_bins.to(distogram.device) # calculate measures of centrality and dispersion - if center == "median": cum_dist = torch.cumsum(distogram, dim=-1) medium = 0.5 * torch.ones(*cum_dist.shape[:-1], device=cum_dist.device).unsqueeze(dim=-1) central = torch.searchsorted(cum_dist, medium).squeeze() central = n_bins[torch.minimum(central, torch.tensor(DISTOGRAM_BUCKETS - 1)).long()] elif center == "mean": central = (distogram * n_bins).sum(dim=-1) # create mask for last class - (IGNORE_INDEX) mask = (central <= bins[-2].item()).float() # mask diagonal to 0 dist diag = np.arange(shape[-2]) if len(central.shape) == 3: central[:, diag, diag] = 0. else: central[diag, diag] = 0. # provide weights if wide == "var": dispersion = (distogram * (n_bins - central.unsqueeze(-1))**2).sum(dim=-1) elif wide == "std": dispersion = (distogram * (n_bins - central.unsqueeze(-1))**2).sum(dim=-1).sqrt() else: dispersion = torch.zeros_like(central, device=central.device) # rescale to 0-1. lower std / var --> weight=1 weights = mask / (1 + dispersion) return central, dispersion, weights
def _vector(weights: torch.Tensor, u: torch.Tensor): """ Performs systematic resampling of a 1D array log weights. :param weights: The weights to use for resampling :return: Resampled indices """ n = weights.shape[0] probs = (torch.arange(n, dtype=u.dtype, device=weights.device) + u) / n cumsum = weights.cumsum(0) cumsum[..., -1] = 1. return torch.searchsorted(cumsum, probs)