Exemplo n.º 1
0
Arquivo: losses.py Projeto: yyht/rrws
def get_defensive_wake_phi_loss(generative_model,
                                inference_network,
                                obs,
                                delta,
                                num_particles=1):
    num_mixtures = inference_network.num_mixtures
    batch_size = len(obs)
    latent_dist = inference_network.get_latent_dist(obs)
    latent = inference_network.sample_from_latent_dist(latent_dist,
                                                       num_particles)
    latent_uniform = torch.distributions.OneHotCategorical(
        logits=torch.ones((num_mixtures, ))).sample(
            (num_particles, batch_size))

    catted = torch.cat([x.unsqueeze(0) for x in [latent, latent_uniform]],
                       dim=0)
    indices = torch.distributions.Bernoulli(probs=torch.tensor(delta)).sample(
        (num_particles, batch_size)).unsqueeze(0).unsqueeze(-1).expand(
            1, num_particles, batch_size, num_mixtures).long()
    latent_mixture = torch.gather(catted, 0, indices).squeeze(0)
    log_p = generative_model.get_log_prob(latent_mixture, obs).transpose(0, 1)
    log_latent = inference_network.get_log_prob_from_latent_dist(
        latent_dist, latent).transpose(0, 1)

    log_uniform = torch.ones_like(log_latent) * (-np.log(num_mixtures))
    log_q_mixture = util.logaddexp(log_latent + np.log(1 - delta),
                                   log_uniform + np.log(delta))
    log_q = inference_network.get_log_prob_from_latent_dist(
        latent_dist, latent_mixture).transpose(0, 1)
    log_weight = log_p - log_q_mixture
    normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1)
    return torch.mean(-torch.sum(normalized_weight.detach() * log_q, dim=1))
Exemplo n.º 2
0
def get_wake_phi_loss_from_log_weight_and_log_q(log_weight, log_q):
    """Returns:
        loss: scalar that we call .backward() on and step the optimizer.
    """
    normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1)
    return torch.mean(-torch.sum(normalized_weight.detach() * log_q, dim=1))
Exemplo n.º 3
0
def get_thermo_loss_different_samples(generative_model,
                                      inference_network,
                                      obs,
                                      partition=None,
                                      num_particles=1,
                                      integration='left'):
    """Thermo loss gradient estimator computed using two set of importance
        samples.

    Args:
        generative_model: models.GenerativeModel object
        inference_network: models.InferenceNetwork object
        obs: tensor of shape [batch_size]
        partition: partition of [0, 1];
            tensor of shape [num_partitions + 1] where partition[0] is zero and
            partition[-1] is one;
            see https://en.wikipedia.org/wiki/Partition_of_an_interval
        num_particles: int
        integration: left, right or trapz

    Returns:
        loss: scalar that we call .backward() on and step the optimizer.
        elbo: average elbo over data
    """

    log_weights, log_ps, log_qs, heated_normalized_weights = [], [], [], []
    for _ in range(2):
        log_weight, log_p, log_q = get_log_weight_log_p_log_q(
            generative_model, inference_network, obs, num_particles)
        log_weights.append(log_weight)
        log_ps.append(log_p)
        log_qs.append(log_q)

        heated_log_weight = log_weight.unsqueeze(-1) * partition
        heated_normalized_weights.append(
            util.exponentiate_and_normalize(heated_log_weight, dim=1))

    w_detached = heated_normalized_weights[0].detach()
    thermo_logp = partition * log_ps[0].unsqueeze(-1) + \
        (1 - partition) * log_qs[0].unsqueeze(-1)
    wf = heated_normalized_weights[1] * log_weights[1].unsqueeze(-1)

    if num_particles == 1:
        correction = 1
    else:
        correction = num_particles / (num_particles - 1)

    thing_to_add = correction * torch.sum(
        w_detached *
        (log_weight.unsqueeze(-1) -
         torch.sum(wf, dim=1, keepdim=True)).detach() * thermo_logp,
        dim=1)

    multiplier = torch.zeros_like(partition)
    if integration == 'trapz':
        multiplier[0] = 0.5 * (partition[1] - partition[0])
        multiplier[1:-1] = 0.5 * (partition[2:] - partition[0:-2])
        multiplier[-1] = 0.5 * (partition[-1] - partition[-2])
    elif integration == 'left':
        multiplier[:-1] = partition[1:] - partition[:-1]
    elif integration == 'right':
        multiplier[1:] = partition[1:] - partition[:-1]

    loss = -torch.mean(
        torch.sum(multiplier *
                  (thing_to_add +
                   torch.sum(w_detached * log_weight.unsqueeze(-1), dim=1)),
                  dim=1))

    log_evidence = torch.logsumexp(log_weight, dim=1) - np.log(num_particles)
    elbo = torch.mean(log_evidence)

    return loss, elbo
Exemplo n.º 4
0
def get_thermo_loss_from_log_weight_log_p_log_q(log_weight,
                                                log_p,
                                                log_q,
                                                partition,
                                                num_particles=1,
                                                integration='left'):
    """Args:
        log_weight: tensor of shape [batch_size, num_particles]
        log_p: tensor of shape [batch_size, num_particles]
        log_q: tensor of shape [batch_size, num_particles]
        partition: partition of [0, 1];
            tensor of shape [num_partitions + 1] where partition[0] is zero and
            partition[-1] is one;
            see https://en.wikipedia.org/wiki/Partition_of_an_interval
        num_particles: int
        integration: left, right or trapz

    Returns:
        loss: scalar that we call .backward() on and step the optimizer.
        elbo: average elbo over data
    """
    heated_log_weight = log_weight.unsqueeze(-1) * partition
    heated_normalized_weight = util.exponentiate_and_normalize(
        heated_log_weight, dim=1)
    thermo_logp = partition * log_p.unsqueeze(-1) + \
        (1 - partition) * log_q.unsqueeze(-1)

    wf = heated_normalized_weight * log_weight.unsqueeze(-1)
    w_detached = heated_normalized_weight.detach()
    if num_particles == 1:
        correction = 1
    else:
        correction = num_particles / (num_particles - 1)

    thing_to_add = correction * torch.sum(
        w_detached * (log_weight.unsqueeze(-1) -
                      torch.sum(wf, dim=1, keepdim=True)).detach() *
        (thermo_logp -
         torch.sum(thermo_logp * w_detached, dim=1, keepdim=True)),
        dim=1)

    multiplier = torch.zeros_like(partition)
    if integration == 'trapz':
        multiplier[0] = 0.5 * (partition[1] - partition[0])
        multiplier[1:-1] = 0.5 * (partition[2:] - partition[0:-2])
        multiplier[-1] = 0.5 * (partition[-1] - partition[-2])
    elif integration == 'left':
        multiplier[:-1] = partition[1:] - partition[:-1]
    elif integration == 'right':
        multiplier[1:] = partition[1:] - partition[:-1]

    loss = -torch.mean(
        torch.sum(multiplier *
                  (thing_to_add +
                   torch.sum(w_detached * log_weight.unsqueeze(-1), dim=1)),
                  dim=1))

    log_evidence = torch.logsumexp(log_weight, dim=1) - np.log(num_particles)
    elbo = torch.mean(log_evidence)

    return loss, elbo
Exemplo n.º 5
0
def get_mean_stds(generative_model, inference_network, num_mc_samples, obss,
                  num_particles):
    vimco_grad = util.OnlineMeanStd()
    vimco_one_grad = util.OnlineMeanStd()
    reinforce_grad = util.OnlineMeanStd()
    reinforce_one_grad = util.OnlineMeanStd()
    two_grad = util.OnlineMeanStd()
    log_evidence_stats = util.OnlineMeanStd()
    log_evidence_grad = util.OnlineMeanStd()
    wake_phi_loss_grad = util.OnlineMeanStd()
    log_Q_grad = util.OnlineMeanStd()
    sleep_loss_grad = util.OnlineMeanStd()

    for mc_sample_idx in range(num_mc_samples):
        util.print_with_time('MC sample {}'.format(mc_sample_idx))
        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obss, num_particles)
        log_evidence = torch.logsumexp(log_weight, dim=1) - \
            np.log(num_particles)
        avg_log_evidence = torch.mean(log_evidence)
        log_Q = torch.sum(log_q, dim=1)
        avg_log_Q = torch.mean(log_Q)
        reinforce_one = torch.mean(log_evidence.detach() * log_Q)
        reinforce = reinforce_one + avg_log_evidence
        vimco_one = 0
        for i in range(num_particles):
            log_weight_ = log_weight[:, util.range_except(num_particles, i)]
            control_variate = torch.logsumexp(
                torch.cat([log_weight_, torch.mean(log_weight_, dim=1,
                                                   keepdim=True)], dim=1),
                dim=1)
            vimco_one = vimco_one + (log_evidence.detach() -
                                     control_variate.detach()) * log_q[:, i]
        vimco_one = torch.mean(vimco_one)
        vimco = vimco_one + avg_log_evidence
        normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1)
        wake_phi_loss = torch.mean(
            -torch.sum(normalized_weight.detach() * log_q, dim=1))

        inference_network.zero_grad()
        generative_model.zero_grad()
        vimco.backward(retain_graph=True)
        vimco_grad.update([param.grad for param in
                           inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        vimco_one.backward(retain_graph=True)
        vimco_one_grad.update([param.grad for param in
                               inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        reinforce.backward(retain_graph=True)
        reinforce_grad.update([param.grad for param in
                               inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        reinforce_one.backward(retain_graph=True)
        reinforce_one_grad.update([param.grad for param in
                                   inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        avg_log_evidence.backward(retain_graph=True)
        two_grad.update([param.grad for param in
                         inference_network.parameters()])
        log_evidence_grad.update([param.grad for param in
                                  generative_model.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        wake_phi_loss.backward(retain_graph=True)
        wake_phi_loss_grad.update([param.grad for param in
                                   inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        avg_log_Q.backward(retain_graph=True)
        log_Q_grad.update([param.grad for param in
                           inference_network.parameters()])

        log_evidence_stats.update([avg_log_evidence.unsqueeze(0)])

        sleep_loss = losses.get_sleep_loss(
            generative_model, inference_network, num_particles * len(obss))
        inference_network.zero_grad()
        generative_model.zero_grad()
        sleep_loss.backward()
        sleep_loss_grad.update([param.grad for param in
                                inference_network.parameters()])

    return list(map(
        lambda x: x.avg_of_means_stds(),
        [vimco_grad, vimco_one_grad, reinforce_grad, reinforce_one_grad,
         two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad,
         log_Q_grad, sleep_loss_grad]))
Exemplo n.º 6
0
def train_mws(
    generative_model,
    inference_network,
    data_loader,
    num_iterations,
    memory_size,
    true_cluster_cov,
    test_data_loader,
    test_num_particles,
    true_generative_model,
    checkpoint_path,
    reweighted=False,
):
    optimizer = torch.optim.Adam(
        itertools.chain(generative_model.parameters(),
                        inference_network.parameters()))
    (
        theta_losses,
        phi_losses,
        cluster_cov_distances,
        test_log_ps,
        test_log_ps_true,
        test_kl_qps,
        test_kl_pqs,
        test_kl_qps_true,
        test_kl_pqs_true,
        train_log_ps,
        train_log_ps_true,
        train_kl_qps,
        train_kl_pqs,
        train_kl_qps_true,
        train_kl_pqs_true,
        train_kl_memory_ps,
        train_kl_memory_ps_true,
    ) = ([], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [])

    memory = {}
    data_loader_iter = iter(data_loader)

    for iteration in range(num_iterations):
        # get obs
        try:
            obs = next(data_loader_iter)
        except StopIteration:
            data_loader_iter = iter(data_loader)
            obs = next(data_loader_iter)

        theta_loss = 0
        phi_loss = 0
        for single_obs in obs:
            # key to index memory
            single_obs_key = tuple(single_obs.tolist())

            # populate memory if empty
            if (single_obs_key not in memory) or len(
                    memory[single_obs_key]) == 0:
                # batch shape [1] and event shape [num_data]
                latent_dist = inference_network.get_latent_dist(
                    single_obs.unsqueeze(0))

                # HACK
                while True:
                    # [memory_size, num_data]
                    latent = inference_network.sample_from_latent_dist(
                        latent_dist, memory_size).squeeze(1)
                    # list of M \in {1, ..., memory_size} elements
                    # could be less than memory_size because
                    # sampled elements can be duplicate
                    memory[single_obs_key] = list(
                        set([tuple(x.tolist()) for x in latent]))

                    if len(memory[single_obs_key]) == memory_size:
                        break

            # WAKE
            # batch shape [1] and event shape [num_data]
            latent_dist = inference_network.get_latent_dist(
                single_obs.unsqueeze(0))
            # [1, 1, num_data] -> [num_data]
            latent = inference_network.sample_from_latent_dist(latent_dist,
                                                               1).view(-1)
            # set (of size memory_size + 1) of tuples (of length num_data)
            memoized_latent_plus_current_latent = set(
                memory.get(single_obs_key, []) + [tuple(latent.tolist())])

            # [memory_size + 1, 1, num_data]
            memoized_latent_plus_current_latent_tensor = torch.tensor(
                list(memoized_latent_plus_current_latent),
                device=single_obs.device).unsqueeze(1)
            # [memory_size + 1]
            log_p_tensor = generative_model.get_log_prob(
                memoized_latent_plus_current_latent_tensor,
                single_obs.unsqueeze(0)).squeeze(-1)

            # this takes the longest
            # {int: [], ...}
            log_p = {
                mem_latent: lp
                for mem_latent, lp in zip(memoized_latent_plus_current_latent,
                                          log_p_tensor)
            }

            # update memory.
            # {float: list of ints}
            memory[single_obs_key] = sorted(
                memoized_latent_plus_current_latent,
                key=log_p.get)[-memory_size:]

            # REMEMBER
            # []
            if reweighted:
                memory_log_weight = torch.stack(
                    list(map(log_p.get,
                             memory[single_obs_key])))  # [memory_size]
                memory_weight_normalized = util.exponentiate_and_normalize(
                    memory_log_weight, dim=0)  # [memory_size]
                memory_latent = torch.tensor(
                    memory[single_obs_key])  # [memory_size, num_data]
                inference_network_log_prob = inference_network.get_log_prob_from_latent_dist(
                    latent_dist,
                    memory_latent[:, None, :]).squeeze(-1)  # [memory_size]

                theta_loss += -torch.sum(
                    memory_log_weight *
                    memory_weight_normalized.detach()) / len(obs)
                phi_loss += -torch.sum(
                    inference_network_log_prob *
                    memory_weight_normalized.detach()) / len(obs)
            else:
                remembered_latent_id_dist = torch.distributions.Categorical(
                    logits=torch.tensor(
                        list(map(log_p.get, memory[single_obs_key]))))
                remembered_latent_id = remembered_latent_id_dist.sample()
                remembered_latent_id_log_prob = remembered_latent_id_dist.log_prob(
                    remembered_latent_id)
                remembered_latent = memory[single_obs_key][
                    remembered_latent_id]
                remembered_latent_tensor = torch.tensor(
                    [remembered_latent], device=single_obs.device)
                # []
                theta_loss += -(log_p.get(remembered_latent) -
                                remembered_latent_id_log_prob.detach()) / len(
                                    obs)
                # []
                phi_loss += -inference_network.get_log_prob_from_latent_dist(
                    latent_dist, remembered_latent_tensor).view(()) / len(obs)

            # SLEEP
            # TODO

        optimizer.zero_grad()
        theta_loss.backward()
        phi_loss.backward()
        optimizer.step()

        theta_losses.append(theta_loss.item())
        phi_losses.append(phi_loss.item())
        cluster_cov_distances.append(
            torch.norm(true_cluster_cov -
                       generative_model.get_cluster_cov()).item())

        if iteration % 100 == 0:  # test every 100 iterations
            (
                test_log_p,
                test_log_p_true,
                test_kl_qp,
                test_kl_pq,
                test_kl_qp_true,
                test_kl_pq_true,
                _,
                _,
            ) = models.eval_gen_inf(true_generative_model, generative_model,
                                    inference_network, None, test_data_loader)
            test_log_ps.append(test_log_p)
            test_log_ps_true.append(test_log_p_true)
            test_kl_qps.append(test_kl_qp)
            test_kl_pqs.append(test_kl_pq)
            test_kl_qps_true.append(test_kl_qp_true)
            test_kl_pqs_true.append(test_kl_pq_true)

            (
                train_log_p,
                train_log_p_true,
                train_kl_qp,
                train_kl_pq,
                train_kl_qp_true,
                train_kl_pq_true,
                train_kl_memory_p,
                train_kl_memory_p_true,
            ) = models.eval_gen_inf(true_generative_model, generative_model,
                                    inference_network, memory, data_loader)
            train_log_ps.append(train_log_p)
            train_log_ps_true.append(train_log_p_true)
            train_kl_qps.append(train_kl_qp)
            train_kl_pqs.append(train_kl_pq)
            train_kl_qps_true.append(train_kl_qp_true)
            train_kl_pqs_true.append(train_kl_pq_true)
            train_kl_memory_ps.append(train_kl_memory_p)
            train_kl_memory_ps_true.append(train_kl_memory_p_true)

            util.save_checkpoint(
                checkpoint_path,
                generative_model,
                inference_network,
                theta_losses,
                phi_losses,
                cluster_cov_distances,
                test_log_ps,
                test_log_ps_true,
                test_kl_qps,
                test_kl_pqs,
                test_kl_qps_true,
                test_kl_pqs_true,
                train_log_ps,
                train_log_ps_true,
                train_kl_qps,
                train_kl_pqs,
                train_kl_qps_true,
                train_kl_pqs_true,
                train_kl_memory_ps,
                train_kl_memory_ps_true,
                memory,
                None,
                None,
            )

        util.print_with_time(
            "it. {} | theta loss = {:.2f} | phi loss = {:.2f}".format(
                iteration, theta_loss, phi_loss))

        # if iteration % 200 == 0:
        #     z = inference_network.get_latent_dist(obs).sample()
        #     util.save_plot("images/mws/iteration_{}.png".format(iteration),
        #                    obs[:3], z[:3])

    return (
        theta_losses,
        phi_losses,
        cluster_cov_distances,
        test_log_ps,
        test_log_ps_true,
        test_kl_qps,
        test_kl_pqs,
        test_kl_qps_true,
        test_kl_pqs_true,
        train_log_ps,
        train_log_ps_true,
        train_kl_qps,
        train_kl_pqs,
        train_kl_qps_true,
        train_kl_pqs_true,
        train_kl_memory_ps,
        train_kl_memory_ps_true,
        memory,
        None,
        None,
    )