Beispiel #1
0
def mutual_info_metric_cars3d(vae, shapes_dataset):
    dataset_loader = DataLoader(shapes_dataset,
                                batch_size=1000,
                                num_workers=1,
                                shuffle=False)

    N = len(dataset_loader.dataset)  # number of data samples
    K = vae.z_dim  # number of latent variables
    nparams = vae.q_dist.nparams
    vae.eval()

    print('Computing q(z|x) distributions.')
    qz_params = torch.Tensor(N, K, nparams)

    n = 0
    with torch.no_grad():
        for xs in dataset_loader:
            batch_size = xs.size(0)
            xs = Variable(xs.view(batch_size, 3, 64, 64).cuda())
            qz_params[n:n + batch_size] = vae.encoder.forward(xs).view(
                batch_size, vae.z_dim, nparams).data
            if vae.var_clipping:
                qz_params.data[n:n + batch_size, ..., 1] = torch.clamp(
                    qz_params.data[n:n + batch_size, ..., 1],
                    vae.lowerbound_ln_var, vae.upperbound_ln_var)
            n += batch_size

    qz_params = Variable(qz_params.view(183, 4, 24, K, nparams).cuda())
    qz_samples = vae.q_dist.sample(params=qz_params)

    print('Estimating marginal entropies.')
    # marginal entropies
    marginal_entropies = estimate_entropies(
        qz_samples.view(N, K).transpose(0, 1), qz_params.view(N, K, nparams),
        vae.q_dist)

    marginal_entropies = marginal_entropies.cpu()
    cond_entropies = torch.zeros(3, K)

    print('Estimating conditional entropies for yaw.')
    for i in range(4):
        qz_samples_pose_az = qz_samples[:, i, :, :].contiguous()
        qz_params_pose_az = qz_params[:, i, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_pose_az.view(N // 4, K).transpose(0, 1),
            qz_params_pose_az.view(N // 4, K, nparams), vae.q_dist)

        cond_entropies[0] += cond_entropies_i.cpu() / 4

    print('Estimating conditional entropies for pitch.')
    for i in range(24):
        qz_samples_pose_el = qz_samples[:, :, i, :].contiguous()
        qz_params_pose_el = qz_params[:, :, i, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_pose_el.view(N // 24, K).transpose(0, 1),
            qz_params_pose_el.view(N // 24, K, nparams), vae.q_dist)

        cond_entropies[1] += cond_entropies_i.cpu() / 24

    metric = compute_metric_faces(marginal_entropies, cond_entropies)
    return metric, marginal_entropies, cond_entropies
Beispiel #2
0
def mutual_info_metric_faces(vae, shapes_dataset):
    dataset_loader = DataLoader(shapes_dataset,
                                batch_size=1000,
                                num_workers=1,
                                shuffle=False)

    N = len(dataset_loader.dataset)  # number of data samples
    K = vae.z_dim  # number of latent variables
    nparams = vae.q_dist.nparams
    vae.eval()

    print('Computing q(z|x) distributions.')
    qz_params = torch.Tensor(N, K, nparams)

    n = 0
    for xs in dataset_loader:
        batch_size = xs.size(0)
        xs = Variable(xs.view(batch_size, 1, 64, 64), volatile=True)
        qz_params[n:n + batch_size] = vae.encoder.forward(xs).view(
            batch_size, vae.z_dim, nparams).data
        n += batch_size

    qz_params = Variable(qz_params.view(50, 21, 11, 11, K, nparams))
    qz_samples = vae.q_dist.sample(params=qz_params)

    print('Estimating marginal entropies.')
    # marginal entropies
    marginal_entropies = estimate_entropies(
        qz_samples.view(N, K).transpose(0, 1), qz_params.view(N, K, nparams),
        vae.q_dist)

    marginal_entropies = marginal_entropies.cpu()
    cond_entropies = torch.zeros(3, K)

    print('Estimating conditional entropies for azimuth.')
    for i in range(21):
        qz_samples_pose_az = qz_samples[:, i, :, :, :].contiguous()
        qz_params_pose_az = qz_params[:, i, :, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_pose_az.view(N // 21, K).transpose(0, 1),
            qz_params_pose_az.view(N // 21, K, nparams), vae.q_dist)

        cond_entropies[0] += cond_entropies_i.cpu() / 21

    print('Estimating conditional entropies for elevation.')
    for i in range(11):
        qz_samples_pose_el = qz_samples[:, :, i, :, :].contiguous()
        qz_params_pose_el = qz_params[:, :, i, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_pose_el.view(N // 11, K).transpose(0, 1),
            qz_params_pose_el.view(N // 11, K, nparams), vae.q_dist)

        cond_entropies[1] += cond_entropies_i.cpu() / 11

    print('Estimating conditional entropies for lighting.')
    for i in range(11):
        qz_samples_lighting = qz_samples[:, :, :, i, :].contiguous()
        qz_params_lighting = qz_params[:, :, :, i, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_lighting.view(N // 11, K).transpose(0, 1),
            qz_params_lighting.view(N // 11, K, nparams), vae.q_dist)

        cond_entropies[2] += cond_entropies_i.cpu() / 11

    metric = compute_metric_faces(marginal_entropies, cond_entropies)
    return metric, marginal_entropies, cond_entropies
Beispiel #3
0
def mutual_info_metric_faces(vae, shapes_dataset):
    dataset_loader = DataLoader(shapes_dataset, batch_size=1000, num_workers=1, shuffle=False)

    N = len(dataset_loader.dataset)  # number of data samples
    K = 10                    # number of latent variables
    nparams = dist.Normal().nparams
    vae.eval()

    print('Computing q(z|x) distributions.')
    qz_params = torch.Tensor(N, K, nparams)

    n = 0
    for xs in dataset_loader:
        batch_size = xs.size(0)
        xs = xs.view(batch_size, 1, 64, 64).cuda()
        
        z, mu, logvar, y = vae(xs)
        mu = mu.view(batch_size, K, 1)
        logvar = logvar.view(batch_size, K, 1) 
        target = torch.cat([mu, logvar], dim=2)

        qz_params[n:n + batch_size] = target.view(batch_size, K, nparams).data
        n += batch_size

    qz_params = qz_params.view(50, 21, 11, 11, K, nparams).cuda()
    qz_samples = dist.Normal().sample(params=qz_params)

    print('Estimating marginal entropies.')
    # marginal entropies
    marginal_entropies = estimate_entropies(
        qz_samples.view(N, K).transpose(0, 1),
        qz_params.view(N, K, nparams),
        dist.Normal())

    marginal_entropies = marginal_entropies.cpu()
    cond_entropies = torch.zeros(3, K)

    print('Estimating conditional entropies for azimuth.')
    for i in range(21):
        qz_samples_pose_az = qz_samples[:, i, :, :, :].contiguous()
        qz_params_pose_az = qz_params[:, i, :, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_pose_az.view(N // 21, K).transpose(0, 1),
            qz_params_pose_az.view(N // 21, K, nparams),
            dist.Normal())

        cond_entropies[0] += cond_entropies_i.cpu() / 21

    print('Estimating conditional entropies for elevation.')
    for i in range(11):
        qz_samples_pose_el = qz_samples[:, :, i, :, :].contiguous()
        qz_params_pose_el = qz_params[:, :, i, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_pose_el.view(N // 11, K).transpose(0, 1),
            qz_params_pose_el.view(N // 11, K, nparams),
            dist.Normal())

        cond_entropies[1] += cond_entropies_i.cpu() / 11

    print('Estimating conditional entropies for lighting.')
    for i in range(11):
        qz_samples_lighting = qz_samples[:, :, :, i, :].contiguous()
        qz_params_lighting = qz_params[:, :, :, i, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_lighting.view(N // 11, K).transpose(0, 1),
            qz_params_lighting.view(N // 11, K, nparams),
            dist.Normal())

        cond_entropies[2] += cond_entropies_i.cpu() / 11

    metric = compute_metric_faces(marginal_entropies, cond_entropies)
    return metric, marginal_entropies, cond_entropies