def merge_samples(i_example_to_i_samples, i_example_to_i_samples_2, gauss_samples, gauss_samples_2): inds_a = np.sort(np.concatenate(i_example_to_i_samples)) th_inds_a = np_to_var(inds_a, dtype=np.int64) th_inds_a, _ = ensure_on_same_device( th_inds_a, gauss_samples) samples_a = gauss_samples[th_inds_a] inds_b = np.sort(np.concatenate(i_example_to_i_samples_2)) if len(inds_b) > 0: th_inds_b = np_to_var(inds_b, dtype=np.int64) th_inds_b, _ = ensure_on_same_device( th_inds_b, gauss_samples) samples_b = gauss_samples_2[th_inds_b] a_dict = dict([(val, i) for i,val in enumerate(inds_a)]) b_dict = dict([(val, i + len(a_dict)) for i,val in enumerate(inds_b)]) # merge samples i_example_to_i_samples_merged = [] for i_example in range(len(i_example_to_i_samples)): a_examples = [a_dict[i] for i in i_example_to_i_samples[i_example]] b_examples = [b_dict[i] for i in i_example_to_i_samples_2[i_example]] i_example_to_i_samples_merged.append(a_examples + b_examples) if len(inds_b) > 0: all_samples = th.cat((samples_a, samples_b), dim=0) else: all_samples = samples_a return all_samples, i_example_to_i_samples_merged
def transport_mat_from_diffs(diffs): transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) return transport_mat
def ot_euclidean_transport_mat(samples_a, samples_b): diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0) diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6)) transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) return transport_mat
def unbalanced_transport(to_move, to_match, cover_fraction): all_diffs = to_move.unsqueeze(1) - to_match.unsqueeze(0) all_diffs = th.sum(all_diffs * all_diffs, dim=2) sorted_diffs, i_sorted = th.sort(all_diffs, dim=1) del all_diffs n_to_cover = int(np.round(cover_fraction * to_match.size()[0])) n_max_cut_off = n_to_cover n_step = n_max_cut_off // 20 cut_off_found = False uniques_so_far = [] for i_cut_off in range(0, n_max_cut_off, n_step): i_sorted_part = i_sorted[:, i_cut_off:i_cut_off + n_step].contiguous().view(-1) this_unique_inds = np.unique(var_to_np(i_sorted_part)) uniques_so_far = np.unique( np.concatenate((uniques_so_far, this_unique_inds))) if len(uniques_so_far) > n_to_cover: i_cut_off = i_cut_off + n_step i_cut_off = i_cut_off * 2 cut_off_found = True break if not cut_off_found: i_cut_off = n_max_cut_off i_cut_off = np.minimum(i_cut_off, n_max_cut_off) i_sorted_part = i_sorted[:, :i_cut_off].contiguous().view(-1) unique_inds = np.unique(var_to_np(i_sorted_part)) unique_inds = np_to_var(unique_inds, dtype=np.int64).cuda() part_to_match = to_match[unique_inds] part_cover_fraction = float(n_to_cover / float(part_to_match.size()[0])) assert cover_fraction > 0 and cover_fraction <= 1 t_mat, diffs = unbalanced_transport_mat_squared_diff( to_move, part_to_match, cover_fraction=part_cover_fraction, return_diffs=True) t_mat, diffs, mask = only_used_tmat_diffs(t_mat, diffs) used_sample_inds = unique_inds[mask ^ 1] t_mat = t_mat[:-1] diffs = diffs[:-1] t_mat = t_mat / th.sum(t_mat) loss = th.sum(t_mat * diffs) rejected_mask = th.ones_like(to_match[:, 0] > 0) rejected_mask[used_sample_inds] = 0 return loss, rejected_mask
def ot_emd_loss_for_samples(samples_a, samples_b): diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0) diffs = th.sum(diffs * diffs, dim=2) transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) eps = 1e-6 loss = th.sqrt(th.sum(transport_mat * diffs) + eps) return loss
def get_batches(self, X, y, shuffle): n_trials = len(X) batches = get_balanced_batches(n_trials, batch_size=self.batch_size, rng=self.rng, shuffle=shuffle) for batch_inds in batches: batch_inds = np_to_var(batch_inds, dtype=np.int64) if X.is_cuda: batch_inds = batch_inds.cuda() batch_X = X[batch_inds] batch_y = y[batch_inds] yield (batch_X, batch_y)
def ot_emd_loss(outs, mean, std): gauss_samples = get_gauss_samples(len(outs), mean, std) diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0) del gauss_samples diffs = th.sum(diffs * diffs, dim=2) transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) eps = 1e-6 loss = th.sqrt(th.sum(transport_mat * diffs) + eps) return loss
def unbalanced_transport_mat_squared_diff(samples_a, samples_b, cover_fraction, return_diffs=False): diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0) diffs = th.sum(diffs * diffs, dim=2) # add dummy point with distance 0 to everything dummy = th.zeros_like(diffs[0:1,:]) diffs = th.cat((diffs, dummy), dim=0) a = np.ones(len(samples_a)) / len(samples_a) * cover_fraction a = np.concatenate((a, [1 - cover_fraction])) transport_mat = ot.emd(a, [], var_to_np(diffs)) transport_mat = np_to_var(transport_mat, dtype=np.float32) transport_mat, diffs = ensure_on_same_device(transport_mat, diffs) if return_diffs: return transport_mat, diffs else: return transport_mat
def get_batch( inputs, targets, rng, batch_size, with_replacement, i_class='all', ): if i_class == 'all': indices = list(range(len(inputs))) else: indices = np.flatnonzero(var_to_np(targets[:, i_class]) == 1) batch_inds = rng.choice(indices, size=batch_size, replace=with_replacement) th_inds = np_to_var(batch_inds, dtype=np.int64) th_inds, _ = ensure_on_same_device(th_inds, inputs) batch_X = inputs[th_inds] batch_y = targets[th_inds] return th_inds, batch_X, batch_y
def ot_euclidean_loss(outs, mean, std, normalize_by_global_emp_std=False): gauss_samples = get_gauss_samples(len(outs), mean, std) diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0) del gauss_samples if normalize_by_global_emp_std: global_emp_std = th.mean(th.std(outs, dim=0)) diffs = diffs / global_emp_std diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6)) transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) loss = th.sum(transport_mat * diffs) return loss