def unbalanced_transport(to_move, to_match, cover_fraction):
    all_diffs = to_move.unsqueeze(1) - to_match.unsqueeze(0)

    all_diffs = th.sum(all_diffs * all_diffs, dim=2)
    sorted_diffs, i_sorted = th.sort(all_diffs, dim=1)
    del all_diffs

    n_to_cover = int(np.round(cover_fraction * to_match.size()[0]))
    n_max_cut_off = n_to_cover
    n_step = n_max_cut_off // 20
    cut_off_found = False
    uniques_so_far = []
    for i_cut_off in range(0, n_max_cut_off, n_step):
        i_sorted_part = i_sorted[:,
                        i_cut_off:i_cut_off + n_step].contiguous().view(-1)
        this_unique_inds = np.unique(var_to_np(i_sorted_part))
        uniques_so_far = np.unique(
            np.concatenate((uniques_so_far, this_unique_inds)))
        if len(uniques_so_far) > n_to_cover:
            i_cut_off = i_cut_off + n_step
            i_cut_off = i_cut_off * 2
            cut_off_found = True
            break

    if not cut_off_found:
        i_cut_off = n_max_cut_off
    i_cut_off = np.minimum(i_cut_off, n_max_cut_off)

    i_sorted_part = i_sorted[:, :i_cut_off].contiguous().view(-1)

    unique_inds = np.unique(var_to_np(i_sorted_part))
    unique_inds = np_to_var(unique_inds, dtype=np.int64).cuda()
    part_to_match = to_match[unique_inds]

    part_cover_fraction = float(n_to_cover / float(part_to_match.size()[0]))
    assert cover_fraction > 0 and cover_fraction <= 1

    t_mat, diffs = unbalanced_transport_mat_squared_diff(
        to_move, part_to_match,
        cover_fraction=part_cover_fraction,
        return_diffs=True)

    t_mat, diffs, mask = only_used_tmat_diffs(t_mat, diffs)
    used_sample_inds = unique_inds[mask ^ 1]
    t_mat = t_mat[:-1]
    diffs = diffs[:-1]
    t_mat = t_mat / th.sum(t_mat)
    loss = th.sum(t_mat * diffs)
    rejected_mask = th.ones_like(to_match[:, 0] > 0)

    rejected_mask[used_sample_inds] = 0
    return loss, rejected_mask
Example #2
0
 def train_epoch(self,
                 inputs,
                 targets,
                 inputs_u,
                 linear_weights_u,
                 trans_loss_function,
                 directions_adv,
                 n_dir_matrices=1):
     loss = 0
     n_examples = 0
     for batch_X, batch_y in self.iterator.get_batches(
             inputs, targets, inputs_u, linear_weights_u):
         if n_dir_matrices > 0:
             dir_mats = [
                 sample_directions(self.means_per_dim.size()[1],
                                   True,
                                   cuda=batch_X.is_cuda)
                 for _ in range(n_dir_matrices)
             ]
             directions = th.cat(dir_mats, dim=0)
             if directions_adv is not None:
                 directions = th.cat((directions, directions_adv), dim=0)
         else:
             directions = directions_adv
         batch_loss = train_on_batch(batch_X, self.model,
                                     self.means_per_dim, self.stds_per_dim,
                                     batch_y, self.optimizer, directions,
                                     trans_loss_function)
         loss = loss + batch_loss * len(batch_X)
         n_examples = n_examples + batch_X.size()[0]
     mean_loss = var_to_np(loss / n_examples)[0]
     return mean_loss
Example #3
0
def optimize_v_adaptively(outs, v, sample_fn_opt, sample_fn_bin_dev,
                          bin_dev_threshold,
                          bin_dev_iters):
    # Optimize V
    n_updates_total = 0
    outs = outs.detach()
    gauss_samples = sample_fn_bin_dev()
    diffs = th.sum((outs.unsqueeze(dim=1) - gauss_samples.unsqueeze(dim=0)) ** 2, dim=2)
    init_lr = float(var_to_np(th.mean(th.min(diffs, dim=1)[0]))[0] * len(outs) / 50)
    optim_v_orig = th.optim.SGD([v], lr=init_lr)
    optim_v = ScheduledOptimizer(DivideSqrtUpdates(), optim_v_orig, True)

    i_updates, avg_v = optimize_v_optimizer(
      v, optim_v, outs.detach(), sample_fn_opt, max_iters=25)
    v.data = avg_v
    n_updates_total += i_updates + 1
    for _ in range(10):
        v.data = avg_v
        bincounts = sample_match_and_bincount(outs, v, sample_fn_bin_dev, iters=bin_dev_iters)
        bin_dev = np.mean(np.abs(bincounts - np.mean(bincounts)))
        if bin_dev < bin_dev_threshold:
            break
        i_updates, avg_v = optimize_v_optimizer(
            v, optim_v, outs.detach(), sample_fn_opt, max_iters=20)
        n_updates_total += i_updates + 1
        v.data = avg_v
    return bincounts, bin_dev, n_updates_total
Example #4
0
def transport_mat_from_diffs(diffs):
    transport_mat = ot.emd([], [], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    return transport_mat
Example #5
0
def collect_out_to_samples(diffs, v):
    min_diffs, inds = th.min(diffs - v.unsqueeze(1), dim=0)
    inds = var_to_np(inds)
    i_example_to_i_samples = [[] for _ in range(diffs.size()[0])]
    i_example_to_diffs = [[] for _ in range(diffs.size()[0])]
    for i_sample, i_out in enumerate(inds):
        i_example_to_i_samples[i_out].append(i_sample)
        i_example_to_diffs[i_out].append(min_diffs[i_sample])
    return i_example_to_i_samples, i_example_to_diffs
Example #6
0
def ot_euclidean_transport_mat(samples_a, samples_b):
    diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0)
    diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6))

    transport_mat = ot.emd([], [], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    return transport_mat
Example #7
0
def ot_emd_loss_for_samples(samples_a, samples_b):
    diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0)
    diffs = th.sum(diffs * diffs, dim=2)

    transport_mat = ot.emd([], [], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    eps = 1e-6
    loss = th.sqrt(th.sum(transport_mat * diffs) + eps)
    return loss
Example #8
0
def ot_emd_loss(outs, mean, std):
    gauss_samples = get_gauss_samples(len(outs), mean, std)
    diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0)
    del gauss_samples
    diffs = th.sum(diffs * diffs, dim=2)

    transport_mat = ot.emd([], [], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    eps = 1e-6
    loss = th.sqrt(th.sum(transport_mat * diffs) + eps)
    return loss
def unbalanced_transport_mat_squared_diff(samples_a, samples_b, cover_fraction,
                                          return_diffs=False):
    diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0)
    diffs = th.sum(diffs * diffs, dim=2)
    # add dummy point with distance 0 to everything
    dummy =  th.zeros_like(diffs[0:1,:])
    diffs = th.cat((diffs, dummy), dim=0)
    a = np.ones(len(samples_a)) / len(samples_a) * cover_fraction
    a = np.concatenate((a, [1 - cover_fraction]))
    transport_mat = ot.emd(a, [], var_to_np(diffs))
    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    transport_mat, diffs = ensure_on_same_device(transport_mat, diffs)
    if return_diffs:
        return transport_mat, diffs
    else:
        return transport_mat
Example #10
0
def get_batch(
    inputs,
    targets,
    rng,
    batch_size,
    with_replacement,
    i_class='all',
):
    if i_class == 'all':
        indices = list(range(len(inputs)))
    else:
        indices = np.flatnonzero(var_to_np(targets[:, i_class]) == 1)
    batch_inds = rng.choice(indices, size=batch_size, replace=with_replacement)
    th_inds = np_to_var(batch_inds, dtype=np.int64)
    th_inds, _ = ensure_on_same_device(th_inds, inputs)
    batch_X = inputs[th_inds]
    batch_y = targets[th_inds]
    return th_inds, batch_X, batch_y
Example #11
0
def ot_euclidean_loss(outs, mean, std, normalize_by_global_emp_std=False):
    gauss_samples = get_gauss_samples(len(outs), mean, std)

    diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0)
    del gauss_samples
    if normalize_by_global_emp_std:
        global_emp_std = th.mean(th.std(outs, dim=0))
        diffs = diffs / global_emp_std
    diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6))

    transport_mat = ot.emd([], [], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    loss = th.sum(transport_mat * diffs)
    return loss