def _evaluate_graph_in_batches(self): """ Feeds the entire sequence though the MPN in batches. It does so by applying a 'sliding window' over the sequence, where windows correspond consecutive pairs of start/end frame locations (e.g. frame 1 to 15, 5 to 20, 10 to 25, etc.). For every window, a subgraph is created by selecting all detections that fall inside it. Then this graph is fed to the message passing network, and predictions are stored. Since windows overlap, we end up with several predictions per edge. We simply average them overall all windows. """ device = torch.device('cuda') all_frames = np.array(self.full_graph.frames) frame_num_per_node = torch.from_numpy( self.full_graph.graph_df.frame.values).to(device) node_names = torch.arange(self.full_graph.graph_obj.x.shape[0]) # Iterate over overlapping windows of (starg_frame, end_frame) overall_edge_preds = torch.zeros( self.full_graph.graph_obj.num_edges).to(device) overall_num_preds = overall_edge_preds.clone() for eval_round, (start_frame, end_frame) in enumerate( zip(all_frames, all_frames[self.full_graph.frames_per_graph - 1:])): assert ((start_frame <= all_frames) & (all_frames <= end_frame) ).sum() == self.full_graph.frames_per_graph # Create and evaluate a a subgraph corresponding to a batch of frames nodes_mask = (start_frame <= frame_num_per_node) & ( frame_num_per_node <= end_frame) edges_mask = nodes_mask[self.full_graph.graph_obj.edge_index[ 0]] & nodes_mask[self.full_graph.graph_obj.edge_index[1]] subgraph = Graph( x=self.full_graph.graph_obj.x[nodes_mask], edge_attr=self.full_graph.graph_obj.edge_attr[edges_mask], reid_emb_dists=self.full_graph.graph_obj. reid_emb_dists[edges_mask], edge_index=self.full_graph.graph_obj.edge_index.T[edges_mask].T - node_names[nodes_mask][0]) if hasattr(self.full_graph.graph_obj, 'edge_labels'): subgraph.edge_labels = self.full_graph.graph_obj.edge_labels[ edges_mask] # Predict edge values for the current batch edge_preds, pred_mask = self._predict_edges(subgraph=subgraph) # Store predictions overall_edge_preds[edges_mask] += edge_preds assert (overall_num_preds[torch.where(edges_mask)[0][pred_mask]] == overall_num_preds[edges_mask][pred_mask]).all() overall_num_preds[torch.where(edges_mask)[0][pred_mask]] += 1 # Average edge predictions over all batches, and over each pair of directed edges final_edge_preds = overall_edge_preds / overall_num_preds final_edge_preds[torch.isnan(final_edge_preds)] = 0 self.full_graph.graph_obj.edge_preds = final_edge_preds to_undirected_graph(self.full_graph, attrs_to_update=('edge_preds', 'edge_labels')) to_lightweight_graph(self.full_graph)
def _evaluate_graph_in_batches(self, subseq_graph, frames_per_graph): """ Feeds the entire sequence though the MPN in batches. It does so by applying a 'sliding window' over the sequence, where windows correspond consecutive pairs of start/end frame locations (e.g. frame 1 to 15, 5 to 20, 10 to 25, etc.). For every window, a subgraph is created by selecting all detections that fall inside it. Then this graph is fed to the message passing network, and predictions are stored. Since windows overlap, we end up with several predictions per edge. We simply average them overall all windows. """ device = torch.device("cuda") all_frames = np.array(subseq_graph.frames) frame_num_per_node = torch.from_numpy( subseq_graph.graph_df.frame.values).to(device) node_names = torch.arange(subseq_graph.graph_obj.x.shape[0]) # Iterate over overlapping windows of (starg_frame, end_frame) overall_edge_preds = torch.zeros( subseq_graph.graph_obj.num_edges).to(device) overall_num_preds = overall_edge_preds.clone() for eval_round, (start_frame, end_frame) in enumerate( zip(all_frames, all_frames[frames_per_graph - 1:])): assert ((start_frame <= all_frames) & (all_frames <= end_frame)).sum() == frames_per_graph # Create and evaluate a a subgraph corresponding to a batch of frames nodes_mask = (start_frame <= frame_num_per_node) & ( frame_num_per_node <= end_frame) edges_mask = (nodes_mask[subseq_graph.graph_obj.edge_index[0]] & nodes_mask[subseq_graph.graph_obj.edge_index[1]]) subraph = None if self.dataset_params["combined_graph"]: joints_per_bb = (subseq_graph.graph_obj.x_joint.shape[0] // subseq_graph.graph_obj.x.shape[0]) if joints_per_bb != 17: print("OH no!") node_types = self.graph_model.node_types edge_attrs = {} edge_indices = {} xs = {} joint_mask = (nodes_mask.reshape((-1, 1)).repeat( (1, joints_per_bb)).reshape((-1))) xs["bb"] = subseq_graph.graph_obj.x[nodes_mask] edge_attrs["bb-bb"] = subseq_graph.graph_obj.edge_attr[ edges_mask] edge_indices["bb-bb"] = ( subseq_graph.graph_obj.edge_index.T[edges_mask].T - node_names[nodes_mask][0]) joint_names = torch.arange( subseq_graph.graph_obj.x_joint.shape[0]) xs["joint"] = subseq_graph.graph_obj.x_joint[joint_mask] if (joint_mask.shape[0] != subseq_graph.graph_obj["edge_index_bb-joint"].shape[1] ): print("Oh no") edge_indices["bb-joint"] = subseq_graph.graph_obj[ "edge_index_bb-joint"][:, joint_mask] edge_indices["bb-joint"][0] = (edge_indices["bb-joint"][0] - node_names[nodes_mask][0]) edge_indices["bb-joint"][1] = (edge_indices["bb-joint"][1] - joint_names[joint_mask][0]) edge_attrs["bb-joint"] = subseq_graph.graph_obj[ "edge_attr_bb-joint"][joint_mask] joint_edge_mask = (edges_mask.reshape((-1, 1)).repeat( (1, joints_per_bb)).reshape((-1))) if (joint_edge_mask.shape[0] != subseq_graph. graph_obj["edge_index_joint-joint"].shape[1]): print("Oh no") edge_indices["joint-joint"] = ( subseq_graph.graph_obj["edge_index_joint-joint"] [:, joint_edge_mask] - joint_names[joint_mask][0]) edge_attrs["joint-joint"] = subseq_graph.graph_obj[ "edge_attr_joint-joint"][joint_edge_mask] subgraph = MultiGraph( node_types=node_types, edge_attrs=edge_attrs, edge_indices=edge_indices, xs=xs, x=subseq_graph.graph_obj.x[nodes_mask], edge_attr=subseq_graph.graph_obj.edge_attr[edges_mask], reid_emb_dists=subseq_graph.graph_obj. reid_emb_dists[edges_mask], edge_index=subseq_graph.graph_obj.edge_index.T[edges_mask]. T - node_names[nodes_mask][0], ) else: subgraph = Graph( x=subseq_graph.graph_obj.x[nodes_mask], edge_attr=subseq_graph.graph_obj.edge_attr[edges_mask], reid_emb_dists=subseq_graph.graph_obj. reid_emb_dists[edges_mask], edge_index=subseq_graph.graph_obj.edge_index.T[edges_mask]. T - node_names[nodes_mask][0], ) if hasattr(subseq_graph.graph_obj, "edge_labels"): subgraph.edge_labels = subseq_graph.graph_obj.edge_labels[ edges_mask] # Predict edge values for the current batch edge_preds, pred_mask = self._predict_edges(subgraph=subgraph) # Store predictions overall_edge_preds[edges_mask] += edge_preds assert (overall_num_preds[torch.where(edges_mask)[0][pred_mask]] == overall_num_preds[edges_mask][pred_mask]).all() overall_num_preds[torch.where(edges_mask)[0][pred_mask]] += 1 # Average edge predictions over all batches, and over each pair of directed edges final_edge_preds = overall_edge_preds / overall_num_preds final_edge_preds[torch.isnan(final_edge_preds)] = 0 subseq_graph.graph_obj.edge_preds = final_edge_preds to_undirected_graph(subseq_graph, attrs_to_update=("edge_preds", "edge_labels")) to_lightweight_graph(subseq_graph) return subseq_graph