def wasserstein(p, u_values, v_values, u_weights=None, v_weights=None): r""" Compute Wass_p between two one-dimensional distributions :math:`u` and :math:`v`. This implementation is an adaptation of the scipy implementation of `scipy.stats._cdf_distance` for torch tensors. Parameters ---------- u_values, v_values : torch tensors of shape (ns, ) and (nt, ) Values observed in the (empirical) distributions. u_weights, v_weights : array_like, optional Weight for each value. If unspecified, each value is assigned the same weight. `u_weights` (resp. `v_weights`) must have the same length as `u_values` (resp. `v_values`). If the weight sum differs from 1, it must still be positive and finite so that the weights can be normalized to sum to 1. Returns ------- distance : float The computed distance between the distributions. """ u_sorter = torch.argsort(u_values) v_sorter = torch.argsort(v_values) all_values = torch.cat((u_values, v_values)) all_values, _ = torch.sort(all_values) # Compute the differences between pairs of successive values of u and v. deltas = torch_diff(all_values) # Get the respective positions of the values of u and v among the values of # both distributions. u_cdf_indices = searchsorted(u_values[u_sorter], all_values[:-1], 'right') v_cdf_indices = searchsorted(v_values[v_sorter], all_values[:-1], 'right') # Calculate the CDFs of u and v using their weights, if specified. if u_weights is None: u_cdf = u_cdf_indices / u_values.size # TODO else: u_sorted_cumweights = torch.cat( ([0], np.cumsum(u_weights[u_sorter]))) # TODO u_cdf = u_sorted_cumweights[u_cdf_indices] / u_sorted_cumweights[-1] if v_weights is None: v_cdf = v_cdf_indices / v_values.size # TODO else: v_sorted_cumweights = torch.cat( ([0], np.cumsum(v_weights[v_sorter]))) # TODO v_cdf = v_sorted_cumweights[v_cdf_indices] / v_sorted_cumweights[-1] # Compute the value of the integral based on the CDFs. return torch.sum(torch.mul(torch.abs(u_cdf - v_cdf), torch.pow(deltas, p)))
def ks2(data1, data2): n1 = data1.shape[1] n2 = data2.shape[1] data1 = data1.sort()[0] data2 = data2.sort()[0] data_all = torch.cat([data1, data2], dim=1) cdf1 = searchsorted(data1, data_all, side='right') / (1.0 * n1) cdf2 = (searchsorted(data2, data_all, side='right')) / (1.0 * n2) d = (cdf1 - cdf2).abs().max() return d
def energy_spectrum(vel): """ Compute energy spectrum given a velocity field :param vel: tensor of shape (N, 3, res, res, res) :return spec: tensor of shape(N, res/2) :return k: tensor of shape (res/2,), frequencies corresponding to spec """ device = vel.device res = vel.shape[-2:] assert (res[0] == res[1]) r = res[0] k_end = int(r / 2) vel_ = pad_rfft3(vel, onesided=False) # (N, 3, res, res, res, 2) uu_ = (torch.norm(vel_, dim=-1) / r**3)**2 e_ = torch.sum(uu_, dim=1) # (N, res, res, res) k = fftfreqs(res).to(device) # (3, res, res, res) rad = torch.norm(k, dim=0) # (res, res, res) k_bin = torch.arange(k_end, device=device).float() + 1 bins = torch.zeros(k_end + 1).to(device) bins[1:-1] = (k_bin[1:] + k_bin[:-1]) / 2 bins[-1] = k_bin[-1] bins = bins.unsqueeze(0) bins[1:] += 1e-3 inds = searchsorted(bins, rad.flatten().unsqueeze(0)).squeeze().int() # bincount = torch.histc(inds.cpu(), bins=bins.shape[1]+1).to(device) bincount = torch.bincount(inds) asort = torch.argsort(inds.squeeze()) sorted_e_ = e_.view(e_.shape[0], -1)[:, asort] csum_e_ = torch.cumsum(sorted_e_, dim=1) binloc = torch.cumsum(bincount, dim=0).long() - 1 spec_ = csum_e_[:, binloc[1:]] - csum_e_[:, binloc[:-1]] spec_ = spec_[:, :-1] spec_ = spec_ * 2 * np.pi * (k_bin.float()**2) / bincount[1:-1].float() return spec_, k_bin
def resample_pdf(probs, samples, n_samples=32, deterministic=False): sampled_depth, sampled_idx, sampled_dists = samples # compute CDF pdf = probs / (probs.sum(-1, keepdims=True) + 1e-7) cdf = torch.cat([torch.zeros_like(pdf[...,:1]), torch.cumsum(pdf, -1)], -1) # generate random samples z = torch.arange(n_samples, device=cdf.device, dtype=cdf.dtype).expand( cdf.size(0), n_samples).contiguous() if deterministic: z = z + 0.5 else: z = z + z.clone().uniform_() z = z / float(n_samples) # inverse transform sampling inds = searchsorted(cdf, z) - 1 inds_miss = inds.eq(sampled_idx.size(1)) inds_safe = inds.clamp(max=sampled_idx.size(1)-1) resampled_below, resampled_above = cdf.gather(1, inds_safe), cdf.gather(1, inds_safe + 1) resampled_idx = sampled_idx.gather(1, inds_safe).masked_fill(inds_miss, -1) resampled_depth = sampled_depth.gather(1, inds_safe).masked_fill(inds_miss, MAX_DEPTH) resampled_dists = sampled_dists.gather(1, inds_safe).masked_fill(inds_miss, 0.0) # reparameterization resampled_depth = ((z - resampled_below) / (resampled_above - resampled_below + 1e-7) - 0.5) * resampled_dists + resampled_depth return resampled_depth, resampled_idx, resampled_depth
def sample_pdf(bins, weights, N_samples, det=False): # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Take uniform samples if det: u = torch.linspace(0., 1., steps=N_samples) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) # Invert CDF u = u.contiguous() inds = searchsorted(cdf, u, side='right') below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def test_searchsorted_output_dtype(device): B = 100 A = 50 V = 12 a = torch.sort(torch.rand(B, V, device=device), dim=1)[0] v = torch.rand(B, A, device=device) out = searchsorted(a, v) out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy()) assert out.dtype == torch.long np.testing.assert_array_equal(out.cpu().numpy(), out_np) out = torch.empty(v.shape, dtype=torch.long, device=device) searchsorted(a, v, out) assert out.dtype == torch.long np.testing.assert_array_equal(out.cpu().numpy(), out_np)
def compute_pad_amounts(self, batch_index, batch_size): """Compute padding needed to form dense minibatch.""" helper_index = torch.arange(batch_size + 1, device=batch_index.device) helper_index = helper_index.unsqueeze(0).contiguous().int() batch_index = batch_index.unsqueeze(0).contiguous().int() start_index = searchsorted(batch_index, helper_index).squeeze(0) batch_count = start_index[1:] - start_index[:-1] pad = list((batch_count.max() - batch_count).cpu().numpy()) batch_count = list(batch_count.cpu().numpy()) return batch_count, pad
def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): """ Sample @N_importance samples from @bins with distribution defined by @weights. Inputs: bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" weights: (N_rays, N_samples_) N_importance: the number of samples to draw from the distribution det: deterministic or not eps: a small number to prevent division by zero Outputs: samples: the sampled samples """ N_rays, N_samples_ = weights.shape weights = weights + eps # prevent division by zero (don't do inplace op!) pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) cdf = torch.cumsum( pdf, -1) # (N_rays, N_samples), cumulative distribution function cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (N_rays, N_samples_+1) # padded to 0~1 inclusive if det: u = torch.linspace(0, 1, N_importance, device=bins.device) u = u.expand(N_rays, N_importance) else: u = torch.rand(N_rays, N_importance, device=bins.device) u = u.contiguous() inds = searchsorted(cdf, u, side='right') below = torch.clamp_min(inds - 1, 0) above = torch.clamp_max(inds, N_samples_) inds_sampled = torch.stack([below, above], -1).view(N_rays, 2 * N_importance) cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) print(u[0], inds[0], cdf[0], below[0], above[0]) denom = cdf_g[..., 1] - cdf_g[..., 0] denom[ denom < eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled # anyway, therefore any value for it is fine (set to 1 here) samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (bins_g[..., 1] - bins_g[..., 0]) print(inds_sampled[0], cdf_g[0], bins_g[0], denom[0]) print(samples[0]) return samples
def test_searchsorted_correct(Ba, Bv, A, V, side, device): if Ba > 1 and Bv > 1 and Ba != Bv: return for test in range(nrepeat): a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0] v = torch.rand(Bv, V, device=device) out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(), side=side) out = searchsorted(a, v, side=side).cpu().numpy() np.testing.assert_array_equal(out, out_np)
def sample_pdf(bins, weights, args): """ Hierarchical sampling """ # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat( [torch.zeros_like(cdf[..., :1], device=args.default_device), cdf], -1) # (batch, len(bins)) # Take uniform samples u = torch.linspace(0., 1., steps=args.number_fine_samples, device=args.default_device) u = u.expand(list(cdf.shape[:-1]) + [args.number_fine_samples]) # Invert CDF u = u.contiguous() #inds = torch.searchsorted(cdf, u, right=True) inds = searchsorted(cdf, u, side='right') below = torch.max(torch.zeros_like(inds - 1, device=args.default_device), inds - 1) above = torch.min( cdf.shape[-1] - 1 * torch.ones_like(inds, device=args.default_device), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom, device=args.default_device), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def sample_pdf(bins, weights, N_samples, det=False, pytest=False): # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Take uniform samples if det: u = torch.linspace(0., 1., steps=N_samples) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) # Pytest, overwrite u with numpy's fixed random numbers if pytest: np.random.seed(0) new_shape = list(cdf.shape[:-1]) + [N_samples] if det: u = np.linspace(0., 1., N_samples) u = np.broadcast_to(u, new_shape) else: u = np.random.rand(*new_shape) u = torch.Tensor(u) # Invert CDF u = u.contiguous() inds = searchsorted(cdf, u, side='right') below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def sample_pdf_2(bins, weights, num_samples, det=False): r"""sample_pdf function from another concurrent pytorch implementation by yenchenlin (https://github.com/yenchenlin/nerf-pytorch). """ weights = weights + 1e-5 pdf = weights / torch.sum(weights, dim=-1, keepdim=True) cdf = torch.cumsum(pdf, dim=-1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # (batchsize, len(bins)) # Take uniform samples if det: u = torch.linspace(0.0, 1.0, steps=num_samples, dtype=weights.dtype, device=weights.device) u = u.expand(list(cdf.shape[:-1]) + [num_samples]) else: u = torch.rand( list(cdf.shape[:-1]) + [num_samples], dtype=weights.dtype, device=weights.device, ) # Invert CDF u = u.contiguous() cdf = cdf.contiguous() inds = torchsearchsorted.searchsorted(cdf, u, side="right") below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack((below, above), dim=-1) # (batchsize, num_samples, 2) matched_shape = (inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]) cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = cdf_g[..., 1] - cdf_g[..., 0] denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def forward(self, decoys: DecoyBatch): # separation = decoys.receivers - graphs.senders - 1 # separation_cls = searchsorted(separation, self.bins, side='right') - 1 separation = (decoys.senders - decoys.receivers + 1).float().unsqueeze_(0) separation_cls = (self.bins.numel() - 1) - searchsorted( self.bins, separation).squeeze_(0).long() separation_onehot = torch.zeros(decoys.num_edges, self.bins.numel(), device=decoys.senders.device) separation_onehot.scatter_(value=1., index=separation_cls.unsqueeze_(1), dim=1) decoys = decoys.evolve(edge_features=torch.cat( (decoys.edge_features, separation_onehot), dim=1), ) return decoys
def __call__(self, batch): """"compute the (generalized) sliced Wasserstein distance between the object dataset and the provided batch batch: torch.Tensor (num_samples, ) + sample_shape the batch of samples for which to compute the GSW distance to the dataset.""" # update the target if required if (self.target_percentiles is None) or not self.manual_refresh: self.refresh() # bringing the target percentiles to the batch device (if not done) # already self.target_percentiles = [ t.to(batch.device) for t in self.target_percentiles ] # if the batch is too small, we may have to reduce the number of # percentiles num_percentiles = min(batch.shape[0], self.num_percentiles) if num_percentiles != self.num_percentiles: indices = searchsorted( self.percentiles[None, :], torch.linspace(0, 100, num_percentiles)[None, :]).long() else: indices = Ellipsis percentiles = self.percentiles[indices].squeeze() loss = torch.tensor(0, device=batch.device) for (projector_id, target_percentiles) in zip(self.projector_ids, self.target_percentiles): # get the projector projector = self.projectors[projector_id] test_percentiles = sketch(projector, batch, percentiles) loss = loss + torch.nn.MSELoss()( target_percentiles[indices].squeeze(), test_percentiles.to(batch.device)) loss = loss / self.batchsize return loss
def sample_pdf_torch(bins, weights, num_samples, det=False): # TESTED (Carefully, line-to-line). # But chances of bugs persist; haven't integration-tested with # training routines. # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / weights.sum(-1).unsqueeze(-1) cdf = torch.cumsum(pdf, -1) cdf = torch.cat((torch.zeros_like(cdf[..., :1]), cdf), -1) # Take uniform samples if det: u = torch.linspace(0.0, 1.0, num_samples).to(weights) u = u.expand(list(cdf.shape[:-1]) + [num_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [num_samples]).to(weights) # Invert CDF inds = torchsearchsorted.searchsorted(cdf.contiguous(), u.contiguous(), side="right") below = torch.max(torch.zeros_like(inds), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack((below, above), -1) orig_inds_shape = inds_g.shape cdf_g = gather_cdf_util_torch(cdf, inds_g) bins_g = gather_cdf_util_torch(bins, inds_g) denom = cdf_g[..., 1] - cdf_g[..., 0] denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
+ ': Training set loss = ' + str(int(loss_data*1e5)/1e5)) #--------------------------------------------------------------------------------------------------------- # spectrum 1 # extract model spec_1 = rest_spec_model_1.spec #RV_pred_1 = rv_model_1.rv #RV_pred_1, RV_pred_2 = rv_model.forward() RV_pred_1 = rv_model_1.forward() RV_pred_2 = rv_model_2.forward() # RV shift doppler_shift = torch.sqrt((1 - RV_pred_1/const_c)/(1 + RV_pred_1/const_c)) new_wavelength = torch.t(torch.ger(wave, doppler_shift)).contiguous() # torch.ger = outer product ind = searchsorted(wave_cat, new_wavelength).type(torch.LongTensor) # fix border indexing problem ind[ind == num_pixel - 1] = num_pixel - 2 # calculate adjacent gradient slopes = (spec_1[1:] - spec_1[:-1])/(wave[1:]-wave[:-1]) # linear interpolate spec_shifted_recovered_1 = spec_1[ind] + slopes[ind]*(new_wavelength - wave[ind]) #--------------------------------------------------------------------------------------------------------- # spectrum 2 spec_2 = rest_spec_model_2.spec #RV_pred_2 = rv_model_2.rv doppler_shift = torch.sqrt((1 - RV_pred_2/const_c)/(1 + RV_pred_2/const_c))
def histogramdd(sample,bins=None,ranges=None,weights=None,edges=None,device=None): custom_edges = False D = sample.size(0) if device == None: device = sample.device if bins == None: if edges == None: bins = 10 custom_edges = False else: try: bins = edges.size(1)-1 except AttributeError: bins = torch.empty(D) for i in range(len(edges)): bins[i] = edges[i].size(0)-1 bins = bins.to(device) custom_edges = True try: M = bins.size(0) if M != D: raise ValueError( 'The dimension of bins must be equal to the dimension of the ' ' sample x.') except AttributeError: # bins is either an integer or a list if type(bins) == int: bins = torch.full([D],bins,dtype=torch.long,device=device) elif torch.is_tensor(bins[0]): custom_edges = True edges = bins bins = torch.empty(D,dtype=torch.long) for i in range(len(edges)): bins[i] = edges[i].size(0)-1 bins = bins.to(device) else: bins = torch.as_tensor(bins) if bins.dim() == 2: custom_edges = True edges = bins bins = torch.full([D],bins.size(1)-1,dtype=torch.long,device=device) if custom_edges: if not torch.is_tensor(edges): m = max(i.size(0) for i in edges) tmp = torch.empty([D,m]) for i in range(D): s = edges[i].size(0) tmp[i,:]=edges[i][-1] tmp[i,:s]=edges[i][:] edges = tmp.to(device) k = searchsorted(edges,sample) else: if ranges == None: ranges = torch.empty(2,D,device=device) ranges[0,:]=torch.min(sample,1)[0] ranges[1,:]=torch.max(sample,1)[0] tranges = torch.empty_like(ranges) tranges[1,:] = bins/(ranges[1,:]-ranges[0,:]) tranges[0,:] = 1-ranges[0,:]*tranges[1,:] k = torch.addcmul(tranges[0,:].reshape(-1,1),sample,tranges[1,:].reshape(-1,1)).long() #Get the right index k = torch.max(k,torch.tensor(0,device=device)) #Underflow bin k = torch.min(k,(bins+1).reshape(-1,1)) multiindex = torch.ones_like(bins) multiindex[1:] = torch.cumprod(torch.flip(bins[1:],[0])+2,-1).long() multiindex = torch.flip(multiindex,[0]) l = torch.sum(k*multiindex.reshape(-1,1),0) hist = torch.bincount(l,minlength=(multiindex[0]*(bins[0]+2)).item(),weights=weights) hist = hist.reshape(tuple(bins+2)) """ m,index = l.sort() r = torch.arange((bins.size(1)+1)**bins.size(0),device=device) hist = searchsorted(m.reshape(1,-1),r.reshape(1,-1),side='right') hist[0,1:]=hist[0,1:]-hist[0,:-1] hist = hist.reshape(tuple(torch.full([bins.size(0)],bins.size(1)+1,dtype=int,device=device))) """ return hist
test_CPU = None for ntest in range(ntests): print("Looking for %dx%d values in %dx%d entries" % (nrows_v, nvalues, nrows_a, nsorted_values)) # generate a matrix with sorted rows a = torch.randn(nrows_a, nsorted_values, device='cpu') a = torch.sort(a, dim=1)[0] # generate a matrix of values to searchsort v = torch.randn(nrows_v, nvalues, device='cpu') t0 = time.time() test_CPU = searchsorted(a, v, test_CPU) print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0))) if not torch.cuda.is_available(): print('CUDA is not available on this machine, cannot go further.') continue else: # now do the CPU a = a.to('cuda') v = v.to('cuda') # launch searchsorted on those t0 = time.time() test_GPU = searchsorted(a, v, test_GPU) print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
def forward(ctx, x, y, xnew, out=None): """ Linear 1D interpolation on the GPU for Pytorch. This function returns interpolated values of a set of 1-D functions at the desired query points `xnew`. This function is working similarly to Matlabâ„¢ or scipy functions with the `linear` interpolation mode on, except that it parallelises over any number of desired interpolation problems. The code will run on GPU if all the tensors provided are on a cuda device. Parameters ---------- x : (N, ) or (D, N) Pytorch Tensor A 1-D or 2-D tensor of real values. y : (N,) or (D, N) Pytorch Tensor A 1-D or 2-D tensor of real values. The length of `y` along its last dimension must be the same as that of `x` xnew : (P,) or (D, P) Pytorch Tensor A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if _both_ `x` and `y` are 1-D. Otherwise, its length along the first dimension must be the same as that of whichever `x` and `y` is 2-D. out : Pytorch Tensor, same shape as `xnew` Tensor for the output. If None: allocated automatically. """ # checking availability of the searchsorted pytorch module if not SEARCHSORTED_AVAILABLE: raise Exception( 'The interp1d function depends on the ' 'torchsearchsorted module, which is not available.\n' 'You must get it at ', 'https://github.com/aliutkus/torchsearchsorted\n') # making the vectors at least 2D is_flat = {} require_grad = {} v = {} device = [] eps = torch.finfo(y.dtype).eps for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items(): assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\ 'at most 2-D.' if len(vec.shape) == 1: v[name] = vec[None, :] else: v[name] = vec is_flat[name] = v[name].shape[0] == 1 require_grad[name] = vec.requires_grad device = list(set(device + [str(vec.device)])) assert len(device) == 1, 'All parameters must be on the same device.' device = device[0] # Checking for the dimensions assert (v['x'].shape[1] == v['y'].shape[1] and (v['x'].shape[0] == v['y'].shape[0] or v['x'].shape[0] == 1 or v['y'].shape[0] == 1) ), ("x and y must have the same number of columns, and either " "the same number of row or one of them having only one " "row.") reshaped_xnew = False if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1) and (v['xnew'].shape[0] > 1)): # if there is only one row for both x and y, there is no need to # loop over the rows of xnew because they will all have to face the # same interpolation problem. We should just stack them together to # call interp1d and put them back in place afterwards. original_xnew_shape = v['xnew'].shape v['xnew'] = v['xnew'].contiguous().view(1, -1) reshaped_xnew = True # identify the dimensions of output and check if the one provided is ok D = max(v['x'].shape[0], v['xnew'].shape[0]) shape_ynew = (D, v['xnew'].shape[-1]) if out is not None: if out.numel() != shape_ynew[0] * shape_ynew[1]: # The output provided is of incorrect shape. # Going for a new one out = None else: ynew = out.reshape(shape_ynew) if out is None: ynew = torch.zeros(*shape_ynew, device=device) # moving everything to the desired device in case it was not there # already (not handling the case things do not fit entirely, user will # do it if required.) for name in v: v[name] = v[name].to(device) # calling searchsorted on the x values. ind = ynew.long() searchsorted(v['x'].contiguous(), v['xnew'].contiguous(), ind) # the `-1` is because searchsorted looks for the index where the values # must be inserted to preserve order. And we want the index of the # preceeding value. ind -= 1 # we clamp the index, because the number of intervals is x.shape-1, # and the left neighbour should hence be at most number of intervals # -1, i.e. number of columns in x -2 ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1) # helper function to select stuff according to the found indices. def sel(name): if is_flat[name]: return v[name].contiguous().view(-1)[ind] return torch.gather(v[name], 1, ind) # activating gradient storing for everything now enable_grad = False saved_inputs = [] for name in ['x', 'y', 'xnew']: if require_grad[name]: enable_grad = True saved_inputs += [v[name]] else: saved_inputs += [ None, ] # assuming x are sorted in the dimension 1, computing the slopes for # the segments is_flat['slopes'] = is_flat['x'] # now we have found the indices of the neighbors, we start building the # output. Hence, we start also activating gradient tracking with torch.enable_grad() if enable_grad else contextlib.suppress(): v['slopes'] = ((v['y'][:, 1:] - v['y'][:, :-1]) / (eps + v['x'][:, 1:] - v['x'][:, :-1])) # now build the linear interpolation ynew = sel('y') + sel('slopes') * (v['xnew'] - sel('x')) if reshaped_xnew: ynew = ynew.view(original_xnew_shape) ctx.save_for_backward(ynew, *saved_inputs) return ynew
side = 'right' # generate a matrix with sorted rows a = torch.randn(nrows_a, nsorted_values, device='cpu') a = torch.sort(a, dim=1)[0] # generate a matrix of values to searchsort v = torch.randn(nrows_v, nvalues, device='cpu') # a = torch.tensor([[0., 1.]]) # v = torch.tensor([[1.]]) t0 = time.time() test_NP = torch.tensor(numpy_searchsorted(a, v, side)) print('NUMPY: searchsorted in %0.3fms' % (1000 * (time.time() - t0))) t0 = time.time() test_CPU = searchsorted(a, v, test_CPU, side) print('CPU: searchsorted in %0.3fms' % (1000 * (time.time() - t0))) # compute the difference between both error_CPU = torch.norm(test_NP.double() - test_CPU.double()).numpy() if error_CPU: import ipdb ipdb.set_trace() print(' difference between CPU and NUMPY: %0.3f' % error_CPU) if not torch.cuda.is_available(): print('CUDA is not available on this machine, cannot go further.') continue else: # now do the CPU a = a.to('cuda') v = v.to('cuda')
def searchsorted_synchronized(a, v, out=None, side='left'): out = searchsorted(a, v, out, side) torch.cuda.synchronize() return out