def test_subset_matrix(): mat = tf.range(5)[None, :] * tf.range(5)[:, None] inds_0 = [2, 3] inds_1 = [2, 1, 4] sub_mat = subset_matrix(mat, tf.constant(inds_0), tf.constant(inds_1)) assert sub_mat.shape == (2, 3) for i, ind_0 in enumerate(inds_0): for j, ind_1 in enumerate(inds_1): assert sub_mat[i, j] == ind_0 * ind_1 with pytest.raises(ValueError): subset_matrix(tf.ones((10, 10, 10)), inds_0, inds_1) with pytest.raises(ValueError): subset_matrix(tf.ones((10, )), inds_0, inds_1)
def _configure_ref_subset(self): etw_size = 2 * self.window_size - 1 # etw = extended test window rw_size = self.n - etw_size # rw = ref window# # Make split and ensure it doesn't cause an initial detection mmd_init = None while mmd_init is None or mmd_init >= self.get_threshold(0): # Make split perm = tf.random.shuffle(tf.range(self.n)) self.ref_inds, self.init_test_inds = perm[:rw_size], perm[ -self.window_size:] self.test_window = tf.gather(self.x_ref, self.init_test_inds) # Compute initial mmd to check for initial detection self.k_xx_sub = subset_matrix(self.k_xx, self.ref_inds, self.ref_inds) self.k_xx_sub_sum = tf.reduce_sum(zero_diag( self.k_xx_sub)) / (rw_size * (rw_size - 1)) self.k_xy = self.kernel(tf.gather(self.x_ref, self.ref_inds), self.test_window) k_yy = self.kernel(self.test_window, self.test_window) mmd_init = (self.k_xx_sub_sum + tf.reduce_sum(zero_diag(k_yy)) / (self.window_size * (self.window_size - 1)) - 2 * tf.reduce_mean(self.k_xy))
def _configure_thresholds(self): # Each bootstrap sample splits the reference samples into a sub-reference sample (x) # and an extended test window (y). The extended test window will be treated as W overlapping # test windows of size W (so 2W-1 test samples in total) w_size = self.window_size etw_size = 2 * w_size - 1 # etw = extended test window rw_size = self.n - etw_size # rw = ref window perms = [ tf.random.shuffle(tf.range(self.n)) for _ in range(self.n_bootstraps) ] x_inds_all = [perm[:-etw_size] for perm in perms] y_inds_all = [perm[-etw_size:] for perm in perms] if self.verbose: print("Generating permutations of kernel matrix..") # Need to compute mmd for each bs for each of W overlapping windows # Most of the computation can be done once however # We avoid summing the rw_size^2 submatrix for each bootstrap sample by instead computing the full # sum once and then subtracting the relavent parts (k_xx_sum = k_full_sum - 2*k_xy_sum - k_yy_sum). # We also reduce computation of k_xy_sum from O(nW) to O(W) by caching column sums k_full_sum = tf.reduce_sum(zero_diag(self.k_xx)) k_xy_col_sums_all = [ tf.reduce_sum(subset_matrix(self.k_xx, x_inds, y_inds), axis=0) for x_inds, y_inds in ( tqdm(zip(x_inds_all, y_inds_all), total=self.n_bootstraps ) if self.verbose else zip(x_inds_all, y_inds_all)) ] k_xx_sums_all = [ (k_full_sum - tf.reduce_sum( zero_diag(subset_matrix(self.k_xx, y_inds, y_inds))) - 2 * tf.reduce_sum(k_xy_col_sums)) / (rw_size * (rw_size - 1)) for y_inds, k_xy_col_sums in zip(y_inds_all, k_xy_col_sums_all) ] k_xy_col_sums_all = [ k_xy_col_sums / (rw_size * w_size) for k_xy_col_sums in k_xy_col_sums_all ] # Now to iterate through the W overlapping windows thresholds = [] p_bar = tqdm(range(w_size), "Computing thresholds") if self.verbose else range(w_size) for w in p_bar: y_inds_all_w = [y_inds[w:w + w_size] for y_inds in y_inds_all] # test windows of size W mmds = [(k_xx_sum + tf.reduce_sum( zero_diag(subset_matrix(self.k_xx, y_inds_w, y_inds_w))) / (w_size * (w_size - 1)) - 2 * tf.reduce_sum(k_xy_col_sums[w:w + w_size])) for k_xx_sum, y_inds_w, k_xy_col_sums in zip( k_xx_sums_all, y_inds_all_w, k_xy_col_sums_all)] mmds = tf.concat(mmds, axis=0) # an mmd for each bootstrap sample # Now we discard all bootstrap samples for which mmd is in top (1/ert)% and record the thresholds thresholds.append(quantile(mmds, 1 - self.fpr)) y_inds_all = [ y_inds_all[i] for i in range(len(y_inds_all)) if mmds[i] < thresholds[-1] ] k_xx_sums_all = [ k_xx_sums_all[i] for i in range(len(k_xx_sums_all)) if mmds[i] < thresholds[-1] ] k_xy_col_sums_all = [ k_xy_col_sums_all[i] for i in range(len(k_xy_col_sums_all)) if mmds[i] < thresholds[-1] ] self.thresholds = thresholds