Ejemplo n.º 1
0
    def train_batch(self, start_epoch):
        self._current_path_builder = PathBuilder()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            set_to_train_mode(self.training_env)
            observation = self._start_new_rollout()
            # This implementation is rather naive. If you want to (e.g.)
            # parallelize data collection, this would be the place to do it.
            for _ in range(self.num_env_steps_per_epoch):
                observation = self._take_step_in_env(observation)
            gt.stamp('sample')

            self._try_to_train()
            gt.stamp('train')

            set_to_eval_mode(self.env)
            self._try_to_eval(epoch)
            gt.stamp('eval')

            self._try_to_fit(epoch)
            gt.stamp('env_fit')
            self.logger.record_tabular(self.env_loss_key, self.env_loss)

            self._end_epoch(epoch)
            self.logger.dump_tabular(with_prefix=False, with_timestamp=False)
Ejemplo n.º 2
0
    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            set_to_train_mode(self.training_env)
            observation = self._start_new_rollout()
            for _ in range(self.num_env_steps_per_epoch):
                observation = self._take_step_in_env(observation)
                gt.stamp('sample')

                self._try_to_fit(epoch)
                gt.stamp('env_fit')

                self._try_to_train()
                gt.stamp('train')

            self.logger.record_tabular(self.env_loss_key, self.env_loss)

            set_to_eval_mode(self.env)
            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch(epoch)

            self.logger.dump_tabular(with_prefix=False, with_timestamp=False)
Ejemplo n.º 3
0
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
     """
     Evaluation
     """
     logger.record_dict(self.get_evaluation_diagnostics(), prefix='eval/')
     """
     Misc
     """
     gt.stamp('logging')
Ejemplo n.º 4
0
    def start_training(self, start_epoch=0):
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            steps_this_epoch = 0
            steps_since_train_call = 0
            while steps_this_epoch < self.min_steps_per_epoch:
                task_params = self.train_task_params_sampler.sample()
                rollout_len = self.do_task_rollout(task_params)

                steps_this_epoch += rollout_len
                steps_since_train_call += rollout_len

                if steps_since_train_call > self.min_steps_between_train_calls:
                    steps_since_train_call = 0
                    gt.stamp('sample')
                    self._try_to_train(epoch)
                    gt.stamp('train')

            gt.stamp('sample')
            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch()
Ejemplo n.º 5
0
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Replay Buffer
     """
     logger.record_dict(self.replay_buffer.get_diagnostics(),
                        prefix='replay_buffer/')
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
     """
     Exploration
     """
     logger.record_dict(self.expl_data_collector.get_diagnostics(),
                        prefix='exploration/')
     expl_paths = self.expl_data_collector.get_epoch_paths()
     if hasattr(self.expl_env, 'get_diagnostics'):
         logger.record_dict(
             self.expl_env.get_diagnostics(expl_paths),
             prefix='exploration/',
         )
     logger.record_dict(
         eval_util.get_generic_path_information(expl_paths),
         prefix="exploration/",
     )
     """
     Evaluation
     """
     logger.record_dict(
         self.eval_data_collector.get_diagnostics(),
         prefix='evaluation/',
     )
     eval_paths = self.eval_data_collector.get_epoch_paths()
     if hasattr(self.eval_env, 'get_diagnostics'):
         logger.record_dict(
             self.eval_env.get_diagnostics(eval_paths),
             prefix='evaluation/',
         )
     logger.record_dict(
         eval_util.get_generic_path_information(eval_paths),
         prefix="evaluation/",
     )
     """
     Misc
     """
     gt.stamp('logging')
     logger.record_dict(_get_epoch_timings())
     logger.record_tabular('Epoch', epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
Ejemplo n.º 6
0
    def _train(self):
        """Called by superclass BaseRLAlgorithm, conducts the training loop.

        Before training (i.e., the minimum number of steps before trainnig) Get
        new paths for _exploration_, with noise added (in the case of DDPG).
        Add the paths to replay buffer.

        Then we begin the actual cycle of evaluation and exploration. Each
        epoch consists of an evaluator data collector collecting paths
        (discarding incomplete ones), and then exploration data collection, and
        only exploration data is added to the buffer. The number of training
        loops is 1 by default so usually it will be one cycle of (evaluate,
        explore). Each explore, though, will do a bunch of training loops,
        e.g., 1000 by default.

        When we talk about 'steps' we really should be talking about training
        (or exploration) steps, right? The evaluation steps is for reporting
        results.
        """
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                )
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)
                gt.stamp('training', unique=False)
                self.training_mode(False)

            self._end_epoch(epoch)
Ejemplo n.º 7
0
    def _train(self):
        self.training_mode(False)
        if self.min_num_steps_before_training > 0:
            self.expl_data_collector.collect_new_steps(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            init_expl_paths = self.expl_data_collector.get_epoch_paths()
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

            gt.stamp("initial exploration", unique=True)

        num_trains_per_expl_step = (self.num_trains_per_train_loop //
                                    self.num_expl_steps_per_train_loop)
        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp("evaluation sampling")

            for _ in range(self.num_train_loops_per_epoch):
                for _ in range(self.num_expl_steps_per_train_loop):
                    self.expl_data_collector.collect_new_steps(
                        self.max_path_length,
                        1,  # num steps
                        discard_incomplete_paths=False,
                    )
                    gt.stamp("exploration sampling", unique=False)

                    self.training_mode(True)
                    for _ in range(num_trains_per_expl_step):
                        train_data = self.replay_buffer.random_batch(
                            self.batch_size)
                        self.trainer.train(train_data)
                    gt.stamp("training", unique=False)
                    self.training_mode(False)

            new_expl_paths = self.expl_data_collector.get_epoch_paths()
            self.replay_buffer.add_paths(new_expl_paths)
            gt.stamp("data storing", unique=False)

            self._end_epoch(epoch)
Ejemplo n.º 8
0
    def _end_epoch(self, epoch):
        snapshot = self._get_snapshot()
        # only save params for the first gpu
        if ptu.dist_rank == 0:
            logger.save_itr_params(epoch, snapshot)
        gt.stamp("saving")
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
    def _end_epoch(self, epoch):
        #print ("core/rl_algorithm, _end_epoch(): ", "epoch: ", epoch)
        snapshot = self._get_snapshot()
        #print ("core/rl_algorithm, _end_epoch(): ", "snapshot: ", snapshot)
        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Ejemplo n.º 10
0
    def _end_epoch(self, epoch):
        snapshot = self._get_snapshot()
        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch)

        if self.collect_actions and epoch % self.collect_actions_every == 0:
            self._log_actions(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Ejemplo n.º 11
0
    def _end_epoch(self, epoch):

        self._log_stats(epoch)
        if epoch > 0:
            snapshot = self._get_snapshot(epoch)
            logger.save_itr_params(epoch + 1, snapshot)
        gt.stamp('saving', unique=False)

        self.trainer.end_epoch(epoch)

        logger.record_dict(_get_epoch_timings())
        logger.record_tabular('Epoch', epoch)

        write_header = True if epoch == 0 else False
        logger.dump_tabular(with_prefix=False, with_timestamp=False,
                            write_header=write_header)
Ejemplo n.º 12
0
    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        observation = self._start_new_rollout()
        #observation = self.concat_state_z(state, self.curr_z)

        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            for _ in range(self.num_env_steps_per_epoch):
                ''' TODO'''
                ''' append the latent variable here'''
                action, agent_info = self._get_action_and_info(observation, )
                if self.render:
                    self.training_env.render()

                next_state, raw_reward, terminal, env_info = (
                    self.training_env.step(action))

                # print (terminal)
                next_ob = self.concat_state_z(next_state, self.curr_z)
                self._n_env_steps_total += 1
                reward = raw_reward * self.reward_scale
                terminal = np.array([terminal])
                reward = np.array([reward])
                self._handle_step(
                    observation,
                    action,
                    reward,
                    next_ob,
                    terminal,
                    agent_info=agent_info,
                    env_info=env_info,
                )
                if terminal or len(
                        self._current_path_builder) >= self.max_path_length:
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                    #print ('starting new rollout')
                else:
                    observation = next_ob

                gt.stamp('sample')
                self._try_to_train()
                gt.stamp('train')
Ejemplo n.º 13
0
    def collect_data(self,
                     num_samples,
                     resample_z_rate,
                     update_posterior_rate,
                     add_to_enc_buffer=True):
        '''
        get trajectories from current env in batch mode with given policy
        collect complete trajectories until the number of collected transitions >= num_samples

        :param agent: policy to rollout
        :param num_samples: total number of transitions to sample
        :param resample_z_rate: how often to resample latent context z (in units of trajectories)
        :param update_posterior_rate: how often to update q(z | c) from which z is sampled (in units of trajectories)
        :param add_to_enc_buffer: whether to add collected data to encoder replay buffer
        '''
        # start from the prior
        self.agent.clear_z()

        num_transitions = 0
        num_ep_succ = num_ep_fail = num_eps = 0
        if self.eval_statistics is None:
            self.eval_statistics = OrderedDict()
        while num_transitions < num_samples:
            paths, n_samples = self.sampler.obtain_samples(
                max_samples=num_samples - num_transitions,
                max_trajs=update_posterior_rate,
                accum_context=False,
                resample=resample_z_rate)
            num_transitions += n_samples
            self.replay_buffer.add_paths(self.task_idx, paths)
            if add_to_enc_buffer:
                self.enc_replay_buffer.add_paths(self.task_idx, paths)
            if update_posterior_rate != np.inf:
                context = self.sample_context(self.task_idx)
                self.agent.infer_posterior(context)
            terms = np.concatenate([p['terminals'] for p in paths])
            task_agn_rew = np.array([
                info['task_agn_rew'] for p in paths for info in p['env_infos']
            ])
            num_ep_succ += np.sum(task_agn_rew[np.where(terms)[0]] == 0.)
            num_ep_fail += np.sum(task_agn_rew[np.where(terms)[0]] == 1.)
            num_eps += np.sum(terms)
        self.eval_statistics['Episode Success Rate'] = num_ep_succ / num_eps
        self.eval_statistics['Episode Failure Rate'] = num_ep_fail / num_eps
        self._n_env_steps_total += num_transitions
        gt.stamp('sample')
Ejemplo n.º 14
0
    def _end_epoch(self, epoch):
        if not self.trainer.discrete:
            snapshot = self._get_snapshot()
            logger.save_itr_params(epoch, snapshot)
            # if snapshot['evaluation/Average Returns'] >= self.best_rewrad:
            #     self.best_rewrad = snapshot['evaluation/Average Returns']

            gt.stamp('saving')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Ejemplo n.º 15
0
    def train(self):
        if self.min_num_steps_before_training > 0:
            for _ in range(0, self.min_num_steps_before_training,
                           self.max_path_length):
                patch_trajectory = rollout(self.expl_env, self.trainer.policy,
                                           self.trainer.qf1, self.trainer.qf2,
                                           self.max_path_length,
                                           self.rnn_seq_len)
                self.replay_buffer.add_trajectory(patch_trajectory)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            rewards, seen_area, total_rotate, right_rotate = eval_rollout(
                self.eval_env, self.trainer.eval_policy, epoch,
                self.num_eval_steps_per_epoch)
            self.writer.add_scalar('eval/mean_reward', np.mean(rewards), epoch)
            self.writer.add_scalar('eval/mean_sean_area', np.mean(seen_area),
                                   epoch)
            self.writer.add_scalar('eval/max_reward', np.max(rewards), epoch)
            self.writer.add_scalar('eval/max_sean_area', np.max(seen_area),
                                   epoch)
            self.writer.add_scalar('eval/min_reward', np.min(rewards), epoch)
            self.writer.add_scalar('eval/min_sean_area', np.min(seen_area),
                                   epoch)
            self.writer.add_scalar(
                'eval/mean_rotate_ratio',
                abs(0.5 - np.sum(right_rotate) / np.sum(total_rotate)), epoch)

            gt.stamp('evalution_sampling', unique=False)

            for _ in range(self.num_train_loops_per_epoch):
                for _ in range(0, self.num_expl_steps_per_train_loop,
                               self.max_path_length):
                    patch_trajectory = rollout(self.expl_env,
                                               self.trainer.policy,
                                               self.trainer.qf1,
                                               self.trainer.qf2,
                                               self.max_path_length,
                                               self.rnn_seq_len)
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_trajectory(patch_trajectory)
                    gt.stamp('data storing', unique=False)

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_batch_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_batch_data)
                gt.stamp('training', unique=False)
                self.training_mode(False)

            self._end_epoch()
    def _train(self):
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)
            self.estimate_obs_stats(init_expl_paths[0]['observations'],
                                    init_flag=True)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self.eval_data_collector.collect_normalized_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
                input_mean=self._obs_mean,
                input_std=self._obs_std,
            )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                new_expl_paths = self.expl_data_collector.collect_normalized_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                    input_mean=self._obs_mean,
                    input_std=self._obs_std,
                )
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.estimate_obs_stats(train_data['observations'],
                                            init_flag=False)
                    train_data['observations'] = self.apply_normalize_obs(
                        train_data['observations'])
                    self.trainer.train(train_data)
                gt.stamp('training', unique=False)
                self.training_mode(False)

            self._end_epoch(epoch)
            if self.save_frequency > 0:
                if epoch % self.save_frequency == 0:
                    self.trainer.save_models(epoch)
                    self.replay_buffer.save_buffer(epoch)
Ejemplo n.º 17
0
    def _end_epoch(self, epoch, play=0, train_step=0):
        snapshot = self._get_snapshot()
        snapshot['epoch'] = epoch
        snapshot['play'] = play
        snapshot['train_step'] = train_step

        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer_expl.end_epoch(epoch)
        self.replay_buffer_eval.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Ejemplo n.º 18
0
    def play_one_step(self, observation=None, env=None):
        if self.environment_farming:
            observation = env.get_last_observation()

        if env == None:
            env = self.training_env

        action, agent_info = self._get_action_and_info(observation)

        if self.render and not self.environment_farming:
            env.render()

        next_ob, raw_reward, terminal, env_info = (env.step(action))

        self._n_env_steps_total += 1
        reward = raw_reward * self.reward_scale
        terminal = np.array([terminal])
        reward = np.array([reward])
        self._handle_step(observation,
                          action,
                          reward,
                          next_ob,
                          terminal,
                          agent_info=agent_info,
                          env_info=env_info,
                          env=env)

        if not self.environment_farming:
            current_path_builder = self._current_path_builder
        else:
            current_path_builder = env.get_current_path_builder()

        if terminal or len(current_path_builder) >= self.max_path_length:
            self._handle_rollout_ending(env)
            observation = self._start_new_rollout(env)
        else:
            observation = next_ob

        if self.environment_farming:
            self.farmer.add_free_env(env)

        gt.stamp('sample')

        return observation
Ejemplo n.º 19
0
    def _train(self):
        st = time.time()
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                runtime_policy=self.pretrain_policy,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)
        self.total_train_expl_time += time.time() - st
        self.trainer.buffer = self.replay_buffer  # TODO: make a cleaner of doing this
        self.training_mode(True)
        for _ in range(self.num_pretrain_steps):
            train_data = self.replay_buffer.random_batch(self.batch_size)
            self.trainer.train(train_data)
        self.training_mode(False)

        for epoch in gt.timed_for(
            range(self._start_epoch, self.num_epochs),
            save_itrs=True,
        ):
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
            )
            gt.stamp("evaluation sampling")
            st = time.time()
            for _ in range(self.num_train_loops_per_epoch):
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                )
                gt.stamp("exploration sampling", unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp("data storing", unique=False)

                self.training_mode(True)
                for train_step in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(self.batch_size)
                    self.trainer.train(train_data)
                gt.stamp("training", unique=False)
                self.training_mode(False)

            if self.eval_buffer:
                eval_data = self.eval_buffer.random_batch(self.batch_size)
                self.trainer.evaluate(eval_data, buffer_data=False)
                eval_data = self.replay_buffer.random_batch(self.batch_size)
                self.trainer.evaluate(eval_data, buffer_data=True)
            self.total_train_expl_time += time.time() - st

            self._end_epoch(epoch)
Ejemplo n.º 20
0
    def train_from_paths(self, paths, train_discrim=True, train_policy=True):
        """
        Reading new paths: append latent to state
        Note that is equivalent to on-policy when latent buffer size = sum of paths length
        """

        epoch_obs, epoch_next_obs, epoch_latents = [], [], []

        for path in paths:
            obs = path['observations']
            next_obs = path['next_observations']
            actions = path['actions']
            latents = path.get('latents', None)
            path_len = len(obs) - self.empowerment_horizon + 1

            obs_latents = np.concatenate([obs, latents], axis=-1)
            log_probs = self.control_policy.get_log_probs(
                ptu.from_numpy(obs_latents),
                ptu.from_numpy(actions),
            )
            log_probs = ptu.get_numpy(log_probs)

            for t in range(path_len):
                self.add_sample(
                    obs[t],
                    next_obs[t + self.empowerment_horizon - 1],
                    next_obs[t],
                    actions[t],
                    latents[t],
                    logprob=log_probs[t],
                )

                epoch_obs.append(obs[t:t + 1])
                epoch_next_obs.append(next_obs[t + self.empowerment_horizon -
                                               1:t + self.empowerment_horizon])
                epoch_latents.append(np.expand_dims(latents[t], axis=0))

        self._epoch_size = len(epoch_obs)

        gt.stamp('policy training', unique=False)

        self.train_from_torch(None)
Ejemplo n.º 21
0
    def _log_stats(self, epoch):
        logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
        """
        Replay Buffer
        """
        logger.record_dict(self.replay_buffer.get_diagnostics(),
                           prefix='replay_buffer/')
        """
        Trainer
        """
        logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
        """
        Exploration
        """
        logger.record_dict(self.expl_data_collector.get_diagnostics(),
                           prefix='exploration/')
        expl_paths = self.expl_data_collector.get_epoch_paths()
        logger.record_dict(
            eval_util.get_generic_path_information(expl_paths),
            prefix="exploration/",
        )
        """
        Remote Evaluation
        """
        logger.record_dict(
            ray.get(self.remote_eval_data_collector.get_diagnostics.remote()),
            prefix='remote_evaluation/',
        )
        remote_eval_paths = ray.get(
            self.remote_eval_data_collector.get_epoch_paths.remote())
        logger.record_dict(
            eval_util.get_generic_path_information(remote_eval_paths),
            prefix="remote_evaluation/",
        )

        logger.record_dict(self.check_q_funct_estimate(remote_eval_paths),
                           prefix="check_estimate/")
        remote_eval_paths
        """
        Misc
        """
        gt.stamp('logging')
Ejemplo n.º 22
0
    def train_from_torch(self, batch):
        gt.blank_stamp()
        losses, stats = self.compute_loss(
            batch,
            skip_statistics=not self._need_to_update_eval_statistics,
        )
        """
        Update networks
        """
        if self.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            losses.alpha_loss.backward()
            self.alpha_optimizer.step()

        self.policy_optimizer.zero_grad()
        losses.policy_loss.backward()
        self.policy_optimizer.step()

        self.qf1_optimizer.zero_grad()
        losses.qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        losses.qf2_loss.backward()
        self.qf2_optimizer.step()

        self.rf1_optimizer.zero_grad()
        losses.rf1_loss.backward()
        self.rf1_optimizer.step()

        self.rf2_optimizer.zero_grad()
        losses.rf2_loss.backward()
        self.rf2_optimizer.step()

        self._n_train_steps_total += 1

        self.try_update_target_networks()
        if self._need_to_update_eval_statistics:
            self.eval_statistics = stats
            # Compute statistics using only one batch per epoch
            self._need_to_update_eval_statistics = False
        gt.stamp('sac training', unique=False)
    def _end_epoch(self, epoch):
        snapshot = self._get_snapshot()
        # print("what is in a snapshot", snapshot)
        # logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)

        ####################################
        if (epoch % self.save_interval == 0) and self.save_dir != "":
            save_path = os.path.join(self.save_dir, self.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass
            torch.save(
                snapshot,
                os.path.join(
                    save_path,
                    self.env_name + "_{}.{}.pt".format(self.save_name, epoch)))

        replaybuffer_save = False
        if replaybuffer_save:
            if (epoch == 0 or epoch == 100
                    or epoch == 200) and self.save_dir != "":
                save_path = os.path.join(self.save_dir, self.algo)
                save_file = open(
                    os.path.join(
                        save_path, self.env_name +
                        "_{}.replaybuffer.{}".format(self.save_name, epoch)),
                    "wb")
                self.replay_buffer._env_info_keys = list(
                    self.replay_buffer.env_info_sizes)
                pickle.dump(self.replay_buffer, save_file)
Ejemplo n.º 24
0
    def _train(self):
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                )
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)

                # logging a snapshot of the replay buffer
                log_data = self.replay_buffer.random_batch(1000)
                log_goal_dict = self.replay_buffer._batch_obs_dict(
                    log_data['indices'])
                log_file_name = osp.join(logger._snapshot_dir,
                                         'buffer_%d.pkl' % epoch)
                pickle.dump(log_goal_dict['achieved_goal'].squeeze(),
                            open(log_file_name, 'wb'))

                gt.stamp('training', unique=False)
                self.training_mode(False)

            self._end_epoch(epoch)
Ejemplo n.º 25
0
    def _train(self):
        """Training of the policy implemented by child class."""
        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp('evaluation sampling')

            self.training_mode(True)
            gt.stamp('training', unique=False)
            for _ in range(self.num_train_loops_per_epoch):
                train_data = self.offline_data.get_batch(self.batch_size)
                self.trainer.train(train_data)
            self.training_mode(False)

            self._end_epoch(epoch)
Ejemplo n.º 26
0
    def collect_data(self, num_samples, resample_z_rate, update_posterior_rate, add_to_enc_buffer=True):#在当前环境下,用当前self.agent.policy采样num_samples条轨迹
        '''
        get trajectories from current env in batch mode with given policy
        collect complete trajectories until the number of collected transitions >= num_samples

        :param agent: policy to rollout
        :param num_samples: total number of transitions to sample 总共采集多少条轨迹
        :param resample_z_rate: how often to resample latent context z (in units of trajectories),每采集多少条轨迹,利用q(z|c)前向传播采样一次z
        :param update_posterior_rate: how often to update q(z | c) from which z is sampled (in units of trajectories),每多少条轨迹更新一次推断网络q(z|c)
        :param add_to_enc_buffer: whether to add collected data to encoder replay buffer
        '''
        # start from the prior
        self.agent.clear_z()

        num_transitions = 0
        while num_transitions < num_samples:#paths, n_steps_total返回轨迹与总步数
            paths, n_samples = self.sampler.obtain_samples(max_samples=num_samples - num_transitions,#最大总步数
                                                                max_trajs=update_posterior_rate,#最大轨迹数量
                                                                accum_context=False,
                                                                resample=resample_z_rate)#resample_z_rate:根据c采样z的频率
            num_transitions += n_samples#步数总数+=采样步数
            self.replay_buffer.add_paths(self.task_idx, paths)#将该task下采集的轨迹加入经验池
            print("\n    buffer",self.task_idx, "size:", self.replay_buffer.task_buffers[self.task_idx].size())
            # time.sleep(1)
            # print("task id:", self.task_idx)
            # print("buffer ", self.task_idx, ":", self.replay_buffer.task_buffers[self.task_idx])
            # print("buffer ", self.task_idx, ":", self.replay_buffer.task_buffers[self.task_idx].__dict__.items())
            # print("buffer ", self.task_idx, ":", self.replay_buffer.task_buffers[self.task_idx])
            if add_to_enc_buffer:#是否加入encoder的buffer
                self.enc_replay_buffer.add_paths(self.task_idx, paths)
                # print("enc_buffer ", self.task_idx, ":", self.enc_replay_buffer.task_buffers[self.task_idx].__dict__.items())
                # print("enc_buffer ", self.task_idx, ":",self.enc_replay_buffer.task_buffers[self.task_idx])
                print("enc_buffer",self.task_idx, "size:", self.enc_replay_buffer.task_buffers[self.task_idx].size())
                # time.sleep(1)
            if update_posterior_rate != np.inf:#利用context更新后验z
                # context = self.prepare_context(self.task_idx)
                context = self.prepare_context(self.task_idx)
                self.agent.infer_posterior(context)
        self._n_env_steps_total += num_transitions
        gt.stamp('sample')
Ejemplo n.º 27
0
    def _train(self):
        # Pretrain the model at the beginning of training until convergence
        # Note that convergence is measured against a holdout set of max size 8192
        if self.train_at_start:
            self.model_trainer.train_from_buffer(
                self.replay_buffer,
                max_grad_steps=self.model_max_grad_steps,
                epochs_since_last_update=self.model_epochs_since_last_update,
            )
        gt.stamp('model training', unique=False)

        for epoch in gt.timed_for(
            range(self._start_epoch, self.num_epochs),
            save_itrs=True,
        ):
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp('evaluation sampling')

            self.training_mode(True)
            for _ in range(self.num_train_loops_per_epoch):
                for t in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(self.batch_size)
                    self.trainer.train(train_data)
                    gt.stamp('policy training', unique=False)
            self.training_mode(False)

            self._end_epoch(epoch)
Ejemplo n.º 28
0
    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        if self.epoch_list is not None:
            iters = list(self.epoch_list)
        else:
            iters = list(range(start_epoch, self.num_epochs, self.epoch_freq))
        if self.num_epochs - 1 not in iters and self.num_epochs - 1 > iters[-1]:
            iters.append(self.num_epochs - 1)
        for epoch in gt.timed_for(
                iters,
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            env_utils.mode(self.training_env, 'train')
            observation = self._start_new_rollout()
            for _ in range(self.num_env_steps_per_epoch):
                if self.do_training:
                    observation = self._take_step_in_env(observation)

                gt.stamp('sample')
                self._try_to_train()
                gt.stamp('train')
            env_utils.mode(self.env, 'eval')
            # TODO steven: move dump_tabular to be conditionally called in
            # end_epoch and move post_epoch after eval
            self._post_epoch(epoch)
            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch()
Ejemplo n.º 29
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        evaluation_env = deep_clone(env) if self._eval_n_episodes else None
        # TODO: use Ezpickle to deep_clone???
        # evaluation_env = env

        with tf_utils.get_default_session().as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(
                    range(self._n_epochs + 1), save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(
                            iteration=t + epoch * self._epoch_length,
                            batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(policy, evaluation_env)
                gt.stamp('eval')

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)

                time_itrs = gt.get_times().stamps.itrs
                time_eval = time_itrs['eval'][-1]
                time_total = gt.get_times().total
                time_train = time_itrs.get('train', [0])[-1]
                time_sample = time_itrs.get('sample', [0])[-1]

                logger.record_tabular('time-train', time_train)
                logger.record_tabular('time-eval', time_eval)
                logger.record_tabular('time-sample', time_sample)
                logger.record_tabular('time-total', time_total)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()

            self.sampler.terminate()
Ejemplo n.º 30
0
    def _train(self):
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        best_eval_return = -np.inf
        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            eval_paths = self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            if self.save_best:
                eval_returns = [sum(path["rewards"]) for path in eval_paths]
                eval_avg_return = np.mean(eval_returns)
                if eval_avg_return > best_eval_return:
                    best_eval_return = eval_avg_return
                    snapshot = self._get_snapshot()
                    file_name = osp.join(logger._snapshot_dir, 'best.pkl')
                    torch.save(snapshot, file_name)
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                )
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)
                gt.stamp('training', unique=False)
                self.training_mode(False)

            self._end_epoch(epoch)
Ejemplo n.º 31
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        evaluation_env = deep_clone(env) if self._eval_n_episodes else None

        with tf_utils.get_default_session().as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(
                    range(self._n_epochs + 1), save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(
                            iteration=t + epoch * self._epoch_length,
                            batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(policy, evaluation_env)
                gt.stamp('eval')

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)

                time_itrs = gt.get_times().stamps.itrs
                time_eval = time_itrs['eval'][-1]
                time_total = gt.get_times().total
                time_train = time_itrs.get('train', [0])[-1]
                time_sample = time_itrs.get('sample', [0])[-1]

                logger.record_tabular('time-train', time_train)
                logger.record_tabular('time-eval', time_eval)
                logger.record_tabular('time-sample', time_sample)
                logger.record_tabular('time-total', time_total)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
Ejemplo n.º 32
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """

        self._init_training(env, policy, pool)

        with self._sess.as_default():
            observation = env.reset()
            policy.reset()

            path_length = 0
            path_return = 0
            last_path_return = 0
            max_path_return = -np.inf
            n_episodes = 0
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(
                    range(self._n_epochs + 1), save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                if self.iter_callback is not None:
                    self.iter_callback(locals(), globals())

                for t in range(self._epoch_length):
                    iteration = t + epoch * self._epoch_length

                    action, _ = policy.get_action(observation)
                    next_ob, reward, terminal, info = env.step(action)
                    path_length += 1
                    path_return += reward

                    self.pool.add_sample(
                        observation,
                        action,
                        reward,
                        terminal,
                        next_ob,
                    )

                    if terminal or path_length >= self._max_path_length:
                        observation = env.reset()
                        policy.reset()
                        path_length = 0
                        max_path_return = max(max_path_return, path_return)
                        last_path_return = path_return

                        path_return = 0
                        n_episodes += 1

                    else:
                        observation = next_ob
                    gt.stamp('sample')

                    if self.pool.size >= self._min_pool_size:
                        for i in range(self._n_train_repeat):
                            batch = self.pool.random_batch(self._batch_size)
                            self._do_training(iteration, batch)

                    gt.stamp('train')

                self._evaluate(epoch)

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)
                times_itrs = gt.get_times().stamps.itrs

                eval_time = times_itrs['eval'][-1] if epoch > 1 else 0
                total_time = gt.get_times().total
                logger.record_tabular('time-train', times_itrs['train'][-1])
                logger.record_tabular('time-eval', eval_time)
                logger.record_tabular('time-sample', times_itrs['sample'][-1])
                logger.record_tabular('time-total', total_time)
                logger.record_tabular('epoch', epoch)
                logger.record_tabular('episodes', n_episodes)
                logger.record_tabular('max-path-return', max_path_return)
                logger.record_tabular('last-path-return', last_path_return)
                logger.record_tabular('pool-size', self.pool.size)

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()

                gt.stamp('eval')

            env.terminate()