Example #1
0
    def forward(self, input_dict, state, seq_lens):
        if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
            orig_obs = input_dict[SampleBatch.OBS]
        else:
            orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS],
                                                   self.obs_space, "tf")
        # Push image observations through our CNNs.
        outs = []
        for i, component in enumerate(orig_obs):
            if i in self.cnns:
                cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
                outs.append(cnn_out)
            elif i in self.one_hot:
                if component.dtype in [tf.int32, tf.int64, tf.uint8]:
                    outs.append(
                        one_hot(component, self.original_space.spaces[i]))
                else:
                    outs.append(component)
            else:
                outs.append(tf.reshape(component, [-1, self.flatten[i]]))
        # Concat all outputs and the non-image inputs.
        out = tf.concat(outs, axis=1)
        # Push through (optional) FC-stack (this may be an empty stack).
        out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None)

        # No logits/value branches.
        if not self.logits_and_value_model:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_and_value_model(out)
        self._value_out = tf.reshape(values, [-1])
        return logits, []
Example #2
0
    def forward(self, input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.model_config["lstm_use_prev_action"]:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                prev_a = one_hot(prev_a, self.action_space)
            prev_a_r.append(
                tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim]))
        if self.model_config["lstm_use_prev_reward"]:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, 1]))

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        # Then through our LSTM.
        input_dict["obs_flat"] = wrapped_out
        return super().forward(input_dict, state, seq_lens)
Example #3
0
    def call(self, input_dict: SampleBatch) -> \
            (TensorType, List[TensorType], Dict[str, TensorType]):
        assert input_dict[SampleBatch.SEQ_LENS] is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _, _ = self.wrapped_keras_model(input_dict)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_n_prev_actions:
            if isinstance(self.action_space, Discrete):
                for i in range(self.use_n_prev_actions):
                    prev_a_r.append(
                        one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i],
                                self.action_space))
            elif isinstance(self.action_space, MultiDiscrete):
                for i in range(0, self.use_n_prev_actions,
                               self.action_space.shape[0]):
                    prev_a_r.append(
                        one_hot(
                            tf.cast(
                                input_dict[SampleBatch.PREV_ACTIONS]
                                [:, i:i + self.action_space.shape[0]],
                                tf.float32), self.action_space))
            else:
                prev_a_r.append(
                    tf.reshape(
                        tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
                                tf.float32),
                        [-1, self.use_n_prev_actions * self.action_dim]))
        if self.use_n_prev_rewards:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, self.use_n_prev_rewards]))

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        memory_ins = [
            s for k, s in input_dict.items() if k.startswith("state_in_")
        ]
        model_out, memory_outs, value_outs = self.base_model([wrapped_out] +
                                                             memory_ins)
        return model_out, memory_outs, {
            SampleBatch.VF_PREDS: tf.reshape(value_outs, [-1])
        }
Example #4
0
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_n_prev_actions:
            if isinstance(self.action_space, Discrete):
                for i in range(self.use_n_prev_actions):
                    prev_a_r.append(
                        one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i],
                                self.action_space))
            elif isinstance(self.action_space, MultiDiscrete):
                for i in range(0, self.use_n_prev_actions,
                               self.action_space.shape[0]):
                    prev_a_r.append(
                        one_hot(
                            tf.cast(
                                input_dict[SampleBatch.PREV_ACTIONS]
                                [:, i:i + self.action_space.shape[0]],
                                tf.float32), self.action_space))
            else:
                prev_a_r.append(
                    tf.reshape(
                        tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
                                tf.float32),
                        [-1, self.use_n_prev_actions * self.action_dim]))
        if self.use_n_prev_rewards:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, self.use_n_prev_rewards]))

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        # Then through our GTrXL.
        input_dict["obs_flat"] = input_dict["obs"] = wrapped_out

        self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
        model_out = self._logits_branch(self._features)
        return model_out, memory_outs
Example #5
0
    def call(self, input_dict: SampleBatch) -> \
            (TensorType, List[TensorType], Dict[str, TensorType]):
        assert input_dict.get("seq_lens") is not None
        # Push obs through underlying (wrapped) model first.
        wrapped_out, _, _ = self.wrapped_keras_model(input_dict)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.lstm_use_prev_action:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                prev_a = one_hot(prev_a, self.action_space)
            prev_a_r.append(
                tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim]))
        if self.lstm_use_prev_reward:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, 1]))

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        max_seq_len = tf.shape(wrapped_out)[0] // tf.shape(
            input_dict["seq_lens"])[0]
        wrapped_out_plus_time_dim = add_time_dimension(
            wrapped_out, max_seq_len=max_seq_len, framework="tf")
        model_out, value_out, h, c = self._rnn_model([
            wrapped_out_plus_time_dim, input_dict["seq_lens"],
            input_dict["state_in_0"], input_dict["state_in_1"]
        ])
        model_out_no_time_dim = tf.reshape(
            model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0))
        return model_out_no_time_dim, [h, c], {
            SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])
        }
 def forward(self, input_dict, states, seq_lens):
     obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
     rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
     actions = one_hot(input_dict["prev_n_actions"], self.action_space)
     out, self._last_value = self.base_model([obs, actions, rewards])
     return out, []