def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: if not self.is_blind: perception_embed = self.visual_encoder(observations) else: # TODO manage blindness for all agents simultaneously or separate? raise NotImplementedError() # TODO alternative where all agents consume all observations x, rnn_hidden_states = self.state_encoder(perception_embed, memory.tensor("rnn"), masks) dists, vals = self.actor_critic(x) return ( ActorCriticOutput( distributions=dists, values=vals, extras={}, ), memory.set_tensor("rnn", rnn_hidden_states), )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: out, recurrent_hidden_states = self.baby_ai_model.forward( observations=observations, recurrent_hidden_states=cast(torch.FloatTensor, memory.tensor(self.memory_key)), prev_actions=prev_actions, masks=masks, ) return out, memory.set_tensor(self.memory_key, recurrent_hidden_states)
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: x = self.goal_visual_encoder(observations) x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks) return ( ActorCriticOutput( distributions=self.actor(x), values=self.critic(x), extras={} ), memory.set_tensor("rnn", rnn_hidden_states), )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: """Processes input batched observations to produce new actor and critic values. Processes input batched observations (along with prior hidden states, previous actions, and masks denoting which recurrent hidden states should be masked) and returns an `ActorCriticOutput` object containing the model's policy (distribution over actions) and evaluation of the current state (value). # Parameters observations : Batched input observations. rnn_hidden_states : Hidden states from initial timepoints. prev_actions : Tensor of previous actions taken. masks : Masks applied to hidden states. See `RNNStateEncoder`. # Returns Tuple of the `ActorCriticOutput` and recurrent hidden state. """ target_encoding = self.get_object_type_encoding( cast(Dict[str, torch.FloatTensor], observations) ) x = [target_encoding] if not self.is_blind: perception_embed = self.visual_encoder(observations) x = [perception_embed] + x x_cat = torch.cat(x, dim=-1) # type: ignore x_out, rnn_hidden_states = self.state_encoder( x_cat, memory.tensor("rnn"), masks ) return ( ActorCriticOutput( distributions=self.actor(x_out), values=self.critic(x_out), extras={} ), memory.set_tensor("rnn", rnn_hidden_states), )