Example #1
0
def get_grads_correct(seed):
    util.set_seed(seed)

    theta_grads_correct = []
    phi_grads_correct = []

    log_weight, log_q = losses.get_log_weight_and_log_q(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)
    wake_theta_loss.backward(retain_graph=True)
    theta_grads_correct = [parameter.grad.clone() for parameter in
                           generative_model.parameters()]
    # in rws, we step as we compute the grads
    # optimizer_theta.step()

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)
    wake_phi_loss.backward()
    phi_grads_correct = [parameter.grad.clone() for parameter in
                         inference_network.parameters()]
    # in rws, we step as we compute the grads
    # optimizer_phi.step()
    return theta_grads_correct, phi_grads_correct
Example #2
0
def get_grads_in_one_no_zeroing(seed):
    util.set_seed(seed)

    theta_grads_in_one = []
    phi_grads_in_one = []

    log_weight, log_q = losses.get_log_weight_and_log_q(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)
    wake_theta_loss.backward(retain_graph=True)

    # optimizer_phi.zero_grad() -> don't zero phi grads
    # optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)
    wake_phi_loss.backward()

    # only get the grads in the end!
    theta_grads_in_one = [parameter.grad.clone() for parameter in
                          generative_model.parameters()]
    phi_grads_in_one = [parameter.grad.clone() for parameter in
                        inference_network.parameters()]

    # in pyro, we want step to be in a different stage
    # optimizer_theta.step()
    # optimizer_phi.step()
    return theta_grads_in_one, phi_grads_in_one
Example #3
0
def train_rws_(
    generative_model,
    inference_network,
    data_loader,
    test_data_loader,
    num_particles,
    test_num_particles,
    num_iterations,
    log_interval,
):
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    theta_losses, phi_losses, log_ps, kls = [], [], [], []
    device = next(generative_model.parameters()).device
    if device.type == "cuda":
        torch.cuda.reset_max_memory_allocated(device=device)

    iteration = 0
    while iteration < num_iterations:
        for obs_id, obs in data_loader:
            latent_dist = inference_network.get_latent_dist(obs_id)
            latent = inference_network.sample_from_latent_dist(latent_dist, num_particles)
            log_p = generative_model.get_log_prob(latent, obs, obs_id).transpose(0, 1)
            log_q = inference_network.get_log_prob_from_latent_dist(latent_dist, latent).transpose(
                0, 1
            )
            log_weight = log_p - log_q

            # wake phi
            optimizer_phi.zero_grad()
            wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(log_weight, log_q)
            wake_phi_loss.backward()
            optimizer_phi.step()

            phi_losses.append(wake_phi_loss.item())
            iteration += 1
            # by this time, we have gone through `iteration` iterations
            if iteration % log_interval == 0:
                util.logging.info(
                    "it. {}/{} | "
                    "phi loss = {:.2f} | last log_p = {} | "
                    "last kl = {} | GPU memory = {:.2f} MB".format(
                        iteration,
                        num_iterations,
                        phi_losses[-1],
                        "N/A" if len(log_ps) == 0 else log_ps[-1],
                        "N/A" if len(kls) == 0 else kls[-1],
                        (
                            torch.cuda.max_memory_allocated(device=device) / 1e6
                            if device.type == "cuda"
                            else 0
                        ),
                    )
                )
            if iteration == num_iterations:
                break

    log_p, kl = eval_gen_inf_(
        generative_model, inference_network, test_data_loader, test_num_particles
    )
    return log_p, phi_losses
Example #4
0
def get_grads_correct_sleep(seed):
    util.set_seed(seed)

    theta_grads_correct = []
    phi_grads_correct = []

    log_weight, log_q = losses.get_log_weight_and_log_q(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)
    wake_theta_loss.backward(retain_graph=True)
    theta_grads_correct = [
        parameter.grad.clone() for parameter in generative_model.parameters()
    ]
    # in rws, we step as we compute the grads
    # optimizer_theta.step()

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)
    wake_phi_loss.backward()
    wake_phi_grads_correct = [
        parameter.grad.clone() for parameter in inference_network.parameters()
    ]
    # in rws, we step as we compute the grads
    # optimizer_phi.step()

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                           inference_network,
                                           num_samples=num_particles)
    sleep_phi_loss.backward()
    sleep_phi_grads_correct = [
        parameter.grad.clone() for parameter in inference_network.parameters()
    ]

    wake_factor = 0.755
    phi_grads_correct = [
        wake_factor * wake_phi_grad_correct +
        (1 - wake_factor) * sleep_phi_grad_correct
        for wake_phi_grad_correct, sleep_phi_grad_correct in zip(
            wake_phi_grads_correct, sleep_phi_grads_correct)
    ]

    return theta_grads_correct, phi_grads_correct
Example #5
0
def train_thermo_wake(generative_model,
                      inference_network,
                      data_loader,
                      num_iterations,
                      num_particles,
                      partition,
                      optim_kwargs,
                      callback=None):
    optimizer_phi = torch.optim.Adam(inference_network.parameters(),
                                     **optim_kwargs)
    optimizer_theta = torch.optim.Adam(generative_model.parameters(),
                                       **optim_kwargs)

    iteration = 0
    while iteration < num_iterations:
        for obs in iter(data_loader):
            log_weight, log_p, log_q = losses.get_log_weight_log_p_log_q(
                generative_model, inference_network, obs, num_particles)

            # wake theta
            optimizer_phi.zero_grad()
            optimizer_theta.zero_grad()
            thermo_loss, elbo = \
                losses.get_thermo_loss_from_log_weight_log_p_log_q(
                    log_weight, log_p, log_q, partition, num_particles)
            thermo_loss.backward(retain_graph=True)
            optimizer_theta.step()

            # wake phi
            optimizer_phi.zero_grad()
            optimizer_theta.zero_grad()
            wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
                log_weight, log_q)
            wake_phi_loss.backward()
            optimizer_phi.step()

            if callback is not None:
                callback(iteration, thermo_loss.item(), wake_phi_loss.item(),
                         elbo.item(), generative_model, inference_network,
                         optimizer_theta, optimizer_phi)

            iteration += 1
            # by this time, we have gone through `iteration` iterations
            if iteration == num_iterations:
                break
Example #6
0
def train_wake_wake(generative_model,
                    inference_network,
                    obss_data_loader,
                    num_iterations,
                    num_particles,
                    callback=None):
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())
    obss_iter = iter(obss_data_loader)

    for iteration in range(num_iterations):
        # get obss
        try:
            obss = next(obss_iter)
        except StopIteration:
            obss_iter = iter(obss_data_loader)
            obss = next(obss_iter)

        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obss, num_particles)

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
            log_weight)
        wake_theta_loss.backward(retain_graph=True)
        optimizer_theta.step()

        # wake phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
            log_weight, log_q)
        wake_phi_loss.backward()
        optimizer_phi.step()

        if callback is not None:
            callback(iteration, wake_theta_loss.item(), wake_phi_loss.item(),
                     elbo.item(), generative_model, inference_network,
                     optimizer_theta, optimizer_phi)

    return optimizer_theta, optimizer_phi
Example #7
0
def get_grads_weird_detach_sleep(seed):
    util.set_seed(seed)

    theta_grads_in_one = []
    phi_grads_in_one = []

    log_weight, log_q = get_log_weight_and_log_q_weird_detach(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)

    # optimizer_phi.zero_grad() -> don't zero phi grads
    # optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)

    sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                           inference_network,
                                           num_samples=num_particles)

    wake_factor = 0.755
    phi_loss = wake_factor * wake_phi_loss + (1 - wake_factor) * sleep_phi_loss

    wake_theta_loss.backward(retain_graph=True)
    phi_loss.backward()

    # only get the grads in the end!
    theta_grads_in_one = [
        parameter.grad.clone() for parameter in generative_model.parameters()
    ]
    phi_grads_in_one = [
        parameter.grad.clone() for parameter in inference_network.parameters()
    ]

    # in pyro, we want step to be in a different stage
    # optimizer_theta.step()
    # optimizer_phi.step()
    return theta_grads_in_one, phi_grads_in_one
Example #8
0
File: train.py Project: yyht/rrws
def train_wake_wake(generative_model,
                    inference_network,
                    true_generative_model,
                    batch_size,
                    num_iterations,
                    num_particles,
                    callback=None):
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())

    for iteration in range(num_iterations):
        # generate synthetic data
        obss = [true_generative_model.sample_obs() for _ in range(batch_size)]

        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obss, num_particles)

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
            log_weight)
        wake_theta_loss.backward(retain_graph=True)
        optimizer_theta.step()

        # wake phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
            log_weight, log_q)
        wake_phi_loss.backward()
        optimizer_phi.step()

        if callback is not None:
            callback(iteration, wake_theta_loss.item(), wake_phi_loss.item(),
                     elbo.item(), generative_model, inference_network,
                     optimizer_theta, optimizer_phi)

    return optimizer_theta, optimizer_phi
Example #9
0
def train_rws(
    generative_model,
    inference_network,
    data_loader,
    num_iterations,
    num_particles,
    true_cluster_cov,
    test_data_loader,
    test_num_particles,
    true_generative_model,
    checkpoint_path,
):
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())
    (
        theta_losses,
        phi_losses,
        cluster_cov_distances,
        test_log_ps,
        test_log_ps_true,
        test_kl_qps,
        test_kl_pqs,
        test_kl_qps_true,
        test_kl_pqs_true,
        train_log_ps,
        train_log_ps_true,
        train_kl_qps,
        train_kl_pqs,
        train_kl_qps_true,
        train_kl_pqs_true,
        reweighted_train_kl_qps,
        reweighted_train_kl_qps_true,
    ) = ([], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [])
    data_loader_iter = iter(data_loader)

    for iteration in range(num_iterations):
        # get obs
        try:
            obs = next(data_loader_iter)
        except StopIteration:
            data_loader_iter = iter(data_loader)
            obs = next(data_loader_iter)

        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obs, num_particles)

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
            log_weight)
        wake_theta_loss.backward(retain_graph=True)
        optimizer_theta.step()

        # wake phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
            log_weight, log_q)
        wake_phi_loss.backward()
        optimizer_phi.step()

        theta_losses.append(wake_theta_loss.item())
        phi_losses.append(wake_phi_loss.item())
        cluster_cov_distances.append(
            torch.norm(true_cluster_cov -
                       generative_model.get_cluster_cov()).item())
        if iteration % 100 == 0:  # test every 100 iterations
            (
                test_log_p,
                test_log_p_true,
                test_kl_qp,
                test_kl_pq,
                test_kl_qp_true,
                test_kl_pq_true,
                _,
                _,
            ) = models.eval_gen_inf(true_generative_model, generative_model,
                                    inference_network, None, test_data_loader)
            test_log_ps.append(test_log_p)
            test_log_ps_true.append(test_log_p_true)
            test_kl_qps.append(test_kl_qp)
            test_kl_pqs.append(test_kl_pq)
            test_kl_qps_true.append(test_kl_qp_true)
            test_kl_pqs_true.append(test_kl_pq_true)

            (
                train_log_p,
                train_log_p_true,
                train_kl_qp,
                train_kl_pq,
                train_kl_qp_true,
                train_kl_pq_true,
                _,
                _,
                reweighted_train_kl_qp,
                reweighted_train_kl_qp_true,
            ) = models.eval_gen_inf(
                true_generative_model,
                generative_model,
                inference_network,
                None,
                data_loader,
                num_particles=num_particles,
                reweighted_kl=True,
            )
            train_log_ps.append(train_log_p)
            train_log_ps_true.append(train_log_p_true)
            train_kl_qps.append(train_kl_qp)
            train_kl_pqs.append(train_kl_pq)
            train_kl_qps_true.append(train_kl_qp_true)
            train_kl_pqs_true.append(train_kl_pq_true)
            reweighted_train_kl_qps.append(reweighted_train_kl_qp)
            reweighted_train_kl_qps_true.append(reweighted_train_kl_qp_true)

            util.save_checkpoint(
                checkpoint_path,
                generative_model,
                inference_network,
                theta_losses,
                phi_losses,
                cluster_cov_distances,
                test_log_ps,
                test_log_ps_true,
                test_kl_qps,
                test_kl_pqs,
                test_kl_qps_true,
                test_kl_pqs_true,
                train_log_ps,
                train_log_ps_true,
                train_kl_qps,
                train_kl_pqs,
                train_kl_qps_true,
                train_kl_pqs_true,
                None,
                None,
                None,
                reweighted_train_kl_qps,
                reweighted_train_kl_qps_true,
            )

        util.print_with_time(
            "it. {} | theta loss = {:.2f} | phi loss = {:.2f}".format(
                iteration, wake_theta_loss, wake_phi_loss))

        # if iteration % 200 == 0:
        #     z = inference_network.get_latent_dist(obs).sample()
        #     util.save_plot("images/rws/iteration_{}.png".format(iteration),
        #                    obs[:3], z[:3])

    return (
        theta_losses,
        phi_losses,
        cluster_cov_distances,
        test_log_ps,
        test_log_ps_true,
        test_kl_qps,
        test_kl_pqs,
        test_kl_qps_true,
        test_kl_pqs_true,
        train_log_ps,
        train_log_ps_true,
        train_kl_qps,
        train_kl_pqs,
        train_kl_qps_true,
        train_kl_pqs_true,
        None,
        None,
        None,
        reweighted_train_kl_qps,
        reweighted_train_kl_qps_true,
    )