コード例 #1
0
ファイル: mot_graph.py プロジェクト: songheony/AAA-multi
    def _get_edge_ixs(self, reid_embeddings):
        """
        Constructs graph edges by taking pairs of nodes with valid time connections (not in same frame, not too far
        apart in time) and perhaps taking KNNs according to reid embeddings.
        Args:
            reid_embeddings: torch.tensor with shape (num_nodes, reid_embeds_dim)

        Returns:
            torch.tensor withs shape (2, num_edges)
        """

        edge_ixs = get_time_valid_conn_ixs(
            frame_num=torch.from_numpy(self.graph_df.frame.values),
            max_frame_dist=self.max_frame_dist,
            use_cuda=self.inference_mode
            and self.graph_df["frame_path"].iloc[0].find("MOT17-03") == -1,
        )

        # During inference, top k nns must not be done here, as it is computed independently for sequence chunks
        if not self.inference_mode and self.dataset_params[
                "top_k_nns"] is not None:
            reid_pwise_dist = F.pairwise_distance(reid_embeddings[edge_ixs[0]],
                                                  reid_embeddings[edge_ixs[1]])
            k_nns_mask = get_knn_mask(
                pwise_dist=reid_pwise_dist,
                edge_ixs=edge_ixs,
                num_nodes=self.graph_df.shape[0],
                top_k_nns=self.dataset_params["top_k_nns"],
                reciprocal_k_nns=self.dataset_params["reciprocal_k_nns"],
                symmetric_edges=False,
                use_cuda=self.inference_mode,
            )
            edge_ixs = edge_ixs.T[k_nns_mask].T

        return edge_ixs
コード例 #2
0
    def _get_edge_ixs(self, reid_embeddings):
        """
        Constructs graph edges by taking pairs of nodes with valid time connections (not in same frame, not too far
        apart in time) and perhaps taking KNNs according to reid embeddings.
        Args:
            reid_embeddings: torch.tensor with shape (num_nodes, reid_embeds_dim)

        Returns:
            torch.tensor withs shape (2, num_edges)
        """

        edge_ixs = get_time_valid_conn_ixs(
            frame_num=torch.from_numpy(self.graph_df.frame.values),
            max_frame_dist=self.max_frame_dist,
            use_cuda=self.inference_mode
            and self.graph_df["frame_path"].iloc[0].find("MOT17-03") == -1,
        )

        edge_feats_dict = None
        if (
            "max_feet_vel" in self.dataset_params
            and self.dataset_params["max_feet_vel"] is not None
        ):  # New parameter. We do graph pruning based on feet velocity
            # print("VELOCITY PRUNING")
            edge_feats_dict = compute_edge_feats_dict(
                edge_ixs=edge_ixs,
                det_df=self.graph_df,
                fps=self.seq_info_dict["fps"],
                use_cuda=self.inference_mode,
            )

            feet_vel = torch.sqrt(
                edge_feats_dict["norm_feet_x_dists"] ** 2
                + edge_feats_dict["norm_feet_y_dists"] ** 2
            )
            vel_mask = feet_vel < self.dataset_params["max_feet_vel"]
            edge_ixs = edge_ixs.T[vel_mask].T
            for feat_name, feat_vals in edge_feats_dict.items():
                edge_feats_dict[feat_name] = feat_vals[vel_mask]

        # During inference, top k nns must not be done here, as it is computed independently for sequence chunks
        if not self.inference_mode and self.dataset_params["top_k_nns"] is not None:
            reid_pwise_dist = F.pairwise_distance(
                reid_embeddings[edge_ixs[0]], reid_embeddings[edge_ixs[1]]
            )
            k_nns_mask = get_knn_mask(
                pwise_dist=reid_pwise_dist,
                edge_ixs=edge_ixs,
                num_nodes=self.graph_df.shape[0],
                top_k_nns=self.dataset_params["top_k_nns"],
                reciprocal_k_nns=self.dataset_params["reciprocal_k_nns"],
                symmetric_edges=False,
                use_cuda=self.inference_mode,
            )
            edge_ixs = edge_ixs.T[k_nns_mask].T
            if edge_feats_dict is not None:
                for feat_name, feat_vals in edge_feats_dict.items():
                    edge_feats_dict[feat_name] = feat_vals[k_nns_mask]

        return edge_ixs, edge_feats_dict
コード例 #3
0
    def _predict_edges(self, subgraph):
        """
        Predicts edge values for a subgraph (i.e. batch of frames) from the entire sequence.
        Args:
            subgraph: Graph Object corresponding to a subset of frames

        Returns:
            tuple containing a torch.Tensor with the predicted value for every edge in the subgraph, and a binary mask
            indicating which edges inside the subgraph where pruned with KNN
        """
        # Prune graph edges
        knn_mask = get_knn_mask(
            pwise_dist=subgraph.reid_emb_dists,
            edge_ixs=subgraph.edge_index,
            num_nodes=subgraph.num_nodes,
            top_k_nns=self.dataset_params['top_k_nns'],
            use_cuda=True,
            reciprocal_k_nns=self.dataset_params['reciprocal_k_nns'],
            symmetric_edges=True)
        subgraph.edge_index = subgraph.edge_index.T[knn_mask].T
        subgraph.edge_attr = subgraph.edge_attr[knn_mask]
        if hasattr(subgraph, 'edge_labels'):
            subgraph.edge_labels = subgraph.edge_labels[knn_mask]

        # Predict active edges
        if self.use_gt:  # For debugging purposes and obtaining oracle results
            pruned_edge_preds = subgraph.edge_labels

        else:
            with torch.no_grad():
                pruned_edge_preds = torch.sigmoid(
                    self.graph_model(subgraph)['classified_edges'][-1].view(
                        -1))

        edge_preds = torch.zeros(knn_mask.shape[0]).to(
            pruned_edge_preds.device)
        edge_preds[knn_mask] = pruned_edge_preds

        if self.eval_params['set_pruned_edges_to_inactive']:
            return edge_preds, torch.ones_like(knn_mask)

        else:
            return edge_preds, knn_mask  # In this case, pruning an edge counts as not predicting a value for it at all
コード例 #4
0
    def _predict_edges(self, subgraph):
        """
        Predicts edge values for a subgraph (i.e. batch of frames) from the entire sequence.
        Args:
            subgraph: Graph Object corresponding to a subset of frames

        Returns:
            tuple containing a torch.Tensor with the predicted value for every edge in the subgraph, and a binary mask
            indicating which edges inside the subgraph where pruned with KNN
        """
        # Prune graph edges
        knn_mask = get_knn_mask(
            pwise_dist=subgraph.reid_emb_dists,
            edge_ixs=subgraph.edge_index,
            num_nodes=subgraph.num_nodes,
            top_k_nns=self.dataset_params["top_k_nns"],
            use_cuda=True,
            reciprocal_k_nns=self.dataset_params["reciprocal_k_nns"],
            symmetric_edges=True,
        )
        subgraph.edge_index = subgraph.edge_index.T[knn_mask].T
        subgraph.edge_attr = subgraph.edge_attr[knn_mask]
        if hasattr(subgraph, "edge_labels"):
            subgraph.edge_labels = subgraph.edge_labels[knn_mask]

        if hasattr(subgraph, "edge_index_bb-bb"):
            joints_per_bb = (subgraph["edge_index_joint-joint"].shape[1] //
                             subgraph["edge_index_bb-bb"].shape[1])

            subgraph["edge_index_bb-bb"] = subgraph["edge_index_bb-bb"].T[
                knn_mask].T
            subgraph["edge_attr_bb-bb"] = subgraph["edge_attr_bb-bb"][knn_mask]

            joint_knn_mask = (knn_mask.reshape((-1, 1)).repeat(
                (1, joints_per_bb)).reshape((-1)))

            subgraph["edge_index_joint-joint"] = (
                subgraph["edge_index_joint-joint"].T[joint_knn_mask].T)
            subgraph["edge_attr_joint-joint"] = subgraph[
                "edge_attr_joint-joint"][joint_knn_mask]

        # Predict active edges
        if self.use_gt:  # For debugging purposes and obtaining oracle results
            pruned_edge_preds = subgraph.edge_labels

        else:
            with torch.no_grad():
                result = self.graph_model(subgraph)["classified_edges"][-1]
                if isinstance(result, dict):
                    result = result["bb-bb"]
                pruned_edge_preds = torch.sigmoid(result.view(-1))

        edge_preds = torch.zeros(knn_mask.shape[0]).to(
            pruned_edge_preds.device)
        edge_preds[knn_mask] = pruned_edge_preds

        if self.eval_params["set_pruned_edges_to_inactive"]:
            return edge_preds, torch.ones_like(knn_mask)

        else:
            return (
                edge_preds,
                knn_mask,
            )  # In this case, pruning an edge counts as not predicting a value for it at all