def unbalanced_transport(to_move, to_match, cover_fraction): all_diffs = to_move.unsqueeze(1) - to_match.unsqueeze(0) all_diffs = th.sum(all_diffs * all_diffs, dim=2) sorted_diffs, i_sorted = th.sort(all_diffs, dim=1) del all_diffs n_to_cover = int(np.round(cover_fraction * to_match.size()[0])) n_max_cut_off = n_to_cover n_step = n_max_cut_off // 20 cut_off_found = False uniques_so_far = [] for i_cut_off in range(0, n_max_cut_off, n_step): i_sorted_part = i_sorted[:, i_cut_off:i_cut_off + n_step].contiguous().view(-1) this_unique_inds = np.unique(var_to_np(i_sorted_part)) uniques_so_far = np.unique( np.concatenate((uniques_so_far, this_unique_inds))) if len(uniques_so_far) > n_to_cover: i_cut_off = i_cut_off + n_step i_cut_off = i_cut_off * 2 cut_off_found = True break if not cut_off_found: i_cut_off = n_max_cut_off i_cut_off = np.minimum(i_cut_off, n_max_cut_off) i_sorted_part = i_sorted[:, :i_cut_off].contiguous().view(-1) unique_inds = np.unique(var_to_np(i_sorted_part)) unique_inds = np_to_var(unique_inds, dtype=np.int64).cuda() part_to_match = to_match[unique_inds] part_cover_fraction = float(n_to_cover / float(part_to_match.size()[0])) assert cover_fraction > 0 and cover_fraction <= 1 t_mat, diffs = unbalanced_transport_mat_squared_diff( to_move, part_to_match, cover_fraction=part_cover_fraction, return_diffs=True) t_mat, diffs, mask = only_used_tmat_diffs(t_mat, diffs) used_sample_inds = unique_inds[mask ^ 1] t_mat = t_mat[:-1] diffs = diffs[:-1] t_mat = t_mat / th.sum(t_mat) loss = th.sum(t_mat * diffs) rejected_mask = th.ones_like(to_match[:, 0] > 0) rejected_mask[used_sample_inds] = 0 return loss, rejected_mask
def train_epoch(self, inputs, targets, inputs_u, linear_weights_u, trans_loss_function, directions_adv, n_dir_matrices=1): loss = 0 n_examples = 0 for batch_X, batch_y in self.iterator.get_batches( inputs, targets, inputs_u, linear_weights_u): if n_dir_matrices > 0: dir_mats = [ sample_directions(self.means_per_dim.size()[1], True, cuda=batch_X.is_cuda) for _ in range(n_dir_matrices) ] directions = th.cat(dir_mats, dim=0) if directions_adv is not None: directions = th.cat((directions, directions_adv), dim=0) else: directions = directions_adv batch_loss = train_on_batch(batch_X, self.model, self.means_per_dim, self.stds_per_dim, batch_y, self.optimizer, directions, trans_loss_function) loss = loss + batch_loss * len(batch_X) n_examples = n_examples + batch_X.size()[0] mean_loss = var_to_np(loss / n_examples)[0] return mean_loss
def optimize_v_adaptively(outs, v, sample_fn_opt, sample_fn_bin_dev, bin_dev_threshold, bin_dev_iters): # Optimize V n_updates_total = 0 outs = outs.detach() gauss_samples = sample_fn_bin_dev() diffs = th.sum((outs.unsqueeze(dim=1) - gauss_samples.unsqueeze(dim=0)) ** 2, dim=2) init_lr = float(var_to_np(th.mean(th.min(diffs, dim=1)[0]))[0] * len(outs) / 50) optim_v_orig = th.optim.SGD([v], lr=init_lr) optim_v = ScheduledOptimizer(DivideSqrtUpdates(), optim_v_orig, True) i_updates, avg_v = optimize_v_optimizer( v, optim_v, outs.detach(), sample_fn_opt, max_iters=25) v.data = avg_v n_updates_total += i_updates + 1 for _ in range(10): v.data = avg_v bincounts = sample_match_and_bincount(outs, v, sample_fn_bin_dev, iters=bin_dev_iters) bin_dev = np.mean(np.abs(bincounts - np.mean(bincounts))) if bin_dev < bin_dev_threshold: break i_updates, avg_v = optimize_v_optimizer( v, optim_v, outs.detach(), sample_fn_opt, max_iters=20) n_updates_total += i_updates + 1 v.data = avg_v return bincounts, bin_dev, n_updates_total
def transport_mat_from_diffs(diffs): transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) return transport_mat
def collect_out_to_samples(diffs, v): min_diffs, inds = th.min(diffs - v.unsqueeze(1), dim=0) inds = var_to_np(inds) i_example_to_i_samples = [[] for _ in range(diffs.size()[0])] i_example_to_diffs = [[] for _ in range(diffs.size()[0])] for i_sample, i_out in enumerate(inds): i_example_to_i_samples[i_out].append(i_sample) i_example_to_diffs[i_out].append(min_diffs[i_sample]) return i_example_to_i_samples, i_example_to_diffs
def ot_euclidean_transport_mat(samples_a, samples_b): diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0) diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6)) transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) return transport_mat
def ot_emd_loss_for_samples(samples_a, samples_b): diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0) diffs = th.sum(diffs * diffs, dim=2) transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) eps = 1e-6 loss = th.sqrt(th.sum(transport_mat * diffs) + eps) return loss
def ot_emd_loss(outs, mean, std): gauss_samples = get_gauss_samples(len(outs), mean, std) diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0) del gauss_samples diffs = th.sum(diffs * diffs, dim=2) transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) eps = 1e-6 loss = th.sqrt(th.sum(transport_mat * diffs) + eps) return loss
def unbalanced_transport_mat_squared_diff(samples_a, samples_b, cover_fraction, return_diffs=False): diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0) diffs = th.sum(diffs * diffs, dim=2) # add dummy point with distance 0 to everything dummy = th.zeros_like(diffs[0:1,:]) diffs = th.cat((diffs, dummy), dim=0) a = np.ones(len(samples_a)) / len(samples_a) * cover_fraction a = np.concatenate((a, [1 - cover_fraction])) transport_mat = ot.emd(a, [], var_to_np(diffs)) transport_mat = np_to_var(transport_mat, dtype=np.float32) transport_mat, diffs = ensure_on_same_device(transport_mat, diffs) if return_diffs: return transport_mat, diffs else: return transport_mat
def get_batch( inputs, targets, rng, batch_size, with_replacement, i_class='all', ): if i_class == 'all': indices = list(range(len(inputs))) else: indices = np.flatnonzero(var_to_np(targets[:, i_class]) == 1) batch_inds = rng.choice(indices, size=batch_size, replace=with_replacement) th_inds = np_to_var(batch_inds, dtype=np.int64) th_inds, _ = ensure_on_same_device(th_inds, inputs) batch_X = inputs[th_inds] batch_y = targets[th_inds] return th_inds, batch_X, batch_y
def ot_euclidean_loss(outs, mean, std, normalize_by_global_emp_std=False): gauss_samples = get_gauss_samples(len(outs), mean, std) diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0) del gauss_samples if normalize_by_global_emp_std: global_emp_std = th.mean(th.std(outs, dim=0)) diffs = diffs / global_emp_std diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6)) transport_mat = ot.emd([], [], var_to_np(diffs)) # sometimes weird low values, try to prevent them transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel()))) transport_mat = np_to_var(transport_mat, dtype=np.float32) diffs, transport_mat = ensure_on_same_device(diffs, transport_mat) loss = th.sum(transport_mat * diffs) return loss