Пример #1
0
def train_and_save_policy(env_name, seed, save_period, total_timesteps, hidden_dim=100):
    save_dir = "trained_policy/%s/seed_%d_hidden_%d" % (env_name, seed, hidden_dim)
    os.makedirs(save_dir, exist_ok=True)
    if len(glob.glob("%s/step_*.pkl" % save_dir)) > 0:
        print("already trained: %s" % save_dir)
        return

    def callback(_locals, _globals):
        global n_steps
        model_filepath = "%s/step_%d.pkl" % (save_dir, n_steps + 1)

        if (n_steps + 1) % save_period == 0:
            print('Saving a model to %s' % model_filepath)
            model.save(model_filepath)

        n_steps += 1
        return True

    global n_steps
    n_steps = 0

    env = gym.make(precise_env_name(env_name))
    env.seed(seed)
    set_global_seeds(seed)
    model = SAC(env, ent_coef='auto', seed=seed, hidden_dim=hidden_dim)

    model.learn(total_timesteps=total_timesteps, log_interval=10, seed=seed, callback=callback)
Пример #2
0
 def initialize_policy(self):
     if self.args.policy == 'dqn':
         q_network = FlattenMlp(input_size=self.args.augmented_obs_dim,
                                output_size=self.args.act_space.n,
                                hidden_sizes=self.args.dqn_layers).to(
                                    ptu.device)
         self.agent = DQN(
             q_network,
             # optimiser_vae=self.optimizer_vae,
             lr=self.args.policy_lr,
             gamma=self.args.gamma,
             tau=self.args.soft_target_tau,
         ).to(ptu.device)
     else:
         # assert self.args.act_space.__class__.__name__ == "Box", (
         #     "Can't train SAC with discrete action space!")
         q1_network = FlattenMlp(
             input_size=self.args.augmented_obs_dim + self.args.action_dim,
             output_size=1,
             hidden_sizes=self.args.dqn_layers).to(ptu.device)
         q2_network = FlattenMlp(
             input_size=self.args.augmented_obs_dim + self.args.action_dim,
             output_size=1,
             hidden_sizes=self.args.dqn_layers).to(ptu.device)
         policy = TanhGaussianPolicy(
             obs_dim=self.args.augmented_obs_dim,
             action_dim=self.args.action_dim,
             hidden_sizes=self.args.policy_layers).to(ptu.device)
         self.agent = SAC(
             policy,
             q1_network,
             q2_network,
             actor_lr=self.args.actor_lr,
             critic_lr=self.args.critic_lr,
             gamma=self.args.gamma,
             tau=self.args.soft_target_tau,
             use_cql=self.args.use_cql if 'use_cql' in self.args else False,
             alpha_cql=self.args.alpha_cql
             if 'alpha_cql' in self.args else None,
             entropy_alpha=self.args.entropy_alpha,
             automatic_entropy_tuning=self.args.automatic_entropy_tuning,
             alpha_lr=self.args.alpha_lr,
             clip_grad_value=self.args.clip_grad_value,
         ).to(ptu.device)
Пример #3
0
def load_trained_agent(env_name, trained_policy_seed, trained_policy_step, bias_offset=0, seed=0, hidden_dim=64):
    env = gym.make(precise_env_name(env_name))
    trained_agent = SAC.load("trained_policy/%s/seed_%d_hidden_%d/step_%d.pkl" % (env_name, trained_policy_seed, hidden_dim, trained_policy_step), env, seed=seed, hidden_dim=hidden_dim)
    parameters = trained_agent.get_parameters()
    for i, parameter in enumerate(parameters):
        name, value = parameter
        if name == 'actor/f2_log_std/bias:0':
            parameters[i] = (name, value + bias_offset)
    trained_agent.load_parameters(parameters)
    return trained_agent
Пример #4
0
def load_agent(args, agent_path):
    q1_network = FlattenMlp(input_size=args.obs_dim + args.action_dim,
                            output_size=1,
                            hidden_sizes=args.dqn_layers)
    q2_network = FlattenMlp(input_size=args.obs_dim + args.action_dim,
                            output_size=1,
                            hidden_sizes=args.dqn_layers)
    policy = TanhGaussianPolicy(obs_dim=args.obs_dim,
                                action_dim=args.action_dim,
                                hidden_sizes=args.policy_layers)
    agent = SAC(policy,
                q1_network,
                q2_network,
                actor_lr=args.actor_lr,
                critic_lr=args.critic_lr,
                gamma=args.gamma,
                tau=args.soft_target_tau,
                entropy_alpha=args.entropy_alpha,
                automatic_entropy_tuning=args.automatic_entropy_tuning,
                alpha_lr=args.alpha_lr).to(ptu.device)
    agent.load_state_dict(torch.load(agent_path))
    return agent
Пример #5
0
class OfflineMetaLearner:
    """
    Off-line Meta-Learner class, a.k.a no interaction with env.
    """
    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        self.args, env = off_utl.expand_args(self.args, include_act_space=True)
        if self.args.act_space.__class__.__name__ == "Discrete":
            self.args.policy = 'dqn'
        else:
            self.args.policy = 'sac'

        # load buffers with data
        if 'load_data' not in self.args or self.args.load_data:
            goals, augmented_obs_dim = self.load_buffer(
                env)  # env is input just for possible relabelling option
            self.args.augmented_obs_dim = augmented_obs_dim
            self.goals = goals

        # initialize policy
        self.initialize_policy()

        # load vae for inference in evaluation
        self.load_vae()

        # create environment for evaluation
        self.env = make_env(
            args.env_name,
            args.max_rollouts_per_task,
            presampled_tasks=args.presampled_tasks,
            seed=args.seed,
        )
        # n_tasks=self.args.num_eval_tasks)
        if self.args.env_name == 'GridNavi-v2':
            self.env.unwrapped.goals = [
                tuple(goal.astype(int)) for goal in self.goals
            ]

    def initialize_policy(self):
        if self.args.policy == 'dqn':
            q_network = FlattenMlp(input_size=self.args.augmented_obs_dim,
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers).to(
                                       ptu.device)
            self.agent = DQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        else:
            # assert self.args.act_space.__class__.__name__ == "Box", (
            #     "Can't train SAC with discrete action space!")
            q1_network = FlattenMlp(
                input_size=self.args.augmented_obs_dim + self.args.action_dim,
                output_size=1,
                hidden_sizes=self.args.dqn_layers).to(ptu.device)
            q2_network = FlattenMlp(
                input_size=self.args.augmented_obs_dim + self.args.action_dim,
                output_size=1,
                hidden_sizes=self.args.dqn_layers).to(ptu.device)
            policy = TanhGaussianPolicy(
                obs_dim=self.args.augmented_obs_dim,
                action_dim=self.args.action_dim,
                hidden_sizes=self.args.policy_layers).to(ptu.device)
            self.agent = SAC(
                policy,
                q1_network,
                q2_network,
                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
                use_cql=self.args.use_cql if 'use_cql' in self.args else False,
                alpha_cql=self.args.alpha_cql
                if 'alpha_cql' in self.args else None,
                entropy_alpha=self.args.entropy_alpha,
                automatic_entropy_tuning=self.args.automatic_entropy_tuning,
                alpha_lr=self.args.alpha_lr,
                clip_grad_value=self.args.clip_grad_value,
            ).to(ptu.device)

    def load_vae(self):
        self.vae = VAE(self.args)
        vae_models_path = os.path.join(self.args.vae_dir, self.args.env_name,
                                       self.args.vae_model_name, 'models')
        off_utl.load_trained_vae(self.vae, vae_models_path)

    def load_buffer(self, env):
        if self.args.hindsight_relabelling:  # without arr_type loading -- GPU will explode
            dataset, goals = off_utl.load_dataset(
                data_dir=self.args.relabelled_data_dir,
                args=self.args,
                num_tasks=self.args.num_train_tasks,
                allow_dense_data_loading=False,
                arr_type='numpy')
            dataset = off_utl.batch_to_trajectories(dataset, self.args)
            dataset, goals = off_utl.mix_task_rollouts(
                dataset, env, goals, self.args)  # reward relabelling
            dataset = off_utl.trajectories_to_batch(dataset)
        else:
            dataset, goals = off_utl.load_dataset(
                data_dir=self.args.relabelled_data_dir,
                args=self.args,
                num_tasks=self.args.num_train_tasks,
                allow_dense_data_loading=False,
                arr_type='numpy')
        augmented_obs_dim = dataset[0][0].shape[1]
        self.storage = MultiTaskPolicyStorage(
            max_replay_buffer_size=max([d[0].shape[0] for d in dataset]),
            obs_dim=dataset[0][0].shape[1],
            action_space=self.args.act_space,
            tasks=range(len(goals)),
            trajectory_len=self.args.trajectory_len)
        for task, set in enumerate(dataset):
            self.storage.add_samples(task,
                                     observations=set[0],
                                     actions=set[1],
                                     rewards=set[2],
                                     next_observations=set[3],
                                     terminals=set[4])
        return goals, augmented_obs_dim

    def train(self):
        self._start_training()
        for iter_ in range(self.args.num_iters):
            self.training_mode(True)
            indices = np.random.choice(len(self.goals), self.args.meta_batch)
            train_stats = self.update(indices)

            self.training_mode(False)
            self.log(iter_ + 1, train_stats)

    def update(self, tasks):
        rl_losses_agg = {}
        for update in range(self.args.rl_updates_per_iter):
            # sample random RL batch
            obs, actions, rewards, next_obs, terms = self.sample_rl_batch(
                tasks, self.args.batch_size)
            # flatten out task dimension
            t, b, _ = obs.size()
            obs = obs.view(t * b, -1)
            actions = actions.view(t * b, -1)
            rewards = rewards.view(t * b, -1)
            next_obs = next_obs.view(t * b, -1)
            terms = terms.view(t * b, -1)

            # RL update
            rl_losses = self.agent.update(obs,
                                          actions,
                                          rewards,
                                          next_obs,
                                          terms,
                                          action_space=self.env.action_space)

            for k, v in rl_losses.items():
                if update == 0:  # first iterate - create list
                    rl_losses_agg[k] = [v]
                else:  # append values
                    rl_losses_agg[k].append(v)
        # take mean
        for k in rl_losses_agg:
            rl_losses_agg[k] = np.mean(rl_losses_agg[k])
        self._n_rl_update_steps_total += self.args.rl_updates_per_iter

        return rl_losses_agg

    def evaluate(self):
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps
        num_tasks = self.args.num_eval_tasks
        obs_size = self.env.unwrapped.observation_space.shape[0]

        returns_per_episode = np.zeros((num_tasks, num_episodes))
        success_rate = np.zeros(num_tasks)

        rewards = np.zeros((num_tasks, self.args.trajectory_len))
        reward_preds = np.zeros((num_tasks, self.args.trajectory_len))
        observations = np.zeros(
            (num_tasks, self.args.trajectory_len + 1, obs_size))
        if self.args.policy == 'sac':
            log_probs = np.zeros((num_tasks, self.args.trajectory_len))

        # This part is very specific for the Semi-Circle env
        # if self.args.env_name == 'PointRobotSparse-v0':
        #     reward_belief = np.zeros((num_tasks, self.args.trajectory_len))
        #
        #     low_x, high_x, low_y, high_y = -2., 2., -1., 2.
        #     resolution = 0.1
        #     grid_x = np.arange(low_x, high_x + resolution, resolution)
        #     grid_y = np.arange(low_y, high_y + resolution, resolution)
        #     centers_x = (grid_x[:-1] + grid_x[1:]) / 2
        #     centers_y = (grid_y[:-1] + grid_y[1:]) / 2
        #     yv, xv = np.meshgrid(centers_y, centers_x, sparse=False, indexing='ij')
        #     centers = np.vstack([xv.ravel(), yv.ravel()]).T
        #     n_grid_points = centers.shape[0]
        #     reward_belief_discretized = np.zeros((num_tasks, self.args.trajectory_len, centers.shape[0]))

        for task_loop_i, task in enumerate(
                self.env.unwrapped.get_all_eval_task_idx()):
            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            # get prior parameters
            with torch.no_grad():
                task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder.prior(
                    batch_size=1)

            observations[task_loop_i,
                         step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    augmented_obs = self.get_augmented_obs(
                        obs, task_mean, task_logvar)
                    if self.args.policy == 'dqn':
                        action, value = self.agent.act(obs=augmented_obs,
                                                       deterministic=True)
                    else:
                        action, _, _, log_prob = self.agent.act(
                            obs=augmented_obs,
                            deterministic=self.args.eval_deterministic,
                            return_log_prob=True)

                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(
                        self.env, action.squeeze(dim=0))
                    running_reward += reward.item()
                    # done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True
                    # update encoding
                    task_sample, task_mean, task_logvar, hidden_state = self.update_encoding(
                        obs=next_obs,
                        action=action,
                        reward=reward,
                        done=done,
                        hidden_state=hidden_state)
                    rewards[task_loop_i, step] = reward.item()
                    reward_preds[task_loop_i, step] = ptu.get_numpy(
                        self.vae.reward_decoder(task_sample, next_obs, obs,
                                                action)[0, 0])

                    # This part is very specific for the Semi-Circle env
                    # if self.args.env_name == 'PointRobotSparse-v0':
                    #     reward_belief[task, step] = ptu.get_numpy(
                    #         self.vae.compute_belief_reward(task_mean, task_logvar, obs, next_obs, action)[0])
                    #
                    #     reward_belief_discretized[task, step, :] = ptu.get_numpy(
                    #         self.vae.compute_belief_reward(task_mean.repeat(n_grid_points, 1),
                    #                                        task_logvar.repeat(n_grid_points, 1),
                    #                                        None,
                    #                                        torch.cat((ptu.FloatTensor(centers),
                    #                                                   ptu.zeros(centers.shape[0], 1)), dim=-1).unsqueeze(0),
                    #                                        None)[:, 0])

                    observations[task_loop_i, step + 1, :] = ptu.get_numpy(
                        next_obs[0, :obs_size])
                    if self.args.policy != 'dqn':
                        log_probs[task_loop_i,
                                  step] = ptu.get_numpy(log_prob[0])

                    if "is_goal_state" in dir(
                            self.env.unwrapped
                    ) and self.env.unwrapped.is_goal_state():
                        success_rate[task_loop_i] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task_loop_i, episode_idx] = running_reward

        if self.args.policy == 'dqn':
            return returns_per_episode, success_rate, observations, rewards, reward_preds
        # This part is very specific for the Semi-Circle env
        # elif self.args.env_name == 'PointRobotSparse-v0':
        #     return returns_per_episode, success_rate, log_probs, observations, \
        #            rewards, reward_preds, reward_belief, reward_belief_discretized, centers
        else:
            return returns_per_episode, success_rate, log_probs, observations, rewards, reward_preds

    def log(self, iteration, train_stats):
        # --- save model ---
        if iteration % self.args.save_interval == 0:
            save_path = os.path.join(self.tb_logger.full_output_folder,
                                     'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                self.agent.state_dict(),
                os.path.join(save_path, "agent{0}.pt".format(iteration)))

        if iteration % self.args.log_interval == 0:
            if self.args.policy == 'dqn':
                returns, success_rate, observations, rewards, reward_preds = self.evaluate(
                )
            # This part is super specific for the Semi-Circle env
            # elif self.args.env_name == 'PointRobotSparse-v0':
            #     returns, success_rate, log_probs, observations, \
            #     rewards, reward_preds, reward_belief, reward_belief_discretized, points = self.evaluate()
            else:
                returns, success_rate, log_probs, observations, rewards, reward_preds = self.evaluate(
                )

            if self.args.log_tensorboard:
                tasks_to_vis = np.random.choice(self.args.num_eval_tasks, 5)
                for i, task in enumerate(tasks_to_vis):
                    self.env.reset(task)
                    if PLOT_VIS:
                        self.tb_logger.writer.add_figure(
                            'policy_vis/task_{}'.format(i),
                            utl_eval.plot_rollouts(observations[task, :],
                                                   self.env),
                            self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_figure(
                        'reward_prediction_train/task_{}'.format(i),
                        utl_eval.plot_rew_pred_vs_rew(rewards[task, :],
                                                      reward_preds[task, :]),
                        self._n_rl_update_steps_total)
                    # self.tb_logger.writer.add_figure('reward_prediction_train/task_{}'.format(i),
                    #                                  utl_eval.plot_rew_pred_vs_reward_belief_vs_rew(rewards[task, :],
                    #                                                                                 reward_preds[task, :],
                    #                                                                                 reward_belief[task, :]),
                    #                                  self._n_rl_update_steps_total)
                    # if self.args.env_name == 'PointRobotSparse-v0':     # This part is super specific for the Semi-Circle env
                    #     for t in range(0, int(self.args.trajectory_len/4), 3):
                    #         self.tb_logger.writer.add_figure('discrete_belief_reward_pred_task_{}/timestep_{}'.format(i, t),
                    #                                          utl_eval.plot_discretized_belief_halfcircle(reward_belief_discretized[task, t, :],
                    #                                                                                      points, self.env,
                    #                                                                                      observations[task, :t+1]),
                    #                                          self._n_rl_update_steps_total)
                if self.args.max_rollouts_per_task > 1:
                    for episode_idx in range(self.args.max_rollouts_per_task):
                        self.tb_logger.writer.add_scalar(
                            'returns_multi_episode/episode_{}'.format(
                                episode_idx + 1),
                            np.mean(returns[:, episode_idx]),
                            self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/sum',
                        np.mean(np.sum(returns, axis=-1)),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/success_rate',
                        np.mean(success_rate), self._n_rl_update_steps_total)
                else:
                    self.tb_logger.writer.add_scalar(
                        'returns/returns_mean', np.mean(returns),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns/returns_std', np.std(returns),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns/success_rate', np.mean(success_rate),
                        self._n_rl_update_steps_total)
                if self.args.policy == 'dqn':
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf_loss_vs_n_updates',
                        train_stats['qf_loss'], self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q_network',
                        list(self.agent.qf.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q_target',
                        list(self.agent.target_qf.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.target_qf.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.target_qf.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                else:
                    self.tb_logger.writer.add_scalar(
                        'policy/log_prob', np.mean(log_probs),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf1_loss', train_stats['qf1_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf2_loss', train_stats['qf2_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/policy_loss', train_stats['policy_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/alpha_entropy_loss',
                        train_stats['alpha_entropy_loss'],
                        self._n_rl_update_steps_total)

                    # weights and gradients
                    self.tb_logger.writer.add_scalar(
                        'weights/q1_network',
                        list(self.agent.qf1.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf1.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf1.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q1_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q1_target',
                        list(self.agent.qf1_target.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf1_target.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.qf1_target.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q1_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q2_network',
                        list(self.agent.qf2.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf2.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf2.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q2_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q2_target',
                        list(self.agent.qf2_target.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf2_target.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.qf2_target.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q2_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/policy',
                        list(self.agent.policy.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.policy.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.policy.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/policy',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)

            for k, v in [
                ('num_rl_updates', self._n_rl_update_steps_total),
                ('time_elapsed', time.time() - self._start_time),
                ('iteration', iteration),
            ]:
                self.tb_logger.writer.add_scalar(k, v,
                                                 self._n_rl_update_steps_total)
            self.tb_logger.finish_iteration(iteration)

            print(
                "Iteration -- {}, Success rate -- {:.3f}, Avg. return -- {:.3f}, Elapsed time {:5d}[s]"
                .format(iteration, np.mean(success_rate),
                        np.mean(np.sum(returns, axis=-1)),
                        int(time.time() - self._start_time)))

    def sample_rl_batch(self, tasks, batch_size):
        ''' sample batch of unordered rl training data from a list/array of tasks '''
        # this batch consists of transitions sampled randomly from replay buffer
        batches = [
            ptu.np_to_pytorch_batch(self.storage.random_batch(
                task, batch_size)) for task in tasks
        ]
        unpacked = [utl.unpack_batch(batch) for batch in batches]
        # group elements together
        unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
        unpacked = [torch.cat(x, dim=0) for x in unpacked]
        return unpacked

    def _start_training(self):
        self._n_rl_update_steps_total = 0
        self._start_time = time.time()

    def training_mode(self, mode):
        self.agent.train(mode)

    def update_encoding(self, obs, action, reward, done, hidden_state):
        # reset hidden state of the recurrent net when the task is done
        hidden_state = self.vae.encoder.reset_hidden(hidden_state, done)
        with torch.no_grad():  # size should be (batch, dim)
            task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder(
                actions=action.float(),
                states=obs,
                rewards=reward,
                hidden_state=hidden_state,
                return_prior=False)

        return task_sample, task_mean, task_logvar, hidden_state

    @staticmethod
    def get_augmented_obs(obs, mean, logvar):
        mean = mean.reshape((-1, mean.shape[-1]))
        logvar = logvar.reshape((-1, logvar.shape[-1]))
        return torch.cat((obs, mean, logvar), dim=-1)

    def load_model(self, agent_path, device='cpu'):
        self.agent.load_state_dict(torch.load(agent_path, map_location=device))
        self.load_vae()
        self.training_mode(False)
Пример #6
0
def run(config):
    torch.set_num_threads(1)
    env_descr = 'map%i_%iagents_task%i' % (config.map_ind, config.num_agents,
                                           config.task_config)
    model_dir = Path('./models') / config.env_type / env_descr / config.model_name
    if not model_dir.exists():
        run_num = 1
    else:
        exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in
                         model_dir.iterdir() if
                         str(folder.name).startswith('run')]
        if len(exst_run_nums) == 0:
            run_num = 1
        else:
            run_num = max(exst_run_nums) + 1
    curr_run = 'run%i' % run_num
    run_dir = model_dir / curr_run
    log_dir = run_dir / 'logs'
    os.makedirs(log_dir)
    logger = SummaryWriter(str(log_dir))

    torch.manual_seed(run_num)
    np.random.seed(run_num)
    env = make_parallel_env(config, run_num)
    if config.nonlinearity == 'relu':
        nonlin = torch.nn.functional.relu
    elif config.nonlinearity == 'leaky_relu':
        nonlin = torch.nn.functional.leaky_relu
    if config.intrinsic_reward == 0:
        n_intr_rew_types = 0
        sep_extr_head = True
    else:
        n_intr_rew_types = len(config.explr_types)
        sep_extr_head = False
    n_rew_heads = n_intr_rew_types + int(sep_extr_head)

    model = SAC.init_from_env(env,
                              nagents=config.num_agents,
                              tau=config.tau,
                              hard_update_interval=config.hard_update,
                              pi_lr=config.pi_lr,
                              q_lr=config.q_lr,
                              phi_lr=config.phi_lr,
                              adam_eps=config.adam_eps,
                              q_decay=config.q_decay,
                              phi_decay=config.phi_decay,
                              gamma_e=config.gamma_e,
                              gamma_i=config.gamma_i,
                              pol_hidden_dim=config.pol_hidden_dim,
                              critic_hidden_dim=config.critic_hidden_dim,
                              nonlin=nonlin,
                              reward_scale=config.reward_scale,
                              head_reward_scale=config.head_reward_scale,
                              beta=config.beta,
                              n_intr_rew_types=n_intr_rew_types,
                              sep_extr_head=sep_extr_head)
    replay_buffer = ReplayBuffer(config.buffer_length, model.nagents,
                                 env.state_space,
                                 env.observation_space,
                                 env.action_space)
    intr_rew_rms = [[RunningMeanStd()
                     for i in range(config.num_agents)]
                    for j in range(n_intr_rew_types)]
    eps_this_turn = 0  # episodes so far this turn
    active_envs = np.ones(config.n_rollout_threads)  # binary indicator of whether env is active
    env_times = np.zeros(config.n_rollout_threads, dtype=int)
    env_ep_extr_rews = np.zeros(config.n_rollout_threads)
    env_extr_rets = np.zeros(config.n_rollout_threads)
    env_ep_intr_rews = [[np.zeros(config.n_rollout_threads) for i in range(config.num_agents)]
                        for j in range(n_intr_rew_types)]
    recent_ep_extr_rews = deque(maxlen=100)
    recent_ep_intr_rews = [[deque(maxlen=100) for i in range(config.num_agents)]
                           for j in range(n_intr_rew_types)]
    recent_ep_lens = deque(maxlen=100)
    recent_found_treasures = [deque(maxlen=100) for i in range(config.num_agents)]
    meta_turn_rets = []
    extr_ret_rms = [RunningMeanStd() for i in range(n_rew_heads)]
    t = 0
    steps_since_update = 0

    state, obs = env.reset()

    while t < config.train_time:
        model.prep_rollouts(device='cuda' if config.gpu_rollout else 'cpu')
        # convert to torch tensor
        torch_obs = apply_to_all_elements(obs, lambda x: torch.tensor(x, dtype=torch.float32, device='cuda' if config.gpu_rollout else 'cpu'))
        # get actions as torch tensors
        torch_agent_actions = model.step(torch_obs, explore=True)
        # convert actions to numpy arrays
        agent_actions = apply_to_all_elements(torch_agent_actions, lambda x: x.cpu().data.numpy())
        # rearrange actions to be per environment
        actions = [[ac[i] for ac in agent_actions] for i in range(int(active_envs.sum()))]
        try:
            with timeout(seconds=1):
                next_state, next_obs, rewards, dones, infos = env.step(actions, env_mask=active_envs)
        # either environment got stuck or vizdoom crashed (vizdoom is unstable w/ multi-agent scenarios)
        except (TimeoutError, ViZDoomErrorException, ViZDoomIsNotRunningException, ViZDoomUnexpectedExitException) as e:
            print("Environments are broken...")
            env.close(force=True)
            print("Closed environments, starting new...")
            env = make_parallel_env(config, run_num)
            state, obs = env.reset()
            env_ep_extr_rews[active_envs.astype(bool)] = 0.0
            env_extr_rets[active_envs.astype(bool)] = 0.0
            for i in range(n_intr_rew_types):
                for j in range(config.num_agents):
                    env_ep_intr_rews[i][j][active_envs.astype(bool)] = 0.0
            env_times = np.zeros(config.n_rollout_threads, dtype=int)
            state = apply_to_all_elements(state, lambda x: x[active_envs.astype(bool)])
            obs = apply_to_all_elements(obs, lambda x: x[active_envs.astype(bool)])
            continue

        steps_since_update += int(active_envs.sum())
        if config.intrinsic_reward == 1:
            # if using state-visit counts, store state indices
            # shape = (n_envs, n_agents, n_inds)
            state_inds = np.array([i['visit_count_lookup'] for i in infos],
                                  dtype=int)
            state_inds_t = state_inds.transpose(1, 0, 2)
            novelties = get_count_based_novelties(env, state_inds_t, device='cpu')
            intr_rews = get_intrinsic_rewards(novelties, config, intr_rew_rms,
                                              update_irrms=True, active_envs=active_envs,
                                              device='cpu')
            intr_rews = apply_to_all_elements(intr_rews, lambda x: x.numpy().flatten())
        else:
            intr_rews = None
            state_inds = None
            state_inds_t = None

        replay_buffer.push(state, obs, agent_actions, rewards, next_state, next_obs, dones,
                           state_inds=state_inds)
        env_ep_extr_rews[active_envs.astype(bool)] += np.array(rewards)
        env_extr_rets[active_envs.astype(bool)] += np.array(rewards) * config.gamma_e**(env_times[active_envs.astype(bool)])
        env_times += active_envs.astype(int)
        if intr_rews is not None:
            for i in range(n_intr_rew_types):
                for j in range(config.num_agents):
                    env_ep_intr_rews[i][j][active_envs.astype(bool)] += intr_rews[i][j]
        over_time = env_times >= config.max_episode_length
        full_dones = np.zeros(config.n_rollout_threads)
        for i, env_i in enumerate(np.where(active_envs)[0]):
            full_dones[env_i] = dones[i]
        need_reset = np.logical_or(full_dones, over_time)
        # create masks ONLY for active envs
        active_over_time = env_times[active_envs.astype(bool)] >= config.max_episode_length
        active_need_reset = np.logical_or(dones, active_over_time)
        if any(need_reset):
            try:
                with timeout(seconds=1):
                    # reset any environments that are past the max number of time steps or done
                    state, obs = env.reset(need_reset=need_reset)
            # either environment got stuck or vizdoom crashed (vizdoom is unstable w/ multi-agent scenarios)
            except (TimeoutError, ViZDoomErrorException, ViZDoomIsNotRunningException, ViZDoomUnexpectedExitException) as e:
                print("Environments are broken...")
                env.close(force=True)
                print("Closed environments, starting new...")
                env = make_parallel_env(config, run_num)
                state, obs = env.reset()
                # other envs that were force reset (rest taken care of in subsequent code)
                other_reset = np.logical_not(need_reset)
                env_ep_extr_rews[other_reset.astype(bool)] = 0.0
                env_extr_rets[other_reset.astype(bool)] = 0.0
                for i in range(n_intr_rew_types):
                    for j in range(config.num_agents):
                        env_ep_intr_rews[i][j][other_reset.astype(bool)] = 0.0
                env_times = np.zeros(config.n_rollout_threads, dtype=int)
        else:
            state, obs = next_state, next_obs
        for env_i in np.where(need_reset)[0]:
            recent_ep_extr_rews.append(env_ep_extr_rews[env_i])
            meta_turn_rets.append(env_extr_rets[env_i])
            if intr_rews is not None:
                for j in range(n_intr_rew_types):
                    for k in range(config.num_agents):
                        # record intrinsic rewards per step (so we don't confuse shorter episodes with less intrinsic rewards)
                        recent_ep_intr_rews[j][k].append(env_ep_intr_rews[j][k][env_i] / env_times[env_i])
                        env_ep_intr_rews[j][k][env_i] = 0

            recent_ep_lens.append(env_times[env_i])
            env_times[env_i] = 0
            env_ep_extr_rews[env_i] = 0
            env_extr_rets[env_i] = 0
            eps_this_turn += 1

            if eps_this_turn + active_envs.sum() - 1 >= config.metapol_episodes:
                active_envs[env_i] = 0

        for i in np.where(active_need_reset)[0]:
            for j in range(config.num_agents):
                # len(infos) = number of active envs
                recent_found_treasures[j].append(infos[i]['n_found_treasures'][j])

        if eps_this_turn >= config.metapol_episodes:
            if not config.uniform_heads and n_rew_heads > 1:
                meta_turn_rets = np.array(meta_turn_rets)
                if all(errms.count < 1 for errms in extr_ret_rms):
                    for errms in extr_ret_rms:
                        errms.mean = meta_turn_rets.mean()
                extr_ret_rms[model.curr_pol_heads[0]].update(meta_turn_rets)
                for i in range(config.metapol_updates):
                    model.update_heads_onpol(meta_turn_rets, extr_ret_rms, logger=logger)
            pol_heads = model.sample_pol_heads(uniform=config.uniform_heads)
            model.set_pol_heads(pol_heads)
            eps_this_turn = 0
            meta_turn_rets = []
            active_envs = np.ones(config.n_rollout_threads)

        if any(need_reset):  # reset returns state and obs for all envs, so make sure we're only looking at active
            state = apply_to_all_elements(state, lambda x: x[active_envs.astype(bool)])
            obs = apply_to_all_elements(obs, lambda x: x[active_envs.astype(bool)])

        if (len(replay_buffer) >= max(config.batch_size,
                                      config.steps_before_update) and
                (steps_since_update >= config.steps_per_update)):
            steps_since_update = 0
            print('Updating at time step %i' % t)
            model.prep_training(device='cuda' if config.use_gpu else 'cpu')

            for u_i in range(config.num_updates):
                sample = replay_buffer.sample(config.batch_size,
                                              to_gpu=config.use_gpu,
                                              state_inds=(config.intrinsic_reward == 1))

                if config.intrinsic_reward == 0:  # no intrinsic reward
                    intr_rews = None
                    state_inds = None
                else:
                    sample, state_inds = sample
                    novelties = get_count_based_novelties(
                        env, state_inds,
                        device='cuda' if config.use_gpu else 'cpu')
                    intr_rews = get_intrinsic_rewards(novelties, config, intr_rew_rms,
                                                      update_irrms=False,
                                                      device='cuda' if config.use_gpu else 'cpu')

                model.update_critic(sample, logger=logger, intr_rews=intr_rews)
                model.update_policies(sample, logger=logger)
                model.update_all_targets()
            if len(recent_ep_extr_rews) > 10:
                logger.add_scalar('episode_rewards/extrinsic/mean',
                                  np.mean(recent_ep_extr_rews), t)
                logger.add_scalar('episode_lengths/mean',
                                  np.mean(recent_ep_lens), t)
                if config.intrinsic_reward == 1:
                    for i in range(n_intr_rew_types):
                        for j in range(config.num_agents):
                            logger.add_scalar('episode_rewards/intrinsic%i_agent%i/mean' % (i, j),
                                              np.mean(recent_ep_intr_rews[i][j]), t)
                for i in range(config.num_agents):
                    logger.add_scalar('agent%i/n_found_treasures' % i, np.mean(recent_found_treasures[i]), t)
                logger.add_scalar('total_n_found_treasures', sum(np.array(recent_found_treasures[i]) for i in range(config.num_agents)).mean(), t)

        if t % config.save_interval < config.n_rollout_threads:
            model.prep_training(device='cpu')
            os.makedirs(run_dir / 'incremental', exist_ok=True)
            model.save(run_dir / 'incremental' / ('model_%isteps.pt' % (t + 1)))
            model.save(run_dir / 'model.pt')

        t += active_envs.sum()
    model.prep_training(device='cpu')
    model.save(run_dir / 'model.pt')
    logger.close()
    env.close(force=(config.env_type == 'vizdoom'))
Пример #7
0
class Learner:
    """
    Learner class.
    """
    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialise environment
        self.env = make_env(
            self.args.env_name,
            self.args.max_rollouts_per_task,
            seed=self.args.seed,
            n_tasks=1,
            modify_init_state_dist=self.args.modify_init_state_dist
            if 'modify_init_state_dist' in self.args else False,
            on_circle_init_state=self.args.on_circle_init_state
            if 'on_circle_init_state' in self.args else True)

        # saving buffer with task in name folder
        if hasattr(self.args, 'save_buffer') and self.args.save_buffer:
            env_dir = os.path.join(self.args.main_save_dir,
                                   '{}'.format(self.args.env_name))
            goal = self.env.unwrapped._goal
            self.output_dir = os.path.join(
                env_dir, self.args.save_dir,
                'seed_{}_'.format(self.args.seed) +
                off_utl.create_goal_path_ext_from_goal(goal))

        if self.args.save_models or self.args.save_buffer:
            os.makedirs(self.output_dir, exist_ok=True)
            config_utl.save_config_file(args, self.output_dir)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        # if not self.args.log_tensorboard:
        #     self.save_config_json_file()
        # unwrapped env to get some info about the environment
        unwrapped_env = self.env.unwrapped

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = unwrapped_env._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task
        self.args.max_trajectory_len = args.max_trajectory_len

        # get action / observation dimensions
        if isinstance(self.env.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.env.action_space.shape[0]
        self.args.obs_dim = self.env.observation_space.shape[0]
        self.args.num_states = unwrapped_env.num_states if hasattr(
            unwrapped_env, 'num_states') else None
        self.args.act_space = self.env.action_space

        # simulate env step to get reward types
        _, _, _, info = unwrapped_env.step(unwrapped_env.action_space.sample())
        reward_types = [
            reward_type for reward_type in list(info.keys())
            if reward_type.startswith('reward')
        ]

        # support dense rewards training (if exists)
        self.args.dense_train_sparse_test = self.args.dense_train_sparse_test \
            if 'dense_train_sparse_test' in self.args else False

        # initialize policy
        self.initialize_policy()
        # initialize buffer for RL updates
        self.policy_storage = MultiTaskPolicyStorage(
            max_replay_buffer_size=int(self.args.policy_buffer_size),
            obs_dim=self.args.obs_dim,
            action_space=self.env.action_space,
            tasks=[0],
            trajectory_len=args.max_trajectory_len,
            num_reward_arrays=len(reward_types)
            if reward_types and self.args.dense_train_sparse_test else 1,
            reward_types=reward_types,
        )

        self.args.belief_reward = False  # initialize arg to not use belief rewards

    def initialize_policy(self):

        if self.args.policy == 'dqn':
            assert self.args.act_space.__class__.__name__ == "Discrete", (
                "Can't train DQN with continuous action space!")
            q_network = FlattenMlp(input_size=self.args.obs_dim,
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers)
            self.agent = DQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                gamma=self.args.gamma,
                eps_init=self.args.dqn_epsilon_init,
                eps_final=self.args.dqn_epsilon_final,
                exploration_iters=self.args.dqn_exploration_iters,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        # elif self.args.policy == 'ddqn':
        #     assert self.args.act_space.__class__.__name__ == "Discrete", (
        #         "Can't train DDQN with continuous action space!")
        #     q_network = FlattenMlp(input_size=self.args.obs_dim,
        #                            output_size=self.args.act_space.n,
        #                            hidden_sizes=self.args.dqn_layers)
        #     self.agent = DoubleDQN(
        #         q_network,
        #         # optimiser_vae=self.optimizer_vae,
        #         lr=self.args.policy_lr,
        #         eps_optim=self.args.dqn_eps,
        #         alpha_optim=self.args.dqn_alpha,
        #         gamma=self.args.gamma,
        #         eps_init=self.args.dqn_epsilon_init,
        #         eps_final=self.args.dqn_epsilon_final,
        #         exploration_iters=self.args.dqn_exploration_iters,
        #         tau=self.args.soft_target_tau,
        #     ).to(ptu.device)
        elif self.args.policy == 'sac':
            assert self.args.act_space.__class__.__name__ == "Box", (
                "Can't train SAC with discrete action space!")
            q1_network = FlattenMlp(input_size=self.args.obs_dim +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            q2_network = FlattenMlp(input_size=self.args.obs_dim +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            policy = TanhGaussianPolicy(obs_dim=self.args.obs_dim,
                                        action_dim=self.args.action_dim,
                                        hidden_sizes=self.args.policy_layers)
            self.agent = SAC(
                policy,
                q1_network,
                q2_network,
                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
                entropy_alpha=self.args.entropy_alpha,
                automatic_entropy_tuning=self.args.automatic_entropy_tuning,
                alpha_lr=self.args.alpha_lr).to(ptu.device)
        else:
            raise NotImplementedError

    def train(self):
        """
        meta-training loop
        """

        self._start_training()
        self.task_idx = 0
        for iter_ in range(self.args.num_iters):
            self.training_mode(True)
            if iter_ == 0:
                print('Collecting initial pool of data..')
                self.env.reset_task(idx=self.task_idx)
                self.collect_rollouts(
                    num_rollouts=self.args.num_init_rollouts_pool,
                    random_actions=True)
                print('Done!')
            # collect data from subset of train tasks

            self.env.reset_task(idx=self.task_idx)
            self.collect_rollouts(num_rollouts=self.args.num_rollouts_per_iter)
            # update
            train_stats = self.update([self.task_idx])
            self.training_mode(False)

            if self.args.policy == 'dqn':
                self.agent.set_exploration_parameter(iter_ + 1)
            # evaluate and log
            if (iter_ + 1) % self.args.log_interval == 0:
                self.log(iter_ + 1, train_stats)

    def update(self, tasks):
        '''
        RL updates
        :param tasks: list/array of task indices. perform update based on the tasks
        :return:
        '''

        # --- RL TRAINING ---
        rl_losses_agg = {}
        for update in range(self.args.rl_updates_per_iter):
            # sample random RL batch
            obs, actions, rewards, next_obs, terms = self.sample_rl_batch(
                tasks, self.args.batch_size)
            # flatten out task dimension
            t, b, _ = obs.size()
            obs = obs.view(t * b, -1)
            actions = actions.view(t * b, -1)
            rewards = rewards.view(t * b, -1)
            next_obs = next_obs.view(t * b, -1)
            terms = terms.view(t * b, -1)

            # RL update
            rl_losses = self.agent.update(obs, actions, rewards, next_obs,
                                          terms)

            for k, v in rl_losses.items():
                if update == 0:  # first iterate - create list
                    rl_losses_agg[k] = [v]
                else:  # append values
                    rl_losses_agg[k].append(v)
        # take mean
        for k in rl_losses_agg:
            rl_losses_agg[k] = np.mean(rl_losses_agg[k])
        self._n_rl_update_steps_total += self.args.rl_updates_per_iter

        return rl_losses_agg

    def evaluate(self, tasks):
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps

        returns_per_episode = np.zeros((len(tasks), num_episodes))
        success_rate = np.zeros(len(tasks))

        if self.args.policy == 'dqn':
            values = np.zeros((len(tasks), self.args.max_trajectory_len))
        else:
            obs_size = self.env.unwrapped.observation_space.shape[0]
            observations = np.zeros(
                (len(tasks), self.args.max_trajectory_len + 1, obs_size))
            log_probs = np.zeros((len(tasks), self.args.max_trajectory_len))

        for task_idx, task in enumerate(tasks):

            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            if self.args.policy == 'sac':
                observations[task_idx,
                             step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    if self.args.policy == 'dqn':
                        action, value = self.agent.act(obs=obs,
                                                       deterministic=True)
                    else:
                        action, _, _, log_prob = self.agent.act(
                            obs=obs,
                            deterministic=self.args.eval_deterministic,
                            return_log_prob=True)
                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(
                        self.env, action.squeeze(dim=0))
                    running_reward += reward.item()
                    if self.args.policy == 'dqn':
                        values[task_idx, step] = value.item()
                    else:
                        observations[task_idx, step + 1, :] = ptu.get_numpy(
                            next_obs[0, :obs_size])
                        log_probs[task_idx, step] = ptu.get_numpy(log_prob[0])

                    if "is_goal_state" in dir(
                            self.env.unwrapped
                    ) and self.env.unwrapped.is_goal_state():
                        success_rate[task_idx] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task_idx, episode_idx] = running_reward

        if self.args.policy == 'dqn':
            return returns_per_episode, success_rate, values
        else:
            return returns_per_episode, success_rate, log_probs, observations

    def log(self, iteration, train_stats):
        # --- save models ---
        if iteration % self.args.save_interval == 0:
            if self.args.save_models:
                if self.args.log_tensorboard:
                    save_path = os.path.join(self.tb_logger.full_output_folder,
                                             'models')
                else:
                    save_path = os.path.join(self.output_dir, 'models')
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                torch.save(
                    self.agent.state_dict(),
                    os.path.join(save_path, "agent{0}.pt".format(iteration)))
            if hasattr(self.args, 'save_buffer') and self.args.save_buffer:
                self.save_buffer()
        # evaluate to get more stats
        if self.args.policy == 'dqn':
            # get stats on train tasks
            returns_train, success_rate_train, values = self.evaluate([0])
        else:
            # get stats on train tasks
            returns_train, success_rate_train, log_probs, observations = self.evaluate(
                [0])

        if self.args.log_tensorboard:
            if self.args.policy != 'dqn':
                #     self.env.reset(0)
                #     self.tb_logger.writer.add_figure('policy_vis_train/task_0',
                #                                      utl_eval.plot_rollouts(observations[0, :], self.env),
                #                                      self._n_env_steps_total)
                #     obs, _, _, _, _ = self.sample_rl_batch(tasks=[0],
                #                                            batch_size=self.policy_storage.task_buffers[0].size())
                #     self.tb_logger.writer.add_figure('state_space_coverage/task_0',
                #                                      utl_eval.plot_visited_states(ptu.get_numpy(obs[0][:, :2]), self.env),
                #                                      self._n_env_steps_total)
                pass
            # some metrics
            self.tb_logger.writer.add_scalar(
                'metrics/successes_in_buffer',
                self._successes_in_buffer / self._n_env_steps_total,
                self._n_env_steps_total)

            if self.args.max_rollouts_per_task > 1:
                for episode_idx in range(self.args.max_rollouts_per_task):
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/episode_{}'.format(episode_idx +
                                                                  1),
                        np.mean(returns_train[:, episode_idx]),
                        self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'returns_multi_episode/sum',
                    np.mean(np.sum(returns_train, axis=-1)),
                    self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'returns_multi_episode/success_rate',
                    np.mean(success_rate_train), self._n_env_steps_total)
            else:
                self.tb_logger.writer.add_scalar('returns/returns_mean_train',
                                                 np.mean(returns_train),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns/returns_std_train',
                                                 np.std(returns_train),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns/success_rate_train',
                                                 np.mean(success_rate_train),
                                                 self._n_env_steps_total)
            # policy
            if self.args.policy == 'dqn':
                self.tb_logger.writer.add_scalar('policy/value_init',
                                                 np.mean(values[:, 0]),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'policy/value_halfway',
                    np.mean(values[:, int(values.shape[-1] / 2)]),
                    self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('policy/value_final',
                                                 np.mean(values[:, -1]),
                                                 self._n_env_steps_total)

                self.tb_logger.writer.add_scalar('policy/exploration_epsilon',
                                                 self.agent.eps,
                                                 self._n_env_steps_total)
                # RL losses
                self.tb_logger.writer.add_scalar(
                    'rl_losses/qf_loss_vs_n_updates', train_stats['qf_loss'],
                    self._n_rl_update_steps_total)
                self.tb_logger.writer.add_scalar(
                    'rl_losses/qf_loss_vs_n_env_steps', train_stats['qf_loss'],
                    self._n_env_steps_total)
            else:
                self.tb_logger.writer.add_scalar('policy/log_prob',
                                                 np.mean(log_probs),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/qf1_loss',
                                                 train_stats['qf1_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/qf2_loss',
                                                 train_stats['qf2_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/policy_loss',
                                                 train_stats['policy_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'rl_losses/alpha_entropy_loss',
                    train_stats['alpha_entropy_loss'], self._n_env_steps_total)

            # weights and gradients
            if self.args.policy == 'dqn':
                self.tb_logger.writer.add_scalar(
                    'weights/q_network',
                    list(self.agent.qf.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q_network',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/q_target',
                    list(self.agent.target_qf.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.target_qf.parameters())[0].grad is not None:
                    param_list = list(self.agent.target_qf.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q_target',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
            else:
                self.tb_logger.writer.add_scalar(
                    'weights/q1_network',
                    list(self.agent.qf1.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf1.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf1.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q1_network',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/q1_target',
                    list(self.agent.qf1_target.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf1_target.parameters()
                        )[0].grad is not None:
                    param_list = list(self.agent.qf1_target.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q1_target',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/q2_network',
                    list(self.agent.qf2.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf2.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf2.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q2_network',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/q2_target',
                    list(self.agent.qf2_target.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf2_target.parameters()
                        )[0].grad is not None:
                    param_list = list(self.agent.qf2_target.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q2_target',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/policy',
                    list(self.agent.policy.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.policy.parameters())[0].grad is not None:
                    param_list = list(self.agent.policy.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/policy',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)

        print(
            "Iteration -- {}, Success rate -- {:.3f}, Avg. return -- {:.3f}, Elapsed time {:5d}[s]"
            .format(iteration, np.mean(success_rate_train),
                    np.mean(np.sum(returns_train, axis=-1)),
                    int(time.time() - self._start_time)))
        # output to user
        # print("Iteration -- {:3d}, Num. RL updates -- {:6d}, Elapsed time {:5d}[s]".
        #       format(iteration,
        #              self._n_rl_update_steps_total,
        #              int(time.time() - self._start_time)))

    def training_mode(self, mode):
        self.agent.train(mode)

    def collect_rollouts(self, num_rollouts, random_actions=False):
        '''

        :param num_rollouts:
        :param random_actions: whether to use policy to sample actions, or randomly sample action space
        :return:
        '''

        for rollout in range(num_rollouts):
            obs = ptu.from_numpy(self.env.reset(self.task_idx))
            obs = obs.reshape(-1, obs.shape[-1])
            done_rollout = False

            while not done_rollout:
                if random_actions:
                    if self.args.policy == 'dqn':
                        action = ptu.FloatTensor([[[
                            self.env.action_space.sample()
                        ]]]).long()  # Sample random action
                    else:
                        action = ptu.FloatTensor([
                            self.env.action_space.sample()
                        ])  # Sample random action
                else:
                    if self.args.policy == 'dqn':
                        action, _ = self.agent.act(obs=obs)  # DQN
                    else:
                        action, _, _, _ = self.agent.act(obs=obs)  # SAC
                # observe reward and next obs
                next_obs, reward, done, info = utl.env_step(
                    self.env, action.squeeze(dim=0))
                done_rollout = False if ptu.get_numpy(
                    done[0][0]) == 0. else True

                # add data to policy buffer - (s+, a, r, s'+, term)
                term = self.env.unwrapped.is_goal_state(
                ) if "is_goal_state" in dir(self.env.unwrapped) else False
                if self.args.dense_train_sparse_test:
                    rew_to_buffer = {
                        rew_type: rew
                        for rew_type, rew in info.items()
                        if rew_type.startswith('reward')
                    }
                else:
                    rew_to_buffer = ptu.get_numpy(reward.squeeze(dim=0))
                self.policy_storage.add_sample(
                    task=self.task_idx,
                    observation=ptu.get_numpy(obs.squeeze(dim=0)),
                    action=ptu.get_numpy(action.squeeze(dim=0)),
                    reward=rew_to_buffer,
                    terminal=np.array([term], dtype=float),
                    next_observation=ptu.get_numpy(next_obs.squeeze(dim=0)))

                # set: obs <- next_obs
                obs = next_obs.clone()

                # update statistics
                self._n_env_steps_total += 1
                if "is_goal_state" in dir(
                        self.env.unwrapped
                ) and self.env.unwrapped.is_goal_state():  # count successes
                    self._successes_in_buffer += 1
            self._n_rollouts_total += 1

    def sample_rl_batch(self, tasks, batch_size):
        ''' sample batch of unordered rl training data from a list/array of tasks '''
        # this batch consists of transitions sampled randomly from replay buffer
        batches = [
            ptu.np_to_pytorch_batch(
                self.policy_storage.random_batch(task, batch_size))
            for task in tasks
        ]
        unpacked = [utl.unpack_batch(batch) for batch in batches]
        # group elements together
        unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
        unpacked = [torch.cat(x, dim=0) for x in unpacked]
        return unpacked

    def _start_training(self):
        self._n_env_steps_total = 0
        self._n_rl_update_steps_total = 0
        self._n_vae_update_steps_total = 0
        self._n_rollouts_total = 0
        self._successes_in_buffer = 0

        self._start_time = time.time()

    def load_model(self, device='cpu', **kwargs):
        if "agent_path" in kwargs:
            self.agent.load_state_dict(
                torch.load(kwargs["agent_path"], map_location=device))
        self.training_mode(False)

    def save_buffer(self):
        size = self.policy_storage.task_buffers[0].size()
        np.save(os.path.join(self.output_dir, 'obs'),
                self.policy_storage.task_buffers[0]._observations[:size])
        np.save(os.path.join(self.output_dir, 'actions'),
                self.policy_storage.task_buffers[0]._actions[:size])
        if self.args.dense_train_sparse_test:
            for reward_type, reward_arr in self.policy_storage.task_buffers[
                    0]._rewards.items():
                np.save(os.path.join(self.output_dir, reward_type),
                        reward_arr[:size])
        else:
            np.save(os.path.join(self.output_dir, 'rewards'),
                    self.policy_storage.task_buffers[0]._rewards[:size])
        np.save(os.path.join(self.output_dir, 'next_obs'),
                self.policy_storage.task_buffers[0]._next_obs[:size])
        np.save(os.path.join(self.output_dir, 'terminals'),
                self.policy_storage.task_buffers[0]._terminals[:size])
Пример #8
0
    def initialize_policy(self):

        if self.args.policy == 'dqn':
            assert self.args.act_space.__class__.__name__ == "Discrete", (
                "Can't train DQN with continuous action space!")
            q_network = FlattenMlp(input_size=self.args.obs_dim,
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers)
            self.agent = DQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                gamma=self.args.gamma,
                eps_init=self.args.dqn_epsilon_init,
                eps_final=self.args.dqn_epsilon_final,
                exploration_iters=self.args.dqn_exploration_iters,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        # elif self.args.policy == 'ddqn':
        #     assert self.args.act_space.__class__.__name__ == "Discrete", (
        #         "Can't train DDQN with continuous action space!")
        #     q_network = FlattenMlp(input_size=self.args.obs_dim,
        #                            output_size=self.args.act_space.n,
        #                            hidden_sizes=self.args.dqn_layers)
        #     self.agent = DoubleDQN(
        #         q_network,
        #         # optimiser_vae=self.optimizer_vae,
        #         lr=self.args.policy_lr,
        #         eps_optim=self.args.dqn_eps,
        #         alpha_optim=self.args.dqn_alpha,
        #         gamma=self.args.gamma,
        #         eps_init=self.args.dqn_epsilon_init,
        #         eps_final=self.args.dqn_epsilon_final,
        #         exploration_iters=self.args.dqn_exploration_iters,
        #         tau=self.args.soft_target_tau,
        #     ).to(ptu.device)
        elif self.args.policy == 'sac':
            assert self.args.act_space.__class__.__name__ == "Box", (
                "Can't train SAC with discrete action space!")
            q1_network = FlattenMlp(input_size=self.args.obs_dim +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            q2_network = FlattenMlp(input_size=self.args.obs_dim +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            policy = TanhGaussianPolicy(obs_dim=self.args.obs_dim,
                                        action_dim=self.args.action_dim,
                                        hidden_sizes=self.args.policy_layers)
            self.agent = SAC(
                policy,
                q1_network,
                q2_network,
                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
                entropy_alpha=self.args.entropy_alpha,
                automatic_entropy_tuning=self.args.automatic_entropy_tuning,
                alpha_lr=self.args.alpha_lr).to(ptu.device)
        else:
            raise NotImplementedError
Пример #9
0
class MetaLearnerSACMER:
    """
    Meta-Learner class.
    """
    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        # initialise environment
        self.env = make_env(self.args.env_name,
                            self.args.max_rollouts_per_task,
                            seed=self.args.seed,
                            n_tasks=self.args.num_tasks)

        # unwrapped env to get some info about the environment
        unwrapped_env = self.env.unwrapped
        # split to train/eval tasks
        shuffled_tasks = np.random.permutation(
            unwrapped_env.get_all_task_idx())
        self.train_tasks = shuffled_tasks[:self.args.num_train_tasks]
        if self.args.num_eval_tasks > 0:
            self.eval_tasks = shuffled_tasks[-self.args.num_eval_tasks:]
        else:
            self.eval_tasks = []
        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = unwrapped_env._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task
        self.args.max_trajectory_len = args.max_trajectory_len

        # get action / observation dimensions
        if isinstance(self.env.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.env.action_space.shape[0]
        self.args.obs_dim = self.env.observation_space.shape[0]
        self.args.num_states = unwrapped_env.num_states if hasattr(
            unwrapped_env, 'num_states') else None
        self.args.act_space = self.env.action_space

        # initialize policy
        self.initialize_policy()
        # initialize buffer for RL updates
        self.policy_storage = MultiTaskPolicyStorage(
            max_replay_buffer_size=int(self.args.policy_buffer_size),
            obs_dim=self._get_augmented_obs_dim(),
            action_space=self.env.action_space,
            tasks=self.train_tasks,
            trajectory_len=args.max_trajectory_len,
        )
        self.current_experience_storage = None

        self.args.belief_reward = False  # initialize arg to not use belief rewards

    def initialize_current_experience_storage(self):
        self.current_experience_storage = MultiTaskPolicyStorage(
            max_replay_buffer_size=int(self.args.num_tasks_sample *
                                       self.args.num_rollouts_per_iter *
                                       self.args.max_trajectory_len),
            obs_dim=self._get_augmented_obs_dim(),
            action_space=self.env.action_space,
            tasks=self.train_tasks,
            trajectory_len=self.args.max_trajectory_len,
        )

    def initialize_policy(self):

        if self.args.policy == 'dqn':
            assert self.args.act_space.__class__.__name__ == "Discrete", (
                "Can't train DQN with continuous action space!")
            q_network = FlattenMlp(input_size=self._get_augmented_obs_dim(),
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers)
            self.agent = DQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                eps_optim=self.args.dqn_eps,
                alpha_optim=self.args.dqn_alpha,
                gamma=self.args.gamma,
                eps_init=self.args.dqn_epsilon_init,
                eps_final=self.args.dqn_epsilon_final,
                exploration_iters=self.args.dqn_exploration_iters,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        elif self.args.policy == 'ddqn':
            assert self.args.act_space.__class__.__name__ == "Discrete", (
                "Can't train DDQN with continuous action space!")
            q_network = FlattenMlp(input_size=self._get_augmented_obs_dim(),
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers)
            self.agent = DoubleDQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                eps_optim=self.args.dqn_eps,
                alpha_optim=self.args.dqn_alpha,
                gamma=self.args.gamma,
                eps_init=self.args.dqn_epsilon_init,
                eps_final=self.args.dqn_epsilon_final,
                exploration_iters=self.args.dqn_exploration_iters,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        elif self.args.policy == 'sac':
            assert self.args.act_space.__class__.__name__ == "Box", (
                "Can't train SAC with discrete action space!")
            q1_network = FlattenMlp(input_size=self._get_augmented_obs_dim() +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            q2_network = FlattenMlp(input_size=self._get_augmented_obs_dim() +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            policy = TanhGaussianPolicy(obs_dim=self._get_augmented_obs_dim(),
                                        action_dim=self.args.action_dim,
                                        hidden_sizes=self.args.policy_layers)
            self.agent = SAC(
                policy,
                q1_network,
                q2_network,
                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
                entropy_alpha=self.args.entropy_alpha,
                automatic_entropy_tuning=self.args.automatic_entropy_tuning,
                alpha_lr=self.args.alpha_lr).to(ptu.device)
        else:
            raise NotImplementedError

    def train(self):
        """
        meta-training loop
        """

        self._start_training()
        for iter_ in range(self.args.num_iters):
            self.training_mode(True)
            # switch to belief reward
            if self.args.switch_to_belief_reward is not None and iter_ >= self.args.switch_to_belief_reward:
                self.args.belief_reward = True
            if iter_ == 0:
                print('Collecting initial pool of data..')
                for task in self.train_tasks:
                    self.task_idx = task
                    self.env.reset_task(idx=task)
                    # self.collect_rollouts(num_rollouts=self.args.num_init_rollouts_pool)
                    self.collect_rollouts(
                        num_rollouts=self.args.num_init_rollouts_pool,
                        random_actions=True)
                print('Done!')

            # collect data from subset of train tasks
            tasks_to_collect = [
                self.train_tasks[np.random.randint(len(self.train_tasks))]
                for _ in range(self.args.num_tasks_sample)
            ]
            self.initialize_current_experience_storage()
            for i in range(self.args.num_tasks_sample):
                task = tasks_to_collect[i]
                self.task_idx = task
                self.env.reset_task(idx=task)
                self.collect_rollouts(
                    num_rollouts=self.args.num_rollouts_per_iter)
            # update
            indices = np.random.choice(self.train_tasks, self.args.meta_batch)
            train_stats = self.update(indices,
                                      current_task_indices=tasks_to_collect)
            self.training_mode(False)

            if self.args.policy == 'dqn':
                self.agent.set_exploration_parameter(iter_ + 1)
            # evaluate and log
            if (iter_ + 1) % self.args.log_interval == 0:
                self.log(iter_ + 1, train_stats)

    def update(self, tasks, current_task_indices=[]):
        '''
        Meta-update
        :param tasks: list/array of task indices. perform update based on the tasks
        :param current_task_indices: indices that the last rollout collected
        :return:
        '''

        # --- RL TRAINING ---
        rl_losses_agg = {}
        prev_qf1 = copy.deepcopy(self.agent.qf1).to(ptu.device)
        prev_qf2 = copy.deepcopy(self.agent.qf2).to(ptu.device)
        prev_qf1_target = copy.deepcopy(self.agent.qf1_target).to(ptu.device)
        prev_qf2_target = copy.deepcopy(self.agent.qf2_target).to(ptu.device)
        prev_policy = copy.deepcopy(self.agent.policy).to(ptu.device)
        # prev_alpha = copy.deepcopy(self.agent.log_alpha_entropy)  # not relevant when alpha is const
        prev_nets = [
            prev_qf1, prev_qf2, prev_qf1_target, prev_qf2_target, prev_policy
        ]

        updates_from_current_experience = np.random.choice(
            self.args.rl_updates_per_iter, self.args.mer_s, replace=False)

        for update in range(self.args.rl_updates_per_iter):
            # sample random RL batch
            obs, actions, rewards, next_obs, terms = self.sample_rl_batch(
                tasks if update not in updates_from_current_experience else
                current_task_indices,
                self.args.batch_size,
                is_sample_from_current_experience=(
                    update in updates_from_current_experience))
            # flatten out task dimension
            t, b, _ = obs.size()
            obs = obs.view(t * b, -1)
            actions = actions.view(t * b, -1)
            rewards = rewards.view(t * b, -1)
            next_obs = next_obs.view(t * b, -1)
            terms = terms.view(t * b, -1)

            # RL update
            rl_losses = self.agent.update(obs, actions, rewards, next_obs,
                                          terms)

            for k, v in rl_losses.items():
                if update == 0:  # first iterate - create list
                    rl_losses_agg[k] = [v]
                else:  # append values
                    rl_losses_agg[k].append(v)

        # reptile update
        new_nets = [
            self.agent.qf1, self.agent.qf2, self.agent.qf1_target,
            self.agent.qf2_target, self.agent.policy
        ]
        for new_net, prev_net in zip(new_nets, prev_nets):
            ptu.soft_update_from_to(new_net, prev_net,
                                    self.args.mer_gamma_policy)
            ptu.copy_model_params_from_to(prev_net, new_net)

        # take mean
        for k in rl_losses_agg:
            rl_losses_agg[k] = np.mean(rl_losses_agg[k])
        self._n_rl_update_steps_total += self.args.rl_updates_per_iter

        train_stats = {
            **rl_losses_agg,
        }

        return train_stats

    def evaluate(self, tasks):
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps

        returns_per_episode = np.zeros((len(tasks), num_episodes))
        success_rate = np.zeros(len(tasks))

        if self.args.policy == 'dqn':
            values = np.zeros((len(tasks), self.args.max_trajectory_len))
        else:
            rewards = np.zeros((len(tasks), self.args.max_trajectory_len))
            obs_size = self.env.unwrapped.observation_space.shape[0]
            observations = np.zeros(
                (len(tasks), self.args.max_trajectory_len + 1, obs_size))

        for task_idx, task in enumerate(tasks):
            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            if self.args.policy == 'dqn':
                raise NotImplementedError
            else:
                observations[task_idx,
                             step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    augmented_obs = self.get_augmented_obs(obs=obs)
                    if self.args.policy == 'dqn':
                        action, value = self.agent.act(obs=augmented_obs,
                                                       deterministic=True)
                    else:
                        action, _, _, log_prob = self.agent.act(
                            obs=augmented_obs,
                            deterministic=self.args.eval_deterministic,
                            return_log_prob=True)
                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(
                        self.env, action.squeeze(dim=0))
                    running_reward += reward.item()

                    if self.args.policy == 'dqn':
                        values[task_idx, step] = value.item()
                    else:
                        rewards[task_idx, step] = reward.item()
                        observations[task_idx, step + 1, :] = ptu.get_numpy(
                            next_obs[0, :obs_size])

                    if "is_goal_state" in dir(
                            self.env.unwrapped
                    ) and self.env.unwrapped.is_goal_state():
                        success_rate[task_idx] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task_idx, episode_idx] = running_reward

        if self.args.policy == 'dqn':
            return returns_per_episode, success_rate, values,
        else:
            return returns_per_episode, success_rate, observations, rewards,

    def log(self, iteration, train_stats):
        # --- save models ---
        if iteration % self.args.save_interval == 0:
            save_path = os.path.join(self.tb_logger.full_output_folder,
                                     'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                self.agent.state_dict(),
                os.path.join(save_path, "agent{0}.pt".format(iteration)))

        # evaluate to get more stats
        if self.args.policy == 'dqn':
            # get stats on train tasks
            returns_train, success_rate_train, values = self.evaluate(
                self.train_tasks)
        else:
            # get stats on train tasks
            returns_train, success_rate_train, observations, rewards_train = self.evaluate(
                self.train_tasks[:len(self.eval_tasks)])
            returns_eval, success_rate_eval, observations_eval, rewards_eval = self.evaluate(
                self.eval_tasks)

        if self.args.log_tensorboard:
            # some metrics
            self.tb_logger.writer.add_scalar(
                'metrics/successes_in_buffer',
                self._successes_in_buffer / self._n_env_steps_total,
                self._n_env_steps_total)

            if self.args.max_rollouts_per_task > 1:
                for episode_idx in range(self.args.max_rollouts_per_task):
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/episode_{}'.format(episode_idx +
                                                                  1),
                        np.mean(returns_train[:, episode_idx]),
                        self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'returns_multi_episode/sum',
                    np.mean(np.sum(returns_train, axis=-1)),
                    self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'returns_multi_episode/success_rate',
                    np.mean(success_rate_train), self._n_env_steps_total)
                if self.args.policy != 'dqn':
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/sum_eval',
                        np.mean(np.sum(returns_eval, axis=-1)),
                        self._n_env_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/success_rate_eval',
                        np.mean(success_rate_eval), self._n_env_steps_total)
            else:
                # self.tb_logger.writer.add_scalar('returns/returns_mean', np.mean(returns),
                #                                  self._n_env_steps_total)
                # self.tb_logger.writer.add_scalar('returns/returns_std', np.std(returns),
                #                                  self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns/returns_mean_train',
                                                 np.mean(returns_train),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns/returns_std_train',
                                                 np.std(returns_train),
                                                 self._n_env_steps_total)
                # self.tb_logger.writer.add_scalar('returns/success_rate', np.mean(success_rate),
                #                                  self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns/success_rate_train',
                                                 np.mean(success_rate_train),
                                                 self._n_env_steps_total)

            # policy
            if self.args.policy == 'dqn':
                self.tb_logger.writer.add_scalar('policy/value_init',
                                                 np.mean(values[:, 0]),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'policy/value_halfway',
                    np.mean(values[:, int(values.shape[-1] / 2)]),
                    self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('policy/value_final',
                                                 np.mean(values[:, -1]),
                                                 self._n_env_steps_total)

                self.tb_logger.writer.add_scalar('policy/exploration_epsilon',
                                                 self.agent.eps,
                                                 self._n_env_steps_total)
                # RL losses
                self.tb_logger.writer.add_scalar(
                    'rl_losses/qf_loss_vs_n_updates', train_stats['qf_loss'],
                    self._n_rl_update_steps_total)
                self.tb_logger.writer.add_scalar(
                    'rl_losses/qf_loss_vs_n_env_steps', train_stats['qf_loss'],
                    self._n_env_steps_total)
            else:
                self.tb_logger.writer.add_scalar('rl_losses/qf1_loss',
                                                 train_stats['qf1_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/qf2_loss',
                                                 train_stats['qf2_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/policy_loss',
                                                 train_stats['policy_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'rl_losses/alpha_loss', train_stats['alpha_entropy_loss'],
                    self._n_env_steps_total)

            # weights and gradients
            if self.args.policy == 'dqn':
                self.tb_logger.writer.add_scalar(
                    'weights/q_network',
                    list(self.agent.qf.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q_network',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/q_target',
                    list(self.agent.target_qf.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.target_qf.parameters())[0].grad is not None:
                    param_list = list(self.agent.target_qf.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q_target',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
            else:
                self.tb_logger.writer.add_scalar(
                    'weights/q1_network',
                    list(self.agent.qf1.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf1.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf1.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q1_network',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/q1_target',
                    list(self.agent.qf1_target.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf1_target.parameters()
                        )[0].grad is not None:
                    param_list = list(self.agent.qf1_target.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q1_target',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/q2_network',
                    list(self.agent.qf2.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf2.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf2.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q2_network',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/q2_target',
                    list(self.agent.qf2_target.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.qf2_target.parameters()
                        )[0].grad is not None:
                    param_list = list(self.agent.qf2_target.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/q2_target',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar(
                    'weights/policy',
                    list(self.agent.policy.parameters())[0].mean(),
                    self._n_env_steps_total)
                if list(self.agent.policy.parameters())[0].grad is not None:
                    param_list = list(self.agent.policy.parameters())
                    self.tb_logger.writer.add_scalar(
                        'gradients/policy',
                        sum([
                            param_list[i].grad.mean()
                            for i in range(len(param_list))
                        ]), self._n_env_steps_total)

            #log iteration
            self.tb_logger.writer.add_scalar('iteration', iteration,
                                             self._n_env_steps_total)

        # output to user
        # print("Iteration -- {:3d}, Num. RL updates -- {:6d}, Elapsed time {:5d}[s]".
        #       format(iteration,
        #              self._n_rl_update_steps_total,
        #              int(time.time() - self._start_time)))
        print(
            "Iteration -- {}, Success rate train -- {:.3f}, Success rate eval.-- {:.3f}, "
            "Avg. return train -- {:.3f}, Avg. return eval. -- {:.3f}, Elapsed time {:5d}[s]"
            .format(iteration, np.mean(success_rate_train),
                    np.mean(success_rate_eval),
                    np.mean(np.sum(returns_train, axis=-1)),
                    np.mean(np.sum(returns_eval, axis=-1)),
                    int(time.time() - self._start_time)))

    def training_mode(self, mode):
        # policy
        self.agent.train(mode)

    def collect_rollouts(self, num_rollouts, random_actions=False):
        '''

        :param num_rollouts:
        :param random_actions: whether to use policy to sample actions, or randomly sample action space
        :return:
        '''

        for rollout in range(num_rollouts):
            obs = ptu.from_numpy(self.env.reset(self.task_idx))
            obs = obs.reshape(-1, obs.shape[-1])
            done_rollout = False
            # self.policy_storage.reset_running_episode(self.task_idx)

            # if self.args.fixed_latent_params:
            #     assert 2 ** self.args.task_embedding_size >= self.args.num_tasks
            #     task_mean = ptu.FloatTensor(utl.vertices(self.args.task_embedding_size)[self.task_idx])
            #     task_logvar = -2. * ptu.ones_like(task_logvar)   # arbitrary negative enough number
            # add distribution parameters to observation - policy is conditioned on posterior
            augmented_obs = self.get_augmented_obs(obs=obs)

            while not done_rollout:
                if random_actions:
                    if self.args.policy == 'dqn':
                        action = ptu.FloatTensor([[
                            self.env.action_space.sample()
                        ]]).type(torch.long)  # Sample random action
                    else:
                        action = ptu.FloatTensor(
                            [self.env.action_space.sample()])
                else:
                    if self.args.policy == 'dqn':
                        action, _ = self.agent.act(obs=augmented_obs)  # DQN
                    else:
                        action, _, _, _ = self.agent.act(
                            obs=augmented_obs)  # SAC
                # observe reward and next obs
                next_obs, reward, done, info = utl.env_step(
                    self.env, action.squeeze(dim=0))
                done_rollout = False if ptu.get_numpy(
                    done[0][0]) == 0. else True

                # get augmented next obs
                augmented_next_obs = self.get_augmented_obs(obs=next_obs)

                # add data to policy buffer - (s+, a, r, s'+, term)
                term = self.env.unwrapped.is_goal_state(
                ) if "is_goal_state" in dir(self.env.unwrapped) else False
                self.policy_storage.add_sample(
                    task=self.task_idx,
                    observation=ptu.get_numpy(augmented_obs.squeeze(dim=0)),
                    action=ptu.get_numpy(action.squeeze(dim=0)),
                    reward=ptu.get_numpy(reward.squeeze(dim=0)),
                    terminal=np.array([term], dtype=float),
                    next_observation=ptu.get_numpy(
                        augmented_next_obs.squeeze(dim=0)))
                if not random_actions:
                    self.current_experience_storage.add_sample(
                        task=self.task_idx,
                        observation=ptu.get_numpy(
                            augmented_obs.squeeze(dim=0)),
                        action=ptu.get_numpy(action.squeeze(dim=0)),
                        reward=ptu.get_numpy(reward.squeeze(dim=0)),
                        terminal=np.array([term], dtype=float),
                        next_observation=ptu.get_numpy(
                            augmented_next_obs.squeeze(dim=0)))

                # set: obs <- next_obs
                obs = next_obs.clone()
                augmented_obs = augmented_next_obs.clone()

                # update statistics
                self._n_env_steps_total += 1
                if "is_goal_state" in dir(
                        self.env.unwrapped
                ) and self.env.unwrapped.is_goal_state():  # count successes
                    self._successes_in_buffer += 1
            self._n_rollouts_total += 1

    def get_augmented_obs(self, obs):
        augmented_obs = obs.clone()
        return augmented_obs

    def _get_augmented_obs_dim(self):
        dim = utl.get_dim(self.env.observation_space)
        return dim

    def sample_rl_batch(self,
                        tasks,
                        batch_size,
                        is_sample_from_current_experience=False):
        ''' sample batch of unordered rl training data from a list/array of tasks '''
        # this batch consists of transitions sampled randomly from replay buffer
        if is_sample_from_current_experience:
            batches = [
                ptu.np_to_pytorch_batch(
                    self.current_experience_storage.random_batch(
                        task, batch_size)) for task in tasks
            ]
        else:
            batches = [
                ptu.np_to_pytorch_batch(
                    self.policy_storage.random_batch(task, batch_size))
                for task in tasks
            ]
        unpacked = [utl.unpack_batch(batch) for batch in batches]
        # group elements together
        unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
        unpacked = [torch.cat(x, dim=0) for x in unpacked]
        return unpacked

    def _start_training(self):
        self._n_env_steps_total = 0
        self._n_rl_update_steps_total = 0
        self._n_rollouts_total = 0
        self._successes_in_buffer = 0

        self._start_time = time.time()

    def load_model(self, device='cpu', **kwargs):
        if "agent_path" in kwargs:
            self.agent.load_state_dict(
                torch.load(kwargs["agent_path"], map_location=device))
        self.training_mode(False)