def _train(self):

        batch_idxes = np.arange(self.num_tasks)

        gt.start()

        for epoch in gt.timed_for(
                trange(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            # Distribute the evaluation. We ship the
            # params of each needed network to the
            # remote path collector

            params_list = []
            for net in self.policy.networks:
                params_list.append(ptu.state_dict_cpu(net))

            self.path_collector.set_policy_params(params_list)

            evaluation_train_obj_id_list = []
            count = 0
            while count < len(self.train_goals):
                if len(self.train_goals) - count < self.num_workers:
                    evaluation_obj_id = self.path_collector.async_evaluate(
                        self.train_goals[count:])
                    count = len(self.train_goals)
                else:
                    evaluation_obj_id = self.path_collector.async_evaluate(
                        self.train_goals[count:count + self.num_workers])
                    count += self.num_workers
                evaluation_train_obj_id_list.extend(evaluation_obj_id)

            assert len(evaluation_train_obj_id_list) == len(
                self.train_goals
            ), f'{len(evaluation_train_obj_id_list)}, {len(self.train_goals)}'

            evaluation_wd_obj_id_list = []
            count = 0
            while count < len(self.wd_goals):
                if len(self.wd_goals) - count < self.num_workers:
                    evaluation_obj_id = self.path_collector.async_evaluate(
                        self.wd_goals[count:])
                    count = len(self.wd_goals)
                else:
                    evaluation_obj_id = self.path_collector.async_evaluate(
                        self.wd_goals[count:count + self.num_workers])
                    count += self.num_workers
                evaluation_wd_obj_id_list.extend(evaluation_obj_id)

            assert len(evaluation_wd_obj_id_list) == len(self.wd_goals)

            # evaluation_ood_obj_id_list = []
            # count = 0
            # while count < len(self.ood_goals) :
            #     if len(self.ood_goals) - count < self.num_workers:
            #         evaluation_obj_id = self.path_collector.async_evaluate(self.ood_goals[count:])
            #         count = len(self.ood_goals)
            #     else:
            #         evaluation_obj_id = self.path_collector.async_evaluate(self.ood_goals[count:count + self.num_workers])
            #         count += self.num_workers
            #     evaluation_ood_obj_id_list.extend(evaluation_obj_id)

            # assert len(evaluation_ood_obj_id_list) == len(self.ood_goals)

            gt.stamp('set_up_evaluation', unique=False)

            train_batch_obj_id = self.train_buffer.sample_training_data(
                batch_idxes)

            for _ in trange(self.num_train_loops_per_epoch):
                train_raw_batch = ray.get(train_batch_obj_id)

                gt.stamp('sample_training_data', unique=False)

                # In this way, we can start the data sampling job for the
                # next training while doing training for the current loop.
                train_batch_obj_id = self.train_buffer.sample_training_data(
                    batch_idxes)

                gt.stamp('set_up_sampling', unique=False)

                train_data = self.construct_training_batch(train_raw_batch)
                gt.stamp('construct_training_batch', unique=False)

                self.policy.train(train_data)
                gt.stamp('training', unique=False)

            eval_train_returns = ray.get(evaluation_train_obj_id_list)

            self.avg_train_episode_returns = [
                item[0] for item in eval_train_returns
            ]
            self.final_train_achieved = [
                item[1] for item in eval_train_returns
            ]
            self.train_avg_returns = np.mean(self.avg_train_episode_returns)

            eval_wd_returns = ray.get(evaluation_wd_obj_id_list)

            self.avg_wd_episode_returns = [item[0] for item in eval_wd_returns]
            self.final_wd_achieved = [item[1] for item in eval_wd_returns]
            self.wd_avg_returns = np.mean(self.avg_wd_episode_returns)

            # eval_ood_returns = ray.get(evaluation_ood_obj_id_list)

            # self.avg_ood_episode_returns = [item[0] for item in eval_ood_returns]
            # self.final_ood_achieved = [item[1] for item in eval_ood_returns]
            # self.ood_avg_returns = np.mean(self.avg_ood_episode_returns)

            gt.stamp('evaluation', unique=False)

            self._end_epoch(epoch)
    def _train(self):

        # Fill the replay buffer to a minimum before training starts
        if self.min_num_steps_before_training > self.replay_buffer.num_steps_can_sample(
        ):
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.trainer.policy,
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                trange(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):

            # To evaluate the policy remotely,
            # we're shipping the policy params to the remote evaluator
            # This can be made more efficient
            # But this is currently extremely cheap due to small network size
            pol_state_dict = ptu.state_dict_cpu(self.trainer.policy)

            remote_eval_obj_id = self.remote_eval_data_collector.async_collect_new_paths.remote(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
                deterministic_pol=True,
                pol_state_dict=pol_state_dict)

            gt.stamp('remote evaluation submit')

            for _ in range(self.num_train_loops_per_epoch):
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.trainer.policy,
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                    optimistic_exploration=self.
                    optimistic_exp_hp['should_use'],
                    optimistic_exploration_kwargs=dict(
                        policy=self.trainer.policy,
                        qfs=[self.trainer.qf1, self.trainer.qf2],
                        hyper_params=self.optimistic_exp_hp))
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)
                gt.stamp('training', unique=False)

            # Wait for eval to finish
            ray.get([remote_eval_obj_id])
            gt.stamp('remote evaluation wait')

            self._end_epoch(epoch)
    def _train(self):

        gt.reset()

        # -----------------------Imitation phase--------------------------------

        # Fill the replay buffer to a minimum before training starts.
        # Here, we use the tiMe policy as the sampler to collect
        # self.min_num_steps_before_training transitions, which are used
        # to adapt the policy to match up with the super Q function.

        init_paths, inferred_mdp = self.tiMe_data_collector.collect_new_paths(
            self.max_path_length,
            self.min_num_steps_before_training,
            discard_incomplete_paths=False,
        )
        self.replay_buffer.add_paths(init_paths)
        self.expl_data_collector.end_epoch(-1)

        # Imitation
        # train the policy network using the collected transitions

        for _ in trange(self.num_pre_train):
            train_data = self.replay_buffer.random_batch(self.batch_size)
            self.trainer.train_to_imitate(train_data)

        self.replay_buffer.reset()

        # -------------------------------------------------------------------------

        self.trainer.update_inferred_mdp_target_policy(inferred_mdp)

        # -------------------------------------------------------------------------

        init_expl_paths = self.expl_data_collector.collect_new_paths(
            self.trainer.policy,
            self.max_path_length,
            self.min_num_steps_before_training,
            discard_incomplete_paths=False,
        )
        self.replay_buffer.add_paths(init_expl_paths)
        self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                trange(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):

            # To evaluate the policy remotely,
            # we're shipping the policy params to the remote evaluator
            # This can be made more efficient
            # But this is currently extremely cheap due to small network size
            pol_state_dict = ptu.state_dict_cpu(self.trainer.policy)

            remote_eval_obj_id = self.remote_eval_data_collector.async_collect_new_paths.remote(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
                deterministic_pol=True,
                pol_state_dict=pol_state_dict)

            gt.stamp('remote evaluation submit')

            for _ in range(self.num_train_loops_per_epoch):
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.trainer.policy,
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                    optimistic_exploration=self.
                    optimistic_exp_hp['should_use'],
                    optimistic_exploration_kwargs=dict(
                        policy=self.trainer.policy,
                        qfs=[self.trainer.qf1, self.trainer.qf2],
                        hyper_params=self.optimistic_exp_hp))
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train_qf1(train_data)

                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train_qf2_policy(train_data)

                gt.stamp('training', unique=False)

            # Wait for eval to finish
            ray.get([remote_eval_obj_id])
            gt.stamp('remote evaluation wait')

            self._end_epoch(epoch)