def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], self.obs_space, "tf") # Push image observations through our CNNs. outs = [] for i, component in enumerate(orig_obs): if i in self.cnns: cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: if component.dtype in [tf.int32, tf.int64, tf.uint8]: outs.append( one_hot(component, self.original_space.spaces[i])) else: outs.append(component) else: outs.append(tf.reshape(component, [-1, self.flatten[i]])) # Concat all outputs and the non-image inputs. out = tf.concat(outs, axis=1) # Push through (optional) FC-stack (this may be an empty stack). out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None) # No logits/value branches. if not self.logits_and_value_model: return out, [] # Logits- and value branches. logits, values = self.logits_and_value_model(out) self._value_out = tf.reshape(values, [-1]) return logits, []
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. prev_a_r = [] if self.model_config["lstm_use_prev_action"]: prev_a = input_dict[SampleBatch.PREV_ACTIONS] if isinstance(self.action_space, (Discrete, MultiDiscrete)): prev_a = one_hot(prev_a, self.action_space) prev_a_r.append( tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])) if self.model_config["lstm_use_prev_reward"]: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1])) if prev_a_r: wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) # Then through our LSTM. input_dict["obs_flat"] = wrapped_out return super().forward(input_dict, state, seq_lens)
def call(self, input_dict: SampleBatch) -> \ (TensorType, List[TensorType], Dict[str, TensorType]): assert input_dict[SampleBatch.SEQ_LENS] is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _, _ = self.wrapped_keras_model(input_dict) # Concat. prev-action/reward if required. prev_a_r = [] if self.use_n_prev_actions: if isinstance(self.action_space, Discrete): for i in range(self.use_n_prev_actions): prev_a_r.append( one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i], self.action_space)) elif isinstance(self.action_space, MultiDiscrete): for i in range(0, self.use_n_prev_actions, self.action_space.shape[0]): prev_a_r.append( one_hot( tf.cast( input_dict[SampleBatch.PREV_ACTIONS] [:, i:i + self.action_space.shape[0]], tf.float32), self.action_space)) else: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_ACTIONS], tf.float32), [-1, self.use_n_prev_actions * self.action_dim])) if self.use_n_prev_rewards: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, self.use_n_prev_rewards])) if prev_a_r: wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) memory_ins = [ s for k, s in input_dict.items() if k.startswith("state_in_") ] model_out, memory_outs, value_outs = self.base_model([wrapped_out] + memory_ins) return model_out, memory_outs, { SampleBatch.VF_PREDS: tf.reshape(value_outs, [-1]) }
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. prev_a_r = [] if self.use_n_prev_actions: if isinstance(self.action_space, Discrete): for i in range(self.use_n_prev_actions): prev_a_r.append( one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i], self.action_space)) elif isinstance(self.action_space, MultiDiscrete): for i in range(0, self.use_n_prev_actions, self.action_space.shape[0]): prev_a_r.append( one_hot( tf.cast( input_dict[SampleBatch.PREV_ACTIONS] [:, i:i + self.action_space.shape[0]], tf.float32), self.action_space)) else: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_ACTIONS], tf.float32), [-1, self.use_n_prev_actions * self.action_dim])) if self.use_n_prev_rewards: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, self.use_n_prev_rewards])) if prev_a_r: wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) # Then through our GTrXL. input_dict["obs_flat"] = input_dict["obs"] = wrapped_out self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens) model_out = self._logits_branch(self._features) return model_out, memory_outs
def call(self, input_dict: SampleBatch) -> \ (TensorType, List[TensorType], Dict[str, TensorType]): assert input_dict.get("seq_lens") is not None # Push obs through underlying (wrapped) model first. wrapped_out, _, _ = self.wrapped_keras_model(input_dict) # Concat. prev-action/reward if required. prev_a_r = [] if self.lstm_use_prev_action: prev_a = input_dict[SampleBatch.PREV_ACTIONS] if isinstance(self.action_space, (Discrete, MultiDiscrete)): prev_a = one_hot(prev_a, self.action_space) prev_a_r.append( tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])) if self.lstm_use_prev_reward: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1])) if prev_a_r: wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) max_seq_len = tf.shape(wrapped_out)[0] // tf.shape( input_dict["seq_lens"])[0] wrapped_out_plus_time_dim = add_time_dimension( wrapped_out, max_seq_len=max_seq_len, framework="tf") model_out, value_out, h, c = self._rnn_model([ wrapped_out_plus_time_dim, input_dict["seq_lens"], input_dict["state_in_0"], input_dict["state_in_1"] ]) model_out_no_time_dim = tf.reshape( model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0)) return model_out_no_time_dim, [h, c], { SampleBatch.VF_PREDS: tf.reshape(value_out, [-1]) }
def forward(self, input_dict, states, seq_lens): obs = tf.cast(input_dict["prev_n_obs"], tf.float32) rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32) actions = one_hot(input_dict["prev_n_actions"], self.action_space) out, self._last_value = self.base_model([obs, actions, rewards]) return out, []