Beispiel #1
0
 def __init__(self,
              env,
              policy,
              n_itr=500,
              max_path_length=500,
              discount=0.99,
              sigma0=1.,
              batch_size=None,
              plot=False,
              **kwargs):
     """
     :param n_itr: Number of iterations.
     :param max_path_length: Maximum length of a single rollout.
     :param batch_size: # of samples from trajs from param distribution,
      when this is set, n_samples is ignored
     :param discount: Discount.
     :param plot: Plot evaluation run after each iteration.
     :param sigma0: Initial std for param dist
     :return:
     """
     Serializable.quick_init(self, locals())
     self.env = env
     self.policy = policy
     self.plot = plot
     self.sigma0 = sigma0
     self.discount = discount
     self.max_path_length = max_path_length
     self.n_itr = n_itr
     self.batch_size = batch_size
     self.plotter = Plotter()
Beispiel #2
0
 def _start_worker(self):
     """Start Plotter and Sampler workers."""
     if self._plot:
         # pylint: disable=import-outside-toplevel
         from garage.plotter import Plotter
         self._plotter = Plotter()
         self._plotter.init_plot(self.get_env_copy(), self._algo.policy)
Beispiel #3
0
 def _start_worker(self):
     """Start Plotter and Sampler workers."""
     self._sampler.start_worker()
     if self._plot:
         # pylint: disable=import-outside-toplevel
         from garage.tf.plotter import Plotter
         self._plotter = Plotter(self.get_env_copy(),
                                 self._algo.policy,
                                 sess=tf.compat.v1.get_default_session())
         self._plotter.start()
Beispiel #4
0
 def __init__(self,
              env,
              policy,
              n_itr=500,
              max_path_length=500,
              discount=0.99,
              init_std=1.,
              n_samples=100,
              batch_size=None,
              best_frac=0.05,
              extra_std=1.,
              extra_decay_time=100,
              plot=False,
              n_evals=1,
              play_every_itr=None,
              play_rollouts_num=3,
              **kwargs):
     """
     :param n_itr: Number of iterations.
     :param max_path_length: Maximum length of a single rollout.
     :param batch_size: # of samples from trajs from param distribution,
      when this is set, n_samples is ignored
     :param discount: Discount.
     :param plot: Plot evaluation run after each iteration.
     :param init_std: Initial std for param distribution
     :param extra_std: Decaying std added to param distribution at each
      iteration
     :param extra_decay_time: Iterations that it takes to decay extra std
     :param n_samples: #of samples from param distribution
     :param best_frac: Best fraction of the sampled params
     :param n_evals: # of evals per sample from the param distr. returned
      score is mean - stderr of evals
     :return:
     """
     Serializable.quick_init(self, locals())
     self.env = env
     self.policy = policy
     self.batch_size = batch_size
     self.plot = plot
     self.extra_decay_time = extra_decay_time
     self.extra_std = extra_std
     self.best_frac = best_frac
     self.n_samples = n_samples
     self.init_std = init_std
     self.discount = discount
     self.max_path_length = max_path_length
     self.n_itr = n_itr
     self.n_evals = n_evals
     self.plotter = Plotter()
     self.play_every_itr = play_every_itr
     self.play_rollouts_num = play_rollouts_num
Beispiel #5
0
    def train(self):
        plotter = Plotter()
        if self.plot:
            plotter.init_plot(self.env, self.policy)
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #{} | '.format(itr)):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params['algo'] = self
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                snapshotter.save_itr_params(itr, params)
                logger.log('saved')
                logger.log(tabular)
                if self.plot:
                    plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                              'continue...')

        plotter.close()
        self.shutdown_worker()
Beispiel #6
0
    def train(self):
        plotter = Plotter()
        if self.plot:
            plotter.init_plot(self.env, self.policy)
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        plotter.close()
        self.shutdown_worker()
    def train(self):
        address = ('localhost', 6000)
        conn = Client(address)
        try:
            plotter = Plotter()
            if self.plot:
                plotter.init_plot(self.env, self.policy)
            conn.send(ExpLifecycle.START)
            self.start_worker()
            self.init_opt()
            for itr in range(self.current_itr, self.n_itr):
                with logger.prefix('itr #{} | '.format(itr)):
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.sampler.obtain_samples(itr)
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.sampler.process_samples(itr, paths)
                    self.log_diagnostics(paths)
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log('saving snapshot...')
                    params = self.get_itr_snapshot(itr, samples_data)
                    self.current_itr = itr + 1
                    params['algo'] = self
                    if self.store_paths:
                        params['paths'] = samples_data['paths']
                    snapshotter.save_itr_params(itr, params)
                    logger.log('saved')
                    logger.log(tabular)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        plotter.update_plot(self.policy, self.max_path_length)
                        if self.pause_for_plot:
                            input('Plotting evaluation run: Press Enter to '
                                  'continue...')

            conn.send(ExpLifecycle.SHUTDOWN)
            plotter.close()
            self.shutdown_worker()
        finally:
            conn.close()
Beispiel #8
0
    def __init__(self,
                 env_spec,
                 policy,
                 baseline,
                 n_samples,
                 gae_lambda=1,
                 max_path_length=500,
                 discount=0.99,
                 init_std=1,
                 best_frac=0.05,
                 extra_std=1.,
                 extra_decay_time=100,
                 **kwargs):
        self.env_spec = env_spec
        self.policy = policy
        self.baseline = baseline
        self.n_samples = n_samples
        self.extra_decay_time = extra_decay_time
        self.extra_std = extra_std
        self.best_frac = best_frac
        self.init_std = init_std
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.max_path_length = max_path_length
        self.plotter = Plotter()

        # epoch-wise
        self.cur_std = self.init_std
        self.cur_mean = self.policy.get_param_values()
        # epoch-cycle-wise
        self.cur_params = self.cur_mean
        self.all_returns = []
        self.all_params = [self.cur_mean.copy()]
        # fixed
        self.n_best = int(n_samples * best_frac)
        assert self.n_best >= 1, (
            f"n_samples is too low. Make sure that n_samples * best_frac >= 1")
        self.n_params = len(self.cur_mean)
Beispiel #9
0
class LocalRunner:
    """Base class of local runner.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.

    Note:
        For the use of any TensorFlow environments, policies and algorithms,
        please use LocalTFRunner().

    Examples:
        | # to train
        | runner = LocalRunner()
        | env = Env(...)
        | policy = Policy(...)
        | algo = Algo(
        |         env=env,
        |         policy=policy,
        |         ...)
        | runner.setup(algo, env)
        | runner.train(n_epochs=100, batch_size=4000)

        | # to resume immediately.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume()

        | # to resume with modified training arguments.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume(n_epochs=20)

    """
    def __init__(self, snapshot_config):
        self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                        snapshot_config.snapshot_mode,
                                        snapshot_config.snapshot_gap)

        self._has_setup = False
        self._plot = False

        self._setup_args = None
        self._train_args = None
        self._stats = ExperimentStats(total_itr=0,
                                      total_env_steps=0,
                                      total_epoch=0,
                                      last_path=None)

        self._algo = None
        self._env = None
        self._sampler = None
        self._plotter = None

        self._start_time = None
        self._itr_start_time = None
        self.step_itr = None
        self.step_path = None

        # only used for off-policy algorithms
        self.enable_logging = True

        self._n_workers = None
        self._worker_class = None
        self._worker_args = None

    def make_sampler(self,
                     sampler_cls,
                     *,
                     seed=None,
                     n_workers=psutil.cpu_count(logical=False),
                     max_episode_length=None,
                     worker_class=None,
                     sampler_args=None,
                     worker_args=None):
        """Construct a Sampler from a Sampler class.

        Args:
            sampler_cls (type): The type of sampler to construct.
            seed (int): Seed to use in sampler workers.
            max_episode_length (int): Maximum path length to be sampled by the
                sampler. Paths longer than this will be truncated.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the Sampler should use.
            sampler_args (dict or None): Additional arguments that should be
                passed to the sampler.
            worker_args (dict or None): Additional arguments that should be
                passed to the sampler.

        Raises:
            ValueError: If `max_episode_length` isn't passed and the algorithm
                doesn't contain a `max_episode_length` field, or if the
                algorithm doesn't have a policy field.

        Returns:
            sampler_cls: An instance of the sampler class.

        """
        policy = getattr(self._algo, 'exploration_policy', None)
        if policy is None:
            policy = getattr(self._algo, 'policy', None)
        if policy is None:
            raise ValueError('If the runner is used to construct a sampler, '
                             'the algorithm must have a `policy` or '
                             '`exploration_policy` field.')
        if max_episode_length is None:
            if hasattr(self._algo, 'max_episode_length'):
                max_episode_length = self._algo.max_episode_length
        if max_episode_length is None:
            raise ValueError('If `sampler_cls` is specified in runner.setup, '
                             'the algorithm must specify `max_episode_length`')
        if worker_class is None:
            worker_class = getattr(self._algo, 'worker_cls', DefaultWorker)
        if seed is None:
            seed = get_seed()
        if sampler_args is None:
            sampler_args = {}
        if worker_args is None:
            worker_args = {}

        return sampler_cls.from_worker_factory(WorkerFactory(
            seed=seed,
            max_episode_length=max_episode_length,
            n_workers=n_workers,
            worker_class=worker_class,
            worker_args=worker_args),
                                               agents=policy,
                                               envs=self._env)

    def setup(self,
              algo,
              env,
              sampler_cls=None,
              sampler_args=None,
              n_workers=psutil.cpu_count(logical=False),
              worker_class=DefaultWorker,
              worker_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the sampler should use.
            worker_args (dict or None): Additional arguments that should be
                passed to the worker.

        Raises:
            ValueError: If sampler_cls is passed and the algorithm doesn't
                contain a `max_episode_length` field.

        """
        self._algo = algo
        self._env = env
        self._n_workers = n_workers
        self._worker_class = worker_class
        if sampler_args is None:
            sampler_args = {}
        if sampler_cls is None:
            sampler_cls = getattr(algo, 'sampler_cls', None)
        if worker_args is None:
            worker_args = {}

        self._worker_args = worker_args
        if sampler_cls is None:
            self._sampler = None
        else:
            self._sampler = self.make_sampler(sampler_cls,
                                              sampler_args=sampler_args,
                                              n_workers=n_workers,
                                              worker_class=worker_class,
                                              worker_args=worker_args)

        self._has_setup = True

        self._setup_args = SetupArgs(sampler_cls=sampler_cls,
                                     sampler_args=sampler_args,
                                     seed=get_seed())

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        if self._plot:
            # pylint: disable=import-outside-toplevel
            from garage.plotter import Plotter
            self._plotter = Plotter()
            self._plotter.init_plot(self.get_env_copy(), self._algo.policy)

    def _shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        if self._sampler is not None:
            self._sampler.shutdown_worker()
        if self._plot:
            self._plotter.close()

    def obtain_trajectories(self,
                            itr,
                            batch_size=None,
                            agent_update=None,
                            env_update=None):
        """Obtain one batch of trajectories.

        Args:
            itr (int): Index of iteration (epoch).
            batch_size (int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before doing rollouts. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before doing rollouts. If a list is passed in,
                it must have length exactly `factory.n_workers`, and will be
                spread across the workers.

        Raises:
            ValueError: Raised if the runner was initialized without a sampler,
                        or batch_size wasn't provided here or to train.

        Returns:
            TrajectoryBatch: Batch of trajectories.

        """
        if self._sampler is None:
            raise ValueError('Runner was not initialized with `sampler_cls`. '
                             'Either provide `sampler_cls` to runner.setup, '
                             ' or set `algo.sampler_cls`.')
        if batch_size is None and self._train_args.batch_size is None:
            raise ValueError('Runner was not initialized with `batch_size`. '
                             'Either provide `batch_size` to runner.train, '
                             ' or pass `batch_size` to runner.obtain_samples.')
        paths = None
        if agent_update is None:
            agent_update = self._algo.policy.get_param_values()
        paths = self._sampler.obtain_samples(
            itr, (batch_size or self._train_args.batch_size),
            agent_update=agent_update,
            env_update=env_update)
        self._stats.total_env_steps += sum(paths.lengths)
        return paths

    def obtain_samples(self,
                       itr,
                       batch_size=None,
                       agent_update=None,
                       env_update=None):
        """Obtain one batch of samples.

        Args:
            itr (int): Index of iteration (epoch).
            batch_size (int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before doing rollouts. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before doing rollouts. If a list is passed in,
                it must have length exactly `factory.n_workers`, and will be
                spread across the workers.

        Raises:
            ValueError: Raised if the runner was initialized without a sampler,
                        or batch_size wasn't provided here or to train.

        Returns:
            list[dict]: One batch of samples.

        """
        trajs = self.obtain_trajectories(itr, batch_size, agent_update,
                                         env_update)
        return trajs.to_trajectory_list()

    def save(self, epoch):
        """Save snapshot of current batch.

        Args:
            epoch (int): Epoch.

        Raises:
            NotSetupError: if save() is called before the runner is set up.

        """
        if not self._has_setup:
            raise NotSetupError('Use setup() to setup runner before saving.')

        logger.log('Saving snapshot...')

        params = dict()
        # Save arguments
        params['setup_args'] = self._setup_args
        params['train_args'] = self._train_args
        params['stats'] = self._stats

        # Save states
        params['env'] = self._env
        params['algo'] = self._algo
        params['n_workers'] = self._n_workers
        params['worker_class'] = self._worker_class
        params['worker_args'] = self._worker_args

        self._snapshotter.save_itr_params(epoch, params)

        logger.log('Saved')

    def restore(self, from_dir, from_epoch='last'):
        """Restore experiment from snapshot.

        Args:
            from_dir (str): Directory of the pickle file
                to resume experiment from.
            from_epoch (str or int): The epoch to restore from.
                Can be 'first', 'last' or a number.
                Not applicable when snapshot_mode='last'.

        Returns:
            TrainArgs: Arguments for train().

        """
        saved = self._snapshotter.load(from_dir, from_epoch)

        self._setup_args = saved['setup_args']
        self._train_args = saved['train_args']
        self._stats = saved['stats']

        set_seed(self._setup_args.seed)

        self.setup(env=saved['env'],
                   algo=saved['algo'],
                   sampler_cls=self._setup_args.sampler_cls,
                   sampler_args=self._setup_args.sampler_args,
                   n_workers=saved['n_workers'],
                   worker_class=saved['worker_class'],
                   worker_args=saved['worker_args'])

        n_epochs = self._train_args.n_epochs
        last_epoch = self._stats.total_epoch
        last_itr = self._stats.total_itr
        total_env_steps = self._stats.total_env_steps
        batch_size = self._train_args.batch_size
        store_paths = self._train_args.store_paths
        pause_for_plot = self._train_args.pause_for_plot

        fmt = '{:<20} {:<15}'
        logger.log('Restore from snapshot saved in %s' %
                   self._snapshotter.snapshot_dir)
        logger.log(fmt.format('-- Train Args --', '-- Value --'))
        logger.log(fmt.format('n_epochs', n_epochs))
        logger.log(fmt.format('last_epoch', last_epoch))
        logger.log(fmt.format('batch_size', batch_size))
        logger.log(fmt.format('store_paths', store_paths))
        logger.log(fmt.format('pause_for_plot', pause_for_plot))
        logger.log(fmt.format('-- Stats --', '-- Value --'))
        logger.log(fmt.format('last_itr', last_itr))
        logger.log(fmt.format('total_env_steps', total_env_steps))

        self._train_args.start_epoch = last_epoch + 1
        return copy.copy(self._train_args)

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot (bool): Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self._start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time))
        tabular.record('TotalEnvSteps', self._stats.total_env_steps)
        logger.log(tabular)

        if self._plot:
            self._plotter.update_plot(self._algo.policy,
                                      self._algo.max_episode_length)
            if pause_for_plot:
                input('Plotting evaluation run: Press Enter to " "continue...')

    def train(self,
              n_epochs,
              batch_size=None,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs (int): Number of epochs.
            batch_size (int or None): Number of environment steps in one batch.
            plot (bool): Visualize policy by doing rollout after each epoch.
            store_paths (bool): Save paths in snapshot.
            pause_for_plot (bool): Pause for plot.

        Raises:
            NotSetupError: If train() is called before setup().

        Returns:
            float: The average return in last epoch cycle.

        """
        if not self._has_setup:
            raise NotSetupError('Use setup() to setup runner before training.')

        # Save arguments for restore
        self._train_args = TrainArgs(n_epochs=n_epochs,
                                     batch_size=batch_size,
                                     plot=plot,
                                     store_paths=store_paths,
                                     pause_for_plot=pause_for_plot,
                                     start_epoch=0)

        self._plot = plot

        average_return = self._algo.train(self)
        self._shutdown_worker()

        return average_return

    def step_epochs(self):
        """Step through each epoch.

        This function returns a magic generator. When iterated through, this
        generator automatically performs services such as snapshotting and log
        management. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        self._start_worker()
        self._start_time = time.time()
        self.step_itr = self._stats.total_itr
        self.step_path = None

        # Used by integration tests to ensure examples can run one epoch.
        n_epochs = int(
            os.environ.get('GARAGE_EXAMPLE_TEST_N_EPOCHS',
                           self._train_args.n_epochs))

        logger.log('Obtaining samples...')

        for epoch in range(self._train_args.start_epoch, n_epochs):
            self._itr_start_time = time.time()
            with logger.prefix('epoch #%d | ' % epoch):
                yield epoch
                save_path = (self.step_path
                             if self._train_args.store_paths else None)

                self._stats.last_path = save_path
                self._stats.total_epoch = epoch
                self._stats.total_itr = self.step_itr

                self.save(epoch)

                if self.enable_logging:
                    self.log_diagnostics(self._train_args.pause_for_plot)
                    logger.dump_all(self.step_itr)
                    tabular.clear()

    def resume(self,
               n_epochs=None,
               batch_size=None,
               plot=None,
               store_paths=None,
               pause_for_plot=None):
        """Resume from restored experiment.

        This method provides the same interface as train().

        If not specified, an argument will default to the
        saved arguments from the last call to train().

        Args:
            n_epochs (int): Number of epochs.
            batch_size (int): Number of environment steps in one batch.
            plot (bool): Visualize policy by doing rollout after each epoch.
            store_paths (bool): Save paths in snapshot.
            pause_for_plot (bool): Pause for plot.

        Raises:
            NotSetupError: If resume() is called before restore().

        Returns:
            float: The average return in last epoch cycle.

        """
        if self._train_args is None:
            raise NotSetupError('You must call restore() before resume().')

        self._train_args.n_epochs = n_epochs or self._train_args.n_epochs
        self._train_args.batch_size = batch_size or self._train_args.batch_size

        if plot is not None:
            self._train_args.plot = plot
        if store_paths is not None:
            self._train_args.store_paths = store_paths
        if pause_for_plot is not None:
            self._train_args.pause_for_plot = pause_for_plot

        average_return = self._algo.train(self)
        self._shutdown_worker()

        return average_return

    def get_env_copy(self):
        """Get a copy of the environment.

        Returns:
            garage.envs.GarageEnv: An environement instance.

        """
        if self._env:
            return cloudpickle.loads(cloudpickle.dumps(self._env))
        else:
            return None

    @property
    def total_env_steps(self):
        """Total environment steps collected.

        Returns:
            int: Total environment steps collected.

        """
        return self._stats.total_env_steps
Beispiel #10
0
    def __init__(self,
                 env,
                 policy,
                 qf,
                 es,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-4,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each
         epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q
         function.
        :param qf_update_method: Online optimization method for training Q
         function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the
         policy.
        :param policy_update_method: Online optimization method for training
         the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the
         policy.
        :param soft_target_tau: Interpolation parameter for doing the soft
         target update.
        :param n_updates_per_sample: Number of Q function and policy updates
         per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when
         training
        :param include_horizon_terminal_transitions: whether to include
         transitions with terminal=True because the horizon was reached. This
         might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each
         eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting
        :return:
        """
        self.env = env
        self.input_dims = configure_dims(env)
        self.policy = policy
        self.qf = qf
        self.es = es
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            parse_update_method(
                qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay
        self.policy_update_method = \
            parse_update_method(
                policy_update_method,
                learning_rate=policy_learning_rate,
            )
        self.policy_learning_rate = policy_learning_rate
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = \
            include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0

        self.scale_reward = scale_reward

        self.opt_info = None
        self.plotter = Plotter()
Beispiel #11
0
class DDPG(RLAlgorithm):
    """
    Deep Deterministic Policy Gradient.
    """

    def __init__(self,
                 env,
                 policy,
                 qf,
                 es,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-4,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each
         epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q
         function.
        :param qf_update_method: Online optimization method for training Q
         function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the
         policy.
        :param policy_update_method: Online optimization method for training
         the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the
         policy.
        :param soft_target_tau: Interpolation parameter for doing the soft
         target update.
        :param n_updates_per_sample: Number of Q function and policy updates
         per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when
         training
        :param include_horizon_terminal_transitions: whether to include
         transitions with terminal=True because the horizon was reached. This
         might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each
         eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting
        :return:
        """
        self.env = env
        self.input_dims = configure_dims(env)
        self.policy = policy
        self.qf = qf
        self.es = es
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            parse_update_method(
                qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay
        self.policy_update_method = \
            parse_update_method(
                policy_update_method,
                learning_rate=policy_learning_rate,
            )
        self.policy_learning_rate = policy_learning_rate
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = \
            include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0

        self.scale_reward = scale_reward

        self.opt_info = None
        self.plotter = Plotter()

    def start_worker(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            self.plotter.init_plot(self.env, self.policy)

    @overrides
    def train(self):
        # This seems like a rather sequential method
        input_shapes = dims_to_shapes(self.input_dims)
        pool = ReplayBuffer(
            buffer_shapes=input_shapes,
            max_buffer_size=self.replay_pool_size,
        )
        self.start_worker()

        self.init_opt()
        itr = 0
        path_length = 0
        path_return = 0
        terminal = False
        observation = self.env.reset()

        sample_policy = pickle.loads(pickle.dumps(self.policy))

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                # Execute policy
                if terminal:  # or path_length > self.max_path_length:
                    # Note that if the last time step ends an episode, the very
                    # last state and observation will be ignored and not added
                    # to the replay pool
                    observation = self.env.reset()
                    self.es.reset()
                    sample_policy.reset()
                    self.es_path_returns.append(path_return)
                    path_length = 0
                    path_return = 0
                action = self.es.get_action(
                    itr, observation, policy=sample_policy)

                next_observation, reward, terminal, _ = self.env.step(action)
                path_length += 1
                path_return += reward

                if not terminal and path_length >= self.max_path_length:
                    terminal = True
                    # only include the terminal transition in this case if the
                    # flag was set
                    if self.include_horizon_terminal_transitions:
                        pool.add_transition(
                            observation=observation,
                            action=action,
                            reward=reward * self.scale_reward,
                            terminal=terminal,
                            next_observation=next_observation)
                else:
                    pool.add_transition(
                        observation=observation,
                        action=action,
                        reward=reward * self.scale_reward,
                        terminal=terminal,
                        next_observation=next_observation)

                observation = next_observation

                if pool.size >= self.min_pool_size:
                    for update_itr in range(self.n_updates_per_sample):
                        # Train policy
                        batch = pool.sample(self.batch_size)
                        self.do_training(itr, batch)
                    sample_policy.set_param_values(
                        self.policy.get_param_values())

                itr += 1

            logger.log("Training finished")
            if pool.size >= self.min_pool_size:
                self.evaluate(epoch, pool)
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.update_plot()
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")
        self.env.close()
        self.policy.terminate()
        self.plotter.close()

    def init_opt(self):

        # First, create "target" policy and Q functions
        target_policy = pickle.loads(pickle.dumps(self.policy))
        target_qf = pickle.loads(pickle.dumps(self.qf))

        # y need to be computed first
        obs = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )

        # The yi values are computed separately as above and then passed to
        # the training functions below
        action = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )
        yvar = TT.vector('ys')

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([TT.sum(TT.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)

        qf_loss = TT.mean(TT.square(yvar - qval))
        qf_reg_loss = qf_loss + qf_weight_decay_term

        policy_weight_decay_term = 0.5 * self.policy_weight_decay * sum([
            TT.sum(TT.square(param))
            for param in self.policy.get_params(regularizable=True)
        ])
        policy_qval = self.qf.get_qval_sym(
            obs, self.policy.get_action_sym(obs), deterministic=True)
        policy_surr = -TT.mean(policy_qval)

        policy_reg_surr = policy_surr + policy_weight_decay_term

        qf_updates = self.qf_update_method(
            qf_reg_loss, self.qf.get_params(trainable=True))
        policy_updates = self.policy_update_method(
            policy_reg_surr, self.policy.get_params(trainable=True))

        f_train_qf = tensor_utils.compile_function(
            inputs=[yvar, obs, action],
            outputs=[qf_loss, qval],
            updates=qf_updates)

        f_train_policy = tensor_utils.compile_function(
            inputs=[obs], outputs=policy_surr, updates=policy_updates)

        self.opt_info = dict(
            f_train_qf=f_train_qf,
            f_train_policy=f_train_policy,
            target_qf=target_qf,
            target_policy=target_policy,
        )

    def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch, "observation", "action", "reward", "next_observation",
            "terminal")

        rewards = rewards.reshape(-1, )
        terminals = terminals.reshape(-1, )

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. - terminals) * self.discount * next_qvals

        f_train_qf = self.opt_info["f_train_qf"]
        f_train_policy = self.opt_info["f_train_policy"]

        qf_loss, qval = f_train_qf(ys, obs, actions)

        policy_surr = f_train_policy(obs)

        target_policy.set_param_values(
            target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
            self.policy.get_param_values() * self.soft_target_tau)
        target_qf.set_param_values(
            target_qf.get_param_values() * (1.0 - self.soft_target_tau) +
            self.qf.get_param_values() * self.soft_target_tau)

        self.qf_loss_averages.append(qf_loss)
        self.policy_surr_averages.append(policy_surr)
        self.q_averages.append(qval)
        self.y_averages.append(ys)

    def evaluate(self, epoch, pool):
        logger.log("Collecting samples for evaluation")
        paths = parallel_sampler.sample_paths(
            policy_params=self.policy.get_param_values(),
            max_samples=self.eval_samples,
            max_path_length=self.max_path_length,
        )

        average_discounted_return = np.mean([
            special.discount_return(path["rewards"], self.discount)
            for path in paths
        ])

        returns = [sum(path["rewards"]) for path in paths]

        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)

        average_q_loss = np.mean(self.qf_loss_averages)
        average_policy_surr = np.mean(self.policy_surr_averages)
        average_action = np.mean(
            np.square(np.concatenate([path["actions"] for path in paths])))

        policy_reg_param_norm = np.linalg.norm(
            self.policy.get_param_values(regularizable=True))
        qfun_reg_param_norm = np.linalg.norm(
            self.qf.get_param_values(regularizable=True))

        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('AverageReturn', np.mean(returns))
        logger.record_tabular('StdReturn', np.std(returns))
        logger.record_tabular('MaxReturn', np.max(returns))
        logger.record_tabular('MinReturn', np.min(returns))
        if self.es_path_returns:
            logger.record_tabular('AverageEsReturn',
                                  np.mean(self.es_path_returns))
            logger.record_tabular('StdEsReturn', np.std(self.es_path_returns))
            logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns))
            logger.record_tabular('MinEsReturn', np.min(self.es_path_returns))
        logger.record_tabular('AverageDiscountedReturn',
                              average_discounted_return)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AveragePolicySurr', average_policy_surr)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        logger.record_tabular('AverageAction', average_action)

        logger.record_tabular('PolicyRegParamNorm', policy_reg_param_norm)
        logger.record_tabular('QFunRegParamNorm', qfun_reg_param_norm)

        self.policy.log_diagnostics(paths)

        self.qf_loss_averages = []
        self.policy_surr_averages = []

        self.q_averages = []
        self.y_averages = []
        self.es_path_returns = []

    def update_plot(self):
        if self.plot:
            self.plotter.update_plot(self.policy, self.max_path_length)

    def get_epoch_snapshot(self, epoch):
        return dict(
            env=self.env,
            epoch=epoch,
            qf=self.qf,
            policy=self.policy,
            target_qf=self.opt_info["target_qf"],
            target_policy=self.opt_info["target_policy"],
            es=self.es,
        )
Beispiel #12
0
class CMAES(RLAlgorithm, Serializable):
    def __init__(self,
                 env,
                 policy,
                 n_itr=500,
                 max_path_length=500,
                 discount=0.99,
                 sigma0=1.,
                 batch_size=None,
                 plot=False,
                 **kwargs):
        """
        :param n_itr: Number of iterations.
        :param max_path_length: Maximum length of a single rollout.
        :param batch_size: # of samples from trajs from param distribution,
         when this is set, n_samples is ignored
        :param discount: Discount.
        :param plot: Plot evaluation run after each iteration.
        :param sigma0: Initial std for param dist
        :return:
        """
        Serializable.quick_init(self, locals())
        self.env = env
        self.policy = policy
        self.plot = plot
        self.sigma0 = sigma0
        self.discount = discount
        self.max_path_length = max_path_length
        self.n_itr = n_itr
        self.batch_size = batch_size
        self.plotter = Plotter()

    def train(self):

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()
        es = cma.CMAEvolutionStrategy(cur_mean, cur_std)

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            self.plotter.init_plot(self.env, self.policy)

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()

        itr = 0
        while itr < self.n_itr and not es.stop():

            if self.batch_size is None:
                # Sample from multivariate normal distribution.
                xs = es.ask()
                xs = np.asarray(xs)
                # For each sample, do a rollout.
                infos = (stateful_pool.singleton_pool.run_map(
                    sample_return,
                    [(x, self.max_path_length, self.discount) for x in xs]))
            else:
                cum_len = 0
                infos = []
                xss = []
                done = False
                while not done:
                    sbs = stateful_pool.singleton_pool.n_parallel * 2
                    # Sample from multivariate normal distribution.
                    # You want to ask for sbs samples here.
                    xs = es.ask(sbs)
                    xs = np.asarray(xs)

                    xss.append(xs)
                    sinfos = stateful_pool.singleton_pool.run_map(
                        sample_return,
                        [(x, self.max_path_length, self.discount) for x in xs])
                    for info in sinfos:
                        infos.append(info)
                        cum_len += len(info['returns'])
                        if cum_len >= self.batch_size:
                            xs = np.concatenate(xss)
                            done = True
                            break

            # Evaluate fitness of samples (negative as it is minimization
            # problem).
            fs = -np.array([info['returns'][0] for info in infos])
            # When batching, you could have generated too many samples compared
            # to the actual evaluations. So we cut it off in this case.
            xs = xs[:len(fs)]
            # Update CMA-ES params based on sample fitness.
            es.tell(xs, fs)

            logger.push_prefix('itr #{} | '.format(itr))
            tabular.record('Iteration', itr)
            tabular.record('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [info['undiscounted_return'] for info in infos])
            tabular.record('AverageReturn', np.mean(undiscounted_returns))
            tabular.record('StdReturn', np.mean(undiscounted_returns))
            tabular.record('MaxReturn', np.max(undiscounted_returns))
            tabular.record('MinReturn', np.min(undiscounted_returns))
            tabular.record('AverageDiscountedReturn', np.mean(fs))
            tabular.record('AvgTrajLen',
                           np.mean([len(info['returns']) for info in infos]))
            self.policy.log_diagnostics(infos)
            snapshotter.save_itr_params(
                itr, dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                ))
            logger.log(tabular)
            if self.plot:
                self.plotter.update_plot(self.policy, self.max_path_length)
            logger.pop_prefix()
            # Update iteration.
            itr += 1

        # Set final params.
        self.policy.set_param_values(es.result()[0])
        parallel_sampler.terminate_task()
        self.plotter.close()
Beispiel #13
0
class TFTrainer(Trainer):
    """This class implements a trainer for TensorFlow algorithms.

    A trainer provides a default TensorFlow session using python context.
    This is useful for those experiment components (e.g. policy) that require a
    TensorFlow session during construction.

    Use trainer.setup(algo, env) to setup algorithm and environment for trainer
    and trainer.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by Trainer to create the snapshotter.
            If None, it will create one with default settings.
        sess (tf.Session): An optional TensorFlow session.
              A new session will be created immediately if not provided.

    Note:
        When resume via command line, new snapshots will be
        saved into the SAME directory if not specified.

        When resume programmatically, snapshot directory should be
        specify manually or through @wrap_experiment interface.

    Examples:
        # to train
        with TFTrainer() as trainer:
            env = gym.make('CartPole-v1')
            policy = CategoricalMLPPolicy(
                env_spec=env.spec,
                hidden_sizes=(32, 32))
            algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                max_episode_length=100,
                discount=0.99,
                max_kl_step=0.01)
            trainer.setup(algo, env)
            trainer.train(n_epochs=100, batch_size=4000)

        # to resume immediately.
        with TFTrainer() as trainer:
            trainer.restore(resume_from_dir)
            trainer.resume()

        # to resume with modified training arguments.
        with TFTrainer() as trainer:
            trainer.restore(resume_from_dir)
            trainer.resume(n_epochs=20)

    """
    def __init__(self, snapshot_config, sess=None):
        super().__init__(snapshot_config=snapshot_config)
        self.sess = sess or tf.compat.v1.Session()
        self.sess_entered = False

    def __enter__(self):
        """Set self.sess as the default session.

        Returns:
            TFTrainer: This trainer.

        """
        if tf.compat.v1.get_default_session() is not self.sess:
            self.sess.__enter__()
            self.sess_entered = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Leave session.

        Args:
            exc_type (str): Type.
            exc_val (object): Value.
            exc_tb (object): Traceback.

        """
        if tf.compat.v1.get_default_session(
        ) is self.sess and self.sess_entered:
            self.sess.__exit__(exc_type, exc_val, exc_tb)
            self.sess_entered = False

    def make_sampler(self,
                     sampler_cls,
                     *,
                     seed=None,
                     n_workers=psutil.cpu_count(logical=False),
                     max_episode_length=None,
                     worker_class=None,
                     sampler_args=None,
                     worker_args=None):
        """Construct a Sampler from a Sampler class.

        Args:
            sampler_cls (type): The type of sampler to construct.
            seed (int): Seed to use in sampler workers.
            max_episode_length (int): Maximum episode length to be sampled by
                the sampler. Paths longer than this will be truncated.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the sampler should use.
            sampler_args (dict or None): Additional arguments that should be
                passed to the sampler.
            worker_args (dict or None): Additional arguments that should be
                passed to the worker.

        Returns:
            sampler_cls: An instance of the sampler class.

        """
        if worker_class is None:
            worker_class = getattr(self._algo, 'worker_cls', DefaultWorker)
        # pylint: disable=useless-super-delegation
        return super().make_sampler(
            sampler_cls,
            seed=seed,
            n_workers=n_workers,
            max_episode_length=max_episode_length,
            worker_class=TFWorkerClassWrapper(worker_class),
            sampler_args=sampler_args,
            worker_args=worker_args)

    def setup(self,
              algo,
              env,
              sampler_cls=None,
              sampler_args=None,
              n_workers=psutil.cpu_count(logical=False),
              worker_class=None,
              worker_args=None):
        """Set up trainer and sessions for algorithm and environment.

        This method saves algo and env within trainer and creates a sampler,
        and initializes all uninitialized variables in session.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (RLAlgorithm): An algorithm instance.
            env (Environment): An environment instance.
            sampler_cls (type): A class which implements :class:`Sampler`
            sampler_args (dict): Arguments to be passed to sampler constructor.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the sampler should use.
            worker_args (dict or None): Additional arguments that should be
                passed to the worker.

        """
        self.initialize_tf_vars()
        logger.log(self.sess.graph)
        super().setup(algo, env, sampler_cls, sampler_args, n_workers,
                      worker_class, worker_args)

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self._sampler.start_worker()
        if self._plot:
            # pylint: disable=import-outside-toplevel
            from garage.tf.plotter import Plotter
            self._plotter = Plotter(self.get_env_copy(),
                                    self._algo.policy,
                                    sess=tf.compat.v1.get_default_session())
            self._plotter.start()

    def initialize_tf_vars(self):
        """Initialize all uninitialized variables in session."""
        with tf.name_scope('initialize_tf_vars'):
            uninited_set = [
                e.decode() for e in self.sess.run(
                    tf.compat.v1.report_uninitialized_variables())
            ]
            self.sess.run(
                tf.compat.v1.variables_initializer([
                    v for v in tf.compat.v1.global_variables()
                    if v.name.split(':')[0] in uninited_set
                ]))
Beispiel #14
0
class CEM(RLAlgorithm, Serializable):
    def __init__(self,
                 env,
                 policy,
                 n_itr=500,
                 max_path_length=500,
                 discount=0.99,
                 init_std=1.,
                 n_samples=100,
                 batch_size=None,
                 best_frac=0.05,
                 extra_std=1.,
                 extra_decay_time=100,
                 plot=False,
                 n_evals=1,
                 play_every_itr=None,
                 play_rollouts_num=3,
                 **kwargs):
        """
        :param n_itr: Number of iterations.
        :param max_path_length: Maximum length of a single rollout.
        :param batch_size: # of samples from trajs from param distribution,
         when this is set, n_samples is ignored
        :param discount: Discount.
        :param plot: Plot evaluation run after each iteration.
        :param init_std: Initial std for param distribution
        :param extra_std: Decaying std added to param distribution at each
         iteration
        :param extra_decay_time: Iterations that it takes to decay extra std
        :param n_samples: #of samples from param distribution
        :param best_frac: Best fraction of the sampled params
        :param n_evals: # of evals per sample from the param distr. returned
         score is mean - stderr of evals
        :return:
        """
        Serializable.quick_init(self, locals())
        self.env = env
        self.policy = policy
        self.batch_size = batch_size
        self.plot = plot
        self.extra_decay_time = extra_decay_time
        self.extra_std = extra_std
        self.best_frac = best_frac
        self.n_samples = n_samples
        self.init_std = init_std
        self.discount = discount
        self.max_path_length = max_path_length
        self.n_itr = n_itr
        self.n_evals = n_evals
        self.plotter = Plotter()
        self.play_every_itr = play_every_itr
        self.play_rollouts_num = play_rollouts_num

    def play_policy(self, env, policy, n_rollout=3):
        # Rollout
        for _ in range(n_rollout):
            obs = env.reset()
            while True:
                env.render()
                action, _ = policy.get_action(obs)
                obs, _, done, _ = env.step(action)
                if done:
                    break

    def train(self, sess=None):
        # created_session = True if (sess is None) else False
        # if sess is None:
        #     sess = tf.Session()
        #     sess.__enter__()

        sess.run(tf.global_variables_initializer())

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            self.plotter.init_plot(self.env, self.policy)

        cur_std = self.init_std
        cur_mean = self.policy.get_param_values()
        # K = cur_mean.size
        n_best = max(1, int(self.n_samples * self.best_frac))

        for itr in range(self.n_itr):
            # sample around the current distribution
            extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0)
            sample_std = np.sqrt(
                np.square(cur_std) +
                np.square(self.extra_std) * extra_var_mult)
            if self.batch_size is None:
                criterion = 'paths'
                threshold = self.n_samples
            else:
                criterion = 'samples'
                threshold = self.batch_size
            infos = stateful_pool.singleton_pool.run_collect(
                _worker_rollout_policy,
                threshold=threshold,
                args=(dict(
                    cur_mean=cur_mean,
                    sample_std=sample_std,
                    max_path_length=self.max_path_length,
                    discount=self.discount,
                    criterion=criterion,
                    n_evals=self.n_evals), ))
            xs = np.asarray([info[0] for info in infos])
            paths = [info[1] for info in infos]

            fs = np.array([path['returns'][0] for path in paths])
            print((xs.shape, fs.shape))
            best_inds = (-fs).argsort()[:n_best]
            best_xs = xs[best_inds]
            cur_mean = best_xs.mean(axis=0)
            cur_std = best_xs.std(axis=0)
            best_x = best_xs[0]
            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [path['undiscounted_return'] for path in paths])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn', np.std(undiscounted_returns))
            logger.record_tabular('MaxReturn', np.max(undiscounted_returns))
            logger.record_tabular('MinReturn', np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn', np.mean(fs))
            logger.record_tabular('NumTrajs', len(paths))
            paths = list(chain(
                *[d['full_paths']
                  for d in paths]))  # flatten paths for the case n_evals > 1
            logger.record_tabular(
                'AvgTrajLen',
                np.mean([len(path['returns']) for path in paths]))

            self.policy.set_param_values(best_x)
            self.policy.log_diagnostics(paths)
            logger.save_itr_params(
                itr,
                dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                    cur_mean=cur_mean,
                    cur_std=cur_std,
                ))
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.policy, self.max_path_length)
            
            # Showing policy from time to time
            if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0:
                self.play_policy(env=self.env, policy=self.policy, n_rollout=self.play_rollouts_num)

        parallel_sampler.terminate_task()
        self.plotter.close()
        if created_session:
            sess.close()