Ejemplo n.º 1
0
class TaskEmbeddingRunner(LocalRunner):
    def setup(self, algo, env, batch_size, max_path_length):
        n_envs = batch_size // max_path_length
        return super().setup(algo=algo,
                             env=env,
                             sampler_cls=TaskEmbeddingSampler,
                             sampler_args=dict(n_envs=n_envs))

    def train(self,
              n_epochs,
              n_epoch_cycles=1,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        return super().train(n_epochs,
                             n_epoch_cycles=n_epoch_cycles,
                             plot=plot,
                             store_paths=store_paths,
                             pause_for_plot=pause_for_plot)

    def start_worker(self):
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env, self.policy, rollout=rollout)
            self.plotter.start()
Ejemplo n.º 2
0
class NPOTaskEmbedding(BatchPolopt, Serializable):
    """
    Natural Policy Optimization with Task Embeddings
    """
    def __init__(self,
                 name="NPOTaskEmbedding",
                 pg_loss=PGLoss.VANILLA,
                 kl_constraint=None,
                 optimizer=LbfgsOptimizer,
                 optimizer_args=dict(),
                 step_size=0.01,
                 num_minibatches=None,
                 num_opt_epochs=None,
                 policy=None,
                 policy_ent_coeff=1e-2,
                 embedding_ent_coeff=1e-5,
                 inference=None,
                 inference_optimizer=LbfgsOptimizer,
                 inference_optimizer_args=dict(),
                 inference_ce_coeff=1e-3,
                 use_softplus_entropy=False,
                 save_sample_frequency=0,
                 stop_ce_graident=False,
                 **kwargs):
        Serializable.quick_init(self, locals())
        assert kwargs['env'].task_space
        assert isinstance(policy, StochasticMultitaskPolicy)
        assert isinstance(inference, StochasticEmbedding)

        self.name = name
        self._name_scope = tf.name_scope(self.name)

        self._pg_loss = pg_loss
        self._policy_opt_inputs = None
        self._inference_opt_inputs = None
        self._use_softplus_entropy = use_softplus_entropy

        self._save_sample_frequency = save_sample_frequency
        self._stop_ce_graident = stop_ce_graident

        with tf.name_scope(self.name):
            # Optimizer for policy and embedding networks
            self.optimizer = optimizer(**optimizer_args)
            self.step_size = float(step_size)
            self.policy_ent_coeff = float(policy_ent_coeff)
            self.embedding_ent_coeff = float(embedding_ent_coeff)

            self.inference = inference
            self.inference_ce_coeff = inference_ce_coeff
            self.inference_optimizer = inference_optimizer(
                **inference_optimizer_args)

            sampler_cls = TaskEmbeddingSampler
            sampler_args = dict(inference=self.inference, )
            super(NPOTaskEmbedding, self).__init__(sampler_cls=sampler_cls,
                                                   sampler_args=sampler_args,
                                                   policy=policy,
                                                   **kwargs)

    @overrides
    def start_worker(self, sess):
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env,
                                   self.policy,
                                   sess=sess,
                                   rollout=rollout)
            self.plotter.start()

    @overrides
    def init_opt(self):
        if self.policy.recurrent:
            raise NotImplementedError

        # Input variables
        (pol_loss_inputs, pol_opt_inputs, infer_loss_inputs,
         infer_opt_inputs) = self._build_inputs()

        self._policy_opt_inputs = pol_opt_inputs
        self._inference_opt_inputs = infer_opt_inputs

        # Jointly optimize policy and embedding network
        pol_loss, pol_kl, embed_kl = self._build_policy_loss(pol_loss_inputs)
        self.optimizer.update_opt(loss=pol_loss,
                                  target=self.policy,
                                  leq_constraint=(pol_kl, self.step_size),
                                  inputs=flatten_inputs(
                                      self._policy_opt_inputs),
                                  constraint_name="mean_kl")

        # Optimize inference distribution separately (supervised learning)
        infer_loss, infer_kl = self._build_inference_loss(infer_loss_inputs)
        self.inference_optimizer.update_opt(loss=infer_loss,
                                            target=self.inference,
                                            inputs=flatten_inputs(
                                                self._inference_opt_inputs))

        return dict()

    #### Loss function network ################################################

    def _build_inputs(self):
        """
        Builds input variables (and trivial views thereof) for the loss
        function network
        """

        observation_space = self.policy.observation_space
        action_space = self.policy.action_space
        task_space = self.policy.task_space
        latent_space = self.policy.latent_space
        trajectory_space = self.inference.input_space

        policy_dist = self.policy._dist
        embed_dist = self.policy.embedding._dist
        infer_dist = self.inference._dist

        with tf.name_scope("inputs"):
            obs_var = observation_space.new_tensor_variable(
                'obs',
                extra_dims=1 + 1,
            )

            task_var = task_space.new_tensor_variable(
                'task',
                extra_dims=1 + 1,
            )

            action_var = action_space.new_tensor_variable(
                'action',
                extra_dims=1 + 1,
            )

            reward_var = tensor_utils.new_tensor(
                'reward',
                ndim=1 + 1,
                dtype=tf.float32,
            )

            latent_var = latent_space.new_tensor_variable(
                'latent',
                extra_dims=1 + 1,
            )

            baseline_var = tensor_utils.new_tensor(
                'baseline',
                ndim=1 + 1,
                dtype=tf.float32,
            )

            trajectory_var = trajectory_space.new_tensor_variable(
                'trajectory',
                extra_dims=1 + 1,
            )

            valid_var = tf.placeholder(tf.float32,
                                       shape=[None, None],
                                       name="valid")

            # Policy state (for RNNs)
            policy_state_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name=k)
                for k, shape in self.policy.state_info_specs
            }
            policy_state_info_vars_list = [
                policy_state_info_vars[k] for k in self.policy.state_info_keys
            ]

            # Old policy distribution (for KL)
            policy_old_dist_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='policy_old_%s' % k)
                for k, shape in policy_dist.dist_info_specs
            }
            policy_old_dist_info_vars_list = [
                policy_old_dist_info_vars[k]
                for k in policy_dist.dist_info_keys
            ]

            # Embedding state (for RNNs)
            embed_state_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='embed_%s' % k)
                for k, shape in self.policy.embedding.state_info_specs
            }
            embed_state_info_vars_list = [
                embed_state_info_vars[k]
                for k in self.policy.embedding.state_info_keys
            ]

            # Old embedding distribution (for KL)
            embed_old_dist_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='embed_old_%s' % k)
                for k, shape in embed_dist.dist_info_specs
            }
            embed_old_dist_info_vars_list = [
                embed_old_dist_info_vars[k] for k in embed_dist.dist_info_keys
            ]

            # Inference distribution state (for RNNs)
            infer_state_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='infer_%s' % k)
                for k, shape in self.inference.state_info_specs
            }
            infer_state_info_vars_list = [
                infer_state_info_vars[k]
                for k in self.inference.state_info_keys
            ]

            # Old inference distribution (for KL)
            infer_old_dist_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='infer_old_%s' % k)
                for k, shape in infer_dist.dist_info_specs
            }
            infer_old_dist_info_vars_list = [
                infer_old_dist_info_vars[k] for k in infer_dist.dist_info_keys
            ]

            # Flattened view
            with tf.name_scope("flat"):
                obs_flat = flatten_batch(obs_var, name="obs_flat")
                task_flat = flatten_batch(task_var, name="task_flat")
                action_flat = flatten_batch(action_var, name="action_flat")
                reward_flat = flatten_batch(reward_var, name="reward_flat")
                latent_flat = flatten_batch(latent_var, name="latent_flat")
                trajectory_flat = flatten_batch(trajectory_var,
                                                name="trajectory_flat")
                valid_flat = flatten_batch(valid_var, name="valid_flat")
                policy_state_info_vars_flat = flatten_batch_dict(
                    policy_state_info_vars, name="policy_state_info_vars_flat")
                policy_old_dist_info_vars_flat = flatten_batch_dict(
                    policy_old_dist_info_vars,
                    name="policy_old_dist_info_vars_flat")
                embed_state_info_vars_flat = flatten_batch_dict(
                    embed_state_info_vars, name="embed_state_info_vars_flat")
                embed_old_dist_info_vars_flat = flatten_batch_dict(
                    embed_old_dist_info_vars,
                    name="embed_old_dist_info_vars_flat")
                infer_state_info_vars_flat = flatten_batch_dict(
                    infer_state_info_vars, name="infer_state_info_vars_flat")
                infer_old_dist_info_vars_flat = flatten_batch_dict(
                    infer_old_dist_info_vars,
                    name="infer_old_dist_info_vars_flat")

            # Valid view
            with tf.name_scope("valid"):
                action_valid = filter_valids(action_flat,
                                             valid_flat,
                                             name="action_valid")
                policy_state_info_vars_valid = filter_valids_dict(
                    policy_state_info_vars_flat,
                    valid_flat,
                    name="policy_state_info_vars_valid")
                policy_old_dist_info_vars_valid = filter_valids_dict(
                    policy_old_dist_info_vars_flat,
                    valid_flat,
                    name="policy_old_dist_info_vars_valid")
                embed_old_dist_info_vars_valid = filter_valids_dict(
                    embed_old_dist_info_vars_flat,
                    valid_flat,
                    name="embed_old_dist_info_vars_valid")
                infer_old_dist_info_vars_valid = filter_valids_dict(
                    infer_old_dist_info_vars_flat,
                    valid_flat,
                    name="infer_old_dist_info_vars_valid")

        # Policy and embedding network loss and optimizer inputs
        pol_flat = graph_inputs(
            "PolicyLossInputsFlat",
            obs_var=obs_flat,
            task_var=task_flat,
            action_var=action_flat,
            reward_var=reward_flat,
            latent_var=latent_flat,
            trajectory_var=trajectory_flat,
            valid_var=valid_flat,
            policy_state_info_vars=policy_state_info_vars_flat,
            policy_old_dist_info_vars=policy_old_dist_info_vars_flat,
            embed_state_info_vars=embed_state_info_vars_flat,
            embed_old_dist_info_vars=embed_old_dist_info_vars_flat,
        )
        pol_valid = graph_inputs(
            "PolicyLossInputsValid",
            action_var=action_valid,
            policy_state_info_vars=policy_state_info_vars_valid,
            policy_old_dist_info_vars=policy_old_dist_info_vars_valid,
            embed_old_dist_info_vars=embed_old_dist_info_vars_valid,
        )
        policy_loss_inputs = graph_inputs(
            "PolicyLossInputs",
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            trajectory_var=trajectory_var,
            task_var=task_var,
            latent_var=latent_var,
            valid_var=valid_var,
            policy_state_info_vars=policy_state_info_vars,
            policy_old_dist_info_vars=policy_old_dist_info_vars,
            embed_state_info_vars=embed_state_info_vars,
            embed_old_dist_info_vars=embed_old_dist_info_vars,
            flat=pol_flat,
            valid=pol_valid,
        )
        # Special variant for the optimizer
        # * Uses lists instead of dicts for the distribution parameters
        # * Omits flats and valids
        # TODO: eliminate
        policy_opt_inputs = graph_inputs(
            "PolicyOptInputs",
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            trajectory_var=trajectory_var,
            task_var=task_var,
            latent_var=latent_var,
            valid_var=valid_var,
            policy_state_info_vars_list=policy_state_info_vars_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_vars_list,
            embed_state_info_vars_list=embed_state_info_vars_list,
            embed_old_dist_info_vars_list=embed_old_dist_info_vars_list,
        )

        # Inference network loss and optimizer inputs
        infer_flat = graph_inputs(
            "InferenceLossInputsFlat",
            latent_var=latent_flat,
            trajectory_var=trajectory_flat,
            valid_var=valid_flat,
            infer_state_info_vars=infer_state_info_vars_flat,
            infer_old_dist_info_vars=infer_old_dist_info_vars_flat,
        )
        infer_valid = graph_inputs(
            "InferenceLossInputsValid",
            infer_old_dist_info_vars=infer_old_dist_info_vars_valid,
        )
        inference_loss_inputs = graph_inputs(
            "InferenceLossInputs",
            latent_var=latent_var,
            trajectory_var=trajectory_var,
            valid_var=valid_var,
            infer_state_info_vars=infer_state_info_vars,
            infer_old_dist_info_vars=infer_old_dist_info_vars,
            flat=infer_flat,
            valid=infer_valid,
        )
        # Special variant for the optimizer
        # * Uses lists instead of dicts for the distribution parameters
        # * Omits flats and valids
        # TODO: eliminate
        inference_opt_inputs = graph_inputs(
            "InferenceOptInputs",
            latent_var=latent_var,
            trajectory_var=trajectory_var,
            valid_var=valid_var,
            infer_state_info_vars_list=infer_state_info_vars_list,
            infer_old_dist_info_vars_list=infer_old_dist_info_vars_list,
        )

        return (policy_loss_inputs, policy_opt_inputs, inference_loss_inputs,
                inference_opt_inputs)

    def _build_policy_loss(self, i):
        """ Build policy network loss """
        pol_dist = self.policy._dist

        # Entropy terms
        embedding_entropy, inference_ce, policy_entropy = \
            self._build_entropy_terms(i)

        # Augment the path rewards with entropy terms
        with tf.name_scope("augmented_rewards"):
            rewards = i.reward_var \
                      - (self.inference_ce_coeff * inference_ce) \
                      + (self.policy_ent_coeff * policy_entropy)

        with tf.name_scope("policy_loss"):
            with tf.name_scope("advantages"):
                advantages = compute_advantages(self.discount,
                                                self.gae_lambda,
                                                self.max_path_length,
                                                i.baseline_var,
                                                rewards,
                                                name="advantages")

                # Flatten and filter valids
                adv_flat = flatten_batch(advantages, name="adv_flat")
                adv_valid = filter_valids(adv_flat,
                                          i.flat.valid_var,
                                          name="adv_valid")

            policy_dist_info_flat = self.policy.dist_info_sym(
                i.flat.task_var,
                i.flat.obs_var,
                i.flat.policy_state_info_vars,
                name="policy_dist_info_flat")
            policy_dist_info_valid = filter_valids_dict(
                policy_dist_info_flat,
                i.flat.valid_var,
                name="policy_dist_info_valid")

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                with tf.name_scope("center_adv"):
                    mean, var = tf.nn.moments(adv_valid, axes=[0])
                    adv_valid = tf.nn.batch_normalization(
                        adv_valid, mean, var, 0, 1, eps)
            if self.positive_adv:
                with tf.name_scope("positive_adv"):
                    m = tf.reduce_min(adv_valid)
                    adv_valid = (adv_valid - m) + eps

            # Calculate loss function and KL divergence
            with tf.name_scope("kl"):
                kl = pol_dist.kl_sym(
                    i.valid.policy_old_dist_info_vars,
                    policy_dist_info_valid,
                )
                pol_mean_kl = tf.reduce_mean(kl)

            # Calculate surrogate loss
            with tf.name_scope("surr_loss"):
                lr = pol_dist.likelihood_ratio_sym(
                    i.valid.action_var,
                    i.valid.policy_old_dist_info_vars,
                    policy_dist_info_valid,
                    name="lr")

                # Policy gradient surrogate objective
                surr_vanilla = lr * adv_valid

                if self._pg_loss == PGLoss.VANILLA:
                    # VPG, TRPO use the standard surrogate objective
                    surr_obj = tf.identity(surr_vanilla, name="surr_obj")
                elif self._pg_loss == PGLoss.CLIP:
                    # PPO uses a surrogate objective with clipped LR
                    lr_clip = tf.clip_by_value(lr,
                                               1 - self.step_size,
                                               1 + self.step_size,
                                               name="lr_clip")
                    surr_clip = lr_clip * adv_valid
                    surr_obj = tf.minimum(surr_vanilla,
                                          surr_clip,
                                          name="surr_obj")
                else:
                    raise NotImplementedError("Unknown PGLoss")

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                surr_loss = -tf.reduce_mean(surr_obj)

                # Embedding entropy bonus
                surr_loss -= self.embedding_ent_coeff * embedding_entropy

            embed_mean_kl = self._build_embedding_kl(i)

        # Diagnostic functions
        self.f_policy_kl = tensor_utils.compile_function(
            flatten_inputs(self._policy_opt_inputs),
            pol_mean_kl,
            log_name="f_policy_kl")

        self.f_rewards = tensor_utils.compile_function(flatten_inputs(
            self._policy_opt_inputs),
                                                       rewards,
                                                       log_name="f_rewards")

        # returns = self._build_returns(rewards)
        returns = discounted_returns(self.discount,
                                     self.max_path_length,
                                     rewards,
                                     name="returns")
        self.f_returns = tensor_utils.compile_function(flatten_inputs(
            self._policy_opt_inputs),
                                                       returns,
                                                       log_name="f_returns")

        return surr_loss, pol_mean_kl, embed_mean_kl

    def _build_entropy_terms(self, i):
        """ Calculate entropy terms """

        with tf.name_scope("entropy_terms"):
            # 1. Embedding distribution total entropy
            with tf.name_scope('embedding_entropy'):
                task_dim = self.policy.task_space.flat_dim
                all_task_one_hots = tf.one_hot(np.arange(task_dim),
                                               task_dim,
                                               name="all_task_one_hots")
                all_task_entropies = self.policy.embedding.entropy_sym(
                    all_task_one_hots)

                if self._use_softplus_entropy:
                    all_task_entropies = tf.nn.softplus(all_task_entropies)

                embedding_entropy = tf.reduce_mean(all_task_entropies,
                                                   name="embedding_entropy")

            # 2. Infernece distribution cross-entropy (log-likelihood)
            with tf.name_scope('inference_ce'):
                traj_ll_flat = self.inference.log_likelihood_sym(
                    i.flat.trajectory_var,
                    self.policy._embedding.latent_sym(i.flat.task_var),
                    name="traj_ll_flat")
                traj_ll = tf.reshape(traj_ll_flat, [-1, self.max_path_length],
                                     name="traj_ll")
                inference_ce_raw = -traj_ll
                inference_ce = tf.clip_by_value(inference_ce_raw, -3, 3)

                if self._use_softplus_entropy:
                    inference_ce = tf.nn.softplus(inference_ce)

                if self._stop_ce_graident:
                    inference = tf.stop_gradient(inference_ce)

            # 3. Policy path entropies
            with tf.name_scope('policy_entropy'):
                policy_entropy_flat = self.policy.entropy_sym(
                    i.task_var, i.obs_var, name="policy_entropy_flat")
                policy_entropy = tf.reshape(policy_entropy_flat,
                                            [-1, self.max_path_length],
                                            name="policy_entropy")

                if self._use_softplus_entropy:
                    policy_entropy = tf.nn.softplus(policy_entropy)

        # Diagnostic functions
        self.f_task_entropies = tensor_utils.compile_function(
            flatten_inputs(self._policy_opt_inputs),
            all_task_entropies,
            log_name="f_task_entropies")
        self.f_embedding_entropy = tensor_utils.compile_function(
            flatten_inputs(self._policy_opt_inputs),
            embedding_entropy,
            log_name="f_embedding_entropy")
        self.f_inference_ce = tensor_utils.compile_function(
            flatten_inputs(self._policy_opt_inputs),
            tf.reduce_mean(inference_ce * i.valid_var),
            log_name="f_inference_ce")
        self.f_policy_entropy = tensor_utils.compile_function(
            flatten_inputs(self._policy_opt_inputs),
            tf.reduce_mean(policy_entropy * i.valid_var),
            log_name="f_policy_entropy")

        return embedding_entropy, inference_ce, policy_entropy

    def _build_embedding_kl(self, i):
        dist = self.policy._embedding._dist
        with tf.name_scope("embedding_kl"):
            # new distribution
            embed_dist_info_flat = self.policy._embedding.dist_info_sym(
                i.flat.task_var,
                i.flat.embed_state_info_vars,
                name="embed_dist_info_flat")
            embed_dist_info_valid = filter_valids_dict(
                embed_dist_info_flat,
                i.flat.valid_var,
                name="embed_dist_info_valid")

            # calculate KL divergence
            kl = dist.kl_sym(i.valid.embed_old_dist_info_vars,
                             embed_dist_info_valid)
            mean_kl = tf.reduce_mean(kl)

            # Diagnostic function
            self.f_embedding_kl = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                mean_kl,
                log_name="f_embedding_kl")

            return mean_kl

    def _build_inference_loss(self, i):
        """ Build loss function for the inference network """

        infer_dist = self.inference._dist
        with tf.name_scope("infer_loss"):
            traj_ll_flat = self.inference.log_likelihood_sym(
                i.flat.trajectory_var, i.flat.latent_var, name="traj_ll_flat")
            traj_ll = tf.reshape(traj_ll_flat, [-1, self.max_path_length],
                                 name="traj_ll")

            # Calculate loss
            traj_gammas = tf.constant(float(self.discount),
                                      dtype=tf.float32,
                                      shape=[self.max_path_length])
            traj_discounts = tf.cumprod(traj_gammas,
                                        exclusive=True,
                                        name="traj_discounts")
            discount_traj_ll = traj_discounts * traj_ll
            discount_traj_ll_flat = flatten_batch(discount_traj_ll,
                                                  name="discount_traj_ll_flat")
            discount_traj_ll_valid = filter_valids(
                discount_traj_ll_flat,
                i.flat.valid_var,
                name="discount_traj_ll_valid")

            with tf.name_scope("loss"):
                infer_loss = -tf.reduce_mean(discount_traj_ll_valid,
                                             name="infer_loss")

            with tf.name_scope("kl"):
                # Calculate predicted embedding distributions for each timestep
                infer_dist_info_flat = self.inference.dist_info_sym(
                    i.flat.trajectory_var,
                    i.flat.infer_state_info_vars,
                    name="infer_dist_info_flat")

                infer_dist_info_valid = filter_valids_dict(
                    infer_dist_info_flat,
                    i.flat.valid_var,
                    name="infer_dist_info_valid")

                # Calculate KL divergence
                kl = infer_dist.kl_sym(i.valid.infer_old_dist_info_vars,
                                       infer_dist_info_valid)
                infer_kl = tf.reduce_mean(kl, name="infer_kl")

            return infer_loss, infer_kl

    #### Sampling and training ################################################
    def _policy_opt_input_values(self, samples_data):
        """ Map rollout samples to the policy optimizer inputs """

        policy_state_info_list = [
            samples_data["agent_infos"][k] for k in self.policy.state_info_keys
        ]
        policy_old_dist_info_list = [
            samples_data["agent_infos"][k]
            for k in self.policy._dist.dist_info_keys
        ]
        embed_state_info_list = [
            samples_data["latent_infos"][k]
            for k in self.policy.embedding.state_info_keys
        ]
        embed_old_dist_info_list = [
            samples_data["latent_infos"][k]
            for k in self.policy.embedding._dist.dist_info_keys
        ]
        policy_opt_input_values = self._policy_opt_inputs._replace(
            obs_var=samples_data["observations"],
            action_var=samples_data["actions"],
            reward_var=samples_data["rewards"],
            baseline_var=samples_data["baselines"],
            trajectory_var=samples_data["trajectories"],
            task_var=samples_data["tasks"],
            latent_var=samples_data["latents"],
            valid_var=samples_data["valids"],
            policy_state_info_vars_list=policy_state_info_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_list,
            embed_state_info_vars_list=embed_state_info_list,
            embed_old_dist_info_vars_list=embed_old_dist_info_list,
        )
        return flatten_inputs(policy_opt_input_values)

    def _inference_opt_input_values(self, samples_data):
        """ Map rollout samples to the inference optimizer inputs """

        infer_state_info_list = [
            samples_data["trajectory_infos"][k]
            for k in self.inference.state_info_keys
        ]
        infer_old_dist_info_list = [
            samples_data["trajectory_infos"][k]
            for k in self.inference._dist.dist_info_keys
        ]
        inference_opt_input_values = self._inference_opt_inputs._replace(
            latent_var=samples_data["latents"],
            trajectory_var=samples_data["trajectories"],
            valid_var=samples_data["valids"],
            infer_state_info_vars_list=infer_state_info_list,
            infer_old_dist_info_vars_list=infer_old_dist_info_list,
        )

        return flatten_inputs(inference_opt_input_values)

    def evaluate(self, policy_opt_input_values, samples_data):
        # Everything else
        rewards_tensor = self.f_rewards(*policy_opt_input_values)
        returns_tensor = self.f_returns(*policy_opt_input_values)
        returns_tensor = np.squeeze(returns_tensor)  # TODO
        # TODO: check the squeeze/dimension handling for both convolutions

        paths = samples_data['paths']
        valids = samples_data['valids']
        baselines = [path['baselines'] for path in paths]
        env_rewards = [path['rewards'] for path in paths]
        env_rewards = tensor_utils.concat_tensor_list(env_rewards.copy())
        env_returns = [path['returns'] for path in paths]
        env_returns = tensor_utils.concat_tensor_list(env_returns.copy())
        env_average_discounted_return = \
            np.mean([path["returns"][0] for path in paths])

        # Recompute parts of samples_data
        aug_rewards = []
        aug_returns = []
        for rew, ret, val, path in zip(rewards_tensor, returns_tensor, valids,
                                       paths):
            path['rewards'] = rew[val.astype(np.bool)]
            path['returns'] = ret[val.astype(np.bool)]
            aug_rewards.append(path['rewards'])
            aug_returns.append(path['returns'])
        aug_rewards = tensor_utils.concat_tensor_list(aug_rewards)
        aug_returns = tensor_utils.concat_tensor_list(aug_returns)
        samples_data['rewards'] = aug_rewards
        samples_data['returns'] = aug_returns

        # Calculate effect of the entropy terms
        d_rewards = np.mean(aug_rewards - env_rewards)
        logger.record_tabular('Policy/EntRewards', d_rewards)

        aug_average_discounted_return = \
            np.mean([path["returns"][0] for path in paths])
        d_returns = np.mean(aug_average_discounted_return -
                            env_average_discounted_return)
        logger.record_tabular('Policy/EntReturns', d_returns)

        # Calculate explained variance
        ev = special.explained_variance_1d(np.concatenate(baselines),
                                           aug_returns)
        logger.record_tabular('Baseline/ExplainedVariance', ev)

        inference_rmse = (samples_data['trajectory_infos']['mean'] -
                          samples_data['latents'])**2.
        inference_rmse = np.sqrt(inference_rmse.mean())
        logger.record_tabular('Inference/RMSE', inference_rmse)

        inference_rrse = rrse(samples_data['latents'],
                              samples_data['trajectory_infos']['mean'])
        logger.record_tabular('Inference/RRSE', inference_rrse)

        embed_ent = self.f_embedding_entropy(*policy_opt_input_values)
        logger.record_tabular('Embedding/Entropy', embed_ent)

        infer_ce = self.f_inference_ce(*policy_opt_input_values)
        logger.record_tabular('Inference/CrossEntropy', infer_ce)

        pol_ent = self.f_policy_entropy(*policy_opt_input_values)
        logger.record_tabular('Policy/Entropy', pol_ent)

        task_ents = self.f_task_entropies(*policy_opt_input_values)
        tasks = samples_data["tasks"][:, 0, :]
        _, task_indices = np.nonzero(tasks)
        path_lengths = np.sum(samples_data["valids"], axis=1)
        for t in range(self.policy.task_space.flat_dim):
            lengths = path_lengths[task_indices == t]
            completed = lengths < self.max_path_length
            pct_completed = np.mean(completed)
            num_samples = np.sum(lengths)
            num_trajs = lengths.shape[0]
            logger.record_tabular('Tasks/EpisodeLength/t={}'.format(t),
                                  np.mean(lengths))
            logger.record_tabular('Tasks/CompletionRate/t={}'.format(t),
                                  pct_completed)
            # logger.record_tabular('Tasks/NumSamples/t={}'.format(t),
            #                       num_samples)
            # logger.record_tabular('Tasks/NumTrajs/t={}'.format(t), num_trajs)
            logger.record_tabular('Tasks/Entropy/t={}'.format(t), task_ents[t])

        return samples_data

    def visualize_distribution(self, samples_data):
        """ Visualize embedding distribution """

        # distributions
        num_tasks = self.policy.task_space.flat_dim
        all_tasks = np.eye(num_tasks, num_tasks)
        _, latent_infos = self.policy._embedding.get_latents(all_tasks)
        for i in range(self.policy.latent_space.flat_dim):
            log_stds = latent_infos["log_std"][:, i]
            if self.policy.embedding._std_parameterization == "exp":
                stds = np.exp(log_stds)
            elif self.policy.embedding._std_parameterization == "softplus":
                stds = np.log(1. + log_stds)
            else:
                raise NotImplementedError
            logger.record_histogram_by_type("normal",
                                            shape=[1000, num_tasks],
                                            key="Embedding/i={}".format(i),
                                            mean=latent_infos["mean"][:, i],
                                            stddev=stds)

        num_traj = self.batch_size // self.max_path_length
        # # samples
        # latents = samples_data["latents"][:num_traj, 0]
        # for i in range(self.policy.latent_space.flat_dim):
        #     logger.record_histogram("Embedding/samples/i={}".format(i),
        #                             latents[:, i])

        # action distributions
        actions = samples_data["actions"][:num_traj, ...]
        logger.record_histogram("Actions", actions)

    def train_policy_and_embedding_networks(self, policy_opt_input_values):
        """ Joint optimization of policy and embedding networks """

        logger.log("Computing loss before")
        loss_before = self.optimizer.loss(policy_opt_input_values)

        logger.log("Computing KL before")
        policy_kl_before = self.f_policy_kl(*policy_opt_input_values)
        embed_kl_before = self.f_embedding_kl(*policy_opt_input_values)

        logger.log("Optimizing")
        self.optimizer.optimize(policy_opt_input_values)

        logger.log("Computing KL after")
        policy_kl = self.f_policy_kl(*policy_opt_input_values)
        embed_kl = self.f_embedding_kl(*policy_opt_input_values)

        logger.log("Computing loss after")
        loss_after = self.optimizer.loss(policy_opt_input_values)

        logger.record_tabular('Policy/LossBefore', loss_before)
        logger.record_tabular('Policy/LossAfter', loss_after)
        logger.record_tabular('Policy/KLBefore', policy_kl_before)
        logger.record_tabular('Policy/KL', policy_kl)
        logger.record_tabular('Policy/dLoss', loss_before - loss_after)
        logger.record_tabular('Embedding/KLBefore', embed_kl_before)
        logger.record_tabular('Embedding/KL', embed_kl)

        return loss_after

    def train_inference_network(self, inference_opt_input_values):
        """ Optimize inference network """

        logger.log("Optimizing inference network...")
        infer_loss_before = self.inference_optimizer.loss(
            inference_opt_input_values)
        logger.record_tabular('Inference/Loss', infer_loss_before)
        self.inference_optimizer.optimize(inference_opt_input_values)
        infer_loss_after = self.inference_optimizer.loss(
            inference_opt_input_values)
        logger.record_tabular('Inference/dLoss',
                              infer_loss_before - infer_loss_after)

        return infer_loss_after

    def save_samples(self, itr, samples_data):
        with open(osp.join(logger.get_snapshot_dir(), 'samples_%i.pkl' % itr),
                  "wb") as fout:
            pickle.dump(samples_data, fout)

    @overrides
    def optimize_policy(self, itr, **kwargs):
        paths = self.obtain_samples(itr)

        samples_data = self.process_samples(itr, paths)
        if self._save_sample_frequency > 0 and itr % self._save_sample_frequency == 0:
            self.save_samples(itr, samples_data)
        self.log_diagnostics(paths)

        policy_opt_input_values = self._policy_opt_input_values(samples_data)
        inference_opt_input_values = self._inference_opt_input_values(
            samples_data)

        self.train_policy_and_embedding_networks(policy_opt_input_values)
        self.train_inference_network(inference_opt_input_values)

        samples_data = self.evaluate(policy_opt_input_values, samples_data)
        self.visualize_distribution(samples_data)

        # Fit baseline
        logger.log("Fitting baseline...")
        if hasattr(self.baseline, 'fit_with_samples'):
            self.baseline.fit_with_samples(paths, samples_data)
        else:
            self.baseline.fit(paths)

        return self.get_itr_snapshot(itr, samples_data)

    @overrides
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())

        self.start_worker(sess)
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                params = self.optimize_policy(itr, )
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
                logger.log("Saving snapshot...")
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('IterTime', time.time() - itr_start_time)
                logger.record_tabular('Time', time.time() - start_time)
                logger.dump_tabular()
        self.shutdown_worker()
        if created_session:
            sess.close()

    @overrides
    def get_itr_snapshot(self, itr, _samples_data):
        return dict(
            itr=itr,
            policy=self.policy,
            # baseline=self.baseline,
            env=self.env,
            inference=self.inference,
        )
Ejemplo n.º 3
0
class LocalRunner:
    """Base class of local runner.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        max_cpus (int): The maximum number of parallel sampler workers.

    Note:
        For the use of any TensorFlow environments, policies and algorithms,
        please use LocalTFRunner().

    Examples:
        | # to train
        | runner = LocalRunner()
        | env = Env(...)
        | policy = Policy(...)
        | algo = Algo(
        |         env=env,
        |         policy=policy,
        |         ...)
        | runner.setup(algo, env)
        | runner.train(n_epochs=100, batch_size=4000)

        | # to resume immediately.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume()

        | # to resume with modified training arguments.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume(n_epochs=20)

    """
    def __init__(self, snapshot_config, max_cpus=1):
        self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                        snapshot_config.snapshot_mode,
                                        snapshot_config.snapshot_gap)

        parallel_sampler.initialize(max_cpus)

        seed = get_seed()
        if seed is not None:
            parallel_sampler.set_seed(seed)

        self._has_setup = False
        self._plot = False

        self._setup_args = None
        self._train_args = None
        self._stats = ExperimentStats(total_itr=0,
                                      total_env_steps=0,
                                      total_epoch=0,
                                      last_path=None)

        self._algo = None
        self._env = None
        self._policy = None
        self._sampler = None
        self._plotter = None

        self._start_time = None
        self._itr_start_time = None
        self.step_itr = None
        self.step_path = None

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.

        """
        self._algo = algo
        self._env = env
        self._policy = self._algo.policy

        if sampler_args is None:
            sampler_args = {}
        if sampler_cls is None:
            sampler_cls = algo.sampler_cls
        self._sampler = sampler_cls(algo, env, **sampler_args)

        self._has_setup = True

        self._setup_args = SetupArgs(sampler_cls=sampler_cls,
                                     sampler_args=sampler_args,
                                     seed=get_seed())

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self._sampler.start_worker()
        if self._plot:
            # pylint: disable=import-outside-toplevel
            from garage.tf.plotter import Plotter
            self._plotter = Plotter(self._env, self._policy)
            self._plotter.start()

    def _shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self._sampler.shutdown_worker()
        if self._plot:
            self._plotter.close()

    def obtain_samples(self, itr, batch_size=None):
        """Obtain one batch of samples.

        Args:
            itr (int): Index of iteration (epoch).
            batch_size (int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            list[dict]: One batch of samples.

        """
        paths = self._sampler.obtain_samples(
            itr, (batch_size or self._train_args.batch_size))

        self._stats.total_env_steps += sum([len(p['rewards']) for p in paths])

        return paths

    def save(self, epoch):
        """Save snapshot of current batch.

        Args:
            epoch (int): Epoch.

        Raises:
            NotSetupError: if save() is called before the runner is set up.

        """
        if not self._has_setup:
            raise NotSetupError('Use setup() to setup runner before saving.')

        logger.log('Saving snapshot...')

        params = dict()
        # Save arguments
        params['setup_args'] = self._setup_args
        params['train_args'] = self._train_args
        params['stats'] = self._stats

        # Save states
        params['env'] = self._env
        params['algo'] = self._algo

        self._snapshotter.save_itr_params(epoch, params)

        logger.log('Saved')

    def restore(self, from_dir, from_epoch='last'):
        """Restore experiment from snapshot.

        Args:
            from_dir (str): Directory of the pickle file
                to resume experiment from.
            from_epoch (str or int): The epoch to restore from.
                Can be 'first', 'last' or a number.
                Not applicable when snapshot_mode='last'.

        Returns:
            TrainArgs: Arguments for train().

        """
        saved = self._snapshotter.load(from_dir, from_epoch)

        self._setup_args = saved['setup_args']
        self._train_args = saved['train_args']
        self._stats = saved['stats']

        set_seed(self._setup_args.seed)

        self.setup(env=saved['env'],
                   algo=saved['algo'],
                   sampler_cls=self._setup_args.sampler_cls,
                   sampler_args=self._setup_args.sampler_args)

        n_epochs = self._train_args.n_epochs
        last_epoch = self._stats.total_epoch
        last_itr = self._stats.total_itr
        total_env_steps = self._stats.total_env_steps
        batch_size = self._train_args.batch_size
        store_paths = self._train_args.store_paths
        pause_for_plot = self._train_args.pause_for_plot

        fmt = '{:<20} {:<15}'
        logger.log('Restore from snapshot saved in %s' %
                   self._snapshotter.snapshot_dir)
        logger.log(fmt.format('-- Train Args --', '-- Value --'))
        logger.log(fmt.format('n_epochs', n_epochs))
        logger.log(fmt.format('last_epoch', last_epoch))
        logger.log(fmt.format('batch_size', batch_size))
        logger.log(fmt.format('store_paths', store_paths))
        logger.log(fmt.format('pause_for_plot', pause_for_plot))
        logger.log(fmt.format('-- Stats --', '-- Value --'))
        logger.log(fmt.format('last_itr', last_itr))
        logger.log(fmt.format('total_env_steps', total_env_steps))

        self._train_args.start_epoch = last_epoch + 1
        return copy.copy(self._train_args)

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot (bool): Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self._start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time))
        logger.log(tabular)
        if self._plot:
            self._plotter.update_plot(self._policy, self._algo.max_path_length)
            if pause_for_plot:
                input('Plotting evaluation run: Press Enter to " "continue...')

    def train(self,
              n_epochs,
              batch_size,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs (int): Number of epochs.
            batch_size (int): Number of environment steps in one batch.
            plot (bool): Visualize policy by doing rollout after each epoch.
            store_paths (bool): Save paths in snapshot.
            pause_for_plot (bool): Pause for plot.

        Raises:
            NotSetupError: If train() is called before setup().

        Returns:
            float: The average return in last epoch cycle.

        """
        if not self._has_setup:
            raise NotSetupError('Use setup() to setup runner before training.')

        # Save arguments for restore
        self._train_args = TrainArgs(n_epochs=n_epochs,
                                     batch_size=batch_size,
                                     plot=plot,
                                     store_paths=store_paths,
                                     pause_for_plot=pause_for_plot,
                                     start_epoch=0)

        self._plot = plot

        return self._algo.train(self)

    def step_epochs(self):
        """Step through each epoch.

        This function returns a magic generator. When iterated through, this
        generator automatically performs services such as snapshotting and log
        management. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        self._start_worker()
        self._start_time = time.time()
        self.step_itr = self._stats.total_itr
        self.step_path = None

        # Used by integration tests to ensure examples can run one epoch.
        n_epochs = int(
            os.environ.get('GARAGE_EXAMPLE_TEST_N_EPOCHS',
                           self._train_args.n_epochs))

        logger.log('Obtaining samples...')

        for epoch in range(self._train_args.start_epoch, n_epochs):
            self._itr_start_time = time.time()
            with logger.prefix('epoch #%d | ' % epoch):
                yield epoch
                save_path = (self.step_path
                             if self._train_args.store_paths else None)

                self._stats.last_path = save_path
                self._stats.total_epoch = epoch
                self._stats.total_itr = self.step_itr

                self.save(epoch)
                self.log_diagnostics(self._train_args.pause_for_plot)
                logger.dump_all(self.step_itr)
                tabular.clear()

    def resume(self,
               n_epochs=None,
               batch_size=None,
               plot=None,
               store_paths=None,
               pause_for_plot=None):
        """Resume from restored experiment.

        This method provides the same interface as train().

        If not specified, an argument will default to the
        saved arguments from the last call to train().

        Args:
            n_epochs (int): Number of epochs.
            batch_size (int): Number of environment steps in one batch.
            plot (bool): Visualize policy by doing rollout after each epoch.
            store_paths (bool): Save paths in snapshot.
            pause_for_plot (bool): Pause for plot.

        Raises:
            NotSetupError: If resume() is called before restore().

        Returns:
            float: The average return in last epoch cycle.

        """
        if self._train_args is None:
            raise NotSetupError('You must call restore() before resume().')

        self._train_args.n_epochs = n_epochs or self._train_args.n_epochs
        self._train_args.batch_size = batch_size or self._train_args.batch_size

        if plot is not None:
            self._train_args.plot = plot
        if store_paths is not None:
            self._train_args.store_paths = store_paths
        if pause_for_plot is not None:
            self._train_args.pause_for_plot = pause_for_plot

        return self._algo.train(self)

    def get_env_copy(self):
        """Get a copy of the environment.

        Returns:
            garage.envs.GarageEnv: An environement instance.

        """
        return pickle.loads(pickle.dumps(self._env))

    @property
    def total_env_steps(self):
        """Total environment steps collected.

        Returns:
            int: Total environment steps collected.

        """
        return self._stats.total_env_steps
Ejemplo n.º 4
0
class OffPolicyRLAlgorithm(RLAlgorithm):
    """This class implements OffPolicyRLAlgorithm."""
    def __init__(
        self,
        env,
        policy,
        qf,
        replay_buffer,
        use_target=False,
        discount=0.99,
        n_epochs=500,
        n_epoch_cycles=20,
        max_path_length=100,
        n_train_steps=50,
        buffer_batch_size=64,
        min_buffer_size=int(1e4),
        rollout_batch_size=1,
        reward_scale=1.,
        input_include_goal=False,
        smooth_return=True,
        sampler_cls=None,
        sampler_args=None,
        force_batch_sampler=False,
        plot=False,
        pause_for_plot=False,
        exploration_strategy=None,
    ):
        """Construct an OffPolicyRLAlgorithm class."""
        self.env = env
        self.policy = policy
        self.qf = qf
        self.replay_buffer = replay_buffer
        self.n_epochs = n_epochs
        self.n_epoch_cycles = n_epoch_cycles
        self.n_train_steps = n_train_steps
        self.buffer_batch_size = buffer_batch_size
        self.use_target = use_target
        self.discount = discount
        self.min_buffer_size = min_buffer_size
        self.rollout_batch_size = rollout_batch_size
        self.reward_scale = reward_scale
        self.evaluate = False
        self.input_include_goal = input_include_goal
        self.smooth_return = smooth_return
        if sampler_cls is None:
            if policy.vectorized and not force_batch_sampler:
                sampler_cls = OffPolicyVectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()
        self.sampler = sampler_cls(self, **sampler_args)
        self.max_path_length = max_path_length
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.es = exploration_strategy
        self.init_opt()

    def start_worker(self, sess):
        """Initialize sampler and plotter."""
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env, self.policy, sess)
            self.plotter.start()

    def shutdown_worker(self):
        """Close sampler and plotter."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr):
        """Sample data for this iteration."""
        return self.sampler.obtain_samples(itr)

    def process_samples(self, itr, paths):
        """Process samples from rollout paths."""
        return self.sampler.process_samples(itr, paths)

    def log_diagnostics(self, paths):
        """Log diagnostic information on current paths."""
        self.policy.log_diagnostics(paths)
        self.qf.log_diagnostics(paths)

    def init_opt(self):
        """
        Initialize the optimization procedure.

        If using tensorflow, this may
        include declaring all the variables and compiling functions.
        """
        raise NotImplementedError

    def get_itr_snapshot(self, itr, samples_data):
        """Return data saved in the snapshot for this iteration."""
        raise NotImplementedError

    def optimize_policy(self, itr, samples_data):
        """Optimize policy network."""
        raise NotImplementedError
Ejemplo n.º 5
0
class LocalRunner:
    """Base class of local runner.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        max_cpus (int): The maximum number of parallel sampler workers.

    Note:
        For the use of any TensorFlow environments, policies and algorithms,
        please use LocalTFRunner().

    Examples:
        | # to train
        | runner = LocalRunner()
        | env = Env(...)
        | policy = Policy(...)
        | algo = Algo(
        |         env=env,
        |         policy=policy,
        |         ...)
        | runner.setup(algo, env)
        | runner.train(n_epochs=100, batch_size=4000)

        | # to resume immediately.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume()

        | # to resume with modified training arguments.
        | runner = LocalRunner()
        | runner.restore(resume_from_dir)
        | runner.resume(n_epochs=20)

    """

    def __init__(self, snapshot_config, max_cpus=1):
        self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                        snapshot_config.snapshot_mode,
                                        snapshot_config.snapshot_gap)

        if max_cpus > 1:
            from garage.sampler import singleton_pool
            singleton_pool.initialize(max_cpus)
        self.has_setup = False
        self.plot = False

        self._setup_args = None
        self.train_args = None

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.

        """
        self.algo = algo
        self.env = env
        self.policy = self.algo.policy

        if sampler_args is None:
            sampler_args = {}
        if sampler_cls is None:
            sampler_cls = algo.sampler_cls
        self.sampler = sampler_cls(algo, env, **sampler_args)

        self.has_setup = True

        self._setup_args = types.SimpleNamespace(sampler_cls=sampler_cls,
                                                 sampler_args=sampler_args)

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self.sampler.start_worker()
        if self.plot:
            from garage.tf.plotter import Plotter
            self.plotter = Plotter(self.env, self.policy)
            self.plotter.start()

    def _shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr, batch_size=None):
        """Obtain one batch of samples.

        Args:
            itr(int): Index of iteration (epoch).
            batch_size(int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            One batch of samples.

        """
        if self.train_args.n_epoch_cycles == 1:
            logger.log('Obtaining samples...')
        return self.sampler.obtain_samples(
            itr, (batch_size or self.train_args.batch_size))

    def save(self, epoch, paths=None):
        """Save snapshot of current batch.

        Args:
            itr(int): Index of iteration (epoch).
            paths(dict): Batch of samples after preprocessed. If None,
                no paths will be logged to the snapshot.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before saving.')

        logger.log('Saving snapshot...')

        params = dict()
        # Save arguments
        params['setup_args'] = self._setup_args
        params['train_args'] = self.train_args

        # Save states
        params['env'] = self.env
        params['algo'] = self.algo
        if paths:
            params['paths'] = paths
        params['last_epoch'] = epoch
        self._snapshotter.save_itr_params(epoch, params)

        logger.log('Saved')

    def restore(self, from_dir, from_epoch='last'):
        """Restore experiment from snapshot.

        Args:
            from_dir (str): Directory of the pickle file
                to resume experiment from.
            from_epoch (str or int): The epoch to restore from.
                Can be 'first', 'last' or a number.
                Not applicable when snapshot_mode='last'.

        Returns:
            A SimpleNamespace for train()'s arguments.

        """
        saved = self._snapshotter.load(from_dir, from_epoch)

        self._setup_args = saved['setup_args']
        self.train_args = saved['train_args']

        self.setup(env=saved['env'],
                   algo=saved['algo'],
                   sampler_cls=self._setup_args.sampler_cls,
                   sampler_args=self._setup_args.sampler_args)

        n_epochs = self.train_args.n_epochs
        last_epoch = saved['last_epoch']
        n_epoch_cycles = self.train_args.n_epoch_cycles
        batch_size = self.train_args.batch_size
        store_paths = self.train_args.store_paths
        pause_for_plot = self.train_args.pause_for_plot

        fmt = '{:<20} {:<15}'
        logger.log('Restore from snapshot saved in %s' %
                   self._snapshotter.snapshot_dir)
        logger.log(fmt.format('Train Args', 'Value'))
        logger.log(fmt.format('n_epochs', n_epochs))
        logger.log(fmt.format('last_epoch', last_epoch))
        logger.log(fmt.format('n_epoch_cycles', n_epoch_cycles))
        logger.log(fmt.format('batch_size', batch_size))
        logger.log(fmt.format('store_paths', store_paths))
        logger.log(fmt.format('pause_for_plot', pause_for_plot))

        self.train_args.start_epoch = last_epoch + 1
        return copy.copy(self.train_args)

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot(bool): Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self._start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time))
        logger.log(tabular)
        if self.plot:
            self.plotter.update_plot(self.policy, self.algo.max_path_length)
            if pause_for_plot:
                input('Plotting evaluation run: Press Enter to " "continue...')

    def train(self,
              n_epochs,
              batch_size,
              n_epoch_cycles=1,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs(int): Number of epochs.
            batch_size(int): Number of environment steps in one batch.
            n_epoch_cycles(int): Number of batches of samples in each epoch.
                This is only useful for off-policy algorithm.
                For on-policy algorithm this value should always be 1.
            plot(bool): Visualize policy by doing rollout after each epoch.
            store_paths(bool): Save paths in snapshot.
            pause_for_plot(bool): Pause for plot.

        Returns:
            The average return in last epoch cycle.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before training.')

        # Save arguments for restore
        self.train_args = types.SimpleNamespace(n_epochs=n_epochs,
                                                n_epoch_cycles=n_epoch_cycles,
                                                batch_size=batch_size,
                                                plot=plot,
                                                store_paths=store_paths,
                                                pause_for_plot=pause_for_plot,
                                                start_epoch=0)

        self.plot = plot

        return self.algo.train(self)

    def step_epochs(self):
        """Step through each epoch.

        This function returns a magic generator. When iterated through, this
        generator automatically performs services such as snapshotting and log
        management. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        try:
            self._start_worker()
            self._start_time = time.time()
            self.step_itr = (self.train_args.start_epoch *
                             self.train_args.n_epoch_cycles)
            self.step_path = None

            for epoch in range(self.train_args.start_epoch,
                               self.train_args.n_epochs):
                self._itr_start_time = time.time()
                with logger.prefix('epoch #%d | ' % epoch):
                    yield epoch
                    save_path = (self.step_path
                                 if self.train_args.store_paths else None)
                    self.save(epoch, save_path)
                    self.log_diagnostics(self.train_args.pause_for_plot)
                    logger.dump_all(self.step_itr)
                    tabular.clear()
        finally:
            self._shutdown_worker()

    def resume(self,
               n_epochs=None,
               batch_size=None,
               n_epoch_cycles=None,
               plot=None,
               store_paths=None,
               pause_for_plot=None):
        """Resume from restored experiment.

        This method provides the same interface as train().

        If not specified, an argument will default to the
        saved arguments from the last call to train().

        Returns:
            The average return in last epoch cycle.

        """
        if self.train_args is None:
            raise Exception('You must call restore() before resume().')

        self.train_args.n_epochs = n_epochs or self.train_args.n_epochs
        self.train_args.batch_size = batch_size or self.train_args.batch_size
        self.train_args.n_epoch_cycles = (n_epoch_cycles
                                          or self.train_args.n_epoch_cycles)

        if plot is not None:
            self.train_args.plot = plot
        if store_paths is not None:
            self.train_args.store_paths = store_paths
        if pause_for_plot is not None:
            self.train_args.pause_for_plot = pause_for_plot

        return self.algo.train(self)
Ejemplo n.º 6
0
class BatchPolopt(RLAlgorithm):
    """
    Base class for batch sampling-based policy optimization methods.
    This includes various policy gradient methods like vpg, npg, ppo, trpo,
    etc.
    """
    def __init__(self,
                 env,
                 policy,
                 baseline,
                 scope=None,
                 n_itr=500,
                 start_itr=0,
                 batch_size=5000,
                 max_path_length=500,
                 discount=0.99,
                 gae_lambda=1,
                 plot=False,
                 pause_for_plot=False,
                 center_adv=True,
                 positive_adv=False,
                 store_paths=False,
                 whole_paths=True,
                 fixed_horizon=False,
                 sampler_cls=None,
                 sampler_args=None,
                 force_batch_sampler=False,
                 **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :type policy: Policy
        :param baseline: Baseline
        :param scope: Scope for identifying the algorithm. Must be specified if
         running multiple algorithms
        simultaneously, each using different environments and policies
        :param n_itr: Number of iterations.
        :param start_itr: Starting iteration.
        :param batch_size: Number of samples per iteration.
        :param max_path_length: Maximum length of a single rollout.
        :param discount: Discount.
        :param gae_lambda: Lambda used for generalized advantage estimation.
        :param plot: Plot evaluation run after each iteration.
        :param pause_for_plot: Whether to pause before contiuing when plotting.
        :param center_adv: Whether to rescale the advantages so that they have
         mean 0 and standard deviation 1.
        :param positive_adv: Whether to shift the advantages so that they are
         always positive. When used in conjunction with center_adv the
         advantages will be standardized before shifting.
        :param store_paths: Whether to save all paths data to the snapshot.
        :return:
        """
        self.env = env
        self.policy = policy
        self.baseline = baseline
        self.scope = scope
        self.n_itr = n_itr
        self.start_itr = start_itr
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.center_adv = center_adv
        self.positive_adv = positive_adv
        self.store_paths = store_paths
        self.whole_paths = whole_paths
        self.fixed_horizon = fixed_horizon
        if sampler_cls is None:
            if self.policy.vectorized and not force_batch_sampler:
                sampler_cls = OnPolicyVectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()
        self.sampler = sampler_cls(self, **sampler_args)
        self.init_opt()

    def start_worker(self, sess):
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env, self.policy, sess)
            self.plotter.start()

    def shutdown_worker(self):
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr):
        return self.sampler.obtain_samples(itr)

    def process_samples(self, itr, paths):
        return self.sampler.process_samples(itr, paths)

    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker(sess)
        start_time = time.time()
        last_average_return = None
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                last_average_return = samples_data["average_return"]
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
        if created_session:
            sess.close()
        return last_average_return

    def log_diagnostics(self, paths):
        self.policy.log_diagnostics(paths)
        self.baseline.log_diagnostics(paths)

    def init_opt(self):
        """
        Initialize the optimization procedure. If using tensorflow, this may
        include declaring all the variables and compiling functions
        """
        raise NotImplementedError

    def get_itr_snapshot(self, itr, samples_data):
        """
        Returns all the data that should be saved in the snapshot for this
        iteration.
        """
        raise NotImplementedError

    def optimize_policy(self, itr, samples_data):
        raise NotImplementedError
Ejemplo n.º 7
0
class LocalRunner:
    """This class implements a local runner for tensorflow algorithms.

    A local runner provides a default tensorflow session using python context.
    This is useful for those experiment components (e.g. policy) that require a
    tensorflow session during construction.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Examples:
        with LocalRunner() as runner:
            env = gym.make('CartPole-v1')
            policy = CategoricalMLPPolicy(
                env_spec=env.spec,
                hidden_sizes=(32, 32))
            algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                max_path_length=100,
                discount=0.99,
                max_kl_step=0.01)
            runner.setup(algo, env)
            runner.train(n_epochs=100, batch_size=4000)

    """
    def __init__(self, sess=None, max_cpus=1):
        """Create a new local runner.

        Args:
            max_cpus: The maximum number of parallel sampler workers.
            sess: An optional tensorflow session.
                  A new session will be created immediately if not provided.

        Note:
            The local runner will set up a joblib task pool of size max_cpus
            possibly later used by BatchSampler. If BatchSampler is not used,
            the processes in the pool will remain dormant.

            This setup is required to use tensorflow in a multiprocess
            environment before a tensorflow session is created
            because tensorflow is not fork-safe.

            See https://github.com/tensorflow/tensorflow/issues/2448.

        """
        if max_cpus > 1:
            from garage.sampler import singleton_pool
            singleton_pool.initialize(max_cpus)
        self.sess = sess or tf.Session()
        self.has_setup = False
        self.plot = False

    def __enter__(self):
        """Set self.sess as the default session.

        Returns:
            This local runner.

        """
        if tf.get_default_session() is not self.sess:
            self.sess.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Leave session."""
        if tf.get_default_session() is self.sess:
            self.sess.__exit__(exc_type, exc_val, exc_tb)

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo: An algorithm instance.
            env: An environement instance.
            sampler_cls: A sampler class.
            sampler_args: Arguments to be passed to sampler constructor.

        """
        self.algo = algo
        self.env = env
        self.policy = self.algo.policy

        if sampler_args is None:
            sampler_args = {}

        if sampler_cls is None:
            from garage.tf.algos.batch_polopt import BatchPolopt
            if isinstance(algo, BatchPolopt):
                if self.policy.vectorized:
                    from garage.tf.samplers import OnPolicyVectorizedSampler
                    sampler_cls = OnPolicyVectorizedSampler
                else:
                    from garage.tf.samplers import BatchSampler
                    sampler_cls = BatchSampler
            else:
                from garage.tf.samplers import OffPolicyVectorizedSampler
                sampler_cls = OffPolicyVectorizedSampler

        self.sampler = sampler_cls(algo, **sampler_args)

        self.initialize_tf_vars()
        self.has_setup = True

    def initialize_tf_vars(self):
        """Initialize all uninitialized variables in session."""
        self.sess.run(
            tf.variables_initializer([
                v for v in tf.global_variables()
                if v.name.split(':')[0] in str(
                    self.sess.run(tf.report_uninitialized_variables()))
            ]))

    def start_worker(self):
        """Start Plotter and Sampler workers."""
        self.sampler.start_worker()
        if self.plot:
            from garage.tf.plotter import Plotter
            self.plotter = Plotter(self.env, self.policy)
            self.plotter.start()

    def shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr, batch_size):
        """Obtain one batch of samples.

        Args:
            itr: Index of iteration (epoch).
            batch_size: Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            One batch of samples.

        """
        if self.n_epoch_cycles == 1:
            logger.log("Obtaining samples...")
        return self.sampler.obtain_samples(itr, batch_size)

    def save_snapshot(self, itr, paths=None):
        """Save snapshot of current batch.

        Args:
            itr: Index of iteration (epoch).
            paths: Batch of samples after preprocessed.

        """
        logger.log("Saving snapshot...")
        params = self.algo.get_itr_snapshot(itr, paths)
        if paths:
            params["paths"] = paths
        logger.save_itr_params(itr, params)
        logger.log("Saved")

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot: Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self.start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self.itr_start_time))
        logger.dump_tabular(with_prefix=False)
        if self.plot:
            self.plotter.update_plot(self.policy, self.algo.max_path_length)
            if pause_for_plot:
                input("Plotting evaluation run: Press Enter to " "continue...")

    def train(self,
              n_epochs,
              n_epoch_cycles=1,
              batch_size=None,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs: Number of epochs.
            n_epoch_cycles: Number of batches of samples in each epoch.
                This is only useful for off-policy algorithm.
                For on-policy algorithm this value should always be 1.
            batch_size: Number of steps in batch.
            plot: Visualize policy by doing rollout after each epoch.
            store_paths: Save paths in snapshot.
            pause_for_plot: Pause for plot.

        Returns:
            The average return in last epoch cycle.

        """
        assert self.has_setup, "Use Runner.setup() to setup runner " \
                               "before training."
        if batch_size is None:
            from garage.tf.samplers import OffPolicyVectorizedSampler
            if isinstance(self.sampler, OffPolicyVectorizedSampler):
                batch_size = self.algo.max_path_length
            else:
                batch_size = 40 * self.algo.max_path_length

        self.n_epoch_cycles = n_epoch_cycles

        self.plot = plot
        self.start_worker()
        self.start_time = time.time()

        itr = 0
        last_return = None
        for epoch in range(n_epochs):
            self.itr_start_time = time.time()
            paths = None
            with logger.prefix('epoch #%d | ' % epoch):
                for cycle in range(n_epoch_cycles):
                    paths = self.obtain_samples(itr, batch_size)
                    paths = self.sampler.process_samples(itr, paths)
                    last_return = self.algo.train_once(itr, paths)
                    itr += 1
                self.save_snapshot(epoch, paths if store_paths else None)
                self.log_diagnostics(pause_for_plot)

        self.shutdown_worker()
        return last_return
Ejemplo n.º 8
0
class LocalTFRunner(LocalRunner):
    """This class implements a local runner for TensorFlow algorithms.

    A local runner provides a default TensorFlow session using python context.
    This is useful for those experiment components (e.g. policy) that require a
    TensorFlow session during construction.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        max_cpus (int): The maximum number of parallel sampler workers.
        sess (tf.Session): An optional TensorFlow session.
              A new session will be created immediately if not provided.

    Note:
        The local runner will set up a joblib task pool of size max_cpus
        possibly later used by BatchSampler. If BatchSampler is not used,
        the processes in the pool will remain dormant.

        This setup is required to use TensorFlow in a multiprocess
        environment before a TensorFlow session is created
        because TensorFlow is not fork-safe. See
        https://github.com/tensorflow/tensorflow/issues/2448.

        When resume via command line, new snapshots will be
        saved into the SAME directory if not specified.

        When resume programmatically, snapshot directory should be
        specify manually or through run_experiment() interface.

    Examples:
        # to train
        with LocalTFRunner() as runner:
            env = gym.make('CartPole-v1')
            policy = CategoricalMLPPolicy(
                env_spec=env.spec,
                hidden_sizes=(32, 32))
            algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                max_path_length=100,
                discount=0.99,
                max_kl_step=0.01)
            runner.setup(algo, env)
            runner.train(n_epochs=100, batch_size=4000)

        # to resume immediately.
        with LocalTFRunner() as runner:
            runner.restore(resume_from_dir)
            runner.resume()

        # to resume with modified training arguments.
        with LocalTFRunner() as runner:
            runner.restore(resume_from_dir)
            runner.resume(n_epochs=20)

    """
    def __init__(self, snapshot_config, sess=None, max_cpus=1):
        super().__init__(snapshot_config=snapshot_config, max_cpus=max_cpus)
        self.sess = sess or tf.compat.v1.Session()
        self.sess_entered = False

    def __enter__(self):
        """Set self.sess as the default session.

        Returns:
            LocalTFRunner: This local runner.

        """
        if tf.compat.v1.get_default_session() is not self.sess:
            self.sess.__enter__()
            self.sess_entered = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Leave session.

        Args:
            exc_type (str): Type.
            exc_val (object): Value.
            exc_tb (object): Traceback.

        """
        if tf.compat.v1.get_default_session(
        ) is self.sess and self.sess_entered:
            self.sess.__exit__(exc_type, exc_val, exc_tb)
            self.sess_entered = False

    def make_sampler(self,
                     sampler_cls,
                     *,
                     seed=None,
                     n_workers=psutil.cpu_count(logical=False),
                     max_path_length=None,
                     worker_class=DefaultWorker,
                     sampler_args=None,
                     worker_args=None):
        """Construct a Sampler from a Sampler class.

        Args:
            sampler_cls (type): The type of sampler to construct.
            seed (int): Seed to use in sampler workers.
            max_path_length (int): Maximum path length to be sampled by the
                sampler. Paths longer than this will be truncated.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the sampler should use.
            sampler_args (dict or None): Additional arguments that should be
                passed to the sampler.
            worker_args (dict or None): Additional arguments that should be
                passed to the worker.

        Returns:
            sampler_cls: An instance of the sampler class.

        """
        # pylint: disable=useless-super-delegation
        return super().make_sampler(
            sampler_cls,
            seed=seed,
            n_workers=n_workers,
            max_path_length=max_path_length,
            worker_class=TFWorkerClassWrapper(worker_class),
            sampler_args=sampler_args,
            worker_args=worker_args)

    def setup(self,
              algo,
              env,
              sampler_cls=None,
              sampler_args=None,
              n_workers=psutil.cpu_count(logical=False),
              worker_class=DefaultWorker,
              worker_args=None):
        """Set up runner and sessions for algorithm and environment.

        This method saves algo and env within runner and creates a sampler,
        and initializes all uninitialized variables in session.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the sampler should use.
            worker_args (dict or None): Additional arguments that should be
                passed to the worker.

        """
        self.initialize_tf_vars()
        logger.log(self.sess.graph)
        super().setup(algo, env, sampler_cls, sampler_args, n_workers,
                      worker_class, worker_args)

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self._sampler.start_worker()
        if self._plot:
            # pylint: disable=import-outside-toplevel
            from garage.tf.plotter import Plotter
            self._plotter = Plotter(self.get_env_copy(), self._algo.policy)
            self._plotter.start()

    def initialize_tf_vars(self):
        """Initialize all uninitialized variables in session."""
        with tf.name_scope('initialize_tf_vars'):
            uninited_set = [
                e.decode() for e in self.sess.run(
                    tf.compat.v1.report_uninitialized_variables())
            ]
            self.sess.run(
                tf.compat.v1.variables_initializer([
                    v for v in tf.compat.v1.global_variables()
                    if v.name.split(':')[0] in uninited_set
                ]))
Ejemplo n.º 9
0
class BatchPolopt(RLAlgorithm):
    """
    Base class for batch sampling-based policy optimization methods.
    This includes various policy gradient methods like vpg, npg, ppo, trpo,
    etc.
    """
    def __init__(self,
                 env,
                 policy,
                 baseline,
                 scope=None,
                 n_itr=500,
                 max_samples=None,
                 start_itr=0,
                 batch_size=5000,
                 max_path_length=500,
                 discount=0.99,
                 gae_lambda=1,
                 plot=False,
                 pause_for_plot=False,
                 center_adv=True,
                 positive_adv=False,
                 store_paths=False,
                 paths_h5_filename=None,
                 whole_paths=True,
                 fixed_horizon=False,
                 sampler_cls=None,
                 sampler_args=None,
                 force_batch_sampler=False,
                 play_every_itr=None,
                 record_every_itr=None,
                 record_end_ep_num=3,
                 **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :type policy: Policy
        :param baseline: Baseline
        :param scope: Scope for identifying the algorithm. Must be specified if
         running multiple algorithms
        simultaneously, each using different environments and policies
        :param n_itr: Max umber of iterations.
        :param max_samples: If not None - exit when max env samples is collected (overrides n_itr)
        :param start_itr: Starting iteration.
        :param batch_size: Number of samples per iteration.
        :param max_path_length: Maximum length of a single rollout.
        :param discount: Discount.
        :param gae_lambda: Lambda used for generalized advantage estimation.
        :param plot: Plot evaluation run after each iteration.
        :param pause_for_plot: Whether to pause before contiuing when plotting.
        :param center_adv: Whether to rescale the advantages so that they have
         mean 0 and standard deviation 1.
        :param positive_adv: Whether to shift the advantages so that they are
         always positive. When used in conjunction with center_adv the
         advantages will be standardized before shifting.
        :param store_paths: Whether to save all paths data to the snapshot.
        :return:
        """
        self.args = locals()
        del self.args["kwargs"]
        del self.args["self"]
        self.args = {**self.args, **kwargs}  #merging dicts

        self.env = env
        try:
            self.env.env.save_dyn_params(
                filename=logger.get_snapshot_dir().rstrip(os.sep) + os.sep +
                "dyn_params.yaml")
        except:
            print("WARNING: BatchPolOpt: couldn't save dynamics params")
            # import pdb; pdb.set_trace()
        from gym.wrappers import Monitor
        # self.env_rec = Monitor(self.env.env, logger.get_snapshot_dir() + os.sep + "videos", force=True)

        self.policy = policy
        self.baseline = baseline
        self.scope = scope
        self.n_itr = n_itr
        self.max_samples = max_samples
        self.start_itr = start_itr
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.center_adv = center_adv
        self.positive_adv = positive_adv
        self.store_paths = store_paths
        self.whole_paths = whole_paths
        self.fixed_horizon = fixed_horizon
        self.play_every_itr = play_every_itr
        self.record_every_itr = record_every_itr
        self.record_end_ep_num = record_end_ep_num
        if sampler_cls is None:
            if self.policy.vectorized and not force_batch_sampler:
                sampler_cls = OnPolicyVectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()
        self.sampler = sampler_cls(self, **sampler_args)
        self.init_opt()

        ## Initialization of HDF5 logging of trajectories
        if self.store_paths:
            self.h5_prepare_file(filename=paths_h5_filename, args=self.args)

        ## Initialize cleaner if we close
        atexit.register(self.clean_at_exit)

    def record_policy(self,
                      env,
                      policy,
                      itr,
                      n_rollout=1,
                      path=None,
                      postfix=""):
        # Rollout
        if path is None:
            path = logger.get_snapshot_dir().rstrip(
                os.sep) + os.sep + "videos" + os.sep + "itr_%05d%s.mp4" % (
                    itr, postfix)
        path_directory = path.rsplit(os.sep, 1)[0]
        if not os.path.exists(path_directory):
            os.makedirs(path_directory, exist_ok=True)
        for _ in range(n_rollout):
            obs = env.reset()
            recorder = VideoRecorder(env.env, path=path)
            while True:
                # env.render()
                # import pdb; pdb.set_trace()
                action, _ = policy.get_action(obs)
                obs, _, done, _ = env.step(action)
                recorder.capture_frame()
                if done:
                    break
            recorder.close()

    def play_policy(self, env, policy, n_rollout=2):
        # Rollout
        for _ in range(n_rollout):
            obs = env.reset()
            while True:
                env.render()
                action, _ = policy.get_action(obs)
                obs, _, done, _ = env.step(action)
                if done:
                    break

    @staticmethod
    def register_exit_handler(handler_fn):
        # Save will be executed upon normal exit of interpreter
        # NOTE: The functions registered via this module are not called when
        # the program is killed by a signal not handled by Python
        atexit.register(handler_fn)

    def clean_at_exit(self):
        # self.hdf.close()
        pass

    def h5_prepare_file(self, filename, args):
        # Assuming the following structure / indexing of the H5 file
        # teacher_info/
        #   - [teacher_indx]:
        #        - description
        #        - params
        # traj_data/
        #   - [teacher_indx] * [iter_indx] * traj_data

        # Making names and opening h5 file
        if filename is None:
            self.h5_filename = logger.get_snapshot_dir(
            ) + os.sep + "trajectories.h5"
        else:  #capability to store multiple teachers in a single file
            self.h5_filename = filename
        self.h5_filename = self.h5_filename if self.h5_filename[
            -3:] == '.h5' else (self.h5_filename + '.h5')

        if os.path.exists(self.h5_filename):
            # input("WARNING: output file %s already exists and will be appended. Press ENTER to continue. (exit with ctrl-C)" % self.h5_filename)
            print(
                "WARNING: output file %s already exists and will be appended" %
                self.h5_filename)
        self.hdf = h5py.File(self.h5_filename, "a")

        # Creating proper groups
        groups = list(self.hdf.keys())
        # Groups to create: tuples: (group_name, structure_decscripton)
        create_groups = [("teacher_info", "Runs indices(Teachers)"),
                         ("traj_data",
                          "Runs(Teachers) x Iterations x Trajectories x Data")]

        for group in create_groups:
            if not group in groups:
                self.hdf.create_group(group[0])
                self.hdf[group[0]].attrs["structure"] = np.string_(group[1])

        # Checking if other teachers' results already exist in the h5 file
        # If they exist - just append
        teacher_indices = list(self.hdf["traj_data"].keys())
        if not teacher_indices:
            self.teacher_indx = 0
        else:
            teacher_indices = [int(indx) for indx in teacher_indices]
            teacher_indices = np.sort(teacher_indices)
            self.teacher_indx = teacher_indices[-1] + 1
            print("%s : Appended teacher index: " % self.__class__.__name__,
                  self.teacher_indx)

        self.hdf.create_group("traj_data/" +
                              h5u.indx2str(self.teacher_indx))  #Teacher group

        ## Saving info about the teacher
        teacher_info_group = "teacher_info/" + h5u.indx2str(
            self.teacher_indx) + "/"
        self.hdf.create_group(teacher_info_group)  #Teacher group
        h5u.add_dict(self.hdf, self.args, groupname=teacher_info_group)

        return self.hdf

    def start_worker(self, sess):
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env, self.policy, sess)
            self.plotter.start()

    def shutdown_worker(self):
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr):
        return self.sampler.obtain_samples(itr)

    def process_samples(self, itr, paths):
        return self.sampler.process_samples(itr, paths)

    def log_env_info(self, env_infos, prefix=""):
        # Logging rewards
        rew_dic = env_infos["rewards"]
        for key in rew_dic.keys():
            rew_sums = np.sum(rew_dic[key], axis=1)
            logger.record_tabular("rewards/" + key + "_avg", np.mean(rew_sums))
            logger.record_tabular("rewards/" + key + "_std", np.std(rew_sums))

    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()
            sess.run(tf.global_variables_initializer())

        # Initialize some missing variables
        uninitialized_vars = []
        for var in tf.all_variables():
            try:
                sess.run(var)
            except tf.errors.FailedPreconditionError:
                print("Uninitialized var: ", var)
                uninitialized_vars.append(var)
        init_new_vars_op = tf.initialize_variables(uninitialized_vars)
        sess.run(init_new_vars_op)

        self.start_worker(sess)
        start_time = time.time()
        last_average_return = None
        samples_total = 0
        for itr in range(self.start_itr, self.n_itr):
            if samples_total >= self.max_samples:
                print("WARNING: Total max num of samples collected: %d >= %d" %
                      (samples_total, self.max_samples))
                break
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                samples_total += self.batch_size
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                last_average_return = samples_data["average_return"]
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                # import pdb; pdb.set_trace()
                if self.store_paths:
                    ## WARN: Beware that data is saved to hdf in float32 by default
                    # see param float_nptype
                    h5u.append_train_iter_data(h5file=self.hdf,
                                               data=samples_data["paths"],
                                               data_group="traj_data/",
                                               teacher_indx=self.teacher_indx,
                                               itr=None,
                                               float_nptype=np.float32)
                    # params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                self.log_env_info(samples_data["env_infos"])
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input(
                            "Plotting evaluation run: Press Enter to continue..."
                        )
                # Showing policy from time to time
                if self.record_every_itr is not None and self.record_every_itr > 0 and itr % self.record_every_itr == 0:
                    self.record_policy(env=self.env,
                                       policy=self.policy,
                                       itr=itr)
                if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0:
                    self.play_policy(env=self.env, policy=self.policy)

        # Recording a few episodes at the end
        if self.record_end_ep_num is not None:
            for i in range(self.record_end_ep_num):
                self.record_policy(env=self.env,
                                   policy=self.policy,
                                   itr=itr,
                                   postfix="_%02d" % i)

        # Reporting termination criteria
        if itr >= self.n_itr - 1:
            print(
                "TERM CRITERIA: Max number of iterations reached itr: %d , itr_max: %d"
                % (itr, self.n_itr - 1))
        if samples_total >= self.max_samples:
            print(
                "TERM CRITERIA: Total max num of samples collected: %d >= %d" %
                (samples_total, self.max_samples))

        self.shutdown_worker()
        if created_session:
            sess.close()

    def log_diagnostics(self, paths):
        self.policy.log_diagnostics(paths)
        self.baseline.log_diagnostics(paths)

        path_lengths = [path["returns"].size for path in paths]
        logger.record_tabular('ep_len_avg', np.mean(path_lengths))
        logger.record_tabular('ep_len_std', np.std(path_lengths))

    def init_opt(self):
        """
        Initialize the optimization procedure. If using tensorflow, this may
        include declaring all the variables and compiling functions
        """
        raise NotImplementedError

    def get_itr_snapshot(self, itr, samples_data):
        """
        Returns all the data that should be saved in the snapshot for this
        iteration.
        """
        raise NotImplementedError

    def optimize_policy(self, itr, samples_data):
        raise NotImplementedError
Ejemplo n.º 10
0
class LocalRunner:
    """This class implements a local runner for tensorflow algorithms.

    A local runner provides a default tensorflow session using python context.
    This is useful for those experiment components (e.g. policy) that require a
    tensorflow session during construction.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        max_cpus (int): The maximum number of parallel sampler workers.
        sess (tf.Session): An optional tensorflow session.
              A new session will be created immediately if not provided.

    Note:
        The local runner will set up a joblib task pool of size max_cpus
        possibly later used by BatchSampler. If BatchSampler is not used,
        the processes in the pool will remain dormant.

        This setup is required to use tensorflow in a multiprocess
        environment before a tensorflow session is created
        because tensorflow is not fork-safe.

        See https://github.com/tensorflow/tensorflow/issues/2448.

    Examples:
        with LocalRunner() as runner:
            env = gym.make('CartPole-v1')
            policy = CategoricalMLPPolicy(
                env_spec=env.spec,
                hidden_sizes=(32, 32))
            algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                max_path_length=100,
                discount=0.99,
                max_kl_step=0.01)
            runner.setup(algo, env)
            runner.train(n_epochs=100, batch_size=4000)

    """
    def __init__(self, snapshot_config=None, sess=None, max_cpus=1):
        if snapshot_config:
            self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                            snapshot_config.snapshot_mode,
                                            snapshot_config.snapshot_gap)
        else:
            self._snapshotter = Snapshotter()

        if max_cpus > 1:
            from garage.sampler import singleton_pool
            singleton_pool.initialize(max_cpus)
        self.sess = sess or tf.Session()
        self.sess_entered = False
        self.has_setup = False
        self.plot = False

        self._setup_args = None
        self.train_args = None

    def __enter__(self):
        """Set self.sess as the default session.

        Returns:
            This local runner.

        """
        if tf.get_default_session() is not self.sess:
            self.sess.__enter__()
            self.sess_entered = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Leave session."""
        if tf.get_default_session() is self.sess and self.sess_entered:
            self.sess.__exit__(exc_type, exc_val, exc_tb)
            self.sess_entered = False

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.

        """
        self.algo = algo
        self.env = env
        self.policy = self.algo.policy

        if sampler_args is None:
            sampler_args = {}

        if sampler_cls is None:
            from garage.tf.algos.batch_polopt import BatchPolopt
            if isinstance(algo, BatchPolopt):
                if self.policy.vectorized:
                    from garage.tf.samplers import OnPolicyVectorizedSampler
                    sampler_cls = OnPolicyVectorizedSampler
                else:
                    from garage.tf.samplers import BatchSampler
                    sampler_cls = BatchSampler
            else:
                from garage.tf.samplers import OffPolicyVectorizedSampler
                sampler_cls = OffPolicyVectorizedSampler

        self.sampler = sampler_cls(algo, env, **sampler_args)

        self.initialize_tf_vars()
        logger.log(self.sess.graph)
        self.has_setup = True

        self._setup_args = types.SimpleNamespace(sampler_cls=sampler_cls,
                                                 sampler_args=sampler_args)

    def initialize_tf_vars(self):
        """Initialize all uninitialized variables in session."""
        with tf.name_scope('initialize_tf_vars'):
            uninited_set = [
                e.decode()
                for e in self.sess.run(tf.report_uninitialized_variables())
            ]
            self.sess.run(
                tf.variables_initializer([
                    v for v in tf.global_variables()
                    if v.name.split(':')[0] in uninited_set
                ]))

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self.sampler.start_worker()
        if self.plot:
            from garage.tf.plotter import Plotter
            self.plotter = Plotter(self.env, self.policy)
            self.plotter.start()

    def _shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr, batch_size):
        """Obtain one batch of samples.

        Args:
            itr(int): Index of iteration (epoch).
            batch_size(int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            One batch of samples.

        """
        if self.train_args.n_epoch_cycles == 1:
            logger.log('Obtaining samples...')
        return self.sampler.obtain_samples(itr, batch_size)

    def save(self, epoch, paths=None):
        """Save snapshot of current batch.

        Args:
            itr(int): Index of iteration (epoch).
            paths(dict): Batch of samples after preprocessed. If None,
                no paths will be logged to the snapshot.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before saving.')

        logger.log('Saving snapshot...')

        params = dict()
        # Save arguments
        params['setup_args'] = self._setup_args
        params['train_args'] = self.train_args

        # Save states
        params['env'] = self.env
        params['algo'] = self.algo
        if paths:
            params['paths'] = paths
        params['last_epoch'] = epoch
        self._snapshotter.save_itr_params(epoch, params)

        logger.log('Saved')

    def restore(self, from_dir, from_epoch='last'):
        """Restore experiment from snapshot.

        Args:
            from_dir (str): Directory of the pickle file
                to resume experiment from.
            from_epoch (str or int): The epoch to restore from.
                Can be 'first', 'last' or a number.
                Not applicable when snapshot_mode='last'.

        Returns:
            A SimpleNamespace for train()'s arguments.

        Examples:
            1. Resume experiment immediately.
            with LocalRunner() as runner:
                runner.restore(resume_from_dir)
                runner.resume()

            2. Resume experiment with modified training arguments.
             with LocalRunner() as runner:
                runner.restore(resume_from_dir)
                runner.resume(n_epochs=20)

        Note:
            When resume via command line, new snapshots will be
            saved into the SAME directory if not specified.

            When resume programmatically, snapshot directory should be
            specify manually or through run_experiment() interface.

        """
        saved = self._snapshotter.load(from_dir, from_epoch)

        self._setup_args = saved['setup_args']
        self.train_args = saved['train_args']

        self.setup(env=saved['env'],
                   algo=saved['algo'],
                   sampler_cls=self._setup_args.sampler_cls,
                   sampler_args=self._setup_args.sampler_args)

        n_epochs = self.train_args.n_epochs
        last_epoch = saved['last_epoch']
        n_epoch_cycles = self.train_args.n_epoch_cycles
        batch_size = self.train_args.batch_size
        store_paths = self.train_args.store_paths
        pause_for_plot = self.train_args.pause_for_plot

        fmt = '{:<20} {:<15}'
        logger.log('Restore from snapshot saved in %s' %
                   self._snapshotter.snapshot_dir)
        logger.log(fmt.format('Train Args', 'Value'))
        logger.log(fmt.format('n_epochs', n_epochs))
        logger.log(fmt.format('last_epoch', last_epoch))
        logger.log(fmt.format('n_epoch_cycles', n_epoch_cycles))
        logger.log(fmt.format('batch_size', batch_size))
        logger.log(fmt.format('store_paths', store_paths))
        logger.log(fmt.format('pause_for_plot', pause_for_plot))

        self.train_args.start_epoch = last_epoch + 1
        return copy.copy(self.train_args)

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot(bool): Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self._start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time))
        logger.log(tabular)
        if self.plot:
            self.plotter.update_plot(self.policy, self.algo.max_path_length)
            if pause_for_plot:
                input('Plotting evaluation run: Press Enter to " "continue...')

    def train(self,
              n_epochs,
              batch_size,
              n_epoch_cycles=1,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs(int): Number of epochs.
            batch_size(int): Number of environment steps in one batch.
            n_epoch_cycles(int): Number of batches of samples in each epoch.
                This is only useful for off-policy algorithm.
                For on-policy algorithm this value should always be 1.
            plot(bool): Visualize policy by doing rollout after each epoch.
            store_paths(bool): Save paths in snapshot.
            pause_for_plot(bool): Pause for plot.

        Returns:
            The average return in last epoch cycle.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before training.')

        # Save arguments for restore
        self.train_args = types.SimpleNamespace(n_epochs=n_epochs,
                                                n_epoch_cycles=n_epoch_cycles,
                                                batch_size=batch_size,
                                                plot=plot,
                                                store_paths=store_paths,
                                                pause_for_plot=pause_for_plot,
                                                start_epoch=0)

        self.plot = plot

        return self.algo.train(self, batch_size)

    def step_epochs(self):
        """Generator for training.

        This function serves as a generator. It is used to separate
        services such as snapshotting, sampler control from the actual
        training loop. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        try:
            self._start_worker()
            self._start_time = time.time()
            self.step_itr = (self.train_args.start_epoch *
                             self.train_args.n_epoch_cycles)
            self.step_path = None

            for epoch in range(self.train_args.start_epoch,
                               self.train_args.n_epochs):
                self._itr_start_time = time.time()
                with logger.prefix('epoch #%d | ' % epoch):
                    yield epoch
                    save_path = (self.step_path
                                 if self.train_args.store_paths else None)
                    self.save(epoch, save_path)
                    self.log_diagnostics(self.train_args.pause_for_plot)
                    logger.dump_all(self.step_itr)
                    tabular.clear()
        finally:
            self._shutdown_worker()

    def resume(self,
               n_epochs=None,
               batch_size=None,
               n_epoch_cycles=None,
               plot=None,
               store_paths=None,
               pause_for_plot=None):
        """Resume from restored experiment.

        This method provides the same interface as train().

        If not specified, an argument will default to the
        saved arguments from the last call to train().

        Returns:
            The average return in last epoch cycle.

        """
        assert self.train_args is not None, (
            'You must call restore() before resume().')

        self.train_args.n_epochs = n_epochs or self.train_args.n_epochs
        self.train_args.batch_size = batch_size or self.train_args.batch_size
        self.train_args.n_epoch_cycles = (n_epoch_cycles
                                          or self.train_args.n_epoch_cycles)

        if plot is not None:
            self.train_args.plot = plot
        if store_paths is not None:
            self.train_args.store_paths = store_paths
        if pause_for_plot is not None:
            self.train_args.pause_for_plot = pause_for_plot

        return self.algo.train(self, batch_size)
Ejemplo n.º 11
0
class DDPG(RLAlgorithm):
    """
    A DDPG model based on https://arxiv.org/pdf/1509.02971.pdf.

    Example:
        $ python garage/examples/tf/ddpg_pendulum.py
    """

    def __init__(self,
                 env,
                 actor,
                 critic,
                 n_epochs=500,
                 n_epoch_cycles=20,
                 n_rollout_steps=100,
                 n_train_steps=50,
                 reward_scale=1.,
                 batch_size=64,
                 target_update_tau=0.01,
                 discount=0.99,
                 actor_lr=1e-4,
                 critic_lr=1e-3,
                 actor_weight_decay=0,
                 critic_weight_decay=0,
                 replay_buffer_size=int(1e6),
                 min_buffer_size=10000,
                 exploration_strategy=None,
                 plot=False,
                 pause_for_plot=False,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 use_her=False,
                 clip_obs=np.inf,
                 clip_pos_returns=True,
                 clip_return=None,
                 replay_k=4,
                 max_action=None,
                 name=None):
        """
        Construct class.

        Args:
            env(): Environment.
            actor(garage.tf.policies.ContinuousMLPPolicy): Policy network.
            critic(garage.tf.q_functions.ContinuousMLPQFunction):
        Q Value network.
            n_epochs(int, optional): Number of epochs.
            n_epoch_cycles(int, optional): Number of epoch cycles.
            n_rollout_steps(int, optional): Number of rollout steps.
        aka the time horizon of rollout.
            n_train_steps(int, optional): Number of train steps.
            reward_scale(float): The scaling factor applied to the rewards when
        training.
            batch_size(int): Number of samples for each minibatch.
            target_update_tau(float): Interpolation parameter for doing the
        soft target update.
            discount(float): Discount factor for the cumulative return.
            actor_lr(float): Learning rate for training policy network.
            critic_lr(float): Learning rate for training q value network.
            actor_weight_decay(float): L2 weight decay factor for parameters of
        the policy network.
            critic_weight_decay(float): L2 weight decay factor for parameters
        of the q value network.
            replay_buffer_size(int): Size of the replay buffer.
            min_buffer_size(int): Minimum size of the replay buffer to start
        training.
            exploration_strategy(): Exploration strategy to randomize the
        action.
            plot(boolean): Whether to visualize the policy performance after
        each eval_interval.
            pause_for_plot(boolean): Whether or not pause before continuing
        when plotting.
            actor_optimizer(): Optimizer for training policy network.
            critic_optimizer(): Optimizer for training q function network.
            use_her(boolean): Whether or not use HER for replay buffer.
            clip_obs(float): Clip observation to be in [-clip_obs, clip_obs].
            clip_pos_returns(boolean): Whether or not clip positive returns.
            clip_return(float): Clip return to be in [-clip_return,
        clip_return].
            replay_k(int): The ratio between HER replays and regular replays.
        Only used when use_her is True.
            max_action(float): Maximum action magnitude.
            name(str): Name of the algorithm shown in computation graph.
        """
        self.env = env

        self.input_dims = configure_dims(env)
        action_bound = env.action_space.high
        self.max_action = action_bound if max_action is None else max_action

        self.actor = actor
        self.critic = critic
        self.n_epochs = n_epochs
        self.n_epoch_cycles = n_epoch_cycles
        self.n_rollout_steps = n_rollout_steps
        self.n_train_steps = n_train_steps
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.tau = target_update_tau
        self.discount = discount
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.actor_weight_decay = actor_weight_decay
        self.critic_weight_decay = critic_weight_decay
        self.replay_buffer_size = replay_buffer_size
        self.min_buffer_size = min_buffer_size
        self.es = exploration_strategy
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.name = name
        self.use_her = use_her
        self.evaluate = False
        self.replay_k = replay_k
        self.clip_return = (
            1. / (1. - self.discount)) if clip_return is None else clip_return
        self.clip_obs = clip_obs
        self.clip_pos_returns = clip_pos_returns
        self.success_history = deque(maxlen=100)
        self._initialize()

    @overrides
    def train(self, sess=None):
        """
        Training process of DDPG algorithm.

        Args:
            sess: A TensorFlow session for executing ops.
        """
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        # Start plotter
        if self.plot:
            self.plotter = Plotter(self.env, self.actor, sess)
            self.plotter.start()

        sess.run(tf.global_variables_initializer())
        self.f_init_target()

        observation = self.env.reset()
        if self.es:
            self.es.reset()

        episode_reward = 0.
        episode_step = 0
        episode_rewards = []
        episode_steps = []
        episode_actor_losses = []
        episode_critic_losses = []
        episodes = 0
        epoch_ys = []
        epoch_qs = []

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            self.success_history.clear()
            for epoch_cycle in pyprind.prog_bar(range(self.n_epoch_cycles)):
                if self.use_her:
                    successes = []
                    for rollout in range(self.n_rollout_steps):
                        o = np.clip(observation["observation"], -self.clip_obs,
                                    self.clip_obs)
                        g = np.clip(observation["desired_goal"],
                                    -self.clip_obs, self.clip_obs)
                        obs_goal = np.concatenate((o, g), axis=-1)
                        action = self.es.get_action(rollout, obs_goal,
                                                    self.actor)

                        next_observation, reward, terminal, info = self.env.step(  # noqa: E501
                            action)
                        if 'is_success' in info:
                            successes.append([info["is_success"]])
                        episode_reward += reward
                        episode_step += 1

                        info_dict = {
                            "info_{}".format(key): info[key].reshape(1)
                            for key in info.keys()
                        }
                        self.replay_buffer.add_transition(
                            observation=observation['observation'],
                            action=action,
                            goal=observation['desired_goal'],
                            achieved_goal=observation['achieved_goal'],
                            **info_dict,
                        )

                        observation = next_observation

                        if rollout == self.n_rollout_steps - 1:
                            self.replay_buffer.add_transition(
                                observation=observation['observation'],
                                achieved_goal=observation['achieved_goal'])

                            episode_rewards.append(episode_reward)
                            episode_steps.append(episode_step)
                            episode_reward = 0.
                            episode_step = 0
                            episodes += 1

                            observation = self.env.reset()
                            if self.es:
                                self.es.reset()

                    successful = np.array(successes)[-1, :]
                    success_rate = np.mean(successful)
                    self.success_history.append(success_rate)

                    for train_itr in range(self.n_train_steps):
                        self.evaluate = True
                        critic_loss, y, q, action_loss = self._learn()

                        episode_actor_losses.append(action_loss)
                        episode_critic_losses.append(critic_loss)
                        epoch_ys.append(y)
                        epoch_qs.append(q)

                    self.f_update_target()
                else:
                    for rollout in range(self.n_rollout_steps):
                        action = self.es.get_action(rollout, observation,
                                                    self.actor)
                        assert action.shape == self.env.action_space.shape

                        next_observation, reward, terminal, info = self.env.step(  # noqa: E501
                            action)
                        episode_reward += reward
                        episode_step += 1

                        self.replay_buffer.add_transition(
                            observation=observation,
                            action=action,
                            reward=reward * self.reward_scale,
                            terminal=terminal,
                            next_observation=next_observation,
                        )

                        observation = next_observation

                        if terminal or rollout == self.n_rollout_steps - 1:
                            episode_rewards.append(episode_reward)
                            episode_steps.append(episode_step)
                            episode_reward = 0.
                            episode_step = 0
                            episodes += 1

                            observation = self.env.reset()
                            if self.es:
                                self.es.reset()

                    for train_itr in range(self.n_train_steps):
                        if self.replay_buffer.size >= self.min_buffer_size:
                            self.evaluate = True
                            critic_loss, y, q, action_loss = self._learn()

                            episode_actor_losses.append(action_loss)
                            episode_critic_losses.append(critic_loss)
                            epoch_ys.append(y)
                            epoch_qs.append(q)

            logger.log("Training finished")
            logger.log("Saving snapshot")
            itr = epoch * self.n_epoch_cycles + epoch_cycle
            params = self.get_itr_snapshot(itr)
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            if self.evaluate:
                logger.record_tabular('Epoch', epoch)
                logger.record_tabular('Episodes', episodes)
                logger.record_tabular('AverageReturn',
                                      np.mean(episode_rewards))
                logger.record_tabular('StdReturn', np.std(episode_rewards))
                logger.record_tabular('Policy/AveragePolicyLoss',
                                      np.mean(episode_actor_losses))
                logger.record_tabular('QFunction/AverageQFunctionLoss',
                                      np.mean(episode_critic_losses))
                logger.record_tabular('QFunction/AverageQ', np.mean(epoch_qs))
                logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                logger.record_tabular('QFunction/AverageAbsQ',
                                      np.mean(np.abs(epoch_qs)))
                logger.record_tabular('QFunction/AverageY', np.mean(epoch_ys))
                logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                logger.record_tabular('QFunction/AverageAbsY',
                                      np.mean(np.abs(epoch_ys)))
                if self.use_her:
                    logger.record_tabular('AverageSuccessRate',
                                          np.mean(self.success_history))

                # Uncomment the following if you want to calculate the average
                # in each epoch, better uncomment when self.use_her is True
                # episode_rewards = []
                # episode_actor_losses = []
                # episode_critic_losses = []
                # epoch_ys = []
                # epoch_qs = []

            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.actor, self.n_rollout_steps)
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")

        if self.plot:
            self.plotter.close()
        if created_session:
            sess.close()

    def _initialize(self):
        with tf.name_scope(self.name, "DDPG"):
            with tf.name_scope("setup_networks"):
                """Set up the actor, critic and target network."""
                # Set up the actor and critic network
                self.actor._build_net(trainable=True)
                self.critic._build_net(trainable=True)

                # Create target actor and critic network
                target_actor = copy(self.actor)
                target_critic = copy(self.critic)

                # Set up the target network
                target_actor.name = "TargetActor"
                target_actor._build_net(trainable=False)
                target_critic.name = "TargetCritic"
                target_critic._build_net(trainable=False)

            input_shapes = dims_to_shapes(self.input_dims)

            # Initialize replay buffer
            if self.use_her:
                buffer_shapes = {
                    key: (self.n_rollout_steps + 1
                          if key == "observation" or key == "achieved_goal"
                          else self.n_rollout_steps, *input_shapes[key])
                    for key, val in input_shapes.items()
                }

                replay_buffer = HerReplayBuffer(
                    buffer_shapes=buffer_shapes,
                    size_in_transitions=self.replay_buffer_size,
                    time_horizon=self.n_rollout_steps,
                    sample_transitions=make_her_sample(
                        self.replay_k, self.env.compute_reward))
            else:
                replay_buffer = ReplayBuffer(
                    buffer_shapes=input_shapes,
                    max_buffer_size=self.replay_buffer_size)

            # Set up target init and update function
            with tf.name_scope("setup_target"):
                actor_init_ops, actor_update_ops = get_target_ops(
                    self.actor.global_vars, target_actor.global_vars, self.tau)
                critic_init_ops, critic_update_ops = get_target_ops(
                    self.critic.global_vars, target_critic.global_vars,
                    self.tau)
                target_init_op = actor_init_ops + critic_init_ops
                target_update_op = actor_update_ops + critic_update_ops

            f_init_target = tensor_utils.compile_function(
                inputs=[], outputs=target_init_op)
            f_update_target = tensor_utils.compile_function(
                inputs=[], outputs=target_update_op)

            with tf.name_scope("inputs"):
                obs_dim = (
                    self.input_dims["observation"] + self.input_dims["goal"]
                ) if self.use_her else self.input_dims["observation"]
                y = tf.placeholder(tf.float32, shape=(None, 1), name="input_y")
                obs = tf.placeholder(
                    tf.float32,
                    shape=(None, obs_dim),
                    name="input_observation")
                actions = tf.placeholder(
                    tf.float32,
                    shape=(None, self.input_dims["action"]),
                    name="input_action")

            # Set up actor training function
            next_action = self.actor.get_action_sym(obs, name="actor_action")
            next_qval = self.critic.get_qval_sym(
                obs, next_action, name="actor_qval")
            with tf.name_scope("action_loss"):
                action_loss = -tf.reduce_mean(next_qval)
                if self.actor_weight_decay > 0.:
                    actor_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.actor_weight_decay),
                        weights_list=self.actor.regularizable_vars)
                    action_loss += actor_reg

            with tf.name_scope("minimize_action_loss"):
                actor_train_op = self.actor_optimizer(
                    self.actor_lr, name="ActorOptimizer").minimize(
                        action_loss, var_list=self.actor.trainable_vars)

            f_train_actor = tensor_utils.compile_function(
                inputs=[obs], outputs=[actor_train_op, action_loss])

            # Set up critic training function
            qval = self.critic.get_qval_sym(obs, actions, name="q_value")
            with tf.name_scope("qval_loss"):
                qval_loss = tf.reduce_mean(tf.squared_difference(y, qval))
                if self.critic_weight_decay > 0.:
                    critic_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.critic_weight_decay),
                        weights_list=self.critic.regularizable_vars)
                    qval_loss += critic_reg

            with tf.name_scope("minimize_critic_loss"):
                critic_train_op = self.critic_optimizer(
                    self.critic_lr, name="CriticOptimizer").minimize(
                        qval_loss, var_list=self.critic.trainable_vars)

            f_train_critic = tensor_utils.compile_function(
                inputs=[y, obs, actions],
                outputs=[critic_train_op, qval_loss, qval])

            self.f_train_actor = f_train_actor
            self.f_train_critic = f_train_critic
            self.f_init_target = f_init_target
            self.f_update_target = f_update_target
            self.replay_buffer = replay_buffer
            self.target_critic = target_critic
            self.target_actor = target_actor

    def _learn(self):
        """
        Perform algorithm optimizing.

        Returns:
            action_loss: Loss of action predicted by the policy network.
            qval_loss: Loss of q value predicted by the q network.
            ys: y_s.
            qval: Q value predicted by the q network.

        """
        if self.use_her:
            transitions = self.replay_buffer.sample(self.batch_size)
            observations = transitions["observation"]
            rewards = transitions["reward"]
            actions = transitions["action"]
            next_observations = transitions["next_observation"]
            goals = transitions["goal"]

            next_inputs = np.concatenate((next_observations, goals), axis=-1)
            inputs = np.concatenate((observations, goals), axis=-1)

            rewards = rewards.reshape(-1, 1)

            target_actions, _ = self.target_actor.get_actions(next_inputs)
            target_qvals = self.target_critic.get_qval(next_inputs,
                                                       target_actions)

            clip_range = (-self.clip_return, 0.
                          if self.clip_pos_returns else np.inf)
            ys = np.clip(rewards + self.discount * target_qvals, clip_range[0],
                         clip_range[1])

            _, qval_loss, qval = self.f_train_critic(ys, inputs, actions)
            _, action_loss = self.f_train_actor(inputs)
        else:
            transitions = self.replay_buffer.sample(self.batch_size)
            observations = transitions["observation"]
            rewards = transitions["reward"]
            actions = transitions["action"]
            terminals = transitions["terminal"]
            next_observations = transitions["next_observation"]

            rewards = rewards.reshape(-1, 1)
            terminals = terminals.reshape(-1, 1)

            target_actions, _ = self.target_actor.get_actions(
                next_observations)
            target_qvals = self.target_critic.get_qval(next_observations,
                                                       target_actions)

            ys = rewards + (1.0 - terminals) * self.discount * target_qvals

            _, qval_loss, qval = self.f_train_critic(ys, observations, actions)
            _, action_loss = self.f_train_actor(observations)
            self.f_update_target()

        return qval_loss, ys, qval, action_loss

    def get_itr_snapshot(self, itr):
        return dict(itr=itr, policy=self.actor, env=self.env)
Ejemplo n.º 12
0
class DDPG(RLAlgorithm):
    """
    A DDPG model based on https://arxiv.org/pdf/1509.02971.pdf.

    Example:
        $ python garage/examples/tf/ddpg_pendulum.py
    """
    def __init__(self,
                 env,
                 actor,
                 critic,
                 n_epochs=500,
                 n_epoch_cycles=20,
                 n_rollout_steps=100,
                 n_train_steps=50,
                 reward_scale=1.,
                 batch_size=64,
                 target_update_tau=0.01,
                 discount=0.99,
                 actor_lr=1e-4,
                 critic_lr=1e-3,
                 actor_weight_decay=0,
                 critic_weight_decay=0,
                 replay_buffer_size=int(1e6),
                 min_buffer_size=10000,
                 exploration_strategy=None,
                 plot=False,
                 pause_for_plot=False,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 name=None):
        """
        Construct class.

        Args:
            env(): Environment.
            actor(garage.tf.policies.ContinuousMLPPolicy): Policy network.
            critic(garage.tf.q_functions.ContinuousMLPQFunction):
         Q Value network.
            n_epochs(int, optional): Number of epochs.
            n_epoch_cycles(int, optional): Number of epoch cycles.
            n_rollout_steps(int, optional): Number of rollout steps.
            n_train_steps(int, optional): Number of train steps.
            reward_scale(float): The scaling factor applied to the rewards when
         training.
            batch_size(int): Number of samples for each minibatch.
            target_update_tau(float): Interpolation parameter for doing the
         soft target update.
            discount(float): Discount factor for the cumulative return.
            actor_lr(float): Learning rate for training policy network.
            critic_lr(float): Learning rate for training q value network.
            actor_weight_decay(float): L2 weight decay factor for parameters of
         the policy network.
            critic_weight_decay(float): L2 weight decay factor for parameters
         of the q value network.
            replay_buffer_size(int): Size of the replay buffer.
            min_buffer_size(int): Minimum size of the replay buffer to start
         training.
            exploration_strategy(): Exploration strategy.
            plot(bool): Whether to visualize the policy performance after each
         eval_interval.
            pause_for_plot(bool): Whether to pause before continuing when
         plotting.
            actor_optimizer(): Optimizer for training policy network.
            critic_optimizer(): Optimizer for training q function network.
        """
        self.env = env

        self.observation_dim = flat_dim(env.observation_space)
        self.action_dim = flat_dim(env.action_space)
        _, self.action_bound = bounds(env.action_space)

        self.actor = actor
        self.critic = critic
        self.n_epochs = n_epochs
        self.n_epoch_cycles = n_epoch_cycles
        self.n_rollout_steps = n_rollout_steps
        self.n_train_steps = n_train_steps
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.tau = target_update_tau
        self.discount = discount
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.actor_weight_decay = actor_weight_decay
        self.critic_weight_decay = critic_weight_decay
        self.replay_buffer_size = replay_buffer_size
        self.min_buffer_size = min_buffer_size
        self.es = exploration_strategy
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.name = name
        self._initialize()

    @overrides
    def train(self, sess=None):
        """
        Training process of DDPG algorithm.

        Args:
            sess: A TensorFlow session for executing ops.
        """
        replay_buffer = self.opt_info["replay_buffer"]
        f_init_target = self.opt_info["f_init_target"]
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        # Start plotter
        if self.plot:
            self.plotter = Plotter(self.env, self.actor, sess)
            self.plotter.start()

        sess.run(tf.global_variables_initializer())
        f_init_target()

        observation = self.env.reset()
        if self.es:
            self.es.reset()

        episode_reward = 0.
        episode_step = 0
        episode_rewards = []
        episode_steps = []
        episode_actor_losses = []
        episode_critic_losses = []
        episodes = 0
        epoch_ys = []
        epoch_qs = []

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            for epoch_cycle in pyprind.prog_bar(range(self.n_epoch_cycles)):
                for rollout in range(self.n_rollout_steps):
                    action = self.es.get_action(rollout, observation,
                                                self.actor)
                    assert action.shape == self.env.action_space.shape

                    next_observation, reward, terminal, info = self.env.step(
                        action)
                    episode_reward += reward
                    episode_step += 1

                    replay_buffer.add_transition(observation, action,
                                                 reward * self.reward_scale,
                                                 terminal, next_observation)

                    observation = next_observation

                    if terminal:
                        episode_rewards.append(episode_reward)
                        episode_steps.append(episode_step)
                        episode_reward = 0.
                        episode_step = 0
                        episodes += 1

                        observation = self.env.reset()
                        if self.es:
                            self.es.reset()

                for train_itr in range(self.n_train_steps):
                    if replay_buffer.size >= self.min_buffer_size:
                        critic_loss, y, q, action_loss = self._learn()

                        episode_actor_losses.append(action_loss)
                        episode_critic_losses.append(critic_loss)
                        epoch_ys.append(y)
                        epoch_qs.append(q)

            logger.log("Training finished")
            if replay_buffer.size >= self.min_buffer_size:
                logger.record_tabular('Epoch', epoch)
                logger.record_tabular('Episodes', episodes)
                logger.record_tabular('AverageReturn',
                                      np.mean(episode_rewards))
                logger.record_tabular('StdReturn', np.std(episode_rewards))
                logger.record_tabular('Policy/AveragePolicyLoss',
                                      np.mean(episode_actor_losses))
                logger.record_tabular('QFunction/AverageQFunctionLoss',
                                      np.mean(episode_critic_losses))
                logger.record_tabular('QFunction/AverageQ', np.mean(epoch_qs))
                logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                logger.record_tabular('QFunction/AverageAbsQ',
                                      np.mean(np.abs(epoch_qs)))
                logger.record_tabular('QFunction/AverageY', np.mean(epoch_ys))
                logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                logger.record_tabular('QFunction/AverageAbsY',
                                      np.mean(np.abs(epoch_ys)))

                # Uncomment the following if you want to calculate the average
                # in each epoch
                # episode_rewards = []
                # episode_actor_losses = []
                # episode_critic_losses = []
                # epoch_ys = []
                # epoch_qs = []

            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.actor, self.n_rollout_steps)
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")

        if self.plot:
            self.plotter.shutdown()
        if created_session:
            sess.close()

    def _initialize(self):
        with tf.name_scope(self.name, "DDPG"):
            with tf.name_scope("setup_networks"):
                """Set up the actor, critic and target network."""
                # Set up the actor and critic network
                self.actor._build_net(trainable=True)
                self.critic._build_net(trainable=True)

                # Create target actor and critic network
                target_actor = copy(self.actor)
                target_critic = copy(self.critic)

                # Set up the target network
                target_actor.name = "TargetActor"
                target_actor._build_net(trainable=False)
                target_critic.name = "TargetCritic"
                target_critic._build_net(trainable=False)

            # Initialize replay buffer
            replay_buffer = ReplayBuffer(self.replay_buffer_size,
                                         self.observation_dim, self.action_dim)

            # Set up target init and update function
            with tf.name_scope("setup_target"):
                actor_init_ops, actor_update_ops = get_target_ops(
                    self.actor.global_vars, target_actor.global_vars, self.tau)
                critic_init_ops, critic_update_ops = get_target_ops(
                    self.critic.global_vars, target_critic.global_vars,
                    self.tau)
                target_init_op = actor_init_ops + critic_init_ops
                target_update_op = actor_update_ops + critic_update_ops

            f_init_target = tensor_utils.compile_function(
                inputs=[], outputs=target_init_op)
            f_update_target = tensor_utils.compile_function(
                inputs=[], outputs=target_update_op)

            with tf.name_scope("inputs"):
                y = tf.placeholder(tf.float32, shape=(None, 1), name="input_y")
                obs = tf.placeholder(tf.float32,
                                     shape=(None, self.observation_dim),
                                     name="input_observation")
                actions = tf.placeholder(tf.float32,
                                         shape=(None, self.action_dim),
                                         name="input_action")

            # Set up actor training function
            next_action = self.actor.get_action_sym(obs, name="actor_action")
            next_qval = self.critic.get_qval_sym(obs,
                                                 next_action,
                                                 name="actor_qval")
            with tf.name_scope("action_loss"):
                action_loss = -tf.reduce_mean(next_qval)
                if self.actor_weight_decay > 0.:
                    actor_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.actor_weight_decay),
                        weights_list=self.actor.regularizable_vars)
                    action_loss += actor_reg

            with tf.name_scope("minimize_action_loss"):
                actor_train_op = self.actor_optimizer(
                    self.actor_lr, name="ActorOptimizer").minimize(
                        action_loss, var_list=self.actor.trainable_vars)

            f_train_actor = tensor_utils.compile_function(
                inputs=[obs], outputs=[actor_train_op, action_loss])

            # Set up critic training function
            qval = self.critic.get_qval_sym(obs, actions, name="q_value")
            with tf.name_scope("qval_loss"):
                qval_loss = tf.reduce_mean(tf.squared_difference(y, qval))
                if self.critic_weight_decay > 0.:
                    critic_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.critic_weight_decay),
                        weights_list=self.critic.regularizable_vars)
                    qval_loss += critic_reg

            with tf.name_scope("minimize_critic_loss"):
                critic_train_op = self.critic_optimizer(
                    self.critic_lr, name="CriticOptimizer").minimize(
                        qval_loss, var_list=self.critic.trainable_vars)

            f_train_critic = tensor_utils.compile_function(
                inputs=[y, obs, actions],
                outputs=[critic_train_op, qval_loss, qval])

            self.opt_info = dict(f_train_actor=f_train_actor,
                                 f_train_critic=f_train_critic,
                                 f_init_target=f_init_target,
                                 f_update_target=f_update_target,
                                 replay_buffer=replay_buffer,
                                 target_critic=target_critic,
                                 target_actor=target_actor)

    def _learn(self):
        """
        Perform algorithm optimizing.

        Returns:
            action_loss: Loss of action predicted by the policy network.
            qval_loss: Loss of q value predicted by the q network.
            ys: y_s.
            qval: Q value predicted by the q network.

        """
        replay_buffer = self.opt_info["replay_buffer"]
        target_actor = self.opt_info["target_actor"]
        target_critic = self.opt_info["target_critic"]
        f_train_critic = self.opt_info["f_train_critic"]
        f_train_actor = self.opt_info["f_train_actor"]
        f_update_target = self.opt_info["f_update_target"]

        transitions = replay_buffer.random_sample(self.batch_size)
        observations = transitions["observations"]
        rewards = transitions["rewards"]
        actions = transitions["actions"]
        terminals = transitions["terminals"]
        next_observations = transitions["next_observations"]

        rewards = rewards.reshape(-1, 1)
        terminals = terminals.reshape(-1, 1)

        target_actions = target_actor.get_actions(next_observations)
        target_qvals = target_critic.get_qval(next_observations,
                                              target_actions)

        ys = rewards + (1.0 - terminals) * self.discount * target_qvals

        _, qval_loss, qval = f_train_critic(ys, observations, actions)
        _, action_loss = f_train_actor(observations)
        f_update_target()

        return qval_loss, ys, qval, action_loss