def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
        else:
            # TODO manage blindness for all agents simultaneously or separate?
            raise NotImplementedError()

        # TODO alternative where all agents consume all observations
        x, rnn_hidden_states = self.state_encoder(perception_embed,
                                                  memory.tensor("rnn"), masks)

        dists, vals = self.actor_critic(x)

        return (
            ActorCriticOutput(
                distributions=dists,
                values=vals,
                extras={},
            ),
            memory.set_tensor("rnn", rnn_hidden_states),
        )
Beispiel #2
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        """Processes input batched observations to produce new actor and critic
        values. Processes input batched observations (along with prior hidden
        states, previous actions, and masks denoting which recurrent hidden
        states should be masked) and returns an `ActorCriticOutput` object
        containing the model's policy (distribution over actions) and
        evaluation of the current state (value).

        # Parameters
        observations : Batched input observations.
        memory : `Memory` containing the hidden states from initial timepoints.
        prev_actions : Tensor of previous actions taken.
        masks : Masks applied to hidden states. See `RNNStateEncoder`.
        # Returns
        Tuple of the `ActorCriticOutput` and recurrent hidden state.
        """

        arm2obj_dist = self.get_relative_distance_embedding(
            observations["relative_agent_arm_to_obj"])
        obj2goal_dist = self.get_relative_distance_embedding(
            observations["relative_obj_to_goal"])

        perception_embed = self.visual_encoder(observations)

        pickup_bool = observations["pickedup_object"]
        before_pickup = pickup_bool == 0  # not used because of our initialization
        after_pickup = pickup_bool == 1
        distances = arm2obj_dist
        distances[after_pickup] = obj2goal_dist[after_pickup]

        x = [distances, perception_embed]

        x_cat = torch.cat(x, dim=-1)
        x_out, rnn_hidden_states = self.state_encoder(x_cat,
                                                      memory.tensor("rnn"),
                                                      masks)

        actor_out = self.actor(x_out)
        critic_out = self.critic(x_out)
        actor_critic_output = ActorCriticOutput(distributions=actor_out,
                                                values=critic_out,
                                                extras={})

        updated_memory = memory.set_tensor("rnn", rnn_hidden_states)

        return (
            actor_critic_output,
            updated_memory,
        )
Beispiel #3
0
 def forward(  # type:ignore
     self,
     observations: ObservationType,
     memory: Memory,
     prev_actions: torch.Tensor,
     masks: torch.FloatTensor,
 ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
     out, recurrent_hidden_states = self.baby_ai_model.forward(
         observations=observations,
         recurrent_hidden_states=cast(torch.FloatTensor,
                                      memory.tensor(self.memory_key)),
         prev_actions=prev_actions,
         masks=masks,
     )
     return out, memory.set_tensor(self.memory_key, recurrent_hidden_states)
 def forward(  # type:ignore
     self,
     observations: ObservationType,
     memory: Memory,
     prev_actions: torch.Tensor,
     masks: torch.FloatTensor,
 ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
     x = self.goal_visual_encoder(observations)
     x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"),
                                               masks)
     return (
         ActorCriticOutput(
             distributions=self.actor(x),
             values=self.critic(x),
             extras={"auxiliary_distributions": self.auxiliary_actor(x)}
             if self.include_auxiliary_head else {},
         ),
         memory.set_tensor("rnn", rnn_hidden_states),
     )
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        """Processes input batched observations to produce new actor and critic
        values. Processes input batched observations (along with prior hidden
        states, previous actions, and masks denoting which recurrent hidden
        states should be masked) and returns an `ActorCriticOutput` object
        containing the model's policy (distribution over actions) and
        evaluation of the current state (value).

        # Parameters
        observations : Batched input observations.
        memory : `Memory` containing the hidden states from initial timepoints.
        prev_actions : Tensor of previous actions taken.
        masks : Masks applied to hidden states. See `RNNStateEncoder`.
        # Returns
        Tuple of the `ActorCriticOutput` and recurrent hidden state.
        """
        target_encoding = self.get_object_type_encoding(
            cast(Dict[str, torch.FloatTensor], observations)
        )
        x = [target_encoding]

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            x = [perception_embed] + x

        x_cat = torch.cat(x, dim=-1)  # type: ignore
        x_out, rnn_hidden_states = self.state_encoder(
            x_cat, memory.tensor("rnn"), masks
        )

        return (
            ActorCriticOutput(
                distributions=self.actor(x_out), values=self.critic(x_out), extras={}
            ),
            memory.set_tensor("rnn", rnn_hidden_states),
        )
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        x = self.goal_visual_encoder(observations)

        x, rnn_hidden_states = self.state_encoder(x,
                                                  memory.tensor("rnn_hidden"),
                                                  masks)

        dists, vals = self.actor_critic(x)

        return (
            ActorCriticOutput(
                distributions=dists,
                values=vals,
                extras={},
            ),
            memory.set_tensor("rnn_hidden", rnn_hidden_states),
        )