def _get_edge_ixs(self, reid_embeddings): """ Constructs graph edges by taking pairs of nodes with valid time connections (not in same frame, not too far apart in time) and perhaps taking KNNs according to reid embeddings. Args: reid_embeddings: torch.tensor with shape (num_nodes, reid_embeds_dim) Returns: torch.tensor withs shape (2, num_edges) """ edge_ixs = get_time_valid_conn_ixs( frame_num=torch.from_numpy(self.graph_df.frame.values), max_frame_dist=self.max_frame_dist, use_cuda=self.inference_mode and self.graph_df["frame_path"].iloc[0].find("MOT17-03") == -1, ) # During inference, top k nns must not be done here, as it is computed independently for sequence chunks if not self.inference_mode and self.dataset_params[ "top_k_nns"] is not None: reid_pwise_dist = F.pairwise_distance(reid_embeddings[edge_ixs[0]], reid_embeddings[edge_ixs[1]]) k_nns_mask = get_knn_mask( pwise_dist=reid_pwise_dist, edge_ixs=edge_ixs, num_nodes=self.graph_df.shape[0], top_k_nns=self.dataset_params["top_k_nns"], reciprocal_k_nns=self.dataset_params["reciprocal_k_nns"], symmetric_edges=False, use_cuda=self.inference_mode, ) edge_ixs = edge_ixs.T[k_nns_mask].T return edge_ixs
def _get_edge_ixs(self, reid_embeddings): """ Constructs graph edges by taking pairs of nodes with valid time connections (not in same frame, not too far apart in time) and perhaps taking KNNs according to reid embeddings. Args: reid_embeddings: torch.tensor with shape (num_nodes, reid_embeds_dim) Returns: torch.tensor withs shape (2, num_edges) """ edge_ixs = get_time_valid_conn_ixs( frame_num=torch.from_numpy(self.graph_df.frame.values), max_frame_dist=self.max_frame_dist, use_cuda=self.inference_mode and self.graph_df["frame_path"].iloc[0].find("MOT17-03") == -1, ) edge_feats_dict = None if ( "max_feet_vel" in self.dataset_params and self.dataset_params["max_feet_vel"] is not None ): # New parameter. We do graph pruning based on feet velocity # print("VELOCITY PRUNING") edge_feats_dict = compute_edge_feats_dict( edge_ixs=edge_ixs, det_df=self.graph_df, fps=self.seq_info_dict["fps"], use_cuda=self.inference_mode, ) feet_vel = torch.sqrt( edge_feats_dict["norm_feet_x_dists"] ** 2 + edge_feats_dict["norm_feet_y_dists"] ** 2 ) vel_mask = feet_vel < self.dataset_params["max_feet_vel"] edge_ixs = edge_ixs.T[vel_mask].T for feat_name, feat_vals in edge_feats_dict.items(): edge_feats_dict[feat_name] = feat_vals[vel_mask] # During inference, top k nns must not be done here, as it is computed independently for sequence chunks if not self.inference_mode and self.dataset_params["top_k_nns"] is not None: reid_pwise_dist = F.pairwise_distance( reid_embeddings[edge_ixs[0]], reid_embeddings[edge_ixs[1]] ) k_nns_mask = get_knn_mask( pwise_dist=reid_pwise_dist, edge_ixs=edge_ixs, num_nodes=self.graph_df.shape[0], top_k_nns=self.dataset_params["top_k_nns"], reciprocal_k_nns=self.dataset_params["reciprocal_k_nns"], symmetric_edges=False, use_cuda=self.inference_mode, ) edge_ixs = edge_ixs.T[k_nns_mask].T if edge_feats_dict is not None: for feat_name, feat_vals in edge_feats_dict.items(): edge_feats_dict[feat_name] = feat_vals[k_nns_mask] return edge_ixs, edge_feats_dict
def _predict_edges(self, subgraph): """ Predicts edge values for a subgraph (i.e. batch of frames) from the entire sequence. Args: subgraph: Graph Object corresponding to a subset of frames Returns: tuple containing a torch.Tensor with the predicted value for every edge in the subgraph, and a binary mask indicating which edges inside the subgraph where pruned with KNN """ # Prune graph edges knn_mask = get_knn_mask( pwise_dist=subgraph.reid_emb_dists, edge_ixs=subgraph.edge_index, num_nodes=subgraph.num_nodes, top_k_nns=self.dataset_params['top_k_nns'], use_cuda=True, reciprocal_k_nns=self.dataset_params['reciprocal_k_nns'], symmetric_edges=True) subgraph.edge_index = subgraph.edge_index.T[knn_mask].T subgraph.edge_attr = subgraph.edge_attr[knn_mask] if hasattr(subgraph, 'edge_labels'): subgraph.edge_labels = subgraph.edge_labels[knn_mask] # Predict active edges if self.use_gt: # For debugging purposes and obtaining oracle results pruned_edge_preds = subgraph.edge_labels else: with torch.no_grad(): pruned_edge_preds = torch.sigmoid( self.graph_model(subgraph)['classified_edges'][-1].view( -1)) edge_preds = torch.zeros(knn_mask.shape[0]).to( pruned_edge_preds.device) edge_preds[knn_mask] = pruned_edge_preds if self.eval_params['set_pruned_edges_to_inactive']: return edge_preds, torch.ones_like(knn_mask) else: return edge_preds, knn_mask # In this case, pruning an edge counts as not predicting a value for it at all
def _predict_edges(self, subgraph): """ Predicts edge values for a subgraph (i.e. batch of frames) from the entire sequence. Args: subgraph: Graph Object corresponding to a subset of frames Returns: tuple containing a torch.Tensor with the predicted value for every edge in the subgraph, and a binary mask indicating which edges inside the subgraph where pruned with KNN """ # Prune graph edges knn_mask = get_knn_mask( pwise_dist=subgraph.reid_emb_dists, edge_ixs=subgraph.edge_index, num_nodes=subgraph.num_nodes, top_k_nns=self.dataset_params["top_k_nns"], use_cuda=True, reciprocal_k_nns=self.dataset_params["reciprocal_k_nns"], symmetric_edges=True, ) subgraph.edge_index = subgraph.edge_index.T[knn_mask].T subgraph.edge_attr = subgraph.edge_attr[knn_mask] if hasattr(subgraph, "edge_labels"): subgraph.edge_labels = subgraph.edge_labels[knn_mask] if hasattr(subgraph, "edge_index_bb-bb"): joints_per_bb = (subgraph["edge_index_joint-joint"].shape[1] // subgraph["edge_index_bb-bb"].shape[1]) subgraph["edge_index_bb-bb"] = subgraph["edge_index_bb-bb"].T[ knn_mask].T subgraph["edge_attr_bb-bb"] = subgraph["edge_attr_bb-bb"][knn_mask] joint_knn_mask = (knn_mask.reshape((-1, 1)).repeat( (1, joints_per_bb)).reshape((-1))) subgraph["edge_index_joint-joint"] = ( subgraph["edge_index_joint-joint"].T[joint_knn_mask].T) subgraph["edge_attr_joint-joint"] = subgraph[ "edge_attr_joint-joint"][joint_knn_mask] # Predict active edges if self.use_gt: # For debugging purposes and obtaining oracle results pruned_edge_preds = subgraph.edge_labels else: with torch.no_grad(): result = self.graph_model(subgraph)["classified_edges"][-1] if isinstance(result, dict): result = result["bb-bb"] pruned_edge_preds = torch.sigmoid(result.view(-1)) edge_preds = torch.zeros(knn_mask.shape[0]).to( pruned_edge_preds.device) edge_preds[knn_mask] = pruned_edge_preds if self.eval_params["set_pruned_edges_to_inactive"]: return edge_preds, torch.ones_like(knn_mask) else: return ( edge_preds, knn_mask, ) # In this case, pruning an edge counts as not predicting a value for it at all