def forward(self, s, g): if g is None: return prior(s).sample() priv_logits = self._priv(s, g) d_cap_logits = self._d_cap(s) d_cap = torch.sigmoid(d_cap_logits) priv = torch.sigmoid(priv_logits) x = d_cap * priv + (1 - d_cap) * prior(s).sample() if AuxLosses.is_active(): priv_log_prob = F.logsigmoid(priv) log_d_cap = F.logsigmoid(d_cap_logits) AuxLosses.register_loss( "information", ( -d_cap * log_d_cap + (1 - d_cap) * (1 + priv_log_prob) - (log_d_cap + priv_log_prob) ).mean(), self.beta, ) return x
def _get_tgt_encoding_hand_code(self, observations, gps_pred, compass_pred): initial_pg = observations[self.goal_sensor_uuid] pred_pg = self._update_pg(initial_pg, gps_pred, compass_pred) if "pointgoal_with_gps" in observations: true_pg = observations["pointgoal_with_gps"] error = (pred_pg - true_pg).norm(dim=-1, keepdim=True) if AuxLosses.is_active(): AuxLosses.register_loss("egomotion_error", error.mean(), 0.1) valid_ego_preds = error.detach() < self.ego_error_threshold pg = torch.where(valid_ego_preds, pred_pg.detach(), true_pg) else: pg = pred_pg # Back-propping from the policy into the goal seems kinda odd -- doesn't make sense # to move the goal to make the policy more/less likely to predict a given action # We also have prefect supervision on what this should be! goal_observations = _to_mag_and_unit_vec(pg.detach()) return self.tgt_embeding(goal_observations)
def forward(self, s, obs): priv_emb = super().forward(s, obs) gps = self.gps_head(s) compass = self.compass_head(s) pg = _update_pg_gps(obs["pointgoal"], gps) embed_pg = self.predicted_embed(_to_mag_and_unit_vec(pg.detach())) if AuxLosses.is_active(): AuxLosses.register_loss( "egomotion_error", torch.norm(pg - obs["pointgoal_with_gps"], dim=-1).mean(), 0.0, ) AuxLosses.register_loss( "compass_loss", _subsampled_mean(_angular_distance_loss(compass, obs["compass"])), ) AuxLosses.register_loss( "gps_loss", _subsampled_mean( F.mse_loss(gps, obs["gps"], reduction="none").mean(-1) ), ) return self.combine_layer(torch.cat([priv_emb, embed_pg], dim=-1))
def forward(self, s, obs): if "pointgoal_with_gps" not in obs: return self.prior(s).sample() if self.use_odometry: privileged_info = obs["pointgoal_with_gps_compass"] else: privileged_info = obs["pointgoal_with_gps"] mu, sigma = torch.chunk( self.encoder( torch.cat( [ s, self.priv_embed( _to_mag_and_unit_vec(privileged_info) ), ], -1, ) ), 2, s.dim() - 1, ) if not self.use_info_bot: mu.fill_(0.0) sigma.fill_(1.0) sigma = F.softplus(sigma) dist = torch.distributions.Normal(mu, sigma) x = dist.rsample() # The following code block is for running with Selective Noise Injection # if self.training: # x = dist.rsample() #else: # x = dist.mean if AuxLosses.is_active(): AuxLosses.register_loss( "information", _subsampled_mean( torch.distributions.kl_divergence(dist, self.prior(s)) ), # torch.distributions.kl_divergence(dist, self.prior(s)).mean(), self.beta, ) return x
def forward(self, observations, prev_observations, rnn_hidden_states, prev_actions, masks): if AuxLosses.is_active(): AuxLosses.obs = observations depth_flag = False rgb_flag = False if "depth" in observations: depth_flag = True if "rgb" in observations: rgb_flag = True if "visual_features" in observations: visual_features = observations["visual_features"] prev_visual_features = observations["prev_visual_features"] elif masks.size(0) != rnn_hidden_states.size(1): obs_input = {} N = rnn_hidden_states.size(1) T = masks.size(0) // N if depth_flag: prev_obs = prev_observations["depth"].view(T, N, *prev_observations["depth"].size()[1:]) obs = observations["depth"].view(T, N, *observations["depth"].size()[1:]) obs_input["depth"] = torch.cat((prev_obs[0:1], obs), dim=0) obs_input["depth"] = obs_input["depth"].view((T + 1) * N, *obs_input["depth"].size()[2:]) if rgb_flag: prev_obs = prev_observations["rgb"].view(T, N, *prev_observations["rgb"].size()[1:]) obs = observations["rgb"].view(T, N, *observations["rgb"].size()[1:]) obs_input["rgb"] = torch.cat((prev_obs[0:1], obs), dim=0) obs_input["rgb"] = obs_input["rgb"].view((T + 1) * N, *obs_input["rgb"].size()[2:]) obs_features = self.visual_encoder(obs_input) prev_visual_features = obs_features[:T*N, :, :, :] visual_features = obs_features[-T*N:, :, :, :] else: obs_input = {} if depth_flag: obs_input["depth"] = torch.cat((prev_observations["depth"], observations["depth"]), dim=0) if rgb_flag: obs_input["rgb"] = torch.cat((prev_observations["rgb"], observations["rgb"]), dim=0) obs_features = self.visual_encoder(obs_input) prev_visual_features, visual_features = obs_features.split(obs_features.size()[0] // 2, dim=0) visual_features = self.compression(visual_features) visual_emb = self.visual_fc(visual_features) # difference of frames (unit 1) flow_emb = self.visual_flow_encoder( (visual_features - self.compression(prev_visual_features)) * masks.view(-1, 1, 1, 1) ) prev_actions = self.prev_action_embedding( ((prev_actions.float() + 1) * masks).long().squeeze(-1) ) context_emb = prev_actions + self._tgt_proj(observations["pointgoal"]) x, rnn_hidden_states = self.state_encoder( torch.cat([visual_emb, flow_emb], dim=-1) + context_emb, rnn_hidden_states, masks, ) tgt_encoding = self.get_tgt_encoding(observations, x) x = torch.cat([x, tgt_encoding], dim=-1) x = self.goal_mem_layer(x) if AuxLosses.is_active(): n = rnn_hidden_states.size(1) t = int(x.size(0) / n) delta_ego = self.delta_egomotion_predictor(flow_emb).view(t, n, 3) gps_gt = observations["gps"].view(t, n, 2) compass_gt = observations["compass"].view(t, n, 1) masks = masks.view(t, n, 1) gt_delta = gps_gt[1:] - gps_gt[:-1] gt_delta = _update_pg_gps_compass( gt_delta.view((t - 1) * n, 2), torch.zeros_like(gt_delta).view((t - 1) * n, 2), compass_gt[:-1].view((t - 1) * n, 1), ).view(t - 1, n, 2) AuxLosses.register_loss( "delta_gps", _subsampled_mean( torch.masked_select( F.mse_loss( delta_ego[1:, :, 0:2], gt_delta, reduction="none" ).mean(dim=-1), masks[1:, :, 0].bool(), ) ), ) AuxLosses.register_loss( "delta_compass", _subsampled_mean( torch.masked_select( _angular_distance_loss( delta_ego[1:, :, 2:], compass_gt[1:] - compass_gt[:-1] ), masks[1:, :, 0].bool(), ) ), ) return x, rnn_hidden_states
def forward(self, observations, rnn_hidden_states, prev_actions, masks): if AuxLosses.is_active(): AuxLosses.obs = observations if "visual_features" in observations: visual_features = observations["visual_features"] else: visual_features = self.visual_encoder(observations) visual_features = self.compression(visual_features) visual_emb = self.visual_fc(visual_features) flow_emb = self.visual_flow_encoder( (visual_features - self.compression(observations["prev_visual_features"])) * masks.view(-1, 1, 1, 1) ) prev_actions = self.prev_action_embedding( ((prev_actions.float() + 1) * masks).long().squeeze(-1) ) context_emb = prev_actions + self._tgt_proj(observations["pointgoal"]) x, rnn_hidden_states = self.state_encoder( torch.cat([visual_emb, flow_emb], dim=-1) + context_emb, rnn_hidden_states, masks, ) tgt_encoding = self.get_tgt_encoding(observations, x) x = torch.cat([x, tgt_encoding], dim=-1) x = self.goal_mem_layer(x) if AuxLosses.is_active(): n = rnn_hidden_states.size(1) t = int(x.size(0) / n) delta_ego = self.delta_egomotion_predictor(flow_emb).view(t, n, 3) gps_gt = observations["gps"].view(t, n, 2) compass_gt = observations["compass"].view(t, n, 1) masks = masks.view(t, n, 1) gt_delta = gps_gt[1:] - gps_gt[:-1] gt_delta = _update_pg_gps_compass( gt_delta.view((t - 1) * n, 2), torch.zeros_like(gt_delta).view((t - 1) * n, 2), compass_gt[:-1].view((t - 1) * n, 1), ).view(t - 1, n, 2) AuxLosses.register_loss( "delta_gps", _subsampled_mean( torch.masked_select( F.mse_loss( delta_ego[1:, :, 0:2], gt_delta, reduction="none" ).mean(dim=-1), masks[1:, :, 0].bool(), ) ), ) AuxLosses.register_loss( "delta_compass", _subsampled_mean( torch.masked_select( _angular_distance_loss( delta_ego[1:, :, 2:], compass_gt[1:] - compass_gt[:-1] ), masks[1:, :, 0].bool(), ) ), ) return x, rnn_hidden_states