Beispiel #1
0
def majority_label_from_topk(top_k_labels, num_classes=None):
  """Compute majority label from top-k retrieved labels.
  
  Args:
    top_k_labels: A 2-D long tensor with shape `[num_queries, top_k]`.

  Returns:
    A 1-D long tensor with shape `[num_queries]`.
  """
  one_hot_top_k_labels = common_utils.one_hot(top_k_labels,
                                              num_classes)
  one_hot_top_k_labels = torch.sum(one_hot_top_k_labels,
                                   dim=1)
  majority_labels = torch.argmax(one_hot_top_k_labels, 1)

  return majority_labels
Beispiel #2
0
def find_majority_label_index(semantic_labels, cluster_labels):
    """Finds indices of pixels that belong to their majority
  label in a cluster.

  Args:
    semantic_labels: An N-D long tensor for semantic labels.
    cluster_labels: An N-D long tensor for cluster labels.

  Returns:
    select_pixel_indices: An 2-D long tensor for indices of pixels
      that belong to their majority label in a cluster.
    majority_semantic_labels: A 1-D long tensor for the semantic
      label for each cluster with length `[num_clusters]`.
  """
    semantic_labels = semantic_labels.view(-1)
    cluster_labels = cluster_labels.view(-1)
    num_clusters = cluster_labels.max() + 1
    num_classes = semantic_labels.max() + 1

    #one_hot_semantic_labels = common_utils.one_hot(
    #    semantic_labels, semantic_labels.max() + 1).float()
    #one_hot_cluster_labels = common_utils.one_hot(
    #    cluster_labels, cluster_labels.max() + 1).float()

    #accumulate_semantic_labels = torch.mm(one_hot_cluster_labels.t(),
    #                                      one_hot_semantic_labels)
    one_hot_semantic_labels = common_utils.one_hot(semantic_labels,
                                                   num_classes)
    accumulate_semantic_labels = torch.zeros((num_clusters, num_classes),
                                             dtype=torch.long,
                                             device=semantic_labels.device)
    accumulate_semantic_labels = accumulate_semantic_labels.scatter_add_(
        0,
        cluster_labels.view(-1, 1).expand(-1, num_classes),
        one_hot_semantic_labels)
    majority_semantic_labels = torch.argmax(accumulate_semantic_labels, 1)

    cluster_semantic_labels = torch.gather(majority_semantic_labels, 0,
                                           cluster_labels)
    select_pixel_indices = torch.eq(cluster_semantic_labels, semantic_labels)
    select_pixel_indices = select_pixel_indices.nonzero()

    return select_pixel_indices, majority_semantic_labels
Beispiel #3
0
def main():
  """Inference for semantic segmentation.
  """
  # Retreve experiment configurations.
  args = parse_args('Inference for semantic segmentation.')
  config.network.kmeans_num_clusters = separate_comma(args.kmeans_num_clusters)
  config.network.label_divisor = args.label_divisor

  # Create directories to save results.
  semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
  semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

  # Create color map.
  color_map = vis_utils.load_color_map(config.dataset.color_map_path)
  color_map = color_map.numpy()

  # Create data loaders.
  test_dataset = ListDataset(
      data_dir=args.data_dir,
      data_list=args.data_list,
      img_mean=config.network.pixel_means,
      img_std=config.network.pixel_stds,
      size=None,
      random_crop=False,
      random_scale=False,
      random_mirror=False,
      training=False)
  test_image_paths = test_dataset.image_paths

  # Create models.
  if config.network.backbone_types == 'panoptic_pspnet_101':
    embedding_model = resnet_101_pspnet(config).cuda()
  elif config.network.backbone_types == 'panoptic_deeplab_101':
    embedding_model = resnet_101_deeplab(config).cuda()
  else:
    raise ValueError('Not support ' + config.network.backbone_types)

  if config.network.prediction_types == 'segsort':
    prediction_model = segsort(config)
  else:
    raise ValueError('Not support ' + config.network.prediction_types)

  embedding_model = embedding_model.to("cuda:0")
  prediction_model = prediction_model.to("cuda:0")
  embedding_model.eval()
  prediction_model.eval()
      
  # Load trained weights.
  model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
  save_iter = config.train.max_iteration - 1
  embedding_model.load_state_dict(
      torch.load(model_path_template.format(save_iter))['embedding_model'],
      resume=True)
  prediction_model.load_state_dict(
      torch.load(model_path_template.format(save_iter))['prediction_model'])

  # Define CRF.
  postprocessor = DenseCRF(
      iter_max=args.crf_iter_max,
      pos_xy_std=args.crf_pos_xy_std,
      pos_w=args.crf_pos_w,
      bi_xy_std=args.crf_bi_xy_std,
      bi_rgb_std=args.crf_bi_rgb_std,
      bi_w=args.crf_bi_w,)

  # Load memory prototypes.
  semantic_memory_prototypes, semantic_memory_prototype_labels = None, None
  if args.semantic_memory_dir is not None:
    semantic_memory_prototypes, semantic_memory_prototype_labels = (
      segsort_others.load_memory_banks(args.semantic_memory_dir))
    semantic_memory_prototypes = semantic_memory_prototypes.to("cuda:0")
    semantic_memory_prototype_labels = semantic_memory_prototype_labels.to("cuda:0")

    # Remove ignore class.
    valid_prototypes = torch.ne(
        semantic_memory_prototype_labels,
        config.dataset.semantic_ignore_index).nonzero()
    valid_prototypes = valid_prototypes.view(-1)
    semantic_memory_prototypes = torch.index_select(
        semantic_memory_prototypes,
        0,
        valid_prototypes)
    semantic_memory_prototype_labels = torch.index_select(
        semantic_memory_prototype_labels,
        0,
        valid_prototypes)

  # Start inferencing.
  for data_index in tqdm(range(len(test_dataset))):
    # Image path.
    image_path = test_image_paths[data_index]
    base_name = os.path.basename(image_path).replace('.jpg', '.png')

    # Image resolution.
    image_batch, label_batch, _ = test_dataset[data_index]
    image_h, image_w = image_batch['image'].shape[-2:]

    # Resize the input image.
    if config.test.image_size > 0:
      image_batch['image'] = transforms.resize_with_interpolation(
          image_batch['image'].transpose(1, 2, 0),
          config.test.image_size,
          method='bilinear').transpose(2, 0, 1)
      for lab_name in ['semantic_label', 'instance_label']:
        label_batch[lab_name] = transforms.resize_with_interpolation(
            label_batch[lab_name],
            config.test.image_size,
            method='nearest')
    resize_image_h, resize_image_w = image_batch['image'].shape[-2:]

    # Crop and Pad the input image.
    image_batch['image'] = transforms.resize_with_pad(
        image_batch['image'].transpose(1, 2, 0),
        config.test.crop_size,
        image_pad_value=0).transpose(2, 0, 1)
    image_batch['image'] = torch.FloatTensor(
        image_batch['image'][np.newaxis, ...]).to("cuda:0")
    pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

    # Create the fake labels where clustering ignores 255.
    fake_label_batch = {}
    for label_name in ['semantic_label', 'instance_label']:
      lab = np.zeros((resize_image_h, resize_image_w),
                     dtype=np.uint8)
      lab = transforms.resize_with_pad(
          lab,
          config.test.crop_size,
          image_pad_value=config.dataset.semantic_ignore_index)

      fake_label_batch[label_name] = torch.LongTensor(
          lab[np.newaxis, ...]).to("cuda:0")

    # Put label batch to gpu 1.
    for k, v in label_batch.items():
      label_batch[k] = torch.LongTensor(v[np.newaxis, ...]).to("cuda:0")

    # Create the ending index of each patch.
    stride_h, stride_w = config.test.stride
    crop_h, crop_w = config.test.crop_size
    npatches_h = math.ceil(1.0 * (pad_image_h-crop_h) / stride_h) + 1
    npatches_w = math.ceil(1.0 * (pad_image_w-crop_w) / stride_w) + 1
    patch_ind_h = np.linspace(
        crop_h, pad_image_h, npatches_h, dtype=np.int32)
    patch_ind_w = np.linspace(
        crop_w, pad_image_w, npatches_w, dtype=np.int32)

    # Create place holder for full-resolution embeddings.
    embeddings = {}
    counts = torch.FloatTensor(
        1, 1, pad_image_h, pad_image_w).zero_().to("cuda:0")
    with torch.no_grad():
      for ind_h in patch_ind_h:
        for ind_w in patch_ind_w:
          sh, eh = ind_h - crop_h, ind_h
          sw, ew = ind_w - crop_w, ind_w
          crop_image_batch = {
            k: v[:, :, sh:eh, sw:ew] for k, v in image_batch.items()}

          # Feed-forward.
          crop_embeddings = embedding_model.generate_embeddings(
              crop_image_batch, resize_as_input=True)

          # Initialize embedding.
          for name in crop_embeddings:
            if crop_embeddings[name] is None:
              continue
            crop_emb = crop_embeddings[name].to("cuda:0")
            if name in ['embedding']:
              crop_emb = common_utils.normalize_embedding(
                  crop_emb.permute(0, 2, 3, 1).contiguous())
              crop_emb = crop_emb.permute(0, 3, 1, 2)
            else:
              continue

            if name not in embeddings.keys():
              embeddings[name] = torch.FloatTensor(
                  1,
                  crop_emb.shape[1],
                  pad_image_h,
                  pad_image_w).zero_().to("cuda:0")
            embeddings[name][:, :, sh:eh, sw:ew] += crop_emb
          counts[:, :, sh:eh, sw:ew] += 1

    for k in embeddings.keys():
      embeddings[k] /= counts

    # KMeans.
    lab_div = config.network.label_divisor
    fake_sem_lab = fake_label_batch['semantic_label']
    fake_inst_lab = fake_label_batch['instance_label']
    clustering_outputs = embedding_model.generate_clusters(
        embeddings.get('embedding', None),
        fake_sem_lab,
        fake_inst_lab)
    embeddings.update(clustering_outputs)

    # Generate predictions.
    outputs = prediction_model(
        embeddings,
        {'semantic_memory_prototype': semantic_memory_prototypes,
         'semantic_memory_prototype_label': semantic_memory_prototype_labels},
        with_loss=False, with_prediction=True)
    semantic_topk = outputs['semantic_score']

    # DenseCRF post-processing.
    semantic_prob = common_utils.one_hot(
        semantic_topk, max_label=config.dataset.num_classes)
    semantic_prob = semantic_prob.sum(dim=1).float() / semantic_topk.shape[1]
    semantic_prob = semantic_prob.view(resize_image_h, resize_image_w, -1)
    semantic_prob = semantic_prob.data.cpu().numpy().astype(np.float32)
    semantic_prob = semantic_prob.transpose(2, 0, 1)

    image = image_batch['image'].data.cpu().numpy().astype(np.float32)
    image = image[0, :, :resize_image_h, :resize_image_w].transpose(1, 2, 0)
    image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
    image += np.reshape(config.network.pixel_means, (1, 1, 3))
    image = image * 255
    image = image.astype(np.uint8)

    semantic_prob = postprocessor(image, semantic_prob)

    # Save semantic predictions.
    semantic_pred = np.argmax(semantic_prob, axis=0).astype(np.uint8)
    semantic_pred = cv2.resize(
        semantic_pred,
        (image_w, image_h),
        interpolation=cv2.INTER_NEAREST)

    semantic_pred_name = os.path.join(
        semantic_dir, base_name)
    if not os.path.isdir(os.path.dirname(semantic_pred_name)):
      os.makedirs(os.path.dirname(semantic_pred_name))
    Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

    semantic_pred_rgb = color_map[semantic_pred]
    semantic_pred_rgb_name = os.path.join(
        semantic_rgb_dir, base_name)
    if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
      os.makedirs(os.path.dirname(semantic_pred_rgb_name))
    Image.fromarray(semantic_pred_rgb, mode='RGB').save(
        semantic_pred_rgb_name)
Beispiel #4
0
def gather_multiset_labels_per_batch_by_nearest_neighbor(
    embeddings,
    prototypes,
    semantic_prototype_labels,
    batch_embedding_labels,
    batch_prototype_labels,
    num_classes=21,
    top_k=3,
    threshold=0.95,
    label_divisor=255):
  """Assigned labels for unlabelled pixels by nearest-neighbor
  labeled segments, which is useful in feature affinity regularization.

  Args:
    embeddings: A float tensor indicates pixel-wise embeddings, whose
      last dimension denotes feature channels.
    prototypes: A 2-D float tensor indicates segment prototypes
      of shape `[num_segments, channels]`.
    semantic_prototype_labels: A 1-D float tensor indicates segment
      semantic labels of shape `[num_segments]`.
    batch_embedding_labels: A 1-D long tensor indicates pixel-wise
      batch indices, which should include the same number of pixels
      as `embeddings`.
    batch_prototype_labels: A 1-D long tensor indicates segment
      batch indices, which should include the same number of segments
      as `prototypes`.
    num_classes: An integer indicates the number of semantic categories.
    top_k: An integer indicates top-K retrievals.
    threshold: A float indicates the confidence threshold.
    label_divisor: An integer indicates the ignored label index.

  Return:
    A 2-D long tensor of shape `[num_pixels, num_classes]`. If entry i's
    value is 1, the nearest-neighbor segment is of category i.
  """

  embeddings = embeddings.view(-1, embeddings.shape[-1])
  prototypes = prototypes.view(-1, embeddings.shape[-1])
  N, C = embeddings.shape

  # Compute distance and retrieve nearest neighbors.
  batch_affinity = torch.eq(batch_embedding_labels.view(-1, 1),
                            batch_prototype_labels.view(1, -1))
  valid_prototypes = (semantic_prototype_labels < num_classes).view(1, -1)
  label_affinity = batch_affinity & valid_prototypes

  dists = torch.mm(embeddings, prototypes.t())
  min_dist = dists.min()
  dists = torch.where(label_affinity, dists, min_dist - 1)
  nn_dists, nn_inds = torch.topk(dists, top_k, dim=1)
  setsemantic_labels = torch.gather(
      semantic_prototype_labels.view(1, -1).expand(N, -1),
      1, nn_inds)
  setsemantic_labels = setsemantic_labels.masked_fill(
      nn_dists < threshold, num_classes)
  setbatch_labels = torch.gather(
      batch_prototype_labels.view(1, -1).expand(N, -1),
      1, nn_inds)

  # Remove ignored cluster embeddings.
  setsemantic_labels_2d = common_utils.one_hot(
      setsemantic_labels, num_classes+1)
  setsemantic_labels_2d = torch.sum(setsemantic_labels_2d, dim=1)
  setsemantic_labels_2d = (setsemantic_labels_2d > 0).long()
  setsemantic_labels_2d = setsemantic_labels_2d[:, :num_classes]

  return setsemantic_labels_2d