Example #1
0
    def log(self, iteration, train_stats):
        # --- save model ---
        if iteration % self.args.save_interval == 0:
            save_path = os.path.join(self.tb_logger.full_output_folder,
                                     'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                self.agent.state_dict(),
                os.path.join(save_path, "agent{0}.pt".format(iteration)))

        if iteration % self.args.log_interval == 0:
            if self.args.policy == 'dqn':
                returns, success_rate, observations, rewards, reward_preds = self.evaluate(
                )
            # This part is super specific for the Semi-Circle env
            # elif self.args.env_name == 'PointRobotSparse-v0':
            #     returns, success_rate, log_probs, observations, \
            #     rewards, reward_preds, reward_belief, reward_belief_discretized, points = self.evaluate()
            else:
                returns, success_rate, log_probs, observations, rewards, reward_preds = self.evaluate(
                )

            if self.args.log_tensorboard:
                tasks_to_vis = np.random.choice(self.args.num_eval_tasks, 5)
                for i, task in enumerate(tasks_to_vis):
                    self.env.reset(task)
                    if PLOT_VIS:
                        self.tb_logger.writer.add_figure(
                            'policy_vis/task_{}'.format(i),
                            utl_eval.plot_rollouts(observations[task, :],
                                                   self.env),
                            self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_figure(
                        'reward_prediction_train/task_{}'.format(i),
                        utl_eval.plot_rew_pred_vs_rew(rewards[task, :],
                                                      reward_preds[task, :]),
                        self._n_rl_update_steps_total)
                    # self.tb_logger.writer.add_figure('reward_prediction_train/task_{}'.format(i),
                    #                                  utl_eval.plot_rew_pred_vs_reward_belief_vs_rew(rewards[task, :],
                    #                                                                                 reward_preds[task, :],
                    #                                                                                 reward_belief[task, :]),
                    #                                  self._n_rl_update_steps_total)
                    # if self.args.env_name == 'PointRobotSparse-v0':     # This part is super specific for the Semi-Circle env
                    #     for t in range(0, int(self.args.trajectory_len/4), 3):
                    #         self.tb_logger.writer.add_figure('discrete_belief_reward_pred_task_{}/timestep_{}'.format(i, t),
                    #                                          utl_eval.plot_discretized_belief_halfcircle(reward_belief_discretized[task, t, :],
                    #                                                                                      points, self.env,
                    #                                                                                      observations[task, :t+1]),
                    #                                          self._n_rl_update_steps_total)
                if self.args.max_rollouts_per_task > 1:
                    for episode_idx in range(self.args.max_rollouts_per_task):
                        self.tb_logger.writer.add_scalar(
                            'returns_multi_episode/episode_{}'.format(
                                episode_idx + 1),
                            np.mean(returns[:, episode_idx]),
                            self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/sum',
                        np.mean(np.sum(returns, axis=-1)),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/success_rate',
                        np.mean(success_rate), self._n_rl_update_steps_total)
                else:
                    self.tb_logger.writer.add_scalar(
                        'returns/returns_mean', np.mean(returns),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns/returns_std', np.std(returns),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns/success_rate', np.mean(success_rate),
                        self._n_rl_update_steps_total)
                if self.args.policy == 'dqn':
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf_loss_vs_n_updates',
                        train_stats['qf_loss'], self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q_network',
                        list(self.agent.qf.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q_target',
                        list(self.agent.target_qf.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.target_qf.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.target_qf.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                else:
                    self.tb_logger.writer.add_scalar(
                        'policy/log_prob', np.mean(log_probs),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf1_loss', train_stats['qf1_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf2_loss', train_stats['qf2_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/policy_loss', train_stats['policy_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/alpha_entropy_loss',
                        train_stats['alpha_entropy_loss'],
                        self._n_rl_update_steps_total)

                    # weights and gradients
                    self.tb_logger.writer.add_scalar(
                        'weights/q1_network',
                        list(self.agent.qf1.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf1.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf1.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q1_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q1_target',
                        list(self.agent.qf1_target.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf1_target.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.qf1_target.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q1_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q2_network',
                        list(self.agent.qf2.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf2.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf2.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q2_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q2_target',
                        list(self.agent.qf2_target.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf2_target.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.qf2_target.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q2_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/policy',
                        list(self.agent.policy.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.policy.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.policy.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/policy',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)

            for k, v in [
                ('num_rl_updates', self._n_rl_update_steps_total),
                ('time_elapsed', time.time() - self._start_time),
                ('iteration', iteration),
            ]:
                self.tb_logger.writer.add_scalar(k, v,
                                                 self._n_rl_update_steps_total)
            self.tb_logger.finish_iteration(iteration)

            print(
                "Iteration -- {}, Success rate -- {:.3f}, Avg. return -- {:.3f}, Elapsed time {:5d}[s]"
                .format(iteration, np.mean(success_rate),
                        np.mean(np.sum(returns, axis=-1)),
                        int(time.time() - self._start_time)))
Example #2
0
def train(vae, dataset, args):
    '''

    :param vae:
    :param dataset: list of lists. each list for different task contains torch tensors of s,a,r,s',t
    :param args:
    :return:
    '''

    if args.log_tensorboard:
        writer = SummaryWriter(args.full_save_path)

    num_tasks = len(dataset)

    start_time = time.time()
    total_updates = 0
    for iter_ in range(args.num_iters):
        n_batches = np.min([
            int(np.ceil(d[0].shape[1] / args.vae_batch_num_rollouts_per_task))
            for d in dataset
        ])
        # traj_permutation = np.random.permutation(dataset[0][0].shape[1])
        traj_permutation = np.random.permutation(
            np.min([d[0].shape[1] for d in dataset]))
        loss_tr, rew_loss_tr, state_loss_tr, kl_loss_tr = 0, 0, 0, 0  # initialize loss for epoch
        n_updates = 0  # count number of updates
        for i in tqdm(range(n_batches), desc="Epoch {}".format(iter_)):

            if i == n_batches - 1:
                traj_indices = traj_permutation[
                    i * args.vae_batch_num_rollouts_per_task:]
            else:
                traj_indices = traj_permutation[
                    i * args.vae_batch_num_rollouts_per_task:(i + 1) *
                    args.vae_batch_num_rollouts_per_task]

            n_task_batches = int(np.ceil(num_tasks / args.tasks_batch_size))
            task_permutation = np.random.permutation(num_tasks)

            for j in range(n_task_batches):  # run over tasks
                if j == n_task_batches - 1:
                    indices = task_permutation[j * args.tasks_batch_size:]
                else:
                    indices = task_permutation[j *
                                               args.tasks_batch_size:(j + 1) *
                                               args.tasks_batch_size]

                obs, actions, rewards, next_obs = [], [], [], []
                for idx in indices:
                    # random_subset = np.random.permutation(dataset[idx][0].shape[1], )
                    # random_subset = np.random.choice(dataset[idx][0].shape[1], args.vae_batch_num_rollouts_per_task)
                    obs.append(
                        ptu.FloatTensor(dataset[idx][0][:, traj_indices, :]))
                    actions.append(
                        ptu.FloatTensor(dataset[idx][1][:, traj_indices, :]))
                    rewards.append(
                        ptu.FloatTensor(dataset[idx][2][:, traj_indices, :]))
                    next_obs.append(
                        ptu.FloatTensor(dataset[idx][3][:, traj_indices, :]))
                obs = torch.cat(obs, dim=1)
                actions = torch.cat(actions, dim=1)
                rewards = torch.cat(rewards, dim=1)
                next_obs = torch.cat(next_obs, dim=1)
                rew_recon_loss, state_recon_loss, kl_term = update_step(
                    vae, obs, actions, rewards, next_obs, args)

                # take average (this is the expectation over p(M))
                loss = args.rew_loss_coeff * rew_recon_loss + \
                       args.state_loss_coeff * state_recon_loss + \
                       args.kl_weight * kl_term
                # update
                vae.optimizer.zero_grad()
                loss.backward()
                vae.optimizer.step()

                n_updates += 1
                loss_tr += loss.item()
                rew_loss_tr += rew_recon_loss.item()
                state_loss_tr += state_recon_loss.item()
                kl_loss_tr += kl_term.item()

        print(
            'Elapsed time: {:.2f}, loss: {:.4f} -- rew_loss: {:.4f} -- state_loss: {:.4f} -- kl: {:.4f}'
            .format(time.time() - start_time, loss_tr / n_updates,
                    rew_loss_tr / n_updates, state_loss_tr / n_updates,
                    kl_loss_tr / n_updates))

        total_updates += n_updates
        # log tb
        if args.log_tensorboard:
            writer.add_scalar('loss/vae_loss', loss_tr / n_updates,
                              total_updates)
            writer.add_scalar('loss/rew_recon_loss', rew_loss_tr / n_updates,
                              total_updates)
            writer.add_scalar('loss/state_recon_loss',
                              state_loss_tr / n_updates, total_updates)
            writer.add_scalar('loss/kl', kl_loss_tr / n_updates, total_updates)
            if args.env_name != 'GridNavi-v2':  # TODO: eval for gridworld domain
                rewards_eval, reward_preds_eval = eval_vae(dataset, vae, args)
                for task in range(NUM_EVAL_TASKS):
                    writer.add_figure(
                        'reward_prediction/task_{}'.format(task),
                        utl_eval.plot_rew_pred_vs_rew(
                            rewards_eval[task, :], reward_preds_eval[task, :]),
                        total_updates)

        if (iter_ + 1) % args.eval_interval == 0:
            pass

        if args.save_model and (iter_ + 1) % args.save_interval == 0:
            save_path = os.path.join(os.getcwd(), args.full_save_path,
                                     'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                vae.encoder.state_dict(),
                os.path.join(save_path, "encoder{0}.pt".format(iter_ + 1)))
            if vae.reward_decoder is not None:
                torch.save(
                    vae.reward_decoder.state_dict(),
                    os.path.join(save_path,
                                 "reward_decoder{0}.pt".format(iter_ + 1)))
            if vae.state_decoder is not None:
                torch.save(
                    vae.state_decoder.state_dict(),
                    os.path.join(save_path,
                                 "state_decoder{0}.pt".format(iter_ + 1)))