Exemple #1
0
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        output, new_state = self.forward_rnn(
            add_time_dimension(input_dict["obs_flat"], seq_lens), state,
            seq_lens)
        return tf.reshape(output, [-1, self.num_outputs]), new_state
Exemple #2
0
  def _build(self, input_dict, num_outputs, options):
    if options.get("use_lstm"):
      cell_size = options.get("lstm_cell_size")
      self.state_init = (
          np.zeros([cell_size], np.float32),
      )
      self.state_in = (
          tf.placeholder(tf.float32, [None, cell_size], name="state_in"),
      )
    else:
      self.state_init = ()
      self.state_in = ()

    if self._imitation:
      obs_embed = ssbm_spaces.slippi_conv_list[0].embed(input_dict["obs"])
      prev_actions = input_dict["prev_actions"]
    else:
      obs_embed = input_dict["obs"]
      prev_actions = tf.unstack(input_dict["prev_actions"], axis=-1)
    
    action_config = nest.flatten(ssbm_actions.repeated_simple_controller_config)
    prev_actions_embed = tf.concat([
        conv.embed(action) for conv, action
        in zip(action_config, prev_actions)], -1)
    
    prev_rewards_embed = tf.expand_dims(input_dict["prev_rewards"], -1)
    inputs = tf.concat([obs_embed, prev_actions_embed, prev_rewards_embed], -1)

    trunk = snt.nets.MLP(
        output_sizes=options["fcnet_hiddens"],
        activation=getattr(tf.nn, options["fcnet_activation"]),
        activate_final=True)

    if not self._imitation:
      inputs = lstm.add_time_dimension(inputs, self.seq_lens)

    trunk_outputs = snt.BatchApply(trunk)(inputs)
    
    if options.get("use_lstm"):
      gru = snt.GRU(cell_size)
      core_outputs, state_out = tf.nn.dynamic_rnn(
          gru,
          trunk_outputs,
          initial_state=self.state_in[0],
          sequence_length=None if self._imitation else self.seq_lens,
          time_major=self._imitation)
      self.state_out = [state_out]
    else:
      core_outputs = trunk_outputs
      self.state_out = []

    self._logit_head = snt.Linear(num_outputs, name="logit_head")
    logits = snt.BatchApply(self._logit_head)(core_outputs)
    self.values = tf.squeeze(snt.BatchApply(self._value_head)(core_outputs), -1)
    
    return logits, core_outputs
    def _build_layers_v2(self, input_dict, num_outputs, options):
        options = options["custom_options"]
        rnn_units = options["rnn_units"]
        rnn_output_units = options["rnn_output_units"]
        mlp_hidden_units = options["mlp_hidden_units"]
        vf_share_layers = options["vf_share_layers"]
        use_linear_baseline = options["linear_baseline"]
        rnn_output_activation = options.get("rnn_output_activation", "tanh")
        mlp_activation = options.get("mlp_activation", "tanh")

        with tf.name_scope("rnn_inputs"):
            rnn_inputs = tf.concat([
                input_dict["obs_buffer"], input_dict["logits_buffer"],
                tf.expand_dims(input_dict["reward_buffer"], axis=1)
            ],
                                   axis=1)
            with tf.name_scope("add_time_dimension"):
                rnn_inputs = add_time_dimension(rnn_inputs, self.seq_lens)
        mlp_inputs = input_dict["obs"]
        model_inputs = {
            "rnn": rnn_inputs,
            "seq_lens": self.seq_lens,
            "mlp": mlp_inputs
        }
        model = KerasTESPWithAdapPolicy(
            rnn_units,
            rnn_output_units,
            rnn_output_activation,
            mlp_hidden_units + [num_outputs],
            mlp_activation,
            vf_share_layers=vf_share_layers,
            custom_params=self.custom_params["policy"],
            use_linear_baseline=use_linear_baseline)
        self.policy_model = model
        if use_linear_baseline:
            output = model(model_inputs)
        else:
            output, self._value_function = model(model_inputs)

        # last_layer = model.mlp.layers[-1]

        # self.rnn_state_in = self.keras_model.initial_state
        # if not nest.is_sequence(self.rnn_state_in):
        #     self.rnn_state_in = [self.rnn_state_in]
        self.rnn_state_out = nest.flatten(self.policy_model.state)
        self.rnn_state_out_init = [
            np.zeros(state.shape.as_list(), dtype=state.dtype.as_numpy_dtype)
            for state in self.rnn_state_out
        ]

        return output, None
Exemple #4
0
    def call(self, inputs, seqlens=None, initial_state=None):
        features = inputs['obs']

        if self._use_conv:
            features = self.conv_layer(features)

        features = add_time_dimension(features, seqlens)
        self.features = features
        latent, *rnn_state = self.rnn(features, initial_state=initial_state)
        self.latent = latent
        latent = tf.reshape(latent, [-1, latent.shape[-1]])

        state_out = list(rnn_state)

        # latent = self.dense_layer(latent)
        logits = self.output_layer(latent)

        output = {'latent': latent, 'logits': logits, 'state_out': state_out}

        self.output_tensors = output

        return output
Exemple #5
0
    def _build_layers_v2(self, input_dict, num_outputs, options):
        def spy(sequences, state_in, state_out, seq_lens):
            if len(sequences) == 1:
                return 0  # don't capture inference inputs
            # TF runs this function in an isolated context, so we have to use
            # redis to communicate back to our suite
            ray.experimental.internal_kv._internal_kv_put(
                "rnn_spy_in_{}".format(RNNSpyModel.capture_index),
                pickle.dumps({
                    "sequences": sequences,
                    "state_in": state_in,
                    "state_out": state_out,
                    "seq_lens": seq_lens
                }),
                overwrite=True)
            RNNSpyModel.capture_index += 1
            return 0

        features = input_dict["obs"]
        cell_size = 3
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c],
                                  name="c")
            h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h],
                                  name="h")
        self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = rnn.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm,
                                                 last_layer,
                                                 initial_state=state_in,
                                                 sequence_length=self.seq_lens,
                                                 time_major=False,
                                                 dtype=tf.float32)

        self.state_out = list(lstm_state)
        spy_fn = tf.py_func(spy, [
            last_layer,
            self.state_in,
            self.state_out,
            self.seq_lens,
        ],
                            tf.int64,
                            stateful=True)

        # Compute outputs
        with tf.control_dependencies([spy_fn]):
            last_layer = tf.reshape(lstm_out, [-1, cell_size])
            logits = linear(last_layer, num_outputs, "action",
                            normc_initializer(0.01))
        return logits, last_layer
Exemple #6
0
    def _build_layers_v2(self, input_dict, num_outputs, options):
        def spy(sequences, state_in, state_out, seq_lens):
            if len(sequences) == 1:
                return 0  # don't capture inference inputs
            # TF runs this function in an isolated context, so we have to use
            # redis to communicate back to our suite
            ray.experimental.internal_kv._internal_kv_put(
                "rnn_spy_in_{}".format(RNNSpyModel.capture_index),
                pickle.dumps({
                    "sequences": sequences,
                    "state_in": state_in,
                    "state_out": state_out,
                    "seq_lens": seq_lens
                }),
                overwrite=True)
            RNNSpyModel.capture_index += 1
            return 0

        features = input_dict["obs"]
        cell_size = 3
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.c], name="c")
            h_in = tf.placeholder(
                tf.float32, [None, lstm.state_size.h], name="h")
        self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = rnn.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(
            lstm,
            last_layer,
            initial_state=state_in,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        self.state_out = list(lstm_state)
        spy_fn = tf.py_func(
            spy, [
                last_layer,
                self.state_in,
                self.state_out,
                self.seq_lens,
            ],
            tf.int64,
            stateful=True)

        # Compute outputs
        with tf.control_dependencies([spy_fn]):
            last_layer = tf.reshape(lstm_out, [-1, cell_size])
            logits = linear(last_layer, num_outputs, "action",
                            normc_initializer(0.01))
        return logits, last_layer