def majority_label_from_topk(top_k_labels, num_classes=None): """Compute majority label from top-k retrieved labels. Args: top_k_labels: A 2-D long tensor with shape `[num_queries, top_k]`. Returns: A 1-D long tensor with shape `[num_queries]`. """ one_hot_top_k_labels = common_utils.one_hot(top_k_labels, num_classes) one_hot_top_k_labels = torch.sum(one_hot_top_k_labels, dim=1) majority_labels = torch.argmax(one_hot_top_k_labels, 1) return majority_labels
def find_majority_label_index(semantic_labels, cluster_labels): """Finds indices of pixels that belong to their majority label in a cluster. Args: semantic_labels: An N-D long tensor for semantic labels. cluster_labels: An N-D long tensor for cluster labels. Returns: select_pixel_indices: An 2-D long tensor for indices of pixels that belong to their majority label in a cluster. majority_semantic_labels: A 1-D long tensor for the semantic label for each cluster with length `[num_clusters]`. """ semantic_labels = semantic_labels.view(-1) cluster_labels = cluster_labels.view(-1) num_clusters = cluster_labels.max() + 1 num_classes = semantic_labels.max() + 1 #one_hot_semantic_labels = common_utils.one_hot( # semantic_labels, semantic_labels.max() + 1).float() #one_hot_cluster_labels = common_utils.one_hot( # cluster_labels, cluster_labels.max() + 1).float() #accumulate_semantic_labels = torch.mm(one_hot_cluster_labels.t(), # one_hot_semantic_labels) one_hot_semantic_labels = common_utils.one_hot(semantic_labels, num_classes) accumulate_semantic_labels = torch.zeros((num_clusters, num_classes), dtype=torch.long, device=semantic_labels.device) accumulate_semantic_labels = accumulate_semantic_labels.scatter_add_( 0, cluster_labels.view(-1, 1).expand(-1, num_classes), one_hot_semantic_labels) majority_semantic_labels = torch.argmax(accumulate_semantic_labels, 1) cluster_semantic_labels = torch.gather(majority_semantic_labels, 0, cluster_labels) select_pixel_indices = torch.eq(cluster_semantic_labels, semantic_labels) select_pixel_indices = select_pixel_indices.nonzero() return select_pixel_indices, majority_semantic_labels
def main(): """Inference for semantic segmentation. """ # Retreve experiment configurations. args = parse_args('Inference for semantic segmentation.') config.network.kmeans_num_clusters = separate_comma(args.kmeans_num_clusters) config.network.label_divisor = args.label_divisor # Create directories to save results. semantic_dir = os.path.join(args.save_dir, 'semantic_gray') semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color') # Create color map. color_map = vis_utils.load_color_map(config.dataset.color_map_path) color_map = color_map.numpy() # Create data loaders. test_dataset = ListDataset( data_dir=args.data_dir, data_list=args.data_list, img_mean=config.network.pixel_means, img_std=config.network.pixel_stds, size=None, random_crop=False, random_scale=False, random_mirror=False, training=False) test_image_paths = test_dataset.image_paths # Create models. if config.network.backbone_types == 'panoptic_pspnet_101': embedding_model = resnet_101_pspnet(config).cuda() elif config.network.backbone_types == 'panoptic_deeplab_101': embedding_model = resnet_101_deeplab(config).cuda() else: raise ValueError('Not support ' + config.network.backbone_types) if config.network.prediction_types == 'segsort': prediction_model = segsort(config) else: raise ValueError('Not support ' + config.network.prediction_types) embedding_model = embedding_model.to("cuda:0") prediction_model = prediction_model.to("cuda:0") embedding_model.eval() prediction_model.eval() # Load trained weights. model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth') save_iter = config.train.max_iteration - 1 embedding_model.load_state_dict( torch.load(model_path_template.format(save_iter))['embedding_model'], resume=True) prediction_model.load_state_dict( torch.load(model_path_template.format(save_iter))['prediction_model']) # Define CRF. postprocessor = DenseCRF( iter_max=args.crf_iter_max, pos_xy_std=args.crf_pos_xy_std, pos_w=args.crf_pos_w, bi_xy_std=args.crf_bi_xy_std, bi_rgb_std=args.crf_bi_rgb_std, bi_w=args.crf_bi_w,) # Load memory prototypes. semantic_memory_prototypes, semantic_memory_prototype_labels = None, None if args.semantic_memory_dir is not None: semantic_memory_prototypes, semantic_memory_prototype_labels = ( segsort_others.load_memory_banks(args.semantic_memory_dir)) semantic_memory_prototypes = semantic_memory_prototypes.to("cuda:0") semantic_memory_prototype_labels = semantic_memory_prototype_labels.to("cuda:0") # Remove ignore class. valid_prototypes = torch.ne( semantic_memory_prototype_labels, config.dataset.semantic_ignore_index).nonzero() valid_prototypes = valid_prototypes.view(-1) semantic_memory_prototypes = torch.index_select( semantic_memory_prototypes, 0, valid_prototypes) semantic_memory_prototype_labels = torch.index_select( semantic_memory_prototype_labels, 0, valid_prototypes) # Start inferencing. for data_index in tqdm(range(len(test_dataset))): # Image path. image_path = test_image_paths[data_index] base_name = os.path.basename(image_path).replace('.jpg', '.png') # Image resolution. image_batch, label_batch, _ = test_dataset[data_index] image_h, image_w = image_batch['image'].shape[-2:] # Resize the input image. if config.test.image_size > 0: image_batch['image'] = transforms.resize_with_interpolation( image_batch['image'].transpose(1, 2, 0), config.test.image_size, method='bilinear').transpose(2, 0, 1) for lab_name in ['semantic_label', 'instance_label']: label_batch[lab_name] = transforms.resize_with_interpolation( label_batch[lab_name], config.test.image_size, method='nearest') resize_image_h, resize_image_w = image_batch['image'].shape[-2:] # Crop and Pad the input image. image_batch['image'] = transforms.resize_with_pad( image_batch['image'].transpose(1, 2, 0), config.test.crop_size, image_pad_value=0).transpose(2, 0, 1) image_batch['image'] = torch.FloatTensor( image_batch['image'][np.newaxis, ...]).to("cuda:0") pad_image_h, pad_image_w = image_batch['image'].shape[-2:] # Create the fake labels where clustering ignores 255. fake_label_batch = {} for label_name in ['semantic_label', 'instance_label']: lab = np.zeros((resize_image_h, resize_image_w), dtype=np.uint8) lab = transforms.resize_with_pad( lab, config.test.crop_size, image_pad_value=config.dataset.semantic_ignore_index) fake_label_batch[label_name] = torch.LongTensor( lab[np.newaxis, ...]).to("cuda:0") # Put label batch to gpu 1. for k, v in label_batch.items(): label_batch[k] = torch.LongTensor(v[np.newaxis, ...]).to("cuda:0") # Create the ending index of each patch. stride_h, stride_w = config.test.stride crop_h, crop_w = config.test.crop_size npatches_h = math.ceil(1.0 * (pad_image_h-crop_h) / stride_h) + 1 npatches_w = math.ceil(1.0 * (pad_image_w-crop_w) / stride_w) + 1 patch_ind_h = np.linspace( crop_h, pad_image_h, npatches_h, dtype=np.int32) patch_ind_w = np.linspace( crop_w, pad_image_w, npatches_w, dtype=np.int32) # Create place holder for full-resolution embeddings. embeddings = {} counts = torch.FloatTensor( 1, 1, pad_image_h, pad_image_w).zero_().to("cuda:0") with torch.no_grad(): for ind_h in patch_ind_h: for ind_w in patch_ind_w: sh, eh = ind_h - crop_h, ind_h sw, ew = ind_w - crop_w, ind_w crop_image_batch = { k: v[:, :, sh:eh, sw:ew] for k, v in image_batch.items()} # Feed-forward. crop_embeddings = embedding_model.generate_embeddings( crop_image_batch, resize_as_input=True) # Initialize embedding. for name in crop_embeddings: if crop_embeddings[name] is None: continue crop_emb = crop_embeddings[name].to("cuda:0") if name in ['embedding']: crop_emb = common_utils.normalize_embedding( crop_emb.permute(0, 2, 3, 1).contiguous()) crop_emb = crop_emb.permute(0, 3, 1, 2) else: continue if name not in embeddings.keys(): embeddings[name] = torch.FloatTensor( 1, crop_emb.shape[1], pad_image_h, pad_image_w).zero_().to("cuda:0") embeddings[name][:, :, sh:eh, sw:ew] += crop_emb counts[:, :, sh:eh, sw:ew] += 1 for k in embeddings.keys(): embeddings[k] /= counts # KMeans. lab_div = config.network.label_divisor fake_sem_lab = fake_label_batch['semantic_label'] fake_inst_lab = fake_label_batch['instance_label'] clustering_outputs = embedding_model.generate_clusters( embeddings.get('embedding', None), fake_sem_lab, fake_inst_lab) embeddings.update(clustering_outputs) # Generate predictions. outputs = prediction_model( embeddings, {'semantic_memory_prototype': semantic_memory_prototypes, 'semantic_memory_prototype_label': semantic_memory_prototype_labels}, with_loss=False, with_prediction=True) semantic_topk = outputs['semantic_score'] # DenseCRF post-processing. semantic_prob = common_utils.one_hot( semantic_topk, max_label=config.dataset.num_classes) semantic_prob = semantic_prob.sum(dim=1).float() / semantic_topk.shape[1] semantic_prob = semantic_prob.view(resize_image_h, resize_image_w, -1) semantic_prob = semantic_prob.data.cpu().numpy().astype(np.float32) semantic_prob = semantic_prob.transpose(2, 0, 1) image = image_batch['image'].data.cpu().numpy().astype(np.float32) image = image[0, :, :resize_image_h, :resize_image_w].transpose(1, 2, 0) image *= np.reshape(config.network.pixel_stds, (1, 1, 3)) image += np.reshape(config.network.pixel_means, (1, 1, 3)) image = image * 255 image = image.astype(np.uint8) semantic_prob = postprocessor(image, semantic_prob) # Save semantic predictions. semantic_pred = np.argmax(semantic_prob, axis=0).astype(np.uint8) semantic_pred = cv2.resize( semantic_pred, (image_w, image_h), interpolation=cv2.INTER_NEAREST) semantic_pred_name = os.path.join( semantic_dir, base_name) if not os.path.isdir(os.path.dirname(semantic_pred_name)): os.makedirs(os.path.dirname(semantic_pred_name)) Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name) semantic_pred_rgb = color_map[semantic_pred] semantic_pred_rgb_name = os.path.join( semantic_rgb_dir, base_name) if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)): os.makedirs(os.path.dirname(semantic_pred_rgb_name)) Image.fromarray(semantic_pred_rgb, mode='RGB').save( semantic_pred_rgb_name)
def gather_multiset_labels_per_batch_by_nearest_neighbor( embeddings, prototypes, semantic_prototype_labels, batch_embedding_labels, batch_prototype_labels, num_classes=21, top_k=3, threshold=0.95, label_divisor=255): """Assigned labels for unlabelled pixels by nearest-neighbor labeled segments, which is useful in feature affinity regularization. Args: embeddings: A float tensor indicates pixel-wise embeddings, whose last dimension denotes feature channels. prototypes: A 2-D float tensor indicates segment prototypes of shape `[num_segments, channels]`. semantic_prototype_labels: A 1-D float tensor indicates segment semantic labels of shape `[num_segments]`. batch_embedding_labels: A 1-D long tensor indicates pixel-wise batch indices, which should include the same number of pixels as `embeddings`. batch_prototype_labels: A 1-D long tensor indicates segment batch indices, which should include the same number of segments as `prototypes`. num_classes: An integer indicates the number of semantic categories. top_k: An integer indicates top-K retrievals. threshold: A float indicates the confidence threshold. label_divisor: An integer indicates the ignored label index. Return: A 2-D long tensor of shape `[num_pixels, num_classes]`. If entry i's value is 1, the nearest-neighbor segment is of category i. """ embeddings = embeddings.view(-1, embeddings.shape[-1]) prototypes = prototypes.view(-1, embeddings.shape[-1]) N, C = embeddings.shape # Compute distance and retrieve nearest neighbors. batch_affinity = torch.eq(batch_embedding_labels.view(-1, 1), batch_prototype_labels.view(1, -1)) valid_prototypes = (semantic_prototype_labels < num_classes).view(1, -1) label_affinity = batch_affinity & valid_prototypes dists = torch.mm(embeddings, prototypes.t()) min_dist = dists.min() dists = torch.where(label_affinity, dists, min_dist - 1) nn_dists, nn_inds = torch.topk(dists, top_k, dim=1) setsemantic_labels = torch.gather( semantic_prototype_labels.view(1, -1).expand(N, -1), 1, nn_inds) setsemantic_labels = setsemantic_labels.masked_fill( nn_dists < threshold, num_classes) setbatch_labels = torch.gather( batch_prototype_labels.view(1, -1).expand(N, -1), 1, nn_inds) # Remove ignored cluster embeddings. setsemantic_labels_2d = common_utils.one_hot( setsemantic_labels, num_classes+1) setsemantic_labels_2d = torch.sum(setsemantic_labels_2d, dim=1) setsemantic_labels_2d = (setsemantic_labels_2d > 0).long() setsemantic_labels_2d = setsemantic_labels_2d[:, :num_classes] return setsemantic_labels_2d