Ejemplo n.º 1
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        out = self.linear(cast(torch.Tensor, observations[self.input_uuid]))

        main_logits = out[..., :self.num_actions]
        aux_logits = out[..., self.num_actions:-1]
        values = out[..., -1:]

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(
                distributions=cast(
                    DistributionType, CategoricalDistr(
                        logits=main_logits)),  # step x sampler x ...
                values=cast(torch.FloatTensor,
                            values.view(values.shape[:2] +
                                        (-1, ))),  # step x sampler x flattened
                extras={
                    "auxiliary_distributions":
                    CategoricalDistr(logits=aux_logits)
                },
            ),
            None,
        )
class TupleCategoricalDistr(Distr):
    def __init__(self, probs=None, logits=None, validate_args=None):
        self.dists = CategoricalDistr(probs=probs,
                                      logits=logits,
                                      validate_args=validate_args)

    def log_prob(self, actions: Tuple[torch.LongTensor,
                                      ...]) -> torch.FloatTensor:
        # flattened output [steps, samplers, num_agents]
        return self.dists.log_prob(torch.stack(actions, dim=-1))

    def entropy(self) -> torch.FloatTensor:
        # flattened output [steps, samplers, num_agents]
        return self.dists.entropy()

    def sample(
        self, sample_shape=torch.Size()) -> Tuple[torch.LongTensor, ...]:
        # split and remove trailing singleton dim
        res = self.dists.sample(sample_shape).split(1, dim=-1)
        return tuple([r.view(r.shape[:2]) for r in res])

    def mode(self) -> Tuple[torch.LongTensor, ...]:
        # split and remove trailing singleton dim
        res = self.dists.mode().split(1, dim=-1)
        return tuple([r.view(r.shape[:2]) for r in res])
Ejemplo n.º 3
0
    def group_loss(
        self,
        distribution: CategoricalDistr,
        expert_actions: torch.Tensor,
        expert_actions_masks: torch.Tensor,
    ):
        assert isinstance(distribution, CategoricalDistr) or (
            isinstance(distribution, ConditionalDistr)
            and isinstance(distribution.distr, CategoricalDistr)
        ), "This implementation only supports (groups of) `CategoricalDistr`"

        expert_successes = expert_actions_masks.sum()

        log_probs = distribution.log_prob(cast(torch.LongTensor, expert_actions))
        assert (
            log_probs.shape[: len(expert_actions_masks.shape)]
            == expert_actions_masks.shape
        )

        # Add dimensions to `expert_actions_masks` on the right to allow for masking
        # if necessary.
        len_diff = len(log_probs.shape) - len(expert_actions_masks.shape)
        assert len_diff >= 0
        expert_actions_masks = expert_actions_masks.view(
            *expert_actions_masks.shape, *((1,) * len_diff)
        )

        group_loss = -(expert_actions_masks * log_probs).sum() / torch.clamp(
            expert_successes, min=1
        )

        return group_loss, expert_successes
Ejemplo n.º 4
0
    def forward(self, x):
        out = self.master_and_critic(x)

        master_logits = out[..., :-1]
        values = out[..., -1:]
        # noinspection PyArgumentList

        cond1 = ConditionalDistr(
            distr_conditioned_on_input_fn_or_instance=CategoricalDistr(
                logits=master_logits
            ),
            action_group_name="higher",
        )
        cond2 = ConditionalDistr(
            distr_conditioned_on_input_fn_or_instance=lambda *args, **kwargs: ConditionedLinearActorCriticHead.lower_policy(
                self, *args, **kwargs
            ),
            action_group_name="lower",
            state_embedding=x,
        )

        return (
            SequentialDistr(cond1, cond2),
            values.view(*values.shape[:2], -1),  # [steps, samplers, flattened]
        )
Ejemplo n.º 5
0
    def adapt_result(ac_output, hidden_states, num_steps, num_samplers,
                     num_agents, num_layers, observations):  # type: ignore
        distributions = CategoricalDistr(
            logits=ac_output.distributions.logits.view(num_steps, num_samplers,
                                                       -1), )
        values = ac_output.values.view(num_steps, num_samplers, num_agents)
        extras = ac_output.extras  # ignore shape
        # TODO confirm the shape of the auxiliary distribution is the same as the actor's
        if "auxiliary_distributions" in extras:
            extras["auxiliary_distributions"] = CategoricalDistr(
                logits=extras["auxiliary_distributions"].logits.view(
                    num_steps,
                    num_samplers,
                    -1  # assume single-agent
                ), )

        hidden_states = hidden_states.view(num_layers,
                                           num_samplers * num_agents, -1)

        # Unflatten all observation batch dims
        def recursively_adapt_observations(obs):
            for entry in obs:
                if isinstance(obs[entry], Dict):
                    recursively_adapt_observations(obs[entry])
                else:
                    assert isinstance(obs[entry], torch.Tensor)
                    if entry in ["minigrid_ego_image", "minigrid_mission"]:
                        final_dims = obs[entry].shape[
                            1:]  # assumes no agents dim in observations!
                        obs[entry] = obs[entry].view(num_steps,
                                                     num_samplers * num_agents,
                                                     *final_dims)

        recursively_adapt_observations(observations)

        return (
            ActorCriticOutput(distributions=distributions,
                              values=values,
                              extras=extras),
            hidden_states,
        )
Ejemplo n.º 6
0
    def forward(self, x) -> Tuple[CategoricalDistr, torch.Tensor]:
        out = self.actor_and_critic(x)

        logits = out[..., :-1]
        values = out[..., -1:]
        # noinspection PyArgumentList
        return (
            # logits are [step, sampler, ...]
            CategoricalDistr(logits=logits),
            # values are [step, sampler, flattened]
            values.view(*values.shape[:2], -1),
        )
Ejemplo n.º 7
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        out = self.linear(cast(torch.Tensor, observations[self.key]))

        assert len(out.shape) in [
            3,
            4,
        ], "observations must be [step, sampler, data] or [step, sampler, agent, data]"

        if len(out.shape) == 3:
            # [step, sampler, data] -> [step, sampler, agent, data]
            out = out.unsqueeze(-2)

        main_logits = out[..., :self.num_actions]
        aux_logits = out[..., self.num_actions:-1]
        values = out[..., -1:]

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(
                distributions=cast(
                    DistributionType, CategoricalDistr(
                        logits=main_logits)),  # step x sampler x ...
                values=cast(torch.FloatTensor,
                            values.view(values.shape[:2] +
                                        (-1, ))),  # step x sampler x flattened
                extras={
                    "auxiliary_distributions":
                    CategoricalDistr(logits=aux_logits),
                },
            ),
            None,
        )
Ejemplo n.º 8
0
    def forward(self, observations, memory, prev_actions, masks):
        out = self.linear(observations[self.input_uuid])

        # noinspection PyArgumentList
        return (
            ActorCriticOutput(
                # ensure [steps, samplers, ...]
                distributions=CategoricalDistr(logits=out[..., :-1]),
                # ensure [steps, samplers, flattened]
                values=cast(torch.FloatTensor,
                            out[..., -1:].view(*out.shape[:2], -1)),
                extras={},
            ),
            None,
        )
Ejemplo n.º 9
0
    def loss(  # type: ignore
            self, step_count: int, batch: ObservationType,
            actor_critic_output: ActorCriticOutput[CategoricalDistr], *args,
            **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
        task_weights = actor_critic_output.extras[self.UUID]
        task_weights = task_weights.view(-1, self.num_tasks)
        entropy = CategoricalDistr(task_weights).entropy()

        avg_loss = (-entropy).mean()
        avg_task_weights = task_weights.mean(dim=0)  # (K)

        outputs = {"entropy_loss": cast(torch.Tensor, avg_loss).item()}
        for i in range(self.num_tasks):
            outputs["weight_" + self.task_names[i]] = cast(
                torch.Tensor, avg_task_weights[i]).item()

        return (
            avg_loss,
            outputs,
        )
Ejemplo n.º 10
0
    def forward(self, x: torch.FloatTensor):  # type: ignore
        x = self.linear(x)  # type:ignore

        # noinspection PyArgumentList
        return CategoricalDistr(logits=x)  # logits are [step, sampler, ...]
Ejemplo n.º 11
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        in_walkthrough_phase_mask = observations[
            self.in_walkthrough_phase_uuid]
        in_unshuffle_phase_mask = ~in_walkthrough_phase_mask
        in_walkthrough_float = in_walkthrough_phase_mask.float()
        in_unshuffle_float = in_unshuffle_phase_mask.float()

        # Don't reset hidden state at start of the unshuffle task
        masks_no_unshuffle_reset = (masks.bool()
                                    | in_unshuffle_phase_mask).float()
        masks_with_unshuffle_reset = masks.float()
        del masks  # Just to make sure we don't accidentally use `masks when we want `masks_no_unshuffle_reset`

        # Visual features
        cur_img_resnet = observations[self.rgb_uuid]
        unshuffled_img_resnet = observations[self.unshuffled_rgb_uuid]
        concat_img = torch.cat(
            (
                cur_img_resnet,
                unshuffled_img_resnet,
                cur_img_resnet * unshuffled_img_resnet,
            ),
            dim=-3,
        )
        batch_shape, features_shape = concat_img.shape[:-3], concat_img.shape[
            -3:]
        concat_img_reshaped = concat_img.view(-1, *features_shape)
        attention_probs = torch.softmax(
            self.visual_attention(concat_img_reshaped).view(
                concat_img_reshaped.shape[0], -1),
            dim=-1,
        ).view(concat_img_reshaped.shape[0], 1,
               *concat_img_reshaped.shape[-2:])
        vis_features = ((self.visual_encoder(concat_img_reshaped) *
                         attention_probs).mean(-1).mean(-1))
        vis_features = vis_features.view(*batch_shape, -1)

        # Various embeddings
        prev_action_embeddings = self.prev_action_embedder(
            ((~masks_with_unshuffle_reset.bool()).long() *
             (prev_actions.unsqueeze(-1) + 1))).squeeze(-2)
        is_walkthrough_phase_embedding = self.is_walkthrough_phase_embedder(
            in_walkthrough_phase_mask.long()).squeeze(-2)

        to_cat = [
            vis_features,
            prev_action_embeddings,
            is_walkthrough_phase_embedding,
        ]

        rnn_hidden_states = memory.tensor("rnn")
        rnn_outs = []
        obs_for_rnn = torch.cat(to_cat, dim=-1)
        last_walkthrough_encoding = memory.tensor("walkthrough_encoding")

        for step in range(masks_with_unshuffle_reset.shape[0]):
            rnn_out, rnn_hidden_states = self.state_encoder(
                torch.cat(
                    (
                        obs_for_rnn[step:step + 1],
                        last_walkthrough_encoding *
                        masks_no_unshuffle_reset[step:step + 1],
                    ),
                    dim=-1,
                ),
                rnn_hidden_states,
                masks_with_unshuffle_reset[step:step + 1],
            )
            rnn_outs.append(rnn_out)

            walkthrough_encoding, _ = self.walkthrough_encoder(
                rnn_out,
                last_walkthrough_encoding,
                masks_no_unshuffle_reset[step:step + 1],
            )
            last_walkthrough_encoding = (
                last_walkthrough_encoding * in_unshuffle_float[step:step + 1] +
                walkthrough_encoding * in_walkthrough_float[step:step + 1])

        memory = memory.set_tensor("walkthrough_encoding",
                                   last_walkthrough_encoding)

        rnn_out = torch.cat(rnn_outs, dim=0)
        walkthrough_dist, walkthrough_vals = self.walkthrough_ac(rnn_out)
        unshuffle_dist, unshuffle_vals = self.unshuffle_ac(rnn_out)

        assert len(in_walkthrough_float.shape) == len(
            walkthrough_dist.logits.shape)

        if self.walkthrough_good_action_logits is not None:
            walkthrough_logits = (
                walkthrough_dist.logits +
                self.walkthrough_good_action_logits.view(
                    *((1, ) * (len(walkthrough_dist.logits.shape) - 1)), -1))
        else:
            walkthrough_logits = walkthrough_dist.logits

        actor = CategoricalDistr(
            logits=in_walkthrough_float * walkthrough_logits +
            in_unshuffle_float * unshuffle_dist.logits)
        values = (in_walkthrough_float * walkthrough_vals +
                  in_unshuffle_float * unshuffle_vals)

        ac_output = ActorCriticOutput(distributions=actor,
                                      values=values,
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
Ejemplo n.º 12
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        in_walkthrough_phase_mask = observations[
            self.in_walkthrough_phase_uuid]
        in_unshuffle_phase_mask = ~in_walkthrough_phase_mask
        in_walkthrough_float = in_walkthrough_phase_mask.float()
        in_unshuffle_float = in_unshuffle_phase_mask.float()

        # Don't reset hidden state at start of the unshuffle task
        masks_no_unshuffle_reset = (masks.bool()
                                    | in_unshuffle_phase_mask).float()

        cur_img = observations[self.rgb_uuid]
        unshuffled_img = observations[self.unshuffled_rgb_uuid]
        concat_img = torch.cat((cur_img, unshuffled_img), dim=-1)

        # Various embeddings
        vis_features = self.visual_encoder({self.concat_rgb_uuid: concat_img})
        prev_action_embeddings = self.prev_action_embedder(
            ((~masks.bool()).long() *
             (prev_actions.unsqueeze(-1) + 1))).squeeze(-2)
        is_walkthrough_phase_embedding = self.is_walkthrough_phase_embedder(
            in_walkthrough_phase_mask.long()).squeeze(-2)

        to_cat = [
            vis_features,
            prev_action_embeddings,
            is_walkthrough_phase_embedding,
        ]

        rnn_hidden_states = memory.tensor("rnn")
        rnn_outs = []
        obs_for_rnn = torch.cat(to_cat, dim=-1)
        last_walkthrough_encoding = memory.tensor("walkthrough_encoding")

        for step in range(masks.shape[0]):
            rnn_out, rnn_hidden_states = self.state_encoder(
                torch.cat(
                    (obs_for_rnn[step:step + 1], last_walkthrough_encoding),
                    dim=-1),
                rnn_hidden_states,
                masks[step:step + 1],
            )
            rnn_outs.append(rnn_out)

            walkthrough_encoding, _ = self.walkthrough_encoder(
                rnn_out,
                last_walkthrough_encoding,
                masks_no_unshuffle_reset[step:step + 1],
            )
            last_walkthrough_encoding = (
                last_walkthrough_encoding * in_unshuffle_float[step:step + 1] +
                walkthrough_encoding * in_walkthrough_float[step:step + 1])

        memory = memory.set_tensor("walkthrough_encoding",
                                   last_walkthrough_encoding)

        rnn_out = torch.cat(rnn_outs, dim=0)
        walkthrough_dist, walkthrough_vals = self.walkthrough_ac(rnn_out)
        unshuffle_dist, unshuffle_vals = self.unshuffle_ac(rnn_out)

        assert len(in_walkthrough_float.shape) == len(
            walkthrough_dist.logits.shape)

        if self.walkthrough_good_action_logits is not None:
            walkthrough_logits = (
                walkthrough_dist.logits +
                self.walkthrough_good_action_logits.view(
                    *((1, ) * (len(walkthrough_dist.logits.shape) - 1)), -1))
        else:
            walkthrough_logits = walkthrough_dist.logits

        actor = CategoricalDistr(
            logits=in_walkthrough_float * walkthrough_logits +
            in_unshuffle_float * unshuffle_dist.logits)
        values = (in_walkthrough_float * walkthrough_vals +
                  in_unshuffle_float * unshuffle_vals)

        ac_output = ActorCriticOutput(distributions=actor,
                                      values=values,
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
 def __init__(self, probs=None, logits=None, validate_args=None):
     self.dists = CategoricalDistr(probs=probs,
                                   logits=logits,
                                   validate_args=validate_args)
Ejemplo n.º 14
0
    def forward(
        self,
        observations: ObservationType,
        recurrent_hidden_states: torch.FloatTensor,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ):
        (
            observations,
            recurrent_hidden_states,
            prev_actions,
            masks,
            num_steps,
            num_samplers,
            num_agents,
            num_layers,
        ) = self.adapt_inputs(observations, recurrent_hidden_states,
                              prev_actions, masks)

        if self.lang_model != "gru":
            ac_output, hidden_states = self.forward_loop(
                observations=observations,
                recurrent_hidden_states=recurrent_hidden_states,
                prev_actions=prev_actions,
                masks=masks,  # type: ignore
            )

            return self.adapt_result(
                ac_output,
                hidden_states[-1:],
                num_steps,
                num_samplers,
                num_agents,
                num_layers,
                observations,
            )

        assert recurrent_hidden_states.shape[0] == 1

        images = cast(torch.FloatTensor, observations["minigrid_ego_image"])
        if self.use_cnn2:
            images_shape = images.shape
            # noinspection PyArgumentList
            images = images + torch.LongTensor(
                [0, 11, 22]).view(  # type:ignore
                    1, 1, 1, 3).to(images.device)
            images = self.semantic_embedding(images).view(  # type:ignore
                *images_shape[:3], 24)
        images = images.permute(0, 3, 1, 2).float()  # type:ignore

        _, nsamplers, _ = recurrent_hidden_states.shape
        rollouts_len = images.shape[0] // nsamplers

        masks = cast(torch.FloatTensor,
                     masks.view(rollouts_len, nsamplers, *masks.shape[1:]))
        instrs: Optional[torch.Tensor] = None
        if "minigrid_mission" in observations and self.use_instr:
            instrs = cast(torch.FloatTensor, observations["minigrid_mission"])
            instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1])

        needs_instr_reset_mask = masks != 1.0
        needs_instr_reset_mask[0] = 1
        needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1)
        blocking_inds: List[int] = np.where(
            needs_instr_reset_mask.view(rollouts_len,
                                        -1).any(-1).cpu().numpy())[0].tolist()
        blocking_inds.append(rollouts_len)

        instr_embeddings: Optional[torch.Tensor] = None
        if self.use_instr:
            instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip(
                *np.where(needs_instr_reset_mask.cpu().numpy())))
            time_ind_to_which_need_instr_reset: List[List] = [
                [] for _ in range(rollouts_len)
            ]
            reset_multi_ind_to_index = {
                mi: i
                for i, mi in enumerate(instr_reset_multi_inds)
            }
            for a, b in instr_reset_multi_inds:
                time_ind_to_which_need_instr_reset[a].append(b)

            unique_instr_embeddings = self._get_instr_embedding(
                instrs[needs_instr_reset_mask])

            instr_embeddings_list = [unique_instr_embeddings[:nsamplers]]
            current_instr_embeddings_list = list(instr_embeddings_list[-1])

            for time_ind in range(1, rollouts_len):
                if len(time_ind_to_which_need_instr_reset[time_ind]) == 0:
                    instr_embeddings_list.append(instr_embeddings_list[-1])
                else:
                    for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[
                            time_ind]:
                        current_instr_embeddings_list[
                            sampler_needing_reset_ind] = unique_instr_embeddings[
                                reset_multi_ind_to_index[(
                                    time_ind, sampler_needing_reset_ind)]]

                    instr_embeddings_list.append(
                        torch.stack(current_instr_embeddings_list, dim=0))

            instr_embeddings = torch.stack(instr_embeddings_list, dim=0)

        # The following code can be used to compute the instr_embeddings in another way
        # and thus verify that the above logic is (more likely to be) correct
        # needs_instr_reset_mask = (masks != 1.0)
        # needs_instr_reset_mask[0] *= 0
        # needs_instr_reset_inds = needs_instr_reset_mask.view(nrollouts, -1).any(-1).cpu().numpy()
        #
        # # Get inds where a new task has started
        # blocking_inds: List[int] = np.where(needs_instr_reset_inds)[0].tolist()
        # blocking_inds.append(needs_instr_reset_inds.shape[0])
        # if nrollouts != 1:
        #     pdb.set_trace()
        # if blocking_inds[0] != 0:
        #     blocking_inds.insert(0, 0)
        # if self.use_instr:
        #     instr_embeddings_list = []
        #     for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]):
        #         instr_embeddings_list.append(
        #             self._get_instr_embedding(instrs[ind0])
        #             .unsqueeze(0)
        #             .repeat(ind1 - ind0, 1, 1)
        #         )
        #     tmp_instr_embeddings = torch.cat(instr_embeddings_list, dim=0)
        # assert (instr_embeddings - tmp_instr_embeddings).abs().max().item() < 1e-6

        # Embed images
        # images = images.view(nrollouts, nsamplers, *images.shape[1:])
        image_embeddings = self.image_conv(images)
        if self.arch.startswith("expert_filmcnn"):
            instr_embeddings_flatter = instr_embeddings.view(
                -1, *instr_embeddings.shape[2:])
            for controller in self.controllers:
                image_embeddings = controller(image_embeddings,
                                              instr_embeddings_flatter)
            image_embeddings = F.relu(self.film_pool(image_embeddings))

        image_embeddings = image_embeddings.view(rollouts_len, nsamplers, -1)

        if self.use_instr and self.lang_model == "attgru":
            raise NotImplementedError("Currently attgru is not implemented.")

        memory = None
        if self.use_memory:
            assert recurrent_hidden_states.shape[0] == 1
            hidden = (
                recurrent_hidden_states[:, :, :self.semi_memory_size],
                recurrent_hidden_states[:, :, self.semi_memory_size:],
            )
            embeddings_list = []
            for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]):
                hidden = (hidden[0] * masks[ind0], hidden[1] * masks[ind0])
                rnn_out, hidden = self.memory_rnn(image_embeddings[ind0:ind1],
                                                  hidden)
                embeddings_list.append(rnn_out)

            # embedding = hidden[0]
            embedding = torch.cat(embeddings_list, dim=0)
            memory = torch.cat(hidden, dim=-1)
        else:
            embedding = image_embeddings

        if self.use_instr and not "filmcnn" in self.arch:
            embedding = torch.cat((embedding, instr_embeddings), dim=-1)

        if hasattr(self, "aux_info") and self.aux_info:
            extra_predictions = {
                info: self.extra_heads[info](embedding)
                for info in self.extra_heads
            }
        else:
            extra_predictions = dict()

        embedding = embedding.view(rollouts_len * nsamplers, -1)

        ac_output = ActorCriticOutput(
            distributions=CategoricalDistr(logits=self.actor(embedding), ),
            values=self.critic(embedding),
            extras=extra_predictions if not self.include_auxiliary_head else {
                **extra_predictions,
                "auxiliary_distributions":
                CategoricalDistr(logits=self.aux(embedding)),
            },
        )
        hidden_states = memory

        return self.adapt_result(
            ac_output,
            hidden_states,
            num_steps,
            num_samplers,
            num_agents,
            num_layers,
            observations,
        )
Ejemplo n.º 15
0
    def forward_loop(
        self,
        observations: ObservationType,
        recurrent_hidden_states: torch.FloatTensor,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ):
        results = []
        images = cast(torch.FloatTensor,
                      observations["minigrid_ego_image"]).float()
        instrs: Optional[torch.Tensor] = None
        if "minigrid_mission" in observations:
            instrs = cast(torch.Tensor, observations["minigrid_mission"])

        _, nsamplers, _ = recurrent_hidden_states.shape
        rollouts_len = images.shape[0] // nsamplers
        obs = babyai.rl.DictList()

        images = images.view(rollouts_len, nsamplers, *images.shape[1:])
        masks = masks.view(rollouts_len, nsamplers,
                           *masks.shape[1:])  # type:ignore

        # needs_reset = (masks != 1.0).view(nrollouts, -1).any(-1)
        if instrs is not None:
            instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1])

        needs_instr_reset_mask = masks != 1.0
        needs_instr_reset_mask[0] = 1
        needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1)
        instr_embeddings: Optional[torch.Tensor] = None
        if self.use_instr:
            instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip(
                *np.where(needs_instr_reset_mask.cpu().numpy())))
            time_ind_to_which_need_instr_reset: List[List] = [
                [] for _ in range(rollouts_len)
            ]
            reset_multi_ind_to_index = {
                mi: i
                for i, mi in enumerate(instr_reset_multi_inds)
            }
            for a, b in instr_reset_multi_inds:
                time_ind_to_which_need_instr_reset[a].append(b)

            unique_instr_embeddings = self._get_instr_embedding(
                instrs[needs_instr_reset_mask])

            instr_embeddings_list = [unique_instr_embeddings[:nsamplers]]
            current_instr_embeddings_list = list(instr_embeddings_list[-1])

            for time_ind in range(1, rollouts_len):
                if len(time_ind_to_which_need_instr_reset[time_ind]) == 0:
                    instr_embeddings_list.append(instr_embeddings_list[-1])
                else:
                    for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[
                            time_ind]:
                        current_instr_embeddings_list[
                            sampler_needing_reset_ind] = unique_instr_embeddings[
                                reset_multi_ind_to_index[(
                                    time_ind, sampler_needing_reset_ind)]]

                    instr_embeddings_list.append(
                        torch.stack(current_instr_embeddings_list, dim=0))

            instr_embeddings = torch.stack(instr_embeddings_list, dim=0)

        assert recurrent_hidden_states.shape[0] == 1
        memory = recurrent_hidden_states[0]
        # instr_embedding: Optional[torch.Tensor] = None
        for i in range(rollouts_len):
            obs.image = images[i]
            if "minigrid_mission" in observations:
                obs.instr = instrs[i]

            # reset = needs_reset[i].item()
            # if self.baby_ai_model.use_instr and (reset or i == 0):
            #     instr_embedding = self.baby_ai_model._get_instr_embedding(obs.instr)

            results.append(
                self.forward_once(obs,
                                  memory=memory * masks[i],
                                  instr_embedding=instr_embeddings[i]))
            memory = results[-1]["memory"]

        embedding = torch.cat([r["embedding"] for r in results], dim=0)

        extra_predictions_list = [r["extra_predictions"] for r in results]
        extra_predictions = {
            key: torch.cat([ep[key] for ep in extra_predictions_list], dim=0)
            for key in extra_predictions_list[0]
        }
        return (
            ActorCriticOutput(
                distributions=CategoricalDistr(logits=self.actor(embedding), ),
                values=self.critic(embedding),
                extras=extra_predictions
                if not self.include_auxiliary_head else {
                    **extra_predictions,
                    "auxiliary_distributions":
                    cast(Any, CategoricalDistr(logits=self.aux(embedding))),
                },
            ),
            torch.stack([r["memory"] for r in results], dim=0),
        )
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        """Processes input batched observations to produce new actor and critic
        values. Processes input batched observations (along with prior hidden
        states, previous actions, and masks denoting which recurrent hidden
        states should be masked) and returns an `ActorCriticOutput` object
        containing the model's policy (distribution over actions) and
        evaluation of the current state (value).

        # Parameters
        observations : Batched input observations.
        memory : `Memory` containing the hidden states from initial timepoints.
        prev_actions : Tensor of previous actions taken.
        masks : Masks applied to hidden states. See `RNNStateEncoder`.
        # Returns
        Tuple of the `ActorCriticOutput` and recurrent hidden state.
        """

        arm2obj_dist = self.relative_dist_embedding_pick(
            observations["relative_agent_arm_to_obj"])
        obj2goal_dist = self.relative_dist_embedding_drop(
            observations["relative_obj_to_goal"])

        perception_embed_pick = self.visual_encoder_pick(observations)
        perception_embed_drop = self.visual_encoder_drop(observations)

        pickup_bool = observations["pickedup_object"]
        before_pickup = pickup_bool == 0  # not used because of our initialization
        after_pickup = pickup_bool == 1
        distances = arm2obj_dist
        distances[after_pickup] = obj2goal_dist[after_pickup]

        perception_embed = perception_embed_pick
        perception_embed[after_pickup] = perception_embed_drop[after_pickup]

        x = [distances, perception_embed]

        x_cat = torch.cat(x, dim=-1)  # type: ignore
        x_out, rnn_hidden_states = self.state_encoder(x_cat,
                                                      memory.tensor("rnn"),
                                                      masks)
        actor_out_pick = self.actor_pick(x_out)
        critic_out_pick = self.critic_pick(x_out)

        actor_out_drop = self.actor_drop(x_out)
        critic_out_drop = self.critic_drop(x_out)

        actor_out = actor_out_pick
        actor_out[after_pickup] = actor_out_drop[after_pickup]
        critic_out = critic_out_pick
        critic_out[after_pickup] = critic_out_drop[after_pickup]

        actor_out = CategoricalDistr(logits=actor_out)
        actor_critic_output = ActorCriticOutput(distributions=actor_out,
                                                values=critic_out,
                                                extras={})
        updated_memory = memory.set_tensor("rnn", rnn_hidden_states)

        return (
            actor_critic_output,
            updated_memory,
        )
Ejemplo n.º 17
0
 def lower_policy(self, *args, **kwargs):
     assert "higher" in kwargs
     assert "state_embedding" in kwargs
     emb = self.embed_higher(kwargs["higher"])
     logits = self.actor(torch.cat([emb, kwargs["state_embedding"]], dim=-1))
     return CategoricalDistr(logits=logits)