Пример #1
0
    def __init__(self, args):

        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # calculate number of updates and keep count of frames/iterations
        self.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes
        self.frames = 0
        self.iter_idx = 0

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            device=device,
            episodes_per_task=self.args.max_rollouts_per_task,
            normalise_rew=args.norm_rew_for_policy,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        self.args.max_trajectory_len = self.envs._max_episode_steps
        self.args.max_trajectory_len *= self.args.max_rollouts_per_task

        # get policy input dimensions
        self.args.state_dim = self.envs.observation_space.shape[0]
        self.args.task_dim = self.envs.task_dim
        self.args.belief_dim = self.envs.belief_dim
        self.args.num_states = self.envs.num_states
        # get policy output (action) dimensions
        self.args.action_space = self.envs.action_space
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        elif isinstance(self.envs.action_space,
                        gym.spaces.multi_discrete.MultiDiscrete):
            self.args.action_dim = self.envs.action_space.nvec[0]
        else:
            self.args.action_dim = self.envs.action_space.shape[0]

        # initialise VAE and policy
        self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx)
        self.policy_storage = self.initialise_policy_storage()
        self.policy = self.initialise_policy()
Пример #2
0
    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        # initialise environment
        self.env = make_env(self.args.env_name,
                            self.args.max_rollouts_per_task,
                            seed=self.args.seed,
                            n_tasks=self.args.num_tasks)

        # unwrapped env to get some info about the environment
        unwrapped_env = self.env.unwrapped
        # split to train/eval tasks
        shuffled_tasks = np.random.permutation(
            unwrapped_env.get_all_task_idx())
        self.train_tasks = shuffled_tasks[:self.args.num_train_tasks]
        if self.args.num_eval_tasks > 0:
            self.eval_tasks = shuffled_tasks[-self.args.num_eval_tasks:]
        else:
            self.eval_tasks = []
        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = unwrapped_env._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task
        self.args.max_trajectory_len = args.max_trajectory_len

        # get action / observation dimensions
        if isinstance(self.env.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.env.action_space.shape[0]
        self.args.obs_dim = self.env.observation_space.shape[0]
        self.args.num_states = unwrapped_env.num_states if hasattr(
            unwrapped_env, 'num_states') else None
        self.args.act_space = self.env.action_space

        # initialize policy
        self.initialize_policy()
        # initialize buffer for RL updates
        self.policy_storage = MultiTaskPolicyStorage(
            max_replay_buffer_size=int(self.args.policy_buffer_size),
            obs_dim=self._get_augmented_obs_dim(),
            action_space=self.env.action_space,
            tasks=self.train_tasks,
            trajectory_len=args.max_trajectory_len,
        )
        self.current_experience_storage = None

        self.args.belief_reward = False  # initialize arg to not use belief rewards
Пример #3
0
    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        self.args, env = off_utl.expand_args(self.args, include_act_space=True)
        if self.args.act_space.__class__.__name__ == "Discrete":
            self.args.policy = 'dqn'
        else:
            self.args.policy = 'sac'

        # load buffers with data
        if 'load_data' not in self.args or self.args.load_data:
            goals, augmented_obs_dim = self.load_buffer(
                env)  # env is input just for possible relabelling option
            self.args.augmented_obs_dim = augmented_obs_dim
            self.goals = goals

        # initialize policy
        self.initialize_policy()

        # load vae for inference in evaluation
        self.load_vae()

        # create environment for evaluation
        self.env = make_env(
            args.env_name,
            args.max_rollouts_per_task,
            presampled_tasks=args.presampled_tasks,
            seed=args.seed,
        )
        # n_tasks=self.args.num_eval_tasks)
        if self.args.env_name == 'GridNavi-v2':
            self.env.unwrapped.goals = [
                tuple(goal.astype(int)) for goal in self.goals
            ]
Пример #4
0
    def __init__(self, args):
        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # count number of frames and number of meta-iterations
        self.frames = 0
        self.iter_idx = 0

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # calculate number of meta updates
        self.args.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes

        # get action / observation dimensions
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]
        self.args.obs_dim = self.envs.observation_space.shape[0]
        self.args.num_states = self.envs.num_states if str.startswith(
            self.args.env_name, 'Grid') else None
        self.args.act_space = self.envs.action_space

        self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx)

        self.initialise_policy()
Пример #5
0
class OfflineMetaLearner:
    """
    Off-line Meta-Learner class, a.k.a no interaction with env.
    """
    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        self.args, env = off_utl.expand_args(self.args, include_act_space=True)
        if self.args.act_space.__class__.__name__ == "Discrete":
            self.args.policy = 'dqn'
        else:
            self.args.policy = 'sac'

        # load buffers with data
        if 'load_data' not in self.args or self.args.load_data:
            goals, augmented_obs_dim = self.load_buffer(
                env)  # env is input just for possible relabelling option
            self.args.augmented_obs_dim = augmented_obs_dim
            self.goals = goals

        # initialize policy
        self.initialize_policy()

        # load vae for inference in evaluation
        self.load_vae()

        # create environment for evaluation
        self.env = make_env(
            args.env_name,
            args.max_rollouts_per_task,
            presampled_tasks=args.presampled_tasks,
            seed=args.seed,
        )
        # n_tasks=self.args.num_eval_tasks)
        if self.args.env_name == 'GridNavi-v2':
            self.env.unwrapped.goals = [
                tuple(goal.astype(int)) for goal in self.goals
            ]

    def initialize_policy(self):
        if self.args.policy == 'dqn':
            q_network = FlattenMlp(input_size=self.args.augmented_obs_dim,
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers).to(
                                       ptu.device)
            self.agent = DQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        else:
            # assert self.args.act_space.__class__.__name__ == "Box", (
            #     "Can't train SAC with discrete action space!")
            q1_network = FlattenMlp(
                input_size=self.args.augmented_obs_dim + self.args.action_dim,
                output_size=1,
                hidden_sizes=self.args.dqn_layers).to(ptu.device)
            q2_network = FlattenMlp(
                input_size=self.args.augmented_obs_dim + self.args.action_dim,
                output_size=1,
                hidden_sizes=self.args.dqn_layers).to(ptu.device)
            policy = TanhGaussianPolicy(
                obs_dim=self.args.augmented_obs_dim,
                action_dim=self.args.action_dim,
                hidden_sizes=self.args.policy_layers).to(ptu.device)
            self.agent = SAC(
                policy,
                q1_network,
                q2_network,
                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
                use_cql=self.args.use_cql if 'use_cql' in self.args else False,
                alpha_cql=self.args.alpha_cql
                if 'alpha_cql' in self.args else None,
                entropy_alpha=self.args.entropy_alpha,
                automatic_entropy_tuning=self.args.automatic_entropy_tuning,
                alpha_lr=self.args.alpha_lr,
                clip_grad_value=self.args.clip_grad_value,
            ).to(ptu.device)

    def load_vae(self):
        self.vae = VAE(self.args)
        vae_models_path = os.path.join(self.args.vae_dir, self.args.env_name,
                                       self.args.vae_model_name, 'models')
        off_utl.load_trained_vae(self.vae, vae_models_path)

    def load_buffer(self, env):
        if self.args.hindsight_relabelling:  # without arr_type loading -- GPU will explode
            dataset, goals = off_utl.load_dataset(
                data_dir=self.args.relabelled_data_dir,
                args=self.args,
                num_tasks=self.args.num_train_tasks,
                allow_dense_data_loading=False,
                arr_type='numpy')
            dataset = off_utl.batch_to_trajectories(dataset, self.args)
            dataset, goals = off_utl.mix_task_rollouts(
                dataset, env, goals, self.args)  # reward relabelling
            dataset = off_utl.trajectories_to_batch(dataset)
        else:
            dataset, goals = off_utl.load_dataset(
                data_dir=self.args.relabelled_data_dir,
                args=self.args,
                num_tasks=self.args.num_train_tasks,
                allow_dense_data_loading=False,
                arr_type='numpy')
        augmented_obs_dim = dataset[0][0].shape[1]
        self.storage = MultiTaskPolicyStorage(
            max_replay_buffer_size=max([d[0].shape[0] for d in dataset]),
            obs_dim=dataset[0][0].shape[1],
            action_space=self.args.act_space,
            tasks=range(len(goals)),
            trajectory_len=self.args.trajectory_len)
        for task, set in enumerate(dataset):
            self.storage.add_samples(task,
                                     observations=set[0],
                                     actions=set[1],
                                     rewards=set[2],
                                     next_observations=set[3],
                                     terminals=set[4])
        return goals, augmented_obs_dim

    def train(self):
        self._start_training()
        for iter_ in range(self.args.num_iters):
            self.training_mode(True)
            indices = np.random.choice(len(self.goals), self.args.meta_batch)
            train_stats = self.update(indices)

            self.training_mode(False)
            self.log(iter_ + 1, train_stats)

    def update(self, tasks):
        rl_losses_agg = {}
        for update in range(self.args.rl_updates_per_iter):
            # sample random RL batch
            obs, actions, rewards, next_obs, terms = self.sample_rl_batch(
                tasks, self.args.batch_size)
            # flatten out task dimension
            t, b, _ = obs.size()
            obs = obs.view(t * b, -1)
            actions = actions.view(t * b, -1)
            rewards = rewards.view(t * b, -1)
            next_obs = next_obs.view(t * b, -1)
            terms = terms.view(t * b, -1)

            # RL update
            rl_losses = self.agent.update(obs,
                                          actions,
                                          rewards,
                                          next_obs,
                                          terms,
                                          action_space=self.env.action_space)

            for k, v in rl_losses.items():
                if update == 0:  # first iterate - create list
                    rl_losses_agg[k] = [v]
                else:  # append values
                    rl_losses_agg[k].append(v)
        # take mean
        for k in rl_losses_agg:
            rl_losses_agg[k] = np.mean(rl_losses_agg[k])
        self._n_rl_update_steps_total += self.args.rl_updates_per_iter

        return rl_losses_agg

    def evaluate(self):
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps
        num_tasks = self.args.num_eval_tasks
        obs_size = self.env.unwrapped.observation_space.shape[0]

        returns_per_episode = np.zeros((num_tasks, num_episodes))
        success_rate = np.zeros(num_tasks)

        rewards = np.zeros((num_tasks, self.args.trajectory_len))
        reward_preds = np.zeros((num_tasks, self.args.trajectory_len))
        observations = np.zeros(
            (num_tasks, self.args.trajectory_len + 1, obs_size))
        if self.args.policy == 'sac':
            log_probs = np.zeros((num_tasks, self.args.trajectory_len))

        # This part is very specific for the Semi-Circle env
        # if self.args.env_name == 'PointRobotSparse-v0':
        #     reward_belief = np.zeros((num_tasks, self.args.trajectory_len))
        #
        #     low_x, high_x, low_y, high_y = -2., 2., -1., 2.
        #     resolution = 0.1
        #     grid_x = np.arange(low_x, high_x + resolution, resolution)
        #     grid_y = np.arange(low_y, high_y + resolution, resolution)
        #     centers_x = (grid_x[:-1] + grid_x[1:]) / 2
        #     centers_y = (grid_y[:-1] + grid_y[1:]) / 2
        #     yv, xv = np.meshgrid(centers_y, centers_x, sparse=False, indexing='ij')
        #     centers = np.vstack([xv.ravel(), yv.ravel()]).T
        #     n_grid_points = centers.shape[0]
        #     reward_belief_discretized = np.zeros((num_tasks, self.args.trajectory_len, centers.shape[0]))

        for task_loop_i, task in enumerate(
                self.env.unwrapped.get_all_eval_task_idx()):
            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            # get prior parameters
            with torch.no_grad():
                task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder.prior(
                    batch_size=1)

            observations[task_loop_i,
                         step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    augmented_obs = self.get_augmented_obs(
                        obs, task_mean, task_logvar)
                    if self.args.policy == 'dqn':
                        action, value = self.agent.act(obs=augmented_obs,
                                                       deterministic=True)
                    else:
                        action, _, _, log_prob = self.agent.act(
                            obs=augmented_obs,
                            deterministic=self.args.eval_deterministic,
                            return_log_prob=True)

                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(
                        self.env, action.squeeze(dim=0))
                    running_reward += reward.item()
                    # done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True
                    # update encoding
                    task_sample, task_mean, task_logvar, hidden_state = self.update_encoding(
                        obs=next_obs,
                        action=action,
                        reward=reward,
                        done=done,
                        hidden_state=hidden_state)
                    rewards[task_loop_i, step] = reward.item()
                    reward_preds[task_loop_i, step] = ptu.get_numpy(
                        self.vae.reward_decoder(task_sample, next_obs, obs,
                                                action)[0, 0])

                    # This part is very specific for the Semi-Circle env
                    # if self.args.env_name == 'PointRobotSparse-v0':
                    #     reward_belief[task, step] = ptu.get_numpy(
                    #         self.vae.compute_belief_reward(task_mean, task_logvar, obs, next_obs, action)[0])
                    #
                    #     reward_belief_discretized[task, step, :] = ptu.get_numpy(
                    #         self.vae.compute_belief_reward(task_mean.repeat(n_grid_points, 1),
                    #                                        task_logvar.repeat(n_grid_points, 1),
                    #                                        None,
                    #                                        torch.cat((ptu.FloatTensor(centers),
                    #                                                   ptu.zeros(centers.shape[0], 1)), dim=-1).unsqueeze(0),
                    #                                        None)[:, 0])

                    observations[task_loop_i, step + 1, :] = ptu.get_numpy(
                        next_obs[0, :obs_size])
                    if self.args.policy != 'dqn':
                        log_probs[task_loop_i,
                                  step] = ptu.get_numpy(log_prob[0])

                    if "is_goal_state" in dir(
                            self.env.unwrapped
                    ) and self.env.unwrapped.is_goal_state():
                        success_rate[task_loop_i] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task_loop_i, episode_idx] = running_reward

        if self.args.policy == 'dqn':
            return returns_per_episode, success_rate, observations, rewards, reward_preds
        # This part is very specific for the Semi-Circle env
        # elif self.args.env_name == 'PointRobotSparse-v0':
        #     return returns_per_episode, success_rate, log_probs, observations, \
        #            rewards, reward_preds, reward_belief, reward_belief_discretized, centers
        else:
            return returns_per_episode, success_rate, log_probs, observations, rewards, reward_preds

    def log(self, iteration, train_stats):
        # --- save model ---
        if iteration % self.args.save_interval == 0:
            save_path = os.path.join(self.tb_logger.full_output_folder,
                                     'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                self.agent.state_dict(),
                os.path.join(save_path, "agent{0}.pt".format(iteration)))

        if iteration % self.args.log_interval == 0:
            if self.args.policy == 'dqn':
                returns, success_rate, observations, rewards, reward_preds = self.evaluate(
                )
            # This part is super specific for the Semi-Circle env
            # elif self.args.env_name == 'PointRobotSparse-v0':
            #     returns, success_rate, log_probs, observations, \
            #     rewards, reward_preds, reward_belief, reward_belief_discretized, points = self.evaluate()
            else:
                returns, success_rate, log_probs, observations, rewards, reward_preds = self.evaluate(
                )

            if self.args.log_tensorboard:
                tasks_to_vis = np.random.choice(self.args.num_eval_tasks, 5)
                for i, task in enumerate(tasks_to_vis):
                    self.env.reset(task)
                    if PLOT_VIS:
                        self.tb_logger.writer.add_figure(
                            'policy_vis/task_{}'.format(i),
                            utl_eval.plot_rollouts(observations[task, :],
                                                   self.env),
                            self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_figure(
                        'reward_prediction_train/task_{}'.format(i),
                        utl_eval.plot_rew_pred_vs_rew(rewards[task, :],
                                                      reward_preds[task, :]),
                        self._n_rl_update_steps_total)
                    # self.tb_logger.writer.add_figure('reward_prediction_train/task_{}'.format(i),
                    #                                  utl_eval.plot_rew_pred_vs_reward_belief_vs_rew(rewards[task, :],
                    #                                                                                 reward_preds[task, :],
                    #                                                                                 reward_belief[task, :]),
                    #                                  self._n_rl_update_steps_total)
                    # if self.args.env_name == 'PointRobotSparse-v0':     # This part is super specific for the Semi-Circle env
                    #     for t in range(0, int(self.args.trajectory_len/4), 3):
                    #         self.tb_logger.writer.add_figure('discrete_belief_reward_pred_task_{}/timestep_{}'.format(i, t),
                    #                                          utl_eval.plot_discretized_belief_halfcircle(reward_belief_discretized[task, t, :],
                    #                                                                                      points, self.env,
                    #                                                                                      observations[task, :t+1]),
                    #                                          self._n_rl_update_steps_total)
                if self.args.max_rollouts_per_task > 1:
                    for episode_idx in range(self.args.max_rollouts_per_task):
                        self.tb_logger.writer.add_scalar(
                            'returns_multi_episode/episode_{}'.format(
                                episode_idx + 1),
                            np.mean(returns[:, episode_idx]),
                            self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/sum',
                        np.mean(np.sum(returns, axis=-1)),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns_multi_episode/success_rate',
                        np.mean(success_rate), self._n_rl_update_steps_total)
                else:
                    self.tb_logger.writer.add_scalar(
                        'returns/returns_mean', np.mean(returns),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns/returns_std', np.std(returns),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'returns/success_rate', np.mean(success_rate),
                        self._n_rl_update_steps_total)
                if self.args.policy == 'dqn':
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf_loss_vs_n_updates',
                        train_stats['qf_loss'], self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q_network',
                        list(self.agent.qf.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q_target',
                        list(self.agent.target_qf.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.target_qf.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.target_qf.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                else:
                    self.tb_logger.writer.add_scalar(
                        'policy/log_prob', np.mean(log_probs),
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf1_loss', train_stats['qf1_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/qf2_loss', train_stats['qf2_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/policy_loss', train_stats['policy_loss'],
                        self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'rl_losses/alpha_entropy_loss',
                        train_stats['alpha_entropy_loss'],
                        self._n_rl_update_steps_total)

                    # weights and gradients
                    self.tb_logger.writer.add_scalar(
                        'weights/q1_network',
                        list(self.agent.qf1.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf1.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf1.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q1_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q1_target',
                        list(self.agent.qf1_target.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf1_target.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.qf1_target.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q1_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q2_network',
                        list(self.agent.qf2.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf2.parameters())[0].grad is not None:
                        param_list = list(self.agent.qf2.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q2_network',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/q2_target',
                        list(self.agent.qf2_target.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.qf2_target.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.qf2_target.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/q2_target',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)
                    self.tb_logger.writer.add_scalar(
                        'weights/policy',
                        list(self.agent.policy.parameters())[0].mean(),
                        self._n_rl_update_steps_total)
                    if list(self.agent.policy.parameters()
                            )[0].grad is not None:
                        param_list = list(self.agent.policy.parameters())
                        self.tb_logger.writer.add_scalar(
                            'gradients/policy',
                            sum([
                                param_list[i].grad.mean()
                                for i in range(len(param_list))
                            ]), self._n_rl_update_steps_total)

            for k, v in [
                ('num_rl_updates', self._n_rl_update_steps_total),
                ('time_elapsed', time.time() - self._start_time),
                ('iteration', iteration),
            ]:
                self.tb_logger.writer.add_scalar(k, v,
                                                 self._n_rl_update_steps_total)
            self.tb_logger.finish_iteration(iteration)

            print(
                "Iteration -- {}, Success rate -- {:.3f}, Avg. return -- {:.3f}, Elapsed time {:5d}[s]"
                .format(iteration, np.mean(success_rate),
                        np.mean(np.sum(returns, axis=-1)),
                        int(time.time() - self._start_time)))

    def sample_rl_batch(self, tasks, batch_size):
        ''' sample batch of unordered rl training data from a list/array of tasks '''
        # this batch consists of transitions sampled randomly from replay buffer
        batches = [
            ptu.np_to_pytorch_batch(self.storage.random_batch(
                task, batch_size)) for task in tasks
        ]
        unpacked = [utl.unpack_batch(batch) for batch in batches]
        # group elements together
        unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
        unpacked = [torch.cat(x, dim=0) for x in unpacked]
        return unpacked

    def _start_training(self):
        self._n_rl_update_steps_total = 0
        self._start_time = time.time()

    def training_mode(self, mode):
        self.agent.train(mode)

    def update_encoding(self, obs, action, reward, done, hidden_state):
        # reset hidden state of the recurrent net when the task is done
        hidden_state = self.vae.encoder.reset_hidden(hidden_state, done)
        with torch.no_grad():  # size should be (batch, dim)
            task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder(
                actions=action.float(),
                states=obs,
                rewards=reward,
                hidden_state=hidden_state,
                return_prior=False)

        return task_sample, task_mean, task_logvar, hidden_state

    @staticmethod
    def get_augmented_obs(obs, mean, logvar):
        mean = mean.reshape((-1, mean.shape[-1]))
        logvar = logvar.reshape((-1, logvar.shape[-1]))
        return torch.cat((obs, mean, logvar), dim=-1)

    def load_model(self, agent_path, device='cpu'):
        self.agent.load_state_dict(torch.load(agent_path, map_location=device))
        self.load_vae()
        self.training_mode(False)
Пример #6
0
class MetaLearner:
    """
    Meta-Learner class with the main training loop for variBAD.
    """
    def __init__(self, args):
        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # count number of frames and number of meta-iterations
        self.frames = 0
        self.iter_idx = 0

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # calculate number of meta updates
        self.args.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes

        # get action / observation dimensions
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]
        self.args.obs_dim = self.envs.observation_space.shape[0]
        self.args.num_states = self.envs.num_states if str.startswith(
            self.args.env_name, 'Grid') else None
        self.args.act_space = self.envs.action_space

        self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx)

        self.initialise_policy()

    def initialise_policy(self):

        # initialise rollout storage for the policy
        self.policy_storage = OnlineStorage(
            self.args,
            self.args.policy_num_steps,
            self.args.num_processes,
            self.args.obs_dim,
            self.args.act_space,
            hidden_size=self.args.aggregator_hidden_size,
            latent_dim=self.args.latent_dim,
            normalise_observations=self.args.norm_obs_for_policy,
            normalise_rewards=self.args.norm_rew_for_policy,
        )

        # initialise policy network
        input_dim = self.args.obs_dim * int(
            self.args.condition_policy_on_state)
        input_dim += (
            1 + int(not self.args.sample_embeddings)) * self.args.latent_dim

        if hasattr(self.envs.action_space, 'low'):
            action_low = self.envs.action_space.low
            action_high = self.envs.action_space.high
        else:
            action_low = action_high = None

        policy_net = Policy(
            state_dim=input_dim,
            action_space=self.args.act_space,
            init_std=self.args.policy_init_std,
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            normalise_actions=self.args.normalise_actions,
            action_low=action_low,
            action_high=action_high,
        ).to(device)

        # initialise policy trainer
        if self.args.policy == 'a2c':
            self.policy = A2C(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                alpha=self.args.a2c_alpha,
            )
        elif self.args.policy == 'ppo':
            self.policy = PPO(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
            )
        else:
            raise NotImplementedError

    def train(self):
        """
        Given some stream of environments and a logger (tensorboard),
        (meta-)trains the policy.
        """

        start_time = time.time()

        # reset environments
        (prev_obs_raw, prev_obs_normalised) = self.envs.reset()
        prev_obs_raw = prev_obs_raw.to(device)
        prev_obs_normalised = prev_obs_normalised.to(device)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw)
        self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised)
        self.policy_storage.to(device)

        vae_is_pretrained = False
        for self.iter_idx in range(self.args.num_updates):

            # First, re-compute the hidden states given the current rollouts (since the VAE might've changed)
            # compute latent embedding (will return prior if current trajectory is empty)
            with torch.no_grad():
                latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory(
                )

            # check if we flushed the policy storage
            assert len(self.policy_storage.latent_mean) == 0

            # add this initial hidden state to the policy storage
            self.policy_storage.hidden_states[0].copy_(hidden_state)
            self.policy_storage.latent_samples.append(latent_sample.clone())
            self.policy_storage.latent_mean.append(latent_mean.clone())
            self.policy_storage.latent_logvar.append(latent_logvar.clone())

            # rollout policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action, action_log_prob = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        deterministic=False,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                    )
                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw, rew_normalised), done, infos = utl.env_step(
                        self.envs, action)
                tasks = torch.FloatTensor([info['task']
                                           for info in infos]).to(device)
                done = torch.from_numpy(np.array(
                    done, dtype=int)).to(device).float().view((-1, 1))

                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                # compute next embedding (for next loop and/or value prediction bootstrap)
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=self.vae.encoder,
                    next_obs=next_obs_raw,
                    action=action,
                    reward=rew_raw,
                    done=done,
                    hidden_state=hidden_state)

                # before resetting, update the embedding and add to vae buffer
                # (last state might include useful task info)
                if not (self.args.disable_decoder
                        and self.args.disable_stochasticity_in_latent):
                    self.vae.rollout_storage.insert(prev_obs_raw.clone(),
                                                    action.detach().clone(),
                                                    next_obs_raw.clone(),
                                                    rew_raw.clone(),
                                                    done.clone(),
                                                    tasks.clone())

                # add the obs before reset to the policy storage
                # (only used to recompute embeddings if rlloss is backpropagated through encoder)
                self.policy_storage.next_obs_raw[step] = next_obs_raw.clone()
                self.policy_storage.next_obs_normalised[
                    step] = next_obs_normalised.clone()

                # reset environments that are done
                done_indices = np.argwhere(
                    done.cpu().detach().flatten()).flatten()
                if len(done_indices) == self.args.num_processes:
                    [next_obs_raw, next_obs_normalised] = self.envs.reset()
                    if not self.args.sample_embeddings:
                        latent_sample = latent_sample
                else:
                    for i in done_indices:
                        [next_obs_raw[i],
                         next_obs_normalised[i]] = self.envs.reset(index=i)
                        if not self.args.sample_embeddings:
                            latent_sample[i] = latent_sample[i]

                # # add experience to policy buffer
                self.policy_storage.insert(
                    obs_raw=next_obs_raw,
                    obs_normalised=next_obs_normalised,
                    actions=action,
                    action_log_probs=action_log_prob,
                    rewards_raw=rew_raw,
                    rewards_normalised=rew_normalised,
                    value_preds=value,
                    masks=masks_done,
                    bad_masks=bad_masks,
                    done=done,
                    hidden_states=hidden_state.squeeze(0).detach(),
                    latent_sample=latent_sample.detach(),
                    latent_mean=latent_mean.detach(),
                    latent_logvar=latent_logvar.detach(),
                )

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                self.frames += self.args.num_processes

            # --- UPDATE ---

            if self.args.precollect_len <= self.frames:
                # check if we are pre-training the VAE
                if self.args.pretrain_len > 0 and not vae_is_pretrained:
                    for _ in range(self.args.pretrain_len):
                        self.vae.compute_vae_loss(update=True)
                    vae_is_pretrained = True

                # otherwise do the normal update (policy + vae)
                else:

                    train_stats = self.update(
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar)

                    # log
                    run_stats = [action, action_log_prob, value]
                    if train_stats is not None:
                        self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()

    def encode_running_trajectory(self):
        """
        (Re-)Encodes (for each process) the entire current trajectory.
        Returns sample/mean/logvar and hidden state (if applicable) for the current timestep.
        :return:
        """

        # for each process, get the current batch (zero-padded obs/act/rew + length indicators)
        prev_obs, next_obs, act, rew, lens = self.vae.rollout_storage.get_running_batch(
        )

        # get embedding - will return (1+sequence_len) * batch * input_size -- includes the prior!
        all_latent_samples, all_latent_means, all_latent_logvars, all_hidden_states = self.vae.encoder(
            actions=act,
            states=next_obs,
            rewards=rew,
            hidden_state=None,
            return_prior=True)

        # get the embedding / hidden state of the current time step (need to do this since we zero-padded)
        latent_sample = (torch.stack([
            all_latent_samples[lens[i]][i] for i in range(len(lens))
        ])).detach().to(device)
        latent_mean = (torch.stack([
            all_latent_means[lens[i]][i] for i in range(len(lens))
        ])).detach().to(device)
        latent_logvar = (torch.stack([
            all_latent_logvars[lens[i]][i] for i in range(len(lens))
        ])).detach().to(device)
        hidden_state = (torch.stack([
            all_hidden_states[lens[i]][i] for i in range(len(lens))
        ])).detach().to(device)

        return latent_sample, latent_mean, latent_logvar, hidden_state

    def get_value(self, obs, latent_sample, latent_mean, latent_logvar):
        obs = utl.get_augmented_obs(self.args, obs, latent_sample, latent_mean,
                                    latent_logvar)
        return self.policy.actor_critic.get_value(obs).detach()

    def update(self, obs, latent_sample, latent_mean, latent_logvar):
        """
        Meta-update.
        Here the policy is updated for good average performance across tasks.
        :return:
        """
        # update policy (if we are not pre-training, have enough data in the vae buffer, and are not at iteration 0)
        if self.iter_idx >= self.args.pretrain_len and self.iter_idx > 0:

            # bootstrap next value prediction
            with torch.no_grad():
                next_value = self.get_value(obs=obs,
                                            latent_sample=latent_sample,
                                            latent_mean=latent_mean,
                                            latent_logvar=latent_logvar)

            # compute returns for current rollouts
            self.policy_storage.compute_returns(
                next_value,
                self.args.policy_use_gae,
                self.args.policy_gamma,
                self.args.policy_tau,
                use_proper_time_limits=self.args.use_proper_time_limits)

            # update agent (this will also call the VAE update!)
            policy_train_stats = self.policy.update(
                args=self.args,
                policy_storage=self.policy_storage,
                encoder=self.vae.encoder,
                rlloss_through_encoder=self.args.rlloss_through_encoder,
                compute_vae_loss=self.vae.compute_vae_loss)
        else:
            policy_train_stats = 0, 0, 0, 0

            # pre-train the VAE
            if self.iter_idx < self.args.pretrain_len:
                self.vae.compute_vae_loss(update=True)

        return policy_train_stats, None

    def log(self, run_stats, train_stats, start_time):
        train_stats, meta_train_stats = train_stats

        # --- visualise behaviour of policy ---

        if self.iter_idx % self.args.vis_interval == 0:
            obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            utl_eval.visualise_behaviour(
                args=self.args,
                policy=self.policy,
                image_folder=self.logger.full_output_folder,
                iter_idx=self.iter_idx,
                obs_rms=obs_rms,
                ret_rms=ret_rms,
                encoder=self.vae.encoder,
                reward_decoder=self.vae.reward_decoder,
                state_decoder=self.vae.state_decoder,
                task_decoder=self.vae.task_decoder,
                compute_rew_reconstruction_loss=self.vae.
                compute_rew_reconstruction_loss,
                compute_state_reconstruction_loss=self.vae.
                compute_state_reconstruction_loss,
                compute_task_reconstruction_loss=self.vae.
                compute_task_reconstruction_loss,
                compute_kl_loss=self.vae.compute_kl_loss,
            )

        # --- evaluate policy ----

        if self.iter_idx % self.args.eval_interval == 0:

            obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            returns_per_episode = utl_eval.evaluate(args=self.args,
                                                    policy=self.policy,
                                                    obs_rms=obs_rms,
                                                    ret_rms=ret_rms,
                                                    encoder=self.vae.encoder,
                                                    iter_idx=self.iter_idx)

            # log the return avg/std across tasks (=processes)
            returns_avg = returns_per_episode.mean(dim=0)
            returns_std = returns_per_episode.std(dim=0)
            for k in range(len(returns_avg)):
                self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1),
                                returns_avg[k], self.iter_idx)
                self.logger.add(
                    'return_avg_per_frame/episode_{}'.format(k + 1),
                    returns_avg[k], self.frames)
                self.logger.add('return_std_per_iter/episode_{}'.format(k + 1),
                                returns_std[k], self.iter_idx)
                self.logger.add(
                    'return_std_per_frame/episode_{}'.format(k + 1),
                    returns_std[k], self.frames)

            print(
                "Updates {}, num timesteps {}, FPS {}, {} \n Mean return (train): {:.5f} \n"
                .format(self.iter_idx, self.frames,
                        int(self.frames / (time.time() - start_time)),
                        self.vae.rollout_storage.prev_obs.shape,
                        returns_avg[-1].item()))

        # --- save models ---

        if self.iter_idx % self.args.save_interval == 0:
            save_path = os.path.join(self.logger.full_output_folder, 'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                self.policy.actor_critic,
                os.path.join(save_path, "policy{0}.pt".format(self.iter_idx)))
            torch.save(
                self.vae.encoder,
                os.path.join(save_path, "encoder{0}.pt".format(self.iter_idx)))
            if self.vae.state_decoder is not None:
                torch.save(
                    self.vae.state_decoder,
                    os.path.join(save_path,
                                 "state_decoder{0}.pt".format(self.iter_idx)))
            if self.vae.reward_decoder is not None:
                torch.save(
                    self.vae.reward_decoder,
                    os.path.join(save_path,
                                 "reward_decoder{0}.pt".format(self.iter_idx)))
            if self.vae.task_decoder is not None:
                torch.save(
                    self.vae.task_decoder,
                    os.path.join(save_path,
                                 "task_decoder{0}.pt".format(self.iter_idx)))

            # save normalisation params of envs
            if self.args.norm_rew_for_policy:
                # save rolling mean and std
                rew_rms = self.envs.venv.ret_rms
                utl.save_obj(rew_rms, save_path,
                             "env_rew_rms{0}.pkl".format(self.iter_idx))
            if self.args.norm_obs_for_policy:
                obs_rms = self.envs.venv.obs_rms
                utl.save_obj(obs_rms, save_path,
                             "env_obs_rms{0}.pkl".format(self.iter_idx))

            # --- log some other things ---

        if self.iter_idx % self.args.log_interval == 0:

            self.logger.add('policy_losses/value_loss', train_stats[0],
                            self.iter_idx)
            self.logger.add('policy_losses/action_loss', train_stats[1],
                            self.iter_idx)
            self.logger.add('policy_losses/dist_entropy', train_stats[2],
                            self.iter_idx)
            self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx)

            self.logger.add('policy/action', run_stats[0][0].float().mean(),
                            self.iter_idx)
            if hasattr(self.policy.actor_critic, 'logstd'):
                self.logger.add('policy/action_logstd',
                                self.policy.actor_critic.dist.logstd.mean(),
                                self.iter_idx)
            self.logger.add('policy/action_logprob', run_stats[1].mean(),
                            self.iter_idx)
            self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx)

            self.logger.add('encoder/latent_mean',
                            torch.cat(self.policy_storage.latent_mean).mean(),
                            self.iter_idx)
            self.logger.add(
                'encoder/latent_logvar',
                torch.cat(self.policy_storage.latent_logvar).mean(),
                self.iter_idx)

            # log the average weights and gradients of all models (where applicable)
            for [model, name
                 ] in [[self.policy.actor_critic, 'policy'],
                       [self.vae.encoder, 'encoder'],
                       [self.vae.reward_decoder, 'reward_decoder'],
                       [self.vae.state_decoder, 'state_transition_decoder'],
                       [self.vae.task_decoder, 'task_decoder']]:
                if model is not None:
                    param_list = list(model.parameters())
                    param_mean = np.mean([
                        param_list[i].data.cpu().numpy().mean()
                        for i in range(len(param_list))
                    ])
                    self.logger.add('weights/{}'.format(name), param_mean,
                                    self.iter_idx)
                    if name == 'policy':
                        self.logger.add('weights/policy_std',
                                        param_list[0].data.mean(),
                                        self.iter_idx)
                    if param_list[0].grad is not None:
                        param_grad_mean = np.mean([
                            param_list[i].grad.cpu().numpy().mean()
                            for i in range(len(param_list))
                        ])
                        self.logger.add('gradients/{}'.format(name),
                                        param_grad_mean, self.iter_idx)

    def load_and_render(self, load_iter):
        #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahJoint-v0/varibad_73__15:05_17:14:07', 'models')
        #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/hfield', 'models')
        save_path = os.path.join(
            '/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahBlocks-v0/varibad_73__15:05_20:20:25',
            'models')
        self.policy.actor_critic = torch.load(
            os.path.join(save_path, "policy{0}.pt".format(load_iter)))
        self.vae.encoder = torch.load(
            os.path.join(save_path, "encoder{0}.pt").format(load_iter))

        args = self.args
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        num_processes = 1
        num_episodes = 100
        num_steps = 1999

        #import pdb; pdb.set_trace()
        # initialise environments
        envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=num_processes,  # 1
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # reset latent state to prior
        latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior(
            num_processes)

        for episode_idx in range(num_episodes):
            (prev_obs_raw, prev_obs_normalised) = envs.reset()
            prev_obs_raw = prev_obs_raw.to(device)
            prev_obs_normalised = prev_obs_normalised.to(device)
            for step_idx in range(num_steps):

                with torch.no_grad():
                    _, action, _ = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                        deterministic=True)

                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw,
                    rew_normalised), done, infos = utl.env_step(envs, action)
                # render
                envs.venv.venv.envs[0].env.env.env.env.render()

                # update the hidden state
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=self.vae.encoder,
                    next_obs=next_obs_raw,
                    action=action,
                    reward=rew_raw,
                    done=None,
                    hidden_state=hidden_state)

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                if done[0]:
                    break
Пример #7
0
class Learner:
    """
    Learner (no meta-learning), can be used to train avg/oracle/belief-oracle policies.
    """
    def __init__(self, args):

        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # calculate number of updates and keep count of frames/iterations
        self.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes
        self.frames = 0
        self.iter_idx = 0

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            device=device,
            episodes_per_task=self.args.max_rollouts_per_task,
            normalise_rew=args.norm_rew_for_policy,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # get policy input dimensions
        self.args.state_dim = self.envs.observation_space.shape[0]
        self.args.task_dim = self.envs.task_dim
        self.args.belief_dim = self.envs.belief_dim
        self.args.num_states = self.envs.num_states
        # get policy output (action) dimensions
        self.args.action_space = self.envs.action_space
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]

        # initialise policy
        self.policy_storage = self.initialise_policy_storage()
        self.policy = self.initialise_policy()

    def initialise_policy_storage(self):
        return OnlineStorage(
            args=self.args,
            num_steps=self.args.policy_num_steps,
            num_processes=self.args.num_processes,
            state_dim=self.args.state_dim,
            latent_dim=0,  # use metalearner.py if you want to use the VAE
            belief_dim=self.args.belief_dim,
            task_dim=self.args.task_dim,
            action_space=self.args.action_space,
            hidden_size=0,
            normalise_rewards=self.args.norm_rew_for_policy,
        )

    def initialise_policy(self):

        if hasattr(self.envs.action_space, 'low'):
            action_low = self.envs.action_space.low
            action_high = self.envs.action_space.high
        else:
            action_low = action_high = None

        # initialise policy network
        policy_net = Policy(
            args=self.args,
            #
            pass_state_to_policy=self.args.pass_state_to_policy,
            pass_latent_to_policy=
            False,  # use metalearner.py if you want to use the VAE
            pass_belief_to_policy=self.args.pass_belief_to_policy,
            pass_task_to_policy=self.args.pass_task_to_policy,
            dim_state=self.args.state_dim,
            dim_latent=0,
            dim_belief=self.args.belief_dim,
            dim_task=self.args.task_dim,
            #
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            policy_initialisation=self.args.policy_initialisation,
            #
            action_space=self.envs.action_space,
            init_std=self.args.policy_init_std,
            norm_actions_of_policy=self.args.norm_actions_of_policy,
            action_low=action_low,
            action_high=action_high,
        ).to(device)

        # initialise policy trainer
        if self.args.policy == 'a2c':
            policy = A2C(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
            )
        elif self.args.policy == 'ppo':
            policy = PPO(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
            )
        else:
            raise NotImplementedError

        return policy

    def train(self):
        """ Main training loop """
        start_time = time.time()

        # reset environments
        state, belief, task = utl.reset_env(self.envs, self.args)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_state[0].copy_(state)

        # log once before training
        with torch.no_grad():
            self.log(None, None, start_time)

        for self.iter_idx in range(self.num_updates):

            # rollout policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action, action_log_prob = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        state=state,
                        belief=belief,
                        task=task,
                        deterministic=False)

                # observe reward and next obs
                [state, belief,
                 task], (rew_raw, rew_normalised), done, infos = utl.env_step(
                     self.envs, action, self.args)

                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                # reset environments that are done
                done_indices = np.argwhere(done.flatten()).flatten()
                if len(done_indices) > 0:
                    state, belief, task = utl.reset_env(self.envs,
                                                        self.args,
                                                        indices=done_indices,
                                                        state=state)

                # add experience to policy buffer
                self.policy_storage.insert(
                    state=state,
                    belief=belief,
                    task=task,
                    actions=action,
                    action_log_probs=action_log_prob,
                    rewards_raw=rew_raw,
                    rewards_normalised=rew_normalised,
                    value_preds=value,
                    masks=masks_done,
                    bad_masks=bad_masks,
                    done=torch.from_numpy(np.array(done,
                                                   dtype=float)).unsqueeze(1),
                )

                self.frames += self.args.num_processes

            # --- UPDATE ---

            train_stats = self.update(state=state, belief=belief, task=task)

            # log
            run_stats = [action, action_log_prob, value]
            if train_stats is not None:
                with torch.no_grad():
                    self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()

    def get_value(self, state, belief, task):
        return self.policy.actor_critic.get_value(state=state,
                                                  belief=belief,
                                                  task=task,
                                                  latent=None).detach()

    def update(self, state, belief, task):
        """
        Meta-update.
        Here the policy is updated for good average performance across tasks.
        :return:    policy_train_stats which are: value_loss_epoch, action_loss_epoch, dist_entropy_epoch, loss_epoch
        """
        # bootstrap next value prediction
        with torch.no_grad():
            next_value = self.get_value(state=state, belief=belief, task=task)

        # compute returns for current rollouts
        self.policy_storage.compute_returns(
            next_value,
            self.args.policy_use_gae,
            self.args.policy_gamma,
            self.args.policy_tau,
            use_proper_time_limits=self.args.use_proper_time_limits)

        policy_train_stats = self.policy.update(
            policy_storage=self.policy_storage)

        return policy_train_stats, None

    def log(self, run_stats, train_stats, start):
        """
        Evaluate policy, save model, write to tensorboard logger.
        """

        # --- visualise behaviour of policy ---

        if self.iter_idx % self.args.vis_interval == 0:

            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None
            utl_eval.visualise_behaviour(
                args=self.args,
                policy=self.policy,
                image_folder=self.logger.full_output_folder,
                iter_idx=self.iter_idx,
                ret_rms=ret_rms,
            )

        # --- evaluate policy ----

        if self.iter_idx % self.args.eval_interval == 0:

            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            returns_per_episode = utl_eval.evaluate(args=self.args,
                                                    policy=self.policy,
                                                    ret_rms=ret_rms,
                                                    iter_idx=self.iter_idx)

            # log the average return across tasks (=processes)
            returns_avg = returns_per_episode.mean(dim=0)
            returns_std = returns_per_episode.std(dim=0)
            for k in range(len(returns_avg)):
                self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1),
                                returns_avg[k], self.iter_idx)
                self.logger.add(
                    'return_avg_per_frame/episode_{}'.format(k + 1),
                    returns_avg[k], self.frames)
                self.logger.add('return_std_per_iter/episode_{}'.format(k + 1),
                                returns_std[k], self.iter_idx)
                self.logger.add(
                    'return_std_per_frame/episode_{}'.format(k + 1),
                    returns_std[k], self.frames)

            print(
                "Updates {}, num timesteps {}, FPS {} \n Mean return (train): {:.5f} \n"
                .format(self.iter_idx, self.frames,
                        int(self.frames / (time.time() - start)),
                        returns_avg[-1].item()))

        # save model
        if self.iter_idx % self.args.save_interval == 0:
            save_path = os.path.join(self.logger.full_output_folder, 'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)

            idx_labels = ['']
            if self.args.save_intermediate_models:
                idx_labels.append(int(self.iter_idx))

            for idx_label in idx_labels:

                torch.save(self.policy.actor_critic,
                           os.path.join(save_path, f"policy{idx_label}.pt"))

                # save normalisation params of envs
                if self.args.norm_rew_for_policy:
                    rew_rms = self.envs.venv.ret_rms
                    utl.save_obj(rew_rms, save_path, f"env_rew_rms{idx_label}")
                # TODO: grab from policy and save?
                # if self.args.norm_obs_for_policy:
                #     obs_rms = self.envs.venv.obs_rms
                #     utl.save_obj(obs_rms, save_path, f"env_obs_rms{idx_label}")

        # --- log some other things ---

        if (self.iter_idx % self.args.log_interval == 0) and (train_stats
                                                              is not None):

            train_stats, _ = train_stats

            self.logger.add('policy_losses/value_loss', train_stats[0],
                            self.iter_idx)
            self.logger.add('policy_losses/action_loss', train_stats[1],
                            self.iter_idx)
            self.logger.add('policy_losses/dist_entropy', train_stats[2],
                            self.iter_idx)
            self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx)

            # writer.add_scalar('policy/action', action.mean(), j)
            self.logger.add('policy/action', run_stats[0][0].float().mean(),
                            self.iter_idx)
            if hasattr(self.policy.actor_critic, 'logstd'):
                self.logger.add('policy/action_logstd',
                                self.policy.actor_critic.dist.logstd.mean(),
                                self.iter_idx)
            self.logger.add('policy/action_logprob', run_stats[1].mean(),
                            self.iter_idx)
            self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx)

            param_list = list(self.policy.actor_critic.parameters())
            param_mean = np.mean([
                param_list[i].data.cpu().numpy().mean()
                for i in range(len(param_list))
            ])
            param_grad_mean = np.mean([
                param_list[i].grad.cpu().numpy().mean()
                for i in range(len(param_list))
            ])
            self.logger.add('weights/policy', param_mean, self.iter_idx)
            self.logger.add('weights/policy_std',
                            param_list[0].data.cpu().mean(), self.iter_idx)
            self.logger.add('gradients/policy', param_grad_mean, self.iter_idx)
            self.logger.add('gradients/policy_std',
                            param_list[0].grad.cpu().numpy().mean(),
                            self.iter_idx)
Пример #8
0
    def __init__(self, args):

        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # calculate number of updates and keep count of frames/iterations
        self.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes
        self.frames = 0
        self.iter_idx = -1

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            device=device,
            episodes_per_task=self.args.max_rollouts_per_task,
            normalise_rew=args.norm_rew_for_policy,
            ret_rms=None,
            tasks=None)

        if self.args.single_task_mode:
            # get the current tasks (which will be num_process many different tasks)
            self.train_tasks = self.envs.get_task()
            # set the tasks to the first task (i.e. just a random task)
            self.train_tasks[1:] = self.train_tasks[0]
            # make it a list
            self.train_tasks = [t for t in self.train_tasks]
            # re-initialise environments with those tasks
            self.envs = make_vec_envs(
                env_name=args.env_name,
                seed=args.seed,
                num_processes=args.num_processes,
                gamma=args.policy_gamma,
                device=device,
                episodes_per_task=self.args.max_rollouts_per_task,
                normalise_rew=args.norm_rew_for_policy,
                ret_rms=None,
                tasks=self.train_tasks,
            )
            # save the training tasks so we can evaluate on the same envs later
            utl.save_obj(self.train_tasks, self.logger.full_output_folder,
                         "train_tasks")
        else:
            self.train_tasks = None

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # get policy input dimensions
        self.args.state_dim = self.envs.observation_space.shape[0]
        self.args.task_dim = self.envs.task_dim
        self.args.belief_dim = self.envs.belief_dim
        self.args.num_states = self.envs.num_states
        # get policy output (action) dimensions
        self.args.action_space = self.envs.action_space
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]

        # initialise policy
        self.policy_storage = self.initialise_policy_storage()
        self.policy = self.initialise_policy()
Пример #9
0
    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialise environment
        self.env = make_env(
            self.args.env_name,
            self.args.max_rollouts_per_task,
            seed=self.args.seed,
            n_tasks=1,
            modify_init_state_dist=self.args.modify_init_state_dist
            if 'modify_init_state_dist' in self.args else False,
            on_circle_init_state=self.args.on_circle_init_state
            if 'on_circle_init_state' in self.args else True)

        # saving buffer with task in name folder
        if hasattr(self.args, 'save_buffer') and self.args.save_buffer:
            env_dir = os.path.join(self.args.main_save_dir,
                                   '{}'.format(self.args.env_name))
            goal = self.env.unwrapped._goal
            self.output_dir = os.path.join(
                env_dir, self.args.save_dir,
                'seed_{}_'.format(self.args.seed) +
                off_utl.create_goal_path_ext_from_goal(goal))

        if self.args.save_models or self.args.save_buffer:
            os.makedirs(self.output_dir, exist_ok=True)
            config_utl.save_config_file(args, self.output_dir)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        # if not self.args.log_tensorboard:
        #     self.save_config_json_file()
        # unwrapped env to get some info about the environment
        unwrapped_env = self.env.unwrapped

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = unwrapped_env._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task
        self.args.max_trajectory_len = args.max_trajectory_len

        # get action / observation dimensions
        if isinstance(self.env.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.env.action_space.shape[0]
        self.args.obs_dim = self.env.observation_space.shape[0]
        self.args.num_states = unwrapped_env.num_states if hasattr(
            unwrapped_env, 'num_states') else None
        self.args.act_space = self.env.action_space

        # simulate env step to get reward types
        _, _, _, info = unwrapped_env.step(unwrapped_env.action_space.sample())
        reward_types = [
            reward_type for reward_type in list(info.keys())
            if reward_type.startswith('reward')
        ]

        # support dense rewards training (if exists)
        self.args.dense_train_sparse_test = self.args.dense_train_sparse_test \
            if 'dense_train_sparse_test' in self.args else False

        # initialize policy
        self.initialize_policy()
        # initialize buffer for RL updates
        self.policy_storage = MultiTaskPolicyStorage(
            max_replay_buffer_size=int(self.args.policy_buffer_size),
            obs_dim=self.args.obs_dim,
            action_space=self.env.action_space,
            tasks=[0],
            trajectory_len=args.max_trajectory_len,
            num_reward_arrays=len(reward_types)
            if reward_types and self.args.dense_train_sparse_test else 1,
            reward_types=reward_types,
        )

        self.args.belief_reward = False  # initialize arg to not use belief rewards
Пример #10
0
class Learner:
    """
    Learner (no meta-learning), can be used to train Oracle policies.
    """
    def __init__(self, args):
        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # calculate number of meta updates
        self.args.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes

        # get action / observation dimensions
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]
        self.args.obs_dim = self.envs.observation_space.shape[0]
        self.args.num_states = self.envs.num_states if str.startswith(
            self.args.env_name, 'Grid') else None
        self.args.act_space = self.envs.action_space

        self.initialise_policy()

        # count number of frames and updates
        self.frames = 0
        self.iter_idx = 0

    def initialise_policy(self):

        # variables for task encoder (used for oracle)
        state_dim = self.envs.observation_space.shape[0]

        # TODO: this isn't ideal, find a nicer way to get the task dimension!
        if 'BeliefOracle' in self.args.env_name:
            task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \
                       gym.make(self.args.env_name.replace('BeliefOracle', '')).observation_space.shape[0]
            latent_dim = self.args.latent_dim
            state_embedding_size = self.args.state_embedding_size
            use_task_encoder = True
        elif 'Oracle' in self.args.env_name:
            task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \
                       gym.make(self.args.env_name.replace('Oracle', '')).observation_space.shape[0]
            latent_dim = self.args.latent_dim
            state_embedding_size = self.args.state_embedding_size
            use_task_encoder = True
        else:
            task_dim = latent_dim = state_embedding_size = 0
            use_task_encoder = False

        # initialise rollout storage for the policy
        self.policy_storage = OnlineStorage(
            self.args,
            self.args.policy_num_steps,
            self.args.num_processes,
            self.args.obs_dim,
            self.args.act_space,
            hidden_size=0,
            latent_dim=self.args.latent_dim,
            normalise_observations=self.args.norm_obs_for_policy,
            normalise_rewards=self.args.norm_rew_for_policy,
        )

        if hasattr(self.envs.action_space, 'low'):
            action_low = self.envs.action_space.low
            action_high = self.envs.action_space.high
        else:
            action_low = action_high = None

        # initialise policy network
        policy_net = Policy(
            # general
            state_dim=int(self.args.condition_policy_on_state) * state_dim,
            action_space=self.envs.action_space,
            init_std=self.args.policy_init_std,
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            use_task_encoder=use_task_encoder,
            # task encoding things (for oracle)
            task_dim=task_dim,
            latent_dim=latent_dim,
            state_embed_dim=state_embedding_size,
            #
            normalise_actions=self.args.normalise_actions,
            action_low=action_low,
            action_high=action_high,
        ).to(device)

        # initialise policy
        if self.args.policy == 'a2c':
            # initialise policy trainer (A2C)
            self.policy = A2C(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                alpha=self.args.a2c_alpha,
            )
        elif self.args.policy == 'ppo':
            # initialise policy network
            self.policy = PPO(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
            )
        else:
            raise NotImplementedError

    def train(self):
        """
        Given some stream of environments and a logger (tensorboard),
        (meta-)trains the policy.
        """

        start_time = time.time()

        # reset environments
        (prev_obs_raw, prev_obs_normalised) = self.envs.reset()
        prev_obs_raw = prev_obs_raw.to(device)
        prev_obs_normalised = prev_obs_normalised.to(device)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw)
        self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised)
        self.policy_storage.to(device)

        for self.iter_idx in range(self.args.num_updates):

            # check if we flushed the policy storage
            assert len(self.policy_storage.latent_mean) == 0

            # rollouts policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action, action_log_prob = utl.select_action(
                        policy=self.policy,
                        args=self.args,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        deterministic=False)

                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw, rew_normalised), done, infos = utl.env_step(
                        self.envs, action)
                action = action.float()

                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                # add the obs before reset to the policy storage
                self.policy_storage.next_obs_raw[step] = next_obs_raw.clone()
                self.policy_storage.next_obs_normalised[
                    step] = next_obs_normalised.clone()

                # reset environments that are done
                done_indices = np.argwhere(done.flatten()).flatten()
                if len(done_indices) == self.args.num_processes:
                    [next_obs_raw, next_obs_normalised] = self.envs.reset()
                    if not self.args.sample_embeddings:
                        latent_sample = latent_sample
                else:
                    for i in done_indices:
                        [next_obs_raw[i],
                         next_obs_normalised[i]] = self.envs.reset(index=i)
                        if not self.args.sample_embeddings:
                            latent_sample[i] = latent_sample[i]

                # add experience to policy buffer
                self.policy_storage.insert(
                    obs_raw=next_obs_raw.clone(),
                    obs_normalised=next_obs_normalised.clone(),
                    actions=action.clone(),
                    action_log_probs=action_log_prob.clone(),
                    rewards_raw=rew_raw.clone(),
                    rewards_normalised=rew_normalised.clone(),
                    value_preds=value.clone(),
                    masks=masks_done.clone(),
                    bad_masks=bad_masks.clone(),
                    done=torch.from_numpy(np.array(
                        done, dtype=float)).unsqueeze(1).clone(),
                )

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                self.frames += self.args.num_processes

            # --- UPDATE ---

            train_stats = self.update(prev_obs_normalised if self.args.
                                      norm_obs_for_policy else prev_obs_raw)

            # log
            run_stats = [action, action_log_prob, value]
            if train_stats is not None:
                self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()

    def get_value(self, obs):
        obs = utl.get_augmented_obs(args=self.args, obs=obs)
        return self.policy.actor_critic.get_value(obs).detach()

    def update(self, obs):
        """
        Meta-update.
        Here the policy is updated for good average performance across tasks.
        :return:    policy_train_stats which are: value_loss_epoch, action_loss_epoch, dist_entropy_epoch, loss_epoch
        """
        # bootstrap next value prediction
        with torch.no_grad():
            next_value = self.get_value(obs)

        # compute returns for current rollouts
        self.policy_storage.compute_returns(
            next_value,
            self.args.policy_use_gae,
            self.args.policy_gamma,
            self.args.policy_tau,
            use_proper_time_limits=self.args.use_proper_time_limits)

        policy_train_stats = self.policy.update(
            args=self.args, policy_storage=self.policy_storage)

        return policy_train_stats, None

    def log(self, run_stats, train_stats, start):
        """
        Evaluate policy, save model, write to tensorboard logger.
        """
        train_stats, meta_train_stats = train_stats

        # --- visualise behaviour of policy ---

        if self.iter_idx % self.args.vis_interval == 0:
            obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            utl_eval.visualise_behaviour(
                args=self.args,
                policy=self.policy,
                image_folder=self.logger.full_output_folder,
                iter_idx=self.iter_idx,
                obs_rms=obs_rms,
                ret_rms=ret_rms,
            )

        # --- evaluate policy ----

        if self.iter_idx % self.args.eval_interval == 0:

            obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            returns_per_episode = utl_eval.evaluate(args=self.args,
                                                    policy=self.policy,
                                                    obs_rms=obs_rms,
                                                    ret_rms=ret_rms,
                                                    iter_idx=self.iter_idx)

            # log the average return across tasks (=processes)
            returns_avg = returns_per_episode.mean(dim=0)
            returns_std = returns_per_episode.std(dim=0)
            for k in range(len(returns_avg)):
                self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1),
                                returns_avg[k], self.iter_idx)
                self.logger.add(
                    'return_avg_per_frame/episode_{}'.format(k + 1),
                    returns_avg[k], self.frames)
                self.logger.add('return_std_per_iter/episode_{}'.format(k + 1),
                                returns_std[k], self.iter_idx)
                self.logger.add(
                    'return_std_per_frame/episode_{}'.format(k + 1),
                    returns_std[k], self.frames)

            print(
                "Updates {}, num timesteps {}, FPS {} \n Mean return (train): {:.5f} \n"
                .format(self.iter_idx, self.frames,
                        int(self.frames / (time.time() - start)),
                        returns_avg[-1].item()))

        # save model
        if self.iter_idx % self.args.save_interval == 0:
            save_path = os.path.join(self.logger.full_output_folder, 'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                self.policy.actor_critic,
                os.path.join(save_path, "policy{0}.pt".format(self.iter_idx)))

            # save normalisation params of envs
            if self.args.norm_rew_for_policy:
                # save rolling mean and std
                rew_rms = self.envs.venv.ret_rms
                utl.save_obj(rew_rms, save_path,
                             "env_rew_rms{0}.pkl".format(self.iter_idx))
            if self.args.norm_obs_for_policy:
                obs_rms = self.envs.venv.obs_rms
                utl.save_obj(obs_rms, save_path,
                             "env_obs_rms{0}.pkl".format(self.iter_idx))

        # --- log some other things ---

        if self.iter_idx % self.args.log_interval == 0:
            self.logger.add('policy_losses/value_loss', train_stats[0],
                            self.iter_idx)
            self.logger.add('policy_losses/action_loss', train_stats[1],
                            self.iter_idx)
            self.logger.add('policy_losses/dist_entropy', train_stats[2],
                            self.iter_idx)
            self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx)

            # writer.add_scalar('policy/action', action.mean(), j)
            self.logger.add('policy/action', run_stats[0][0].float().mean(),
                            self.iter_idx)
            if hasattr(self.policy.actor_critic, 'logstd'):
                self.logger.add('policy/action_logstd',
                                self.policy.actor_critic.dist.logstd.mean(),
                                self.iter_idx)
            self.logger.add('policy/action_logprob', run_stats[1].mean(),
                            self.iter_idx)
            self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx)

            param_list = list(self.policy.actor_critic.parameters())
            param_mean = np.mean(
                [param_list[i].data.mean() for i in range(len(param_list))])
            param_grad_mean = np.mean(
                [param_list[i].grad.mean() for i in range(len(param_list))])
            self.logger.add('weights/policy', param_mean, self.iter_idx)
            self.logger.add('weights/policy_std', param_list[0].data.mean(),
                            self.iter_idx)
            self.logger.add('gradients/policy', param_grad_mean, self.iter_idx)
            self.logger.add('gradients/policy_std', param_list[0].grad.mean(),
                            self.iter_idx)
Пример #11
0
class Learner:
    """
    Learner class.
    """

    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialise environment
        self.env = make_env(self.args.env_name,
                            self.args.max_rollouts_per_task,
                            seed=self.args.seed,
                            n_tasks=1,
                            modify_init_state_dist=self.args.modify_init_state_dist
                            if 'modify_init_state_dist' in self.args else False,
                            on_circle_init_state=self.args.on_circle_init_state
                            if 'on_circle_init_state' in self.args else True)

        # saving buffer with task in name folder
        if hasattr(self.args, 'save_buffer') and self.args.save_buffer:
            env_dir = os.path.join(self.args.main_save_dir,
                                   '{}'.format(self.args.env_name))
            goal = self.env.unwrapped._goal
            self.output_dir = os.path.join(env_dir, self.args.save_dir, 'seed_{}_'.format(self.args.seed) +
                                           off_utl.create_goal_path_ext_from_goal(goal))

        if self.args.save_models or self.args.save_buffer:
            os.makedirs(self.output_dir, exist_ok=True)
            config_utl.save_config_file(args, self.output_dir)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        # if not self.args.log_tensorboard:
        #     self.save_config_json_file()
        # unwrapped env to get some info about the environment
        unwrapped_env = self.env.unwrapped

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = unwrapped_env._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task
        self.args.max_trajectory_len = args.max_trajectory_len

        # get action / observation dimensions
        if isinstance(self.env.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.env.action_space.shape[0]
        self.args.obs_dim = self.env.observation_space.shape[0]
        self.args.num_states = unwrapped_env.num_states if hasattr(unwrapped_env, 'num_states') else None
        self.args.act_space = self.env.action_space

        # simulate env step to get reward types
        _, _, _, info = unwrapped_env.step(unwrapped_env.action_space.sample())
        reward_types = [reward_type for reward_type in list(info.keys()) if reward_type.startswith('reward')]

        # support dense rewards training (if exists)
        self.args.dense_train_sparse_test = self.args.dense_train_sparse_test \
            if 'dense_train_sparse_test' in self.args else False

        # initialize policy
        self.initialize_policy()
        # initialize buffer for RL updates
        self.policy_storage = MultiTaskPolicyStorage(
            max_replay_buffer_size=int(self.args.policy_buffer_size),
            obs_dim=self.args.obs_dim,
            action_space=self.env.action_space,
            tasks=[0],
            trajectory_len=args.max_trajectory_len,
            num_reward_arrays=len(reward_types) if reward_types and self.args.dense_train_sparse_test else 1,
            reward_types=reward_types,
        )

        self.args.belief_reward = False     # initialize arg to not use belief rewards

    def initialize_policy(self):

        if self.args.policy == 'dqn':
            assert self.args.act_space.__class__.__name__ == "Discrete", (
                "Can't train DQN with continuous action space!")
            q_network = FlattenMlp(input_size=self.args.obs_dim,
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers)
            self.agent = DQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                gamma=self.args.gamma,
                eps_init=self.args.dqn_epsilon_init,
                eps_final=self.args.dqn_epsilon_final,
                exploration_iters=self.args.dqn_exploration_iters,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        # elif self.args.policy == 'ddqn':
        #     assert self.args.act_space.__class__.__name__ == "Discrete", (
        #         "Can't train DDQN with continuous action space!")
        #     q_network = FlattenMlp(input_size=self.args.obs_dim,
        #                            output_size=self.args.act_space.n,
        #                            hidden_sizes=self.args.dqn_layers)
        #     self.agent = DoubleDQN(
        #         q_network,
        #         # optimiser_vae=self.optimizer_vae,
        #         lr=self.args.policy_lr,
        #         eps_optim=self.args.dqn_eps,
        #         alpha_optim=self.args.dqn_alpha,
        #         gamma=self.args.gamma,
        #         eps_init=self.args.dqn_epsilon_init,
        #         eps_final=self.args.dqn_epsilon_final,
        #         exploration_iters=self.args.dqn_exploration_iters,
        #         tau=self.args.soft_target_tau,
        #     ).to(ptu.device)
        elif self.args.policy == 'sac':
            assert self.args.act_space.__class__.__name__ == "Box", (
                "Can't train SAC with discrete action space!")
            q1_network = FlattenMlp(input_size=self.args.obs_dim + self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            q2_network = FlattenMlp(input_size=self.args.obs_dim + self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            policy = TanhGaussianPolicy(obs_dim=self.args.obs_dim,
                                        action_dim=self.args.action_dim,
                                        hidden_sizes=self.args.policy_layers)
            self.agent = SAC(
                policy,
                q1_network,
                q2_network,

                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,

                entropy_alpha=self.args.entropy_alpha,
                automatic_entropy_tuning=self.args.automatic_entropy_tuning,
                alpha_lr=self.args.alpha_lr
            ).to(ptu.device)
        else:
            raise NotImplementedError

    def train(self):
        """
        meta-training loop
        """

        self._start_training()
        self.task_idx = 0
        for iter_ in range(self.args.num_iters):
            self.training_mode(True)
            if iter_ == 0:
                print('Collecting initial pool of data..')
                self.env.reset_task(idx=self.task_idx)
                self.collect_rollouts(num_rollouts=self.args.num_init_rollouts_pool, random_actions=True)
                print('Done!')
            # collect data from subset of train tasks

            self.env.reset_task(idx=self.task_idx)
            self.collect_rollouts(num_rollouts=self.args.num_rollouts_per_iter)
            # update
            train_stats = self.update([self.task_idx])
            self.training_mode(False)

            if self.args.policy == 'dqn':
                self.agent.set_exploration_parameter(iter_ + 1)
            # evaluate and log
            if (iter_ + 1) % self.args.log_interval == 0:
                self.log(iter_ + 1, train_stats)

    def update(self, tasks):
        '''
        RL updates
        :param tasks: list/array of task indices. perform update based on the tasks
        :return:
        '''

        # --- RL TRAINING ---
        rl_losses_agg = {}
        for update in range(self.args.rl_updates_per_iter):
            # sample random RL batch
            obs, actions, rewards, next_obs, terms = self.sample_rl_batch(tasks, self.args.batch_size)
            # flatten out task dimension
            t, b, _ = obs.size()
            obs = obs.view(t * b, -1)
            actions = actions.view(t * b, -1)
            rewards = rewards.view(t * b, -1)
            next_obs = next_obs.view(t * b, -1)
            terms = terms.view(t * b, -1)

            # RL update
            rl_losses = self.agent.update(obs, actions, rewards, next_obs, terms)

            for k, v in rl_losses.items():
                if update == 0:     # first iterate - create list
                    rl_losses_agg[k] = [v]
                else:   # append values
                    rl_losses_agg[k].append(v)
        # take mean
        for k in rl_losses_agg:
            rl_losses_agg[k] = np.mean(rl_losses_agg[k])
        self._n_rl_update_steps_total += self.args.rl_updates_per_iter

        return rl_losses_agg

    def evaluate(self, tasks):
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps

        returns_per_episode = np.zeros((len(tasks), num_episodes))
        success_rate = np.zeros(len(tasks))

        if self.args.policy == 'dqn':
            values = np.zeros((len(tasks), self.args.max_trajectory_len))
        else:
            obs_size = self.env.unwrapped.observation_space.shape[0]
            observations = np.zeros((len(tasks), self.args.max_trajectory_len + 1, obs_size))
            log_probs = np.zeros((len(tasks), self.args.max_trajectory_len))

        for task_idx, task in enumerate(tasks):

            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            if self.args.policy == 'sac':
                observations[task_idx, step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    if self.args.policy == 'dqn':
                        action, value = self.agent.act(obs=obs, deterministic=True)
                    else:
                        action, _, _, log_prob = self.agent.act(obs=obs,
                                                                deterministic=self.args.eval_deterministic,
                                                                return_log_prob=True)
                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(self.env, action.squeeze(dim=0))
                    running_reward += reward.item()
                    if self.args.policy == 'dqn':
                        values[task_idx, step] = value.item()
                    else:
                        observations[task_idx, step + 1, :] = ptu.get_numpy(next_obs[0, :obs_size])
                        log_probs[task_idx, step] = ptu.get_numpy(log_prob[0])

                    if "is_goal_state" in dir(self.env.unwrapped) and self.env.unwrapped.is_goal_state():
                        success_rate[task_idx] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task_idx, episode_idx] = running_reward

        if self.args.policy == 'dqn':
            return returns_per_episode, success_rate, values
        else:
            return returns_per_episode, success_rate, log_probs, observations

    def log(self, iteration, train_stats):
        # --- save models ---
        if iteration % self.args.save_interval == 0:
            if self.args.save_models:
                if self.args.log_tensorboard:
                    save_path = os.path.join(self.tb_logger.full_output_folder, 'models')
                else:
                    save_path = os.path.join(self.output_dir, 'models')
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                torch.save(self.agent.state_dict(), os.path.join(save_path, "agent{0}.pt".format(iteration)))
            if hasattr(self.args, 'save_buffer') and self.args.save_buffer:
                self.save_buffer()
        # evaluate to get more stats
        if self.args.policy == 'dqn':
            # get stats on train tasks
            returns_train, success_rate_train, values = self.evaluate([0])
        else:
            # get stats on train tasks
            returns_train, success_rate_train, log_probs, observations = self.evaluate([0])

        if self.args.log_tensorboard:
            if self.args.policy != 'dqn':
            #     self.env.reset(0)
            #     self.tb_logger.writer.add_figure('policy_vis_train/task_0',
            #                                      utl_eval.plot_rollouts(observations[0, :], self.env),
            #                                      self._n_env_steps_total)
            #     obs, _, _, _, _ = self.sample_rl_batch(tasks=[0],
            #                                            batch_size=self.policy_storage.task_buffers[0].size())
            #     self.tb_logger.writer.add_figure('state_space_coverage/task_0',
            #                                      utl_eval.plot_visited_states(ptu.get_numpy(obs[0][:, :2]), self.env),
            #                                      self._n_env_steps_total)
                pass
            # some metrics
            self.tb_logger.writer.add_scalar('metrics/successes_in_buffer',
                                             self._successes_in_buffer / self._n_env_steps_total,
                                             self._n_env_steps_total)

            if self.args.max_rollouts_per_task > 1:
                for episode_idx in range(self.args.max_rollouts_per_task):
                    self.tb_logger.writer.add_scalar('returns_multi_episode/episode_{}'.
                                                     format(episode_idx + 1),
                                                     np.mean(returns_train[:, episode_idx]),
                                                     self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns_multi_episode/sum',
                                                 np.mean(np.sum(returns_train, axis=-1)),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns_multi_episode/success_rate',
                                                 np.mean(success_rate_train),
                                                 self._n_env_steps_total)
            else:
                self.tb_logger.writer.add_scalar('returns/returns_mean_train', np.mean(returns_train),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns/returns_std_train', np.std(returns_train),
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('returns/success_rate_train', np.mean(success_rate_train),
                                                 self._n_env_steps_total)
            # policy
            if self.args.policy == 'dqn':
                self.tb_logger.writer.add_scalar('policy/value_init', np.mean(values[:, 0]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('policy/value_halfway', np.mean(values[:, int(values.shape[-1]/2)]), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('policy/value_final', np.mean(values[:, -1]), self._n_env_steps_total)

                self.tb_logger.writer.add_scalar('policy/exploration_epsilon', self.agent.eps, self._n_env_steps_total)
                # RL losses
                self.tb_logger.writer.add_scalar('rl_losses/qf_loss_vs_n_updates', train_stats['qf_loss'],
                                                 self._n_rl_update_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/qf_loss_vs_n_env_steps', train_stats['qf_loss'],
                                                 self._n_env_steps_total)
            else:
                self.tb_logger.writer.add_scalar('policy/log_prob', np.mean(log_probs), self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/qf1_loss', train_stats['qf1_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/qf2_loss', train_stats['qf2_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/policy_loss', train_stats['policy_loss'],
                                                 self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('rl_losses/alpha_entropy_loss', train_stats['alpha_entropy_loss'],
                                                 self._n_env_steps_total)

            # weights and gradients
            if self.args.policy == 'dqn':
                self.tb_logger.writer.add_scalar('weights/q_network',
                                                 list(self.agent.qf.parameters())[0].mean(),
                                                 self._n_env_steps_total)
                if list(self.agent.qf.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf.parameters())
                    self.tb_logger.writer.add_scalar('gradients/q_network',
                                                     sum([param_list[i].grad.mean() for i in range(len(param_list))]),
                                                     self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('weights/q_target',
                                                 list(self.agent.target_qf.parameters())[0].mean(),
                                                 self._n_env_steps_total)
                if list(self.agent.target_qf.parameters())[0].grad is not None:
                    param_list = list(self.agent.target_qf.parameters())
                    self.tb_logger.writer.add_scalar('gradients/q_target',
                                                     sum([param_list[i].grad.mean() for i in range(len(param_list))]),
                                                     self._n_env_steps_total)
            else:
                self.tb_logger.writer.add_scalar('weights/q1_network',
                                                 list(self.agent.qf1.parameters())[0].mean(),
                                                 self._n_env_steps_total)
                if list(self.agent.qf1.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf1.parameters())
                    self.tb_logger.writer.add_scalar('gradients/q1_network',
                                                     sum([param_list[i].grad.mean() for i in range(len(param_list))]),
                                                     self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('weights/q1_target',
                                                 list(self.agent.qf1_target.parameters())[0].mean(),
                                                 self._n_env_steps_total)
                if list(self.agent.qf1_target.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf1_target.parameters())
                    self.tb_logger.writer.add_scalar('gradients/q1_target',
                                                     sum([param_list[i].grad.mean() for i in range(len(param_list))]),
                                                     self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('weights/q2_network',
                                                 list(self.agent.qf2.parameters())[0].mean(),
                                                 self._n_env_steps_total)
                if list(self.agent.qf2.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf2.parameters())
                    self.tb_logger.writer.add_scalar('gradients/q2_network',
                                                     sum([param_list[i].grad.mean() for i in range(len(param_list))]),
                                                     self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('weights/q2_target',
                                                 list(self.agent.qf2_target.parameters())[0].mean(),
                                                 self._n_env_steps_total)
                if list(self.agent.qf2_target.parameters())[0].grad is not None:
                    param_list = list(self.agent.qf2_target.parameters())
                    self.tb_logger.writer.add_scalar('gradients/q2_target',
                                                     sum([param_list[i].grad.mean() for i in range(len(param_list))]),
                                                     self._n_env_steps_total)
                self.tb_logger.writer.add_scalar('weights/policy',
                                                 list(self.agent.policy.parameters())[0].mean(),
                                                 self._n_env_steps_total)
                if list(self.agent.policy.parameters())[0].grad is not None:
                    param_list = list(self.agent.policy.parameters())
                    self.tb_logger.writer.add_scalar('gradients/policy',
                                                     sum([param_list[i].grad.mean() for i in range(len(param_list))]),
                                                     self._n_env_steps_total)

        self.tb_logger.finish_iteration(iteration)
        print("Iteration -- {}, Success rate -- {:.3f}, Avg. return -- {:.3f}, Elapsed time {:5d}[s]"
              .format(iteration, np.mean(success_rate_train), np.mean(np.sum(returns_train, axis=-1)),
                      int(time.time() - self._start_time)))
        # output to user
        # print("Iteration -- {:3d}, Num. RL updates -- {:6d}, Elapsed time {:5d}[s]".
        #       format(iteration,
        #              self._n_rl_update_steps_total,
        #              int(time.time() - self._start_time)))

    def training_mode(self, mode):
        self.agent.train(mode)

    def collect_rollouts(self, num_rollouts, random_actions=False):
        '''

        :param num_rollouts:
        :param random_actions: whether to use policy to sample actions, or randomly sample action space
        :return:
        '''

        for rollout in range(num_rollouts):
            obs = ptu.from_numpy(self.env.reset(self.task_idx))
            obs = obs.reshape(-1, obs.shape[-1])
            done_rollout = False

            while not done_rollout:
                if random_actions:
                    if self.args.policy == 'dqn':
                        action = ptu.FloatTensor([[[self.env.action_space.sample()]]]).long()   # Sample random action
                    else:
                        action = ptu.FloatTensor([self.env.action_space.sample()])  # Sample random action
                else:
                    if self.args.policy == 'dqn':
                        action, _ = self.agent.act(obs=obs)   # DQN
                    else:
                        action, _, _, _ = self.agent.act(obs=obs)   # SAC
                # observe reward and next obs
                next_obs, reward, done, info = utl.env_step(self.env, action.squeeze(dim=0))
                done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True

                # add data to policy buffer - (s+, a, r, s'+, term)
                term = self.env.unwrapped.is_goal_state() if "is_goal_state" in dir(self.env.unwrapped) else False
                if self.args.dense_train_sparse_test:
                    rew_to_buffer = {rew_type: rew for rew_type, rew in info.items()
                                     if rew_type.startswith('reward')}
                else:
                    rew_to_buffer = ptu.get_numpy(reward.squeeze(dim=0))
                self.policy_storage.add_sample(task=self.task_idx,
                                               observation=ptu.get_numpy(obs.squeeze(dim=0)),
                                               action=ptu.get_numpy(action.squeeze(dim=0)),
                                               reward=rew_to_buffer,
                                               terminal=np.array([term], dtype=float),
                                               next_observation=ptu.get_numpy(next_obs.squeeze(dim=0)))

                # set: obs <- next_obs
                obs = next_obs.clone()

                # update statistics
                self._n_env_steps_total += 1
                if "is_goal_state" in dir(self.env.unwrapped) and self.env.unwrapped.is_goal_state():  # count successes
                    self._successes_in_buffer += 1
            self._n_rollouts_total += 1

    def sample_rl_batch(self, tasks, batch_size):
        ''' sample batch of unordered rl training data from a list/array of tasks '''
        # this batch consists of transitions sampled randomly from replay buffer
        batches = [ptu.np_to_pytorch_batch(
            self.policy_storage.random_batch(task, batch_size)) for task in tasks]
        unpacked = [utl.unpack_batch(batch) for batch in batches]
        # group elements together
        unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
        unpacked = [torch.cat(x, dim=0) for x in unpacked]
        return unpacked

    def _start_training(self):
        self._n_env_steps_total = 0
        self._n_rl_update_steps_total = 0
        self._n_vae_update_steps_total = 0
        self._n_rollouts_total = 0
        self._successes_in_buffer = 0

        self._start_time = time.time()

    def load_model(self, device='cpu', **kwargs):
        if "agent_path" in kwargs:
            self.agent.load_state_dict(torch.load(kwargs["agent_path"], map_location=device))
        self.training_mode(False)

    def save_buffer(self):
        size = self.policy_storage.task_buffers[0].size()
        np.save(os.path.join(self.output_dir, 'obs'), self.policy_storage.task_buffers[0]._observations[:size])
        np.save(os.path.join(self.output_dir, 'actions'), self.policy_storage.task_buffers[0]._actions[:size])
        if self.args.dense_train_sparse_test:
            for reward_type, reward_arr in self.policy_storage.task_buffers[0]._rewards.items():
                np.save(os.path.join(self.output_dir, reward_type), reward_arr[:size])
        else:
            np.save(os.path.join(self.output_dir, 'rewards'), self.policy_storage.task_buffers[0]._rewards[:size])
        np.save(os.path.join(self.output_dir, 'next_obs'), self.policy_storage.task_buffers[0]._next_obs[:size])
        np.save(os.path.join(self.output_dir, 'terminals'), self.policy_storage.task_buffers[0]._terminals[:size])
Пример #12
0
class MetaLearner:
    """
    Meta-Learner class with the main training loop for variBAD.
    """
    def __init__(self, args):

        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # calculate number of updates and keep count of frames/iterations
        self.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes
        self.frames = 0
        self.iter_idx = -1

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            device=device,
            episodes_per_task=self.args.max_rollouts_per_task,
            normalise_rew=args.norm_rew_for_policy,
            ret_rms=None,
            tasks=None)

        if self.args.single_task_mode:
            # get the current tasks (which will be num_process many different tasks)
            self.train_tasks = self.envs.get_task()
            # set the tasks to the first task (i.e. just a random task)
            self.train_tasks[1:] = self.train_tasks[0]
            # make it a list
            self.train_tasks = [t for t in self.train_tasks]
            # re-initialise environments with those tasks
            self.envs = make_vec_envs(
                env_name=args.env_name,
                seed=args.seed,
                num_processes=args.num_processes,
                gamma=args.policy_gamma,
                device=device,
                episodes_per_task=self.args.max_rollouts_per_task,
                normalise_rew=args.norm_rew_for_policy,
                ret_rms=None,
                tasks=self.train_tasks)
            # save the training tasks so we can evaluate on the same envs later
            utl.save_obj(self.train_tasks, self.logger.full_output_folder,
                         "train_tasks")
        else:
            self.train_tasks = None

        # calculate what the maximum length of the trajectories is
        self.args.max_trajectory_len = self.envs._max_episode_steps
        self.args.max_trajectory_len *= self.args.max_rollouts_per_task

        # get policy input dimensions
        self.args.state_dim = self.envs.observation_space.shape[0]
        self.args.task_dim = self.envs.task_dim
        self.args.belief_dim = self.envs.belief_dim
        self.args.num_states = self.envs.num_states
        # get policy output (action) dimensions
        self.args.action_space = self.envs.action_space
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]

        # initialise VAE and policy
        self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx)
        self.policy_storage = self.initialise_policy_storage()
        self.policy = self.initialise_policy()

    def initialise_policy_storage(self):
        return OnlineStorage(
            args=self.args,
            num_steps=self.args.policy_num_steps,
            num_processes=self.args.num_processes,
            state_dim=self.args.state_dim,
            latent_dim=self.args.latent_dim,
            belief_dim=self.args.belief_dim,
            task_dim=self.args.task_dim,
            action_space=self.args.action_space,
            hidden_size=self.args.encoder_gru_hidden_size,
            normalise_rewards=self.args.norm_rew_for_policy,
        )

    def initialise_policy(self):

        # initialise policy network
        policy_net = Policy(
            args=self.args,
            #
            pass_state_to_policy=self.args.pass_state_to_policy,
            pass_latent_to_policy=self.args.pass_latent_to_policy,
            pass_belief_to_policy=self.args.pass_belief_to_policy,
            pass_task_to_policy=self.args.pass_task_to_policy,
            dim_state=self.args.state_dim,
            dim_latent=self.args.latent_dim * 2,
            dim_belief=self.args.belief_dim,
            dim_task=self.args.task_dim,
            #
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            policy_initialisation=self.args.policy_initialisation,
            #
            action_space=self.envs.action_space,
            init_std=self.args.policy_init_std,
        ).to(device)

        # initialise policy trainer
        if self.args.policy == 'a2c':
            policy = A2C(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
            )
        elif self.args.policy == 'ppo':
            policy = PPO(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
                optimiser_vae=self.vae.optimiser_vae,
            )
        else:
            raise NotImplementedError

        return policy

    def train(self):
        """ Main Meta-Training loop """
        start_time = time.time()

        # reset environments
        prev_state, belief, task = utl.reset_env(self.envs, self.args)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_state[0].copy_(prev_state)

        # log once before training
        with torch.no_grad():
            self.log(None, None, start_time)

        for self.iter_idx in range(self.num_updates):

            # First, re-compute the hidden states given the current rollouts (since the VAE might've changed)
            with torch.no_grad():
                latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory(
                )

            # add this initial hidden state to the policy storage
            assert len(self.policy_storage.latent_mean
                       ) == 0  # make sure we emptied buffers
            self.policy_storage.hidden_states[0].copy_(hidden_state)
            self.policy_storage.latent_samples.append(latent_sample.clone())
            self.policy_storage.latent_mean.append(latent_mean.clone())
            self.policy_storage.latent_logvar.append(latent_logvar.clone())

            # rollout policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        state=prev_state,
                        belief=belief,
                        task=task,
                        deterministic=False,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                    )

                # take step in the environment
                [next_state, belief,
                 task], (rew_raw, rew_normalised), done, infos = utl.env_step(
                     self.envs, action, self.args)

                done = torch.from_numpy(np.array(
                    done, dtype=int)).to(device).float().view((-1, 1))
                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                with torch.no_grad():
                    # compute next embedding (for next loop and/or value prediction bootstrap)
                    latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                        encoder=self.vae.encoder,
                        next_obs=next_state,
                        action=action,
                        reward=rew_raw,
                        done=done,
                        hidden_state=hidden_state)

                # before resetting, update the embedding and add to vae buffer
                # (last state might include useful task info)
                if not (self.args.disable_decoder
                        and self.args.disable_kl_term):
                    self.vae.rollout_storage.insert(
                        prev_state.clone(),
                        action.detach().clone(), next_state.clone(),
                        rew_raw.clone(), done.clone(),
                        task.clone() if task is not None else None)

                # add the obs before reset to the policy storage
                self.policy_storage.next_state[step] = next_state.clone()

                # reset environments that are done
                done_indices = np.argwhere(done.cpu().flatten()).flatten()
                if len(done_indices) > 0:
                    next_state, belief, task = utl.reset_env(
                        self.envs,
                        self.args,
                        indices=done_indices,
                        state=next_state)

                # TODO: deal with resampling for posterior sampling algorithm
                #     latent_sample = latent_sample
                #     latent_sample[i] = latent_sample[i]

                # add experience to policy buffer
                self.policy_storage.insert(
                    state=next_state,
                    belief=belief,
                    task=task,
                    actions=action,
                    rewards_raw=rew_raw,
                    rewards_normalised=rew_normalised,
                    value_preds=value,
                    masks=masks_done,
                    bad_masks=bad_masks,
                    done=done,
                    hidden_states=hidden_state.squeeze(0),
                    latent_sample=latent_sample,
                    latent_mean=latent_mean,
                    latent_logvar=latent_logvar,
                )

                prev_state = next_state

                self.frames += self.args.num_processes

            # --- UPDATE ---

            if self.args.precollect_len <= self.frames:

                # check if we are pre-training the VAE
                if self.args.pretrain_len > self.iter_idx:
                    for p in range(self.args.num_vae_updates_per_pretrain):
                        self.vae.compute_vae_loss(
                            update=True,
                            pretrain_index=self.iter_idx *
                            self.args.num_vae_updates_per_pretrain + p)
                # otherwise do the normal update (policy + vae)
                else:

                    train_stats = self.update(state=prev_state,
                                              belief=belief,
                                              task=task,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar)

                    # log
                    run_stats = [
                        action, self.policy_storage.action_log_probs, value
                    ]
                    with torch.no_grad():
                        self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()

        self.envs.close()

    def encode_running_trajectory(self):
        """
        (Re-)Encodes (for each process) the entire current trajectory.
        Returns sample/mean/logvar and hidden state (if applicable) for the current timestep.
        :return:
        """

        # for each process, get the current batch (zero-padded obs/act/rew + length indicators)
        prev_obs, next_obs, act, rew, lens = self.vae.rollout_storage.get_running_batch(
        )

        # get embedding - will return (1+sequence_len) * batch * input_size -- includes the prior!
        all_latent_samples, all_latent_means, all_latent_logvars, all_hidden_states = self.vae.encoder(
            actions=act,
            states=next_obs,
            rewards=rew,
            hidden_state=None,
            return_prior=True)

        # get the embedding / hidden state of the current time step (need to do this since we zero-padded)
        latent_sample = (torch.stack([
            all_latent_samples[lens[i]][i] for i in range(len(lens))
        ])).to(device)
        latent_mean = (torch.stack([
            all_latent_means[lens[i]][i] for i in range(len(lens))
        ])).to(device)
        latent_logvar = (torch.stack([
            all_latent_logvars[lens[i]][i] for i in range(len(lens))
        ])).to(device)
        hidden_state = (torch.stack([
            all_hidden_states[lens[i]][i] for i in range(len(lens))
        ])).to(device)

        return latent_sample, latent_mean, latent_logvar, hidden_state

    def get_value(self, state, belief, task, latent_sample, latent_mean,
                  latent_logvar):
        latent = utl.get_latent_for_policy(self.args,
                                           latent_sample=latent_sample,
                                           latent_mean=latent_mean,
                                           latent_logvar=latent_logvar)
        return self.policy.actor_critic.get_value(state=state,
                                                  belief=belief,
                                                  task=task,
                                                  latent=latent).detach()

    def update(self, state, belief, task, latent_sample, latent_mean,
               latent_logvar):
        """
        Meta-update.
        Here the policy is updated for good average performance across tasks.
        :return:
        """
        # update policy (if we are not pre-training, have enough data in the vae buffer, and are not at iteration 0)
        if self.iter_idx >= self.args.pretrain_len and self.iter_idx > 0:

            # bootstrap next value prediction
            with torch.no_grad():
                next_value = self.get_value(state=state,
                                            belief=belief,
                                            task=task,
                                            latent_sample=latent_sample,
                                            latent_mean=latent_mean,
                                            latent_logvar=latent_logvar)

            # compute returns for current rollouts
            self.policy_storage.compute_returns(
                next_value,
                self.args.policy_use_gae,
                self.args.policy_gamma,
                self.args.policy_tau,
                use_proper_time_limits=self.args.use_proper_time_limits)

            # update agent (this will also call the VAE update!)
            policy_train_stats = self.policy.update(
                policy_storage=self.policy_storage,
                encoder=self.vae.encoder,
                rlloss_through_encoder=self.args.rlloss_through_encoder,
                compute_vae_loss=self.vae.compute_vae_loss)
        else:
            policy_train_stats = 0, 0, 0, 0

            # pre-train the VAE
            if self.iter_idx < self.args.pretrain_len:
                self.vae.compute_vae_loss(update=True)

        return policy_train_stats

    def log(self, run_stats, train_stats, start_time):

        # --- visualise behaviour of policy ---

        if (self.iter_idx + 1) % self.args.vis_interval == 0:
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None
            utl_eval.visualise_behaviour(
                args=self.args,
                policy=self.policy,
                image_folder=self.logger.full_output_folder,
                iter_idx=self.iter_idx,
                ret_rms=ret_rms,
                encoder=self.vae.encoder,
                reward_decoder=self.vae.reward_decoder,
                state_decoder=self.vae.state_decoder,
                task_decoder=self.vae.task_decoder,
                compute_rew_reconstruction_loss=self.vae.
                compute_rew_reconstruction_loss,
                compute_state_reconstruction_loss=self.vae.
                compute_state_reconstruction_loss,
                compute_task_reconstruction_loss=self.vae.
                compute_task_reconstruction_loss,
                compute_kl_loss=self.vae.compute_kl_loss,
                tasks=self.train_tasks,
            )

        # --- evaluate policy ----

        if (self.iter_idx + 1) % self.args.eval_interval == 0:

            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None
            returns_per_episode = utl_eval.evaluate(
                args=self.args,
                policy=self.policy,
                ret_rms=ret_rms,
                encoder=self.vae.encoder,
                iter_idx=self.iter_idx,
                tasks=self.train_tasks,
            )

            # log the return avg/std across tasks (=processes)
            returns_avg = returns_per_episode.mean(dim=0)
            returns_std = returns_per_episode.std(dim=0)
            for k in range(len(returns_avg)):
                self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1),
                                returns_avg[k], self.iter_idx)
                self.logger.add(
                    'return_avg_per_frame/episode_{}'.format(k + 1),
                    returns_avg[k], self.frames)
                self.logger.add('return_std_per_iter/episode_{}'.format(k + 1),
                                returns_std[k], self.iter_idx)
                self.logger.add(
                    'return_std_per_frame/episode_{}'.format(k + 1),
                    returns_std[k], self.frames)

            print(f"Updates {self.iter_idx}, "
                  f"Frames {self.frames}, "
                  f"FPS {int(self.frames / (time.time() - start_time))}, "
                  f"\n Mean return (train): {returns_avg[-1].item()} \n")

        # --- save models ---

        if (self.iter_idx + 1) % self.args.save_interval == 0:
            save_path = os.path.join(self.logger.full_output_folder, 'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)

            idx_labels = ['']
            if self.args.save_intermediate_models:
                idx_labels.append(int(self.iter_idx))

            for idx_label in idx_labels:

                torch.save(self.policy.actor_critic,
                           os.path.join(save_path, f"policy{idx_label}.pt"))
                torch.save(self.vae.encoder,
                           os.path.join(save_path, f"encoder{idx_label}.pt"))
                if self.vae.state_decoder is not None:
                    torch.save(
                        self.vae.state_decoder,
                        os.path.join(save_path,
                                     f"state_decoder{idx_label}.pt"))
                if self.vae.reward_decoder is not None:
                    torch.save(
                        self.vae.reward_decoder,
                        os.path.join(save_path,
                                     f"reward_decoder{idx_label}.pt"))
                if self.vae.task_decoder is not None:
                    torch.save(
                        self.vae.task_decoder,
                        os.path.join(save_path, f"task_decoder{idx_label}.pt"))

                # save normalisation params of envs
                if self.args.norm_rew_for_policy:
                    rew_rms = self.envs.venv.ret_rms
                    utl.save_obj(rew_rms, save_path, f"env_rew_rms{idx_label}")
                # TODO: grab from policy and save?
                # if self.args.norm_obs_for_policy:
                #     obs_rms = self.envs.venv.obs_rms
                #     utl.save_obj(obs_rms, save_path, f"env_obs_rms{idx_label}")

        # --- log some other things ---

        if ((self.iter_idx + 1) % self.args.log_interval
                == 0) and (train_stats is not None):

            self.logger.add('environment/state_max',
                            self.policy_storage.prev_state.max(),
                            self.iter_idx)
            self.logger.add('environment/state_min',
                            self.policy_storage.prev_state.min(),
                            self.iter_idx)

            self.logger.add('environment/rew_max',
                            self.policy_storage.rewards_raw.max(),
                            self.iter_idx)
            self.logger.add('environment/rew_min',
                            self.policy_storage.rewards_raw.min(),
                            self.iter_idx)

            self.logger.add('policy_losses/value_loss', train_stats[0],
                            self.iter_idx)
            self.logger.add('policy_losses/action_loss', train_stats[1],
                            self.iter_idx)
            self.logger.add('policy_losses/dist_entropy', train_stats[2],
                            self.iter_idx)
            self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx)

            self.logger.add('policy/action', run_stats[0][0].float().mean(),
                            self.iter_idx)
            if hasattr(self.policy.actor_critic, 'logstd'):
                self.logger.add('policy/action_logstd',
                                self.policy.actor_critic.dist.logstd.mean(),
                                self.iter_idx)
            self.logger.add('policy/action_logprob', run_stats[1].mean(),
                            self.iter_idx)
            self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx)

            self.logger.add('encoder/latent_mean',
                            torch.cat(self.policy_storage.latent_mean).mean(),
                            self.iter_idx)
            self.logger.add(
                'encoder/latent_logvar',
                torch.cat(self.policy_storage.latent_logvar).mean(),
                self.iter_idx)

            # log the average weights and gradients of all models (where applicable)
            for [model, name
                 ] in [[self.policy.actor_critic, 'policy'],
                       [self.vae.encoder, 'encoder'],
                       [self.vae.reward_decoder, 'reward_decoder'],
                       [self.vae.state_decoder, 'state_transition_decoder'],
                       [self.vae.task_decoder, 'task_decoder']]:
                if model is not None:
                    param_list = list(model.parameters())
                    param_mean = np.mean([
                        param_list[i].data.cpu().numpy().mean()
                        for i in range(len(param_list))
                    ])
                    self.logger.add('weights/{}'.format(name), param_mean,
                                    self.iter_idx)
                    if name == 'policy':
                        self.logger.add('weights/policy_std',
                                        param_list[0].data.mean(),
                                        self.iter_idx)
                    if param_list[0].grad is not None:
                        param_grad_mean = np.mean([
                            param_list[i].grad.cpu().numpy().mean()
                            for i in range(len(param_list))
                        ])
                        self.logger.add('gradients/{}'.format(name),
                                        param_grad_mean, self.iter_idx)