예제 #1
0
 def _init(self, inputs, num_outputs, options):
     x = inputs
     with tf.name_scope("convnet"):
         for i in range(4):
             x = tf.nn.elu(conv2d(x, 32, "l{}".format(i+1), [3, 3], [2, 2]))
         r, c = x.shape[1].value, x.shape[2].value
         x = tf.reshape(x, [-1, r*c*32])
         fc1 = linear(x, 256, "fc1")
         fc2 = linear(x, num_outputs, "fc2", normc_initializer(0.01))
         return fc2, fc1
예제 #2
0
        def value_function(self):
            assert self.cur_instance, "must call forward first"

            with self._branch_variable_scope("value_function"):
                # Simple case: sharing the feature layer
                if self.model_config["vf_share_layers"]:
                    return tf.reshape(
                        linear(self.cur_instance.last_layer, 1,
                               "value_function", normc_initializer(1.0)), [-1])

                # Create a new separate model with no RNN state, etc.
                branch_model_config = self.model_config.copy()
                branch_model_config["free_log_std"] = False
                if branch_model_config["use_lstm"]:
                    branch_model_config["use_lstm"] = False
                    logger.warning(
                        "It is not recommended to use a LSTM model with "
                        "vf_share_layers=False (consider setting it to True). "
                        "If you want to not share layers, you can implement "
                        "a custom LSTM model that overrides the "
                        "value_function() method.")
                branch_instance = self.legacy_model_cls(
                    self.cur_instance.input_dict,
                    self.obs_space,
                    self.action_space,
                    1,
                    branch_model_config,
                    state_in=None,
                    seq_lens=None)
                return tf.reshape(branch_instance.outputs, [-1])
예제 #3
0
            def vf_template(last_layer, input_dict):
                with tf.variable_scope(self.variable_scope):
                    with tf.variable_scope("value_function"):
                        # Simple case: sharing the feature layer
                        if model_config["vf_share_layers"]:
                            return tf.reshape(
                                linear(last_layer, 1, "value_function",
                                       normc_initializer(1.0)), [-1])

                        # Create a new separate model with no RNN state, etc.
                        branch_model_config = model_config.copy()
                        branch_model_config["free_log_std"] = False
                        if branch_model_config["use_lstm"]:
                            branch_model_config["use_lstm"] = False
                            logger.warning(
                                "It is not recommended to use a LSTM model "
                                "with vf_share_layers=False (consider "
                                "setting it to True). If you want to not "
                                "share layers, you can implement a custom "
                                "LSTM model that overrides the "
                                "value_function() method.")
                        branch_instance = legacy_model_cls(
                            input_dict,
                            obs_space,
                            action_space,
                            1,
                            branch_model_config,
                            state_in=None,
                            seq_lens=None)
                        return tf.reshape(branch_instance.outputs, [-1])
예제 #4
0
    def _setup_graph(self, ob_space, ac_space):
        self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
        dist_class, self.logit_dim = ModelCatalog.get_action_dist(
            ac_space, self.config["model"])
        self._model = LSTM(self.x, self.logit_dim, {})

        self.state_in = self._model.state_in
        self.state_out = self._model.state_out

        self.logits = self._model.outputs
        self.action_dist = dist_class(self.logits)
        # with tf.variable_scope("vf"):
        #     vf_model = ModelCatalog.get_model(self.x, 1)
        self.vf = tf.reshape(
            linear(self._model.last_layer, 1, "value", normc_initializer(1.0)),
            [-1])

        self.sample = self.action_dist.sample()
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)
        self.global_step = tf.get_variable("global_step", [],
                                           tf.int32,
                                           initializer=tf.constant_initializer(
                                               0, dtype=tf.int32),
                                           trainable=False)
예제 #5
0
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
        self.config = config
        self.sess = tf.get_default_session()

        # Setup the policy
        self.observations = tf.placeholder(
            tf.float32, [None] + list(observation_space.shape))
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        self.model = ModelCatalog.get_model(self.observations, logit_dim,
                                            self.config["model"])
        action_dist = dist_class(self.model.outputs)
        self.vf = tf.reshape(
            linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
            [-1])
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        # Setup the policy loss
        if isinstance(action_space, gym.spaces.Box):
            ac_size = action_space.shape[0]
            actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
        elif isinstance(action_space, gym.spaces.Discrete):
            actions = tf.placeholder(tf.int64, [None], name="ac")
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for A3C.".format(
                    action_space))
        advantages = tf.placeholder(tf.float32, [None], name="advantages")
        v_target = tf.placeholder(tf.float32, [None], name="v_target")
        self.loss = A3CLoss(action_dist, actions, advantages, v_target,
                            self.vf, self.config["vf_loss_coeff"],
                            self.config["entropy_coeff"])

        # Initialize TFPolicyGraph
        loss_in = [
            ("obs", self.observations),
            ("actions", actions),
            ("advantages", advantages),
            ("value_targets", v_target),
        ]
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=self.observations,
            action_sampler=action_dist.sample(),
            loss=self.loss.total_loss,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())
예제 #6
0
파일: model.py 프로젝트: jamescasbon/ray
    def value_function(self):
        """Builds the value function output.

        This method can be overridden to customize the implementation of the
        value function (e.g., not sharing hidden layers).

        Returns:
            Tensor of size [BATCH_SIZE] for the value function.
        """
        return tf.reshape(
            linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
예제 #7
0
파일: model.py 프로젝트: zofuthan/ray
    def value_function(self):
        """Builds the value function output.

        This method can be overridden to customize the implementation of the
        value function (e.g., not sharing hidden layers).

        Returns:
            Tensor of size [BATCH_SIZE] for the value function.
        """
        return tf.reshape(
            linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
예제 #8
0
파일: lstm.py 프로젝트: robertnishihara/ray
    def _build_layers_v2(self, input_dict, num_outputs, options):
        cell_size = options.get("lstm_cell_size")
        if options.get("lstm_use_prev_action_reward"):
            action_dim = int(
                np.product(
                    input_dict["prev_actions"].get_shape().as_list()[1:]))
            features = tf.concat(
                [
                    input_dict["obs"],
                    tf.reshape(
                        tf.cast(input_dict["prev_actions"], tf.float32),
                        [-1, action_dim]),
                    tf.reshape(input_dict["prev_rewards"], [-1, 1]),
                ],
                axis=1)
        else:
            features = input_dict["obs"]
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.c], name="c")
            h_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.h], name="h")
            self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = rnn.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(
            lstm,
            last_layer,
            initial_state=state_in,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        self.state_out = list(lstm_state)

        # Compute outputs
        last_layer = tf.reshape(lstm_out, [-1, cell_size])
        logits = linear(last_layer, num_outputs, "action",
                        normc_initializer(0.01))
        return logits, last_layer
예제 #9
0
파일: lstm.py 프로젝트: zdpau/ray-1
    def _build_layers_v2(self, input_dict, num_outputs, options):
        cell_size = options.get("lstm_cell_size")
        if options.get("lstm_use_prev_action_reward"):
            action_dim = int(
                np.product(
                    input_dict["prev_actions"].get_shape().as_list()[1:]))
            features = tf.concat(
                [
                    input_dict["obs"],
                    tf.reshape(
                        tf.cast(input_dict["prev_actions"], tf.float32),
                        [-1, action_dim]),
                    tf.reshape(input_dict["prev_rewards"], [-1, 1]),
                ],
                axis=1)
        else:
            features = input_dict["obs"]
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.c], name="c")
            h_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.h], name="h")
            self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = rnn.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(
            lstm,
            last_layer,
            initial_state=state_in,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        self.state_out = list(lstm_state)

        # Compute outputs
        last_layer = tf.reshape(lstm_out, [-1, cell_size])
        logits = linear(last_layer, num_outputs, "action",
                        normc_initializer(0.01))
        return logits, last_layer
예제 #10
0
    def _setup_graph(self, ob_space, ac_space):
        self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
        dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
        self._model = ModelCatalog.get_model(
            self.registry, self.x, self.logit_dim, self.config["model"])
        self.logits = self._model.outputs
        self.curr_dist = dist_class(self.logits)
        self.vf = tf.reshape(linear(self._model.last_layer, 1, "value",
                                    normc_initializer(1.0)), [-1])

        self.sample = self.curr_dist.sample()
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)
        self.global_step = tf.get_variable(
            "global_step", [], tf.int32,
            initializer=tf.constant_initializer(0, dtype=tf.int32),
            trainable=False)
예제 #11
0
파일: lstm.py 프로젝트: zhiyun/ray
    def _build_layers(self, inputs, num_outputs, options):
        cell_size = options.get("lstm_cell_size", 256)
        use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
                         distutils.version.LooseVersion("1.0.0"))
        last_layer = add_time_dimension(inputs, self.seq_lens)

        # Setup the LSTM cell
        if use_tf100_api:
            lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        else:
            lstm = rnn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c],
                                  name="c")
            h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h],
                                  name="h")
            self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        if use_tf100_api:
            state_in = rnn.LSTMStateTuple(c_in, h_in)
        else:
            state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm,
                                                 last_layer,
                                                 initial_state=state_in,
                                                 sequence_length=self.seq_lens,
                                                 time_major=False)
        self.state_out = list(lstm_state)

        # Compute outputs
        last_layer = tf.reshape(lstm_out, [-1, cell_size])
        logits = linear(last_layer, num_outputs, "action",
                        normc_initializer(0.01))
        return logits, last_layer
예제 #12
0
    def _init(self, inputs, num_outputs, options):
        use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
                         distutils.version.LooseVersion("1.0.0"))

        self.x = x = inputs
        for i in range(4):
            x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
        # Introduce a "fake" batch dimension of 1 after flatten so that we can
        # do LSTM over the time dim.
        x = tf.expand_dims(flatten(x), [0])

        size = 256
        if use_tf100_api:
            lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
        else:
            lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
        step_size = tf.shape(self.x)[:1]

        c_init = np.zeros((1, lstm.state_size.c), np.float32)
        h_init = np.zeros((1, lstm.state_size.h), np.float32)
        self.state_init = [c_init, h_init]
        c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
        h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
        self.state_in = [c_in, h_in]

        if use_tf100_api:
            state_in = rnn.LSTMStateTuple(c_in, h_in)
        else:
            state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm,
                                                 x,
                                                 initial_state=state_in,
                                                 sequence_length=step_size,
                                                 time_major=False)
        lstm_c, lstm_h = lstm_state
        x = tf.reshape(lstm_out, [-1, size])
        logits = linear(x, num_outputs, "action", normc_initializer(0.01))
        self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
        return logits, x
예제 #13
0
파일: lstm.py 프로젝트: adgirish/ray
    def _init(self, inputs, num_outputs, options):
        use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
                         distutils.version.LooseVersion("1.0.0"))

        self.x = x = inputs
        for i in range(4):
            x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
        # Introduce a "fake" batch dimension of 1 after flatten so that we can
        # do LSTM over the time dim.
        x = tf.expand_dims(flatten(x), [0])

        size = 256
        if use_tf100_api:
            lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
        else:
            lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
        step_size = tf.shape(self.x)[:1]

        c_init = np.zeros((1, lstm.state_size.c), np.float32)
        h_init = np.zeros((1, lstm.state_size.h), np.float32)
        self.state_init = [c_init, h_init]
        c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
        h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
        self.state_in = [c_in, h_in]

        if use_tf100_api:
            state_in = rnn.LSTMStateTuple(c_in, h_in)
        else:
            state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm, x,
                                                 initial_state=state_in,
                                                 sequence_length=step_size,
                                                 time_major=False)
        lstm_c, lstm_h = lstm_state
        x = tf.reshape(lstm_out, [-1, size])
        logits = linear(x, num_outputs, "action", normc_initializer(0.01))
        self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
        return logits, x
예제 #14
0
파일: lstm.py 프로젝트: yuanfeng0905/ray
    def _build_layers(self, inputs, num_outputs, options):
        cell_size = options.get("lstm_cell_size", 256)
        last_layer = add_time_dimension(inputs, self.seq_lens)

        # Setup the LSTM cell
        lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.c], name="c")
            h_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.h], name="h")
            self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        lstm_out, lstm_state = tf.nn.dynamic_rnn(
            lstm,
            last_layer,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        self.state_out = list(lstm_state)

        # Compute outputs
        last_layer = tf.reshape(lstm_out, [-1, cell_size])
        logits = linear(last_layer, num_outputs, "action",
                        normc_initializer(0.01))
        return logits, last_layer
예제 #15
0
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
        self.config = config
        self.sess = tf.get_default_session()

        # Setup the policy
        self.observations = tf.placeholder(
            tf.float32, [None] + list(observation_space.shape))
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        prev_actions = ModelCatalog.get_action_placeholder(action_space)
        prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
        self.model = ModelCatalog.get_model({
            "obs": self.observations,
            "prev_actions": prev_actions,
            "prev_rewards": prev_rewards
        }, observation_space, logit_dim, self.config["model"])
        action_dist = dist_class(self.model.outputs)
        self.vf = tf.reshape(
            linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
            [-1])
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        # Setup the policy loss
        if isinstance(action_space, gym.spaces.Box):
            ac_size = action_space.shape[0]
            actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
        elif isinstance(action_space, gym.spaces.Discrete):
            actions = tf.placeholder(tf.int64, [None], name="ac")
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for A3C.".format(
                    action_space))
        advantages = tf.placeholder(tf.float32, [None], name="advantages")
        self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
        self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
                            self.vf, self.config["vf_loss_coeff"],
                            self.config["entropy_coeff"])

        # Initialize TFPolicyGraph
        loss_in = [
            ("obs", self.observations),
            ("actions", actions),
            ("prev_actions", prev_actions),
            ("prev_rewards", prev_rewards),
            ("advantages", advantages),
            ("value_targets", self.v_target),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=self.observations,
            action_sampler=action_dist.sample(),
            loss=self.loss.total_loss,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            prev_action_input=prev_actions,
            prev_reward_input=prev_rewards,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"])

        self.stats_fetches = {
            "stats": {
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "policy_loss": self.loss.pi_loss,
                "policy_entropy": self.loss.entropy,
                "grad_gnorm": tf.global_norm(self._grads),
                "var_gnorm": tf.global_norm(self.var_list),
                "vf_loss": self.loss.vf_loss,
                "vf_explained_var": explained_variance(self.v_target, self.vf),
            },
        }

        self.sess.run(tf.global_variables_initializer())
예제 #16
0
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None):
        """
        Arguments:
            observation_space: Environment observation space specification.
            action_space: Environment action space specification.
            config (dict): Configuration values for PPO graph.
            existing_inputs (list): Optional list of tuples that specify the
                placeholders upon which the graph should be built upon.
        """
        config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
        self.sess = tf.get_default_session()
        self.action_space = action_space
        self.config = config
        self.kl_coeff_val = self.config["kl_coeff"]
        self.kl_target = self.config["kl_target"]
        dist_cls, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])

        if existing_inputs:
            obs_ph, value_targets_ph, adv_ph, act_ph, \
                logits_ph, vf_preds_ph = existing_inputs[:6]
            existing_state_in = existing_inputs[6:-1]
            existing_seq_lens = existing_inputs[-1]
        else:
            obs_ph = tf.placeholder(tf.float32,
                                    name="obs",
                                    shape=(None, ) + observation_space.shape)
            adv_ph = tf.placeholder(tf.float32,
                                    name="advantages",
                                    shape=(None, ))
            act_ph = ModelCatalog.get_action_placeholder(action_space)
            logits_ph = tf.placeholder(tf.float32,
                                       name="logits",
                                       shape=(None, logit_dim))
            vf_preds_ph = tf.placeholder(tf.float32,
                                         name="vf_preds",
                                         shape=(None, ))
            value_targets_ph = tf.placeholder(tf.float32,
                                              name="value_targets",
                                              shape=(None, ))
            existing_state_in = None
            existing_seq_lens = None
        self.observations = obs_ph

        self.loss_in = [
            ("obs", obs_ph),
            ("value_targets", value_targets_ph),
            ("advantages", adv_ph),
            ("actions", act_ph),
            ("logits", logits_ph),
            ("vf_preds", vf_preds_ph),
        ]
        self.model = ModelCatalog.get_model(obs_ph,
                                            logit_dim,
                                            self.config["model"],
                                            state_in=existing_state_in,
                                            seq_lens=existing_seq_lens)

        # KL Coefficient
        self.kl_coeff = tf.get_variable(initializer=tf.constant_initializer(
            self.kl_coeff_val),
                                        name="kl_coeff",
                                        shape=(),
                                        trainable=False,
                                        dtype=tf.float32)

        self.logits = self.model.outputs
        curr_action_dist = dist_cls(self.logits)
        self.sampler = curr_action_dist.sample()
        if self.config["use_gae"]:
            if self.config["vf_share_layers"]:
                self.value_function = tf.reshape(
                    linear(self.model.last_layer, 1, "value",
                           normc_initializer(1.0)), [-1])
            else:
                vf_config = self.config["model"].copy()
                # Do not split the last layer of the value function into
                # mean parameters and standard deviation parameters and
                # do not make the standard deviations free variables.
                vf_config["free_log_std"] = False
                vf_config["use_lstm"] = False
                with tf.variable_scope("value_function"):
                    self.value_function = ModelCatalog.get_model(
                        obs_ph, 1, vf_config).outputs
                    self.value_function = tf.reshape(self.value_function, [-1])
        else:
            self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])

        if self.model.state_in:
            max_seq_len = tf.reduce_max(self.model.seq_lens)
            mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
            mask = tf.reshape(mask, [-1])
        else:
            mask = tf.ones_like(adv_ph)

        self.loss_obj = PPOLoss(action_space,
                                value_targets_ph,
                                adv_ph,
                                act_ph,
                                logits_ph,
                                vf_preds_ph,
                                curr_action_dist,
                                self.value_function,
                                self.kl_coeff,
                                mask,
                                entropy_coeff=self.config["entropy_coeff"],
                                clip_param=self.config["clip_param"],
                                vf_clip_param=self.config["vf_clip_param"],
                                vf_loss_coeff=self.config["vf_loss_coeff"],
                                use_gae=self.config["use_gae"])

        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(self,
                               observation_space,
                               action_space,
                               self.sess,
                               obs_input=obs_ph,
                               action_sampler=self.sampler,
                               loss=self.loss_obj.loss,
                               loss_inputs=self.loss_in,
                               state_inputs=self.model.state_in,
                               state_outputs=self.model.state_out,
                               seq_lens=self.model.seq_lens,
                               max_seq_len=config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())
        self.explained_variance = explained_variance(value_targets_ph,
                                                     self.value_function)
        self.stats_fetches = {
            "cur_lr": tf.cast(self.cur_lr, tf.float64),
            "total_loss": self.loss_obj.loss,
            "policy_loss": self.loss_obj.mean_policy_loss,
            "vf_loss": self.loss_obj.mean_vf_loss,
            "vf_explained_var": self.explained_variance,
            "kl": self.loss_obj.mean_kl,
            "entropy": self.loss_obj.mean_entropy
        }
예제 #17
0
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None):
        config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
        assert config["batch_mode"] == "truncate_episodes", \
            "Must use `truncate_episodes` batch mode with V-trace."
        self.config = config
        self.sess = tf.get_default_session()

        # Create input placeholders
        if existing_inputs:
            actions, dones, behaviour_logits, rewards, observations = \
                existing_inputs[:5]
            existing_state_in = existing_inputs[5:-1]
            existing_seq_lens = existing_inputs[-1]
        else:
            if isinstance(action_space, gym.spaces.Discrete):
                ac_size = action_space.n
                actions = tf.placeholder(tf.int64, [None], name="ac")
            else:
                raise UnsupportedSpaceException(
                    "Action space {} is not supported for IMPALA.".format(
                        action_space))
            dones = tf.placeholder(tf.bool, [None], name="dones")
            rewards = tf.placeholder(tf.float32, [None], name="rewards")
            behaviour_logits = tf.placeholder(
                tf.float32, [None, ac_size], name="behaviour_logits")
            observations = tf.placeholder(
                tf.float32, [None] + list(observation_space.shape))
            existing_state_in = None
            existing_seq_lens = None

        # Setup the policy
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        self.model = ModelCatalog.get_model(
            observations,
            logit_dim,
            self.config["model"],
            state_in=existing_state_in,
            seq_lens=existing_seq_lens)
        action_dist = dist_class(self.model.outputs)
        values = tf.reshape(
            linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
            [-1])
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        def to_batches(tensor):
            if self.config["model"]["use_lstm"]:
                B = tf.shape(self.model.seq_lens)[0]
                T = tf.shape(tensor)[0] // B
            else:
                # Important: chop the tensor into batches at known episode cut
                # boundaries. TODO(ekl) this is kind of a hack
                T = self.config["sample_batch_size"]
                B = tf.shape(tensor)[0] // T
            rs = tf.reshape(tensor,
                            tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
            # swap B and T axes
            return tf.transpose(
                rs,
                [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))

        if self.model.state_in:
            max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
            mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
            mask = tf.reshape(mask, [-1])
        else:
            mask = tf.ones_like(rewards)

        # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
        self.loss = VTraceLoss(
            actions=to_batches(actions)[:-1],
            actions_logp=to_batches(action_dist.logp(actions))[:-1],
            actions_entropy=to_batches(action_dist.entropy())[:-1],
            dones=to_batches(dones)[:-1],
            behaviour_logits=to_batches(behaviour_logits)[:-1],
            target_logits=to_batches(self.model.outputs)[:-1],
            discount=config["gamma"],
            rewards=to_batches(rewards)[:-1],
            values=to_batches(values)[:-1],
            bootstrap_value=to_batches(values)[-1],
            valid_mask=to_batches(mask)[:-1],
            vf_loss_coeff=self.config["vf_loss_coeff"],
            entropy_coeff=self.config["entropy_coeff"],
            clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])

        # Initialize TFPolicyGraph
        loss_in = [
            ("actions", actions),
            ("dones", dones),
            ("behaviour_logits", behaviour_logits),
            ("rewards", rewards),
            ("obs", observations),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=observations,
            action_sampler=action_dist.sample(),
            loss=self.loss.total_loss,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())

        self.stats_fetches = {
            "stats": {
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "policy_loss": self.loss.pi_loss,
                "entropy": self.loss.entropy,
                "grad_gnorm": tf.global_norm(self._grads),
                "var_gnorm": tf.global_norm(self.var_list),
                "vf_loss": self.loss.vf_loss,
                "vf_explained_var": explained_variance(
                    tf.reshape(self.loss.vtrace_returns.vs, [-1]),
                    tf.reshape(to_batches(values)[:-1], [-1])),
            },
        }
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 existing_inputs=None,
                 unsupType='action',
                 designHead='universe'):
        """
        Arguments:
            observation_space: Environment observation space specification.
            action_space: Environment action space specification.
            config (dict): Configuration values for PPO graph.
            existing_inputs (list): Optional list of tuples that specify the
                placeholders upon which the graph should be built upon.
        """
        self.unsup = unsupType is not None
        # self.cur_batch = None
        # self.cur_sample_batch = {}

        predictor = None
        numaction = action_space.n

        config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
        self.sess = tf.get_default_session()
        self.action_space = action_space
        self.config = config
        self.kl_coeff_val = self.config["kl_coeff"]
        self.kl_target = self.config["kl_target"]

        dist_cls, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])

        if existing_inputs:
            obs_ph, value_targets_ph, adv_ph, act_ph, \
                logits_ph, vf_preds_ph, phi1, phi2, asample = existing_inputs[:9]

            # TODO: updates to account for s1, s2 and asample
            existing_state_in = existing_inputs[9:-1]
            existing_seq_lens = existing_inputs[-1]
        else:
            obs_ph = tf.placeholder(tf.float32,
                                    name="obs",
                                    shape=(None, ) + observation_space.shape)
            adv_ph = tf.placeholder(tf.float32,
                                    name="advantages",
                                    shape=(None, ))
            act_ph = ModelCatalog.get_action_placeholder(action_space)
            logits_ph = tf.placeholder(tf.float32,
                                       name="logits",
                                       shape=(None, logit_dim))
            vf_preds_ph = tf.placeholder(tf.float32,
                                         name="vf_preds",
                                         shape=(None, ))
            value_targets_ph = tf.placeholder(tf.float32,
                                              name="value_targets",
                                              shape=(None, ))
            phi1 = tf.placeholder(tf.float32,
                                  shape=(None, ) + observation_space.shape,
                                  name="phi1")
            phi2 = tf.placeholder(tf.float32,
                                  shape=(None, ) + observation_space.shape,
                                  name="phi2")
            asample = tf.placeholder(tf.float32,
                                     shape=(None, numaction),
                                     name="asample")

            existing_state_in = None
            existing_seq_lens = None
        self.observations = obs_ph

        if self.unsup:
            with tf.variable_scope("predictor"):
                if 'state' in unsupType:
                    self.local_ap_network = predictor = StatePredictor(
                        phi1, phi2, asample, observation_space, numaction,
                        designHead, unsupType)
                else:
                    self.local_ap_network = predictor = StateActionPredictor(
                        phi1, phi2, asample, observation_space, numaction,
                        designHead)

        self.model = pi = ModelCatalog.get_model(obs_ph,
                                                 logit_dim,
                                                 self.config["model"],
                                                 state_in=existing_state_in,
                                                 seq_lens=existing_seq_lens)

        # KL Coefficient
        self.kl_coeff = tf.get_variable(initializer=tf.constant_initializer(
            self.kl_coeff_val),
                                        name="kl_coeff",
                                        shape=(),
                                        trainable=False,
                                        dtype=tf.float32)

        self.logits = self.model.outputs
        curr_action_dist = dist_cls(self.logits)
        self.sampler = curr_action_dist.sample()

        if self.config["use_gae"]:
            if self.config["vf_share_layers"]:
                self.value_function = tf.reshape(
                    linear(self.model.last_layer, 1, "value",
                           normc_initializer(1.0)), [-1])
            else:
                vf_config = self.config["model"].copy()
                # Do not split the last layer of the value function into
                # mean parameters and standard deviation parameters and
                # do not make the standard deviations free variables.
                vf_config["free_log_std"] = False
                vf_config["use_lstm"] = False
                with tf.variable_scope("value_function"):
                    self.value_function = ModelCatalog.get_model(
                        obs_ph, 1, vf_config).outputs
                    self.value_function = tf.reshape(self.value_function, [-1])
        else:
            self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])

        self.loss_obj = PPOLoss(action_space,
                                value_targets_ph,
                                adv_ph,
                                act_ph,
                                logits_ph,
                                vf_preds_ph,
                                curr_action_dist,
                                self.value_function,
                                self.kl_coeff,
                                unsupType,
                                predictor,
                                entropy_coeff=self.config["entropy_coeff"],
                                clip_param=self.config["clip_param"],
                                vf_clip_param=self.config["vf_clip_param"],
                                vf_loss_coeff=self.config["vf_loss_coeff"],
                                use_gae=self.config["use_gae"])

        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])

        self.loss_in = [
            ("obs", obs_ph),
            ("value_targets", value_targets_ph),
            ("advantages", adv_ph),
            ("actions", act_ph),
            ("logits", logits_ph),
            ("vf_preds", vf_preds_ph),
            ("s1", phi1),
            ("s2", phi2),
            ("asample", asample),
        ]

        self.extra_inputs = ["s1", "s2", "asample"]

        # TODO: testing to see if this lets me pass inputs to ICM
        # self.variables = ray.experimental.TensorFlowVariables(self.loss_in, self.sess)

        TFPolicyGraph.__init__(self,
                               observation_space,
                               action_space,
                               self.sess,
                               obs_input=obs_ph,
                               action_sampler=self.sampler,
                               loss=self.loss_obj.loss,
                               loss_inputs=self.loss_in,
                               state_inputs=self.model.state_in,
                               state_outputs=self.model.state_out,
                               seq_lens=self.model.seq_lens,
                               max_seq_len=config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())
        self.explained_variance = explained_variance(value_targets_ph,
                                                     self.value_function)
        self.stats_fetches = {
            "cur_lr": tf.cast(self.cur_lr, tf.float64),
            "total_loss": self.loss_obj.loss,
            "policy_loss": self.loss_obj.mean_policy_loss,
            "vf_loss": self.loss_obj.mean_vf_loss,
            "vf_explained_var": self.explained_variance,
            "kl": self.loss_obj.mean_kl,
            "entropy": self.loss_obj.mean_entropy
        }
예제 #19
0
    def _build_layers_v2(self, input_dict, num_outputs, options):
        def spy(sequences, state_in, state_out, seq_lens):
            if len(sequences) == 1:
                return 0  # don't capture inference inputs
            # TF runs this function in an isolated context, so we have to use
            # redis to communicate back to our suite
            ray.experimental.internal_kv._internal_kv_put(
                "rnn_spy_in_{}".format(RNNSpyModel.capture_index),
                pickle.dumps({
                    "sequences": sequences,
                    "state_in": state_in,
                    "state_out": state_out,
                    "seq_lens": seq_lens
                }),
                overwrite=True)
            RNNSpyModel.capture_index += 1
            return 0

        features = input_dict["obs"]
        cell_size = 3
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.c], name="c")
            h_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.h], name="h")
        self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = rnn.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(
            lstm,
            last_layer,
            initial_state=state_in,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        self.state_out = list(lstm_state)
        spy_fn = tf.py_func(
            spy, [
                last_layer,
                self.state_in,
                self.state_out,
                self.seq_lens,
            ],
            tf.int64,
            stateful=True)

        # Compute outputs
        with tf.control_dependencies([spy_fn]):
            last_layer = tf.reshape(lstm_out, [-1, cell_size])
            logits = linear(last_layer, num_outputs, "action",
                            normc_initializer(0.01))
        return logits, last_layer
예제 #20
0
    def _build_layers(self, inputs, num_outputs, options):
        cell_size = options.get("lstm_cell_size", 256)
        use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
                         distutils.version.LooseVersion("1.0.0"))
        last_layer = add_time_dimension(inputs, self.seq_lens)

        # Setup the LSTM cell
        if use_tf100_api:
            lstm1 = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
            lstm2 = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        else:
            lstm1 = rnn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
            lstm2 = rnn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm1.state_size.c, np.float32),
            np.zeros(lstm1.state_size.h, np.float32),
            np.zeros(lstm2.state_size.c, np.float32),
            np.zeros(lstm2.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c1_in, h1_in, c2_in, h2_in = self.state_in
        else:
            c1_in = tf.placeholder(tf.float32, [None, lstm1.state_size.c],
                                   name="c1")
            h1_in = tf.placeholder(tf.float32, [None, lstm1.state_size.h],
                                   name="h1")
            c2_in = tf.placeholder(tf.float32, [None, lstm2.state_size.c],
                                   name="c2")
            h2_in = tf.placeholder(tf.float32, [None, lstm2.state_size.h],
                                   name="h2")
            self.state_in = [c1_in, h1_in, c2_in, h2_in]

        # Setup LSTM outputs
        if use_tf100_api:
            state1_in = rnn.LSTMStateTuple(c1_in, h1_in)
            state2_in = rnn.LSTMStateTuple(c2_in, h2_in)
        else:
            state1_in = rnn.rnn_cell.LSTMStateTuple(c1_in, h1_in)
            state2_in = rnn.rnn_cell.LSTMStateTuple(c2_in, h2_in)
        lstm1_out, lstm1_state = tf.nn.dynamic_rnn(
            lstm1,
            last_layer,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        with tf.variable_scope("value_function"):
            lstm2_out, lstm2_state = tf.nn.dynamic_rnn(
                lstm2,
                last_layer,
                sequence_length=self.seq_lens,
                time_major=False,
                dtype=tf.float32)

        self.value_function = tf.reshape(
            linear(tf.reshape(lstm2_out, [-1, cell_size]), 1, "vf",
                   normc_initializer(0.01)), [-1])

        self.state_out = list(lstm1_state) + list(lstm2_state)

        # Compute outputs
        last_layer = tf.reshape(lstm1_out, [-1, cell_size])
        logits = linear(last_layer, num_outputs, "action",
                        normc_initializer(0.01))

        return logits, last_layer
예제 #21
0
    def _build_layers_v2(self, input_dict, num_outputs, options):
        def spy(sequences, state_in, state_out, seq_lens):
            if len(sequences) == 1:
                return 0  # don't capture inference inputs
            # TF runs this function in an isolated context, so we have to use
            # redis to communicate back to our suite
            ray.experimental.internal_kv._internal_kv_put(
                "rnn_spy_in_{}".format(RNNSpyModel.capture_index),
                pickle.dumps({
                    "sequences": sequences,
                    "state_in": state_in,
                    "state_out": state_out,
                    "seq_lens": seq_lens
                }),
                overwrite=True)
            RNNSpyModel.capture_index += 1
            return 0

        features = input_dict["obs"]
        cell_size = 3
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c],
                                  name="c")
            h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h],
                                  name="h")
        self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = rnn.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm,
                                                 last_layer,
                                                 initial_state=state_in,
                                                 sequence_length=self.seq_lens,
                                                 time_major=False,
                                                 dtype=tf.float32)

        self.state_out = list(lstm_state)
        spy_fn = tf.py_func(spy, [
            last_layer,
            self.state_in,
            self.state_out,
            self.seq_lens,
        ],
                            tf.int64,
                            stateful=True)

        # Compute outputs
        with tf.control_dependencies([spy_fn]):
            last_layer = tf.reshape(lstm_out, [-1, cell_size])
            logits = linear(last_layer, num_outputs, "action",
                            normc_initializer(0.01))
        return logits, last_layer
예제 #22
0
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
        self.config = config
        self.sess = tf.get_default_session()

        # Setup the policy
        self.observations = tf.placeholder(
            tf.float32, [None] + list(observation_space.shape))
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        self.model = ModelCatalog.get_model(
            self.observations, logit_dim, self.config["model"])
        action_dist = dist_class(self.model.outputs)
        self.vf = tf.reshape(
            linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
            [-1])
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)
        is_training = tf.placeholder_with_default(True, ())

        # Setup the policy loss
        if isinstance(action_space, gym.spaces.Box):
            ac_size = action_space.shape[0]
            actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
        elif isinstance(action_space, gym.spaces.Discrete):
            actions = tf.placeholder(tf.int64, [None], name="ac")
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for A3C.".format(
                    action_space))
        advantages = tf.placeholder(tf.float32, [None], name="advantages")
        v_target = tf.placeholder(tf.float32, [None], name="v_target")
        self.loss = A3CLoss(
            action_dist, actions, advantages, v_target, self.vf,
            self.config["vf_loss_coeff"], self.config["entropy_coeff"])

        # Initialize TFPolicyGraph
        loss_in = [
            ("obs", self.observations),
            ("actions", actions),
            ("advantages", advantages),
            ("value_targets", v_target),
        ]
        for i, ph in enumerate(self.model.state_in):
            loss_in.append(("state_in_{}".format(i), ph))
        self.state_in = self.model.state_in
        self.state_out = self.model.state_out
        TFPolicyGraph.__init__(
            self, observation_space, action_space, self.sess,
            obs_input=self.observations, action_sampler=action_dist.sample(),
            loss=self.loss.total_loss, loss_inputs=loss_in,
            is_training=is_training, state_inputs=self.state_in,
            state_outputs=self.state_out,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"])

        if self.config.get("summarize"):
            bs = tf.to_float(tf.shape(self.observations)[0])
            tf.summary.scalar("model/policy_graph", self.loss.pi_loss / bs)
            tf.summary.scalar("model/value_loss", self.loss.vf_loss / bs)
            tf.summary.scalar("model/entropy", self.loss.entropy / bs)
            tf.summary.scalar("model/grad_gnorm", tf.global_norm(self._grads))
            tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list))
            self.summary_op = tf.summary.merge_all()

        self.sess.run(tf.global_variables_initializer())
예제 #23
0
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
        assert config["batch_mode"] == "truncate_episodes", \
            "Must use `truncate_episodes` batch mode with V-trace."
        self.config = config
        self.sess = tf.get_default_session()

        # Setup the policy
        self.observations = tf.placeholder(tf.float32, [None] +
                                           list(observation_space.shape))
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        self.model = ModelCatalog.get_model(self.observations, logit_dim,
                                            self.config["model"])
        action_dist = dist_class(self.model.outputs)
        values = tf.reshape(
            linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
            [-1])
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        # Setup the policy loss
        if isinstance(action_space, gym.spaces.Box):
            ac_size = action_space.shape[0]
            actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
        elif isinstance(action_space, gym.spaces.Discrete):
            ac_size = action_space.n
            actions = tf.placeholder(tf.int64, [None], name="ac")
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for IMPALA.".format(
                    action_space))
        dones = tf.placeholder(tf.bool, [None], name="dones")
        rewards = tf.placeholder(tf.float32, [None], name="rewards")
        behaviour_logits = tf.placeholder(tf.float32, [None, ac_size],
                                          name="behaviour_logits")

        def to_batches(tensor):
            if self.config["model"]["use_lstm"]:
                B = tf.shape(self.model.seq_lens)[0]
                T = tf.shape(tensor)[0] // B
            else:
                # Important: chop the tensor into batches at known episode cut
                # boundaries. TODO(ekl) this is kind of a hack
                T = (self.config["sample_batch_size"] //
                     self.config["num_envs_per_worker"])
                B = tf.shape(tensor)[0] // T
            rs = tf.reshape(tensor,
                            tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
            # swap B and T axes
            return tf.transpose(
                rs,
                [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))

        if self.config["clip_rewards"]:
            clipped_rewards = tf.clip_by_value(rewards, -1, 1)
        else:
            clipped_rewards = rewards

        # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
        self.loss = VTraceLoss(
            actions=to_batches(actions)[:-1],
            actions_logp=to_batches(action_dist.logp(actions))[:-1],
            actions_entropy=to_batches(action_dist.entropy())[:-1],
            dones=to_batches(dones)[:-1],
            behaviour_logits=to_batches(behaviour_logits)[:-1],
            target_logits=to_batches(self.model.outputs)[:-1],
            discount=config["gamma"],
            rewards=to_batches(clipped_rewards)[:-1],
            values=to_batches(values)[:-1],
            bootstrap_value=to_batches(values)[-1],
            vf_loss_coeff=self.config["vf_loss_coeff"],
            entropy_coeff=self.config["entropy_coeff"],
            clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])

        # Initialize TFPolicyGraph
        loss_in = [
            ("actions", actions),
            ("dones", dones),
            ("behaviour_logits", behaviour_logits),
            ("rewards", rewards),
            ("obs", self.observations),
        ]
        TFPolicyGraph.__init__(self,
                               observation_space,
                               action_space,
                               self.sess,
                               obs_input=self.observations,
                               action_sampler=action_dist.sample(),
                               loss=self.loss.total_loss,
                               loss_inputs=loss_in,
                               state_inputs=self.model.state_in,
                               state_outputs=self.model.state_out,
                               seq_lens=self.model.seq_lens,
                               max_seq_len=self.config["model"]["max_seq_len"])

        self.sess.run(tf.global_variables_initializer())