Ejemplo n.º 1
0
    def call(
        self, input_dict: SampleBatch
    ) -> (TensorType, List[TensorType], Dict[str, TensorType]):
        assert input_dict[SampleBatch.SEQ_LENS] is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _, _ = self.wrapped_keras_model(input_dict)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_n_prev_actions:
            if isinstance(self.action_space, Discrete):
                for i in range(self.use_n_prev_actions):
                    prev_a_r.append(
                        one_hot(
                            input_dict[SampleBatch.PREV_ACTIONS][:, i],
                            self.action_space,
                        ))
            elif isinstance(self.action_space, MultiDiscrete):
                for i in range(0, self.use_n_prev_actions,
                               self.action_space.shape[0]):
                    prev_a_r.append(
                        one_hot(
                            tf.cast(
                                input_dict[SampleBatch.PREV_ACTIONS]
                                [:, i:i + self.action_space.shape[0]],
                                tf.float32,
                            ),
                            self.action_space,
                        ))
            else:
                prev_a_r.append(
                    tf.reshape(
                        tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
                                tf.float32),
                        [-1, self.use_n_prev_actions * self.action_dim],
                    ))
        if self.use_n_prev_rewards:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, self.use_n_prev_rewards],
                ))

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        memory_ins = [
            s for k, s in input_dict.items() if k.startswith("state_in_")
        ]
        model_out, memory_outs, value_outs = self.base_model([wrapped_out] +
                                                             memory_ins)
        return (
            model_out,
            memory_outs,
            {
                SampleBatch.VF_PREDS: tf.reshape(value_outs, [-1])
            },
        )
Ejemplo n.º 2
0
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.model_config["lstm_use_prev_action"]:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                prev_a = one_hot(prev_a, self.action_space)
            prev_a_r.append(
                tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim]))
        if self.model_config["lstm_use_prev_reward"]:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, 1]))

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        # Then through our LSTM.
        input_dict["obs_flat"] = wrapped_out
        return super().forward(input_dict, state, seq_lens)
Ejemplo n.º 3
0
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_n_prev_actions:
            if isinstance(self.action_space, Discrete):
                for i in range(self.use_n_prev_actions):
                    prev_a_r.append(
                        one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i],
                                self.action_space))
            elif isinstance(self.action_space, MultiDiscrete):
                for i in range(0, self.use_n_prev_actions,
                               self.action_space.shape[0]):
                    prev_a_r.append(
                        one_hot(
                            tf.cast(
                                input_dict[SampleBatch.PREV_ACTIONS]
                                [:, i:i + self.action_space.shape[0]],
                                tf.float32), self.action_space))
            else:
                prev_a_r.append(
                    tf.reshape(
                        tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
                                tf.float32),
                        [-1, self.use_n_prev_actions * self.action_dim]))
        if self.use_n_prev_rewards:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, self.use_n_prev_rewards]))

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        # Then through our GTrXL.
        input_dict["obs_flat"] = input_dict["obs"] = wrapped_out

        self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
        model_out = self._logits_branch(self._features)
        return model_out, memory_outs
Ejemplo n.º 4
0
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []

        # Prev actions.
        if self.model_config["lstm_use_prev_action"]:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            # If actions are not processed yet (in their original form as
            # have been sent to environment):
            # Flatten/one-hot into 1D array.
            if self.model_config["_disable_action_flattening"]:
                prev_a_r.append(
                    flatten_inputs_to_1d_tensor(
                        prev_a,
                        spaces_struct=self.action_space_struct,
                        time_axis=False,
                    )
                )
            # If actions are already flattened (but not one-hot'd yet!),
            # one-hot discrete/multi-discrete actions here.
            else:
                if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                    prev_a = one_hot(prev_a, self.action_space)
                prev_a_r.append(
                    tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])
                )
        # Prev rewards.
        if self.model_config["lstm_use_prev_reward"]:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]
                )
            )

        # Concat prev. actions + rewards to the "main" input.
        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        # Push everything through our LSTM.
        input_dict["obs_flat"] = wrapped_out
        return super().forward(input_dict, state, seq_lens)
Ejemplo n.º 5
0
    def forward(self, input_dict, state, seq_lens):
        if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
            orig_obs = input_dict[SampleBatch.OBS]
        else:
            orig_obs = restore_original_dimensions(
                input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="tf"
            )
        # Push image observations through our CNNs.
        outs = []
        for i, component in enumerate(tree.flatten(orig_obs)):
            if i in self.cnns:
                cnn_out, _ = self.cnns[i](SampleBatch({SampleBatch.OBS: component}))
                outs.append(cnn_out)
            elif i in self.one_hot:
                if "int" in component.dtype.name:
                    one_hot_in = {
                        SampleBatch.OBS: one_hot(
                            component, self.flattened_input_space[i]
                        )
                    }
                else:
                    one_hot_in = {SampleBatch.OBS: component}
                one_hot_out, _ = self.one_hot[i](SampleBatch(one_hot_in))
                outs.append(one_hot_out)
            else:
                nn_out, _ = self.flatten[i](
                    SampleBatch(
                        {
                            SampleBatch.OBS: tf.cast(
                                tf.reshape(component, [-1, self.flatten_dims[i]]),
                                tf.float32,
                            )
                        }
                    )
                )
                outs.append(nn_out)
        # Concat all outputs and the non-image inputs.
        out = tf.concat(outs, axis=1)
        # Push through (optional) FC-stack (this may be an empty stack).
        out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out}))

        # No logits/value branches.
        if not self.logits_and_value_model:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_and_value_model(out)
        self._value_out = tf.reshape(values, [-1])
        return logits, []
Ejemplo n.º 6
0
    def call(
        self, input_dict: SampleBatch
    ) -> (TensorType, List[TensorType], Dict[str, TensorType]):
        assert input_dict.get(SampleBatch.SEQ_LENS) is not None
        # Push obs through underlying (wrapped) model first.
        wrapped_out, _, _ = self.wrapped_keras_model(input_dict)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.lstm_use_prev_action:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                prev_a = one_hot(prev_a, self.action_space)
            prev_a_r.append(
                tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])
            )
        if self.lstm_use_prev_reward:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]
                )
            )

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        max_seq_len = (
            tf.shape(wrapped_out)[0] // tf.shape(input_dict[SampleBatch.SEQ_LENS])[0]
        )
        wrapped_out_plus_time_dim = add_time_dimension(
            wrapped_out, max_seq_len=max_seq_len, framework="tf"
        )
        model_out, value_out, h, c = self._rnn_model(
            [
                wrapped_out_plus_time_dim,
                input_dict[SampleBatch.SEQ_LENS],
                input_dict["state_in_0"],
                input_dict["state_in_1"],
            ]
        )
        model_out_no_time_dim = tf.reshape(
            model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0)
        )
        return (
            model_out_no_time_dim,
            [h, c],
            {SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])},
        )
Ejemplo n.º 7
0
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []

        # Prev actions.
        if self.use_n_prev_actions:
            prev_n_actions = input_dict[SampleBatch.PREV_ACTIONS]
            # If actions are not processed yet (in their original form as
            # have been sent to environment):
            # Flatten/one-hot into 1D array.
            if self.model_config["_disable_action_flattening"]:
                # Merge prev n actions into flat tensor.
                flat = flatten_inputs_to_1d_tensor(
                    prev_n_actions,
                    spaces_struct=self.action_space_struct,
                    time_axis=True,
                )
                # Fold time-axis into flattened data.
                flat = tf.reshape(flat, [tf.shape(flat)[0], -1])
                prev_a_r.append(flat)
            # If actions are already flattened (but not one-hot'd yet!),
            # one-hot discrete/multi-discrete actions here and concatenate the
            # n most recent actions together.
            else:
                if isinstance(self.action_space, Discrete):
                    for i in range(self.use_n_prev_actions):
                        prev_a_r.append(
                            one_hot(prev_n_actions[:, i], self.action_space))
                elif isinstance(self.action_space, MultiDiscrete):
                    for i in range(0, self.use_n_prev_actions,
                                   self.action_space.shape[0]):
                        prev_a_r.append(
                            one_hot(tf.cast(
                                prev_n_actions[:, i:i +
                                               self.action_space.shape[0]],
                                tf.float32),
                                    space=self.action_space))
                else:
                    prev_a_r.append(
                        tf.reshape(
                            tf.cast(prev_n_actions, tf.float32),
                            [-1, self.use_n_prev_actions * self.action_dim]))
        # Prev rewards.
        if self.use_n_prev_rewards:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, self.use_n_prev_rewards]))

        # Concat prev. actions + rewards to the "main" input.
        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        # Then through our GTrXL.
        input_dict["obs_flat"] = input_dict["obs"] = wrapped_out

        self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
        model_out = self._logits_branch(self._features)
        return model_out, memory_outs
 def forward(self, input_dict, states, seq_lens):
     obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
     rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
     actions = one_hot(input_dict["prev_n_actions"], self.action_space)
     out, self._last_value = self.base_model([obs, actions, rewards])
     return out, []