Ejemplo n.º 1
0
Archivo: mbpo.py Proyecto: anyboby/mbpo
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    print('[ MBPO ] log_dir: {} | ratio: {}'.format(
                        self._log_dir, self._real_ratio))
                    print(
                        '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})'
                        .format(self._epoch, self._model_train_freq,
                                self._timestep, self._total_timestep,
                                self._train_steps_this_epoch,
                                self._num_train_steps))

                    model_train_metrics = self._train_model(
                        batch_size=256,
                        max_epochs=None,
                        holdout_ratio=0.2,
                        max_t=self._max_model_t)
                    model_metrics.update(model_train_metrics)
                    gt.stamp('epoch_train_model')

                    self._set_rollout_length()
                    if self._rollout_batch_size > 30000:
                        factor = self._rollout_batch_size // 30000 + 1
                        mini_batch = self._rollout_batch_size // factor
                        for i in range(factor):
                            model_rollout_metrics = self._rollout_model(
                                rollout_batch_size=mini_batch,
                                deterministic=self._deterministic)
                    else:
                        model_rollout_metrics = self._rollout_model(
                            rollout_batch_size=self._rollout_batch_size,
                            deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}
Ejemplo n.º 2
0
    def _train(self, env, policy, pool, initial_exploration_policy=None):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(env, initial_exploration_policy,
                                           pool)

        self.sampler.initialize(env, policy, pool)
        evaluation_env = env.copy() if self._eval_n_episodes else None

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if samples_now >= start_samples + self._epoch_length:
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy, evaluation_env)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths, env)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_env)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_env, 'render_rollouts'):
                # TODO(hartikainen): Make this consistent such that there's no
                # need for the hasattr check.
                env.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()
Ejemplo n.º 3
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        import gtimer as gt
        from itertools import count
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        training_metrics = [0 for _ in range(self._num_goals)]

        if not self._training_started:
            self._init_training()

            for i in range(self._num_goals):
                self._initial_exploration_hook(
                    training_environment, self._initial_exploration_policy, i)

        self._initialize_samplers()
        self._sample_count = 0

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        print("starting_training")
        self._training_before_hook()
        import time

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')
            start_samples = sum([self._samplers[i]._total_samples for i in range(self._num_goals)])
            sample_times = []
            for i in count():
                samples_now = sum([self._samplers[i]._total_samples for i in range(self._num_goals)])
                self._timestep = samples_now - start_samples

                if samples_now >= start_samples + self._epoch_length and self.ready_to_train:
                    break

                t0 = time.time()
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')
                sample_times.append(time.time() - t0)
                t0 = time.time()
                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')
                # print("Train time: ", time.time() - t0)

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            # TODO diagnostics per goal
            print("Average Sample Time: ", np.mean(np.array(sample_times)))
            print("Step count", self._sample_count)
            training_paths_per_policy = self._training_paths()
            # self.sampler.get_last_n_paths(
            #     math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths_per_policy = self._evaluation_paths()
            gt.stamp('evaluation_paths')

            training_metrics_per_policy = self._evaluate_rollouts(
                training_paths_per_policy, training_environment)
            gt.stamp('training_metrics')

            if evaluation_paths_per_policy:
                evaluation_metrics_per_policy = self._evaluate_rollouts(
                    evaluation_paths_per_policy, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics_per_policy = [{} for _ in range(self._num_goals)]

            self._epoch_after_hook(training_paths_per_policy)
            gt.stamp('epoch_after_hook')

            t0 = time.time()

            sampler_diagnostics_per_policy = [
                self._samplers[i].get_diagnostics() for i in range(self._num_goals)]

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batches=self._evaluation_batches(),
                training_paths_per_policy=training_paths_per_policy,
                evaluation_paths_per_policy=evaluation_paths_per_policy)

            time_diagnostics = gt.get_times().stamps.itrs

            print("Basic diagnostics: ", time.time() - t0)
            print("Sample count: ", self._sample_count)

            diagnostics.update(OrderedDict((
                *(
                    (f'times/{key}', time_diagnostics[key][-1])
                    for key in sorted(time_diagnostics.keys())
                ),
                ('epoch', self._epoch),
                ('timestep', self._timestep),
                ('timesteps_total', self._total_timestep),
                ('train-steps', self._num_train_steps),
            )))

            print("Other basic diagnostics: ", time.time() - t0)
            for i, (evaluation_metrics, training_metrics, sampler_diagnostics) in (
                enumerate(zip(evaluation_metrics_per_policy,
                              training_metrics_per_policy,
                              sampler_diagnostics_per_policy))):
                diagnostics.update(OrderedDict((
                    *(
                        (f'evaluation_{i}/{key}', evaluation_metrics[key])
                        for key in sorted(evaluation_metrics.keys())
                    ),
                    *(
                        (f'training_{i}/{key}', training_metrics[key])
                        for key in sorted(training_metrics.keys())
                    ),
                    *(
                        (f'sampler_{i}/{key}', sampler_diagnostics[key])
                        for key in sorted(sampler_diagnostics.keys())
                    ),
                )))

            # if self._eval_render_kwargs and hasattr(
            #         evaluation_environment, 'render_rollouts'):
            #     # TODO(hartikainen): Make this consistent such that there's no
            #     # need for the hasattr check.
            #     training_environment.render_rollouts(evaluation_paths)

            yield diagnostics
            print("Diagnostic time: ",  time.time() - t0)

        for i in range(self._num_goals):
            self._samplers[i].terminate()

        self._training_after_hook()

        del evaluation_paths_per_policy

        yield {'done': True, **diagnostics}
Ejemplo n.º 4
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        import time
        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')
            start_samples = self.sampler._total_samples

            sample_times = []
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')
                t0 = time.time()
                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')
                sample_times.append(time.time() - t0)
                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            print("Average Sample Time: ", np.mean(np.array(sample_times)))

            training_paths = self._training_paths()
            # self.sampler.get_last_n_paths(
            #     math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')

            #should_save_path = (
            #    self._path_save_frequency > 0
            #    and self._epoch % self._path_save_frequency == 0)
            #if should_save_path:
            #    import pickle
            #    for i, path in enumerate(training_paths):
            #        #path.pop('images')
            #        path_file_name = f'training_path_{self._epoch}_{i}.pkl'
            #        path_file_path = os.path.join(
            #            os.getcwd(), 'paths', path_file_name)
            #        if not os.path.exists(os.path.dirname(path_file_path)):
            #            os.makedirs(os.path.dirname(path_file_path))
            #        with open(path_file_path, 'wb' ) as f:
            #            pickle.dump(path, f)

            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            obs = self._pool.last_n_batch(
                self._pool.size)['observations']['state_observation']
            plt.cla()
            plt.clf()
            plt.xlim(-20, 20)
            plt.ylim(-20, 20)
            plt.plot(obs[:, 0], obs[:, 1])
            plt.savefig('traj_plot_%d.png' % (self._epoch))
            if self._rnd_int_rew_coeff:
                errors = []
                for i in np.arange(-20, 20, 0.5):
                    error = []
                    for j in np.arange(-20, 20, 0.5):
                        curr_pos = np.array([i, j])
                        err = self._session.run(
                            self._rnd_errors, {
                                self._placeholders['observations']['state_observation']:
                                [curr_pos]
                            })[0]
                        error.append(err)
                    errors.append(error)
                plt.cla()
                plt.clf()
                plt.imshow(np.asarray(errors)[:, :, 0])
                plt.savefig('errors_%d.png' % (self._epoch))
            if self._eval_render_kwargs and hasattr(evaluation_environment,
                                                    'render_rollouts'):
                # TODO(hartikainen): Make this consistent such that there's no
                # need for the hasattr check.
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        del evaluation_paths

        yield {'done': True, **diagnostics}
Ejemplo n.º 5
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            update_diagnostics = []

            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    update_diagnostics.append(
                        self._do_training_repeats(
                            timestep=self._total_timestep))

                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            update_diagnostics = tree.map_structure(lambda *d: np.mean(d),
                                                    *update_diagnostics)

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment,
                                                       self._total_timestep,
                                                       evaluation_type='train')
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths,
                    evaluation_environment,
                    self._total_timestep,
                    evaluation_type='evaluation')
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = {
                key: times[-1]
                for key, times in gt.get_times().stamps.itrs.items()
            }

            # TODO(hartikainen/tf2): Fix the naming of training/update
            # diagnostics/metric
            diagnostics.update((
                ('evaluation', evaluation_metrics),
                ('training', training_metrics),
                ('update', update_diagnostics),
                ('times', time_diagnostics),
                ('sampler', sampler_diagnostics),
                ('epoch', self._epoch),
                ('timestep', self._timestep),
                ('total_timestep', self._total_timestep),
                ('num_train_steps', self._num_train_steps),
            ))

            if self._eval_render_kwargs and hasattr(evaluation_environment,
                                                    'render_rollouts'):
                # TODO(hartikainen): Make this consistent such that there's no
                # need for the hasattr check.
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        yield {'done': True, **diagnostics}
Ejemplo n.º 6
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """

        #### pool is e.g. simple_replay_pool
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool

        if not self._training_started:
            #### perform some initial steps (gather samples) using initial policy
            ######  fills pool with _n_initial_exploration_steps samples
            self._initial_exploration_hook(training_environment, self._policy,
                                           pool)

        #### set up sampler with train env and actual policy (may be different from initial exploration policy)
        ######## note: sampler is set up with the pool that may be already filled from initial exploration hook
        self.sampler.initialize(training_environment, policy, pool)
        self.model_sampler.initialize(self.fake_env, policy, self.model_pool)
        rollout_dkl_lim = self.model_sampler.compute_dynamics_dkl(
            obs_batch=self._pool.rand_batch_from_archive(
                5000, fields=['observations'])['observations'],
            depth=self._rollout_schedule[2])
        self.model_sampler.set_rollout_dkl(rollout_dkl_lim)
        self.initial_model_dkl = self.model_sampler.dyn_dkl
        #### reset gtimer (for coverage of project development)
        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)
        self.policy_epoch = 0  ### count policy updates
        self.new_real_samples = 0
        self.last_eval_step = 0
        self.diag_counter = 0
        running_diag = {}
        self.approx_model_batch = self.batch_size_policy - self.min_real_samples_per_epoch  ### some size to start off

        #### not implemented, could train policy before hook
        self._training_before_hook()

        #### iterate over epochs, gt.timed_for to create loop with gt timestamps
        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            #### do something at beginning of epoch (in this case reset self._train_steps_this_epoch=0)
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            #### util class Progress, e.g. for plotting a progress bar
            #######   note: sampler may already contain samples in its pool from initial_exploration_hook or previous epochs
            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat /
                                               self._train_every_n_steps)

            samples_added = 0
            #=====================================================================#
            #            Rollout model                                            #
            #=====================================================================#
            model_samples = None
            keep_rolling = True
            model_metrics = {}
            #### start model rollout
            if self._real_ratio < 1.0:  #if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                #=====================================================================#
                #                           Model Rollouts                            #
                #=====================================================================#
                if self.rollout_mode == 'schedule':
                    self._set_rollout_length()

                while keep_rolling:
                    ep_b = self._pool.epoch_batch(
                        batch_size=self._rollout_batch_size,
                        epochs=self._pool.epochs_list,
                        fields=['observations', 'pi_infos'])
                    kls = np.clip(self._policy.compute_DKL(
                        ep_b['observations'], ep_b['mu'], ep_b['log_std']),
                                  a_min=0,
                                  a_max=None)
                    btz_dist = self._pool.boltz_dist(kls,
                                                     alpha=self.policy_alpha)
                    btz_b = self._pool.distributed_batch_from_archive(
                        self._rollout_batch_size,
                        btz_dist,
                        fields=['observations', 'pi_infos'])
                    start_states, mus, logstds = btz_b['observations'], btz_b[
                        'mu'], btz_b['log_std']
                    btz_kl = np.clip(self._policy.compute_DKL(
                        start_states, mus, logstds),
                                     a_min=0,
                                     a_max=None)

                    self.model_sampler.reset(start_states)
                    if self.rollout_mode == 'uncertainty':
                        self.model_sampler.set_max_uncertainty(
                            self.max_tddyn_err)

                    for i in count():
                        # print(f'Model Sampling step Nr. {i+1}')

                        _, _, _, info = self.model_sampler.sample(
                            max_samples=int(self.approx_model_batch -
                                            samples_added))

                        if self.model_sampler._total_samples + samples_added >= .99 * self.approx_model_batch:
                            keep_rolling = False
                            break

                        if info['alive_ratio'] <= 0.1: break

                    ### diagnostics for rollout ###
                    rollout_diagnostics = self.model_sampler.finish_all_paths()
                    if self.rollout_mode == 'iv_gae':
                        keep_rolling = self.model_pool.size + samples_added <= .99 * self.approx_model_batch

                    ######################################################################
                    ### get model_samples, get() invokes the inverse variance rollouts ###
                    model_samples_new, buffer_diagnostics_new = self.model_pool.get(
                    )
                    model_samples = [
                        np.concatenate((o, n), axis=0)
                        for o, n in zip(model_samples, model_samples_new)
                    ] if model_samples else model_samples_new

                    ######################################################################
                    ### diagnostics
                    new_n_samples = len(model_samples_new[0]) + EPS
                    diag_weight_old = samples_added / (new_n_samples +
                                                       samples_added)
                    diag_weight_new = new_n_samples / (new_n_samples +
                                                       samples_added)
                    model_metrics = update_dict(model_metrics,
                                                rollout_diagnostics,
                                                weight_a=diag_weight_old,
                                                weight_b=diag_weight_new)
                    model_metrics = update_dict(model_metrics,
                                                buffer_diagnostics_new,
                                                weight_a=diag_weight_old,
                                                weight_b=diag_weight_new)
                    ### run diagnostics on model data
                    if buffer_diagnostics_new['poolm_batch_size'] > 0:
                        model_data_diag = self._policy.run_diagnostics(
                            model_samples_new)
                        model_data_diag = {
                            k + '_m': v
                            for k, v in model_data_diag.items()
                        }
                        model_metrics = update_dict(model_metrics,
                                                    model_data_diag,
                                                    weight_a=diag_weight_old,
                                                    weight_b=diag_weight_new)

                    samples_added += new_n_samples
                    model_metrics.update({'samples_added': samples_added})
                    ######################################################################

                ## for debugging
                model_metrics.update({
                    'cached_var':
                    np.mean(self.fake_env._model.scaler_out.cached_var)
                })
                model_metrics.update({
                    'cached_mu':
                    np.mean(self.fake_env._model.scaler_out.cached_mu)
                })

                print(f'Rollouts finished')
                gt.stamp('epoch_rollout_model')

            #=====================================================================#
            #  Sample                                                             #
            #=====================================================================#
            n_real_samples = self.model_sampler.dyn_dkl / self.initial_model_dkl * self.min_real_samples_per_epoch
            n_real_samples = max(n_real_samples, 1000)
            # n_real_samples = self.min_real_samples_per_epoch ### for ablation

            model_metrics.update({'n_real_samples': n_real_samples})
            start_samples = self.sampler._total_samples
            ### train for epoch_length ###
            for i in count():

                #### _timestep is within an epoch
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                #### not implemented atm
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                ##### śampling from the real world ! #####
                _, _, _, _ = self._do_sampling(timestep=self.policy_epoch)
                gt.stamp('sample')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

                if self.ready_to_train or self._timestep > n_real_samples:
                    self.sampler.finish_all_paths(append_val=True,
                                                  append_cval=True,
                                                  reset_path=False)
                    self.new_real_samples += self._timestep
                    break

            #=====================================================================#
            #  Train model                                                        #
            #=====================================================================#
            if self.new_real_samples > 2048 and self._real_ratio < 1.0:
                model_diag = self.train_model(min_epochs=1, max_epochs=10)
                self.new_real_samples = 0
                model_metrics.update(model_diag)

            #=====================================================================#
            #  Get Buffer Data                                                    #
            #=====================================================================#
            real_samples, buf_diag = self._pool.get()

            ### run diagnostics on real data
            policy_diag = self._policy.run_diagnostics(real_samples)
            policy_diag = {k + '_r': v for k, v in policy_diag.items()}
            model_metrics.update(policy_diag)
            model_metrics.update(buf_diag)

            #=====================================================================#
            #  Update Policy                                                      #
            #=====================================================================#
            train_samples = [
                np.concatenate((r, m), axis=0)
                for r, m in zip(real_samples, model_samples)
            ] if model_samples else real_samples
            self._policy.update_real_c(real_samples)
            self._policy.update_policy(train_samples)
            self._policy.update_critic(
                train_samples,
                train_vc=(train_samples[-3] >
                          0).any())  ### only train vc if there are any costs

            if self._real_ratio < 1.0:
                self.approx_model_batch = self.batch_size_policy - n_real_samples  #self.model_sampler.dyn_dkl/self.initial_model_dkl * self.min_real_samples_per_epoch

            self.policy_epoch += 1
            self.max_tddyn_err *= self.max_tddyn_err_decay
            #### log policy diagnostics
            self._policy.log()

            gt.stamp('train')
            #=====================================================================#
            #  Log performance and stats                                          #
            #=====================================================================#

            self.sampler.log()
            # write results to file, ray prints for us, so no need to print from logger
            logger_diagnostics = self.logger.dump_tabular(
                output_dir=self._log_dir, print_out=False)
            #=====================================================================#

            if self._total_timestep // self.eval_every_n_steps > self.last_eval_step:
                evaluation_paths = self._evaluation_paths(
                    policy, evaluation_environment)
                gt.stamp('evaluation_paths')

                self.last_eval_step = self._total_timestep // self.eval_every_n_steps
            else:
                evaluation_paths = []

            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
                diag_obs_batch = np.concatenate(([
                    evaluation_paths[i]['observations']
                    for i in range(len(evaluation_paths))
                ]),
                                                axis=0)
            else:
                evaluation_metrics = {}
                diag_obs_batch = []

            gt.stamp('epoch_after_hook')

            new_diagnostics = {}

            time_diagnostics = gt.get_times().stamps.itrs

            # add diagnostics from logger
            new_diagnostics.update(logger_diagnostics)

            new_diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            #### updateing and averaging
            old_ts_diag = running_diag.get('timestep', 0)
            new_ts_diag = self._total_timestep - self.diag_counter - old_ts_diag
            w_olddiag = old_ts_diag / (new_ts_diag + old_ts_diag)
            w_newdiag = new_ts_diag / (new_ts_diag + old_ts_diag)
            running_diag = update_dict(running_diag,
                                       new_diagnostics,
                                       weight_a=w_olddiag,
                                       weight_b=w_newdiag)
            running_diag.update({'timestep': new_ts_diag + old_ts_diag})
            ####

            if new_ts_diag + old_ts_diag > self.eval_every_n_steps:
                running_diag.update({
                    'epoch': self._epoch,
                    'timesteps_total': self._total_timestep,
                    'train-steps': self._num_train_steps,
                })
                self.diag_counter = self._total_timestep
                diag = running_diag.copy()
                running_diag = {}
                yield diag

            if self._total_timestep >= self.n_env_interacts:
                self.sampler.terminate()

                self._training_after_hook()

                self._training_progress.close()
                print("###### DONE ######")
                yield {'done': True, **running_diag}

                break
Ejemplo n.º 7
0
        algorithm="SAC",
        version="normal",
        layer_size=256,
        replay_buffer_size=int(1E6),
        algorithm_kwargs=dict(
            num_epochs=3,
            num_eval_steps_per_epoch=1000,
            num_trains_per_train_loop=100,
            num_expl_steps_per_train_loop=1000,
            min_num_steps_before_training=1000,
            max_path_length=300,
            batch_size=50,
        ),
        trainer_kwargs=dict(
            discount=0.99,
            soft_target_tau=5e-3,
            target_update_period=1,
            policy_lr=3E-4,
            qf_lr=3E-4,
            reward_scale=1,
            use_automatic_entropy_tuning=True,
        ),
    )
    setup_logger('name-of-experiment',
                 variant=variant,
                 log_dir='d:/tmp2',
                 to_file_only=False)
    ptu.set_gpu_mode(True)  # optionally set the GPU (default=False)
    gt.reset_root()  # for interactive restarts
    experiment(variant)
Ejemplo n.º 8
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        # policy = self._policy
        pool = self._pool
        model_metrics = {}

        # if not self._training_started:
        self._init_training()

        # TODO: change policy to placeholder or a function
        def get_action(state, hidden, deterministic=False):
            return self.get_action_meta(state, hidden, deterministic)

        def make_init_hidden(batch_size=1):
            return self.make_init_hidden(batch_size)

        self.sampler.initialize(training_environment,
                                (get_action, make_init_hidden), pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        # self._training_before_hook()

        #### model training
        print('[ MOPO ] log_dir: {} | ratio: {}'.format(
            self._log_dir, self._real_ratio))
        print(
            '[ MOPO ] Training model at epoch {} | freq {} | timestep {} (total: {})'
            .format(self._epoch, self._model_train_freq, self._timestep,
                    self._total_timestep))
        # train dynamics model offline
        max_epochs = 1 if self._model.model_loaded else None
        model_train_metrics = self._train_model(batch_size=256,
                                                max_epochs=max_epochs,
                                                holdout_ratio=0.2,
                                                max_t=self._max_model_t)

        model_metrics.update(model_train_metrics)
        self._log_model()
        gt.stamp('epoch_train_model')
        ####
        tester.time_step_holder.set_time(0)
        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            training_logs = {}
            for timestep in count():
                self._timestep = timestep
                if (timestep >= self._epoch_length and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                ## model rollouts
                if timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    self._set_rollout_length()
                    self._reallocate_model_pool()
                    model_rollout_metrics = self._rollout_model(
                        rollout_batch_size=self._rollout_batch_size,
                        deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    self._training_progress.resume()

                ## train actor and critic
                if self.ready_to_train:
                    # print('[ DEBUG ]: ready to train at timestep: {}'.format(timestep))
                    training_logs = self._do_training_repeats(
                        timestep=timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            # evaluate the polices
            evaluation_paths = self._evaluation_paths(
                (lambda _state, _hidden: get_action(_state, _hidden, True),
                 make_init_hidden), evaluation_environment)
            gt.stamp('evaluation_paths')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict(
                    (*(('evaluation/{}'.format(key), evaluation_metrics[key])
                       for key in sorted(evaluation_metrics.keys())),
                     *(('times/{}'.format(key), time_diagnostics[key][-1])
                       for key in sorted(time_diagnostics.keys())),
                     *(('sampler/{}'.format(key), sampler_diagnostics[key])
                       for key in sorted(sampler_diagnostics.keys())),
                     *(('model/{}'.format(key), model_metrics[key])
                       for key in sorted(model_metrics.keys())),
                     ('epoch', self._epoch), ('timestep', self._timestep),
                     ('timesteps_total',
                      self._total_timestep), ('train-steps',
                                              self._num_train_steps),
                     *(('training/{}'.format(key), training_logs[key])
                       for key in sorted(training_logs.keys())))))
            diagnostics['perf/AverageReturn'] = diagnostics[
                'evaluation/return-average']
            diagnostics['perf/AverageLength'] = diagnostics[
                'evaluation/episode-length-avg']
            if not self.min_ret == self.max_ret:
                diagnostics['perf/NormalizedReturn'] = (diagnostics['perf/AverageReturn'] - self.min_ret) \
                                                       / (self.max_ret - self.min_ret)
            # diagnostics['keys/logp_pi'] =  diagnostics['training/sac_pi/logp_pi']
            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            ## ensure we did not collect any more data
            assert self._pool.size == self._init_pool_size
            for k, v in diagnostics.items():
                # print('[ DEBUG ] epoch: {} diagnostics k: {}, v: {}'.format(self._epoch, k, v))
                self._writer.add_scalar(k, v, self._epoch)
            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}
Ejemplo n.º 9
0
Archivo: mopo.py Proyecto: numahha/mopo
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            if self._epoch % 200 == 0:
                #### model training
                print('[ MOPO ] log_dir: {} | ratio: {}'.format(
                    self._log_dir, self._real_ratio))
                print(
                    '[ MOPO ] Training model at epoch {} | freq {} | timestep {} (total: {})'
                    .format(self._epoch, self._model_train_freq,
                            self._timestep, self._total_timestep))
                max_epochs = 1 if self._model.model_loaded else None
                model_train_metrics = self._train_model(
                    batch_size=256,
                    max_epochs=max_epochs,
                    holdout_ratio=0.2,
                    max_t=self._max_model_t)
                model_metrics.update(model_train_metrics)
                self._log_model()
                gt.stamp('epoch_train_model')
                ####

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for timestep in count():
                self._timestep = timestep

                if (timestep >= self._epoch_length and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                ## model rollouts
                if timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    self._set_rollout_length()
                    self._reallocate_model_pool()
                    model_rollout_metrics = self._rollout_model(
                        rollout_batch_size=self._rollout_batch_size,
                        deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    self._training_progress.resume()

                ## train actor and critic
                if self.ready_to_train:
                    self._do_training_repeats(timestep=timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))

            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            ## ensure we did not collect any more data
            assert self._pool.size == self._init_pool_size

            yield diagnostics

        epi_ret = self._rollout_model_for_eval(
            self._training_environment.reset)
        np.savetxt("EEepi_ret__fin.csv", epi_ret, delimiter=',')

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}
Ejemplo n.º 10
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            print('starting epoch', self._epoch)
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            start_samples = self.sampler._total_samples
            print('start samples', start_samples)
            for i in count():
                samples_now = self.sampler._total_samples
                # print('samples now', samples_now)
                self._timestep = samples_now - start_samples
                if (-samples_now +
                    (start_samples + self._epoch_length)) % 100 == 0:
                    print('samples needed',
                          -samples_now + (start_samples + self._epoch_length))

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            print('after hook', self._epoch)

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')

            should_save_path = (self._path_save_frequency > 0 and
                                self._epoch % self._path_save_frequency == 0)
            if should_save_path:
                import pickle
                for i, path in enumerate(training_paths):
                    #path.pop('images')
                    path_file_name = f'training_path_{self._epoch}_{i}.pkl'
                    path_file_path = os.path.join(os.getcwd(), 'paths',
                                                  path_file_name)
                    if not os.path.exists(os.path.dirname(path_file_path)):
                        os.makedirs(os.path.dirname(path_file_path))
                    with open(path_file_path, 'wb') as f:
                        pickle.dump(path, f)

            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                # TODO(hartikainen): Make this consistent such that there's no
                # need for the hasattr check.
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        yield {'done': True, **diagnostics}
def train_model(variant):
    gt.reset_root()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = f"./output/train_out_{timestamp}/"

    setup_logger('name-of-experiment',
                 variant=variant,
                 snapshot_mode='gap_and_last',
                 snapshot_gap=20,
                 log_dir=log_dir)

    expl_env_kwargs = variant['expl_env_kwargs']
    eval_env_kwargs = variant['eval_env_kwargs']
    trainer_kwargs = variant['trainer_kwargs']

    df_ret_train, df_ret_val, df_feature = load_dataset()
    df_ret_train.to_csv(os.path.join(log_dir, 'df_ret_train.csv'))
    df_ret_val.to_csv(os.path.join(log_dir, 'df_ret_val.csv'))
    df_feature.to_csv(os.path.join(log_dir, 'df_feature.csv'))
    expl_env = NormalizedBoxEnv(
        gym.make('MarketEnv-v0',
                 returns=df_ret_train,
                 features=df_feature,
                 **expl_env_kwargs))

    eval_env = NormalizedBoxEnv(
        gym.make('MarketEnv-v0',
                 returns=df_ret_val,
                 features=df_feature,
                 **eval_env_kwargs))

    def post_epoch_func(self, epoch):
        progress_csv = os.path.join(log_dir, 'progress.csv')
        df = pd.read_csv(progress_csv)
        kpis = ['cagr', 'dd', 'mdd', 'wealths', 'std']
        srcs = ['evaluation', 'exploration']
        n = 50
        for kpi in kpis:
            series = map(lambda s: df[f'{s}/env_infos/final/{kpi} Mean'], srcs)
            plot_ma(series=series, lables=srcs, title=kpi, n=n)
            plt.savefig(os.path.join(log_dir, f'{kpi}.png'))
            plt.close()

    trainer = get_trainer(env=eval_env, **trainer_kwargs)
    policy = trainer.policy
    eval_policy = MakeDeterministic(policy)
    #eval_policy = policy
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.post_epoch_funcs = [
        post_epoch_func,
    ]
    algorithm.to(ptu.device)
    algorithm.train()
Ejemplo n.º 12
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """

        #### pool is e.g. simple_replay_pool
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        #### init Qs for SAC
        if not self._training_started:
            self._init_training()

            #### perform some initial steps (gather samples) using initial policy
            ######  fills pool with _n_initial_exploration_steps samples
            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        #### set up sampler with train env and actual policy (may be different from initial exploration policy)
        ######## note: sampler is set up with the pool that may be already filled from initial exploration hook
        self.sampler.initialize(training_environment, policy, pool)

        #### reset gtimer (for coverage of project development)
        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        #### not implemented, could train policy before hook
        self._training_before_hook()

        #### iterate over epochs, gt.timed_for to create loop with gt timestamps
        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            #### do something at beginning of epoch (in this case reset self._train_steps_this_epoch=0)
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            #### util class Progress, e.g. for plotting a progress bar
            #######   note: sampler may already contain samples in its pool from initial_exploration_hook or previous epochs
            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat /
                                               self._train_every_n_steps)
            start_samples = self.sampler._total_samples

            ### train for epoch_length ###
            for i in count():

                #### _timestep is within an epoch
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                #### check if you're at the end of an epoch to train
                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                #### not implemented atm
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                #### start model rollout
                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    print('[ MBPO ] log_dir: {} | ratio: {}'.format(
                        self._log_dir, self._real_ratio))
                    print(
                        '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})'
                        .format(self._epoch, self._model_train_freq,
                                self._timestep, self._total_timestep,
                                self._train_steps_this_epoch,
                                self._num_train_steps))

                    #### train the model with input:(obs, act), outputs: (rew, delta_obs), inputs are divided into sets with holdout_ratio
                    #@anyboby debug
                    samples = self._pool.return_all_samples()
                    self.fake_env.reset_model()
                    model_train_metrics = self.fake_env.train(
                        samples,
                        batch_size=512,
                        max_epochs=None,
                        holdout_ratio=0.2,
                        max_t=self._max_model_t)
                    model_metrics.update(model_train_metrics)
                    gt.stamp('epoch_train_model')

                    #### rollout model env ####
                    self._set_rollout_length()
                    self._reallocate_model_pool(
                        use_mjc_model_pool=self.use_mjc_state_model)
                    model_rollout_metrics = self._rollout_model(
                        rollout_batch_size=self._rollout_batch_size,
                        deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)
                    ###########################

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                ##### śampling from the real world ! #####
                ##### _total_timestep % train_every_n_steps is checked inside _do_sampling
                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                ### n_train_repeat from config ###
                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        ### this is where we yield the episode diagnostics to tune trial runner ###
        yield {'done': True, **diagnostics}
Ejemplo n.º 13
0
    def _train(self):
        """Return a generator that runs the standard RL loop."""
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            update_diagnostics = []

            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    repeat_diagnostics = self._do_training_repeats(
                        timestep=self._total_timestep)
                    if repeat_diagnostics is not None:
                        update_diagnostics.append(repeat_diagnostics)

                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            update_diagnostics = tree.map_structure(lambda *d: np.mean(d),
                                                    *update_diagnostics)

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            """
            evaluation_paths = self._evaluation_paths(
                policy, evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(
                training_paths,
                training_environment,
                self._total_timestep,
                evaluation_type='train')
            gt.stamp('training_metrics')
            if False: # evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths,
                    evaluation_environment,
                    self._total_timestep,
                    evaluation_type='evaluation')
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}
            """
            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(iteration=self._total_timestep,
                                               batch=self._evaluation_batch(),
                                               training_paths=training_paths,
                                               evaluation_paths=None)

            time_diagnostics = {
                key: times[-1]
                for key, times in gt.get_times().stamps.itrs.items()
            }

            # TODO(hartikainen/tf2): Fix the naming of training/update
            # diagnostics/metric
            diagnostics.update((
                ('evaluation', None),
                ('training', None),
                ('update', update_diagnostics),
                ('times', time_diagnostics),
                ('sampler', sampler_diagnostics),
                ('epoch', self._epoch),
                ('timestep', self._timestep),
                ('total_timestep', self._total_timestep),
                ('num_train_steps', self._num_train_steps),
            ))

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        yield {'done': True, **diagnostics}