Esempio n. 1
0
    def forward(self, input_dict, state, seq_lens):
        if "internal_state" in input_dict["obs"]:
            self.internal_states = input_dict["obs"]["internal_state"]

        pi_obs_inputs = input_dict["obs"][self.pi_obs_key]
        vf_obs_inputs = input_dict['obs'][self.vf_obs_key]
        if self.use_lstm:
            policy_out, self._value_out, *state_out = self.base_model([
                add_time_dimension(pi_obs_inputs, seq_lens),
                add_time_dimension(vf_obs_inputs, seq_lens),

                seq_lens,
                *state])
            policy_out = tf.reshape(policy_out, [-1, self.num_outputs])
        else:
            policy_out, self._value_out = self.base_model([pi_obs_inputs, vf_obs_inputs])
            state_out = state

        self.unmasked_policy_logits = policy_out

        if self.mask_invalid_actions:
            # set policy logits for invalid actions to zero
            self.valid_actions_masks = input_dict["obs"]["valid_actions_mask"]
            inf_mask = tf.maximum(tf.log(self.valid_actions_masks), tf.float32.min)
            self.masked_policy_logits = policy_out + inf_mask
        else:
            self.masked_policy_logits = policy_out

        return self.masked_policy_logits, state_out
Esempio n. 2
0
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn()"""

        # first we add the time dimension for each object
        input_dict["obs_vision"] = add_time_dimension(input_dict["obs"][0], seq_lens)
        input_dict["obs_messages"] = add_time_dimension(input_dict["obs"][1], seq_lens)

        output, new_state = self.forward_rnn(input_dict, state, seq_lens)

        return tf.reshape(output, [-1, self.num_outputs]), new_state
Esempio n. 3
0
    def test_add_time_dimension(self):
        """Test add_time_dimension gives sequential data along the time dimension"""

        B, T, F = np.random.choice(
            np.asarray(list(range(8, 32)),
                       dtype=np.int32),  # use int32 for seq_lens
            size=3,
            replace=False,
        )

        inputs_numpy = np.repeat(np.arange(B * T)[:, np.newaxis],
                                 repeats=F,
                                 axis=-1).astype(np.int32)
        check(inputs_numpy.shape, (B * T, F))

        time_shift_diff_batch_major = np.ones(shape=(B, T - 1, F),
                                              dtype=np.int32)
        time_shift_diff_time_major = np.ones(shape=(T - 1, B, F),
                                             dtype=np.int32)

        if tf is not None:
            # Test tensorflow batch-major
            padded_inputs = tf.constant(inputs_numpy)
            batch_major_outputs = add_time_dimension(padded_inputs,
                                                     max_seq_len=T,
                                                     framework="tf",
                                                     time_major=False)
            check(batch_major_outputs.shape.as_list(), [B, T, F])
            time_shift_diff = batch_major_outputs[:,
                                                  1:] - batch_major_outputs[:, :
                                                                            -1]
            check(time_shift_diff, time_shift_diff_batch_major)

        if torch is not None:
            # Test torch batch-major
            padded_inputs = torch.from_numpy(inputs_numpy)
            batch_major_outputs = add_time_dimension(padded_inputs,
                                                     max_seq_len=T,
                                                     framework="torch",
                                                     time_major=False)
            check(batch_major_outputs.shape, (B, T, F))
            time_shift_diff = batch_major_outputs[:,
                                                  1:] - batch_major_outputs[:, :
                                                                            -1]
            check(time_shift_diff, time_shift_diff_batch_major)

            # Test torch time-major
            padded_inputs = torch.from_numpy(inputs_numpy)
            time_major_outputs = add_time_dimension(padded_inputs,
                                                    max_seq_len=T,
                                                    framework="torch",
                                                    time_major=True)
            check(time_major_outputs.shape, (T, B, F))
            time_shift_diff = time_major_outputs[1:] - time_major_outputs[:-1]
            check(time_shift_diff, time_shift_diff_time_major)
Esempio n. 4
0
 def forward(self, input_dict, state, seq_lens):
     """Adds time dimension to batch before sending inputs to forward_rnn().
     You should implement forward_rnn() in your subclass."""
     if self.use_prev_action:
         output, new_state = self.forward_rnn(
             add_time_dimension(input_dict["obs"], seq_lens), state,
             seq_lens, add_time_dimension(input_dict["prev_action"], seq_lens))
     else:
         output, new_state = self.forward_rnn(
             add_time_dimension(input_dict["obs"], seq_lens), state,
             seq_lens)
     return tf.reshape(output, [-1, self.num_outputs]), new_state
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn()"""
        # first we add the time dimension for each object
        if isinstance(input_dict["obs"], dict):
            padded_obs = add_time_dimension(input_dict["obs"]["obs"], seq_lens)
        else:
            padded_obs = add_time_dimension(input_dict["obs"], seq_lens)

        if self.use_prev_action:
            padded_action = add_time_dimension(input_dict["prev_actions"], seq_lens)
            padded_obs = tf.concat([padded_obs, padded_action], axis=-1)

        output, new_state = self.forward_rnn(padded_obs, state, seq_lens)
        return tf.reshape(output, [-1, self.num_outputs]), new_state
Esempio n. 6
0
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        assert seq_lens is not None
        padded_inputs = input_dict["obs_flat"]
        max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0]
        output, new_state = self.forward_rnn(
            add_time_dimension(
                padded_inputs,
                max_seq_len=max_seq_len,
                framework="tf",
            ),
            state,
            seq_lens,
        )
        output = tf.reshape(output, [-1, self.num_outputs])

        action_mask = input_dict["obs"]["action_mask"]
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
        output = output + inf_mask

        return output, new_state
Esempio n. 7
0
    def forward(self, input_dict, state, seq_lens):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        x = input_dict["obs"]["conv_features"]
        x = self.shared_layers(x)

        if type(input_dict["prev_rewards"]) != torch.Tensor:
            input_dict["prev_rewards"] = torch.tensor(
                input_dict["prev_rewards"], device=device)

        last_reward = torch.reshape(input_dict["prev_rewards"],
                                    [-1, 1]).float()

        if type(input_dict["prev_actions"]) != torch.Tensor:
            prev_actions = np.array(input_dict["prev_actions"], dtype=np.int)
        else:
            prev_actions = np.array(input_dict["prev_actions"].cpu().numpy(),
                                    dtype=np.int)
        prev_actions = np.expand_dims(prev_actions, 0)

        one_hot_prev_actions = torch.cat(
            [nn.functional.one_hot(torch.tensor(a), 6) for a in prev_actions],
            axis=-1)

        x = torch.cat((x, input_dict["obs"]["fc_features"], last_reward,
                       one_hot_prev_actions.float().to(device)),
                      dim=1)

        output, new_state = self.forward_rnn(
            add_time_dimension(x.float(), seq_lens, framework="torch"), state,
            seq_lens)

        return torch.reshape(output, [-1, self.num_outputs]), new_state
    def forward(self, input_dict, state, seq_lens):
        obs_inputs = input_dict["obs"][self._obs_key]
        self.valid_actions_masks = input_dict["obs"]["valid_actions_mask"]
        # self.valid_actions_masks = tf.Print(input_dict["obs"]["valid_actions_mask"], [input_dict["obs"]["valid_actions_mask"]], message="valid_act_mask: ")

        if self.use_lstm:
            obs_inputs_time_dist = add_time_dimension(obs_inputs, seq_lens)
            # obs_inputs_time_dist_check = tf.debugging.check_numerics(
            #     obs_inputs_time_dist, "nan found in obs_inputs_time_dist", name=None
            # )

            # seq_lens = tf.debugging.check_numerics(
            #     seq_lens, "nan found in seq_lens", name=None
            # )
            # state_checks = []
            # for i in range(len(state)):
            #     state_checks.append(tf.debugging.check_numerics(
            #         state[i], f"nan found in state[{i}]", name=None
            #     ))

            # with tf.control_dependencies([obs_inputs_time_dist_check, *state_checks]):
            base_model_out, *state_out = self._base_model(
                [obs_inputs_time_dist, seq_lens, *state])

            # base_model_out = tf.Print(base_model_out, state_out,
            #          message="state_out: ")

            return tf.reshape(base_model_out,
                              [-1, *self._base_model_out_shape]), state_out
        else:
            base_model_out = self._base_model([obs_inputs])
            state_out = state
            return base_model_out, state_out
Esempio n. 9
0
    def forward(
            self, input_dict: Dict[str,
                                   torch.Tensor], state: List[torch.Tensor],
            seq_lens: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:

        # first apply the cnn
        x = input_dict['obs'].float().permute(0, 3, 1, 2) / 255.0
        x = self.cnn(x)

        # add time
        x_flat = x.view(x.shape[0], -1)

        # pylint: disable=too-many-function-args,missing-kwoa
        x = add_time_dimension(x_flat, seq_lens, "torch")

        # apply lstm
        # pylint: disable=no-member
        x, state_out = self.lstm(
            x, (torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)))

        # pylint: disable=no-member
        x = torch.reshape(x, [-1, self.lstm_cell_size])

        return self._forward_helper(x), [
            torch.squeeze(state_out[0], 0),
            torch.squeeze(state_out[1], 0)
        ]
Esempio n. 10
0
 def forward(self, input_dict, state, seq_lens):
     if isinstance(seq_lens, np.ndarray):
         seq_lens = torch.Tensor(seq_lens).int()
     output, new_state = self.forward_rnn(
         add_time_dimension(input_dict["obs"].float(),
                            seq_lens,
                            framework="torch"), input_dict["prev_actions"],
         state, seq_lens)
     return torch.reshape(output, [-1, self.num_outputs]), new_state
Esempio n. 11
0
    def _build_layers_v2(self, input_dict, num_outputs, options):
        # Hard deprecate this class. All Models should use the ModelV2
        # API from here on.
        deprecation_warning("Model->LSTM", "RecurrentNetwork", error=False)

        cell_size = options.get("lstm_cell_size")
        if options.get("lstm_use_prev_action_reward"):
            action_dim = int(
                np.product(
                    input_dict["prev_actions"].get_shape().as_list()[1:]))
            features = tf.concat(
                [
                    input_dict["obs"],
                    tf.reshape(
                        tf.cast(input_dict["prev_actions"], tf.float32),
                        [-1, action_dim]),
                    tf.reshape(input_dict["prev_rewards"], [-1, 1]),
                ],
                axis=1)
        else:
            features = input_dict["obs"]
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = tf1.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf1.placeholder(
                tf.float32, [None, lstm.state_size.c], name="c")
            h_in = tf1.placeholder(
                tf.float32, [None, lstm.state_size.h], name="h")
            self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = tf1.nn.rnn_cell.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf1.nn.dynamic_rnn(
            lstm,
            last_layer,
            initial_state=state_in,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        self.state_out = list(lstm_state)

        # Compute outputs
        last_layer = tf.reshape(lstm_out, [-1, cell_size])
        logits = linear(last_layer, num_outputs, "action",
                        normc_initializer(0.01))
        return logits, last_layer
Esempio n. 12
0
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        output, new_state = self.forward_rnn(
            add_time_dimension(input_dict["obs_flat"],
                               seq_lens,
                               framework="tf"), state, seq_lens)
        return tf.reshape(output, [-1, self.num_outputs]), new_state
Esempio n. 13
0
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        if isinstance(seq_lens, np.ndarray):
            seq_lens = torch.Tensor(seq_lens).int()
        output, new_state = self.forward_rnn(
            add_time_dimension(
                input_dict["obs_flat"].float(), seq_lens, framework="torch"),
            state, seq_lens)
        return torch.reshape(output, [-1, self.num_outputs]), new_state
Esempio n. 14
0
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        assert seq_lens is not None
        padded_inputs = input_dict["obs_flat"]
        max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0]
        output, new_state = self.forward_rnn(
            add_time_dimension(
                padded_inputs, max_seq_len=max_seq_len, framework="tf"), state,
            seq_lens)
        return tf.reshape(output, [-1, self.num_outputs]), new_state
Esempio n. 15
0
    def forward(self, input_dict, state, seq_lens):
        x = input_dict["obs"]["conv_features"]
        x = self.shared_conv_layers(x)
        x = torch.cat((x, input_dict["obs"]["fc_features"]), dim=1)
        x = self.shared_fc_layers(x)

        output, new_state = self.forward_rnn(
            add_time_dimension(x.float(), seq_lens, framework="torch"),
            state,
            seq_lens
        )

        return torch.reshape(output, [-1, self.num_outputs]), new_state
Esempio n. 16
0
    def forward(self, input_dict, state, seq_lens):
        """
        Evaluate the model.
        Adds time dimension to batch before sending inputs to forward_rnn()
        :param input_dict: The input tensors.
        :param state: The model state.
        :param seq_lens: LSTM sequence lengths.
        :return: The policy logits and state.
        """
        trunk = self.encoder_model(input_dict["obs"]["curr_obs"])
        new_dict = {"curr_obs": add_time_dimension(trunk, seq_lens)}

        output, new_state = self.forward_rnn(new_dict, state, seq_lens)
        return tf.reshape(output, [-1, self.num_outputs]), new_state
Esempio n. 17
0
    def call(
        self, input_dict: SampleBatch
    ) -> (TensorType, List[TensorType], Dict[str, TensorType]):
        assert input_dict.get(SampleBatch.SEQ_LENS) is not None
        # Push obs through underlying (wrapped) model first.
        wrapped_out, _, _ = self.wrapped_keras_model(input_dict)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.lstm_use_prev_action:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                prev_a = one_hot(prev_a, self.action_space)
            prev_a_r.append(
                tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])
            )
        if self.lstm_use_prev_reward:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]
                )
            )

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        max_seq_len = (
            tf.shape(wrapped_out)[0] // tf.shape(input_dict[SampleBatch.SEQ_LENS])[0]
        )
        wrapped_out_plus_time_dim = add_time_dimension(
            wrapped_out, max_seq_len=max_seq_len, framework="tf"
        )
        model_out, value_out, h, c = self._rnn_model(
            [
                wrapped_out_plus_time_dim,
                input_dict[SampleBatch.SEQ_LENS],
                input_dict["state_in_0"],
                input_dict["state_in_1"],
            ]
        )
        model_out_no_time_dim = tf.reshape(
            model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0)
        )
        return (
            model_out_no_time_dim,
            [h, c],
            {SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])},
        )
Esempio n. 18
0
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        flat_inputs = input_dict["obs_flat"].float()
        if isinstance(seq_lens, np.ndarray):
            seq_lens = torch.Tensor(seq_lens).int()
        max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0]
        self.time_major = self.model_config.get("_time_major", False)
        inputs = add_time_dimension(
            flat_inputs,
            max_seq_len=max_seq_len,
            framework="torch",
            time_major=self.time_major,
        )
        output, new_state = self.forward_rnn(inputs, state, seq_lens)
        output = torch.reshape(output, [-1, self.num_outputs])
        return output, new_state
Esempio n. 19
0
    def forward(self, input_dict, state, seq_lens):
        flat_inputs = input_dict["obs"]["state"].float()
        if isinstance(seq_lens, np.ndarray):
            seq_lens = torch.Tensor(seq_lens).int()
        max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0]
        self.time_major = self.model_config.get("_time_major", False)
        inputs = add_time_dimension(
            flat_inputs,
            max_seq_len=max_seq_len,
            framework="torch",
            time_major=self.time_major,
        )

        action_logits, new_state = self.forward_rnn(inputs, state, seq_lens)
        action_logits = torch.reshape(action_logits, [-1, self.num_outputs])

        action_mask = input_dict["obs"]["action_mask"]
        inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)
        return action_logits + inf_mask, new_state
Esempio n. 20
0
    def forward(self, input_dict, state, seq_lens):
        """
        First evaluate non-LSTM parts of model. Then add a time dimension to the batch before
        sending inputs to forward_rnn(), which evaluates the LSTM parts of the model.
        :param input_dict: The input tensors.
        :param state: The model state.
        :param seq_lens: LSTM sequence lengths.
        :return: The agent's own action logits and the new model state.
        """
        # Evaluate non-lstm layers
        actor_critic_fc_output, moa_fc_output = self.moa_encoder_model(
            input_dict["obs"]["curr_obs"])

        rnn_input_dict = {
            "ac_trunk": actor_critic_fc_output,
            "prev_moa_trunk": state[5],
            "other_agent_actions": input_dict["obs"]["other_agent_actions"],
            "visible_agents": input_dict["obs"]["visible_agents"],
            "prev_actions": input_dict["prev_actions"],
        }

        # Add time dimension to rnn inputs
        for k, v in rnn_input_dict.items():
            rnn_input_dict[k] = add_time_dimension(v, seq_lens)

        output, new_state = self.forward_rnn(rnn_input_dict, state, seq_lens)
        action_logits = tf.reshape(output, [-1, self.num_outputs])
        counterfactuals = tf.reshape(
            self._counterfactuals,
            [
                -1, self._counterfactuals.shape[-2],
                self._counterfactuals.shape[-1]
            ],
        )
        new_state.extend([action_logits, moa_fc_output])

        self.compute_influence_reward(input_dict, state[4], counterfactuals)

        return action_logits, new_state
Esempio n. 21
0
    def _build_layers_v2(self, input_dict, num_outputs, options):
        # Previously, a new class object was created during
        # deserialization and this `capture_index`
        # variable would be refreshed between class instantiations.
        # This behavior is no longer the case, so we manually refresh
        # the variable.
        RNNSpyModel.capture_index = 0

        def spy(sequences, state_in, state_out, seq_lens):
            if len(sequences) == 1:
                return 0  # don't capture inference inputs
            # TF runs this function in an isolated context, so we have to use
            # redis to communicate back to our suite
            ray.experimental.internal_kv._internal_kv_put(
                "rnn_spy_in_{}".format(RNNSpyModel.capture_index),
                pickle.dumps({
                    "sequences": sequences,
                    "state_in": state_in,
                    "state_out": state_out,
                    "seq_lens": seq_lens
                }),
                overwrite=True)
            RNNSpyModel.capture_index += 1
            return 0

        features = input_dict["obs"]
        cell_size = 3
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = tf.nn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c],
                                  name="c")
            h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h],
                                  name="h")
        self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm,
                                                 last_layer,
                                                 initial_state=state_in,
                                                 sequence_length=self.seq_lens,
                                                 time_major=False,
                                                 dtype=tf.float32)

        self.state_out = list(lstm_state)
        spy_fn = tf.py_func(spy, [
            last_layer,
            self.state_in,
            self.state_out,
            self.seq_lens,
        ],
                            tf.int64,
                            stateful=True)

        # Compute outputs
        with tf.control_dependencies([spy_fn]):
            last_layer = tf.reshape(lstm_out, [-1, cell_size])
            logits = linear(last_layer, num_outputs, "action",
                            normc_initializer(0.01))
        return logits, last_layer
Esempio n. 22
0
    def forward(self, input_dict, state, seq_lens):
        """
        Adds time dimension to batch and does forward inference
        """
        prev_actions = tf.cast(input_dict["prev_actions"][:, 0],
                               dtype=tf.int32)
        prev_rewards = input_dict["prev_rewards"]
        lstm_state = state[:2]
        if self.use_receiver_bias:
            receiver_bias_state = state[2:4]

        obs_dict = input_dict["obs"]
        inputs = add_time_dimension(obs_dict["obs"], seq_lens)
        if self.use_comm:
            extra_inputs = obs_dict["message"]
        else:
            extra_inputs = tf.zeros_like(obs_dict["message"])
        outputs = self.rnn_model(
            [inputs, extra_inputs, prev_actions, prev_rewards, seq_lens] +
            lstm_state)

        if self.use_receiver_bias:
            extra_inputs = tf.zeros_like(obs_dict["message"])
            self.no_message_outputs = self.rnn_model(
                [inputs, extra_inputs, prev_actions, prev_rewards, seq_lens] +
                receiver_bias_state)

        if self.use_cpc:
            (
                model_out,
                self._value_out,
                h,
                c,
                self._cpc_ins,
                self._cpc_preds,
                *self._unscaled_message_p,
            ) = outputs
        else:
            model_out, self._value_out, h, c, *self._unscaled_message_p = outputs

        next_states = [h, c]
        if self.use_receiver_bias:
            next_states.extend(self.no_message_outputs[2:4])

        if self.use_inference_policy:
            if self.pm_type == "moving_avg":
                action_logits = model_out[..., :-self.message_size]
                unscaled_message_logits = model_out[..., -self.message_size:]
                avg_message_logits = tf.log(
                    self._avg_message_p) - tf.log(1 - self._avg_message_p)
                scaled_message_logits = unscaled_message_logits - avg_message_logits
                model_out = tf.keras.layers.Concatenate()(
                    [action_logits, scaled_message_logits])
            elif self.pm_type == "hyper_nn":
                action_logits = model_out[..., :-self.message_size]
                unscaled_message_logits = model_out[..., -self.message_size:]
                scaled_message_logits = unscaled_message_logits - self._pm_logits
                model_out = tf.keras.layers.Concatenate()(
                    [action_logits, scaled_message_logits])
            else:
                raise NotImplementedError("Wrong type for inference_policy")

        self._model_out = tf.reshape(model_out, [-1, self.num_outputs])
        return self._model_out, next_states