コード例 #1
0
    def forward(self, s, g):
        if g is None:
            return prior(s).sample()

        priv_logits = self._priv(s, g)
        d_cap_logits = self._d_cap(s)

        d_cap = torch.sigmoid(d_cap_logits)
        priv = torch.sigmoid(priv_logits)
        x = d_cap * priv + (1 - d_cap) * prior(s).sample()

        if AuxLosses.is_active():
            priv_log_prob = F.logsigmoid(priv)
            log_d_cap = F.logsigmoid(d_cap_logits)
            AuxLosses.register_loss(
                "information",
                (
                    -d_cap * log_d_cap
                    + (1 - d_cap) * (1 + priv_log_prob)
                    - (log_d_cap + priv_log_prob)
                ).mean(),
                self.beta,
            )

        return x
コード例 #2
0
    def _get_tgt_encoding_hand_code(self, observations, gps_pred,
                                    compass_pred):
        initial_pg = observations[self.goal_sensor_uuid]
        pred_pg = self._update_pg(initial_pg, gps_pred, compass_pred)

        if "pointgoal_with_gps" in observations:
            true_pg = observations["pointgoal_with_gps"]

            error = (pred_pg - true_pg).norm(dim=-1, keepdim=True)

            if AuxLosses.is_active():
                AuxLosses.register_loss("egomotion_error", error.mean(), 0.1)

            valid_ego_preds = error.detach() < self.ego_error_threshold

            pg = torch.where(valid_ego_preds, pred_pg.detach(), true_pg)
        else:
            pg = pred_pg

        # Back-propping from the policy into the goal seems kinda odd -- doesn't make sense
        # to move the goal to make the policy more/less likely to predict a given action
        # We also have prefect supervision on what this should be!
        goal_observations = _to_mag_and_unit_vec(pg.detach())

        return self.tgt_embeding(goal_observations)
コード例 #3
0
    def forward(self, s, obs):
        priv_emb = super().forward(s, obs)

        gps = self.gps_head(s)
        compass = self.compass_head(s)
        pg = _update_pg_gps(obs["pointgoal"], gps)

        embed_pg = self.predicted_embed(_to_mag_and_unit_vec(pg.detach()))

        if AuxLosses.is_active():
            AuxLosses.register_loss(
                "egomotion_error",
                torch.norm(pg - obs["pointgoal_with_gps"], dim=-1).mean(),
                0.0,
            )
            AuxLosses.register_loss(
                "compass_loss",
                _subsampled_mean(_angular_distance_loss(compass, obs["compass"])),
            )
            AuxLosses.register_loss(
                "gps_loss",
                _subsampled_mean(
                    F.mse_loss(gps, obs["gps"], reduction="none").mean(-1)
                ),
            )

        return self.combine_layer(torch.cat([priv_emb, embed_pg], dim=-1))
コード例 #4
0
    def forward(self, s, obs):
        if "pointgoal_with_gps" not in obs:
            return self.prior(s).sample()
        
        if self.use_odometry:
            privileged_info = obs["pointgoal_with_gps_compass"]
        else:
            privileged_info = obs["pointgoal_with_gps"]

        mu, sigma = torch.chunk(
            self.encoder(
                torch.cat(
                    [
                        s,
                        self.priv_embed(
                            _to_mag_and_unit_vec(privileged_info)
                        ),
                    ],
                    -1,
                )
            ),
            2,
            s.dim() - 1,
        )

        if not self.use_info_bot:
            mu.fill_(0.0)
            sigma.fill_(1.0)

        sigma = F.softplus(sigma)
        dist = torch.distributions.Normal(mu, sigma)

        x = dist.rsample()

        # The following code block is for running with Selective Noise Injection

        # if self.training:
        #    x = dist.rsample()
        #else:
        #    x = dist.mean


        if AuxLosses.is_active():
            AuxLosses.register_loss(
                "information",
                _subsampled_mean(
                    torch.distributions.kl_divergence(dist, self.prior(s))
                ),
                # torch.distributions.kl_divergence(dist, self.prior(s)).mean(),
                self.beta,
            )

        return x
コード例 #5
0
    def forward(self, observations, prev_observations, rnn_hidden_states, prev_actions, masks):
        if AuxLosses.is_active():
            AuxLosses.obs = observations

        depth_flag = False
        rgb_flag = False

        if "depth" in observations:
            depth_flag = True
        if "rgb" in observations:
            rgb_flag = True

        if "visual_features" in observations:
            visual_features = observations["visual_features"]
            prev_visual_features = observations["prev_visual_features"]
        
        elif masks.size(0) != rnn_hidden_states.size(1):
            obs_input = {}
            N = rnn_hidden_states.size(1)
            T = masks.size(0) // N

            if depth_flag:
                prev_obs = prev_observations["depth"].view(T, N, *prev_observations["depth"].size()[1:])
                obs = observations["depth"].view(T, N, *observations["depth"].size()[1:])
                obs_input["depth"] = torch.cat((prev_obs[0:1], obs), dim=0)
                obs_input["depth"] = obs_input["depth"].view((T + 1) * N, *obs_input["depth"].size()[2:])

            if rgb_flag:
                prev_obs = prev_observations["rgb"].view(T, N, *prev_observations["rgb"].size()[1:])
                obs = observations["rgb"].view(T, N, *observations["rgb"].size()[1:])
                obs_input["rgb"] = torch.cat((prev_obs[0:1], obs), dim=0)
                obs_input["rgb"] = obs_input["rgb"].view((T + 1) * N, *obs_input["rgb"].size()[2:])

            obs_features = self.visual_encoder(obs_input)
            prev_visual_features = obs_features[:T*N, :, :, :]
            visual_features = obs_features[-T*N:, :, :, :]
                
        else:
            obs_input = {}

            if depth_flag:
                obs_input["depth"] = torch.cat((prev_observations["depth"], observations["depth"]), dim=0)
            if rgb_flag:
                obs_input["rgb"] = torch.cat((prev_observations["rgb"], observations["rgb"]), dim=0)

            obs_features = self.visual_encoder(obs_input)
            prev_visual_features, visual_features = obs_features.split(obs_features.size()[0] // 2, dim=0)

        visual_features = self.compression(visual_features)

        visual_emb = self.visual_fc(visual_features)
	
	    # difference of frames (unit 1)
        flow_emb = self.visual_flow_encoder(
            (visual_features - self.compression(prev_visual_features))
            * masks.view(-1, 1, 1, 1)
        )

        prev_actions = self.prev_action_embedding(
            ((prev_actions.float() + 1) * masks).long().squeeze(-1)
        )

        context_emb = prev_actions + self._tgt_proj(observations["pointgoal"])
        x, rnn_hidden_states = self.state_encoder(
            torch.cat([visual_emb, flow_emb], dim=-1) + context_emb,
            rnn_hidden_states,
            masks,
        )

        tgt_encoding = self.get_tgt_encoding(observations, x)

        x = torch.cat([x, tgt_encoding], dim=-1)
        x = self.goal_mem_layer(x)

        if AuxLosses.is_active():
            n = rnn_hidden_states.size(1)
            t = int(x.size(0) / n)

            delta_ego = self.delta_egomotion_predictor(flow_emb).view(t, n, 3)
            gps_gt = observations["gps"].view(t, n, 2)
            compass_gt = observations["compass"].view(t, n, 1)
            masks = masks.view(t, n, 1)

            gt_delta = gps_gt[1:] - gps_gt[:-1]
            gt_delta = _update_pg_gps_compass(
                gt_delta.view((t - 1) * n, 2),
                torch.zeros_like(gt_delta).view((t - 1) * n, 2),
                compass_gt[:-1].view((t - 1) * n, 1),
            ).view(t - 1, n, 2)
            AuxLosses.register_loss(
                "delta_gps",
                _subsampled_mean(
                    torch.masked_select(
                        F.mse_loss(
                            delta_ego[1:, :, 0:2], gt_delta, reduction="none"
                        ).mean(dim=-1),
                        masks[1:, :, 0].bool(),
                    )
                ),
            )

            AuxLosses.register_loss(
                "delta_compass",
                _subsampled_mean(
                    torch.masked_select(
                        _angular_distance_loss(
                            delta_ego[1:, :, 2:], compass_gt[1:] - compass_gt[:-1]
                        ),
                        masks[1:, :, 0].bool(),
                    )
                ),
            )

        return x, rnn_hidden_states
コード例 #6
0
    def forward(self, observations, rnn_hidden_states, prev_actions, masks):
        if AuxLosses.is_active():
            AuxLosses.obs = observations

        if "visual_features" in observations:
            visual_features = observations["visual_features"]
        else:
            visual_features = self.visual_encoder(observations)

        visual_features = self.compression(visual_features)

        visual_emb = self.visual_fc(visual_features)
        flow_emb = self.visual_flow_encoder(
            (visual_features - self.compression(observations["prev_visual_features"]))
            * masks.view(-1, 1, 1, 1)
        )

        prev_actions = self.prev_action_embedding(
            ((prev_actions.float() + 1) * masks).long().squeeze(-1)
        )

        context_emb = prev_actions + self._tgt_proj(observations["pointgoal"])
        x, rnn_hidden_states = self.state_encoder(
            torch.cat([visual_emb, flow_emb], dim=-1) + context_emb,
            rnn_hidden_states,
            masks,
        )

        tgt_encoding = self.get_tgt_encoding(observations, x)

        x = torch.cat([x, tgt_encoding], dim=-1)
        x = self.goal_mem_layer(x)

        if AuxLosses.is_active():
            n = rnn_hidden_states.size(1)
            t = int(x.size(0) / n)

            delta_ego = self.delta_egomotion_predictor(flow_emb).view(t, n, 3)
            gps_gt = observations["gps"].view(t, n, 2)
            compass_gt = observations["compass"].view(t, n, 1)
            masks = masks.view(t, n, 1)

            gt_delta = gps_gt[1:] - gps_gt[:-1]
            gt_delta = _update_pg_gps_compass(
                gt_delta.view((t - 1) * n, 2),
                torch.zeros_like(gt_delta).view((t - 1) * n, 2),
                compass_gt[:-1].view((t - 1) * n, 1),
            ).view(t - 1, n, 2)
            AuxLosses.register_loss(
                "delta_gps",
                _subsampled_mean(
                    torch.masked_select(
                        F.mse_loss(
                            delta_ego[1:, :, 0:2], gt_delta, reduction="none"
                        ).mean(dim=-1),
                        masks[1:, :, 0].bool(),
                    )
                ),
            )

            AuxLosses.register_loss(
                "delta_compass",
                _subsampled_mean(
                    torch.masked_select(
                        _angular_distance_loss(
                            delta_ego[1:, :, 2:], compass_gt[1:] - compass_gt[:-1]
                        ),
                        masks[1:, :, 0].bool(),
                    )
                ),
            )

        return x, rnn_hidden_states