示例#1
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        target_encoding = self.get_target_coordinates_encoding(observations)
        x: Union[torch.Tensor, List[torch.Tensor]]
        x = [target_encoding]

        # if observations["rgb"].shape[0] != 1:
        #     print("rgb", (observations["rgb"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["rgb"][...,0,0,:]).float().mean())
        #     if "depth" in observations:
        #         print("depth", (observations["depth"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["depth"][...,0,0,:]).float().mean())

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            if self.sensor_fusion:
                perception_embed = self.sensor_fuser(perception_embed)
            x = [perception_embed] + x

        x = torch.cat(x, dim=-1)
        x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks)

        ac_output = ActorCriticOutput(
            distributions=self.actor(x), values=self.critic(x), extras={}
        )

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
示例#2
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
        else:
            # TODO manage blindness for all agents simultaneously or separate?
            raise NotImplementedError()

        # TODO alternative where all agents consume all observations
        x, rnn_hidden_states = self.state_encoder(perception_embed,
                                                  memory.tensor("rnn"), masks)

        dists, vals = self.actor_critic(x)

        return (
            ActorCriticOutput(
                distributions=dists,
                values=vals,
                extras={},
            ),
            memory.set_tensor("rnn", rnn_hidden_states),
        )
示例#3
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        out = self.linear(cast(torch.Tensor, observations[self.key]))

        assert len(out.shape) in [
            3,
            4,
        ], "observations must be [step, sampler, data] or [step, sampler, agent, data]"

        if len(out.shape) == 3:
            # [step, sampler, data] -> [step, sampler, agent, data]
            out = out.unsqueeze(-2)

        main_logits = out[..., :self.num_actions]
        aux_logits = out[..., self.num_actions:-1]
        values = out[..., -1:]

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(
                distributions=CategoricalDistr(logits=main_logits),
                values=typing.cast(torch.FloatTensor, values),
                extras={
                    "auxiliary_distributions":
                    CategoricalDistr(logits=aux_logits)
                },
            ),
            None,
        )
示例#4
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        target_encoding = self.get_target_coordinates_encoding(observations)
        x: Union[torch.Tensor, List[torch.Tensor]]
        x = [target_encoding]

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            if self.sensor_fusion:
                perception_embed = self.sensor_fuser(perception_embed)
            x = [perception_embed] + x

        x = torch.cat(x, dim=-1)
        x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"),
                                                  masks)

        ac_output = ActorCriticOutput(distributions=self.actor(x),
                                      values=self.critic(x),
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
示例#5
0
 def forward(  # type:ignore
     self,
     observations: ObservationType,
     memory: Memory,
     prev_actions: torch.Tensor,
     masks: torch.FloatTensor,
 ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
     x = self.goal_visual_encoder(observations)
     x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks)
     return (
         ActorCriticOutput(
             distributions=self.actor(x), values=self.critic(x), extras={}
         ),
         memory.set_tensor("rnn", rnn_hidden_states),
     )
示例#6
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        """Processes input batched observations to produce new actor and critic
        values. Processes input batched observations (along with prior hidden
        states, previous actions, and masks denoting which recurrent hidden
        states should be masked) and returns an `ActorCriticOutput` object
        containing the model's policy (distribution over actions) and
        evaluation of the current state (value).

        # Parameters
        observations : Batched input observations.
        rnn_hidden_states : Hidden states from initial timepoints.
        prev_actions : Tensor of previous actions taken.
        masks : Masks applied to hidden states. See `RNNStateEncoder`.
        # Returns
        Tuple of the `ActorCriticOutput` and recurrent hidden state.
        """
        target_encoding = self.get_object_type_encoding(
            cast(Dict[str, torch.FloatTensor], observations)
        )
        x = [target_encoding]

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            x = [perception_embed] + x

        x_cat = torch.cat(x, dim=-1)  # type: ignore
        x_out, rnn_hidden_states = self.state_encoder(
            x_cat, memory.tensor("rnn"), masks
        )

        return (
            ActorCriticOutput(
                distributions=self.actor(x_out), values=self.critic(x_out), extras={}
            ),
            memory.set_tensor("rnn", rnn_hidden_states),
        )
示例#7
0
    def adapt_result(ac_output, hidden_states, num_steps, num_samplers,
                     num_agents, num_layers, observations):  # type: ignore
        distributions = CategoricalDistr(
            logits=ac_output.distributions.logits.view(num_steps, num_samplers,
                                                       num_agents, -1), )
        values = ac_output.values.view(num_steps, num_samplers, num_agents, 1)
        extras = ac_output.extras  # ignore shape
        # TODO confirm the shape of the auxiliary distribution is the same as the actor's
        if "auxiliary_distributions" in extras:
            extras["auxiliary_distributions"] = CategoricalDistr(
                logits=extras["auxiliary_distributions"].logits.view(
                    num_steps, num_samplers, num_agents, -1), )

        hidden_states = hidden_states.view(num_layers,
                                           num_samplers * num_agents, -1)

        # Unflatten all observation batch dims
        def recursively_adapt_observations(obs, num_steps, num_samplers,
                                           num_agents):
            for entry in obs:
                if isinstance(obs[entry], Dict):
                    recursively_adapt_observations(obs[entry], num_steps,
                                                   num_samplers, num_agents)
                else:
                    assert isinstance(obs[entry], torch.Tensor)
                    if entry in ["minigrid_ego_image", "minigrid_mission"]:
                        final_dims = obs[entry].shape[
                            1:]  # assumes no agents dim in observations!
                        obs[entry] = obs[entry].view(num_steps,
                                                     num_samplers * num_agents,
                                                     *final_dims)

        recursively_adapt_observations(observations, num_steps, num_samplers,
                                       num_agents)

        return (
            ActorCriticOutput(distributions=distributions,
                              values=values,
                              extras=extras),
            hidden_states,
        )
示例#8
0
    def forward(self, observations, memory, prev_actions, masks):
        out = self.linear(observations[self.input_uuid])

        assert len(out.shape) in [
            3,
            4,
        ], "observations must be [step, sampler, data] or [step, sampler, agent, data]"

        if len(out.shape) == 3:
            # [step, sampler, data] -> [step, sampler, agent, data]
            out = out.unsqueeze(-2)

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(
                distributions=CategoricalDistr(logits=out[..., :-1]),
                values=cast(torch.FloatTensor, out[..., -1:]),
                extras={},
            ),
            None,
        )
示例#9
0
    def forward(
        self,
        observations: ObservationType,
        recurrent_hidden_states: torch.FloatTensor,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ):
        (
            observations,
            recurrent_hidden_states,
            prev_actions,
            masks,
            num_steps,
            num_samplers,
            num_agents,
            num_layers,
        ) = self.adapt_inputs(observations, recurrent_hidden_states,
                              prev_actions, masks)

        if self.lang_model != "gru":
            ac_output, hidden_states = self.forward_loop(
                observations=observations,
                recurrent_hidden_states=recurrent_hidden_states,
                prev_actions=prev_actions,
                masks=masks,
            )

            return self.adapt_result(
                ac_output,
                hidden_states[-1:],
                num_steps,
                num_samplers,
                num_agents,
                num_layers,
                observations,
            )

        assert recurrent_hidden_states.shape[0] == 1

        images = cast(torch.FloatTensor, observations["minigrid_ego_image"])
        if self.use_cnn2:
            images_shape = images.shape
            images = images + torch.LongTensor(
                [0, 11, 22]).view(  # type:ignore
                    1, 1, 1, 3).to(images.device)
            images = self.semantic_embedding(images).view(  # type:ignore
                *images_shape[:3], 24)
        images = images.permute(0, 3, 1, 2).float()  # type:ignore

        _, nsamplers, _ = recurrent_hidden_states.shape
        rollouts_len = images.shape[0] // nsamplers

        masks = typing.cast(
            torch.FloatTensor,
            masks.view(rollouts_len, nsamplers, *masks.shape[1:]))
        instrs: Optional[torch.Tensor] = None
        if "minigrid_mission" in observations and self.use_instr:
            instrs = cast(torch.FloatTensor, observations["minigrid_mission"])
            instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1])

        needs_instr_reset_mask = masks != 1.0
        needs_instr_reset_mask[0] = 1
        needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1)
        blocking_inds: List[int] = np.where(
            needs_instr_reset_mask.view(rollouts_len,
                                        -1).any(-1).cpu().numpy())[0].tolist()
        blocking_inds.append(rollouts_len)

        instr_embeddings: Optional[torch.Tensor] = None
        if self.use_instr:
            instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip(
                *np.where(needs_instr_reset_mask.cpu().numpy())))
            time_ind_to_which_need_instr_reset: List[List] = [
                [] for _ in range(rollouts_len)
            ]
            reset_multi_ind_to_index = {
                mi: i
                for i, mi in enumerate(instr_reset_multi_inds)
            }
            for a, b in instr_reset_multi_inds:
                time_ind_to_which_need_instr_reset[a].append(b)

            unique_instr_embeddings = self._get_instr_embedding(
                instrs[needs_instr_reset_mask])

            instr_embeddings_list = []
            instr_embeddings_list.append(unique_instr_embeddings[:nsamplers])
            current_instr_embeddings_list = list(instr_embeddings_list[-1])

            for time_ind in range(1, rollouts_len):
                if len(time_ind_to_which_need_instr_reset[time_ind]) == 0:
                    instr_embeddings_list.append(instr_embeddings_list[-1])
                else:
                    for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[
                            time_ind]:
                        current_instr_embeddings_list[
                            sampler_needing_reset_ind] = unique_instr_embeddings[
                                reset_multi_ind_to_index[(
                                    time_ind, sampler_needing_reset_ind)]]

                    instr_embeddings_list.append(
                        torch.stack(current_instr_embeddings_list, dim=0))

            instr_embeddings = torch.stack(instr_embeddings_list, dim=0)

        # The following code can be used to compute the instr_embeddings in another way
        # and thus verify that the above logic is (more likely to be) correct
        # needs_instr_reset_mask = (masks != 1.0)
        # needs_instr_reset_mask[0] *= 0
        # needs_instr_reset_inds = needs_instr_reset_mask.view(nrollouts, -1).any(-1).cpu().numpy()
        #
        # # Get inds where a new task has started
        # blocking_inds: List[int] = np.where(needs_instr_reset_inds)[0].tolist()
        # blocking_inds.append(needs_instr_reset_inds.shape[0])
        # if nrollouts != 1:
        #     pdb.set_trace()
        # if blocking_inds[0] != 0:
        #     blocking_inds.insert(0, 0)
        # if self.use_instr:
        #     instr_embeddings_list = []
        #     for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]):
        #         instr_embeddings_list.append(
        #             self._get_instr_embedding(instrs[ind0])
        #             .unsqueeze(0)
        #             .repeat(ind1 - ind0, 1, 1)
        #         )
        #     tmp_instr_embeddings = torch.cat(instr_embeddings_list, dim=0)
        # assert (instr_embeddings - tmp_instr_embeddings).abs().max().item() < 1e-6

        # Embed images
        # images = images.view(nrollouts, nsamplers, *images.shape[1:])
        image_embeddings = self.image_conv(images)
        if self.arch.startswith("expert_filmcnn"):
            instr_embeddings_flatter = instr_embeddings.view(
                -1, *instr_embeddings.shape[2:])
            for controller in self.controllers:
                image_embeddings = controller(image_embeddings,
                                              instr_embeddings_flatter)
            image_embeddings = F.relu(self.film_pool(image_embeddings))

        image_embeddings = image_embeddings.view(rollouts_len, nsamplers, -1)

        if self.use_instr and self.lang_model == "attgru":
            raise NotImplementedError("Currently attgru is not implemented.")

        memory = None
        if self.use_memory:
            assert recurrent_hidden_states.shape[0] == 1
            hidden = (
                recurrent_hidden_states[:, :, :self.semi_memory_size],
                recurrent_hidden_states[:, :, self.semi_memory_size:],
            )
            embeddings_list = []
            for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]):
                hidden = (hidden[0] * masks[ind0], hidden[1] * masks[ind0])
                rnn_out, hidden = self.memory_rnn(image_embeddings[ind0:ind1],
                                                  hidden)
                embeddings_list.append(rnn_out)

            # embedding = hidden[0]
            embedding = torch.cat(embeddings_list, dim=0)
            memory = torch.cat(hidden, dim=-1)
        else:
            embedding = image_embeddings

        if self.use_instr and not "filmcnn" in self.arch:
            embedding = torch.cat((embedding, instr_embeddings), dim=-1)

        if hasattr(self, "aux_info") and self.aux_info:
            extra_predictions = {
                info: self.extra_heads[info](embedding)
                for info in self.extra_heads
            }
        else:
            extra_predictions = dict()

        embedding = embedding.view(rollouts_len * nsamplers, -1)

        ac_output = ActorCriticOutput(
            distributions=CategoricalDistr(logits=self.actor(embedding), ),
            values=self.critic(embedding),
            extras=extra_predictions if not self.include_auxiliary_head else {
                **extra_predictions,
                "auxiliary_distributions":
                CategoricalDistr(logits=self.aux(embedding)),
            },
        )
        hidden_states = memory

        return self.adapt_result(
            ac_output,
            hidden_states,
            num_steps,
            num_samplers,
            num_agents,
            num_layers,
            observations,
        )
示例#10
0
    def forward_loop(
        self,
        observations: ObservationType,
        recurrent_hidden_states: torch.FloatTensor,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ):
        results = []
        images = cast(torch.FloatTensor,
                      observations["minigrid_ego_image"]).float()
        instrs: Optional[torch.Tensor] = None
        if "minigrid_mission" in observations:
            instrs = cast(torch.Tensor, observations["minigrid_mission"])

        _, nsamplers, _ = recurrent_hidden_states.shape
        rollouts_len = images.shape[0] // nsamplers
        obs = babyai.rl.DictList()

        images = images.view(rollouts_len, nsamplers, *images.shape[1:])
        masks = masks.view(rollouts_len, nsamplers,
                           *masks.shape[1:])  # type:ignore

        # needs_reset = (masks != 1.0).view(nrollouts, -1).any(-1)
        if instrs is not None:
            instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1])

        needs_instr_reset_mask = masks != 1.0
        needs_instr_reset_mask[0] = 1
        needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1)
        instr_embeddings: Optional[torch.Tensor] = None
        if self.use_instr:
            instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip(
                *np.where(needs_instr_reset_mask.cpu().numpy())))
            time_ind_to_which_need_instr_reset: List[List] = [
                [] for _ in range(rollouts_len)
            ]
            reset_multi_ind_to_index = {
                mi: i
                for i, mi in enumerate(instr_reset_multi_inds)
            }
            for a, b in instr_reset_multi_inds:
                time_ind_to_which_need_instr_reset[a].append(b)

            unique_instr_embeddings = self._get_instr_embedding(
                instrs[needs_instr_reset_mask])

            instr_embeddings_list = []
            instr_embeddings_list.append(unique_instr_embeddings[:nsamplers])
            current_instr_embeddings_list = list(instr_embeddings_list[-1])

            for time_ind in range(1, rollouts_len):
                if len(time_ind_to_which_need_instr_reset[time_ind]) == 0:
                    instr_embeddings_list.append(instr_embeddings_list[-1])
                else:
                    for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[
                            time_ind]:
                        current_instr_embeddings_list[
                            sampler_needing_reset_ind] = unique_instr_embeddings[
                                reset_multi_ind_to_index[(
                                    time_ind, sampler_needing_reset_ind)]]

                    instr_embeddings_list.append(
                        torch.stack(current_instr_embeddings_list, dim=0))

            instr_embeddings = torch.stack(instr_embeddings_list, dim=0)

        assert recurrent_hidden_states.shape[0] == 1
        memory = recurrent_hidden_states[0]
        # instr_embedding: Optional[torch.Tensor] = None
        for i in range(rollouts_len):
            obs.image = images[i]
            if "minigrid_mission" in observations:
                obs.instr = instrs[i]

            # reset = needs_reset[i].item()
            # if self.baby_ai_model.use_instr and (reset or i == 0):
            #     instr_embedding = self.baby_ai_model._get_instr_embedding(obs.instr)

            results.append(
                self.forward_once(obs,
                                  memory=memory * masks[i],
                                  instr_embedding=instr_embeddings[i]))
            memory = results[-1]["memory"]

        embedding = torch.cat([r["embedding"] for r in results], dim=0)

        extra_predictions_list = [r["extra_predictions"] for r in results]
        extra_predictions = {
            key: torch.cat([ep[key] for ep in extra_predictions_list], dim=0)
            for key in extra_predictions_list[0]
        }
        return (
            ActorCriticOutput(
                distributions=CategoricalDistr(logits=self.actor(embedding), ),
                values=self.critic(embedding),
                extras=extra_predictions
                if not self.include_auxiliary_head else {
                    **extra_predictions,
                    "auxiliary_distributions":
                    CategoricalDistr(logits=self.aux(embedding)),
                },
            ),
            torch.stack([r["memory"] for r in results], dim=0),
        )