def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
        self.config = config
        self.sess = tf.get_default_session()

        # Setup the policy
        self.observations = tf.placeholder(
            tf.float32, [None] + list(observation_space.shape))
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        prev_actions = ModelCatalog.get_action_placeholder(action_space)
        prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
        self.model = ModelCatalog.get_model({
            "obs": self.observations,
            "prev_actions": prev_actions,
            "prev_rewards": prev_rewards,
            "is_training": self._get_is_training_placeholder(),
        }, observation_space, logit_dim, self.config["model"])
        action_dist = dist_class(self.model.outputs)
        self.vf = self.model.value_function()
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        # Setup the policy loss
        if isinstance(action_space, gym.spaces.Box):
            ac_size = action_space.shape[0]
            actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
        elif isinstance(action_space, gym.spaces.Discrete):
            actions = tf.placeholder(tf.int64, [None], name="ac")
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for A3C.".format(
                    action_space))
        advantages = tf.placeholder(tf.float32, [None], name="advantages")
        self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
        self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
                            self.vf, self.config["vf_loss_coeff"],
                            self.config["entropy_coeff"])

        # Initialize TFPolicyGraph
        loss_in = [
            ("obs", self.observations),
            ("actions", actions),
            ("prev_actions", prev_actions),
            ("prev_rewards", prev_rewards),
            ("advantages", advantages),
            ("value_targets", self.v_target),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=self.observations,
            action_sampler=action_dist.sample(),
            loss=self.model.loss() + self.loss.total_loss,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            prev_action_input=prev_actions,
            prev_reward_input=prev_rewards,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"])

        self.stats_fetches = {
            "stats": {
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "policy_loss": self.loss.pi_loss,
                "policy_entropy": self.loss.entropy,
                "grad_gnorm": tf.global_norm(self._grads),
                "var_gnorm": tf.global_norm(self.var_list),
                "vf_loss": self.loss.vf_loss,
                "vf_explained_var": explained_variance(self.v_target, self.vf),
            },
        }

        self.sess.run(tf.global_variables_initializer())
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None):
        """
        Arguments:
            observation_space: Environment observation space specification.
            action_space: Environment action space specification.
            config (dict): Configuration values for PPO graph.
            existing_inputs (list): Optional list of tuples that specify the
                placeholders upon which the graph should be built upon.
        """
        config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
        self.sess = tf.get_default_session()
        self.action_space = action_space
        self.config = config
        self.kl_coeff_val = self.config["kl_coeff"]
        self.kl_target = self.config["kl_target"]
        dist_cls, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])

        if existing_inputs:
            obs_ph, value_targets_ph, adv_ph, act_ph, \
                logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \
                existing_inputs[:8]
            existing_state_in = existing_inputs[8:-1]
            existing_seq_lens = existing_inputs[-1]
        else:
            obs_ph = tf.placeholder(
                tf.float32,
                name="obs",
                shape=(None, ) + observation_space.shape)
            adv_ph = tf.placeholder(
                tf.float32, name="advantages", shape=(None, ))
            act_ph = ModelCatalog.get_action_placeholder(action_space)
            logits_ph = tf.placeholder(
                tf.float32, name="logits", shape=(None, logit_dim))
            vf_preds_ph = tf.placeholder(
                tf.float32, name="vf_preds", shape=(None, ))
            value_targets_ph = tf.placeholder(
                tf.float32, name="value_targets", shape=(None, ))
            prev_actions_ph = ModelCatalog.get_action_placeholder(action_space)
            prev_rewards_ph = tf.placeholder(
                tf.float32, [None], name="prev_reward")
            existing_state_in = None
            existing_seq_lens = None
        self.observations = obs_ph
        self.prev_actions = prev_actions_ph
        self.prev_rewards = prev_rewards_ph

        self.loss_in = [
            ("obs", obs_ph),
            ("value_targets", value_targets_ph),
            ("advantages", adv_ph),
            ("actions", act_ph),
            ("logits", logits_ph),
            ("vf_preds", vf_preds_ph),
            ("prev_actions", prev_actions_ph),
            ("prev_rewards", prev_rewards_ph),
        ]
        self.model = ModelCatalog.get_model(
            {
                "obs": obs_ph,
                "prev_actions": prev_actions_ph,
                "prev_rewards": prev_rewards_ph,
                "is_training": self._get_is_training_placeholder(),
            },
            observation_space,
            action_space,
            logit_dim,
            self.config["model"],
            state_in=existing_state_in,
            seq_lens=existing_seq_lens)

        # KL Coefficient
        self.kl_coeff = tf.get_variable(
            initializer=tf.constant_initializer(self.kl_coeff_val),
            name="kl_coeff",
            shape=(),
            trainable=False,
            dtype=tf.float32)

        self.logits = self.model.outputs
        curr_action_dist = dist_cls(self.logits)
        self.sampler = curr_action_dist.sample()
        if self.config["use_gae"]:
            if self.config["vf_share_layers"]:
                self.value_function = self.model.value_function()
            else:
                vf_config = self.config["model"].copy()
                # Do not split the last layer of the value function into
                # mean parameters and standard deviation parameters and
                # do not make the standard deviations free variables.
                vf_config["free_log_std"] = False
                if vf_config["use_lstm"]:
                    vf_config["use_lstm"] = False
                    logger.warning(
                        "It is not recommended to use a LSTM model with "
                        "vf_share_layers=False (consider setting it to True). "
                        "If you want to not share layers, you can implement "
                        "a custom LSTM model that overrides the "
                        "value_function() method.")
                with tf.variable_scope("value_function"):
                    self.value_function = ModelCatalog.get_model({
                        "obs": obs_ph,
                        "prev_actions": prev_actions_ph,
                        "prev_rewards": prev_rewards_ph,
                        "is_training": self._get_is_training_placeholder(),
                    }, observation_space, action_space, 1, vf_config).outputs
                    self.value_function = tf.reshape(self.value_function, [-1])
        else:
            self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])

        if self.model.state_in:
            max_seq_len = tf.reduce_max(self.model.seq_lens)
            mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
            mask = tf.reshape(mask, [-1])
        else:
            mask = tf.ones_like(adv_ph, dtype=tf.bool)

        self.loss_obj = PPOLoss(
            action_space,
            value_targets_ph,
            adv_ph,
            act_ph,
            logits_ph,
            vf_preds_ph,
            curr_action_dist,
            self.value_function,
            self.kl_coeff,
            mask,
            entropy_coeff=self.config["entropy_coeff"],
            clip_param=self.config["clip_param"],
            vf_clip_param=self.config["vf_clip_param"],
            vf_loss_coeff=self.config["vf_loss_coeff"],
            use_gae=self.config["use_gae"])

        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=obs_ph,
            action_sampler=self.sampler,
            action_prob=curr_action_dist.sampled_action_prob(),
            loss=self.loss_obj.loss,
            model=self.model,
            loss_inputs=self.loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            prev_action_input=prev_actions_ph,
            prev_reward_input=prev_rewards_ph,
            seq_lens=self.model.seq_lens,
            max_seq_len=config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())
        self.explained_variance = explained_variance(value_targets_ph,
                                                     self.value_function)
        self.stats_fetches = {
            "cur_kl_coeff": self.kl_coeff,
            "cur_lr": tf.cast(self.cur_lr, tf.float64),
            "total_loss": self.loss_obj.loss,
            "policy_loss": self.loss_obj.mean_policy_loss,
            "vf_loss": self.loss_obj.mean_vf_loss,
            "vf_explained_var": self.explained_variance,
            "kl": self.loss_obj.mean_kl,
            "entropy": self.loss_obj.mean_entropy
        }
Exemple #3
0
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
        self.config = config
        self.sess = tf.get_default_session()

        # Setup the policy
        self.observations = tf.placeholder(tf.float32, [None] +
                                           list(observation_space.shape))
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        self.prev_actions = ModelCatalog.get_action_placeholder(action_space)
        self.prev_rewards = tf.placeholder(tf.float32, [None],
                                           name="prev_reward")
        self.model = ModelCatalog.get_model(
            {
                "obs": self.observations,
                "prev_actions": self.prev_actions,
                "prev_rewards": self.prev_rewards,
                "is_training": self._get_is_training_placeholder(),
            }, observation_space, action_space, logit_dim,
            self.config["model"])
        action_dist = dist_class(self.model.outputs)
        self.vf = self.model.value_function()
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        # Setup the policy loss
        if isinstance(action_space, gym.spaces.Box):
            ac_size = action_space.shape[0]
            actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
        elif isinstance(action_space, gym.spaces.Discrete):
            actions = tf.placeholder(tf.int64, [None], name="ac")
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for A3C.".format(
                    action_space))
        advantages = tf.placeholder(tf.float32, [None], name="advantages")
        self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
        self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
                            self.vf, self.config["vf_loss_coeff"],
                            self.config["entropy_coeff"])

        # Initialize TFPolicyGraph
        loss_in = [
            ("obs", self.observations),
            ("actions", actions),
            ("prev_actions", self.prev_actions),
            ("prev_rewards", self.prev_rewards),
            ("advantages", advantages),
            ("value_targets", self.v_target),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(self,
                               observation_space,
                               action_space,
                               self.sess,
                               obs_input=self.observations,
                               action_sampler=action_dist.sample(),
                               action_prob=action_dist.sampled_action_prob(),
                               loss=self.loss.total_loss,
                               model=self.model,
                               loss_inputs=loss_in,
                               state_inputs=self.model.state_in,
                               state_outputs=self.model.state_out,
                               prev_action_input=self.prev_actions,
                               prev_reward_input=self.prev_rewards,
                               seq_lens=self.model.seq_lens,
                               max_seq_len=self.config["model"]["max_seq_len"])

        self.stats_fetches = {
            "stats": {
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "policy_loss": self.loss.pi_loss,
                "policy_entropy": self.loss.entropy,
                "grad_gnorm": tf.global_norm(self._grads),
                "var_gnorm": tf.global_norm(self.var_list),
                "vf_loss": self.loss.vf_loss,
                "vf_explained_var": explained_variance(self.v_target, self.vf),
            },
        }

        self.sess.run(tf.global_variables_initializer())
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None):
        config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
        assert config["batch_mode"] == "truncate_episodes", \
            "Must use `truncate_episodes` batch mode with V-trace."
        self.config = config
        self.sess = tf.get_default_session()
        self.grads = None

        if isinstance(action_space, gym.spaces.Discrete):
            is_multidiscrete = False
            output_hidden_shape = [action_space.n]
        elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
            is_multidiscrete = True
            output_hidden_shape = action_space.nvec.astype(np.int32)
        elif self.config["vtrace"]:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for APPO + VTrace.",
                format(action_space))
        else:
            is_multidiscrete = False
            output_hidden_shape = 1

        # Policy network model
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])

        # Create input placeholders
        if existing_inputs:
            if self.config["vtrace"]:
                actions, dones, behaviour_logits, rewards, observations, \
                    prev_actions, prev_rewards = existing_inputs[:7]
                existing_state_in = existing_inputs[7:-1]
                existing_seq_lens = existing_inputs[-1]
            else:
                actions, dones, behaviour_logits, rewards, observations, \
                    prev_actions, prev_rewards, adv_ph, value_targets = \
                    existing_inputs[:9]
                existing_state_in = existing_inputs[9:-1]
                existing_seq_lens = existing_inputs[-1]
        else:
            actions = ModelCatalog.get_action_placeholder(action_space)
            dones = tf.placeholder(tf.bool, [None], name="dones")
            rewards = tf.placeholder(tf.float32, [None], name="rewards")
            behaviour_logits = tf.placeholder(
                tf.float32, [None, logit_dim], name="behaviour_logits")
            observations = tf.placeholder(
                tf.float32, [None] + list(observation_space.shape))
            existing_state_in = None
            existing_seq_lens = None

            if not self.config["vtrace"]:
                adv_ph = tf.placeholder(
                    tf.float32, name="advantages", shape=(None, ))
                value_targets = tf.placeholder(
                    tf.float32, name="value_targets", shape=(None, ))
        self.observations = observations

        # Unpack behaviour logits
        unpacked_behaviour_logits = tf.split(
            behaviour_logits, output_hidden_shape, axis=1)

        # Setup the policy
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        prev_actions = ModelCatalog.get_action_placeholder(action_space)
        prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
        self.model = ModelCatalog.get_model(
            {
                "obs": observations,
                "prev_actions": prev_actions,
                "prev_rewards": prev_rewards,
                "is_training": self._get_is_training_placeholder(),
            },
            observation_space,
            action_space,
            logit_dim,
            self.config["model"],
            state_in=existing_state_in,
            seq_lens=existing_seq_lens)
        unpacked_outputs = tf.split(
            self.model.outputs, output_hidden_shape, axis=1)

        dist_inputs = unpacked_outputs if is_multidiscrete else \
            self.model.outputs
        prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
            behaviour_logits

        action_dist = dist_class(dist_inputs)
        prev_action_dist = dist_class(prev_dist_inputs)

        values = self.model.value_function()
        self.value_function = values
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        def make_time_major(tensor, drop_last=False):
            """Swaps batch and trajectory axis.
            Args:
                tensor: A tensor or list of tensors to reshape.
                drop_last: A bool indicating whether to drop the last
                trajectory item.
            Returns:
                res: A tensor with swapped axes or a list of tensors with
                swapped axes.
            """
            if isinstance(tensor, list):
                return [make_time_major(t, drop_last) for t in tensor]

            if self.model.state_init:
                B = tf.shape(self.model.seq_lens)[0]
                T = tf.shape(tensor)[0] // B
            else:
                # Important: chop the tensor into batches at known episode cut
                # boundaries. TODO(ekl) this is kind of a hack
                T = self.config["sample_batch_size"]
                B = tf.shape(tensor)[0] // T
            rs = tf.reshape(tensor,
                            tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))

            # swap B and T axes
            res = tf.transpose(
                rs,
                [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))

            if drop_last:
                return res[:-1]
            return res

        if self.model.state_in:
            max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
            mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
            mask = tf.reshape(mask, [-1])
        else:
            mask = tf.ones_like(rewards)

        # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
        if self.config["vtrace"]:
            logger.info("Using V-Trace surrogate loss (vtrace=True)")

            # Prepare actions for loss
            loss_actions = actions if is_multidiscrete else tf.expand_dims(
                actions, axis=1)

            self.loss = VTraceSurrogateLoss(
                actions=make_time_major(loss_actions, drop_last=True),
                prev_actions_logp=make_time_major(
                    prev_action_dist.logp(actions), drop_last=True),
                actions_logp=make_time_major(
                    action_dist.logp(actions), drop_last=True),
                action_kl=prev_action_dist.kl(action_dist),
                actions_entropy=make_time_major(
                    action_dist.entropy(), drop_last=True),
                dones=make_time_major(dones, drop_last=True),
                behaviour_logits=make_time_major(
                    unpacked_behaviour_logits, drop_last=True),
                target_logits=make_time_major(
                    unpacked_outputs, drop_last=True),
                discount=config["gamma"],
                rewards=make_time_major(rewards, drop_last=True),
                values=make_time_major(values, drop_last=True),
                bootstrap_value=make_time_major(values)[-1],
                valid_mask=make_time_major(mask, drop_last=True),
                vf_loss_coeff=self.config["vf_loss_coeff"],
                entropy_coeff=self.config["entropy_coeff"],
                clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
                clip_pg_rho_threshold=self.config[
                    "vtrace_clip_pg_rho_threshold"],
                clip_param=self.config["clip_param"])
        else:
            logger.info("Using PPO surrogate loss (vtrace=False)")
            self.loss = PPOSurrogateLoss(
                prev_actions_logp=make_time_major(
                    prev_action_dist.logp(actions)),
                actions_logp=make_time_major(action_dist.logp(actions)),
                action_kl=prev_action_dist.kl(action_dist),
                actions_entropy=make_time_major(action_dist.entropy()),
                values=make_time_major(values),
                valid_mask=make_time_major(mask),
                advantages=make_time_major(adv_ph),
                value_targets=make_time_major(value_targets),
                vf_loss_coeff=self.config["vf_loss_coeff"],
                entropy_coeff=self.config["entropy_coeff"],
                clip_param=self.config["clip_param"])

        # KL divergence between worker and learner logits for debugging
        model_dist = MultiCategorical(unpacked_outputs)
        behaviour_dist = MultiCategorical(unpacked_behaviour_logits)

        kls = model_dist.kl(behaviour_dist)
        if len(kls) > 1:
            self.KL_stats = {}

            for i, kl in enumerate(kls):
                self.KL_stats.update({
                    "mean_KL_{}".format(i): tf.reduce_mean(kl),
                    "max_KL_{}".format(i): tf.reduce_max(kl),
                    "median_KL_{}".format(i): tf.contrib.distributions.
                    percentile(kl, 50.0),
                })
        else:
            self.KL_stats = {
                "mean_KL": tf.reduce_mean(kls[0]),
                "max_KL": tf.reduce_max(kls[0]),
                "median_KL": tf.contrib.distributions.percentile(kls[0], 50.0),
            }

        # Initialize TFPolicyGraph
        loss_in = [
            ("actions", actions),
            ("dones", dones),
            ("behaviour_logits", behaviour_logits),
            ("rewards", rewards),
            ("obs", observations),
            ("prev_actions", prev_actions),
            ("prev_rewards", prev_rewards),
        ]
        if not self.config["vtrace"]:
            loss_in.append(("advantages", adv_ph))
            loss_in.append(("value_targets", value_targets))
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=observations,
            action_sampler=action_dist.sample(),
            action_prob=action_dist.sampled_action_prob(),
            loss=self.loss.total_loss,
            model=self.model,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            prev_action_input=prev_actions,
            prev_reward_input=prev_rewards,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"],
            batch_divisibility_req=self.config["sample_batch_size"])

        self.sess.run(tf.global_variables_initializer())

        values_batched = make_time_major(
            values, drop_last=self.config["vtrace"])
        self.stats_fetches = {
            "stats": dict({
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "policy_loss": self.loss.pi_loss,
                "entropy": self.loss.entropy,
                "grad_gnorm": tf.global_norm(self._grads),
                "var_gnorm": tf.global_norm(self.var_list),
                "vf_loss": self.loss.vf_loss,
                "vf_explained_var": explained_variance(
                    tf.reshape(self.loss.value_targets, [-1]),
                    tf.reshape(values_batched, [-1])),
            }, **self.KL_stats),
        }
Exemple #5
0
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None):
        config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
        assert config["batch_mode"] == "truncate_episodes", \
            "Must use `truncate_episodes` batch mode with V-trace."
        self.config = config
        self.sess = tf.get_default_session()

        # Create input placeholders
        if existing_inputs:
            actions, dones, behaviour_logits, rewards, observations, \
                prev_actions, prev_rewards = existing_inputs[:7]
            existing_state_in = existing_inputs[7:-1]
            existing_seq_lens = existing_inputs[-1]
        else:
            if isinstance(action_space, gym.spaces.Discrete):
                ac_size = action_space.n
                actions = tf.placeholder(tf.int64, [None], name="ac")
            else:
                raise UnsupportedSpaceException(
                    "Action space {} is not supported for IMPALA.".format(
                        action_space))
            dones = tf.placeholder(tf.bool, [None], name="dones")
            rewards = tf.placeholder(tf.float32, [None], name="rewards")
            behaviour_logits = tf.placeholder(tf.float32, [None, ac_size],
                                              name="behaviour_logits")
            observations = tf.placeholder(tf.float32, [None] +
                                          list(observation_space.shape))
            existing_state_in = None
            existing_seq_lens = None

        # Setup the policy
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        prev_actions = ModelCatalog.get_action_placeholder(action_space)
        prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
        self.model = ModelCatalog.get_model(
            {
                "obs": observations,
                "prev_actions": prev_actions,
                "prev_rewards": prev_rewards,
                "is_training": self._get_is_training_placeholder(),
            },
            observation_space,
            logit_dim,
            self.config["model"],
            state_in=existing_state_in,
            seq_lens=existing_seq_lens)
        action_dist = dist_class(self.model.outputs)
        values = self.model.value_function()
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        def to_batches(tensor):
            if self.config["model"]["use_lstm"]:
                B = tf.shape(self.model.seq_lens)[0]
                T = tf.shape(tensor)[0] // B
            else:
                # Important: chop the tensor into batches at known episode cut
                # boundaries. TODO(ekl) this is kind of a hack
                T = self.config["sample_batch_size"]
                B = tf.shape(tensor)[0] // T
            rs = tf.reshape(tensor,
                            tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
            # swap B and T axes
            return tf.transpose(
                rs,
                [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))

        if self.model.state_in:
            max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
            mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
            mask = tf.reshape(mask, [-1])
        else:
            mask = tf.ones_like(rewards)

        # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
        self.loss = VTraceLoss(
            actions=to_batches(actions)[:-1],
            actions_logp=to_batches(action_dist.logp(actions))[:-1],
            actions_entropy=to_batches(action_dist.entropy())[:-1],
            dones=to_batches(dones)[:-1],
            behaviour_logits=to_batches(behaviour_logits)[:-1],
            target_logits=to_batches(self.model.outputs)[:-1],
            discount=config["gamma"],
            rewards=to_batches(rewards)[:-1],
            values=to_batches(values)[:-1],
            bootstrap_value=to_batches(values)[-1],
            valid_mask=to_batches(mask)[:-1],
            vf_loss_coeff=self.config["vf_loss_coeff"],
            entropy_coeff=self.config["entropy_coeff"],
            clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])

        # KL divergence between worker and learner logits for debugging
        model_dist = Categorical(self.model.outputs)
        behaviour_dist = Categorical(behaviour_logits)
        self.KLs = model_dist.kl(behaviour_dist)
        self.mean_KL = tf.reduce_mean(self.KLs)
        self.max_KL = tf.reduce_max(self.KLs)
        self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0)

        # Initialize TFPolicyGraph
        loss_in = [
            ("actions", actions),
            ("dones", dones),
            ("behaviour_logits", behaviour_logits),
            ("rewards", rewards),
            ("obs", observations),
            ("prev_actions", prev_actions),
            ("prev_rewards", prev_rewards),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=observations,
            action_sampler=action_dist.sample(),
            loss=self.model.loss() + self.loss.total_loss,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            prev_action_input=prev_actions,
            prev_reward_input=prev_rewards,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"],
            batch_divisibility_req=self.config["sample_batch_size"])

        self.sess.run(tf.global_variables_initializer())

        self.stats_fetches = {
            "stats": {
                "cur_lr":
                tf.cast(self.cur_lr, tf.float64),
                "policy_loss":
                self.loss.pi_loss,
                "entropy":
                self.loss.entropy,
                "grad_gnorm":
                tf.global_norm(self._grads),
                "var_gnorm":
                tf.global_norm(self.var_list),
                "vf_loss":
                self.loss.vf_loss,
                "vf_explained_var":
                explained_variance(
                    tf.reshape(self.loss.vtrace_returns.vs, [-1]),
                    tf.reshape(to_batches(values)[:-1], [-1])),
                "mean_KL":
                self.mean_KL,
                "max_KL":
                self.max_KL,
                "median_KL":
                self.median_KL,
            },
        }
Exemple #6
0
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None):
        """
        Arguments:
            observation_space: Environment observation space specification.
            action_space: Environment action space specification.
            config (dict): Configuration values for PPO graph.
            existing_inputs (list): Optional list of tuples that specify the
                placeholders upon which the graph should be built upon.
        """
        config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
        self.sess = tf.get_default_session()
        self.action_space = action_space
        self.config = config
        self.kl_coeff_val = self.config["kl_coeff"]
        self.kl_target = self.config["kl_target"]
        dist_cls, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])

        if existing_inputs:
            obs_ph, value_targets_ph, adv_ph, act_ph, \
                logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \
                existing_inputs[:8]
            existing_state_in = existing_inputs[8:-1]
            existing_seq_lens = existing_inputs[-1]
        else:
            obs_ph = tf.placeholder(tf.float32,
                                    name="obs",
                                    shape=(None, ) + observation_space.shape)
            adv_ph = tf.placeholder(tf.float32,
                                    name="advantages",
                                    shape=(None, ))
            act_ph = ModelCatalog.get_action_placeholder(action_space)
            logits_ph = tf.placeholder(tf.float32,
                                       name="logits",
                                       shape=(None, logit_dim))
            vf_preds_ph = tf.placeholder(tf.float32,
                                         name="vf_preds",
                                         shape=(None, ))
            value_targets_ph = tf.placeholder(tf.float32,
                                              name="value_targets",
                                              shape=(None, ))
            prev_actions_ph = ModelCatalog.get_action_placeholder(action_space)
            prev_rewards_ph = tf.placeholder(tf.float32, [None],
                                             name="prev_reward")
            existing_state_in = None
            existing_seq_lens = None
        self.observations = obs_ph

        self.loss_in = [
            ("obs", obs_ph),
            ("value_targets", value_targets_ph),
            ("advantages", adv_ph),
            ("actions", act_ph),
            ("logits", logits_ph),
            ("vf_preds", vf_preds_ph),
            ("prev_actions", prev_actions_ph),
            ("prev_rewards", prev_rewards_ph),
        ]
        self.model = ModelCatalog.get_model(
            {
                "obs": obs_ph,
                "prev_actions": prev_actions_ph,
                "prev_rewards": prev_rewards_ph
            },
            observation_space,
            logit_dim,
            self.config["model"],
            state_in=existing_state_in,
            seq_lens=existing_seq_lens)

        # KL Coefficient
        self.kl_coeff = tf.get_variable(initializer=tf.constant_initializer(
            self.kl_coeff_val),
                                        name="kl_coeff",
                                        shape=(),
                                        trainable=False,
                                        dtype=tf.float32)

        self.logits = self.model.outputs
        curr_action_dist = dist_cls(self.logits)
        self.sampler = curr_action_dist.sample()
        if self.config["use_gae"]:
            if self.config["vf_share_layers"]:
                self.value_function = self.model.value_function()
            else:
                vf_config = self.config["model"].copy()
                # Do not split the last layer of the value function into
                # mean parameters and standard deviation parameters and
                # do not make the standard deviations free variables.
                vf_config["free_log_std"] = False
                vf_config["use_lstm"] = False
                with tf.variable_scope("value_function"):
                    self.value_function = ModelCatalog.get_model(
                        {
                            "obs": obs_ph,
                            "prev_actions": prev_actions_ph,
                            "prev_rewards": prev_rewards_ph
                        }, observation_space, 1, vf_config).outputs
                    self.value_function = tf.reshape(self.value_function, [-1])
        else:
            self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])

        if self.model.state_in:
            max_seq_len = tf.reduce_max(self.model.seq_lens)
            mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
            mask = tf.reshape(mask, [-1])
        else:
            mask = tf.ones_like(adv_ph)

        self.loss_obj = PPOLoss(action_space,
                                value_targets_ph,
                                adv_ph,
                                act_ph,
                                logits_ph,
                                vf_preds_ph,
                                curr_action_dist,
                                self.value_function,
                                self.kl_coeff,
                                mask,
                                entropy_coeff=self.config["entropy_coeff"],
                                clip_param=self.config["clip_param"],
                                vf_clip_param=self.config["vf_clip_param"],
                                vf_loss_coeff=self.config["vf_loss_coeff"],
                                use_gae=self.config["use_gae"])

        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(self,
                               observation_space,
                               action_space,
                               self.sess,
                               obs_input=obs_ph,
                               action_sampler=self.sampler,
                               loss=self.model.loss() + self.loss_obj.loss,
                               loss_inputs=self.loss_in,
                               state_inputs=self.model.state_in,
                               state_outputs=self.model.state_out,
                               prev_action_input=prev_actions_ph,
                               prev_reward_input=prev_rewards_ph,
                               seq_lens=self.model.seq_lens,
                               max_seq_len=config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())
        self.explained_variance = explained_variance(value_targets_ph,
                                                     self.value_function)
        self.stats_fetches = {
            "cur_lr": tf.cast(self.cur_lr, tf.float64),
            "total_loss": self.loss_obj.loss,
            "policy_loss": self.loss_obj.mean_policy_loss,
            "vf_loss": self.loss_obj.mean_vf_loss,
            "vf_explained_var": self.explained_variance,
            "kl": self.loss_obj.mean_kl,
            "entropy": self.loss_obj.mean_entropy
        }
Exemple #7
0
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None):
        config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
        assert config["batch_mode"] == "truncate_episodes", \
            "Must use `truncate_episodes` batch mode with V-trace."
        self.config = config
        self.sess = tf.get_default_session()
        self.grads = None

        if isinstance(action_space, gym.spaces.Discrete):
            is_multidiscrete = False
            actions_shape = [None]
            output_hidden_shape = [action_space.n]
        elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
            is_multidiscrete = True
            actions_shape = [None, len(action_space.nvec)]
            output_hidden_shape = action_space.nvec.astype(np.int32)
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for IMPALA.".format(
                    action_space))

        # Create input placeholders
        if existing_inputs:
            actions, dones, behaviour_logits, rewards, observations, \
                prev_actions, prev_rewards = existing_inputs[:7]
            existing_state_in = existing_inputs[7:-1]
            existing_seq_lens = existing_inputs[-1]
        else:
            actions = tf.placeholder(tf.int64, actions_shape, name="ac")
            dones = tf.placeholder(tf.bool, [None], name="dones")
            rewards = tf.placeholder(tf.float32, [None], name="rewards")
            behaviour_logits = tf.placeholder(
                tf.float32, [None, sum(output_hidden_shape)],
                name="behaviour_logits")
            observations = tf.placeholder(
                tf.float32, [None] + list(observation_space.shape))
            existing_state_in = None
            existing_seq_lens = None

        # Unpack behaviour logits
        unpacked_behaviour_logits = tf.split(
            behaviour_logits, output_hidden_shape, axis=1)

        # Setup the policy
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        prev_actions = ModelCatalog.get_action_placeholder(action_space)
        prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
        self.model = ModelCatalog.get_model(
            {
                "obs": observations,
                "prev_actions": prev_actions,
                "prev_rewards": prev_rewards,
                "is_training": self._get_is_training_placeholder(),
            },
            observation_space,
            logit_dim,
            self.config["model"],
            state_in=existing_state_in,
            seq_lens=existing_seq_lens)
        unpacked_outputs = tf.split(
            self.model.outputs, output_hidden_shape, axis=1)

        dist_inputs = unpacked_outputs if is_multidiscrete else \
            self.model.outputs
        action_dist = dist_class(dist_inputs)

        values = self.model.value_function()
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        def make_time_major(tensor, drop_last=False):
            """Swaps batch and trajectory axis.
            Args:
                tensor: A tensor or list of tensors to reshape.
                drop_last: A bool indicating whether to drop the last
                trajectory item.
            Returns:
                res: A tensor with swapped axes or a list of tensors with
                swapped axes.
            """
            if isinstance(tensor, list):
                return [make_time_major(t, drop_last) for t in tensor]

            if self.model.state_init:
                B = tf.shape(self.model.seq_lens)[0]
                T = tf.shape(tensor)[0] // B
            else:
                # Important: chop the tensor into batches at known episode cut
                # boundaries. TODO(ekl) this is kind of a hack
                T = self.config["sample_batch_size"]
                B = tf.shape(tensor)[0] // T
            rs = tf.reshape(tensor,
                            tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))

            # swap B and T axes
            res = tf.transpose(
                rs,
                [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))

            if drop_last:
                return res[:-1]
            return res

        if self.model.state_in:
            max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
            mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
            mask = tf.reshape(mask, [-1])
        else:
            mask = tf.ones_like(rewards, dtype=tf.bool)

        # Prepare actions for loss
        loss_actions = actions if is_multidiscrete else tf.expand_dims(
            actions, axis=1)

        # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
        self.loss = VTraceLoss(
            actions=make_time_major(loss_actions, drop_last=True),
            actions_logp=make_time_major(
                action_dist.logp(actions), drop_last=True),
            actions_entropy=make_time_major(
                action_dist.entropy(), drop_last=True),
            dones=make_time_major(dones, drop_last=True),
            behaviour_logits=make_time_major(
                unpacked_behaviour_logits, drop_last=True),
            target_logits=make_time_major(unpacked_outputs, drop_last=True),
            discount=config["gamma"],
            rewards=make_time_major(rewards, drop_last=True),
            values=make_time_major(values, drop_last=True),
            bootstrap_value=make_time_major(values)[-1],
            valid_mask=make_time_major(mask, drop_last=True),
            vf_loss_coeff=self.config["vf_loss_coeff"],
            entropy_coeff=self.config["entropy_coeff"],
            clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])

        # KL divergence between worker and learner logits for debugging
        model_dist = MultiCategorical(unpacked_outputs)
        behaviour_dist = MultiCategorical(unpacked_behaviour_logits)

        kls = model_dist.kl(behaviour_dist)
        if len(kls) > 1:
            self.KL_stats = {}

            for i, kl in enumerate(kls):
                self.KL_stats.update({
                    "mean_KL_{}".format(i): tf.reduce_mean(kl),
                    "max_KL_{}".format(i): tf.reduce_max(kl),
                    "median_KL_{}".format(i): tf.contrib.distributions.
                    percentile(kl, 50.0),
                })
        else:
            self.KL_stats = {
                "mean_KL": tf.reduce_mean(kls[0]),
                "max_KL": tf.reduce_max(kls[0]),
                "median_KL": tf.contrib.distributions.percentile(kls[0], 50.0),
            }

        # Initialize TFPolicyGraph
        loss_in = [
            ("actions", actions),
            ("dones", dones),
            ("behaviour_logits", behaviour_logits),
            ("rewards", rewards),
            ("obs", observations),
            ("prev_actions", prev_actions),
            ("prev_rewards", prev_rewards),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=observations,
            action_sampler=action_dist.sample(),
            action_prob=action_dist.sampled_action_prob(),
            loss=self.loss.total_loss,
            model=self.model,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            prev_action_input=prev_actions,
            prev_reward_input=prev_rewards,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"],
            batch_divisibility_req=self.config["sample_batch_size"])

        self.sess.run(tf.global_variables_initializer())

        self.stats_fetches = {
            "stats": dict({
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "policy_loss": self.loss.pi_loss,
                "entropy": self.loss.entropy,
                "grad_gnorm": tf.global_norm(self._grads),
                "var_gnorm": tf.global_norm(self.var_list),
                "vf_loss": self.loss.vf_loss,
                "vf_explained_var": explained_variance(
                    tf.reshape(self.loss.vtrace_returns.vs, [-1]),
                    tf.reshape(make_time_major(values, drop_last=True), [-1])),
            }, **self.KL_stats),
        }
Exemple #8
0
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
        assert config["batch_mode"] == "truncate_episodes", \
            "Must use `truncate_episodes` batch mode with V-trace."
        self.config = config
        self.sess = tf.get_default_session()

        # Setup the policy
        self.observations = tf.placeholder(
            tf.float32, [None] + list(observation_space.shape))
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        self.model = ModelCatalog.get_model(self.observations, logit_dim,
                                            self.config["model"])
        action_dist = dist_class(self.model.outputs)
        values = tf.reshape(
            linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
            [-1])
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        # Setup the policy loss
        if isinstance(action_space, gym.spaces.Discrete):
            ac_size = action_space.n
            actions = tf.placeholder(tf.int64, [None], name="ac")
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for IMPALA.".format(
                    action_space))
        dones = tf.placeholder(tf.bool, [None], name="dones")
        rewards = tf.placeholder(tf.float32, [None], name="rewards")
        behaviour_logits = tf.placeholder(
            tf.float32, [None, ac_size], name="behaviour_logits")

        def to_batches(tensor):
            if self.config["model"]["use_lstm"]:
                B = tf.shape(self.model.seq_lens)[0]
                T = tf.shape(tensor)[0] // B
            else:
                # Important: chop the tensor into batches at known episode cut
                # boundaries. TODO(ekl) this is kind of a hack
                T = (self.config["sample_batch_size"] //
                     self.config["num_envs_per_worker"])
                B = tf.shape(tensor)[0] // T
            rs = tf.reshape(tensor,
                            tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
            # swap B and T axes
            return tf.transpose(
                rs,
                [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))

        # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
        self.loss = VTraceLoss(
            actions=to_batches(actions)[:-1],
            actions_logp=to_batches(action_dist.logp(actions))[:-1],
            actions_entropy=to_batches(action_dist.entropy())[:-1],
            dones=to_batches(dones)[:-1],
            behaviour_logits=to_batches(behaviour_logits)[:-1],
            target_logits=to_batches(self.model.outputs)[:-1],
            discount=config["gamma"],
            rewards=to_batches(rewards)[:-1],
            values=to_batches(values)[:-1],
            bootstrap_value=to_batches(values)[-1],
            vf_loss_coeff=self.config["vf_loss_coeff"],
            entropy_coeff=self.config["entropy_coeff"],
            clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])

        # Initialize TFPolicyGraph
        loss_in = [
            ("actions", actions),
            ("dones", dones),
            ("behaviour_logits", behaviour_logits),
            ("rewards", rewards),
            ("obs", self.observations),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=self.observations,
            action_sampler=action_dist.sample(),
            loss=self.loss.total_loss,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())

        self.stats_fetches = {
            "stats": {
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "policy_loss": self.loss.pi_loss,
                "entropy": self.loss.entropy,
                "grad_gnorm": tf.global_norm(self._grads),
                "var_gnorm": tf.global_norm(self.var_list),
                "vf_loss": self.loss.vf_loss,
                "vf_explained_var": explained_variance(
                    tf.reshape(self.loss.vtrace_returns.vs, [-1]),
                    tf.reshape(to_batches(values)[:-1], [-1])),
            },
        }
Exemple #9
0
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None):
        """
        Arguments:
            observation_space: Environment observation space specification.
            action_space: Environment action space specification.
            config (dict): Configuration values for PPORND graph.
            existing_inputs (list): Optional list of tuples that specify the
                placeholders upon which the graph should be built upon.
        """
        config = dict(DEFAULT_CONFIG, **config)
        self.sess = tf.get_default_session()
        self.action_space = action_space
        self.config = config
        self.kl_coeff_val = self.config["kl_coeff"]
        self.kl_target = self.config["kl_target"]
        dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space)

        if existing_inputs:
            obs_ph, value_targets_ph, adv_ph, act_ph, \
                logits_ph, vf_preds_ph = existing_inputs[:6]
            # TODO: add adv_ph_int
            existing_state_in = existing_inputs[6:-1]
            existing_seq_lens = existing_inputs[-1]
        else:
            obs_ph = tf.placeholder(tf.float32,
                                    name="obs",
                                    shape=(None, ) + observation_space.shape)
            adv_ph = tf.placeholder(tf.float32,
                                    name="advantages",
                                    shape=(None, ))
            adv_int_ph = tf.placeholder(tf.float32,
                                        name="advantages_int",
                                        shape=(None, ))
            act_ph = ModelCatalog.get_action_placeholder(action_space)
            logits_ph = tf.placeholder(tf.float32,
                                       name="logits",
                                       shape=(None, logit_dim))
            vf_preds_ph = tf.placeholder(tf.float32,
                                         name="vf_preds",
                                         shape=(None, ))
            value_targets_ph = tf.placeholder(tf.float32,
                                              name="value_targets",
                                              shape=(None, ))
            existing_state_in = None
            existing_seq_lens = None
        self.observations = obs_ph

        self.loss_in = [
            ("obs", obs_ph),
            ("value_targets", value_targets_ph),
            ("advantages", adv_ph),
            ("actions", act_ph),
            ("logits", logits_ph),
            ("vf_preds", vf_preds_ph),
        ]
        self.model = ModelCatalog.get_model(obs_ph,
                                            logit_dim,
                                            self.config["model"],
                                            state_in=existing_state_in,
                                            seq_lens=existing_seq_lens)

        # KL Coefficient
        self.kl_coeff = tf.get_variable(initializer=tf.constant_initializer(
            self.kl_coeff_val),
                                        name="kl_coeff",
                                        shape=(),
                                        trainable=False,
                                        dtype=tf.float32)

        self.logits = self.model.outputs
        curr_action_dist = dist_cls(self.logits)
        self.sampler = curr_action_dist.sample()
        if self.config["use_gae"]:
            if self.config["vf_share_layers"]:
                self.value_function = tf.reshape(
                    linear(self.model.last_layer, 1, "value",
                           normc_initializer(1.0)), [-1])
            else:
                vf_config = self.config["model"].copy()
                # Do not split the last layer of the value function into
                # mean parameters and standard deviation parameters and
                # do not make the standard deviations free variables.
                vf_config["free_log_std"] = False
                vf_config["use_lstm"] = False
                with tf.variable_scope("value_function"):
                    self.value_function = ModelCatalog.get_model(
                        obs_ph, 1, vf_config).outputs
                    self.value_function = tf.reshape(self.value_function, [-1])
        else:
            self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])

        # TODO: add another head in the policy network for estimating value of intrinsic reward

        # RND target network
        with tf.variable_scope("rnd_target"):
            modelconfig = self.config["model"].copy()
            modelconfig["free_log_std"] = False
            modelconfig["use_lstm"] = False
            self.rnd_target = ModelCatalog.get_model(
                obs_ph, self.config["embedding_size"], modelconfig).outputs
            # self.rnd_target = tf.reshape(self.rnd_target, [-1])  # TODO: necessary?

        # RND predictor network
        with tf.variable_scope("rnd_predictor"):
            modelconfig = self.config["model"].copy()
            modelconfig["free_log_std"] = False
            modelconfig["use_lstm"] = False
            self.rnd_predictor = ModelCatalog.get_model(
                obs_ph, self.config["embedding_size"], modelconfig).outputs

        self.loss_obj = PPORNDLoss(
            action_space,
            value_targets_ph,
            adv_ph,
            adv_int_ph,
            act_ph,
            logits_ph,
            vf_preds_ph,
            curr_action_dist,
            self.value_function,
            self.kl_coeff,
            self.rnd_target,
            self.rnd_predictor,
            # TODO: valid_mask??
            entropy_coeff=self.config["entropy_coeff"],
            clip_param=self.config["clip_param"],
            vf_clip_param=self.config["vf_clip_param"],
            vf_loss_coeff=self.config["vf_loss_coeff"],
            use_gae=self.config["use_gae"])

        entropy_coeff = 0,
        clip_param = 0.1,
        vf_clip_param = 0.1,
        vf_loss_coeff = 1.0,
        use_gae = True,
        rnd_pred_update_prop = 0.25

        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(self,
                               observation_space,
                               action_space,
                               self.sess,
                               obs_input=obs_ph,
                               action_sampler=self.sampler,
                               loss=self.loss_obj.loss,
                               loss_inputs=self.loss_in,
                               state_inputs=self.model.state_in,
                               state_outputs=self.model.state_out,
                               seq_lens=self.model.seq_lens,
                               max_seq_len=config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())
        self.explained_variance = explained_variance(value_targets_ph,
                                                     self.value_function)
        self.stats_fetches = {
            "cur_lr": tf.cast(self.cur_lr, tf.float64),
            "total_loss": self.loss_obj.loss,
            "policy_loss": self.loss_obj.mean_policy_loss,
            "vf_loss": self.loss_obj.mean_vf_loss,
            "vf_explained_var": self.explained_variance,
            "kl": self.loss_obj.mean_kl,
            "entropy": self.loss_obj.mean_entropy
        }