Exemple #1
0
def train_iwae(algorithm,
               generative_model,
               inference_network,
               data_loader,
               num_iterations,
               num_particles,
               optim_kwargs,
               callback=None):
    parameters = itertools.chain.from_iterable(
        [x.parameters() for x in [generative_model, inference_network]])
    optimizer = torch.optim.Adam(parameters, **optim_kwargs)

    iteration = 0
    while iteration < num_iterations:
        for obs in iter(data_loader):
            optimizer.zero_grad()
            if algorithm == 'vimco':
                loss, elbo = losses.get_vimco_loss(generative_model,
                                                   inference_network, obs,
                                                   num_particles)
            elif algorithm == 'reinforce':
                loss, elbo = losses.get_reinforce_loss(generative_model,
                                                       inference_network, obs,
                                                       num_particles)
            loss.backward()
            optimizer.step()

            if callback is not None:
                callback(iteration, loss.item(), elbo.item(), generative_model,
                         inference_network, optimizer)

            iteration += 1
            # by this time, we have gone through `iteration` iterations
            if iteration == num_iterations:
                break
Exemple #2
0
def train_iwae(algorithm,
               generative_model,
               inference_network,
               obss_data_loader,
               num_iterations,
               num_particles,
               callback=None):
    """Train using IWAE objective.

    Args:
        algorithm: reinforce, vimco or concrete
    """

    parameters = itertools.chain.from_iterable(
        [x.parameters() for x in [generative_model, inference_network]])
    optimizer = torch.optim.Adam(parameters)
    obss_iter = iter(obss_data_loader)

    for iteration in range(num_iterations):
        # get obss
        try:
            obss = next(obss_iter)
        except StopIteration:
            obss_iter = iter(obss_data_loader)
            obss = next(obss_iter)

        # wake theta
        optimizer.zero_grad()
        if algorithm == 'vimco':
            loss, elbo = losses.get_vimco_loss(generative_model,
                                               inference_network, obss,
                                               num_particles)
        elif algorithm == 'reinforce':
            loss, elbo = losses.get_reinforce_loss(generative_model,
                                                   inference_network, obss,
                                                   num_particles)
        elif algorithm == 'concrete':
            loss, elbo = losses.get_concrete_loss(generative_model,
                                                  inference_network, obss,
                                                  num_particles)
        loss.backward()
        optimizer.step()

        if callback is not None:
            callback(iteration, loss.item(), elbo.item(), generative_model,
                     inference_network, optimizer)

    return optimizer
Exemple #3
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                generative_model.zero_grad()
                inference_network.zero_grad()
                if self.train_mode == 'vimco':
                    loss, elbo = losses.get_vimco_loss(generative_model,
                                                       inference_network,
                                                       self.test_obs,
                                                       self.num_particles)
                elif self.train_mode == 'reinforce':
                    loss, elbo = losses.get_reinforce_loss(
                        generative_model, inference_network, self.test_obs,
                        self.num_particles)
                loss.backward()
                stats.update([p.grad for p in generative_model.parameters()] +
                             [p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format(
                    iteration, self.log_p_history[-1], self.kl_history[-1]))
Exemple #4
0
def train_iwae(algorithm,
               generative_model,
               inference_network,
               true_generative_model,
               batch_size,
               num_iterations,
               num_particles,
               callback=None):
    """Train using IWAE objective.

    Args:
        algorithm: reinforce or vimco
    """

    parameters = itertools.chain.from_iterable(
        [x.parameters() for x in [generative_model, inference_network]])
    optimizer = torch.optim.Adam(parameters)

    for iteration in range(num_iterations):
        # generate synthetic data
        obss = [true_generative_model.sample_obs() for _ in range(batch_size)]

        # wake theta
        optimizer.zero_grad()
        if algorithm == 'vimco':
            loss, elbo = losses.get_vimco_loss(generative_model,
                                               inference_network, obss,
                                               num_particles)
        elif algorithm == 'reinforce':
            loss, elbo = losses.get_reinforce_loss(generative_model,
                                                   inference_network, obss,
                                                   num_particles)
        loss.backward()
        optimizer.step()

        if callback is not None:
            callback(iteration, loss.item(), elbo.item(), generative_model,
                     inference_network, optimizer)

    return optimizer
Exemple #5
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            stats = util.OnlineMeanStd()
            for _ in range(10):
                inference_network.zero_grad()
                if self.train_mode == 'vimco':
                    loss, elbo = losses.get_vimco_loss(generative_model,
                                                       inference_network,
                                                       self.test_obss,
                                                       self.num_particles)
                elif self.train_mode == 'reinforce':
                    loss, elbo = losses.get_reinforce_loss(
                        generative_model, inference_network, self.test_obss,
                        self.num_particles)
                loss.backward()
                stats.update([p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1])
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))
Exemple #6
0
def train_vimco(
    generative_model,
    inference_network,
    data_loader,
    num_iterations,
    num_particles,
    true_cluster_cov,
    test_data_loader,
    test_num_particles,
    true_generative_model,
    checkpoint_path,
):
    optimizer = torch.optim.Adam(
        itertools.chain(generative_model.parameters(),
                        inference_network.parameters()))
    (
        theta_losses,
        phi_losses,
        cluster_cov_distances,
        test_log_ps,
        test_log_ps_true,
        test_kl_qps,
        test_kl_pqs,
        test_kl_qps_true,
        test_kl_pqs_true,
        train_log_ps,
        train_log_ps_true,
        train_kl_qps,
        train_kl_pqs,
        train_kl_qps_true,
        train_kl_pqs_true,
        reweighted_train_kl_qps,
        reweighted_train_kl_qps_true,
    ) = ([], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [])
    data_loader_iter = iter(data_loader)

    for iteration in range(num_iterations):
        # get obs
        try:
            obs = next(data_loader_iter)
        except StopIteration:
            data_loader_iter = iter(data_loader)
            obs = next(data_loader_iter)

        # loss
        optimizer.zero_grad()
        loss, elbo = losses.get_vimco_loss(generative_model, inference_network,
                                           obs, num_particles)
        loss.backward(retain_graph=True)
        optimizer.step()

        theta_losses.append(loss.item())
        phi_losses.append(loss.item())
        cluster_cov_distances.append(
            torch.norm(true_cluster_cov -
                       generative_model.get_cluster_cov()).item())
        if iteration % 100 == 0:  # test every 100 iterations
            (
                test_log_p,
                test_log_p_true,
                test_kl_qp,
                test_kl_pq,
                test_kl_qp_true,
                test_kl_pq_true,
                _,
                _,
            ) = models.eval_gen_inf(true_generative_model, generative_model,
                                    inference_network, None, test_data_loader)
            test_log_ps.append(test_log_p)
            test_log_ps_true.append(test_log_p_true)
            test_kl_qps.append(test_kl_qp)
            test_kl_pqs.append(test_kl_pq)
            test_kl_qps_true.append(test_kl_qp_true)
            test_kl_pqs_true.append(test_kl_pq_true)

            (
                train_log_p,
                train_log_p_true,
                train_kl_qp,
                train_kl_pq,
                train_kl_qp_true,
                train_kl_pq_true,
                _,
                _,
                reweighted_train_kl_qp,
                reweighted_train_kl_qp_true,
            ) = models.eval_gen_inf(
                true_generative_model,
                generative_model,
                inference_network,
                None,
                data_loader,
                num_particles=num_particles,
                reweighted_kl=True,
            )
            train_log_ps.append(train_log_p)
            train_log_ps_true.append(train_log_p_true)
            train_kl_qps.append(train_kl_qp)
            train_kl_pqs.append(train_kl_pq)
            train_kl_qps_true.append(train_kl_qp_true)
            train_kl_pqs_true.append(train_kl_pq_true)
            reweighted_train_kl_qps.append(reweighted_train_kl_qp)
            reweighted_train_kl_qps_true.append(reweighted_train_kl_qp_true)

            util.save_checkpoint(
                checkpoint_path,
                generative_model,
                inference_network,
                theta_losses,
                phi_losses,
                cluster_cov_distances,
                test_log_ps,
                test_log_ps_true,
                test_kl_qps,
                test_kl_pqs,
                test_kl_qps_true,
                test_kl_pqs_true,
                train_log_ps,
                train_log_ps_true,
                train_kl_qps,
                train_kl_pqs,
                train_kl_qps_true,
                train_kl_pqs_true,
                None,
                None,
                None,
                reweighted_train_kl_qps,
                reweighted_train_kl_qps_true,
            )

        util.print_with_time("it. {} | theta loss = {:.2f}".format(
            iteration, loss))

        # if iteration % 200 == 0:
        #     z = inference_network.get_latent_dist(obs).sample()
        #     util.save_plot("images/rws/iteration_{}.png".format(iteration),
        #                    obs[:3], z[:3])

    return (
        theta_losses,
        phi_losses,
        cluster_cov_distances,
        test_log_ps,
        test_log_ps_true,
        test_kl_qps,
        test_kl_pqs,
        test_kl_qps_true,
        test_kl_pqs_true,
        train_log_ps,
        train_log_ps_true,
        train_kl_qps,
        train_kl_pqs,
        train_kl_qps_true,
        train_kl_pqs_true,
        None,
        None,
        None,
        reweighted_train_kl_qps,
        reweighted_train_kl_qps_true,
    )