コード例 #1
0
ファイル: train.py プロジェクト: yyht/rrws
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, control_variate):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_filename(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.pcfg_path, self.model_folder)
            util.save_control_variate(control_variate, self.model_folder)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_to_true_history.append(
                util.get_q_error(self.true_generative_model,
                                 inference_network))
            self.q_error_to_model_history.append(
                util.get_q_error(generative_model, inference_network))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = {:.3f}, '
                'q_error_to_model = {:.3f}'.format(
                    iteration, self.p_error_history[-1],
                    self.q_error_to_true_history[-1],
                    self.q_error_to_model_history[-1]))
コード例 #2
0
ファイル: train.py プロジェクト: yyht/rrws
    def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.wake_phi_loss_history.append(wake_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_filename(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.pcfg_path, self.model_folder)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_to_true_history.append(
                util.get_q_error(self.true_generative_model,
                                 inference_network))
            self.q_error_to_model_history.append(
                util.get_q_error(generative_model, inference_network))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = {:.3f}, '
                'q_error_to_model = {:.3f}'.format(
                    iteration, self.p_error_history[-1],
                    self.q_error_to_true_history[-1],
                    self.q_error_to_model_history[-1]))
コード例 #3
0
    def __call__(self, iteration, theta_loss, phi_loss, generative_model,
                 inference_network, memory, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}'.format(
                    iteration, theta_loss, phi_loss))
            self.theta_loss_history.append(theta_loss)
            self.phi_loss_history.append(phi_loss)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration, memory)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            # TODO
            # self.memory_error_history.append(util.get_memory_error(
            #     self.true_generative_model, memory, generative_model,
            #     self.test_obss))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))
コード例 #4
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            stats = util.OnlineMeanStd()
            for _ in range(10):
                inference_network.zero_grad()
                if self.train_mode == 'vimco':
                    loss, elbo = losses.get_vimco_loss(generative_model,
                                                       inference_network,
                                                       self.test_obss,
                                                       self.num_particles)
                elif self.train_mode == 'reinforce':
                    loss, elbo = losses.get_reinforce_loss(
                        generative_model, inference_network, self.test_obss,
                        self.num_particles)
                loss.backward()
                stats.update([p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1])
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))