Esempio n. 1
0
    def _configure_thresholds(self):

        # Each bootstrap sample splits the reference samples into a sub-reference sample (x)
        # and an extended test window (y). The extended test window will be treated as W overlapping
        # test windows of size W (so 2W-1 test samples in total)

        w_size = self.window_size
        etw_size = 2 * w_size - 1  # etw = extended test window
        nkc_size = self.n - self.n_kernel_centers  # nkc = non-kernel-centers
        rw_size = nkc_size - etw_size  # rw = ref-window

        perms = [torch.randperm(nkc_size) for _ in range(self.n_bootstraps)]
        x_inds_all = [perm[:rw_size] for perm in perms]
        y_inds_all = [perm[rw_size:] for perm in perms]

        # For stability in high dimensions we don't divide H by (pi*sigma^2)^(d/2)
        # Results in an alternative test-stat of LSDD*(pi*sigma^2)^(d/2). Same p-vals etc.
        H = GaussianRBF(np.sqrt(2.) * self.kernel.sigma)(self.kernel_centers,
                                                         self.kernel_centers)

        # Compute lsdds for first test-window. We infer regularisation constant lambda here.
        y_inds_all_0 = [y_inds[:w_size] for y_inds in y_inds_all]
        lsdds_0, H_lam_inv = permed_lsdds(
            self.k_xc,
            x_inds_all,
            y_inds_all_0,
            H,
            lam_rd_max=self.lambda_rd_max,
        )

        # Can compute threshold for first window
        thresholds = [quantile(lsdds_0, 1 - self.fpr)]
        # And now to iterate through the other W-1 overlapping windows
        p_bar = tqdm(range(1, w_size),
                     "Computing thresholds") if self.verbose else range(
                         1, w_size)
        for w in p_bar:
            y_inds_all_w = [y_inds[w:(w + w_size)] for y_inds in y_inds_all]
            lsdds_w, _ = permed_lsdds(self.k_xc,
                                      x_inds_all,
                                      y_inds_all_w,
                                      H,
                                      H_lam_inv=H_lam_inv)
            thresholds.append(quantile(lsdds_w, 1 - self.fpr))
            x_inds_all = [
                x_inds_all[i] for i in range(len(x_inds_all))
                if lsdds_w[i] < thresholds[-1]
            ]
            y_inds_all = [
                y_inds_all[i] for i in range(len(y_inds_all))
                if lsdds_w[i] < thresholds[-1]
            ]

        self.thresholds = thresholds
        self.H_lam_inv = H_lam_inv
Esempio n. 2
0
def test_quantile(quantile_params):
    type, sorted = quantile_params

    sample = (0.5 + torch.arange(1e6)) / 1e6
    if not sorted:
        sample = sample[torch.randperm(len(sample))]

    np.testing.assert_almost_equal(quantile(sample,
                                            0.001,
                                            type=type,
                                            sorted=sorted),
                                   0.001,
                                   decimal=6)
    np.testing.assert_almost_equal(quantile(sample,
                                            0.999,
                                            type=type,
                                            sorted=sorted),
                                   0.999,
                                   decimal=6)

    assert quantile(torch.ones(100), 0.42, type=type, sorted=sorted) == 1
    with pytest.raises(ValueError):
        quantile(torch.ones(10), 0.999, type=type, sorted=sorted)
    with pytest.raises(ValueError):
        quantile(torch.ones(100, 100), 0.5, type=type, sorted=sorted)
Esempio n. 3
0
    def _configure_thresholds(self):

        # Each bootstrap sample splits the reference samples into a sub-reference sample (x)
        # and an extended test window (y). The extended test window will be treated as W overlapping
        # test windows of size W (so 2W-1 test samples in total)

        w_size = self.window_size
        etw_size = 2*w_size-1  # etw = extended test window
        rw_size = self.n - etw_size  # rw = sub-ref window

        perms = [torch.randperm(self.n) for _ in range(self.n_bootstraps)]
        x_inds_all = [perm[:-etw_size] for perm in perms]
        y_inds_all = [perm[-etw_size:] for perm in perms]

        if self.verbose:
            print("Generating permutations of kernel matrix..")
        # Need to compute mmd for each bs for each of W overlapping windows
        # Most of the computation can be done once however
        # We avoid summing the rw_size^2 submatrix for each bootstrap sample by instead computing the full
        # sum once and then subtracting the relavent parts (k_xx_sum = k_full_sum - 2*k_xy_sum - k_yy_sum).
        # We also reduce computation of k_xy_sum from O(nW) to O(W) by caching column sums

        k_full_sum = zero_diag(self.k_xx).sum()
        k_xy_col_sums_all = [
            self.k_xx[x_inds][:, y_inds].sum(0) for x_inds, y_inds in
            (tqdm(zip(x_inds_all, y_inds_all), total=self.n_bootstraps) if self.verbose else
                zip(x_inds_all, y_inds_all))
        ]
        k_xx_sums_all = [(
            k_full_sum - zero_diag(self.k_xx[y_inds][:, y_inds]).sum() - 2*k_xy_col_sums.sum()
        )/(rw_size*(rw_size-1)) for y_inds, k_xy_col_sums in zip(y_inds_all, k_xy_col_sums_all)]
        k_xy_col_sums_all = [k_xy_col_sums/(rw_size*w_size) for k_xy_col_sums in k_xy_col_sums_all]

        # Now to iterate through the W overlapping windows
        thresholds = []
        p_bar = tqdm(range(w_size), "Computing thresholds") if self.verbose else range(w_size)
        for w in p_bar:
            y_inds_all_w = [y_inds[w:w+w_size] for y_inds in y_inds_all]  # test windows of size w_size
            mmds = [(
                k_xx_sum +
                zero_diag(self.k_xx[y_inds_w][:, y_inds_w]).sum()/(w_size*(w_size-1)) -
                2*k_xy_col_sums[w:w+w_size].sum()
            ) for k_xx_sum, y_inds_w, k_xy_col_sums in zip(k_xx_sums_all, y_inds_all_w, k_xy_col_sums_all)
            ]
            mmds = torch.tensor(mmds)  # an mmd for each bootstrap sample

            # Now we discard all bootstrap samples for which mmd is in top (1/ert)% and record the thresholds
            thresholds.append(quantile(mmds, 1-self.fpr))
            y_inds_all = [y_inds_all[i] for i in range(len(y_inds_all)) if mmds[i] < thresholds[-1]]
            k_xx_sums_all = [
                k_xx_sums_all[i] for i in range(len(k_xx_sums_all)) if mmds[i] < thresholds[-1]
            ]
            k_xy_col_sums_all = [
                k_xy_col_sums_all[i] for i in range(len(k_xy_col_sums_all)) if mmds[i] < thresholds[-1]
            ]

        self.thresholds = thresholds