Пример #1
0
    def _get_diagnostics(self):
        timer.start_timer('logging', unique=False)
        algo_log = OrderedDict()
        append_log(algo_log,
                   self.replay_buffer.get_diagnostics(),
                   prefix='replay_buffer/')
        append_log(algo_log, self.trainer.get_diagnostics(), prefix='trainer/')
        # Eval
        if self.epoch % self._eval_epoch_freq == 0:
            self._prev_eval_log = OrderedDict()
            eval_diag = self.eval_data_collector.get_diagnostics()
            self._prev_eval_log.update(eval_diag)
            append_log(algo_log, eval_diag, prefix='eval/')
            eval_paths = self.eval_data_collector.get_epoch_paths()
            for fn in self._eval_get_diag_fns:
                addl_diag = fn(eval_paths)
                self._prev_eval_log.update(addl_diag)
                append_log(algo_log, addl_diag, prefix='eval/')
        else:
            append_log(algo_log, self._prev_eval_log, prefix='eval/')

        append_log(algo_log, _get_epoch_timings())
        algo_log['epoch'] = self.epoch
        try:
            import os
            import psutil
            process = psutil.Process(os.getpid())
            algo_log['RAM Usage (Mb)'] = int(process.memory_info().rss /
                                             1000000)
        except ImportError:
            pass
        timer.stop_timer('logging')
        return algo_log
Пример #2
0
 def run(self):
     if self.progress_csv_file_name != 'progress.csv':
         logger.remove_tabular_output('progress.csv',
                                      relative_to_snapshot_dir=True)
         logger.add_tabular_output(
             self.progress_csv_file_name,
             relative_to_snapshot_dir=True,
         )
     timer.return_global_times = True
     for _ in range(self.num_iters):
         self._begin_epoch()
         timer.start_timer('saving')
         logger.save_itr_params(self.epoch, self._get_snapshot())
         timer.stop_timer('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         self._end_epoch()
     logger.save_itr_params(self.epoch, self._get_snapshot())
     if self.progress_csv_file_name != 'progress.csv':
         logger.remove_tabular_output(
             self.progress_csv_file_name,
             relative_to_snapshot_dir=True,
         )
         logger.add_tabular_output(
             'progress.csv',
             relative_to_snapshot_dir=True,
         )
Пример #3
0
 def _get_diagnostics(self):
     timer.start_timer('logging', unique=False)
     algo_log = OrderedDict()
     append_log(algo_log, self.trainer.get_diagnostics(), prefix='trainer/')
     append_log(algo_log, _get_epoch_timings())
     algo_log['epoch'] = self.epoch
     timer.stop_timer('logging')
     return algo_log
Пример #4
0
 def train(self):
     timer.return_global_times = True
     for _ in range(self.num_epochs):
         self._begin_epoch()
         timer.start_timer('saving')
         logger.save_itr_params(self.epoch, self._get_snapshot())
         timer.stop_timer('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         self._end_epoch()
     logger.save_itr_params(self.epoch, self._get_snapshot())
Пример #5
0
    def _train(self):
        done = (self.epoch == self.num_iters)
        if done:
            return OrderedDict(), done

        timer.start_timer('training', unique=False)
        for _ in range(self.num_epochs_per_iter):
            for batch in self.data_loader:
                self.trainer.train_from_torch(batch)
        timer.stop_timer('training')
        log_stats = self._get_diagnostics()
        return log_stats, False
Пример #6
0
    def train_from_torch(self, batch):
        timer.start_timer('vae training', unique=False)
        losses, stats = self.compute_loss(
            batch,
            skip_statistics=not self._need_to_update_eval_statistics,
        )
        self.vae_optimizer.zero_grad()
        losses.vae_loss.backward()
        self.vae_optimizer.step()

        if self._need_to_update_eval_statistics:
            self.eval_statistics = stats
            self._need_to_update_eval_statistics = False
            self.example_obs_batch = batch['raw_next_observations']
        timer.stop_timer('vae training')
Пример #7
0
    def train_from_torch(self, batch):
        timer.start_timer('sac training', unique=False)
        losses, stats = self.compute_loss(
            batch,
            skip_statistics=not self._need_to_update_eval_statistics,
        )
        """
        Update networks
        """
        if self.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            losses.alpha_loss.backward()
            self.alpha_optimizer.step()

        self.policy_optimizer.zero_grad()
        losses.policy_loss.backward()
        self.policy_optimizer.step()

        self.qf1_optimizer.zero_grad()
        losses.qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        losses.qf2_loss.backward()
        self.qf2_optimizer.step()

        if self._reward_scale in {
                'auto_normalize_by_max_magnitude',
                'auto_normalize_by_max_magnitude_times_10',
        }:
            rewards = batch['rewards']
            self._reward_normalizer = (
                self._reward_normalizer * self._reward_tracking_momentum +
                rewards.abs().max() * (1 - self._reward_tracking_momentum))
        elif isinstance(self._reward_scale, Number):
            pass
        else:
            raise NotImplementedError()

        self._n_train_steps_total += 1
        self.try_update_target_networks()
        if self._need_to_update_eval_statistics:
            self.eval_statistics = stats
            # Compute statistics using only one batch per epoch
            self._need_to_update_eval_statistics = False
        timer.stop_timer('sac training')
Пример #8
0
    def train_from_torch(self, batch):
        timer.start_timer('vae training', unique=False)
        self.vae.train()
        loss, stats = self.compute_loss(
            batch,
            skip_statistics=not self._need_to_update_eval_statistics,
        )
        self.vae_optimizer.zero_grad()
        loss.backward()
        self.vae_optimizer.step()

        if self._need_to_update_eval_statistics:
            self.eval_statistics = stats
            self._need_to_update_eval_statistics = False
            self.example_batch = batch
            self.eval_statistics['num_train_batches'] = self._num_train_batches
        self._num_train_batches += 1
        timer.stop_timer('vae training')
Пример #9
0
    def _train_lifelong(self):
        done = (self.epoch == self.num_epochs)
        if done:
            return OrderedDict(), done

        self.training_mode(False)
        if self.min_num_steps_before_training > 0 and self.epoch == 0:
            self.expl_data_collector.collect_new_steps(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            init_expl_paths = self.expl_data_collector.get_epoch_paths()
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        num_trains_per_expl_step = self.num_trains_per_train_loop // self.num_expl_steps_per_train_loop
        timer.start_timer('evaluation sampling')
        if self.epoch % self._eval_epoch_freq == 0:
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
        timer.stop_timer('evaluation sampling')

        if not self._eval_only:
            for _ in range(self.num_train_loops_per_epoch):
                for _ in range(self.num_expl_steps_per_train_loop):
                    timer.start_timer('exploration sampling', unique=False)
                    self.expl_data_collector.collect_new_steps(
                        self.max_path_length,
                        1,  # num steps
                        discard_incomplete_paths=False,
                    )
                    timer.stop_timer('exploration sampling')

                    timer.start_timer('training', unique=False)
                    self.training_mode(True)
                    for _ in range(num_trains_per_expl_step):
                        train_data = self.replay_buffer.random_batch(
                            self.batch_size)
                        self.trainer.train(train_data)
                    timer.stop_timer('training')
                    self.training_mode(False)

            timer.start_timer('replay buffer data storing', unique=False)
            new_expl_paths = self.expl_data_collector.get_epoch_paths()
            self.replay_buffer.add_paths(new_expl_paths)
            timer.stop_timer('replay buffer data storing')

        log_stats = self._get_diagnostics()

        return log_stats, False
Пример #10
0
    def train_from_torch(self, batch):
        timer.start_timer('pg training', unique=False)
        losses, stats = self.compute_loss(
            batch,
            skip_statistics=not self._need_to_update_eval_statistics,
        )
        """
        Update networks
        """
        self.policy_optimizer.zero_grad()
        losses.policy_loss.backward()
        self.policy_optimizer.step()

        obs = batch['observations']
        returns = batch['returns'][:, 0]

        for _ in range(self.vf_iters_per_step):
            v = self.vf(obs)[:, 0]
            v_error = (v - returns)**2
            vf_loss = v_error.mean()

            self.vf_optimizer.zero_grad()
            vf_loss.backward()
            self.vf_optimizer.step()

        self._n_train_steps_total += 1
        if self._need_to_update_eval_statistics:
            self.eval_statistics = stats
            # Compute statistics using only one batch per epoch
            self._need_to_update_eval_statistics = False

            stats["VF Loss"] = np.mean(ptu.get_numpy(vf_loss))
            stats.update(create_stats_ordered_dict(
                "VF",
                ptu.get_numpy(v),
            ))

        timer.stop_timer('pg training')

        self.replay_buffer.empty_buffer()
Пример #11
0
    def _train(self):
        done = (self.epoch == self.num_epochs)
        if done:
            return OrderedDict(), done

        timer.start_timer('evaluation sampling')
        if self.epoch % self._eval_epoch_freq == 0:
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
        timer.stop_timer('evaluation sampling')

        if not self._eval_only:
            for _ in range(self.num_train_loops_per_epoch):
                timer.start_timer('training', unique=False)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)
                timer.stop_timer('training')
        log_stats = self._get_diagnostics()
        return log_stats, False
Пример #12
0
    def _train_batch(self):
        done = (self.epoch == self.num_epochs)
        if done:
            return OrderedDict(), done

        if self.epoch == 0 and self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        timer.start_timer('evaluation sampling')
        if self.epoch % self._eval_epoch_freq == 0:
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
        timer.stop_timer('evaluation sampling')

        if not self._eval_only:
            for _ in range(self.num_train_loops_per_epoch):
                timer.start_timer('exploration sampling', unique=False)
                new_env_mod_parms = dict()
                if self.env_mod_dist:
                    current = max(
                        1, self.epoch /
                        (self.mod_env_epoch_schedule * self.num_epochs))
                    for k, v in self.env_mod_params.items():
                        lbound, ubound = v
                        low = current + (1.0 - current) * lbound
                        high = current + (1.0 - current) * ubound
                        new_env_mod_parms[k] = np.random.uniform(low, high)

                else:
                    current = max(
                        1, self.epoch /
                        (self.mod_env_epoch_schedule * self.num_epochs))
                    for k, v in self.env_mod_params.items():
                        new_env_mod_parms[k] = 1.0 * current + (1.0 -
                                                                current) * v
                self.expl_data_collector._env = self.env_class(
                    new_env_mod_parms)

                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                )
                timer.stop_timer('exploration sampling')

                timer.start_timer('replay buffer data storing', unique=False)
                self.replay_buffer.add_paths(new_expl_paths)
                timer.stop_timer('replay buffer data storing')

                timer.start_timer('training', unique=False)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)
                timer.stop_timer('training')
        log_stats = self._get_diagnostics()
        return log_stats, False
Пример #13
0
 def _end_epoch(self):
     timer.start_timer('vae training')
     self._train_vae(self.epoch)
     timer.stop_timer('vae training')
     super()._end_epoch()
Пример #14
0
    def train(self):
        # first train only the Q function
        iteration = 0
        timer.return_global_times = True
        timer.reset()
        for i in range(self.num_batches):
            if self.use_meta_learning_buffer:
                train_data = self.meta_replay_buffer.sample_meta_batch(
                    rl_batch_size=self.batch_size,
                    meta_batch_size=self.meta_batch_size,
                    embedding_batch_size=self.task_embedding_batch_size,
                )
                train_data = np_to_pytorch_batch(train_data)
            else:
                task_indices = np.random.choice(
                    self.train_tasks, self.meta_batch_size,
                )
                train_data = self.replay_buffer.sample_batch(
                    task_indices,
                    self.batch_size,
                )
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                train_data['observations'] = obs
                train_data['next_observations'] = next_obs
                train_data['context'] = (
                    self.task_embedding_replay_buffer.sample_context(
                        task_indices,
                        self.task_embedding_batch_size,
                    ))
            timer.start_timer('train', unique=False)
            self.trainer.train_from_torch(train_data)
            timer.stop_timer('train')
            if i % self.logging_period == 0 or i == self.num_batches - 1:
                stats_with_prefix = add_prefix(
                    self.trainer.eval_statistics, prefix="trainer/")
                self.trainer.end_epoch(iteration)
                logger.record_dict(stats_with_prefix)
                timer.start_timer('extra_fns', unique=False)
                for fn in self._extra_eval_fns:
                    extra_stats = fn()
                    logger.record_dict(extra_stats)
                timer.stop_timer('extra_fns')


                # TODO: evaluate during offline RL
                # eval_stats = self.get_eval_statistics()
                # eval_stats_with_prefix = add_prefix(eval_stats, prefix="eval/")
                # logger.record_dict(eval_stats_with_prefix)

                logger.record_tabular('iteration', iteration)
                logger.record_dict(_get_epoch_timings())
                try:
                    import os
                    import psutil
                    process = psutil.Process(os.getpid())
                    logger.record_tabular('RAM Usage (Mb)', int(process.memory_info().rss / 1000000))
                except ImportError:
                    pass
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                iteration += 1