def _unpack_observations(self, input_dict): restored = input_dict.copy() restored["obs"] = restore_original_dimensions( input_dict["obs"], self.observation_space, self.framework ) if len(input_dict["obs"].shape) > 2: restored["obs_flat"] = flatten(input_dict["obs"], self.framework) else: restored["obs_flat"] = input_dict["obs"] return restored
def __call__(self, input_dict, state=None, seq_lens=None): """Call the model with the given input tensors and state. This is the method used by RLlib to execute the forward pass. It calls forward() internally after unpacking nested observation tensors. Custom models should override forward() instead of __call__. Arguments: input_dict (dict): dictionary of input tensors, including "obs", "prev_action", "prev_reward", "is_training" state (list): list of state tensors with sizes matching those returned by get_initial_state + the batch dimension seq_lens (Tensor): 1d tensor holding input sequence lengths Returns: (outputs, state): The model output tensor of size [BATCH, output_spec.size] or a list of tensors corresponding to output_spec.shape_list, and a list of state tensors of [BATCH, state_size_i]. """ restored = input_dict.copy() restored["obs"] = restore_original_dimensions(input_dict["obs"], self.obs_space, self.framework) if len(input_dict["obs"].shape) > 2: restored["obs_flat"] = flatten(input_dict["obs"], self.framework) else: restored["obs_flat"] = input_dict["obs"] with self.context(): res = self.forward(restored, state or [], seq_lens) if ((not isinstance(res, list) and not isinstance(res, tuple)) or len(res) != 2): raise ValueError( "forward() must return a tuple of (output, state) tensors, " "got {}".format(res)) outputs, state = res try: shape = outputs.shape except AttributeError: raise ValueError("Output is not a tensor: {}".format(outputs)) else: if len(shape) != 2 or shape[1] != self.num_outputs: raise ValueError( "Expected output shape of [None, {}], got {}".format( self.num_outputs, shape)) if not isinstance(state, list): raise ValueError("State output is not a list: {}".format(state)) self._last_output = outputs return outputs, state