def _compute_full_miou(self): if self._full_vote_miou is not None: return has_prediction = self._test_area.prediction_count > 0 log.info( "Computing full res mIoU, we have predictions for %.2f%% of the points." % (torch.sum(has_prediction) / (1.0 * has_prediction.shape[0]) * 100)) self._test_area = self._test_area.to("cpu") # Full res interpolation full_pred = knn_interpolate( self._test_area.votes[has_prediction], self._test_area.pos[has_prediction], self._test_area.pos, k=1, ) # Full res pred c = ConfusionMatrix(self._num_classes) c.count_predicted_batch(self._test_area.y.numpy(), torch.argmax(full_pred, 1).numpy()) self._full_vote_miou = c.get_average_intersection_union() * 100
def _predict_full_res(self): """ Predict full resolution results based on votes """ has_prediction = self._vote_counts > 0 votes = self._votes[has_prediction].div(self._vote_counts[has_prediction].unsqueeze(-1)) # Upsample and predict full_pred = knn_interpolate(votes, self._raw_data.pos[has_prediction], self._raw_data.pos, k=self._k) self._full_res_preds = full_pred
def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None): # transform low-res features and reduce the number of features x_sub = self.mlp_sub(x_sub) # interpolate low-res feats to high-res points x_interpolated = knn_interpolate(x_sub, pos_sub, pos, k=3, batch_x=batch_sub, batch_y=batch) x = self.mlp(x) + x_interpolated return x
def _predict_full_res(self): """ Predict full resolution results based on votes """ for id_scan in self._votes: has_prediction = self._vote_counts[id_scan] > 0 self._votes[id_scan][has_prediction] /= self._vote_counts[id_scan][has_prediction].unsqueeze(-1) # Upsample and predict full_pred = knn_interpolate( self._votes[id_scan][has_prediction], self._raw_datas[id_scan].pos[has_prediction], self._raw_datas[id_scan].pos, k=1, ) self._full_preds[id_scan] = full_pred.argmax(-1)
def upsample(self, y, coarse_nodes, coarse_batch, fine): fine_nodes = fine.x[:, :2] y = knn_interpolate(y.cpu(), coarse_nodes[:, :2].cpu(), fine_nodes.cpu(), coarse_batch.cpu(), fine.batch.cpu(), k=3).to(y.device) return y