コード例 #1
0
ファイル: local_runner.py プロジェクト: wjssx/garage
class LocalRunner:
    """Base class of local runner.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        max_cpus (int): The maximum number of parallel sampler workers.

    Note:
        For the use of any TensorFlow environments, policies and algorithms,
        please use LocalTFRunner().

    Examples:
        | # to train
        | runner = LocalRunner()
        | env = Env(...)
        | policy = Policy(...)
        | algo = Algo(
        |         env=env,
        |         policy=policy,
        |         ...)
        | runner.setup(algo, env)
        | runner.train(n_epochs=100, batch_size=4000)

        | # to resume immediately.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume()

        | # to resume with modified training arguments.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume(n_epochs=20)

    """
    def __init__(self, snapshot_config, max_cpus=1):
        self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                        snapshot_config.snapshot_mode,
                                        snapshot_config.snapshot_gap)

        parallel_sampler.initialize(max_cpus)

        seed = get_seed()
        if seed is not None:
            parallel_sampler.set_seed(seed)

        self._has_setup = False
        self._plot = False

        self._setup_args = None
        self._train_args = None
        self._stats = ExperimentStats(total_itr=0,
                                      total_env_steps=0,
                                      total_epoch=0,
                                      last_path=None)

        self._algo = None
        self._env = None
        self._policy = None
        self._sampler = None
        self._plotter = None

        self._start_time = None
        self._itr_start_time = None
        self.step_itr = None
        self.step_path = None

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.

        """
        self._algo = algo
        self._env = env
        self._policy = self._algo.policy

        if sampler_args is None:
            sampler_args = {}
        if sampler_cls is None:
            sampler_cls = algo.sampler_cls
        self._sampler = sampler_cls(algo, env, **sampler_args)

        self._has_setup = True

        self._setup_args = SetupArgs(sampler_cls=sampler_cls,
                                     sampler_args=sampler_args,
                                     seed=get_seed())

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self._sampler.start_worker()
        if self._plot:
            # pylint: disable=import-outside-toplevel
            from garage.tf.plotter import Plotter
            self._plotter = Plotter(self._env, self._policy)
            self._plotter.start()

    def _shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self._sampler.shutdown_worker()
        if self._plot:
            self._plotter.close()

    def obtain_samples(self, itr, batch_size=None):
        """Obtain one batch of samples.

        Args:
            itr (int): Index of iteration (epoch).
            batch_size (int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            list[dict]: One batch of samples.

        """
        paths = self._sampler.obtain_samples(
            itr, (batch_size or self._train_args.batch_size))

        self._stats.total_env_steps += sum([len(p['rewards']) for p in paths])

        return paths

    def save(self, epoch):
        """Save snapshot of current batch.

        Args:
            epoch (int): Epoch.

        Raises:
            NotSetupError: if save() is called before the runner is set up.

        """
        if not self._has_setup:
            raise NotSetupError('Use setup() to setup runner before saving.')

        logger.log('Saving snapshot...')

        params = dict()
        # Save arguments
        params['setup_args'] = self._setup_args
        params['train_args'] = self._train_args
        params['stats'] = self._stats

        # Save states
        params['env'] = self._env
        params['algo'] = self._algo

        self._snapshotter.save_itr_params(epoch, params)

        logger.log('Saved')

    def restore(self, from_dir, from_epoch='last'):
        """Restore experiment from snapshot.

        Args:
            from_dir (str): Directory of the pickle file
                to resume experiment from.
            from_epoch (str or int): The epoch to restore from.
                Can be 'first', 'last' or a number.
                Not applicable when snapshot_mode='last'.

        Returns:
            TrainArgs: Arguments for train().

        """
        saved = self._snapshotter.load(from_dir, from_epoch)

        self._setup_args = saved['setup_args']
        self._train_args = saved['train_args']
        self._stats = saved['stats']

        set_seed(self._setup_args.seed)

        self.setup(env=saved['env'],
                   algo=saved['algo'],
                   sampler_cls=self._setup_args.sampler_cls,
                   sampler_args=self._setup_args.sampler_args)

        n_epochs = self._train_args.n_epochs
        last_epoch = self._stats.total_epoch
        last_itr = self._stats.total_itr
        total_env_steps = self._stats.total_env_steps
        batch_size = self._train_args.batch_size
        store_paths = self._train_args.store_paths
        pause_for_plot = self._train_args.pause_for_plot

        fmt = '{:<20} {:<15}'
        logger.log('Restore from snapshot saved in %s' %
                   self._snapshotter.snapshot_dir)
        logger.log(fmt.format('-- Train Args --', '-- Value --'))
        logger.log(fmt.format('n_epochs', n_epochs))
        logger.log(fmt.format('last_epoch', last_epoch))
        logger.log(fmt.format('batch_size', batch_size))
        logger.log(fmt.format('store_paths', store_paths))
        logger.log(fmt.format('pause_for_plot', pause_for_plot))
        logger.log(fmt.format('-- Stats --', '-- Value --'))
        logger.log(fmt.format('last_itr', last_itr))
        logger.log(fmt.format('total_env_steps', total_env_steps))

        self._train_args.start_epoch = last_epoch + 1
        return copy.copy(self._train_args)

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot (bool): Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self._start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time))
        logger.log(tabular)
        if self._plot:
            self._plotter.update_plot(self._policy, self._algo.max_path_length)
            if pause_for_plot:
                input('Plotting evaluation run: Press Enter to " "continue...')

    def train(self,
              n_epochs,
              batch_size,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs (int): Number of epochs.
            batch_size (int): Number of environment steps in one batch.
            plot (bool): Visualize policy by doing rollout after each epoch.
            store_paths (bool): Save paths in snapshot.
            pause_for_plot (bool): Pause for plot.

        Raises:
            NotSetupError: If train() is called before setup().

        Returns:
            float: The average return in last epoch cycle.

        """
        if not self._has_setup:
            raise NotSetupError('Use setup() to setup runner before training.')

        # Save arguments for restore
        self._train_args = TrainArgs(n_epochs=n_epochs,
                                     batch_size=batch_size,
                                     plot=plot,
                                     store_paths=store_paths,
                                     pause_for_plot=pause_for_plot,
                                     start_epoch=0)

        self._plot = plot

        return self._algo.train(self)

    def step_epochs(self):
        """Step through each epoch.

        This function returns a magic generator. When iterated through, this
        generator automatically performs services such as snapshotting and log
        management. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        self._start_worker()
        self._start_time = time.time()
        self.step_itr = self._stats.total_itr
        self.step_path = None

        # Used by integration tests to ensure examples can run one epoch.
        n_epochs = int(
            os.environ.get('GARAGE_EXAMPLE_TEST_N_EPOCHS',
                           self._train_args.n_epochs))

        logger.log('Obtaining samples...')

        for epoch in range(self._train_args.start_epoch, n_epochs):
            self._itr_start_time = time.time()
            with logger.prefix('epoch #%d | ' % epoch):
                yield epoch
                save_path = (self.step_path
                             if self._train_args.store_paths else None)

                self._stats.last_path = save_path
                self._stats.total_epoch = epoch
                self._stats.total_itr = self.step_itr

                self.save(epoch)
                self.log_diagnostics(self._train_args.pause_for_plot)
                logger.dump_all(self.step_itr)
                tabular.clear()

    def resume(self,
               n_epochs=None,
               batch_size=None,
               plot=None,
               store_paths=None,
               pause_for_plot=None):
        """Resume from restored experiment.

        This method provides the same interface as train().

        If not specified, an argument will default to the
        saved arguments from the last call to train().

        Args:
            n_epochs (int): Number of epochs.
            batch_size (int): Number of environment steps in one batch.
            plot (bool): Visualize policy by doing rollout after each epoch.
            store_paths (bool): Save paths in snapshot.
            pause_for_plot (bool): Pause for plot.

        Raises:
            NotSetupError: If resume() is called before restore().

        Returns:
            float: The average return in last epoch cycle.

        """
        if self._train_args is None:
            raise NotSetupError('You must call restore() before resume().')

        self._train_args.n_epochs = n_epochs or self._train_args.n_epochs
        self._train_args.batch_size = batch_size or self._train_args.batch_size

        if plot is not None:
            self._train_args.plot = plot
        if store_paths is not None:
            self._train_args.store_paths = store_paths
        if pause_for_plot is not None:
            self._train_args.pause_for_plot = pause_for_plot

        return self._algo.train(self)

    def get_env_copy(self):
        """Get a copy of the environment.

        Returns:
            garage.envs.GarageEnv: An environement instance.

        """
        return pickle.loads(pickle.dumps(self._env))

    @property
    def total_env_steps(self):
        """Total environment steps collected.

        Returns:
            int: Total environment steps collected.

        """
        return self._stats.total_env_steps
コード例 #2
0
class LocalRunner:
    """Base class of local runner.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        max_cpus (int): The maximum number of parallel sampler workers.

    Note:
        For the use of any TensorFlow environments, policies and algorithms,
        please use LocalTFRunner().

    Examples:
        | # to train
        | runner = LocalRunner()
        | env = Env(...)
        | policy = Policy(...)
        | algo = Algo(
        |         env=env,
        |         policy=policy,
        |         ...)
        | runner.setup(algo, env)
        | runner.train(n_epochs=100, batch_size=4000)

        | # to resume immediately.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume()

        | # to resume with modified training arguments.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume(n_epochs=20)

    """

    def __init__(self, snapshot_config, max_cpus=1):
        self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                        snapshot_config.snapshot_mode,
                                        snapshot_config.snapshot_gap)

        if max_cpus > 1:
            from garage.sampler import singleton_pool
            singleton_pool.initialize(max_cpus)
        self.has_setup = False
        self.plot = False

        self._setup_args = None
        self.train_args = None

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.

        """
        self.algo = algo
        self.env = env
        self.policy = self.algo.policy

        if sampler_args is None:
            sampler_args = {}
        if sampler_cls is None:
            sampler_cls = algo.sampler_cls
        self.sampler = sampler_cls(algo, env, **sampler_args)

        self.has_setup = True

        self._setup_args = types.SimpleNamespace(sampler_cls=sampler_cls,
                                                 sampler_args=sampler_args)

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self.sampler.start_worker()
        if self.plot:
            from garage.tf.plotter import Plotter
            self.plotter = Plotter(self.env, self.policy)
            self.plotter.start()

    def _shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr, batch_size=None):
        """Obtain one batch of samples.

        Args:
            itr(int): Index of iteration (epoch).
            batch_size(int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            One batch of samples.

        """
        if self.train_args.n_epoch_cycles == 1:
            logger.log('Obtaining samples...')
        return self.sampler.obtain_samples(
            itr, (batch_size or self.train_args.batch_size))

    def save(self, epoch, paths=None):
        """Save snapshot of current batch.

        Args:
            itr(int): Index of iteration (epoch).
            paths(dict): Batch of samples after preprocessed. If None,
                no paths will be logged to the snapshot.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before saving.')

        logger.log('Saving snapshot...')

        params = dict()
        # Save arguments
        params['setup_args'] = self._setup_args
        params['train_args'] = self.train_args

        # Save states
        params['env'] = self.env
        params['algo'] = self.algo
        if paths:
            params['paths'] = paths
        params['last_epoch'] = epoch
        self._snapshotter.save_itr_params(epoch, params)

        logger.log('Saved')

    def restore(self, from_dir, from_epoch='last'):
        """Restore experiment from snapshot.

        Args:
            from_dir (str): Directory of the pickle file
                to resume experiment from.
            from_epoch (str or int): The epoch to restore from.
                Can be 'first', 'last' or a number.
                Not applicable when snapshot_mode='last'.

        Returns:
            A SimpleNamespace for train()'s arguments.

        """
        saved = self._snapshotter.load(from_dir, from_epoch)

        self._setup_args = saved['setup_args']
        self.train_args = saved['train_args']

        self.setup(env=saved['env'],
                   algo=saved['algo'],
                   sampler_cls=self._setup_args.sampler_cls,
                   sampler_args=self._setup_args.sampler_args)

        n_epochs = self.train_args.n_epochs
        last_epoch = saved['last_epoch']
        n_epoch_cycles = self.train_args.n_epoch_cycles
        batch_size = self.train_args.batch_size
        store_paths = self.train_args.store_paths
        pause_for_plot = self.train_args.pause_for_plot

        fmt = '{:<20} {:<15}'
        logger.log('Restore from snapshot saved in %s' %
                   self._snapshotter.snapshot_dir)
        logger.log(fmt.format('Train Args', 'Value'))
        logger.log(fmt.format('n_epochs', n_epochs))
        logger.log(fmt.format('last_epoch', last_epoch))
        logger.log(fmt.format('n_epoch_cycles', n_epoch_cycles))
        logger.log(fmt.format('batch_size', batch_size))
        logger.log(fmt.format('store_paths', store_paths))
        logger.log(fmt.format('pause_for_plot', pause_for_plot))

        self.train_args.start_epoch = last_epoch + 1
        return copy.copy(self.train_args)

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot(bool): Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self._start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time))
        logger.log(tabular)
        if self.plot:
            self.plotter.update_plot(self.policy, self.algo.max_path_length)
            if pause_for_plot:
                input('Plotting evaluation run: Press Enter to " "continue...')

    def train(self,
              n_epochs,
              batch_size,
              n_epoch_cycles=1,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs(int): Number of epochs.
            batch_size(int): Number of environment steps in one batch.
            n_epoch_cycles(int): Number of batches of samples in each epoch.
                This is only useful for off-policy algorithm.
                For on-policy algorithm this value should always be 1.
            plot(bool): Visualize policy by doing rollout after each epoch.
            store_paths(bool): Save paths in snapshot.
            pause_for_plot(bool): Pause for plot.

        Returns:
            The average return in last epoch cycle.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before training.')

        # Save arguments for restore
        self.train_args = types.SimpleNamespace(n_epochs=n_epochs,
                                                n_epoch_cycles=n_epoch_cycles,
                                                batch_size=batch_size,
                                                plot=plot,
                                                store_paths=store_paths,
                                                pause_for_plot=pause_for_plot,
                                                start_epoch=0)

        self.plot = plot

        return self.algo.train(self)

    def step_epochs(self):
        """Step through each epoch.

        This function returns a magic generator. When iterated through, this
        generator automatically performs services such as snapshotting and log
        management. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        try:
            self._start_worker()
            self._start_time = time.time()
            self.step_itr = (self.train_args.start_epoch *
                             self.train_args.n_epoch_cycles)
            self.step_path = None

            for epoch in range(self.train_args.start_epoch,
                               self.train_args.n_epochs):
                self._itr_start_time = time.time()
                with logger.prefix('epoch #%d | ' % epoch):
                    yield epoch
                    save_path = (self.step_path
                                 if self.train_args.store_paths else None)
                    self.save(epoch, save_path)
                    self.log_diagnostics(self.train_args.pause_for_plot)
                    logger.dump_all(self.step_itr)
                    tabular.clear()
        finally:
            self._shutdown_worker()

    def resume(self,
               n_epochs=None,
               batch_size=None,
               n_epoch_cycles=None,
               plot=None,
               store_paths=None,
               pause_for_plot=None):
        """Resume from restored experiment.

        This method provides the same interface as train().

        If not specified, an argument will default to the
        saved arguments from the last call to train().

        Returns:
            The average return in last epoch cycle.

        """
        if self.train_args is None:
            raise Exception('You must call restore() before resume().')

        self.train_args.n_epochs = n_epochs or self.train_args.n_epochs
        self.train_args.batch_size = batch_size or self.train_args.batch_size
        self.train_args.n_epoch_cycles = (n_epoch_cycles
                                          or self.train_args.n_epoch_cycles)

        if plot is not None:
            self.train_args.plot = plot
        if store_paths is not None:
            self.train_args.store_paths = store_paths
        if pause_for_plot is not None:
            self.train_args.pause_for_plot = pause_for_plot

        return self.algo.train(self)
コード例 #3
0
class OffPolicyRLAlgorithm(RLAlgorithm):
    """This class implements OffPolicyRLAlgorithm."""
    def __init__(
        self,
        env,
        policy,
        qf,
        replay_buffer,
        use_target=False,
        discount=0.99,
        n_epochs=500,
        n_epoch_cycles=20,
        max_path_length=100,
        n_train_steps=50,
        buffer_batch_size=64,
        min_buffer_size=int(1e4),
        rollout_batch_size=1,
        reward_scale=1.,
        input_include_goal=False,
        smooth_return=True,
        sampler_cls=None,
        sampler_args=None,
        force_batch_sampler=False,
        plot=False,
        pause_for_plot=False,
        exploration_strategy=None,
    ):
        """Construct an OffPolicyRLAlgorithm class."""
        self.env = env
        self.policy = policy
        self.qf = qf
        self.replay_buffer = replay_buffer
        self.n_epochs = n_epochs
        self.n_epoch_cycles = n_epoch_cycles
        self.n_train_steps = n_train_steps
        self.buffer_batch_size = buffer_batch_size
        self.use_target = use_target
        self.discount = discount
        self.min_buffer_size = min_buffer_size
        self.rollout_batch_size = rollout_batch_size
        self.reward_scale = reward_scale
        self.evaluate = False
        self.input_include_goal = input_include_goal
        self.smooth_return = smooth_return
        if sampler_cls is None:
            if policy.vectorized and not force_batch_sampler:
                sampler_cls = OffPolicyVectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()
        self.sampler = sampler_cls(self, **sampler_args)
        self.max_path_length = max_path_length
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.es = exploration_strategy
        self.init_opt()

    def start_worker(self, sess):
        """Initialize sampler and plotter."""
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env, self.policy, sess)
            self.plotter.start()

    def shutdown_worker(self):
        """Close sampler and plotter."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr):
        """Sample data for this iteration."""
        return self.sampler.obtain_samples(itr)

    def process_samples(self, itr, paths):
        """Process samples from rollout paths."""
        return self.sampler.process_samples(itr, paths)

    def log_diagnostics(self, paths):
        """Log diagnostic information on current paths."""
        self.policy.log_diagnostics(paths)
        self.qf.log_diagnostics(paths)

    def init_opt(self):
        """
        Initialize the optimization procedure.

        If using tensorflow, this may
        include declaring all the variables and compiling functions.
        """
        raise NotImplementedError

    def get_itr_snapshot(self, itr, samples_data):
        """Return data saved in the snapshot for this iteration."""
        raise NotImplementedError

    def optimize_policy(self, itr, samples_data):
        """Optimize policy network."""
        raise NotImplementedError
コード例 #4
0
ファイル: batch_polopt.py プロジェクト: Mee321/HAPG_exp
class BatchPolopt(RLAlgorithm):
    """
    Base class for batch sampling-based policy optimization methods.
    This includes various policy gradient methods like vpg, npg, ppo, trpo,
    etc.
    """
    def __init__(self,
                 env,
                 policy,
                 baseline,
                 scope=None,
                 n_itr=500,
                 start_itr=0,
                 batch_size=5000,
                 max_path_length=500,
                 discount=0.99,
                 gae_lambda=1,
                 plot=False,
                 pause_for_plot=False,
                 center_adv=True,
                 positive_adv=False,
                 store_paths=False,
                 whole_paths=True,
                 fixed_horizon=False,
                 sampler_cls=None,
                 sampler_args=None,
                 force_batch_sampler=False,
                 **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :type policy: Policy
        :param baseline: Baseline
        :param scope: Scope for identifying the algorithm. Must be specified if
         running multiple algorithms
        simultaneously, each using different environments and policies
        :param n_itr: Number of iterations.
        :param start_itr: Starting iteration.
        :param batch_size: Number of samples per iteration.
        :param max_path_length: Maximum length of a single rollout.
        :param discount: Discount.
        :param gae_lambda: Lambda used for generalized advantage estimation.
        :param plot: Plot evaluation run after each iteration.
        :param pause_for_plot: Whether to pause before contiuing when plotting.
        :param center_adv: Whether to rescale the advantages so that they have
         mean 0 and standard deviation 1.
        :param positive_adv: Whether to shift the advantages so that they are
         always positive. When used in conjunction with center_adv the
         advantages will be standardized before shifting.
        :param store_paths: Whether to save all paths data to the snapshot.
        :return:
        """
        self.env = env
        self.policy = policy
        self.baseline = baseline
        self.scope = scope
        self.n_itr = n_itr
        self.start_itr = start_itr
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.center_adv = center_adv
        self.positive_adv = positive_adv
        self.store_paths = store_paths
        self.whole_paths = whole_paths
        self.fixed_horizon = fixed_horizon
        if sampler_cls is None:
            if self.policy.vectorized and not force_batch_sampler:
                sampler_cls = OnPolicyVectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()
        self.sampler = sampler_cls(self, **sampler_args)
        self.init_opt()

    def start_worker(self, sess):
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env, self.policy, sess)
            self.plotter.start()

    def shutdown_worker(self):
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr):
        return self.sampler.obtain_samples(itr)

    def process_samples(self, itr, paths):
        return self.sampler.process_samples(itr, paths)

    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker(sess)
        start_time = time.time()
        last_average_return = None
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                last_average_return = samples_data["average_return"]
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
        if created_session:
            sess.close()
        return last_average_return

    def log_diagnostics(self, paths):
        self.policy.log_diagnostics(paths)
        self.baseline.log_diagnostics(paths)

    def init_opt(self):
        """
        Initialize the optimization procedure. If using tensorflow, this may
        include declaring all the variables and compiling functions
        """
        raise NotImplementedError

    def get_itr_snapshot(self, itr, samples_data):
        """
        Returns all the data that should be saved in the snapshot for this
        iteration.
        """
        raise NotImplementedError

    def optimize_policy(self, itr, samples_data):
        raise NotImplementedError
コード例 #5
0
class LocalRunner:
    """This class implements a local runner for tensorflow algorithms.

    A local runner provides a default tensorflow session using python context.
    This is useful for those experiment components (e.g. policy) that require a
    tensorflow session during construction.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Examples:
        with LocalRunner() as runner:
            env = gym.make('CartPole-v1')
            policy = CategoricalMLPPolicy(
                env_spec=env.spec,
                hidden_sizes=(32, 32))
            algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                max_path_length=100,
                discount=0.99,
                max_kl_step=0.01)
            runner.setup(algo, env)
            runner.train(n_epochs=100, batch_size=4000)

    """
    def __init__(self, sess=None, max_cpus=1):
        """Create a new local runner.

        Args:
            max_cpus: The maximum number of parallel sampler workers.
            sess: An optional tensorflow session.
                  A new session will be created immediately if not provided.

        Note:
            The local runner will set up a joblib task pool of size max_cpus
            possibly later used by BatchSampler. If BatchSampler is not used,
            the processes in the pool will remain dormant.

            This setup is required to use tensorflow in a multiprocess
            environment before a tensorflow session is created
            because tensorflow is not fork-safe.

            See https://github.com/tensorflow/tensorflow/issues/2448.

        """
        if max_cpus > 1:
            from garage.sampler import singleton_pool
            singleton_pool.initialize(max_cpus)
        self.sess = sess or tf.Session()
        self.has_setup = False
        self.plot = False

    def __enter__(self):
        """Set self.sess as the default session.

        Returns:
            This local runner.

        """
        if tf.get_default_session() is not self.sess:
            self.sess.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Leave session."""
        if tf.get_default_session() is self.sess:
            self.sess.__exit__(exc_type, exc_val, exc_tb)

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo: An algorithm instance.
            env: An environement instance.
            sampler_cls: A sampler class.
            sampler_args: Arguments to be passed to sampler constructor.

        """
        self.algo = algo
        self.env = env
        self.policy = self.algo.policy

        if sampler_args is None:
            sampler_args = {}

        if sampler_cls is None:
            from garage.tf.algos.batch_polopt import BatchPolopt
            if isinstance(algo, BatchPolopt):
                if self.policy.vectorized:
                    from garage.tf.samplers import OnPolicyVectorizedSampler
                    sampler_cls = OnPolicyVectorizedSampler
                else:
                    from garage.tf.samplers import BatchSampler
                    sampler_cls = BatchSampler
            else:
                from garage.tf.samplers import OffPolicyVectorizedSampler
                sampler_cls = OffPolicyVectorizedSampler

        self.sampler = sampler_cls(algo, **sampler_args)

        self.initialize_tf_vars()
        self.has_setup = True

    def initialize_tf_vars(self):
        """Initialize all uninitialized variables in session."""
        self.sess.run(
            tf.variables_initializer([
                v for v in tf.global_variables()
                if v.name.split(':')[0] in str(
                    self.sess.run(tf.report_uninitialized_variables()))
            ]))

    def start_worker(self):
        """Start Plotter and Sampler workers."""
        self.sampler.start_worker()
        if self.plot:
            from garage.tf.plotter import Plotter
            self.plotter = Plotter(self.env, self.policy)
            self.plotter.start()

    def shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr, batch_size):
        """Obtain one batch of samples.

        Args:
            itr: Index of iteration (epoch).
            batch_size: Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            One batch of samples.

        """
        if self.n_epoch_cycles == 1:
            logger.log("Obtaining samples...")
        return self.sampler.obtain_samples(itr, batch_size)

    def save_snapshot(self, itr, paths=None):
        """Save snapshot of current batch.

        Args:
            itr: Index of iteration (epoch).
            paths: Batch of samples after preprocessed.

        """
        logger.log("Saving snapshot...")
        params = self.algo.get_itr_snapshot(itr, paths)
        if paths:
            params["paths"] = paths
        logger.save_itr_params(itr, params)
        logger.log("Saved")

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot: Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self.start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self.itr_start_time))
        logger.dump_tabular(with_prefix=False)
        if self.plot:
            self.plotter.update_plot(self.policy, self.algo.max_path_length)
            if pause_for_plot:
                input("Plotting evaluation run: Press Enter to " "continue...")

    def train(self,
              n_epochs,
              n_epoch_cycles=1,
              batch_size=None,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs: Number of epochs.
            n_epoch_cycles: Number of batches of samples in each epoch.
                This is only useful for off-policy algorithm.
                For on-policy algorithm this value should always be 1.
            batch_size: Number of steps in batch.
            plot: Visualize policy by doing rollout after each epoch.
            store_paths: Save paths in snapshot.
            pause_for_plot: Pause for plot.

        Returns:
            The average return in last epoch cycle.

        """
        assert self.has_setup, "Use Runner.setup() to setup runner " \
                               "before training."
        if batch_size is None:
            from garage.tf.samplers import OffPolicyVectorizedSampler
            if isinstance(self.sampler, OffPolicyVectorizedSampler):
                batch_size = self.algo.max_path_length
            else:
                batch_size = 40 * self.algo.max_path_length

        self.n_epoch_cycles = n_epoch_cycles

        self.plot = plot
        self.start_worker()
        self.start_time = time.time()

        itr = 0
        last_return = None
        for epoch in range(n_epochs):
            self.itr_start_time = time.time()
            paths = None
            with logger.prefix('epoch #%d | ' % epoch):
                for cycle in range(n_epoch_cycles):
                    paths = self.obtain_samples(itr, batch_size)
                    paths = self.sampler.process_samples(itr, paths)
                    last_return = self.algo.train_once(itr, paths)
                    itr += 1
                self.save_snapshot(epoch, paths if store_paths else None)
                self.log_diagnostics(pause_for_plot)

        self.shutdown_worker()
        return last_return
コード例 #6
0
class BatchPolopt(RLAlgorithm):
    """
    Base class for batch sampling-based policy optimization methods.
    This includes various policy gradient methods like vpg, npg, ppo, trpo,
    etc.
    """
    def __init__(self,
                 env,
                 policy,
                 baseline,
                 scope=None,
                 n_itr=500,
                 max_samples=None,
                 start_itr=0,
                 batch_size=5000,
                 max_path_length=500,
                 discount=0.99,
                 gae_lambda=1,
                 plot=False,
                 pause_for_plot=False,
                 center_adv=True,
                 positive_adv=False,
                 store_paths=False,
                 paths_h5_filename=None,
                 whole_paths=True,
                 fixed_horizon=False,
                 sampler_cls=None,
                 sampler_args=None,
                 force_batch_sampler=False,
                 play_every_itr=None,
                 record_every_itr=None,
                 record_end_ep_num=3,
                 **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :type policy: Policy
        :param baseline: Baseline
        :param scope: Scope for identifying the algorithm. Must be specified if
         running multiple algorithms
        simultaneously, each using different environments and policies
        :param n_itr: Max umber of iterations.
        :param max_samples: If not None - exit when max env samples is collected (overrides n_itr)
        :param start_itr: Starting iteration.
        :param batch_size: Number of samples per iteration.
        :param max_path_length: Maximum length of a single rollout.
        :param discount: Discount.
        :param gae_lambda: Lambda used for generalized advantage estimation.
        :param plot: Plot evaluation run after each iteration.
        :param pause_for_plot: Whether to pause before contiuing when plotting.
        :param center_adv: Whether to rescale the advantages so that they have
         mean 0 and standard deviation 1.
        :param positive_adv: Whether to shift the advantages so that they are
         always positive. When used in conjunction with center_adv the
         advantages will be standardized before shifting.
        :param store_paths: Whether to save all paths data to the snapshot.
        :return:
        """
        self.args = locals()
        del self.args["kwargs"]
        del self.args["self"]
        self.args = {**self.args, **kwargs}  #merging dicts

        self.env = env
        try:
            self.env.env.save_dyn_params(
                filename=logger.get_snapshot_dir().rstrip(os.sep) + os.sep +
                "dyn_params.yaml")
        except:
            print("WARNING: BatchPolOpt: couldn't save dynamics params")
            # import pdb; pdb.set_trace()
        from gym.wrappers import Monitor
        # self.env_rec = Monitor(self.env.env, logger.get_snapshot_dir() + os.sep + "videos", force=True)

        self.policy = policy
        self.baseline = baseline
        self.scope = scope
        self.n_itr = n_itr
        self.max_samples = max_samples
        self.start_itr = start_itr
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.center_adv = center_adv
        self.positive_adv = positive_adv
        self.store_paths = store_paths
        self.whole_paths = whole_paths
        self.fixed_horizon = fixed_horizon
        self.play_every_itr = play_every_itr
        self.record_every_itr = record_every_itr
        self.record_end_ep_num = record_end_ep_num
        if sampler_cls is None:
            if self.policy.vectorized and not force_batch_sampler:
                sampler_cls = OnPolicyVectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()
        self.sampler = sampler_cls(self, **sampler_args)
        self.init_opt()

        ## Initialization of HDF5 logging of trajectories
        if self.store_paths:
            self.h5_prepare_file(filename=paths_h5_filename, args=self.args)

        ## Initialize cleaner if we close
        atexit.register(self.clean_at_exit)

    def record_policy(self,
                      env,
                      policy,
                      itr,
                      n_rollout=1,
                      path=None,
                      postfix=""):
        # Rollout
        if path is None:
            path = logger.get_snapshot_dir().rstrip(
                os.sep) + os.sep + "videos" + os.sep + "itr_%05d%s.mp4" % (
                    itr, postfix)
        path_directory = path.rsplit(os.sep, 1)[0]
        if not os.path.exists(path_directory):
            os.makedirs(path_directory, exist_ok=True)
        for _ in range(n_rollout):
            obs = env.reset()
            recorder = VideoRecorder(env.env, path=path)
            while True:
                # env.render()
                # import pdb; pdb.set_trace()
                action, _ = policy.get_action(obs)
                obs, _, done, _ = env.step(action)
                recorder.capture_frame()
                if done:
                    break
            recorder.close()

    def play_policy(self, env, policy, n_rollout=2):
        # Rollout
        for _ in range(n_rollout):
            obs = env.reset()
            while True:
                env.render()
                action, _ = policy.get_action(obs)
                obs, _, done, _ = env.step(action)
                if done:
                    break

    @staticmethod
    def register_exit_handler(handler_fn):
        # Save will be executed upon normal exit of interpreter
        # NOTE: The functions registered via this module are not called when
        # the program is killed by a signal not handled by Python
        atexit.register(handler_fn)

    def clean_at_exit(self):
        # self.hdf.close()
        pass

    def h5_prepare_file(self, filename, args):
        # Assuming the following structure / indexing of the H5 file
        # teacher_info/
        #   - [teacher_indx]:
        #        - description
        #        - params
        # traj_data/
        #   - [teacher_indx] * [iter_indx] * traj_data

        # Making names and opening h5 file
        if filename is None:
            self.h5_filename = logger.get_snapshot_dir(
            ) + os.sep + "trajectories.h5"
        else:  #capability to store multiple teachers in a single file
            self.h5_filename = filename
        self.h5_filename = self.h5_filename if self.h5_filename[
            -3:] == '.h5' else (self.h5_filename + '.h5')

        if os.path.exists(self.h5_filename):
            # input("WARNING: output file %s already exists and will be appended. Press ENTER to continue. (exit with ctrl-C)" % self.h5_filename)
            print(
                "WARNING: output file %s already exists and will be appended" %
                self.h5_filename)
        self.hdf = h5py.File(self.h5_filename, "a")

        # Creating proper groups
        groups = list(self.hdf.keys())
        # Groups to create: tuples: (group_name, structure_decscripton)
        create_groups = [("teacher_info", "Runs indices(Teachers)"),
                         ("traj_data",
                          "Runs(Teachers) x Iterations x Trajectories x Data")]

        for group in create_groups:
            if not group in groups:
                self.hdf.create_group(group[0])
                self.hdf[group[0]].attrs["structure"] = np.string_(group[1])

        # Checking if other teachers' results already exist in the h5 file
        # If they exist - just append
        teacher_indices = list(self.hdf["traj_data"].keys())
        if not teacher_indices:
            self.teacher_indx = 0
        else:
            teacher_indices = [int(indx) for indx in teacher_indices]
            teacher_indices = np.sort(teacher_indices)
            self.teacher_indx = teacher_indices[-1] + 1
            print("%s : Appended teacher index: " % self.__class__.__name__,
                  self.teacher_indx)

        self.hdf.create_group("traj_data/" +
                              h5u.indx2str(self.teacher_indx))  #Teacher group

        ## Saving info about the teacher
        teacher_info_group = "teacher_info/" + h5u.indx2str(
            self.teacher_indx) + "/"
        self.hdf.create_group(teacher_info_group)  #Teacher group
        h5u.add_dict(self.hdf, self.args, groupname=teacher_info_group)

        return self.hdf

    def start_worker(self, sess):
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env, self.policy, sess)
            self.plotter.start()

    def shutdown_worker(self):
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr):
        return self.sampler.obtain_samples(itr)

    def process_samples(self, itr, paths):
        return self.sampler.process_samples(itr, paths)

    def log_env_info(self, env_infos, prefix=""):
        # Logging rewards
        rew_dic = env_infos["rewards"]
        for key in rew_dic.keys():
            rew_sums = np.sum(rew_dic[key], axis=1)
            logger.record_tabular("rewards/" + key + "_avg", np.mean(rew_sums))
            logger.record_tabular("rewards/" + key + "_std", np.std(rew_sums))

    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()
            sess.run(tf.global_variables_initializer())

        # Initialize some missing variables
        uninitialized_vars = []
        for var in tf.all_variables():
            try:
                sess.run(var)
            except tf.errors.FailedPreconditionError:
                print("Uninitialized var: ", var)
                uninitialized_vars.append(var)
        init_new_vars_op = tf.initialize_variables(uninitialized_vars)
        sess.run(init_new_vars_op)

        self.start_worker(sess)
        start_time = time.time()
        last_average_return = None
        samples_total = 0
        for itr in range(self.start_itr, self.n_itr):
            if samples_total >= self.max_samples:
                print("WARNING: Total max num of samples collected: %d >= %d" %
                      (samples_total, self.max_samples))
                break
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                samples_total += self.batch_size
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                last_average_return = samples_data["average_return"]
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                # import pdb; pdb.set_trace()
                if self.store_paths:
                    ## WARN: Beware that data is saved to hdf in float32 by default
                    # see param float_nptype
                    h5u.append_train_iter_data(h5file=self.hdf,
                                               data=samples_data["paths"],
                                               data_group="traj_data/",
                                               teacher_indx=self.teacher_indx,
                                               itr=None,
                                               float_nptype=np.float32)
                    # params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                self.log_env_info(samples_data["env_infos"])
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input(
                            "Plotting evaluation run: Press Enter to continue..."
                        )
                # Showing policy from time to time
                if self.record_every_itr is not None and self.record_every_itr > 0 and itr % self.record_every_itr == 0:
                    self.record_policy(env=self.env,
                                       policy=self.policy,
                                       itr=itr)
                if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0:
                    self.play_policy(env=self.env, policy=self.policy)

        # Recording a few episodes at the end
        if self.record_end_ep_num is not None:
            for i in range(self.record_end_ep_num):
                self.record_policy(env=self.env,
                                   policy=self.policy,
                                   itr=itr,
                                   postfix="_%02d" % i)

        # Reporting termination criteria
        if itr >= self.n_itr - 1:
            print(
                "TERM CRITERIA: Max number of iterations reached itr: %d , itr_max: %d"
                % (itr, self.n_itr - 1))
        if samples_total >= self.max_samples:
            print(
                "TERM CRITERIA: Total max num of samples collected: %d >= %d" %
                (samples_total, self.max_samples))

        self.shutdown_worker()
        if created_session:
            sess.close()

    def log_diagnostics(self, paths):
        self.policy.log_diagnostics(paths)
        self.baseline.log_diagnostics(paths)

        path_lengths = [path["returns"].size for path in paths]
        logger.record_tabular('ep_len_avg', np.mean(path_lengths))
        logger.record_tabular('ep_len_std', np.std(path_lengths))

    def init_opt(self):
        """
        Initialize the optimization procedure. If using tensorflow, this may
        include declaring all the variables and compiling functions
        """
        raise NotImplementedError

    def get_itr_snapshot(self, itr, samples_data):
        """
        Returns all the data that should be saved in the snapshot for this
        iteration.
        """
        raise NotImplementedError

    def optimize_policy(self, itr, samples_data):
        raise NotImplementedError
コード例 #7
0
ファイル: local_tf_runner.py プロジェクト: wyjw/garage
class LocalRunner:
    """This class implements a local runner for tensorflow algorithms.

    A local runner provides a default tensorflow session using python context.
    This is useful for those experiment components (e.g. policy) that require a
    tensorflow session during construction.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        max_cpus (int): The maximum number of parallel sampler workers.
        sess (tf.Session): An optional tensorflow session.
              A new session will be created immediately if not provided.

    Note:
        The local runner will set up a joblib task pool of size max_cpus
        possibly later used by BatchSampler. If BatchSampler is not used,
        the processes in the pool will remain dormant.

        This setup is required to use tensorflow in a multiprocess
        environment before a tensorflow session is created
        because tensorflow is not fork-safe.

        See https://github.com/tensorflow/tensorflow/issues/2448.

    Examples:
        with LocalRunner() as runner:
            env = gym.make('CartPole-v1')
            policy = CategoricalMLPPolicy(
                env_spec=env.spec,
                hidden_sizes=(32, 32))
            algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                max_path_length=100,
                discount=0.99,
                max_kl_step=0.01)
            runner.setup(algo, env)
            runner.train(n_epochs=100, batch_size=4000)

    """
    def __init__(self, snapshot_config=None, sess=None, max_cpus=1):
        if snapshot_config:
            self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                            snapshot_config.snapshot_mode,
                                            snapshot_config.snapshot_gap)
        else:
            self._snapshotter = Snapshotter()

        if max_cpus > 1:
            from garage.sampler import singleton_pool
            singleton_pool.initialize(max_cpus)
        self.sess = sess or tf.Session()
        self.sess_entered = False
        self.has_setup = False
        self.plot = False

        self._setup_args = None
        self.train_args = None

    def __enter__(self):
        """Set self.sess as the default session.

        Returns:
            This local runner.

        """
        if tf.get_default_session() is not self.sess:
            self.sess.__enter__()
            self.sess_entered = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Leave session."""
        if tf.get_default_session() is self.sess and self.sess_entered:
            self.sess.__exit__(exc_type, exc_val, exc_tb)
            self.sess_entered = False

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.

        """
        self.algo = algo
        self.env = env
        self.policy = self.algo.policy

        if sampler_args is None:
            sampler_args = {}

        if sampler_cls is None:
            from garage.tf.algos.batch_polopt import BatchPolopt
            if isinstance(algo, BatchPolopt):
                if self.policy.vectorized:
                    from garage.tf.samplers import OnPolicyVectorizedSampler
                    sampler_cls = OnPolicyVectorizedSampler
                else:
                    from garage.tf.samplers import BatchSampler
                    sampler_cls = BatchSampler
            else:
                from garage.tf.samplers import OffPolicyVectorizedSampler
                sampler_cls = OffPolicyVectorizedSampler

        self.sampler = sampler_cls(algo, env, **sampler_args)

        self.initialize_tf_vars()
        logger.log(self.sess.graph)
        self.has_setup = True

        self._setup_args = types.SimpleNamespace(sampler_cls=sampler_cls,
                                                 sampler_args=sampler_args)

    def initialize_tf_vars(self):
        """Initialize all uninitialized variables in session."""
        with tf.name_scope('initialize_tf_vars'):
            uninited_set = [
                e.decode()
                for e in self.sess.run(tf.report_uninitialized_variables())
            ]
            self.sess.run(
                tf.variables_initializer([
                    v for v in tf.global_variables()
                    if v.name.split(':')[0] in uninited_set
                ]))

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self.sampler.start_worker()
        if self.plot:
            from garage.tf.plotter import Plotter
            self.plotter = Plotter(self.env, self.policy)
            self.plotter.start()

    def _shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr, batch_size):
        """Obtain one batch of samples.

        Args:
            itr(int): Index of iteration (epoch).
            batch_size(int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            One batch of samples.

        """
        if self.train_args.n_epoch_cycles == 1:
            logger.log('Obtaining samples...')
        return self.sampler.obtain_samples(itr, batch_size)

    def save(self, epoch, paths=None):
        """Save snapshot of current batch.

        Args:
            itr(int): Index of iteration (epoch).
            paths(dict): Batch of samples after preprocessed. If None,
                no paths will be logged to the snapshot.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before saving.')

        logger.log('Saving snapshot...')

        params = dict()
        # Save arguments
        params['setup_args'] = self._setup_args
        params['train_args'] = self.train_args

        # Save states
        params['env'] = self.env
        params['algo'] = self.algo
        if paths:
            params['paths'] = paths
        params['last_epoch'] = epoch
        self._snapshotter.save_itr_params(epoch, params)

        logger.log('Saved')

    def restore(self, from_dir, from_epoch='last'):
        """Restore experiment from snapshot.

        Args:
            from_dir (str): Directory of the pickle file
                to resume experiment from.
            from_epoch (str or int): The epoch to restore from.
                Can be 'first', 'last' or a number.
                Not applicable when snapshot_mode='last'.

        Returns:
            A SimpleNamespace for train()'s arguments.

        Examples:
            1. Resume experiment immediately.
            with LocalRunner() as runner:
                runner.restore(resume_from_dir)
                runner.resume()

            2. Resume experiment with modified training arguments.
             with LocalRunner() as runner:
                runner.restore(resume_from_dir)
                runner.resume(n_epochs=20)

        Note:
            When resume via command line, new snapshots will be
            saved into the SAME directory if not specified.

            When resume programmatically, snapshot directory should be
            specify manually or through run_experiment() interface.

        """
        saved = self._snapshotter.load(from_dir, from_epoch)

        self._setup_args = saved['setup_args']
        self.train_args = saved['train_args']

        self.setup(env=saved['env'],
                   algo=saved['algo'],
                   sampler_cls=self._setup_args.sampler_cls,
                   sampler_args=self._setup_args.sampler_args)

        n_epochs = self.train_args.n_epochs
        last_epoch = saved['last_epoch']
        n_epoch_cycles = self.train_args.n_epoch_cycles
        batch_size = self.train_args.batch_size
        store_paths = self.train_args.store_paths
        pause_for_plot = self.train_args.pause_for_plot

        fmt = '{:<20} {:<15}'
        logger.log('Restore from snapshot saved in %s' %
                   self._snapshotter.snapshot_dir)
        logger.log(fmt.format('Train Args', 'Value'))
        logger.log(fmt.format('n_epochs', n_epochs))
        logger.log(fmt.format('last_epoch', last_epoch))
        logger.log(fmt.format('n_epoch_cycles', n_epoch_cycles))
        logger.log(fmt.format('batch_size', batch_size))
        logger.log(fmt.format('store_paths', store_paths))
        logger.log(fmt.format('pause_for_plot', pause_for_plot))

        self.train_args.start_epoch = last_epoch + 1
        return copy.copy(self.train_args)

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot(bool): Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self._start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time))
        logger.log(tabular)
        if self.plot:
            self.plotter.update_plot(self.policy, self.algo.max_path_length)
            if pause_for_plot:
                input('Plotting evaluation run: Press Enter to " "continue...')

    def train(self,
              n_epochs,
              batch_size,
              n_epoch_cycles=1,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs(int): Number of epochs.
            batch_size(int): Number of environment steps in one batch.
            n_epoch_cycles(int): Number of batches of samples in each epoch.
                This is only useful for off-policy algorithm.
                For on-policy algorithm this value should always be 1.
            plot(bool): Visualize policy by doing rollout after each epoch.
            store_paths(bool): Save paths in snapshot.
            pause_for_plot(bool): Pause for plot.

        Returns:
            The average return in last epoch cycle.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before training.')

        # Save arguments for restore
        self.train_args = types.SimpleNamespace(n_epochs=n_epochs,
                                                n_epoch_cycles=n_epoch_cycles,
                                                batch_size=batch_size,
                                                plot=plot,
                                                store_paths=store_paths,
                                                pause_for_plot=pause_for_plot,
                                                start_epoch=0)

        self.plot = plot

        return self.algo.train(self, batch_size)

    def step_epochs(self):
        """Generator for training.

        This function serves as a generator. It is used to separate
        services such as snapshotting, sampler control from the actual
        training loop. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        try:
            self._start_worker()
            self._start_time = time.time()
            self.step_itr = (self.train_args.start_epoch *
                             self.train_args.n_epoch_cycles)
            self.step_path = None

            for epoch in range(self.train_args.start_epoch,
                               self.train_args.n_epochs):
                self._itr_start_time = time.time()
                with logger.prefix('epoch #%d | ' % epoch):
                    yield epoch
                    save_path = (self.step_path
                                 if self.train_args.store_paths else None)
                    self.save(epoch, save_path)
                    self.log_diagnostics(self.train_args.pause_for_plot)
                    logger.dump_all(self.step_itr)
                    tabular.clear()
        finally:
            self._shutdown_worker()

    def resume(self,
               n_epochs=None,
               batch_size=None,
               n_epoch_cycles=None,
               plot=None,
               store_paths=None,
               pause_for_plot=None):
        """Resume from restored experiment.

        This method provides the same interface as train().

        If not specified, an argument will default to the
        saved arguments from the last call to train().

        Returns:
            The average return in last epoch cycle.

        """
        assert self.train_args is not None, (
            'You must call restore() before resume().')

        self.train_args.n_epochs = n_epochs or self.train_args.n_epochs
        self.train_args.batch_size = batch_size or self.train_args.batch_size
        self.train_args.n_epoch_cycles = (n_epoch_cycles
                                          or self.train_args.n_epoch_cycles)

        if plot is not None:
            self.train_args.plot = plot
        if store_paths is not None:
            self.train_args.store_paths = store_paths
        if pause_for_plot is not None:
            self.train_args.pause_for_plot = pause_for_plot

        return self.algo.train(self, batch_size)
コード例 #8
0
class DDPG(RLAlgorithm):
    """
    A DDPG model based on https://arxiv.org/pdf/1509.02971.pdf.

    Example:
        $ python garage/examples/tf/ddpg_pendulum.py
    """

    def __init__(self,
                 env,
                 actor,
                 critic,
                 n_epochs=500,
                 n_epoch_cycles=20,
                 n_rollout_steps=100,
                 n_train_steps=50,
                 reward_scale=1.,
                 batch_size=64,
                 target_update_tau=0.01,
                 discount=0.99,
                 actor_lr=1e-4,
                 critic_lr=1e-3,
                 actor_weight_decay=0,
                 critic_weight_decay=0,
                 replay_buffer_size=int(1e6),
                 min_buffer_size=10000,
                 exploration_strategy=None,
                 plot=False,
                 pause_for_plot=False,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 use_her=False,
                 clip_obs=np.inf,
                 clip_pos_returns=True,
                 clip_return=None,
                 replay_k=4,
                 max_action=None,
                 name=None):
        """
        Construct class.

        Args:
            env(): Environment.
            actor(garage.tf.policies.ContinuousMLPPolicy): Policy network.
            critic(garage.tf.q_functions.ContinuousMLPQFunction):
        Q Value network.
            n_epochs(int, optional): Number of epochs.
            n_epoch_cycles(int, optional): Number of epoch cycles.
            n_rollout_steps(int, optional): Number of rollout steps.
        aka the time horizon of rollout.
            n_train_steps(int, optional): Number of train steps.
            reward_scale(float): The scaling factor applied to the rewards when
        training.
            batch_size(int): Number of samples for each minibatch.
            target_update_tau(float): Interpolation parameter for doing the
        soft target update.
            discount(float): Discount factor for the cumulative return.
            actor_lr(float): Learning rate for training policy network.
            critic_lr(float): Learning rate for training q value network.
            actor_weight_decay(float): L2 weight decay factor for parameters of
        the policy network.
            critic_weight_decay(float): L2 weight decay factor for parameters
        of the q value network.
            replay_buffer_size(int): Size of the replay buffer.
            min_buffer_size(int): Minimum size of the replay buffer to start
        training.
            exploration_strategy(): Exploration strategy to randomize the
        action.
            plot(boolean): Whether to visualize the policy performance after
        each eval_interval.
            pause_for_plot(boolean): Whether or not pause before continuing
        when plotting.
            actor_optimizer(): Optimizer for training policy network.
            critic_optimizer(): Optimizer for training q function network.
            use_her(boolean): Whether or not use HER for replay buffer.
            clip_obs(float): Clip observation to be in [-clip_obs, clip_obs].
            clip_pos_returns(boolean): Whether or not clip positive returns.
            clip_return(float): Clip return to be in [-clip_return,
        clip_return].
            replay_k(int): The ratio between HER replays and regular replays.
        Only used when use_her is True.
            max_action(float): Maximum action magnitude.
            name(str): Name of the algorithm shown in computation graph.
        """
        self.env = env

        self.input_dims = configure_dims(env)
        action_bound = env.action_space.high
        self.max_action = action_bound if max_action is None else max_action

        self.actor = actor
        self.critic = critic
        self.n_epochs = n_epochs
        self.n_epoch_cycles = n_epoch_cycles
        self.n_rollout_steps = n_rollout_steps
        self.n_train_steps = n_train_steps
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.tau = target_update_tau
        self.discount = discount
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.actor_weight_decay = actor_weight_decay
        self.critic_weight_decay = critic_weight_decay
        self.replay_buffer_size = replay_buffer_size
        self.min_buffer_size = min_buffer_size
        self.es = exploration_strategy
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.name = name
        self.use_her = use_her
        self.evaluate = False
        self.replay_k = replay_k
        self.clip_return = (
            1. / (1. - self.discount)) if clip_return is None else clip_return
        self.clip_obs = clip_obs
        self.clip_pos_returns = clip_pos_returns
        self.success_history = deque(maxlen=100)
        self._initialize()

    @overrides
    def train(self, sess=None):
        """
        Training process of DDPG algorithm.

        Args:
            sess: A TensorFlow session for executing ops.
        """
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        # Start plotter
        if self.plot:
            self.plotter = Plotter(self.env, self.actor, sess)
            self.plotter.start()

        sess.run(tf.global_variables_initializer())
        self.f_init_target()

        observation = self.env.reset()
        if self.es:
            self.es.reset()

        episode_reward = 0.
        episode_step = 0
        episode_rewards = []
        episode_steps = []
        episode_actor_losses = []
        episode_critic_losses = []
        episodes = 0
        epoch_ys = []
        epoch_qs = []

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            self.success_history.clear()
            for epoch_cycle in pyprind.prog_bar(range(self.n_epoch_cycles)):
                if self.use_her:
                    successes = []
                    for rollout in range(self.n_rollout_steps):
                        o = np.clip(observation["observation"], -self.clip_obs,
                                    self.clip_obs)
                        g = np.clip(observation["desired_goal"],
                                    -self.clip_obs, self.clip_obs)
                        obs_goal = np.concatenate((o, g), axis=-1)
                        action = self.es.get_action(rollout, obs_goal,
                                                    self.actor)

                        next_observation, reward, terminal, info = self.env.step(  # noqa: E501
                            action)
                        if 'is_success' in info:
                            successes.append([info["is_success"]])
                        episode_reward += reward
                        episode_step += 1

                        info_dict = {
                            "info_{}".format(key): info[key].reshape(1)
                            for key in info.keys()
                        }
                        self.replay_buffer.add_transition(
                            observation=observation['observation'],
                            action=action,
                            goal=observation['desired_goal'],
                            achieved_goal=observation['achieved_goal'],
                            **info_dict,
                        )

                        observation = next_observation

                        if rollout == self.n_rollout_steps - 1:
                            self.replay_buffer.add_transition(
                                observation=observation['observation'],
                                achieved_goal=observation['achieved_goal'])

                            episode_rewards.append(episode_reward)
                            episode_steps.append(episode_step)
                            episode_reward = 0.
                            episode_step = 0
                            episodes += 1

                            observation = self.env.reset()
                            if self.es:
                                self.es.reset()

                    successful = np.array(successes)[-1, :]
                    success_rate = np.mean(successful)
                    self.success_history.append(success_rate)

                    for train_itr in range(self.n_train_steps):
                        self.evaluate = True
                        critic_loss, y, q, action_loss = self._learn()

                        episode_actor_losses.append(action_loss)
                        episode_critic_losses.append(critic_loss)
                        epoch_ys.append(y)
                        epoch_qs.append(q)

                    self.f_update_target()
                else:
                    for rollout in range(self.n_rollout_steps):
                        action = self.es.get_action(rollout, observation,
                                                    self.actor)
                        assert action.shape == self.env.action_space.shape

                        next_observation, reward, terminal, info = self.env.step(  # noqa: E501
                            action)
                        episode_reward += reward
                        episode_step += 1

                        self.replay_buffer.add_transition(
                            observation=observation,
                            action=action,
                            reward=reward * self.reward_scale,
                            terminal=terminal,
                            next_observation=next_observation,
                        )

                        observation = next_observation

                        if terminal or rollout == self.n_rollout_steps - 1:
                            episode_rewards.append(episode_reward)
                            episode_steps.append(episode_step)
                            episode_reward = 0.
                            episode_step = 0
                            episodes += 1

                            observation = self.env.reset()
                            if self.es:
                                self.es.reset()

                    for train_itr in range(self.n_train_steps):
                        if self.replay_buffer.size >= self.min_buffer_size:
                            self.evaluate = True
                            critic_loss, y, q, action_loss = self._learn()

                            episode_actor_losses.append(action_loss)
                            episode_critic_losses.append(critic_loss)
                            epoch_ys.append(y)
                            epoch_qs.append(q)

            logger.log("Training finished")
            logger.log("Saving snapshot")
            itr = epoch * self.n_epoch_cycles + epoch_cycle
            params = self.get_itr_snapshot(itr)
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            if self.evaluate:
                logger.record_tabular('Epoch', epoch)
                logger.record_tabular('Episodes', episodes)
                logger.record_tabular('AverageReturn',
                                      np.mean(episode_rewards))
                logger.record_tabular('StdReturn', np.std(episode_rewards))
                logger.record_tabular('Policy/AveragePolicyLoss',
                                      np.mean(episode_actor_losses))
                logger.record_tabular('QFunction/AverageQFunctionLoss',
                                      np.mean(episode_critic_losses))
                logger.record_tabular('QFunction/AverageQ', np.mean(epoch_qs))
                logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                logger.record_tabular('QFunction/AverageAbsQ',
                                      np.mean(np.abs(epoch_qs)))
                logger.record_tabular('QFunction/AverageY', np.mean(epoch_ys))
                logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                logger.record_tabular('QFunction/AverageAbsY',
                                      np.mean(np.abs(epoch_ys)))
                if self.use_her:
                    logger.record_tabular('AverageSuccessRate',
                                          np.mean(self.success_history))

                # Uncomment the following if you want to calculate the average
                # in each epoch, better uncomment when self.use_her is True
                # episode_rewards = []
                # episode_actor_losses = []
                # episode_critic_losses = []
                # epoch_ys = []
                # epoch_qs = []

            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.actor, self.n_rollout_steps)
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")

        if self.plot:
            self.plotter.close()
        if created_session:
            sess.close()

    def _initialize(self):
        with tf.name_scope(self.name, "DDPG"):
            with tf.name_scope("setup_networks"):
                """Set up the actor, critic and target network."""
                # Set up the actor and critic network
                self.actor._build_net(trainable=True)
                self.critic._build_net(trainable=True)

                # Create target actor and critic network
                target_actor = copy(self.actor)
                target_critic = copy(self.critic)

                # Set up the target network
                target_actor.name = "TargetActor"
                target_actor._build_net(trainable=False)
                target_critic.name = "TargetCritic"
                target_critic._build_net(trainable=False)

            input_shapes = dims_to_shapes(self.input_dims)

            # Initialize replay buffer
            if self.use_her:
                buffer_shapes = {
                    key: (self.n_rollout_steps + 1
                          if key == "observation" or key == "achieved_goal"
                          else self.n_rollout_steps, *input_shapes[key])
                    for key, val in input_shapes.items()
                }

                replay_buffer = HerReplayBuffer(
                    buffer_shapes=buffer_shapes,
                    size_in_transitions=self.replay_buffer_size,
                    time_horizon=self.n_rollout_steps,
                    sample_transitions=make_her_sample(
                        self.replay_k, self.env.compute_reward))
            else:
                replay_buffer = ReplayBuffer(
                    buffer_shapes=input_shapes,
                    max_buffer_size=self.replay_buffer_size)

            # Set up target init and update function
            with tf.name_scope("setup_target"):
                actor_init_ops, actor_update_ops = get_target_ops(
                    self.actor.global_vars, target_actor.global_vars, self.tau)
                critic_init_ops, critic_update_ops = get_target_ops(
                    self.critic.global_vars, target_critic.global_vars,
                    self.tau)
                target_init_op = actor_init_ops + critic_init_ops
                target_update_op = actor_update_ops + critic_update_ops

            f_init_target = tensor_utils.compile_function(
                inputs=[], outputs=target_init_op)
            f_update_target = tensor_utils.compile_function(
                inputs=[], outputs=target_update_op)

            with tf.name_scope("inputs"):
                obs_dim = (
                    self.input_dims["observation"] + self.input_dims["goal"]
                ) if self.use_her else self.input_dims["observation"]
                y = tf.placeholder(tf.float32, shape=(None, 1), name="input_y")
                obs = tf.placeholder(
                    tf.float32,
                    shape=(None, obs_dim),
                    name="input_observation")
                actions = tf.placeholder(
                    tf.float32,
                    shape=(None, self.input_dims["action"]),
                    name="input_action")

            # Set up actor training function
            next_action = self.actor.get_action_sym(obs, name="actor_action")
            next_qval = self.critic.get_qval_sym(
                obs, next_action, name="actor_qval")
            with tf.name_scope("action_loss"):
                action_loss = -tf.reduce_mean(next_qval)
                if self.actor_weight_decay > 0.:
                    actor_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.actor_weight_decay),
                        weights_list=self.actor.regularizable_vars)
                    action_loss += actor_reg

            with tf.name_scope("minimize_action_loss"):
                actor_train_op = self.actor_optimizer(
                    self.actor_lr, name="ActorOptimizer").minimize(
                        action_loss, var_list=self.actor.trainable_vars)

            f_train_actor = tensor_utils.compile_function(
                inputs=[obs], outputs=[actor_train_op, action_loss])

            # Set up critic training function
            qval = self.critic.get_qval_sym(obs, actions, name="q_value")
            with tf.name_scope("qval_loss"):
                qval_loss = tf.reduce_mean(tf.squared_difference(y, qval))
                if self.critic_weight_decay > 0.:
                    critic_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.critic_weight_decay),
                        weights_list=self.critic.regularizable_vars)
                    qval_loss += critic_reg

            with tf.name_scope("minimize_critic_loss"):
                critic_train_op = self.critic_optimizer(
                    self.critic_lr, name="CriticOptimizer").minimize(
                        qval_loss, var_list=self.critic.trainable_vars)

            f_train_critic = tensor_utils.compile_function(
                inputs=[y, obs, actions],
                outputs=[critic_train_op, qval_loss, qval])

            self.f_train_actor = f_train_actor
            self.f_train_critic = f_train_critic
            self.f_init_target = f_init_target
            self.f_update_target = f_update_target
            self.replay_buffer = replay_buffer
            self.target_critic = target_critic
            self.target_actor = target_actor

    def _learn(self):
        """
        Perform algorithm optimizing.

        Returns:
            action_loss: Loss of action predicted by the policy network.
            qval_loss: Loss of q value predicted by the q network.
            ys: y_s.
            qval: Q value predicted by the q network.

        """
        if self.use_her:
            transitions = self.replay_buffer.sample(self.batch_size)
            observations = transitions["observation"]
            rewards = transitions["reward"]
            actions = transitions["action"]
            next_observations = transitions["next_observation"]
            goals = transitions["goal"]

            next_inputs = np.concatenate((next_observations, goals), axis=-1)
            inputs = np.concatenate((observations, goals), axis=-1)

            rewards = rewards.reshape(-1, 1)

            target_actions, _ = self.target_actor.get_actions(next_inputs)
            target_qvals = self.target_critic.get_qval(next_inputs,
                                                       target_actions)

            clip_range = (-self.clip_return, 0.
                          if self.clip_pos_returns else np.inf)
            ys = np.clip(rewards + self.discount * target_qvals, clip_range[0],
                         clip_range[1])

            _, qval_loss, qval = self.f_train_critic(ys, inputs, actions)
            _, action_loss = self.f_train_actor(inputs)
        else:
            transitions = self.replay_buffer.sample(self.batch_size)
            observations = transitions["observation"]
            rewards = transitions["reward"]
            actions = transitions["action"]
            terminals = transitions["terminal"]
            next_observations = transitions["next_observation"]

            rewards = rewards.reshape(-1, 1)
            terminals = terminals.reshape(-1, 1)

            target_actions, _ = self.target_actor.get_actions(
                next_observations)
            target_qvals = self.target_critic.get_qval(next_observations,
                                                       target_actions)

            ys = rewards + (1.0 - terminals) * self.discount * target_qvals

            _, qval_loss, qval = self.f_train_critic(ys, observations, actions)
            _, action_loss = self.f_train_actor(observations)
            self.f_update_target()

        return qval_loss, ys, qval, action_loss

    def get_itr_snapshot(self, itr):
        return dict(itr=itr, policy=self.actor, env=self.env)