def forward(self, s, g): if g is None: return prior(s).sample() priv_logits = self._priv(s, g) d_cap_logits = self._d_cap(s) d_cap = torch.sigmoid(d_cap_logits) priv = torch.sigmoid(priv_logits) x = d_cap * priv + (1 - d_cap) * prior(s).sample() if AuxLosses.is_active(): priv_log_prob = F.logsigmoid(priv) log_d_cap = F.logsigmoid(d_cap_logits) AuxLosses.register_loss( "information", ( -d_cap * log_d_cap + (1 - d_cap) * (1 + priv_log_prob) - (log_d_cap + priv_log_prob) ).mean(), self.beta, ) return x
def _get_tgt_encoding_hand_code(self, observations, gps_pred, compass_pred): initial_pg = observations[self.goal_sensor_uuid] pred_pg = self._update_pg(initial_pg, gps_pred, compass_pred) if "pointgoal_with_gps" in observations: true_pg = observations["pointgoal_with_gps"] error = (pred_pg - true_pg).norm(dim=-1, keepdim=True) if AuxLosses.is_active(): AuxLosses.register_loss("egomotion_error", error.mean(), 0.1) valid_ego_preds = error.detach() < self.ego_error_threshold pg = torch.where(valid_ego_preds, pred_pg.detach(), true_pg) else: pg = pred_pg # Back-propping from the policy into the goal seems kinda odd -- doesn't make sense # to move the goal to make the policy more/less likely to predict a given action # We also have prefect supervision on what this should be! goal_observations = _to_mag_and_unit_vec(pg.detach()) return self.tgt_embeding(goal_observations)
def forward(self, s, obs): if "pointgoal_with_gps" not in obs: return self.prior(s).sample() if self.use_odometry: privileged_info = obs["pointgoal_with_gps_compass"] else: privileged_info = obs["pointgoal_with_gps"] mu, sigma = torch.chunk( self.encoder( torch.cat( [ s, self.priv_embed( _to_mag_and_unit_vec(privileged_info) ), ], -1, ) ), 2, s.dim() - 1, ) if not self.use_info_bot: mu.fill_(0.0) sigma.fill_(1.0) sigma = F.softplus(sigma) dist = torch.distributions.Normal(mu, sigma) x = dist.rsample() # The following code block is for running with Selective Noise Injection # if self.training: # x = dist.rsample() #else: # x = dist.mean if AuxLosses.is_active(): AuxLosses.register_loss( "information", _subsampled_mean( torch.distributions.kl_divergence(dist, self.prior(s)) ), # torch.distributions.kl_divergence(dist, self.prior(s)).mean(), self.beta, ) return x
def evaluate_actions( self, observations, prev_observations, rnn_hidden_states, prev_actions, masks, action ): features, _ = self.net(observations, prev_observations, rnn_hidden_states, prev_actions, masks) value = self.critic(features) if self.supervise_stop: stop_distribution = self.stop_action_distribution(features) non_stop_distribution = self.non_stop_action_distribution(features) action_log_probs = torch.where( action == 0, stop_distribution.log_probs(torch.full_like(action, 1)), stop_distribution.log_probs(torch.full_like(action, 0)) + non_stop_distribution.log_probs( torch.max(action - 1, torch.zeros_like(action)) ), ) distribution_entropy = ( -1.0 * ( stop_distribution.probs[:, -1] * stop_distribution.logits[:, -1] + ( stop_distribution.probs[:, 0:1] * non_stop_distribution.probs * ( stop_distribution.logits[:, 0:1] + non_stop_distribution.logits ) ).sum(-1) ).mean() ) stop_loss = F.cross_entropy( stop_distribution.logits, observations["stop_oracle"].long().squeeze(-1), weight=torch.tensor( [1.0, 1.0 / np.sqrt(100.0)], device=features.device ), ) AuxLosses.register_loss("stop_loss", stop_loss) else: action_distribution = self.action_distribution(features) action_log_probs = action_distribution.log_probs(action) distribution_entropy = action_distribution.entropy().mean() return value, action_log_probs, distribution_entropy
def forward(self, s, obs): priv_emb = super().forward(s, obs) gps = self.gps_head(s) compass = self.compass_head(s) pg = _update_pg_gps(obs["pointgoal"], gps) embed_pg = self.predicted_embed(_to_mag_and_unit_vec(pg.detach())) if AuxLosses.is_active(): AuxLosses.register_loss( "egomotion_error", torch.norm(pg - obs["pointgoal_with_gps"], dim=-1).mean(), 0.0, ) AuxLosses.register_loss( "compass_loss", _subsampled_mean(_angular_distance_loss(compass, obs["compass"])), ) AuxLosses.register_loss( "gps_loss", _subsampled_mean( F.mse_loss(gps, obs["gps"], reduction="none").mean(-1) ), ) return self.combine_layer(torch.cat([priv_emb, embed_pg], dim=-1))
def forward(self, observations, prev_observations, rnn_hidden_states, prev_actions, masks): if AuxLosses.is_active(): AuxLosses.obs = observations depth_flag = False rgb_flag = False if "depth" in observations: depth_flag = True if "rgb" in observations: rgb_flag = True if "visual_features" in observations: visual_features = observations["visual_features"] prev_visual_features = observations["prev_visual_features"] elif masks.size(0) != rnn_hidden_states.size(1): obs_input = {} N = rnn_hidden_states.size(1) T = masks.size(0) // N if depth_flag: prev_obs = prev_observations["depth"].view(T, N, *prev_observations["depth"].size()[1:]) obs = observations["depth"].view(T, N, *observations["depth"].size()[1:]) obs_input["depth"] = torch.cat((prev_obs[0:1], obs), dim=0) obs_input["depth"] = obs_input["depth"].view((T + 1) * N, *obs_input["depth"].size()[2:]) if rgb_flag: prev_obs = prev_observations["rgb"].view(T, N, *prev_observations["rgb"].size()[1:]) obs = observations["rgb"].view(T, N, *observations["rgb"].size()[1:]) obs_input["rgb"] = torch.cat((prev_obs[0:1], obs), dim=0) obs_input["rgb"] = obs_input["rgb"].view((T + 1) * N, *obs_input["rgb"].size()[2:]) obs_features = self.visual_encoder(obs_input) prev_visual_features = obs_features[:T*N, :, :, :] visual_features = obs_features[-T*N:, :, :, :] else: obs_input = {} if depth_flag: obs_input["depth"] = torch.cat((prev_observations["depth"], observations["depth"]), dim=0) if rgb_flag: obs_input["rgb"] = torch.cat((prev_observations["rgb"], observations["rgb"]), dim=0) obs_features = self.visual_encoder(obs_input) prev_visual_features, visual_features = obs_features.split(obs_features.size()[0] // 2, dim=0) visual_features = self.compression(visual_features) visual_emb = self.visual_fc(visual_features) # difference of frames (unit 1) flow_emb = self.visual_flow_encoder( (visual_features - self.compression(prev_visual_features)) * masks.view(-1, 1, 1, 1) ) prev_actions = self.prev_action_embedding( ((prev_actions.float() + 1) * masks).long().squeeze(-1) ) context_emb = prev_actions + self._tgt_proj(observations["pointgoal"]) x, rnn_hidden_states = self.state_encoder( torch.cat([visual_emb, flow_emb], dim=-1) + context_emb, rnn_hidden_states, masks, ) tgt_encoding = self.get_tgt_encoding(observations, x) x = torch.cat([x, tgt_encoding], dim=-1) x = self.goal_mem_layer(x) if AuxLosses.is_active(): n = rnn_hidden_states.size(1) t = int(x.size(0) / n) delta_ego = self.delta_egomotion_predictor(flow_emb).view(t, n, 3) gps_gt = observations["gps"].view(t, n, 2) compass_gt = observations["compass"].view(t, n, 1) masks = masks.view(t, n, 1) gt_delta = gps_gt[1:] - gps_gt[:-1] gt_delta = _update_pg_gps_compass( gt_delta.view((t - 1) * n, 2), torch.zeros_like(gt_delta).view((t - 1) * n, 2), compass_gt[:-1].view((t - 1) * n, 1), ).view(t - 1, n, 2) AuxLosses.register_loss( "delta_gps", _subsampled_mean( torch.masked_select( F.mse_loss( delta_ego[1:, :, 0:2], gt_delta, reduction="none" ).mean(dim=-1), masks[1:, :, 0].bool(), ) ), ) AuxLosses.register_loss( "delta_compass", _subsampled_mean( torch.masked_select( _angular_distance_loss( delta_ego[1:, :, 2:], compass_gt[1:] - compass_gt[:-1] ), masks[1:, :, 0].bool(), ) ), ) return x, rnn_hidden_states
def forward(self, observations, rnn_hidden_states, prev_actions, masks): if AuxLosses.is_active(): AuxLosses.obs = observations if "visual_features" in observations: visual_features = observations["visual_features"] else: visual_features = self.visual_encoder(observations) visual_features = self.compression(visual_features) visual_emb = self.visual_fc(visual_features) flow_emb = self.visual_flow_encoder( (visual_features - self.compression(observations["prev_visual_features"])) * masks.view(-1, 1, 1, 1) ) prev_actions = self.prev_action_embedding( ((prev_actions.float() + 1) * masks).long().squeeze(-1) ) context_emb = prev_actions + self._tgt_proj(observations["pointgoal"]) x, rnn_hidden_states = self.state_encoder( torch.cat([visual_emb, flow_emb], dim=-1) + context_emb, rnn_hidden_states, masks, ) tgt_encoding = self.get_tgt_encoding(observations, x) x = torch.cat([x, tgt_encoding], dim=-1) x = self.goal_mem_layer(x) if AuxLosses.is_active(): n = rnn_hidden_states.size(1) t = int(x.size(0) / n) delta_ego = self.delta_egomotion_predictor(flow_emb).view(t, n, 3) gps_gt = observations["gps"].view(t, n, 2) compass_gt = observations["compass"].view(t, n, 1) masks = masks.view(t, n, 1) gt_delta = gps_gt[1:] - gps_gt[:-1] gt_delta = _update_pg_gps_compass( gt_delta.view((t - 1) * n, 2), torch.zeros_like(gt_delta).view((t - 1) * n, 2), compass_gt[:-1].view((t - 1) * n, 1), ).view(t - 1, n, 2) AuxLosses.register_loss( "delta_gps", _subsampled_mean( torch.masked_select( F.mse_loss( delta_ego[1:, :, 0:2], gt_delta, reduction="none" ).mean(dim=-1), masks[1:, :, 0].bool(), ) ), ) AuxLosses.register_loss( "delta_compass", _subsampled_mean( torch.masked_select( _angular_distance_loss( delta_ego[1:, :, 2:], compass_gt[1:] - compass_gt[:-1] ), masks[1:, :, 0].bool(), ) ), ) return x, rnn_hidden_states
def seq_forward(self, x, context, mems, masks): r"""Forward for a sequence of length T Args: x: (T, N, -1) Tensor that has been flattened to (T * N, -1) hidden_states: The starting hidden state. masks: The masks to be applied to hidden state at every timestep. A (T, N) tensor flatten to (T * N) """ # x is a (T, N, -1) tensor flattened to (T * N, -1) n = mems.size(1) t = int(x.size(0) / n) # unflatten x = x.view(t, n, x.size(1)) context = context.view(t, n, context.size(1)) masks = masks.view(t, n) if self.self_sup: ep_lens = [] for i in range(n): last_zero = 0 has_zeros = (( masks[1:-1, i] == 0.0).nonzero().squeeze(-1).cpu().unbind(0)) for z in has_zeros: z = z.item() + 1 ep_lens.append(z - last_zero) last_zero = z ep_lens.append(t - last_zero) k = random.randint( 1, max( min(self.max_self_sup_K, int(0.8 * np.mean(np.array(ep_lens)))), 2), ) else: k = None content, mems, query = self.transformer.transformer_seq_forward( x, context, mems, masks, two_stream_k=k) if self.self_sup: positives = x negatives = [] for _ in range(3): negative_inds = torch.randperm(t * n, device=x.device) negatives.append( torch.gather( x.view(t * n, -1), dim=0, index=negative_inds.view(t * n, 1).expand(t * n, x.size(-1)), ).view(t, n, -1)) negatives = torch.stack(negatives, dim=-1) positives = torch.einsum("...i, ...i -> ...", positives, query) negatives = torch.einsum("...ik, ...i -> ...k", negatives, query) cpc_logits = torch.stack([positives.unsqueeze(-1), negatives], dim=-1) valid_modeling_queries = torch.ones(t, n, device=query.device, dtype=torch.bool) valid_modeling_queries[0:k] = 0 for i in range(n): has_zeros_batch = (( masks[:, i] == 0.0).nonzero().squeeze(-1).cpu().unbind(0)) for z in has_zeros_batch: valid_modeling_queries[n:n + k, i] = 0 cpc_loss = torch.masked_select( F.cross_entropy( cpc_logits, torch.zeros(t, n, dtype=torch.long, device=cpc_logits.device), reduction="none", ), valid_modeling_queries, ).mean() AuxLosses.register_loss("CPC|A", cpc_loss, 0.2) content = content.view(t * n, -1) return content, mems