def _cluster(self, semantic_logits, offset_logits): """ Compute clusters from positions and votes """ predicted_labels = torch.max(semantic_logits, 1)[1] clusters_pos = region_grow( self.raw_pos, predicted_labels, self.input.batch.to(self.device), ignore_labels=self._stuff_classes.to(self.device), radius=self.opt.cluster_radius_search, ) clusters_votes = region_grow( self.raw_pos + offset_logits, predicted_labels, self.input.batch.to(self.device), ignore_labels=self._stuff_classes.to(self.device), radius=self.opt.cluster_radius_search, nsample=200, ) all_clusters = clusters_pos + clusters_votes all_clusters = [c.to(self.device) for c in all_clusters] cluster_type = torch.zeros(len(all_clusters), dtype=torch.uint8).to(self.device) cluster_type[len(clusters_pos):] = 1 return all_clusters, cluster_type
def forward(self, epoch=-1, **kwargs): # Backbone backbone_features = self.Backbone(self.input).x # Semantic and offset heads semantic_logits = self.Semantic(backbone_features) offset_logits = self.Offset(backbone_features) # Grouping and scoring cluster_scores = None all_clusters = None cluster_type = None if epoch == -1 or epoch > self.opt.prepare_epoch: # Active by default predicted_labels = torch.max(semantic_logits, 1)[1] clusters_pos = region_grow( self.raw_pos.cpu(), predicted_labels.cpu(), self.input.batch.cpu(), ignore_labels=self._stuff_classes.cpu(), radius=self.opt.cluster_radius_search, ) clusters_votes = region_grow( self.raw_pos.cpu() + offset_logits.cpu(), predicted_labels.cpu(), self.input.batch.cpu(), ignore_labels=self._stuff_classes.cpu(), radius=self.opt.cluster_radius_search, ) all_clusters = clusters_pos + clusters_votes all_clusters = [c.to(self.device) for c in all_clusters] cluster_type = torch.zeros(len(all_clusters), dtype=torch.uint8).to(self.device) cluster_type[len(clusters_pos):] = 1 if len(all_clusters): x = [] coords = [] batch = [] for i, cluster in enumerate(all_clusters): x.append(backbone_features[cluster]) coords.append(self.input.coords[cluster]) batch.append(i * torch.ones(cluster.shape[0])) batch_cluster = Data(x=torch.cat(x).cpu(), coords=torch.cat(coords).cpu(), batch=torch.cat(batch).cpu()) score_backbone_out = self.Scorer(batch_cluster) if self._scorer_is_encoder: cluster_feats = score_backbone_out.x else: cluster_feats = scatter_max(score_backbone_out.x, score_backbone_out.batch, dim=0) cluster_scores = self.ScorerHead(cluster_feats) self.output = PanopticResults( semantic_logits=semantic_logits, offset_logits=offset_logits, clusters=all_clusters, cluster_scores=cluster_scores, cluster_type=cluster_type, ) # Sets visual data for debugging self._dump_visuals(epoch) # Compute loss self._compute_loss()