示例#1
0
def get_grads_correct(seed):
    util.set_seed(seed)

    theta_grads_correct = []
    phi_grads_correct = []

    log_weight, log_q = losses.get_log_weight_and_log_q(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)
    wake_theta_loss.backward(retain_graph=True)
    theta_grads_correct = [parameter.grad.clone() for parameter in
                           generative_model.parameters()]
    # in rws, we step as we compute the grads
    # optimizer_theta.step()

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)
    wake_phi_loss.backward()
    phi_grads_correct = [parameter.grad.clone() for parameter in
                         inference_network.parameters()]
    # in rws, we step as we compute the grads
    # optimizer_phi.step()
    return theta_grads_correct, phi_grads_correct
示例#2
0
def get_grads_in_one_no_zeroing(seed):
    util.set_seed(seed)

    theta_grads_in_one = []
    phi_grads_in_one = []

    log_weight, log_q = losses.get_log_weight_and_log_q(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)
    wake_theta_loss.backward(retain_graph=True)

    # optimizer_phi.zero_grad() -> don't zero phi grads
    # optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)
    wake_phi_loss.backward()

    # only get the grads in the end!
    theta_grads_in_one = [parameter.grad.clone() for parameter in
                          generative_model.parameters()]
    phi_grads_in_one = [parameter.grad.clone() for parameter in
                        inference_network.parameters()]

    # in pyro, we want step to be in a different stage
    # optimizer_theta.step()
    # optimizer_phi.step()
    return theta_grads_in_one, phi_grads_in_one
示例#3
0
文件: grad_check.py 项目: yyht/rrws
def get_grads_correct_sleep(seed):
    util.set_seed(seed)

    theta_grads_correct = []
    phi_grads_correct = []

    log_weight, log_q = losses.get_log_weight_and_log_q(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)
    wake_theta_loss.backward(retain_graph=True)
    theta_grads_correct = [
        parameter.grad.clone() for parameter in generative_model.parameters()
    ]
    # in rws, we step as we compute the grads
    # optimizer_theta.step()

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)
    wake_phi_loss.backward()
    wake_phi_grads_correct = [
        parameter.grad.clone() for parameter in inference_network.parameters()
    ]
    # in rws, we step as we compute the grads
    # optimizer_phi.step()

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                           inference_network,
                                           num_samples=num_particles)
    sleep_phi_loss.backward()
    sleep_phi_grads_correct = [
        parameter.grad.clone() for parameter in inference_network.parameters()
    ]

    wake_factor = 0.755
    phi_grads_correct = [
        wake_factor * wake_phi_grad_correct +
        (1 - wake_factor) * sleep_phi_grad_correct
        for wake_phi_grad_correct, sleep_phi_grad_correct in zip(
            wake_phi_grads_correct, sleep_phi_grads_correct)
    ]

    return theta_grads_correct, phi_grads_correct
示例#4
0
文件: train.py 项目: JunLi-Galios/tvo
def train_wake_wake(generative_model,
                    inference_network,
                    data_loader,
                    num_iterations,
                    num_particles,
                    optim_kwargs,
                    callback=None):
    optimizer_phi = torch.optim.Adam(inference_network.parameters(),
                                     **optim_kwargs)
    optimizer_theta = torch.optim.Adam(generative_model.parameters(),
                                       **optim_kwargs)

    iteration = 0
    while iteration < num_iterations:
        for obs in iter(data_loader):
            log_weight, log_q = losses.get_log_weight_and_log_q(
                generative_model, inference_network, obs, num_particles)

            # wake theta
            optimizer_phi.zero_grad()
            optimizer_theta.zero_grad()
            wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
                log_weight)
            wake_theta_loss.backward(retain_graph=True)
            optimizer_theta.step()

            # wake phi
            optimizer_phi.zero_grad()
            optimizer_theta.zero_grad()
            wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
                log_weight, log_q)
            wake_phi_loss.backward()
            optimizer_phi.step()

            if callback is not None:
                callback(iteration,
                         wake_theta_loss.item(), wake_phi_loss.item(),
                         elbo.item(), generative_model, inference_network,
                         optimizer_theta, optimizer_phi)

            iteration += 1
            # by this time, we have gone through `iteration` iterations
            if iteration == num_iterations:
                break
示例#5
0
def train_defensive_wake_wake(delta,
                              generative_model,
                              inference_network,
                              obss_data_loader,
                              num_iterations,
                              num_particles,
                              callback=None):
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())
    obss_iter = iter(obss_data_loader)

    for iteration in range(num_iterations):
        # get obss
        try:
            obss = next(obss_iter)
        except StopIteration:
            obss_iter = iter(obss_data_loader)
            obss = next(obss_iter)

        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obss, num_particles)

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
            log_weight)
        wake_theta_loss.backward(retain_graph=True)
        optimizer_theta.step()

        # wake phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_phi_loss = losses.get_defensive_wake_phi_loss(
            generative_model, inference_network, obss, delta, num_particles)
        wake_phi_loss.backward()
        optimizer_phi.step()

        if callback is not None:
            callback(iteration, wake_theta_loss.item(), wake_phi_loss.item(),
                     elbo.item(), generative_model, inference_network,
                     optimizer_theta, optimizer_phi)

    return optimizer_theta, optimizer_phi
示例#6
0
def get_log_p_and_kl(generative_model, inference_network, obs, num_particles):
    """Compute log weight and log prob of inference network.

    Args:
        generative_model: models.GenerativeModel object
        inference_network: models.InferenceNetwork object
        obs: tensor of shape [batch_size, num_data * num_dim]
        num_particles: int

    Returns:
        log_p: tensor [batch_size]
        kl: tensor [batch_size]
    """
    log_weight, _ = losses.get_log_weight_and_log_q(generative_model,
                                                    inference_network, obs,
                                                    num_particles)

    log_p = torch.logsumexp(log_weight, dim=1) - math.log(num_particles)
    elbo = torch.mean(log_weight, dim=1)
    kl = log_p - elbo

    return log_p, kl
示例#7
0
文件: train.py 项目: yyht/rrws
def train_wake_wake(generative_model,
                    inference_network,
                    true_generative_model,
                    batch_size,
                    num_iterations,
                    num_particles,
                    callback=None):
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())

    for iteration in range(num_iterations):
        # generate synthetic data
        obss = [true_generative_model.sample_obs() for _ in range(batch_size)]

        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obss, num_particles)

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
            log_weight)
        wake_theta_loss.backward(retain_graph=True)
        optimizer_theta.step()

        # wake phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
            log_weight, log_q)
        wake_phi_loss.backward()
        optimizer_phi.step()

        if callback is not None:
            callback(iteration, wake_theta_loss.item(), wake_phi_loss.item(),
                     elbo.item(), generative_model, inference_network,
                     optimizer_theta, optimizer_phi)

    return optimizer_theta, optimizer_phi
示例#8
0
def get_mean_stds(generative_model, inference_network, num_mc_samples, obss,
                  num_particles):
    vimco_grad = util.OnlineMeanStd()
    vimco_one_grad = util.OnlineMeanStd()
    reinforce_grad = util.OnlineMeanStd()
    reinforce_one_grad = util.OnlineMeanStd()
    two_grad = util.OnlineMeanStd()
    log_evidence_stats = util.OnlineMeanStd()
    log_evidence_grad = util.OnlineMeanStd()
    wake_phi_loss_grad = util.OnlineMeanStd()
    log_Q_grad = util.OnlineMeanStd()
    sleep_loss_grad = util.OnlineMeanStd()

    for mc_sample_idx in range(num_mc_samples):
        util.print_with_time('MC sample {}'.format(mc_sample_idx))
        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obss, num_particles)
        log_evidence = torch.logsumexp(log_weight, dim=1) - \
            np.log(num_particles)
        avg_log_evidence = torch.mean(log_evidence)
        log_Q = torch.sum(log_q, dim=1)
        avg_log_Q = torch.mean(log_Q)
        reinforce_one = torch.mean(log_evidence.detach() * log_Q)
        reinforce = reinforce_one + avg_log_evidence
        vimco_one = 0
        for i in range(num_particles):
            log_weight_ = log_weight[:, util.range_except(num_particles, i)]
            control_variate = torch.logsumexp(
                torch.cat([log_weight_, torch.mean(log_weight_, dim=1,
                                                   keepdim=True)], dim=1),
                dim=1)
            vimco_one = vimco_one + (log_evidence.detach() -
                                     control_variate.detach()) * log_q[:, i]
        vimco_one = torch.mean(vimco_one)
        vimco = vimco_one + avg_log_evidence
        normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1)
        wake_phi_loss = torch.mean(
            -torch.sum(normalized_weight.detach() * log_q, dim=1))

        inference_network.zero_grad()
        generative_model.zero_grad()
        vimco.backward(retain_graph=True)
        vimco_grad.update([param.grad for param in
                           inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        vimco_one.backward(retain_graph=True)
        vimco_one_grad.update([param.grad for param in
                               inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        reinforce.backward(retain_graph=True)
        reinforce_grad.update([param.grad for param in
                               inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        reinforce_one.backward(retain_graph=True)
        reinforce_one_grad.update([param.grad for param in
                                   inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        avg_log_evidence.backward(retain_graph=True)
        two_grad.update([param.grad for param in
                         inference_network.parameters()])
        log_evidence_grad.update([param.grad for param in
                                  generative_model.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        wake_phi_loss.backward(retain_graph=True)
        wake_phi_loss_grad.update([param.grad for param in
                                   inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        avg_log_Q.backward(retain_graph=True)
        log_Q_grad.update([param.grad for param in
                           inference_network.parameters()])

        log_evidence_stats.update([avg_log_evidence.unsqueeze(0)])

        sleep_loss = losses.get_sleep_loss(
            generative_model, inference_network, num_particles * len(obss))
        inference_network.zero_grad()
        generative_model.zero_grad()
        sleep_loss.backward()
        sleep_loss_grad.update([param.grad for param in
                                inference_network.parameters()])

    return list(map(
        lambda x: x.avg_of_means_stds(),
        [vimco_grad, vimco_one_grad, reinforce_grad, reinforce_one_grad,
         two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad,
         log_Q_grad, sleep_loss_grad]))
示例#9
0
def train_rws(
    generative_model,
    inference_network,
    data_loader,
    num_iterations,
    num_particles,
    true_cluster_cov,
    test_data_loader,
    test_num_particles,
    true_generative_model,
    checkpoint_path,
):
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())
    (
        theta_losses,
        phi_losses,
        cluster_cov_distances,
        test_log_ps,
        test_log_ps_true,
        test_kl_qps,
        test_kl_pqs,
        test_kl_qps_true,
        test_kl_pqs_true,
        train_log_ps,
        train_log_ps_true,
        train_kl_qps,
        train_kl_pqs,
        train_kl_qps_true,
        train_kl_pqs_true,
        reweighted_train_kl_qps,
        reweighted_train_kl_qps_true,
    ) = ([], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [])
    data_loader_iter = iter(data_loader)

    for iteration in range(num_iterations):
        # get obs
        try:
            obs = next(data_loader_iter)
        except StopIteration:
            data_loader_iter = iter(data_loader)
            obs = next(data_loader_iter)

        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obs, num_particles)

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
            log_weight)
        wake_theta_loss.backward(retain_graph=True)
        optimizer_theta.step()

        # wake phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
            log_weight, log_q)
        wake_phi_loss.backward()
        optimizer_phi.step()

        theta_losses.append(wake_theta_loss.item())
        phi_losses.append(wake_phi_loss.item())
        cluster_cov_distances.append(
            torch.norm(true_cluster_cov -
                       generative_model.get_cluster_cov()).item())
        if iteration % 100 == 0:  # test every 100 iterations
            (
                test_log_p,
                test_log_p_true,
                test_kl_qp,
                test_kl_pq,
                test_kl_qp_true,
                test_kl_pq_true,
                _,
                _,
            ) = models.eval_gen_inf(true_generative_model, generative_model,
                                    inference_network, None, test_data_loader)
            test_log_ps.append(test_log_p)
            test_log_ps_true.append(test_log_p_true)
            test_kl_qps.append(test_kl_qp)
            test_kl_pqs.append(test_kl_pq)
            test_kl_qps_true.append(test_kl_qp_true)
            test_kl_pqs_true.append(test_kl_pq_true)

            (
                train_log_p,
                train_log_p_true,
                train_kl_qp,
                train_kl_pq,
                train_kl_qp_true,
                train_kl_pq_true,
                _,
                _,
                reweighted_train_kl_qp,
                reweighted_train_kl_qp_true,
            ) = models.eval_gen_inf(
                true_generative_model,
                generative_model,
                inference_network,
                None,
                data_loader,
                num_particles=num_particles,
                reweighted_kl=True,
            )
            train_log_ps.append(train_log_p)
            train_log_ps_true.append(train_log_p_true)
            train_kl_qps.append(train_kl_qp)
            train_kl_pqs.append(train_kl_pq)
            train_kl_qps_true.append(train_kl_qp_true)
            train_kl_pqs_true.append(train_kl_pq_true)
            reweighted_train_kl_qps.append(reweighted_train_kl_qp)
            reweighted_train_kl_qps_true.append(reweighted_train_kl_qp_true)

            util.save_checkpoint(
                checkpoint_path,
                generative_model,
                inference_network,
                theta_losses,
                phi_losses,
                cluster_cov_distances,
                test_log_ps,
                test_log_ps_true,
                test_kl_qps,
                test_kl_pqs,
                test_kl_qps_true,
                test_kl_pqs_true,
                train_log_ps,
                train_log_ps_true,
                train_kl_qps,
                train_kl_pqs,
                train_kl_qps_true,
                train_kl_pqs_true,
                None,
                None,
                None,
                reweighted_train_kl_qps,
                reweighted_train_kl_qps_true,
            )

        util.print_with_time(
            "it. {} | theta loss = {:.2f} | phi loss = {:.2f}".format(
                iteration, wake_theta_loss, wake_phi_loss))

        # if iteration % 200 == 0:
        #     z = inference_network.get_latent_dist(obs).sample()
        #     util.save_plot("images/rws/iteration_{}.png".format(iteration),
        #                    obs[:3], z[:3])

    return (
        theta_losses,
        phi_losses,
        cluster_cov_distances,
        test_log_ps,
        test_log_ps_true,
        test_kl_qps,
        test_kl_pqs,
        test_kl_qps_true,
        test_kl_pqs_true,
        train_log_ps,
        train_log_ps_true,
        train_kl_qps,
        train_kl_pqs,
        train_kl_qps_true,
        train_kl_pqs_true,
        None,
        None,
        None,
        reweighted_train_kl_qps,
        reweighted_train_kl_qps_true,
    )