Example #1
0
    def get_tile_params_at_coord(self, plocs: torch.Tensor):
        """Return the parameters of the tiles that contain each of the locations in plocs."""
        assert len(plocs.shape) == 2 and plocs.shape[1] == 2
        assert plocs.device == self.locs.device
        n_total = len(plocs)
        slen = self.n_tiles_h * self.tile_slen
        wlen = self.n_tiles_w * self.tile_slen
        # coordinates on tiles.
        x_coords = torch.arange(0,
                                slen,
                                self.tile_slen,
                                device=self.locs.device).long()
        y_coords = torch.arange(0,
                                wlen,
                                self.tile_slen,
                                device=self.locs.device).long()

        x_indx = torch.searchsorted(x_coords.contiguous(),
                                    plocs[:, 0].contiguous()) - 1
        y_indx = torch.searchsorted(y_coords.contiguous(),
                                    plocs[:, 1].contiguous()) - 1

        return {
            k: v[:, x_indx, y_indx, :, :].reshape(n_total, -1)
            for k, v in self.items()
        }
Example #2
0
def _cdf_distance(u_values, v_values):

    u_sorter = torch.argsort(u_values)
    v_sorter = torch.argsort(v_values)

    all_values = torch.cat((u_values, v_values))
    all_values, _ = torch.sort(all_values)
    # Compute the differences between pairs of successive values of u and v.
    deltas = torch_diff(all_values)

    # Get the respective positions of the values of u and v among the values of
    # both distributions.

    u_cdf_indices = torch.searchsorted(u_values[u_sorter],
                                       all_values[:-1],
                                       right=True)
    v_cdf_indices = torch.searchsorted(v_values[v_sorter],
                                       all_values[:-1],
                                       right=True)
    # Calculate the CDFs of u and v using their weights, if specified.

    u_cdf = u_cdf_indices / u_values.shape[0]

    v_cdf = v_cdf_indices / v_values.shape[0]

    return torch.sum(torch.mul(torch.abs(u_cdf - v_cdf), deltas))
def _torch_cdf_distance(tensor_a, tensor_b):
    """
    Torch implementation of _cdf_distance for Wasserstein distance
    input: tensor_a, tensor_b
    output: cdf_loss which the computed distance between the tensors.
    
    #Note: this function yields an difference of \approx 10^-9
    
    Updated for batch support | 29/03/2022
    Updated for multivariate time series support | 29/03/2022
        
        Expects tensor_a and tensor_b to be of shape: (batch_size, segment_length, n_features),
            Example: a single batch of 10 time series with lengths of 12 should have shape=(1, 12, 10)
    """
    batch_size = tensor_a.shape[0]
    assert tensor_a.shape == tensor_b.shape, 'tensor_a and tensor_b have different shape'

    #It is necessary to reshape the tensors to match the dimensions of Scipy.
    tensor_a = torch.reshape(torch.swapaxes(
        tensor_a, -1, -2), (batch_size, tensor_a.shape[2], tensor_a.shape[1]))
    tensor_b = torch.reshape(torch.swapaxes(
        tensor_b, -1, -2), (batch_size, tensor_b.shape[2], tensor_b.shape[1]))

    # Creater sorters:
    sorter_a = torch.argsort(tensor_a, dim=-1)
    sorter_b = torch.argsort(tensor_a, dim=-1)

    # We append both tensors and sort them
    all_values = torch.cat((tensor_a, tensor_b), dim=-1)
    all_values, idx = torch.sort(all_values, dim=-1)

    # Calculate the n-th discrete difference along the given axis (equivalent to np.diff())
    deltas = all_values[:, :, 1:] - all_values[:, :, :-1]

    sorted_a, idx = torch.sort(tensor_a, dim=-1)
    sorted_b, idx = torch.sort(tensor_b, dim=-1)
    # Get the respective positions of the values of u and v among the values of
    # both distributions.
    a_cdf_index = torch.searchsorted(sorted_a.flatten(start_dim=2),
                                     all_values[:, :, :-1],
                                     right=True)
    # TODO: torch.searchsorted() expects contiguousarrays, passing non-contiguousarrays slows performance due to data copy | fix doesn't seem trivial
    b_cdf_index = torch.searchsorted(sorted_b.flatten(start_dim=2),
                                     all_values[:, :, :-1],
                                     right=True)
    #Compute the cdf
    a_cdf = a_cdf_index / tensor_a.shape[-1]
    b_cdf = b_cdf_index / tensor_b.shape[-1]
    #And the distance between them
    cdf_distance = torch.sum(torch.mul(torch.abs((a_cdf - b_cdf)), deltas),
                             dim=-1)

    cdf_loss = cdf_distance.mean()
    return cdf_loss
    def test_searchsorted_cpu(self):
        for i in range(1, 3):
            s = np.sort(np.random.rand(*((10, ) * i)), -1)
            v = np.random.rand(*((10, ) * i))
            s_jt = jt.array(s)
            v_jt = jt.array(v)
            s_tc = torch.from_numpy(s)
            v_tc = torch.from_numpy(v)

            y_tc = torch.searchsorted(s_tc, v_tc, right=True)
            y_jt = jt.searchsorted(s_jt, v_jt, right=True)
            assert np.allclose(y_jt.numpy(), y_tc.data)
            y_jt = jt.searchsorted(s_jt, v_jt, right=False)
            y_tc = torch.searchsorted(s_tc, v_tc, right=False)
            assert np.allclose(y_jt.numpy(), y_tc.data)
def _torch_cdf_distance(tensor_a, tensor_b):
    """
        Torch implementation of _cdf_distance for Wasserstein distance

    input: tensor_a, tensor_b
    output: cdf_loss which the computed distance between the tensors.
    
    #Note: this function yields an difference of \approx 10^-9
    """

    #It is necessary to reshape the tensors to match the dimensions of Scipy.
    tensor_a = torch.reshape(tensor_a, (1, tensor_a.shape[0]))
    tensor_b = torch.reshape(tensor_b, (1, tensor_b.shape[0]))

    # Creater sorters:
    sorter_a = torch.argsort(tensor_a)
    sorter_b = torch.argsort(tensor_a)

    # We append both tensors and sort them
    all_values = torch.cat((tensor_a, tensor_b), dim=1)
    all_values, idx = torch.sort(all_values, dim=1)

    # Calculate the n-th discrete difference along the given axis (equivalent to np.diff())
    deltas = all_values[0, 1:] - all_values[0, :-1]

    sorted_a, idx = torch.sort(tensor_a, dim=1)
    sorted_b, idx = torch.sort(tensor_b, dim=1)

    # Get the respective positions of the values of u and v among the values of
    # both distributions.
    a_cdf_index = torch.searchsorted(sorted_a.flatten(),
                                     all_values[0, :-1],
                                     right=True)
    b_cdf_index = torch.searchsorted(sorted_b.flatten(),
                                     all_values[0, :-1],
                                     right=True)

    #Compute the cdf
    a_cdf = a_cdf_index / tensor_a.shape[1]
    b_cdf = b_cdf_index / tensor_b.shape[1]

    #And the distance between them
    cdf_distance = torch.sum(torch.mul(torch.abs((a_cdf - b_cdf)), deltas),
                             dim=-1)

    cdf_loss = cdf_distance.mean()

    return cdf_loss
Example #6
0
    def sample(self, labels, index_positive, optimizer):
        self.step += 1
        positive = torch.unique(labels[index_positive], sorted=True).cuda()
        if self.num_sample - positive.size(0) >= 0:
            perm = torch.rand(size=[self.num_local]).cuda()
            perm[positive] = 2.0
            index = torch.topk(perm, k=self.num_sample)[1].cuda()
            index = index.sort()[0].cuda()
        else:
            index = positive
        self.weight_index = index
        labels[index_positive] = torch.searchsorted(index,
                                                    labels[index_positive])
        self.weight_activated = torch.nn.Parameter(
            self.weight[self.weight_index])
        self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index]
        self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[
            self.weight_index]

        if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
            # TODO the params of partial fc must be last in the params list
            optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
            optimizer.param_groups[-1]["params"][0] = self.weight_activated
            optimizer.state[self.weight_activated][
                "exp_avg"] = self.weight_activated_exp_avg
            optimizer.state[self.weight_activated][
                "exp_avg_sq"] = self.weight_activated_exp_avg_sq
            optimizer.state[self.weight_activated]["step"] = self.step
        else:
            raise
Example #7
0
    def forward(self, x):
        # x: (ndata, ndim) 2d array
        xx, yy, delta = self._prepare() #(ndim, nknot)

        index = torch.searchsorted(xx.detach(), x.T.contiguous().detach()).T
        y = torch.zeros_like(x)
        logderiv = torch.zeros_like(x)

        #linear extrapolation
        select0 = index == 0
        dim = torch.repeat_interleave(torch.arange(self.ndim).view(1,-1), len(x), dim=0)[select0]
        y[select0] = yy[dim, 0] + (x[select0]-xx[dim, 0]) * delta[dim, 0]
        logderiv[select0] = self.logderiv[dim, 0]
        selectn = index == self.nknot
        dim = torch.repeat_interleave(torch.arange(self.ndim).view(1,-1), len(x), dim=0)[selectn]
        y[selectn] = yy[dim, -1] + (x[selectn]-xx[dim, -1]) * delta[dim, -1]
        logderiv[selectn] = self.logderiv[dim, -1]

        #rational quadratic spline
        select = ~(select0 | selectn)
        index = index[select]
        dim = torch.repeat_interleave(torch.arange(self.ndim).view(1,-1), len(x), dim=0)[select]
        xi = (x[select] - xx[dim, index-1]) / (xx[dim, index] - xx[dim, index-1])
        s = (yy[dim, index]-yy[dim, index-1]) / (xx[dim, index]-xx[dim, index-1])
        xi1_xi = xi*(1-xi)
        denominator = s + (delta[dim, index]+delta[dim, index-1]-2*s)*xi1_xi
        xi2 = xi**2

        y[select] = yy[dim, index-1] + ((yy[dim, index]-yy[dim, index-1]) * (s*xi2+delta[dim, index-1]*xi1_xi)) / denominator
        logderiv[select] = 2*torch.log(s) + torch.log(delta[dim, index]*xi2 + 2*s*xi1_xi + delta[dim, index-1]*(1-xi)**2) - 2 * torch.log(denominator)

        return y, logderiv
Example #8
0
    def __getitem__(self, k):
        k = self.remapping[k]
        channels = self.dataset[k]

        if self.prefix and self.only_prefix:
            dur_channel = channels["dur_source"]
            assert dur_channel.sum() >= self.prefix

            token_times = dur_channel.cumsum(dim=-1)
            cut_after = torch.searchsorted(token_times,
                                           torch.tensor(self.prefix))

            r = {}
            for channel_name, value in channels.items():
                if isinstance(value,
                              torch.Tensor) and "source" in channel_name:
                    # if self.filter_short: assert value.size(0) >= self.prefix
                    r[channel_name] = value[:cut_after + 1]
                else:
                    r[channel_name] = value

            r["prefix"] = cut_after + 1
        else:
            r = channels

        return r
Example #9
0
def sample_pdf(bins, weights, N_samples):
    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf],
                    -1)  # (batch, len(bins))

    # Take uniform samples
    u = torch.linspace(0., 1., steps=N_samples)
    u = u.expand(list(cdf.shape[:-1]) + [N_samples])

    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf.detach(), u, right=True)
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
Example #10
0
    def sample(self, total_label):
        """
        Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
        `num_sample`.

        total_label: tensor
            Label after all gather, which cross all GPUs.
        """
        index_positive = (self.class_start <= total_label) & (
            total_label < self.class_start + self.num_local)
        total_label[~index_positive] = -1
        total_label[index_positive] -= self.class_start
        if int(self.sample_rate) != 1:
            positive = torch.unique(total_label[index_positive], sorted=True)
            if self.num_sample - positive.size(0) >= 0:
                perm = torch.rand(size=[self.num_local], device=self.device)
                perm[positive] = 2.0
                index = torch.topk(perm, k=self.num_sample)[1]
                index = index.sort()[0]
            else:
                index = positive
            self.index = index
            total_label[index_positive] = torch.searchsorted(
                index, total_label[index_positive])
            self.sub_weight = Parameter(self.weight[index])
            self.sub_weight_mom = self.weight_mom[index]
def is_neighbor(rowptr, col, a, b):
    # O(log(d_bar))
    a_neighbs = get_neighbors(rowptr, col, a)
    if b <= a_neighbs[-1] and torch.searchsorted(a_neighbs, b, right=True):
        return True
    else:
        return False
Example #12
0
def batch_sample(batch, num, baseline=False, deterministic=False):
    """ sample indices in batch (not repeat)

    Args:
        batch (tensor): the batch index in increasing order
        num (tensor): the num of samples for each batch
        baseline (bool): baseline is to use pure pytorch implementation
        deterministic (bool): whether produce the same result (fix the seed)
    Returns:
        index (tensor): the sampled indices (int64)
    """
    assert batch[-1] + 1 == len(num), "num of batch does not match!"
    assert batch.dtype in [torch.int64], "unsupported data type for `batch`!"
    assert batch.device == num.device, "`batch` and `num` must be on the same device!"

    device = batch.device
    if deterministic: torch.manual_seed(0)
    value = torch.rand(batch.shape, device=device)
    if deterministic: torch.manual_seed(random.randint(0, 9999))
    start_ind = torch.repeat_interleave(
        torch.searchsorted(batch, torch.arange(len(num), device=device)), num)
    ind = torch.arange(len(start_ind), device=device) - \
            torch.repeat_interleave(torch.cat([torch.tensor([0], device=device),
                                               torch.cumsum(num, dim=0)[:-1]], dim=0), num)
    index_out = batch_sort(value, batch, baseline=baseline)[start_ind + ind]

    return index_out
Example #13
0
    def sample_ancestral_index(self, log_weights):
        """
        sample ancestral indices
        """
        sample_dim, batch_dim = log_weights.shape
        if self.strategy == 'systematic':
            positions = (self.uniformer.sample(
                (batch_dim, )) + self.spacing) / self.S
            # weights = log_weights.exp()
            normalized_weights = F.softmax(log_weights, 0)
            cumsums = torch.cumsum(normalized_weights.transpose(0, 1), dim=1)
            (normalizers, _) = torch.max(input=cumsums, dim=1, keepdim=True)
            normalized_cumsums = cumsums / normalizers  ## B * S

            ancestral_index = torch.searchsorted(normalized_cumsums, positions)
            assert ancestral_index.shape == (
                batch_dim, sample_dim
            ), "ERROR! systematic resampling resulted unexpected index shape."
            ancestral_index = ancestral_index.transpose(0, 1)

        elif self.strategy == 'multinomial':
            normalized_weights = F.softmax(log_weights, 0)
            ancestral_index = Categorical(normalized_weights.transpose(
                0, 1)).sample((sample_dim, ))
        else:
            print("ERROR! unexpected resampling strategy.")
        return ancestral_index
    def elliptical_slice(
        self,
        initial_theta: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x1 = torch.randn_like(initial_theta)

        # we need to draw the new angle t_new
        # start by finding the slices
        rotation_angle, rotation_slices = self.find_rotated_intersections(
            initial_theta, x1)

        rotation_slices = rotation_slices.reshape(-1, 2)
        rotation_lengths = rotation_slices[:, 1] - rotation_slices[:, 0]

        # now construct the rotation angle
        cum_lengths = torch.cat((
            torch.zeros(1,
                        device=rotation_lengths.device,
                        dtype=rotation_lengths.dtype),
            torch.cumsum(rotation_lengths, 0),
        ))
        random_angle = torch.rand(1) * cum_lengths[-1]
        idx = torch.searchsorted(cum_lengths, random_angle) - 1
        t_new = rotation_slices[
            idx, 0] + random_angle - cum_lengths[idx] + rotation_angle

        return self._sample_on_slice(t_new, initial_theta, x1)
Example #15
0
def energy_spectrum(vel):
    """
    Compute energy spectrum given a velocity field
    :param vel: tensor of shape (N, 3, res, res, res)
    :return spec: tensor of shape(N, res/2)
    :return k: tensor of shape (res/2,), frequencies corresponding to spec
    """
    device = vel.device
    res = vel.shape[-2:]

    assert(res[0] == res[1])
    r = res[0]
    k_end = int(r/2)
    vel_ = pad_rfft3(vel, onesided=False) # (N, 3, res, res, res, 2)
    uu_ = (torch.norm(vel_, dim=-1) / r**3)**2
    e_ = torch.sum(uu_, dim=1)  # (N, res, res, res)
    k = fftfreqs(res).to(device) # (3, res, res, res)
    rad = torch.norm(k, dim=0) # (res, res, res)
    k_bin = torch.arange(k_end, device=device).float()+1
    bins = torch.zeros(k_end+1).to(device)
    bins[1:-1] = (k_bin[1:]+k_bin[:-1])/2
    bins[-1] = k_bin[-1]
    bins = bins.unsqueeze(0)
    bins[1:] += 1e-3
    inds = torch.searchsorted(bins, rad.flatten().unsqueeze(0)).squeeze().int()
    # bincount = torch.histc(inds.cpu(), bins=bins.shape[1]+1).to(device)
    bincount = torch.bincount(inds)
    asort = torch.argsort(inds.squeeze())
    sorted_e_ = e_.view(e_.shape[0], -1)[:, asort]
    csum_e_ = torch.cumsum(sorted_e_, dim=1)
    binloc = torch.cumsum(bincount, dim=0).long()-1
    spec_ = csum_e_[:,binloc[1:]] - csum_e_[:,binloc[:-1]]
    spec_ = spec_[:, :-1]
    spec_ = spec_ * 2 * np.pi * (k_bin.float()**2) / bincount[1:-1].float()
    return spec_, k_bin
Example #16
0
    def sample_fine(self, rays, weights):
        """
        Weighted stratified (importance) sample
        :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
        :param weights (B, Kc)
        :return (B, Kf-Kfd)
        """
        device = rays.device
        B = rays.shape[0]

        weights = weights.detach() + 1e-5  # Prevent division by zero
        pdf = weights / torch.sum(weights, -1, keepdim=True)  # (B, Kc)
        cdf = torch.cumsum(pdf, -1)  # (B, Kc)
        cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1)  # (B, Kc+1)

        u = torch.rand(B,
                       self.n_fine - self.n_fine_depth,
                       dtype=torch.float32,
                       device=device)  # (B, Kf)
        inds = torch.searchsorted(cdf, u, right=True).float() - 1.0  # (B, Kf)
        inds = torch.clamp_min(inds, 0.0)

        z_steps = (inds + torch.rand_like(inds)) / self.n_coarse  # (B, Kf)

        near, far = rays[:, -2:-1], rays[:, -1:]  # (B, 1)
        if not self.lindisp:  # Use linear sampling in depth space
            z_samp = near * (1 - z_steps) + far * z_steps  # (B, Kf)
        else:  # Use linear sampling in disparity space
            z_samp = 1 / (1 / near *
                          (1 - z_steps) + 1 / far * z_steps)  # (B, Kf)
        return z_samp
Example #17
0
def systematic(w: torch.Tensor,
               normalized=False,
               u: Union[torch.Tensor, float] = None) -> torch.Tensor:
    """
    Performs systematic resampling on either a 1D or 2D array.

    Args:
        w: the log weights to use for resampling.
        normalized: whether the weights are normalized are not.
        u: parameter for overriding the sampled index, only used for testing.
    """

    is_1d = w.dim() == 1

    if is_1d:
        w = w.unsqueeze(0)

    shape = (w.shape[0], 1)
    u = u if u is not None else (torch.empty(shape,
                                             device=w.device)).uniform_()
    w = normalize(w) if not normalized else w

    n = w.shape[1]
    index_range = torch.arange(n, dtype=u.dtype, device=w.device).unsqueeze(0)

    probs = (index_range + u) / n
    cumsum = w.cumsum(-1)

    cumsum[..., -1] = 1.0
    res = torch.searchsorted(cumsum, probs)

    return res.squeeze(0) if is_1d else res
Example #18
0
def sample_cdf(z_vals, weights, det=False):
    bins = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
    weights = weights + 1e-5
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)

    if det:
        u = torch.linspace(0., 1., config.N_samples)
        u = torch.broadcast_to(u, list(cdf.shape[:-1]) + [config.N_samples])
    else:
        u = torch.rand(cdf.shape[0], config.N_samples)
    u = u.cuda(cdf.device)

    idxs = torch.searchsorted(cdf, u, right=True)
    below = torch.maximum(torch.zeros_like(idxs), idxs - 1).long()
    above = torch.minimum(torch.ones_like(idxs) * (cdf.shape[-1] - 1),
                          idxs).long()

    # idxs_g = torch.stack([below, above], -1)
    cdf_below = torch.gather(cdf, dim=1, index=below)
    cdf_above = torch.gather(cdf, dim=1, index=above)

    bin_below = torch.gather(bins, dim=1, index=below)
    bin_above = torch.gather(bins, dim=1, index=above)

    denom = cdf_above - cdf_below
    denom = torch.clamp(denom, 1e-5, 99999)
    t = (u - cdf_below) / denom
    samples = bin_below + t * (bin_above - bin_below)
    return samples
Example #19
0
def sample_pdf(bins, weights, num_samples, det=False):
    # TESTED (Carefully, line-to-line).
    # But chances of bugs persist; haven't integration-tested with
    # training routines.

    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / weights.sum(-1).unsqueeze(-1)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat((torch.zeros_like(cdf[..., :1]), cdf), -1)

    # Take uniform samples
    if det:
        u = torch.linspace(0.0, 1.0, num_samples).to(weights)
        u = u.expand(list(cdf.shape[:-1]) + [num_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [num_samples]).to(weights)

    # Invert CDF
    inds = torch.searchsorted(cdf.contiguous(), u.contiguous(), right=True)
    below = torch.max(torch.zeros_like(inds), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
    inds_g = torch.stack((below, above), -1)
    orig_inds_shape = inds_g.shape

    cdf_g = gather_cdf_util(cdf, inds_g)
    bins_g = gather_cdf_util(bins, inds_g)

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
Example #20
0
    def _interp(self, xq, y):
        if self.y_is_given:
            y = self.y

        x = self.x

        nr = x.shape[-1]
        idxr = torch.searchsorted(x.detach(), xq.detach(), right=False)  # (nrq)
        idxr = torch.clamp(idxr, 1, nr - 1)
        idxl = idxr - 1  # (nrq) from (0 to nr-2)

        if torch.numel(xq) > torch.numel(x):
            yl = y[..., :-1]  # (*BY, nr-1)
            xl = x[..., :-1]  # (nr-1)
            dy = y[..., 1:] - yl  # (*BY, nr-1)
            dx = x[..., 1:] - xl  # (nr-1)
            t = (xq - torch.gather(xl, -1, idxl)) / torch.gather(dx, -1, idxl)  # (nrq)
            yq = dy[..., idxl] * t
            yq += yl[..., idxl]
            return yq
        else:
            xl = torch.gather(x, -1, idxl)
            xr = torch.gather(x, -1, idxr)
            yl = y[..., idxl].contiguous()
            yr = y[..., idxr].contiguous()

            dxrl = xr - xl  # (nrq,)
            dyrl = yr - yl  # (nbatch, nrq)
            t = (xq - xl) / dxrl  # (nrq,)
            yq = yl + dyrl * t
            return yq
Example #21
0
 def __call__(self, x):
     idx = torch.searchsorted(self.xs, x, right=True) - 1
     if idx >= self.n:
         return self.ys[-1]
     elif idx < 0:
         return self[0]
     else:
         return self.ys[idx] + self.slopes[idx] * (x - self.xs[idx])
Example #22
0
def emd1D(u_values,
          v_values,
          u_weights=None,
          v_weights=None,
          p=1,
          require_sort=True):
    n = u_values.shape[-1]
    m = v_values.shape[-1]

    device = u_values.device
    dtype = u_values.dtype

    if u_weights is None:
        u_weights = torch.full((n, ), 1 / n, dtype=dtype, device=device)

    if v_weights is None:
        v_weights = torch.full((m, ), 1 / m, dtype=dtype, device=device)

    if require_sort:
        u_values, u_sorter = torch.sort(u_values, -1)
        v_values, v_sorter = torch.sort(v_values, -1)

        u_weights = u_weights[..., u_sorter]
        v_weights = v_weights[..., v_sorter]

    zero = torch.zeros(1, dtype=dtype, device=device)

    u_cdf = torch.cumsum(u_weights, -1)
    v_cdf = torch.cumsum(v_weights, -1)

    cdf_axis, _ = torch.sort(torch.cat((u_cdf, v_cdf), -1), -1)

    u_index = torch.searchsorted(u_cdf, cdf_axis)
    v_index = torch.searchsorted(v_cdf, cdf_axis)

    u_icdf = torch.gather(u_values, -1, u_index.clip(0, n - 1))
    v_icdf = torch.gather(v_values, -1, v_index.clip(0, m - 1))

    cdf_axis = torch.nn.functional.pad(cdf_axis, (1, 0))
    delta = cdf_axis[..., 1:] - cdf_axis[..., :-1]

    if p == 1:
        return torch.sum(delta * torch.abs(u_icdf - v_icdf), axis=-1)
    if p == 2:
        return torch.sum(delta * torch.square(u_icdf - v_icdf), axis=-1)
    return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p), axis=-1)
Example #23
0
 def est_rank(self, user_id, pos_rat, candidate_items, sample_size):
     # (rank - 1)/N = (est_rank - 1)/sample_size
     candidate_rat = self.inference(
         user_id.repeat(candidate_items.shape[1], 1).T, candidate_items)
     sorted_seq, _ = torch.sort(candidate_rat)
     quick_r = torch.searchsorted(sorted_seq, pos_rat.unsqueeze(-1))
     r = ((quick_r) * (self.num_item - 1) / sample_size).floor().long()
     return self._rank_weight_pre[r]
Example #24
0
def _searchsorted(sorted_sequence, values):
    # searchsorted is introduced to PyTorch in 1.6.0
    if _TORCH_HAS_SEARCHSORTED:
        return th.searchsorted(sorted_sequence, values)
    else:
        device = values.device
        return th.from_numpy(
            np.searchsorted(sorted_sequence.cpu().numpy(),
                            values.cpu().numpy())).to(device)
Example #25
0
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None):
    side = "right" if right else "left"
    if isinstance(a, np.ndarray):
        return np.searchsorted(a, v, side, sorter)  # type: ignore
    if hasattr(torch, "searchsorted"):
        return torch.searchsorted(a, v, right=right)  # type: ignore
    # if using old PyTorch, will convert to numpy array then compute
    ret = np.searchsorted(a.cpu().numpy(), v.cpu().numpy(), side, sorter)  # type: ignore
    ret, *_ = convert_to_dst_type(ret, a)
    return ret
Example #26
0
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    # Get pdf
    weights = weights + 1e-5  # prevent nans
    # probability distribution function
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    # cumulative distribution function
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf],
                    -1)  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])

    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)

    # Invert CDF
    # see zhihu's explanation
    u = u.contiguous()
    # find index of each element in u in cdf (from y axis of CDF to x axis)
    inds = torch.searchsorted(cdf, u, right=True)
    # the upper and lower bound of indices
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    # find correspondence of indices in cdf_g and bins_g
    # CDF Values and coarse sample results
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    # get delta probability by difference
    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    # t: factor of u occupy the interval between lower and upper bound
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
Example #27
0
    def convert_to_trainId(self, mask):
        """
        https://stackoverflow.com/questions/47171356
        """
        k = torch.tensor(list(self.id_to_trainId.keys()))
        v = torch.tensor(list(self.id_to_trainId.values()))
        sidx = k.argsort()

        ks = k[sidx]
        vs = v[sidx]
        return vs[torch.searchsorted(ks, mask)].to(torch.int64)
Example #28
0
    def forward(self, t):
        assert self.segs is not None, f'segments have not been initialized'

        lens = self.segs[1:].norm(2, dim=-1).cumsum(0)
        lens = lens / lens[-1]

        inds = torch.searchsorted(lens, t)
        lens = torch.cat([torch.zeros(1, dtype=lens.dtype), lens])
        extra = t - lens[inds]

        return self.segs.cumsum(0)[inds] + extra.unsqueeze(-1) * \
               F.normalize(self.segs[1:])[inds.clamp(max=self.n_seg-1)]
Example #29
0
def center_distogram_torch(distogram,
                           bins=DISTANCE_THRESHOLDS,
                           min_t=1.,
                           center="median",
                           wide="std"):
    """ Returns the central estimate of a distogram. Median for now.
        Inputs:
        * distogram: (N x N x B) where B is the number of buckets.
                     supports batched predictions (batch, N, N, B).
        * bins: (B,) containing the cutoffs for the different buckets
        * min_t: float. lower bound for distances.
        TODO: return confidence/weights
    """
    shape = distogram.shape
    # threshold to weights and find mean value of each bin
    n_bins = bins - 0.5 * (bins[2] - bins[2])
    n_bins[0] = 1.5
    # TODO: adapt so that mean option considers IGNORE_INDEX
    n_bins[-1] = n_bins[-1]
    n_bins = n_bins.to(distogram.device)
    # calculate measures of centrality and dispersion -
    if center == "median":
        cum_dist = torch.cumsum(distogram, dim=-1)
        medium = 0.5 * torch.ones(*cum_dist.shape[:-1],
                                  device=cum_dist.device).unsqueeze(dim=-1)
        central = torch.searchsorted(cum_dist, medium).squeeze()
        central = n_bins[torch.minimum(central,
                                       torch.tensor(DISTOGRAM_BUCKETS -
                                                    1)).long()]
    elif center == "mean":
        central = (distogram * n_bins).sum(dim=-1)
    # create mask for last class - (IGNORE_INDEX)
    mask = (central <= bins[-2].item()).float()
    # mask diagonal to 0 dist
    diag = np.arange(shape[-2])
    if len(central.shape) == 3:
        central[:, diag, diag] = 0.
    else:
        central[diag, diag] = 0.
    # provide weights
    if wide == "var":
        dispersion = (distogram *
                      (n_bins - central.unsqueeze(-1))**2).sum(dim=-1)
    elif wide == "std":
        dispersion = (distogram *
                      (n_bins - central.unsqueeze(-1))**2).sum(dim=-1).sqrt()
    else:
        dispersion = torch.zeros_like(central, device=central.device)
    # rescale to 0-1. lower std / var  --> weight=1
    weights = mask / (1 + dispersion)

    return central, dispersion, weights
Example #30
0
def _vector(weights: torch.Tensor, u: torch.Tensor):
    """
    Performs systematic resampling of a 1D array log weights.
    :param weights: The weights to use for resampling
    :return: Resampled indices
    """
    n = weights.shape[0]
    probs = (torch.arange(n, dtype=u.dtype, device=weights.device) + u) / n

    cumsum = weights.cumsum(0)
    cumsum[..., -1] = 1.

    return torch.searchsorted(cumsum, probs)