示例#1
0
class MetaRLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(
            self,
            env,
            policy,
            train_tasks,
            eval_tasks,
            meta_batch=64,
            num_iterations=100,
            num_train_steps_per_itr=1000,
            num_tasks_sample=100,
            num_steps_per_task=100,
            num_evals=10,
            num_steps_per_eval=1000,
            batch_size=1024,
            embedding_batch_size=1024,
            embedding_mini_batch_size=1024,
            max_path_length=1000,
            discount=0.99,
            replay_buffer_size=1000000,  #1000000,
            reward_scale=1,
            train_embedding_source='posterior_only',
            eval_embedding_source='initial_pool',
            eval_deterministic=True,
            render=False,
            save_replay_buffer=False,
            save_algorithm=False,
            save_environment=False,
            obs_emb_dim=0):
        """
        Base class for Meta RL Algorithms
        :param env: training env
        :param policy: policy that is conditioned on a latent variable z that rl_algorithm is responsible for feeding in
        :param train_tasks: list of tasks used for training
        :param eval_tasks: list of tasks used for eval
        :param meta_batch: number of tasks used for meta-update
        :param num_iterations: number of meta-updates taken
        :param num_train_steps_per_itr: number of meta-updates performed per iteration
        :param num_tasks_sample: number of train tasks to sample to collect data for
        :param num_steps_per_task: number of transitions to collect per task
        :param num_evals: number of independent evaluation runs, with separate task encodings
        :param num_steps_per_eval: number of transitions to sample for evaluation
        :param batch_size: size of batches used to compute RL update
        :param embedding_batch_size: size of batches used to compute embedding
        :param embedding_mini_batch_size: size of batch used for encoder update
        :param max_path_length: max episode length
        :param discount:
        :param replay_buffer_size: max replay buffer size
        :param reward_scale:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        """
        self.env = env
        self.policy = policy
        self.exploration_policy = policy  # Can potentially use a different policy purely for exploration rather than also solving tasks, currently not being used
        self.train_tasks = train_tasks
        self.eval_tasks = eval_tasks
        self.meta_batch = meta_batch
        self.num_iterations = num_iterations
        self.num_train_steps_per_itr = num_train_steps_per_itr
        self.num_tasks_sample = num_tasks_sample
        self.num_steps_per_task = num_steps_per_task
        self.num_evals = num_evals
        self.num_steps_per_eval = num_steps_per_eval
        self.batch_size = batch_size
        self.embedding_batch_size = embedding_batch_size
        self.embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = min(
            int(replay_buffer_size / (len(train_tasks))), 1000)
        self.reward_scale = reward_scale
        self.train_embedding_source = train_embedding_source
        self.eval_embedding_source = eval_embedding_source  # TODO: add options for computing embeddings on train tasks too
        self.eval_deterministic = eval_deterministic
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment

        self.eval_sampler = InPlacePathSampler(
            env=env,
            policy=policy,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
        )

        # separate replay buffers for
        # - training RL update
        # - training encoder update
        # - testing encoder
        self.replay_buffer = MultiTaskReplayBuffer(self.replay_buffer_size,
                                                   env,
                                                   self.train_tasks,
                                                   state_dim=obs_emb_dim)

        self.enc_replay_buffer = MultiTaskReplayBuffer(self.replay_buffer_size,
                                                       env,
                                                       self.train_tasks,
                                                       state_dim=obs_emb_dim)
        self.eval_enc_replay_buffer = MultiTaskReplayBuffer(
            self.replay_buffer_size,
            env,
            self.eval_tasks,
            state_dim=obs_emb_dim)

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []

    def make_exploration_policy(self, policy):
        return policy

    def make_eval_policy(self, policy):
        return policy

    def sample_task(self, is_eval=False):
        '''
        sample task randomly
        '''
        if is_eval:
            idx = np.random.randint(len(self.eval_tasks))
        else:
            idx = np.random.randint(len(self.train_tasks))
        return idx

    def train(self):
        '''
        meta-training loop
        '''
        self.pretrain()
        params = self.get_epoch_snapshot(-1)
        logger.save_itr_params(-1, params)
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()
        self.train_obs = self._start_new_rollout()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(
                range(self.num_iterations),
                save_itrs=True,
        ):
            self._start_epoch(it_)
            self.training_mode(True)
            if it_ == 0:
                print('collecting initial pool of data for train and eval')
                # temp for evaluating
                for idx in self.train_tasks:
                    print('train task', idx)
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    self.collect_data_sampling_from_prior(
                        num_samples=self.max_path_length * 10,
                        resample_z_every_n=self.max_path_length,
                        eval_task=False)
                """
                for idx in self.eval_tasks:
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    # TODO: make number of initial trajectories a parameter
                    self.collect_data_sampling_from_prior(num_samples=self.max_path_length * 20,
                                                          resample_z_every_n=self.max_path_length,
                                                          eval_task=True)
                """

            # Sample data from train tasks.
            for i in range(self.num_tasks_sample):
                idx = np.random.randint(len(self.train_tasks))
                self.task_idx = idx
                self.env.reset_task(idx)

                # TODO: there may be more permutations of sampling/adding to encoding buffer we may wish to try
                if self.train_embedding_source == 'initial_pool':
                    # embeddings are computed using only the initial pool of data
                    # sample data from posterior to train RL algorithm
                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=False)
                elif self.train_embedding_source == 'posterior_only':
                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        eval_task=False,
                        add_to_enc_buffer=True)
                elif self.train_embedding_source == 'online_exploration_trajectories':
                    # embeddings are computed using only data collected using the prior
                    # sample data from posterior to train RL algorithm
                    self.enc_replay_buffer.task_buffers[idx].clear()
                    # resamples using current policy, conditioned on prior

                    self.collect_data_sampling_from_prior(
                        num_samples=self.num_steps_per_task,
                        resample_z_every_n=self.max_path_length,
                        add_to_enc_buffer=True)

                    self.env.reset_task(idx)

                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=False,
                        viz=True)

                elif self.train_embedding_source == 'online_on_policy_trajectories':
                    # sample from prior, then sample more from the posterior
                    # embeddings computed from both prior and posterior data
                    self.enc_replay_buffer.task_buffers[idx].clear()
                    self.collect_data_online(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=True)
                else:
                    raise Exception(
                        "Invalid option for computing train embedding {}".
                        format(self.train_embedding_source))

            # Sample train tasks and compute gradient updates on parameters.
            for train_step in range(self.num_train_steps_per_itr):
                indices = np.random.choice(self.train_tasks, self.meta_batch)
                self._do_training(indices, train_step)
                self._n_train_steps_total += 1
            gt.stamp('train')

            #self.training_mode(False)

            # eval
            self._try_to_eval(it_)
            gt.stamp('eval')

            self._end_epoch()

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        pass

    def sample_z_from_prior(self):
        """
        Samples z from the prior distribution, which can be either a delta function at 0 or a standard Gaussian
        depending on whether we use the information bottleneck.
        :return: latent z as a Numpy array
        """
        pass

    def sample_z_from_posterior(self, idx, eval_task):
        """
        Samples z from the posterior distribution given data from task idx, where data comes from the encoding buffer
        :param idx: task idx from which to compute the posterior from
        :param eval_task: whether or not the task is an eval task
        :return: latent z as a Numpy array
        """
        pass

    # TODO: maybe find a better name for resample_z_every_n?
    def collect_data_sampling_from_prior(self,
                                         num_samples=1,
                                         resample_z_every_n=None,
                                         eval_task=False,
                                         add_to_enc_buffer=True):
        # do not resample z if resample_z_every_n is None
        if resample_z_every_n is None:
            self.policy.clear_z()
            self.collect_data(self.policy,
                              num_samples=num_samples,
                              eval_task=eval_task,
                              add_to_enc_buffer=add_to_enc_buffer)
        else:
            # collects more data in batches of resample_z_every_n until done
            while num_samples > 0:
                self.collect_data_sampling_from_prior(
                    num_samples=min(resample_z_every_n, num_samples),
                    resample_z_every_n=None,
                    eval_task=eval_task,
                    add_to_enc_buffer=add_to_enc_buffer)
                num_samples -= resample_z_every_n

    def collect_data_from_task_posterior(self,
                                         idx,
                                         num_samples=1,
                                         resample_z_every_n=None,
                                         eval_task=False,
                                         add_to_enc_buffer=True,
                                         viz=False):
        # do not resample z if resample_z_every_n is None
        if resample_z_every_n is None:
            self.sample_z_from_posterior(idx, eval_task=eval_task)
            self.collect_data(self.policy,
                              num_samples=num_samples,
                              eval_task=eval_task,
                              add_to_enc_buffer=add_to_enc_buffer,
                              viz=viz)
        else:
            # collects more data in batches of resample_z_every_n until done
            while num_samples > 0:
                self.collect_data_from_task_posterior(
                    idx=idx,
                    num_samples=min(resample_z_every_n, num_samples),
                    resample_z_every_n=None,
                    eval_task=eval_task,
                    add_to_enc_buffer=add_to_enc_buffer,
                    viz=viz)
                num_samples -= resample_z_every_n

    # split number of prior and posterior samples
    def collect_data_online(self,
                            idx,
                            num_samples,
                            eval_task=False,
                            add_to_enc_buffer=True):
        self.collect_data_sampling_from_prior(
            num_samples=num_samples,
            resample_z_every_n=self.max_path_length,
            eval_task=eval_task,
            add_to_enc_buffer=True)
        self.env.reset_task(idx)
        self.collect_data_from_task_posterior(
            idx=idx,
            num_samples=num_samples,
            resample_z_every_n=self.max_path_length,
            eval_task=eval_task,
            add_to_enc_buffer=add_to_enc_buffer,
            viz=True)

    # TODO: since switching tasks now resets the environment, we are not correctly handling episodes terminating
    # correctly. We also aren't using the episodes anywhere, but we should probably change this to make it gather paths
    # until we have more samples than num_samples, to make sure every episode cleanly terminates when intended.

    # @profile
    def collect_data(self,
                     agent,
                     num_samples=1,
                     max_resets=None,
                     eval_task=False,
                     add_to_enc_buffer=True,
                     viz=False):
        '''
        collect data from current env in batch mode
        with given policy
        '''

        images = []
        # if num_samples == 50:
        #     import pdb; pdb.set_trace()

        env_time = self.env.time
        rews = []
        terms = []
        n_resets = 0

        for _ in range(num_samples):
            action, agent_info = self._get_action_and_info(
                agent, self.train_obs)
            if self.render:
                self.env.render()

            next_ob, raw_reward, terminal, env_info = (self.env.step(action))
            if viz:
                images.append(next_ob)
                # vis.image(next_ob[-1])

            reward = raw_reward
            rews += [reward]
            terms += [terminal]

            terminal = np.array([terminal])
            reward = np.array([reward])
            self._handle_step(
                self.task_idx,
                np.concatenate(
                    [self.train_obs.flatten()[None], agent_info['obs_emb']],
                    axis=-1),
                action,
                reward,
                np.concatenate([
                    next_ob.flatten()[None],
                    torch.zeros(agent_info['obs_emb'].shape)
                ],
                               axis=-1),
                terminal,
                eval_task=eval_task,
                add_to_enc_buffer=add_to_enc_buffer,
                agent_info=agent_info,
                env_info=env_info,
            )

            # TODO USE masking here to handle the terminal episodes
            # print(len(self._current_path_builder))
            if terminal or len(
                    self._current_path_builder) >= self.max_path_length:
                self._handle_rollout_ending(eval_task=eval_task)
                self.train_obs = self._start_new_rollout()
                n_resets += 1

                if _ + self.max_path_length > num_samples - 1:
                    break
                if max_resets is not None and n_resets > max_resets:
                    break

            else:
                # print((next_ob - self.train_obs).sum())
                # self.train_obs = None
                self.train_obs = next_ob

        if viz and np.random.random() < 0.3:
            # import pdb; pdb.set_trace()
            vis.images(np.stack(images)[:, -1:])
            vis.line(np.array([rews, terms]).T,
                     opts=dict(width=400, height=320))
            vis.text('', opts=dict(width=10000, height=5))
            # vis.video(np.stack(images))

        if not eval_task:
            self._n_env_steps_total += num_samples
            gt.stamp('sample')

    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        # import pdb; pdb.set_trace()
        return (
            # len(self._exploration_paths) > 0
            # and
            self.replay_buffer.num_steps_can_sample(self.task_idx) >=
            self.batch_size)

    def _can_train(self):
        return all([
            self.replay_buffer.num_steps_can_sample(idx) >= self.batch_size
            for idx in self.train_tasks
        ])

    def _get_action_and_info(self, agent, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        agent.set_num_steps_total(self._n_env_steps_total)
        return agent.get_action(observation, )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self):
        ret = self.env.reset()
        if isinstance(ret, tuple):
            ret = ret[0]
        return ret

    # not used
    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (ob, action, reward, next_ob, terminal, agent_info,
             env_info) in zip(
                 path["observations"],
                 path["actions"],
                 path["rewards"],
                 path["next_observations"],
                 path["terminals"],
                 path["agent_infos"],
                 path["env_infos"],
             ):
            self._handle_step(
                ob.reshape(-1),
                action,
                reward,
                next_ob.reshape(-1),
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
        self,
        task_idx,
        observation,
        action,
        reward,
        next_observation,
        terminal,
        agent_info,
        env_info,
        eval_task=False,
        add_to_enc_buffer=True,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            task=task_idx,
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        if eval_task:
            self.eval_enc_replay_buffer.add_sample(
                task=task_idx,
                observation=observation,
                action=action,
                reward=reward,
                terminal=terminal,
                next_observation=next_observation,
                agent_info=agent_info,
                env_info=env_info,
            )
        else:
            self.replay_buffer.add_sample(
                task=task_idx,
                observation=observation,
                action=action,
                reward=reward,
                terminal=terminal,
                next_observation=next_observation,
                agent_info=agent_info,
                env_info=env_info,
            )
            if add_to_enc_buffer:
                self.enc_replay_buffer.add_sample(
                    task=task_idx,
                    observation=observation,
                    action=action,
                    reward=reward,
                    terminal=terminal,
                    next_observation=next_observation,
                    agent_info=agent_info,
                    env_info=env_info,
                )

    def _handle_rollout_ending(self, eval_task=False):
        """
        Implement anything that needs to happen after every rollout.
        """
        if eval_task:
            self.eval_enc_replay_buffer.terminate_episode(self.task_idx)
        else:
            self.replay_buffer.terminate_episode(self.task_idx)
            self.enc_replay_buffer.terminate_episode(self.task_idx)

        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:  # and False:
            # self._exploration_paths.append(
            #     self._current_path_builder.get_all_stacked()
            # )
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
示例#2
0
class OfflineMetaRLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(self,
                 env,
                 agent,
                 train_tasks,
                 eval_tasks,
                 goal_radius,
                 eval_deterministic=True,
                 render=False,
                 render_eval_paths=False,
                 plotter=None,
                 **kwargs):
        """
        :param env: training env
        :param agent: agent that is conditioned on a latent variable z that rl_algorithm is responsible for feeding in
        :param train_tasks: list of tasks used for training
        :param eval_tasks: list of tasks used for eval
        :param goal_radius: reward threshold for defining sparse rewards

        see default experiment config file for descriptions of the rest of the arguments
        """
        self.env = env
        self.agent = agent
        self.train_tasks = train_tasks
        self.eval_tasks = eval_tasks
        self.goal_radius = goal_radius

        self.meta_batch = kwargs['meta_batch']
        self.batch_size = kwargs['batch_size']
        self.num_iterations = kwargs['num_iterations']
        self.num_train_steps_per_itr = kwargs['num_train_steps_per_itr']
        self.num_initial_steps = kwargs['num_initial_steps']
        self.num_tasks_sample = kwargs['num_tasks_sample']
        self.num_steps_prior = kwargs['num_steps_prior']
        self.num_steps_posterior = kwargs['num_steps_posterior']
        self.num_extra_rl_steps_posterior = kwargs[
            'num_extra_rl_steps_posterior']
        self.num_evals = kwargs['num_evals']
        self.num_steps_per_eval = kwargs['num_steps_per_eval']
        self.embedding_batch_size = kwargs['embedding_batch_size']
        self.embedding_mini_batch_size = kwargs['embedding_mini_batch_size']
        self.max_path_length = kwargs['max_path_length']
        self.discount = kwargs['discount']
        self.replay_buffer_size = kwargs['replay_buffer_size']
        self.reward_scale = kwargs['reward_scale']
        self.update_post_train = kwargs['update_post_train']
        self.num_exp_traj_eval = kwargs['num_exp_traj_eval']
        self.save_replay_buffer = kwargs['save_replay_buffer']
        self.save_algorithm = kwargs['save_algorithm']
        self.save_environment = kwargs['save_environment']
        self.dump_eval_paths = kwargs['dump_eval_paths']
        self.data_dir = kwargs['data_dir']
        self.train_epoch = kwargs['train_epoch']
        self.eval_epoch = kwargs['eval_epoch']
        self.sample = kwargs['sample']
        self.n_trj = kwargs['n_trj']
        self.allow_eval = kwargs['allow_eval']
        self.mb_replace = kwargs['mb_replace']

        self.eval_deterministic = eval_deterministic
        self.render = render
        self.eval_statistics = None
        self.render_eval_paths = render_eval_paths
        self.plotter = plotter

        self.train_buffer = MultiTaskReplayBuffer(self.replay_buffer_size, env,
                                                  self.train_tasks,
                                                  self.goal_radius)
        self.eval_buffer = MultiTaskReplayBuffer(self.replay_buffer_size, env,
                                                 self.eval_tasks,
                                                 self.goal_radius)
        self.replay_buffer = MultiTaskReplayBuffer(self.replay_buffer_size,
                                                   env, self.train_tasks,
                                                   self.goal_radius)
        self.enc_replay_buffer = MultiTaskReplayBuffer(self.replay_buffer_size,
                                                       env, self.train_tasks,
                                                       self.goal_radius)
        # offline sampler which samples from the train/eval buffer
        self.offline_sampler = OfflineInPlacePathSampler(
            env=env, policy=agent, max_path_length=self.max_path_length)
        # online sampler for evaluation (if collect on-policy context, for offline context, use self.offline_sampler)
        self.sampler = InPlacePathSampler(env=env,
                                          policy=agent,
                                          max_path_length=self.max_path_length)

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []
        self.init_buffer()

    def init_buffer(self):
        train_trj_paths = []
        eval_trj_paths = []
        # trj entry format: [obs, action, reward, new_obs]
        if self.sample:
            for n in range(self.n_trj):
                if self.train_epoch is None:
                    train_trj_paths += glob.glob(
                        os.path.join(self.data_dir, "goal_idx*",
                                     "trj_evalsample%d_step*.npy" % (n)))
                else:
                    train_trj_paths += glob.glob(
                        os.path.join(
                            self.data_dir, "goal_idx*",
                            "trj_evalsample%d_step%d.npy" %
                            (n, self.train_epoch)))
                if self.eval_epoch is None:
                    eval_trj_paths += glob.glob(
                        os.path.join(self.data_dir, "goal_idx*",
                                     "trj_evalsample%d_step*.npy" % (n)))
                else:
                    eval_trj_paths += glob.glob(
                        os.path.join(
                            self.data_dir, "goal_idx*",
                            "trj_evalsample%d_step%d.npy" %
                            (n, self.eval_epoch)))
        else:
            if self.train_epoch is None:
                train_trj_paths = glob.glob(
                    os.path.join(self.data_dir, "goal_idx*",
                                 "trj_eval[0-%d]_step*.npy") % (self.n_trj))
            else:
                train_trj_paths = glob.glob(
                    os.path.join(
                        self.data_dir, "goal_idx*",
                        "trj_eval[0-%d]_step%d.npy" %
                        (self.n_trj, self.train_epoch)))
            if self.eval_epoch is None:
                eval_trj_paths = glob.glob(
                    os.path.join(self.data_dir, "goal_idx*",
                                 "trj_eval[0-%d]_step*.npy") % (self.n_trj))
            else:
                eval_trj_paths = glob.glob(
                    os.path.join(
                        self.data_dir, "goal_idx*",
                        "trj_eval[0-%d]_step%d.npy" %
                        (self.n_trj, self.test_epoch)))

        train_paths = [
            train_trj_path for train_trj_path in train_trj_paths
            if int(train_trj_path.split('/')[-2].split('goal_idx')[-1]) in
            self.train_tasks
        ]
        train_task_idxs = [
            int(train_trj_path.split('/')[-2].split('goal_idx')[-1])
            for train_trj_path in train_trj_paths
            if int(train_trj_path.split('/')[-2].split('goal_idx')[-1]) in
            self.train_tasks
        ]
        eval_paths = [
            eval_trj_path for eval_trj_path in eval_trj_paths
            if int(eval_trj_path.split('/')[-2].split('goal_idx')[-1]) in
            self.eval_tasks
        ]
        eval_task_idxs = [
            int(eval_trj_path.split('/')[-2].split('goal_idx')[-1])
            for eval_trj_path in eval_trj_paths
            if int(eval_trj_path.split('/')[-2].split('goal_idx')[-1]) in
            self.eval_tasks
        ]

        obs_train_lst = []
        action_train_lst = []
        reward_train_lst = []
        next_obs_train_lst = []
        terminal_train_lst = []
        task_train_lst = []
        obs_eval_lst = []
        action_eval_lst = []
        reward_eval_lst = []
        next_obs_eval_lst = []
        terminal_eval_lst = []
        task_eval_lst = []

        for train_path, train_task_idx in zip(train_paths, train_task_idxs):
            trj_npy = np.load(train_path, allow_pickle=True)
            obs_train_lst += list(trj_npy[:, 0])
            action_train_lst += list(trj_npy[:, 1])
            reward_train_lst += list(trj_npy[:, 2])
            next_obs_train_lst += list(trj_npy[:, 3])
            terminal = [0 for _ in range(trj_npy.shape[0])]
            terminal[-1] = 1
            terminal_train_lst += terminal
            task_train = [train_task_idx for _ in range(trj_npy.shape[0])]
            task_train_lst += task_train
        for eval_path, eval_task_idx in zip(eval_paths, eval_task_idxs):
            trj_npy = np.load(eval_path, allow_pickle=True)
            obs_eval_lst += list(trj_npy[:, 0])
            action_eval_lst += list(trj_npy[:, 1])
            reward_eval_lst += list(trj_npy[:, 2])
            next_obs_eval_lst += list(trj_npy[:, 3])
            terminal = [0 for _ in range(trj_npy.shape[0])]
            terminal[-1] = 1
            terminal_eval_lst += terminal
            task_eval = [eval_task_idx for _ in range(trj_npy.shape[0])]
            task_eval_lst += task_eval

        # load training buffer
        for i, (
                task_train,
                obs,
                action,
                reward,
                next_obs,
                terminal,
        ) in enumerate(
                zip(
                    task_train_lst,
                    obs_train_lst,
                    action_train_lst,
                    reward_train_lst,
                    next_obs_train_lst,
                    terminal_train_lst,
                )):
            self.train_buffer.add_sample(
                task_train,
                obs,
                action,
                reward,
                terminal,
                next_obs,
                **{'env_info': {}},
            )

        # load evaluation buffer
        for i, (
                task_eval,
                obs,
                action,
                reward,
                next_obs,
                terminal,
        ) in enumerate(
                zip(
                    task_eval_lst,
                    obs_eval_lst,
                    action_eval_lst,
                    reward_eval_lst,
                    next_obs_eval_lst,
                    terminal_eval_lst,
                )):
            self.eval_buffer.add_sample(
                task_eval,
                obs,
                action,
                reward,
                terminal,
                next_obs,
                **{'env_info': {}},
            )

    def _try_to_eval(self, epoch):
        #logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)
            #params = self.get_epoch_snapshot(epoch)
            #logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys
            logger.record_tabular("Number of train steps total",
                                  self._n_train_steps_total)
            logger.record_tabular("Number of env steps total",
                                  self._n_env_steps_total)
            logger.record_tabular("Number of rollouts total",
                                  self._n_rollouts_total)

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        # eval collects its own context, so can eval any time
        return True

    def _can_train(self):
        return all([
            self.replay_buffer.num_steps_can_sample(idx) >= self.batch_size
            for idx in self.train_tasks
        ])

    def _get_action_and_info(self, agent, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        agent.set_num_steps_total(self._n_env_steps_total)
        return agent.get_action(observation, )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    ##### Snapshotting utils #####
    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    def _do_eval(self, indices, epoch, buffer):
        final_returns = []
        online_returns = []
        for idx in indices:
            all_rets = []
            for r in range(self.num_evals):
                paths = self.collect_paths(idx, epoch, r, buffer)
                all_rets.append(
                    [eval_util.get_average_returns([p]) for p in paths])
            final_returns.append(np.mean([a[-1] for a in all_rets]))
            # record online returns for the first n trajectories
            n = min([len(a) for a in all_rets])
            all_rets = [a[:n] for a in all_rets]
            all_rets = np.mean(np.stack(all_rets),
                               axis=0)  # avg return per nth rollout
            online_returns.append(all_rets)
        n = min([len(t) for t in online_returns])
        online_returns = [t[:n] for t in online_returns]
        return final_returns, online_returns

    def test(self, log_dir, end_point=-1):
        assert os.path.exists(log_dir)
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(range(self.num_iterations), save_itrs=True):
            self._start_epoch(it_)

            if it_ == 0:
                print('collecting initial pool of data for test')
                # temp for evaluating
                for idx in self.train_tasks:
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    self.collect_data(self.num_initial_steps,
                                      1,
                                      np.inf,
                                      buffer=self.train_buffer)
            # Sample data from train tasks.
            for i in range(self.num_tasks_sample):
                idx = np.random.choice(self.train_tasks, 1)[0]
                self.task_idx = idx
                self.env.reset_task(idx)
                self.enc_replay_buffer.task_buffers[idx].clear()

                # collect some trajectories with z ~ prior
                if self.num_steps_prior > 0:
                    self.collect_data(self.num_steps_prior,
                                      1,
                                      np.inf,
                                      buffer=self.train_buffer)
                # collect some trajectories with z ~ posterior
                if self.num_steps_posterior > 0:
                    self.collect_data(self.num_steps_posterior,
                                      1,
                                      self.update_post_train,
                                      buffer=self.train_buffer)
                # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior
                if self.num_extra_rl_steps_posterior > 0:
                    self.collect_data(self.num_extra_rl_steps_posterior,
                                      1,
                                      self.update_post_train,
                                      buffer=self.train_buffer,
                                      add_to_enc_buffer=False)

            print([
                self.replay_buffer.task_buffers[idx]._size
                for idx in self.train_tasks
            ])
            print([
                self.enc_replay_buffer.task_buffers[idx]._size
                for idx in self.train_tasks
            ])

            for train_step in range(self.num_train_steps_per_itr):
                self._n_train_steps_total += 1

            gt.stamp('train')
            # eval
            self.training_mode(False)
            if it_ % 5 == 0 and it_ > end_point:
                status = self.load_epoch_model(it_, log_dir)
                if status:
                    self._try_to_eval(it_)
            gt.stamp('eval')
            self._end_epoch()

    def train(self):
        '''
        meta-training loop
        '''
        params = self.get_epoch_snapshot(-1)
        logger.save_itr_params(-1, params)
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(range(self.num_iterations), save_itrs=True):
            self._start_epoch(it_)
            self.training_mode(True)
            if it_ == 0:
                print('collecting initial pool of data for train and eval')
                # temp for evaluating
                for idx in self.train_tasks:
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    self.collect_data(self.num_initial_steps,
                                      1,
                                      np.inf,
                                      buffer=self.train_buffer)
            # Sample data from train tasks.
            for i in range(self.num_tasks_sample):
                idx = np.random.choice(self.train_tasks, 1)[0]
                self.task_idx = idx
                self.env.reset_task(idx)
                self.enc_replay_buffer.task_buffers[idx].clear()

                # collect some trajectories with z ~ prior
                if self.num_steps_prior > 0:
                    self.collect_data(self.num_steps_prior,
                                      1,
                                      np.inf,
                                      buffer=self.train_buffer)
                # collect some trajectories with z ~ posterior
                if self.num_steps_posterior > 0:
                    self.collect_data(self.num_steps_posterior,
                                      1,
                                      self.update_post_train,
                                      buffer=self.train_buffer)
                # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior
                if self.num_extra_rl_steps_posterior > 0:
                    self.collect_data(self.num_extra_rl_steps_posterior,
                                      1,
                                      self.update_post_train,
                                      buffer=self.train_buffer,
                                      add_to_enc_buffer=False)

            indices_lst = []
            z_means_lst = []
            z_vars_lst = []
            # Sample train tasks and compute gradient updates on parameters.
            for train_step in range(self.num_train_steps_per_itr):
                indices = np.random.choice(self.train_tasks,
                                           self.meta_batch,
                                           replace=self.mb_replace)
                z_means, z_vars = self._do_training(indices, zloss=True)
                indices_lst.append(indices)
                z_means_lst.append(z_means)
                z_vars_lst.append(z_vars)
                self._n_train_steps_total += 1

            indices = np.concatenate(indices_lst)
            z_means = np.concatenate(z_means_lst)
            z_vars = np.concatenate(z_vars_lst)
            data_dict = self.data_dict(indices, z_means, z_vars)
            logger.save_itr_data(it_, **data_dict)
            gt.stamp('train')
            self.training_mode(False)
            # eval
            params = self.get_epoch_snapshot(it_)
            logger.save_itr_params(it_, params)

            if self.allow_eval:
                logger.save_extra_data(self.get_extra_data_to_save(it_))
                self._try_to_eval(it_)
                gt.stamp('eval')
            self._end_epoch()

    def data_dict(self, indices, z_means, z_vars):
        data_dict = {}
        data_dict['task_idx'] = indices
        for i in range(z_means.shape[1]):
            data_dict['z_means%d' % i] = list(z_means[:, i])
        for i in range(z_vars.shape[1]):
            data_dict['z_vars%d' % i] = list(z_vars[:, i])
        return data_dict

    def evaluate(self, epoch):

        if self.eval_statistics is None:
            self.eval_statistics = OrderedDict()

        ### sample trajectories from prior for debugging / visualization
        if self.dump_eval_paths:
            # 100 arbitrarily chosen for visualizations of point_robot trajectories
            # just want stochasticity of z, not the policy
            self.agent.clear_z()
            prior_paths, _ = self.offline_sampler.obtain_samples(
                buffer=self.train_buffer,
                deterministic=self.eval_deterministic,
                max_samples=self.max_path_length * 20,
                accum_context=False,
                resample=1)
            logger.save_extra_data(
                prior_paths,
                path='eval_trajectories/prior-epoch{}'.format(epoch))

        ### train tasks
        # eval on a subset of train tasks for speed

        # {}-dir envs
        if len(self.train_tasks) == 2 and len(self.eval_tasks) == 2:
            indices = self.train_tasks
            eval_util.dprint('evaluating on {} train tasks'.format(
                len(indices)))
            ### eval train tasks with posterior sampled from the training replay buffer
            train_returns = []
            for idx in indices:
                self.task_idx = idx
                self.env.reset_task(idx)
                paths = []
                print(self.num_steps_per_eval, self.max_path_length)
                for _ in range(self.num_steps_per_eval //
                               self.max_path_length):
                    context = self.sample_context(idx)
                    self.agent.infer_posterior(context, idx)
                    p, _ = self.offline_sampler.obtain_samples(
                        buffer=self.train_buffer,
                        deterministic=self.eval_deterministic,
                        max_samples=self.max_path_length,
                        accum_context=False,
                        max_trajs=1,
                        resample=np.inf)
                    paths += p

                if self.sparse_rewards:
                    for p in paths:
                        sparse_rewards = np.stack(
                            e['sparse_reward']
                            for e in p['env_infos']).reshape(-1, 1)
                        p['rewards'] = sparse_rewards

                train_returns.append(eval_util.get_average_returns(paths))
            ### eval train tasks with on-policy data to match eval of test tasks
            train_final_returns, train_online_returns = self._do_eval(
                indices, epoch, buffer=self.train_buffer)
            eval_util.dprint('train online returns')
            eval_util.dprint(train_online_returns)

            ### test tasks
            eval_util.dprint('evaluating on {} test tasks'.format(
                len(self.eval_tasks)))
            test_final_returns, test_online_returns = self._do_eval(
                self.eval_tasks, epoch, buffer=self.eval_buffer)
            eval_util.dprint('test online returns')
            eval_util.dprint(test_online_returns)

            # save the final posterior
            self.agent.log_diagnostics(self.eval_statistics)

            if hasattr(self.env, "log_diagnostics"):
                self.env.log_diagnostics(paths, prefix=None)

            avg_train_online_return = np.mean(np.stack(train_online_returns),
                                              axis=0)
            avg_test_online_return = np.mean(np.stack(test_online_returns),
                                             axis=0)
            for i in indices:
                self.eval_statistics[
                    f'AverageTrainReturn_train_task{i}'] = train_returns[i]
                self.eval_statistics[
                    f'AverageReturn_all_train_task{i}'] = train_final_returns[
                        i]
                self.eval_statistics[
                    f'AverageReturn_all_test_tasks{i}'] = test_final_returns[i]

        # non {}-dir envs
        else:
            indices = np.random.choice(self.train_tasks, len(self.eval_tasks))
            eval_util.dprint('evaluating on {} train tasks'.format(
                len(indices)))
            ### eval train tasks with posterior sampled from the training replay buffer
            train_returns = []
            for idx in indices:
                self.task_idx = idx
                self.env.reset_task(idx)
                paths = []
                for _ in range(self.num_steps_per_eval //
                               self.max_path_length):
                    context = self.sample_context(idx)
                    self.agent.infer_posterior(context, idx)
                    p, _ = self.offline_sampler.obtain_samples(
                        buffer=self.train_buffer,
                        deterministic=self.eval_deterministic,
                        max_samples=self.max_path_length,
                        accum_context=False,
                        max_trajs=1,
                        resample=np.inf)
                    paths += p

                if self.sparse_rewards:
                    for p in paths:
                        sparse_rewards = np.stack(
                            e['sparse_reward']
                            for e in p['env_infos']).reshape(-1, 1)
                        p['rewards'] = sparse_rewards

                train_returns.append(eval_util.get_average_returns(paths))
            train_returns = np.mean(train_returns)
            ### eval train tasks with on-policy data to match eval of test tasks
            train_final_returns, train_online_returns = self._do_eval(
                indices, epoch, buffer=self.train_buffer)
            eval_util.dprint('train online returns')
            eval_util.dprint(train_online_returns)

            ### test tasks
            eval_util.dprint('evaluating on {} test tasks'.format(
                len(self.eval_tasks)))
            test_final_returns, test_online_returns = self._do_eval(
                self.eval_tasks, epoch, buffer=self.eval_buffer)
            eval_util.dprint('test online returns')
            eval_util.dprint(test_online_returns)

            # save the final posterior
            self.agent.log_diagnostics(self.eval_statistics)

            if hasattr(self.env, "log_diagnostics"):
                self.env.log_diagnostics(paths, prefix=None)

            avg_train_return = np.mean(train_final_returns)
            avg_test_return = np.mean(test_final_returns)
            avg_train_online_return = np.mean(np.stack(train_online_returns),
                                              axis=0)
            avg_test_online_return = np.mean(np.stack(test_online_returns),
                                             axis=0)
            self.eval_statistics[
                'AverageTrainReturn_all_train_tasks'] = train_returns
            self.eval_statistics[
                'AverageReturn_all_train_tasks'] = avg_train_return
            self.eval_statistics[
                'AverageReturn_all_test_tasks'] = avg_test_return

            self.loss['train_returns'] = train_returns
            self.loss['avg_train_return'] = avg_train_return
            self.loss['avg_test_return'] = avg_test_return
            self.loss['avg_train_online_return'] = np.mean(
                avg_train_online_return)
            self.loss['avg_test_online_return'] = np.mean(
                avg_test_online_return)

        logger.save_extra_data(avg_train_online_return,
                               path='online-train-epoch{}'.format(epoch))
        logger.save_extra_data(avg_test_online_return,
                               path='online-test-epoch{}'.format(epoch))

        for key, value in self.eval_statistics.items():
            logger.record_tabular(key, value)
        self.eval_statistics = None

        if self.render_eval_paths:
            self.env.render_paths(paths)

        if self.plotter:
            self.plotter.draw()

    def collect_paths(self, idx, epoch, run, buffer):
        self.task_idx = idx
        self.env.reset_task(idx)

        self.agent.clear_z()
        paths = []
        num_transitions = 0
        # num_trajs = 0
        while num_transitions < self.num_steps_per_eval:
            path, num = self.offline_sampler.obtain_samples(
                buffer=buffer,
                deterministic=self.eval_deterministic,
                max_samples=self.num_steps_per_eval - num_transitions,
                max_trajs=1,
                accum_context=True,
                rollout=True)
            paths += path
            num_transitions += num

        if self.sparse_rewards:
            for p in paths:
                sparse_rewards = np.stack(
                    e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1)
                p['rewards'] = sparse_rewards

        goal = self.env._goal
        for path in paths:
            path['goal'] = goal  # goal

        # save the paths for visualization, only useful for point mass
        if self.dump_eval_paths:
            logger.save_extra_data(
                paths,
                path='eval_trajectories/task{}-epoch{}-run{}'.format(
                    idx, epoch, run))

        return paths

    def collect_data(self,
                     num_samples,
                     resample_z_rate,
                     update_posterior_rate,
                     buffer,
                     add_to_enc_buffer=True):
        '''
        get trajectories from current env in batch mode with given policy
        collect complete trajectories until the number of collected transitions >= num_samples

        :param agent: policy to rollout
        :param num_samples: total number of transitions to sample
        :param resample_z_rate: how often to resample latent context z (in units of trajectories)
        :param update_posterior_rate: how often to update q(z | c) from which z is sampled (in units of trajectories)
        :param add_to_enc_buffer: whether to add collected data to encoder replay buffer
        '''
        # start from the prior
        self.agent.clear_z()

        num_transitions = 0
        while num_transitions < num_samples:
            paths, n_samples = self.offline_sampler.obtain_samples(
                buffer=buffer,
                max_samples=num_samples - num_transitions,
                max_trajs=update_posterior_rate,
                accum_context=False,
                resample=resample_z_rate,
                rollout=False)
            num_transitions += n_samples
            self.replay_buffer.add_paths(self.task_idx, paths)
            if add_to_enc_buffer:
                self.enc_replay_buffer.add_paths(self.task_idx, paths)
            if update_posterior_rate != np.inf:
                context = self.sample_context(self.task_idx)
                self.agent.infer_posterior(context,
                                           task_indices=np.array(
                                               [self.task_idx]))
        self._n_env_steps_total += num_transitions
        gt.stamp('sample')

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass