示例#1
0
def radial_distance_loss_from_samples_for_gauss_dist(outs, mean, std,
                                                     n_test_samples,
                                                     adv_samples):
    if (n_test_samples is not None) and n_test_samples > 0:
        test_samples = get_gauss_samples(n_test_samples, mean, std)
    else:
        test_samples = None
    if adv_samples is not None:
        if test_samples is not None:
            test_samples = th.cat((test_samples, adv_samples), dim=0)
        else:
            test_samples = adv_samples
    gauss_samples = get_gauss_samples(len(outs), mean, std)
    return radial_distance_loss_from_samples_for_test_samples(
        outs, gauss_samples, test_samples)
示例#2
0
def get_amp_phase_samples(n_samples, mean, std, phase_dist, truncate_to):
    assert phase_dist in ["gauss", "uni"]
    i_half = len(mean) // 2

    amps = get_gauss_samples(n_samples,
                             mean[:i_half],
                             std[:i_half],
                             truncate_to=truncate_to)
    #amps = th.abs(amps)
    if phase_dist == 'uni':
        phases = get_uniform_samples(n_samples, mean[i_half:],
                                     std[i_half:] * 2 * np.pi)
    else:
        assert phase_dist == 'gauss'
        phases = get_gauss_samples(n_samples,
                                   mean[i_half:],
                                   std[i_half:] * 0.5 * np.pi,
                                   truncate_to=truncate_to * 0.5 * np.pi)

    return amps, phases
示例#3
0
def optimize_v(outs_main, means_per_cluster, stds_per_cluster, v, i_class,
               n_wanted_stds, norm_std_to):
    means_reduced, stds_reduced, largest_stds = reduce_dims_to_large_stds(
        means_per_cluster, stds_per_cluster, i_class, n_wanted_stds, norm_std_to=norm_std_to)
    outs_reduced = outs_main.index_select(dim=1, index=largest_stds)
    sample_fn_dt = lambda: get_gauss_samples(
        len(outs_reduced) * 1, means_reduced.detach(), stds_reduced.detach())
    bincounts, bin_dev, n_updates = optimize_v_adaptively(
        outs_reduced.detach(), v, sample_fn_dt, sample_fn_dt,
        bin_dev_threshold=0.75, bin_dev_iters=5)
    return bincounts, bin_dev, n_updates
def sample_wavelet(n_samples,
                   mean,
                   log_std,
                   truncate_to=3,
                   convert_fn=convert_haar_wavelet):
    wavelet_samples = []
    # sample first one outside
    mean_sample = get_gauss_samples(n_samples,
                                    mean[0:1],
                                    th.exp(log_std[0:1]),
                                    truncate_to=truncate_to)
    wavelet_samples.append(mean_sample)
    for i_exp in range(int(np.log2(len(mean)))):
        i_start = int(2**i_exp)
        i_stop = int(2**(i_exp + 1))
        this_mean = mean[i_start:i_stop]
        this_log_std = log_std[i_start:i_stop]
        this_samples = get_gauss_samples(n_samples,
                                         this_mean,
                                         th.exp(this_log_std),
                                         truncate_to=truncate_to)
        wavelet_samples.append(this_samples)
    return convert_fn(wavelet_samples)
示例#5
0
def ot_emd_loss(outs, mean, std):
    gauss_samples = get_gauss_samples(len(outs), mean, std)
    diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0)
    del gauss_samples
    diffs = th.sum(diffs * diffs, dim=2)

    transport_mat = ot.emd([], [], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    eps = 1e-6
    loss = th.sqrt(th.sum(transport_mat * diffs) + eps)
    return loss
示例#6
0
def get_perturbations(l_out, norm, perturb_dist):
    dims = int(np.prod(l_out.size()[1:]))
    mean = th.zeros(dims)
    std = th.ones(dims)
    mean = th.autograd.Variable(mean)
    std = th.autograd.Variable(std)
    _, mean, std = ensure_on_same_device(l_out, mean, std)
    if perturb_dist == 'uniform':
        perturbations = get_uniform_samples(l_out.size()[0], mean, std)
    else:
        assert perturb_dist == 'gaussian'
        perturbations = get_gauss_samples(l_out.size()[0], mean, std)

    perturbations = norm * (perturbations / th.sqrt(th.sum(perturbations ** 2, dim=1, keepdim=True)))
    perturbations = perturbations.view(l_out.size())
    return perturbations
示例#7
0
def ot_euclidean_loss(outs, mean, std, normalize_by_global_emp_std=False):
    gauss_samples = get_gauss_samples(len(outs), mean, std)

    diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0)
    del gauss_samples
    if normalize_by_global_emp_std:
        global_emp_std = th.mean(th.std(outs, dim=0))
        diffs = diffs / global_emp_std
    diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6))

    transport_mat = ot.emd([], [], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0 / (diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    loss = th.sum(transport_mat * diffs)
    return loss
def sample_hierarchically(n_samples, mean, log_stds):
    cur_mean = mean
    covs = th.zeros((len(cur_mean), len(cur_mean)), dtype=th.float32)
    hierarchical_samples = []
    for i_exp in range(int(np.log2(len(cur_mean))) + 1):
        cur_mean = th.stack(th.chunk(cur_mean, int(2**i_exp)))
        this_mean = th.mean(cur_mean, dim=1, keepdim=True)
        cur_mean = cur_mean - this_mean
        cur_mean = cur_mean.view(-1)
        this_log_std = log_stds[i_exp]
        # sample...
        this_samples = get_gauss_samples(n_samples, this_mean.squeeze(-1),
                                         th.exp(this_log_std).squeeze(-1))
        hierarchical_samples.append(this_samples)
        # compute cov matrix
        for i_part in range(2**i_exp):
            i_1, i_2 = int((i_part / 2**i_exp) * len(covs)), int(
                ((i_part + 1) / 2**i_exp) * len(covs))
            covs[i_1:i_2, i_1:i_2] += (th.exp(this_log_std[i_part])**2)
    samples = convert_hierarchical_samples_to_samples(hierarchical_samples,
                                                      len(mean))
    return samples, covs
示例#9
0
def sliced_from_samples_for_gauss_dist(outs, mean, std, n_dirs, adv_dirs,
                                       **kwargs):
    gauss_samples = get_gauss_samples(len(outs), mean, std)
    return sliced_from_samples(outs, gauss_samples, n_dirs, adv_dirs, **kwargs)
示例#10
0
def ot_euclidean_energy_loss(outs, mean, std):
    gauss_samples = get_gauss_samples(len(outs), mean, std)
    o1, o2 = th.chunk(outs, 2, dim=0)
    g1, g2 = th.chunk(gauss_samples, 2, dim=0)
    return ot_eucledian_energy_loss(o1, o2, g1, g2)