def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        target_encoding = self.get_target_coordinates_encoding(observations)
        x: Union[torch.Tensor, List[torch.Tensor]]
        x = [target_encoding]

        # if observations["rgb"].shape[0] != 1:
        #     print("rgb", (observations["rgb"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["rgb"][...,0,0,:]).float().mean())
        #     if "depth" in observations:
        #         print("depth", (observations["depth"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["depth"][...,0,0,:]).float().mean())

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            if self.sensor_fusion:
                perception_embed = self.sensor_fuser(perception_embed)
            x = [perception_embed] + x

        x = torch.cat(x, dim=-1)
        x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"),
                                                  masks)

        ac_output = ActorCriticOutput(distributions=self.actor(x),
                                      values=self.critic(x),
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
Beispiel #2
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        out = self.linear(cast(torch.Tensor, observations[self.input_uuid]))

        main_logits = out[..., :self.num_actions]
        aux_logits = out[..., self.num_actions:-1]
        values = out[..., -1:]

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(
                distributions=cast(
                    DistributionType, CategoricalDistr(
                        logits=main_logits)),  # step x sampler x ...
                values=cast(torch.FloatTensor,
                            values.view(values.shape[:2] +
                                        (-1, ))),  # step x sampler x flattened
                extras={
                    "auxiliary_distributions":
                    CategoricalDistr(logits=aux_logits)
                },
            ),
            None,
        )
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
        else:
            # TODO manage blindness for all agents simultaneously or separate?
            raise NotImplementedError()

        # TODO alternative where all agents consume all observations
        x, rnn_hidden_states = self.state_encoder(perception_embed,
                                                  memory.tensor("rnn"), masks)

        dists, vals = self.actor_critic(x)

        return (
            ActorCriticOutput(
                distributions=dists,
                values=vals,
                extras={},
            ),
            memory.set_tensor("rnn", rnn_hidden_states),
        )
Beispiel #4
0
    def forward(self, observations, memory, prev_actions, masks):
        dists, values = self.head(observations[self.input_uuid])

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(distributions=dists, values=values, extras={},),
            None,
        )
Beispiel #5
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        """Processes input batched observations to produce new actor and critic
        values. Processes input batched observations (along with prior hidden
        states, previous actions, and masks denoting which recurrent hidden
        states should be masked) and returns an `ActorCriticOutput` object
        containing the model's policy (distribution over actions) and
        evaluation of the current state (value).

        # Parameters
        observations : Batched input observations.
        memory : `Memory` containing the hidden states from initial timepoints.
        prev_actions : Tensor of previous actions taken.
        masks : Masks applied to hidden states. See `RNNStateEncoder`.
        # Returns
        Tuple of the `ActorCriticOutput` and recurrent hidden state.
        """

        arm2obj_dist = self.get_relative_distance_embedding(
            observations["relative_agent_arm_to_obj"])
        obj2goal_dist = self.get_relative_distance_embedding(
            observations["relative_obj_to_goal"])

        perception_embed = self.visual_encoder(observations)

        pickup_bool = observations["pickedup_object"]
        before_pickup = pickup_bool == 0  # not used because of our initialization
        after_pickup = pickup_bool == 1
        distances = arm2obj_dist
        distances[after_pickup] = obj2goal_dist[after_pickup]

        x = [distances, perception_embed]

        x_cat = torch.cat(x, dim=-1)
        x_out, rnn_hidden_states = self.state_encoder(x_cat,
                                                      memory.tensor("rnn"),
                                                      masks)

        actor_out = self.actor(x_out)
        critic_out = self.critic(x_out)
        actor_critic_output = ActorCriticOutput(distributions=actor_out,
                                                values=critic_out,
                                                extras={})

        updated_memory = memory.set_tensor("rnn", rnn_hidden_states)

        return (
            actor_critic_output,
            updated_memory,
        )
Beispiel #6
0
    def forward(self, observations, memory, prev_actions, masks):
        out = self.linear(observations[self.input_uuid])

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(
                # ensure [steps, samplers, ...]
                distributions=CategoricalDistr(logits=out[..., :-1]),
                # ensure [steps, samplers, flattened]
                values=cast(torch.FloatTensor,
                            out[..., -1:].view(*out.shape[:2], -1)),
                extras={},
            ),
            None,
        )
 def forward(  # type:ignore
     self,
     observations: ObservationType,
     memory: Memory,
     prev_actions: torch.Tensor,
     masks: torch.FloatTensor,
 ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
     x = self.goal_visual_encoder(observations)
     x, rnn_hidden_states = self.state_encoder(
         x, memory.tensor(self.memory_key), masks)
     return (
         ActorCriticOutput(distributions=self.actor(x),
                           values=self.critic(x),
                           extras={}),
         memory.set_tensor(self.memory_key, rnn_hidden_states),
     )
    def forward(  # type:ignore
        self,
        observations: Dict[str, Union[torch.FloatTensor, Dict[str, Any]]],
        memory: Memory,
        prev_actions: Any,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        means = self.actor(observations[self.input_uuid])
        values = self.critic(observations[self.input_uuid])

        return (
            ActorCriticOutput(
                cast(DistributionType, GaussianDistr(loc=means, scale=self.action_std)),
                values,
                {},
            ),
            None,  # no Memory
        )
 def forward(  # type:ignore
     self,
     observations: ObservationType,
     memory: Memory,
     prev_actions: torch.Tensor,
     masks: torch.FloatTensor,
 ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
     x = self.goal_visual_encoder(observations)
     x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"),
                                               masks)
     return (
         ActorCriticOutput(
             distributions=self.actor(x),
             values=self.critic(x),
             extras={"auxiliary_distributions": self.auxiliary_actor(x)}
             if self.include_auxiliary_head else {},
         ),
         memory.set_tensor("rnn", rnn_hidden_states),
     )
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        """Processes input batched observations to produce new actor and critic
        values. Processes input batched observations (along with prior hidden
        states, previous actions, and masks denoting which recurrent hidden
        states should be masked) and returns an `ActorCriticOutput` object
        containing the model's policy (distribution over actions) and
        evaluation of the current state (value).

        # Parameters
        observations : Batched input observations.
        memory : `Memory` containing the hidden states from initial timepoints.
        prev_actions : Tensor of previous actions taken.
        masks : Masks applied to hidden states. See `RNNStateEncoder`.
        # Returns
        Tuple of the `ActorCriticOutput` and recurrent hidden state.
        """
        target_encoding = self.get_object_type_encoding(
            cast(Dict[str, torch.FloatTensor], observations)
        )
        x = [target_encoding]

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            x = [perception_embed] + x

        x_cat = torch.cat(x, dim=-1)  # type: ignore
        x_out, rnn_hidden_states = self.state_encoder(
            x_cat, memory.tensor("rnn"), masks
        )

        return (
            ActorCriticOutput(
                distributions=self.actor(x_out), values=self.critic(x_out), extras={}
            ),
            memory.set_tensor("rnn", rnn_hidden_states),
        )
Beispiel #11
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        cur_img = observations[self.rgb_uuid]
        unshuffled_img = observations[self.unshuffled_rgb_uuid]
        concat_img = torch.cat((cur_img, unshuffled_img), dim=-1)

        x = self.visual_encoder({self.concat_rgb_uuid: concat_img})
        x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"),
                                                  masks)

        ac_output = ActorCriticOutput(distributions=self.actor(x),
                                      values=self.critic(x),
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
Beispiel #12
0
    def adapt_result(ac_output, hidden_states, num_steps, num_samplers,
                     num_agents, num_layers, observations):  # type: ignore
        distributions = CategoricalDistr(
            logits=ac_output.distributions.logits.view(num_steps, num_samplers,
                                                       -1), )
        values = ac_output.values.view(num_steps, num_samplers, num_agents)
        extras = ac_output.extras  # ignore shape
        # TODO confirm the shape of the auxiliary distribution is the same as the actor's
        if "auxiliary_distributions" in extras:
            extras["auxiliary_distributions"] = CategoricalDistr(
                logits=extras["auxiliary_distributions"].logits.view(
                    num_steps,
                    num_samplers,
                    -1  # assume single-agent
                ), )

        hidden_states = hidden_states.view(num_layers,
                                           num_samplers * num_agents, -1)

        # Unflatten all observation batch dims
        def recursively_adapt_observations(obs):
            for entry in obs:
                if isinstance(obs[entry], Dict):
                    recursively_adapt_observations(obs[entry])
                else:
                    assert isinstance(obs[entry], torch.Tensor)
                    if entry in ["minigrid_ego_image", "minigrid_mission"]:
                        final_dims = obs[entry].shape[
                            1:]  # assumes no agents dim in observations!
                        obs[entry] = obs[entry].view(num_steps,
                                                     num_samplers * num_agents,
                                                     *final_dims)

        recursively_adapt_observations(observations)

        return (
            ActorCriticOutput(distributions=distributions,
                              values=values,
                              extras=extras),
            hidden_states,
        )
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        out = self.linear(cast(torch.Tensor, observations[self.key]))

        assert len(out.shape) in [
            3,
            4,
        ], "observations must be [step, sampler, data] or [step, sampler, agent, data]"

        if len(out.shape) == 3:
            # [step, sampler, data] -> [step, sampler, agent, data]
            out = out.unsqueeze(-2)

        main_logits = out[..., :self.num_actions]
        aux_logits = out[..., self.num_actions:-1]
        values = out[..., -1:]

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(
                distributions=cast(
                    DistributionType, CategoricalDistr(
                        logits=main_logits)),  # step x sampler x ...
                values=cast(torch.FloatTensor,
                            values.view(values.shape[:2] +
                                        (-1, ))),  # step x sampler x flattened
                extras={
                    "auxiliary_distributions":
                    CategoricalDistr(logits=aux_logits),
                },
            ),
            None,
        )
Beispiel #14
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        cur_img_resnet = observations[self.rgb_uuid]
        unshuffled_img_resnet = observations[self.unshuffled_rgb_uuid]
        concat_img = torch.cat(
            (
                cur_img_resnet,
                unshuffled_img_resnet,
                cur_img_resnet * unshuffled_img_resnet,
            ),
            dim=-3,
        )
        batch_shape, features_shape = concat_img.shape[:-3], concat_img.shape[
            -3:]
        concat_img_reshaped = concat_img.view(-1, *features_shape)
        attention_probs = torch.softmax(
            self.visual_attention(concat_img_reshaped).view(
                concat_img_reshaped.shape[0], -1),
            dim=-1,
        ).view(concat_img_reshaped.shape[0], 1,
               *concat_img_reshaped.shape[-2:])
        x = ((self.visual_encoder(concat_img_reshaped) *
              attention_probs).mean(-1).mean(-1))
        x = x.view(*batch_shape, -1)

        x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"),
                                                  masks)

        ac_output = ActorCriticOutput(distributions=self.actor(x),
                                      values=self.critic(x),
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
Beispiel #15
0
    def forward(
        self,
        observations: ObservationType,
        recurrent_hidden_states: torch.FloatTensor,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ):
        (
            observations,
            recurrent_hidden_states,
            prev_actions,
            masks,
            num_steps,
            num_samplers,
            num_agents,
            num_layers,
        ) = self.adapt_inputs(observations, recurrent_hidden_states,
                              prev_actions, masks)

        if self.lang_model != "gru":
            ac_output, hidden_states = self.forward_loop(
                observations=observations,
                recurrent_hidden_states=recurrent_hidden_states,
                prev_actions=prev_actions,
                masks=masks,  # type: ignore
            )

            return self.adapt_result(
                ac_output,
                hidden_states[-1:],
                num_steps,
                num_samplers,
                num_agents,
                num_layers,
                observations,
            )

        assert recurrent_hidden_states.shape[0] == 1

        images = cast(torch.FloatTensor, observations["minigrid_ego_image"])
        if self.use_cnn2:
            images_shape = images.shape
            # noinspection PyArgumentList
            images = images + torch.LongTensor(
                [0, 11, 22]).view(  # type:ignore
                    1, 1, 1, 3).to(images.device)
            images = self.semantic_embedding(images).view(  # type:ignore
                *images_shape[:3], 24)
        images = images.permute(0, 3, 1, 2).float()  # type:ignore

        _, nsamplers, _ = recurrent_hidden_states.shape
        rollouts_len = images.shape[0] // nsamplers

        masks = cast(torch.FloatTensor,
                     masks.view(rollouts_len, nsamplers, *masks.shape[1:]))
        instrs: Optional[torch.Tensor] = None
        if "minigrid_mission" in observations and self.use_instr:
            instrs = cast(torch.FloatTensor, observations["minigrid_mission"])
            instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1])

        needs_instr_reset_mask = masks != 1.0
        needs_instr_reset_mask[0] = 1
        needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1)
        blocking_inds: List[int] = np.where(
            needs_instr_reset_mask.view(rollouts_len,
                                        -1).any(-1).cpu().numpy())[0].tolist()
        blocking_inds.append(rollouts_len)

        instr_embeddings: Optional[torch.Tensor] = None
        if self.use_instr:
            instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip(
                *np.where(needs_instr_reset_mask.cpu().numpy())))
            time_ind_to_which_need_instr_reset: List[List] = [
                [] for _ in range(rollouts_len)
            ]
            reset_multi_ind_to_index = {
                mi: i
                for i, mi in enumerate(instr_reset_multi_inds)
            }
            for a, b in instr_reset_multi_inds:
                time_ind_to_which_need_instr_reset[a].append(b)

            unique_instr_embeddings = self._get_instr_embedding(
                instrs[needs_instr_reset_mask])

            instr_embeddings_list = [unique_instr_embeddings[:nsamplers]]
            current_instr_embeddings_list = list(instr_embeddings_list[-1])

            for time_ind in range(1, rollouts_len):
                if len(time_ind_to_which_need_instr_reset[time_ind]) == 0:
                    instr_embeddings_list.append(instr_embeddings_list[-1])
                else:
                    for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[
                            time_ind]:
                        current_instr_embeddings_list[
                            sampler_needing_reset_ind] = unique_instr_embeddings[
                                reset_multi_ind_to_index[(
                                    time_ind, sampler_needing_reset_ind)]]

                    instr_embeddings_list.append(
                        torch.stack(current_instr_embeddings_list, dim=0))

            instr_embeddings = torch.stack(instr_embeddings_list, dim=0)

        # The following code can be used to compute the instr_embeddings in another way
        # and thus verify that the above logic is (more likely to be) correct
        # needs_instr_reset_mask = (masks != 1.0)
        # needs_instr_reset_mask[0] *= 0
        # needs_instr_reset_inds = needs_instr_reset_mask.view(nrollouts, -1).any(-1).cpu().numpy()
        #
        # # Get inds where a new task has started
        # blocking_inds: List[int] = np.where(needs_instr_reset_inds)[0].tolist()
        # blocking_inds.append(needs_instr_reset_inds.shape[0])
        # if nrollouts != 1:
        #     pdb.set_trace()
        # if blocking_inds[0] != 0:
        #     blocking_inds.insert(0, 0)
        # if self.use_instr:
        #     instr_embeddings_list = []
        #     for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]):
        #         instr_embeddings_list.append(
        #             self._get_instr_embedding(instrs[ind0])
        #             .unsqueeze(0)
        #             .repeat(ind1 - ind0, 1, 1)
        #         )
        #     tmp_instr_embeddings = torch.cat(instr_embeddings_list, dim=0)
        # assert (instr_embeddings - tmp_instr_embeddings).abs().max().item() < 1e-6

        # Embed images
        # images = images.view(nrollouts, nsamplers, *images.shape[1:])
        image_embeddings = self.image_conv(images)
        if self.arch.startswith("expert_filmcnn"):
            instr_embeddings_flatter = instr_embeddings.view(
                -1, *instr_embeddings.shape[2:])
            for controller in self.controllers:
                image_embeddings = controller(image_embeddings,
                                              instr_embeddings_flatter)
            image_embeddings = F.relu(self.film_pool(image_embeddings))

        image_embeddings = image_embeddings.view(rollouts_len, nsamplers, -1)

        if self.use_instr and self.lang_model == "attgru":
            raise NotImplementedError("Currently attgru is not implemented.")

        memory = None
        if self.use_memory:
            assert recurrent_hidden_states.shape[0] == 1
            hidden = (
                recurrent_hidden_states[:, :, :self.semi_memory_size],
                recurrent_hidden_states[:, :, self.semi_memory_size:],
            )
            embeddings_list = []
            for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]):
                hidden = (hidden[0] * masks[ind0], hidden[1] * masks[ind0])
                rnn_out, hidden = self.memory_rnn(image_embeddings[ind0:ind1],
                                                  hidden)
                embeddings_list.append(rnn_out)

            # embedding = hidden[0]
            embedding = torch.cat(embeddings_list, dim=0)
            memory = torch.cat(hidden, dim=-1)
        else:
            embedding = image_embeddings

        if self.use_instr and not "filmcnn" in self.arch:
            embedding = torch.cat((embedding, instr_embeddings), dim=-1)

        if hasattr(self, "aux_info") and self.aux_info:
            extra_predictions = {
                info: self.extra_heads[info](embedding)
                for info in self.extra_heads
            }
        else:
            extra_predictions = dict()

        embedding = embedding.view(rollouts_len * nsamplers, -1)

        ac_output = ActorCriticOutput(
            distributions=CategoricalDistr(logits=self.actor(embedding), ),
            values=self.critic(embedding),
            extras=extra_predictions if not self.include_auxiliary_head else {
                **extra_predictions,
                "auxiliary_distributions":
                CategoricalDistr(logits=self.aux(embedding)),
            },
        )
        hidden_states = memory

        return self.adapt_result(
            ac_output,
            hidden_states,
            num_steps,
            num_samplers,
            num_agents,
            num_layers,
            observations,
        )
Beispiel #16
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        """Processes input batched observations to produce new actor and critic
        values. Processes input batched observations (along with prior hidden
        states, previous actions, and masks denoting which recurrent hidden
        states should be masked) and returns an `ActorCriticOutput` object
        containing the model's policy (distribution over actions) and
        evaluation of the current state (value).

        # Parameters
        observations : Batched input observations.
        memory : `Memory` containing the hidden states from initial timepoints.
        prev_actions : Tensor of previous actions taken.
        masks : Masks applied to hidden states. See `RNNStateEncoder`.
        # Returns
        Tuple of the `ActorCriticOutput` and recurrent hidden state.
        """

        # 1.1 use perception model (i.e. encoder) to get observation embeddings
        obs_embeds = self.forward_encoder(observations)
        # 1.2 use embedding model to get prev_action embeddings
        prev_actions_embeds = self.prev_action_embedder(prev_actions)
        joint_embeds = torch.cat((obs_embeds, prev_actions_embeds),
                                 dim=-1)  # (T, N, *)

        # 2. use RNNs to get single/multiple beliefs
        beliefs_dict = {}
        for key, model in self.state_encoders.items():
            beliefs_dict[key], rnn_hidden_states = model(
                joint_embeds, memory.tensor(key), masks)
            memory.set_tensor(key, rnn_hidden_states)  # update memory here

        # 3. fuse beliefs for multiple belief models
        beliefs, task_weights = self.fuse_beliefs(beliefs_dict,
                                                  obs_embeds)  # fused beliefs

        # 4. prepare output
        extras = ({
            aux_uuid: {
                "beliefs":
                (beliefs_dict[aux_uuid] if self.multiple_beliefs else beliefs),
                "obs_embeds":
                obs_embeds,
                "aux_model": (self.aux_models[aux_uuid]
                              if aux_uuid in self.aux_models else None),
            }
            for aux_uuid in self.auxiliary_uuids
        } if self.auxiliary_uuids is not None else {})

        if self.multiple_beliefs:
            extras[MultiAuxTaskNegEntropyLoss.UUID] = task_weights

        actor_critic_output = ActorCriticOutput(
            distributions=self.actor(beliefs),
            values=self.critic(beliefs),
            extras=extras,
        )

        return actor_critic_output, memory
Beispiel #17
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        in_walkthrough_phase_mask = observations[
            self.in_walkthrough_phase_uuid]
        in_unshuffle_phase_mask = ~in_walkthrough_phase_mask
        in_walkthrough_float = in_walkthrough_phase_mask.float()
        in_unshuffle_float = in_unshuffle_phase_mask.float()

        # Don't reset hidden state at start of the unshuffle task
        masks_no_unshuffle_reset = (masks.bool()
                                    | in_unshuffle_phase_mask).float()
        masks_with_unshuffle_reset = masks.float()
        del masks  # Just to make sure we don't accidentally use `masks when we want `masks_no_unshuffle_reset`

        # Visual features
        cur_img_resnet = observations[self.rgb_uuid]
        unshuffled_img_resnet = observations[self.unshuffled_rgb_uuid]
        concat_img = torch.cat(
            (
                cur_img_resnet,
                unshuffled_img_resnet,
                cur_img_resnet * unshuffled_img_resnet,
            ),
            dim=-3,
        )
        batch_shape, features_shape = concat_img.shape[:-3], concat_img.shape[
            -3:]
        concat_img_reshaped = concat_img.view(-1, *features_shape)
        attention_probs = torch.softmax(
            self.visual_attention(concat_img_reshaped).view(
                concat_img_reshaped.shape[0], -1),
            dim=-1,
        ).view(concat_img_reshaped.shape[0], 1,
               *concat_img_reshaped.shape[-2:])
        vis_features = ((self.visual_encoder(concat_img_reshaped) *
                         attention_probs).mean(-1).mean(-1))
        vis_features = vis_features.view(*batch_shape, -1)

        # Various embeddings
        prev_action_embeddings = self.prev_action_embedder(
            ((~masks_with_unshuffle_reset.bool()).long() *
             (prev_actions.unsqueeze(-1) + 1))).squeeze(-2)
        is_walkthrough_phase_embedding = self.is_walkthrough_phase_embedder(
            in_walkthrough_phase_mask.long()).squeeze(-2)

        to_cat = [
            vis_features,
            prev_action_embeddings,
            is_walkthrough_phase_embedding,
        ]

        rnn_hidden_states = memory.tensor("rnn")
        rnn_outs = []
        obs_for_rnn = torch.cat(to_cat, dim=-1)
        last_walkthrough_encoding = memory.tensor("walkthrough_encoding")

        for step in range(masks_with_unshuffle_reset.shape[0]):
            rnn_out, rnn_hidden_states = self.state_encoder(
                torch.cat(
                    (
                        obs_for_rnn[step:step + 1],
                        last_walkthrough_encoding *
                        masks_no_unshuffle_reset[step:step + 1],
                    ),
                    dim=-1,
                ),
                rnn_hidden_states,
                masks_with_unshuffle_reset[step:step + 1],
            )
            rnn_outs.append(rnn_out)

            walkthrough_encoding, _ = self.walkthrough_encoder(
                rnn_out,
                last_walkthrough_encoding,
                masks_no_unshuffle_reset[step:step + 1],
            )
            last_walkthrough_encoding = (
                last_walkthrough_encoding * in_unshuffle_float[step:step + 1] +
                walkthrough_encoding * in_walkthrough_float[step:step + 1])

        memory = memory.set_tensor("walkthrough_encoding",
                                   last_walkthrough_encoding)

        rnn_out = torch.cat(rnn_outs, dim=0)
        walkthrough_dist, walkthrough_vals = self.walkthrough_ac(rnn_out)
        unshuffle_dist, unshuffle_vals = self.unshuffle_ac(rnn_out)

        assert len(in_walkthrough_float.shape) == len(
            walkthrough_dist.logits.shape)

        if self.walkthrough_good_action_logits is not None:
            walkthrough_logits = (
                walkthrough_dist.logits +
                self.walkthrough_good_action_logits.view(
                    *((1, ) * (len(walkthrough_dist.logits.shape) - 1)), -1))
        else:
            walkthrough_logits = walkthrough_dist.logits

        actor = CategoricalDistr(
            logits=in_walkthrough_float * walkthrough_logits +
            in_unshuffle_float * unshuffle_dist.logits)
        values = (in_walkthrough_float * walkthrough_vals +
                  in_unshuffle_float * unshuffle_vals)

        ac_output = ActorCriticOutput(distributions=actor,
                                      values=values,
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
Beispiel #18
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        in_walkthrough_phase_mask = observations[
            self.in_walkthrough_phase_uuid]
        in_unshuffle_phase_mask = ~in_walkthrough_phase_mask
        in_walkthrough_float = in_walkthrough_phase_mask.float()
        in_unshuffle_float = in_unshuffle_phase_mask.float()

        # Don't reset hidden state at start of the unshuffle task
        masks_no_unshuffle_reset = (masks.bool()
                                    | in_unshuffle_phase_mask).float()

        cur_img = observations[self.rgb_uuid]
        unshuffled_img = observations[self.unshuffled_rgb_uuid]
        concat_img = torch.cat((cur_img, unshuffled_img), dim=-1)

        # Various embeddings
        vis_features = self.visual_encoder({self.concat_rgb_uuid: concat_img})
        prev_action_embeddings = self.prev_action_embedder(
            ((~masks.bool()).long() *
             (prev_actions.unsqueeze(-1) + 1))).squeeze(-2)
        is_walkthrough_phase_embedding = self.is_walkthrough_phase_embedder(
            in_walkthrough_phase_mask.long()).squeeze(-2)

        to_cat = [
            vis_features,
            prev_action_embeddings,
            is_walkthrough_phase_embedding,
        ]

        rnn_hidden_states = memory.tensor("rnn")
        rnn_outs = []
        obs_for_rnn = torch.cat(to_cat, dim=-1)
        last_walkthrough_encoding = memory.tensor("walkthrough_encoding")

        for step in range(masks.shape[0]):
            rnn_out, rnn_hidden_states = self.state_encoder(
                torch.cat(
                    (obs_for_rnn[step:step + 1], last_walkthrough_encoding),
                    dim=-1),
                rnn_hidden_states,
                masks[step:step + 1],
            )
            rnn_outs.append(rnn_out)

            walkthrough_encoding, _ = self.walkthrough_encoder(
                rnn_out,
                last_walkthrough_encoding,
                masks_no_unshuffle_reset[step:step + 1],
            )
            last_walkthrough_encoding = (
                last_walkthrough_encoding * in_unshuffle_float[step:step + 1] +
                walkthrough_encoding * in_walkthrough_float[step:step + 1])

        memory = memory.set_tensor("walkthrough_encoding",
                                   last_walkthrough_encoding)

        rnn_out = torch.cat(rnn_outs, dim=0)
        walkthrough_dist, walkthrough_vals = self.walkthrough_ac(rnn_out)
        unshuffle_dist, unshuffle_vals = self.unshuffle_ac(rnn_out)

        assert len(in_walkthrough_float.shape) == len(
            walkthrough_dist.logits.shape)

        if self.walkthrough_good_action_logits is not None:
            walkthrough_logits = (
                walkthrough_dist.logits +
                self.walkthrough_good_action_logits.view(
                    *((1, ) * (len(walkthrough_dist.logits.shape) - 1)), -1))
        else:
            walkthrough_logits = walkthrough_dist.logits

        actor = CategoricalDistr(
            logits=in_walkthrough_float * walkthrough_logits +
            in_unshuffle_float * unshuffle_dist.logits)
        values = (in_walkthrough_float * walkthrough_vals +
                  in_unshuffle_float * unshuffle_vals)

        ac_output = ActorCriticOutput(distributions=actor,
                                      values=values,
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
Beispiel #19
0
    def forward_loop(
        self,
        observations: ObservationType,
        recurrent_hidden_states: torch.FloatTensor,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ):
        results = []
        images = cast(torch.FloatTensor,
                      observations["minigrid_ego_image"]).float()
        instrs: Optional[torch.Tensor] = None
        if "minigrid_mission" in observations:
            instrs = cast(torch.Tensor, observations["minigrid_mission"])

        _, nsamplers, _ = recurrent_hidden_states.shape
        rollouts_len = images.shape[0] // nsamplers
        obs = babyai.rl.DictList()

        images = images.view(rollouts_len, nsamplers, *images.shape[1:])
        masks = masks.view(rollouts_len, nsamplers,
                           *masks.shape[1:])  # type:ignore

        # needs_reset = (masks != 1.0).view(nrollouts, -1).any(-1)
        if instrs is not None:
            instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1])

        needs_instr_reset_mask = masks != 1.0
        needs_instr_reset_mask[0] = 1
        needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1)
        instr_embeddings: Optional[torch.Tensor] = None
        if self.use_instr:
            instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip(
                *np.where(needs_instr_reset_mask.cpu().numpy())))
            time_ind_to_which_need_instr_reset: List[List] = [
                [] for _ in range(rollouts_len)
            ]
            reset_multi_ind_to_index = {
                mi: i
                for i, mi in enumerate(instr_reset_multi_inds)
            }
            for a, b in instr_reset_multi_inds:
                time_ind_to_which_need_instr_reset[a].append(b)

            unique_instr_embeddings = self._get_instr_embedding(
                instrs[needs_instr_reset_mask])

            instr_embeddings_list = [unique_instr_embeddings[:nsamplers]]
            current_instr_embeddings_list = list(instr_embeddings_list[-1])

            for time_ind in range(1, rollouts_len):
                if len(time_ind_to_which_need_instr_reset[time_ind]) == 0:
                    instr_embeddings_list.append(instr_embeddings_list[-1])
                else:
                    for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[
                            time_ind]:
                        current_instr_embeddings_list[
                            sampler_needing_reset_ind] = unique_instr_embeddings[
                                reset_multi_ind_to_index[(
                                    time_ind, sampler_needing_reset_ind)]]

                    instr_embeddings_list.append(
                        torch.stack(current_instr_embeddings_list, dim=0))

            instr_embeddings = torch.stack(instr_embeddings_list, dim=0)

        assert recurrent_hidden_states.shape[0] == 1
        memory = recurrent_hidden_states[0]
        # instr_embedding: Optional[torch.Tensor] = None
        for i in range(rollouts_len):
            obs.image = images[i]
            if "minigrid_mission" in observations:
                obs.instr = instrs[i]

            # reset = needs_reset[i].item()
            # if self.baby_ai_model.use_instr and (reset or i == 0):
            #     instr_embedding = self.baby_ai_model._get_instr_embedding(obs.instr)

            results.append(
                self.forward_once(obs,
                                  memory=memory * masks[i],
                                  instr_embedding=instr_embeddings[i]))
            memory = results[-1]["memory"]

        embedding = torch.cat([r["embedding"] for r in results], dim=0)

        extra_predictions_list = [r["extra_predictions"] for r in results]
        extra_predictions = {
            key: torch.cat([ep[key] for ep in extra_predictions_list], dim=0)
            for key in extra_predictions_list[0]
        }
        return (
            ActorCriticOutput(
                distributions=CategoricalDistr(logits=self.actor(embedding), ),
                values=self.critic(embedding),
                extras=extra_predictions
                if not self.include_auxiliary_head else {
                    **extra_predictions,
                    "auxiliary_distributions":
                    cast(Any, CategoricalDistr(logits=self.aux(embedding))),
                },
            ),
            torch.stack([r["memory"] for r in results], dim=0),
        )