Esempio n. 1
0
 def _get_embeddings_and_geodists_for_mesh(
         self, embedder: nn.Module,
         mesh_name: str) -> Tuple[torch.Tensor, torch.Tensor]:
     """
     Produces embeddings and geodesic distance tensors for a given mesh. May subsample
     the mesh, if it contains too many vertices (controlled by
     SHAPE_CYCLE_LOSS_MAX_NUM_VERTICES parameter).
     Args:
         embedder (nn.Module): module that computes embeddings for mesh vertices
         mesh_name (str): mesh name
     Return:
         embeddings (torch.Tensor of size [N, D]): embeddings for selected mesh
             vertices (N = number of selected vertices, D = embedding space dim)
         geodists (torch.Tensor of size [N, N]): geodesic distances for the selected
             mesh vertices (N = number of selected vertices)
     """
     embeddings = embedder(mesh_name)
     indices = sample_random_indices(embeddings.shape[0],
                                     self.max_num_vertices,
                                     embeddings.device)
     mesh = create_mesh(mesh_name, embeddings.device)
     geodists = mesh.geodists
     if indices is not None:
         embeddings = embeddings[indices]
         geodists = geodists[torch.meshgrid(indices, indices)]
     return embeddings, geodists
    def evaluate(self):
        ge_per_mesh = {}
        gps_per_mesh = {}
        for mesh_name_1 in self.mesh_names:
            avg_errors = []
            avg_gps = []
            embeddings_1 = self.embedder(mesh_name_1)
            keyvertices_1 = self.mesh_keyvertices[mesh_name_1]
            keyvertex_names_1 = list(keyvertices_1.keys())
            keyvertex_indices_1 = [
                keyvertices_1[name] for name in keyvertex_names_1
            ]
            for mesh_name_2 in self.mesh_names:
                if mesh_name_1 == mesh_name_2:
                    continue
                embeddings_2 = self.embedder(mesh_name_2)
                keyvertices_2 = self.mesh_keyvertices[mesh_name_2]
                sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(
                    embeddings_2.T)
                vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(
                    axis=1)
                mesh_2 = create_mesh(mesh_name_2, embeddings_2.device)
                geodists = mesh_2.geodists[
                    vertices_2_matching_keyvertices_1,
                    [keyvertices_2[name] for name in keyvertex_names_1], ]
                Current_Mean_Distances = 0.255
                gps = (-(geodists**2) / (2 *
                                         (Current_Mean_Distances**2))).exp()
                avg_errors.append(geodists.mean().item())
                avg_gps.append(gps.mean().item())

            ge_mean = torch.as_tensor(avg_errors).mean().item()
            gps_mean = torch.as_tensor(avg_gps).mean().item()
            ge_per_mesh[mesh_name_1] = ge_mean
            gps_per_mesh[mesh_name_1] = gps_mean
        ge_mean_global = torch.as_tensor(list(
            ge_per_mesh.values())).mean().item()
        gps_mean_global = torch.as_tensor(list(
            gps_per_mesh.values())).mean().item()
        per_mesh_metrics = {
            "GE": ge_per_mesh,
            "GPS": gps_per_mesh,
        }
        return ge_mean_global, gps_mean_global, per_mesh_metrics
    def __call__(
        self,
        proposals_with_gt: List[Instances],
        densepose_predictor_outputs: Any,
        packed_annotations: PackedCseAnnotations,
        interpolator: BilinearInterpolationHelper,
        embedder: nn.Module,
    ) -> Dict[int, torch.Tensor]:
        """
        Produces losses for estimated embeddings given annotated vertices.
        Embeddings for all the vertices of a mesh are computed by the embedder.
        Embeddings for observed pixels are estimated by a predictor.
        Losses are computed as cross-entropy for unnormalized scores given
        ground truth vertex IDs.
         1) squared distances between estimated vertex embeddings
            and mesh vertex embeddings;
         2) geodesic distances between vertices of a mesh

        Args:
            proposals_with_gt (list of Instances): detections with associated
                ground truth data; each item corresponds to instances detected
                on 1 image; the number of items corresponds to the number of
                images in a batch
            densepose_predictor_outputs: an object of a dataclass that contains predictor
                outputs with estimated values; assumed to have the following attributes:
                * embedding - embedding estimates, tensor of shape [N, D, S, S], where
                  N = number of instances (= sum N_i, where N_i is the number of
                      instances on image i)
                  D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
                  S = output size (width and height)
            packed_annotations (PackedCseAnnotations): contains various data useful
                for loss computation, each data is packed into a single tensor
            interpolator (BilinearInterpolationHelper): bilinear interpolation helper
            embedder (nn.Module): module that computes vertex embeddings for different meshes
        Return:
            dict(int -> tensor): losses for different mesh IDs
        """
        losses = {}
        for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique():
            mesh_id = mesh_id_tensor.item()
            mesh_name = MeshCatalog.get_mesh_name(mesh_id)
            # valid points are those that fall into estimated bbox
            # and correspond to the current mesh
            j_valid = interpolator.j_valid * (
                packed_annotations.vertex_mesh_ids_gt == mesh_id)
            # extract estimated embeddings for valid points
            # -> tensor [J, D]
            vertex_embeddings_i = normalize_embeddings(
                interpolator.extract_at_points(
                    densepose_predictor_outputs.embedding,
                    slice_fine_segm=slice(None),
                    w_ylo_xlo=interpolator.w_ylo_xlo[:, None],
                    w_ylo_xhi=interpolator.w_ylo_xhi[:, None],
                    w_yhi_xlo=interpolator.w_yhi_xlo[:, None],
                    w_yhi_xhi=interpolator.w_yhi_xhi[:, None],
                )[j_valid, :])
            # extract vertex ids for valid points
            # -> tensor [J]
            vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid]
            # embeddings for all mesh vertices
            # -> tensor [K, D]
            mesh_vertex_embeddings = embedder(mesh_name)
            # softmax values of geodesic distances for GT mesh vertices
            # -> tensor [J, K]
            mesh = create_mesh(mesh_name, mesh_vertex_embeddings.device)
            geodist_softmax_values = F.softmax(
                mesh.geodists[vertex_indices_i] / (-self.geodist_gauss_sigma),
                dim=1)
            # logsoftmax values for valid points
            # -> tensor [J, K]
            embdist_logsoftmax_values = F.log_softmax(
                squared_euclidean_distance_matrix(vertex_embeddings_i,
                                                  mesh_vertex_embeddings) /
                (-self.embdist_gauss_sigma),
                dim=1,
            )
            losses[mesh_name] = (-geodist_softmax_values *
                                 embdist_logsoftmax_values).sum(1).mean()

        for mesh_name in embedder.mesh_names:
            if mesh_name not in losses:
                losses[mesh_name] = self.fake_value(
                    densepose_predictor_outputs, embedder, mesh_id)
        return losses