Esempio n. 1
0
def calculate_prototypes_from_labels(embeddings, labels, max_label=None):
    """Calculates prototypes from labels.

  This function calculates prototypes (mean direction) from embedding
  features for each label. This function is also used as the m-step in
  k-means clustering.

  Args:
    embeddings: A 2-D or 4-D float tensor with feature embedding in the
      last dimension (embedding_dim).
    labels: An N-D long label map for each embedding pixel.
    max_label: The maximum value of the label map. Calculated on-the-fly
      if not specified.

  Returns:
    A 2-D float tensor with shape `[num_prototypes, embedding_dim]`.
  """
    embeddings = embeddings.view(-1, embeddings.shape[-1])

    if max_label is None:
        max_label = labels.max() + 1
    prototypes = torch.zeros((max_label, embeddings.shape[-1]),
                             dtype=embeddings.dtype,
                             device=embeddings.device)
    labels = labels.view(-1, 1).expand(-1, embeddings.shape[-1])
    prototypes = prototypes.scatter_add_(0, labels, embeddings)
    prototypes = common_utils.normalize_embedding(prototypes)

    return prototypes
Esempio n. 2
0
def embedding_to_rgb(embeddings, project_type='pca'):
    """Project high-dimension embeddings to RGB colors.

  Args:
    embeddings: A 4-D float tensor with shape
      `[batch_size, embedding_dim, height, width]`.
    project_type: pca | random.

  Returns:
    An N-D float tensor with shape `[batch_size, 3, height, width]`.
  """
    # Transform NCHW to NHWC.
    embeddings = embeddings.permute(0, 2, 3, 1).contiguous()
    embeddings = common_utils.normalize_embedding(embeddings)

    N, H, W, C = embeddings.shape
    if project_type == 'pca':
        rgb = common_utils.pca(embeddings, 3)
    elif project_type == 'random':
        random_inds = torch.randint(0,
                                    C, (3, ),
                                    dtype=tf.long,
                                    device=embeddings.device)
        rgb = torch.index_select(embeddings, -1, random_inds)
    else:
        raise NotImplementedError()

    # Normalize per image.
    rgb = rgb.view(N, -1, 3)
    rgb -= torch.min(rgb, 1, keepdim=True)[0]
    rgb /= torch.max(rgb, 1, keepdim=True)[0]
    rgb *= 255
    rgb = rgb.byte()

    # Transform NHWC to NCHW.
    rgb = rgb.view(N, H, W, 3)
    rgb = rgb.permute(0, 3, 1, 2).contiguous()

    return rgb
Esempio n. 3
0
def segment_by_kmeans(embeddings,
                      labels=None,
                      num_clusters=[5, 5],
                      cluster_indices=None,
                      local_features=None,
                      ignore_index=None,
                      iterations=10):
    """Segment image into prototypes by Spherical KMeans Clustering.

  This function conducts Spherical KMeans Clustering within
  each image.

  Args:
    embeddings: A 4-D float tensor of shape
      `[batch_size, num_channels, height, width]`.
    num_clusters: A list of two integers indicate number of cluster
      for height and width.
    kmeans_iterations: An integer indicates number of iteration for
      kmeans clustering.
    label_divisor: An integer indicates the offset between semantic
      and instance labels.
    labels: A 3-D long tensor of shape
      `[batch_size, height, width]`.
    cluster_indices: A 3-D long tensor of shape
      `[batch_size, height, width]`.
    location_features: A 4-D float tensor of shape
      `[batch_size, height, width, 2]`.
    ignore_index: An integer denotes index of ignored class.

  Returns:
    prototypes: A 2-D float tensor of shape `[num_prototypes, embedding_dim]`.
    prototype_panoptic_labels: A 1-D long tensor.
    prototype_batch_labels: A 1-D long tensor.
    cluster_labels: A 1-D long tensor.
  """
    # Convert embeddings from NCHW to NHWC.
    embeddings = embeddings.permute(0, 2, 3, 1).contiguous()
    N, H, W, C = embeddings.shape

    # L-2 normalize the embeddings.
    embeddings = common_utils.normalize_embedding(embeddings)

    # Generate location features.
    if local_features is None:
        local_features = generate_location_features((H, W),
                                                    device=embeddings.device,
                                                    feature_type='float')
        local_features -= 0.5
        local_features = local_features.view(1, H, W, 2).expand(N, H, W, 2)

    # Create initial cluster labels.
    if cluster_indices is None:
        cluster_indices = initialize_cluster_labels(num_clusters, (H, W),
                                                    device=embeddings.device)
        cluster_indices = cluster_indices.view(1, H, W).expand(N, H, W)

    # Extract semantic and instance labels from panoptic labels.
    if labels is None:
        labels = torch.zeros((N, H, W),
                             dtype=torch.long,
                             device=embeddings.device)

    # Perform KMeans clustering per image.
    gathered_datas = {
        'labels': [],
        'cluster_indices': [],
        'batch_indices': [],
        'embeddings': [],
        'embeddings_with_loc': []
    }
    for batch_index in range(N):
        # Prepare datas for each image.
        cur_labels = labels[batch_index].view(-1)
        cur_cluster_indices = cluster_indices[batch_index].view(-1)
        _, cur_cluster_indices = torch.unique(cur_cluster_indices,
                                              return_inverse=True)

        cur_num_clusters = cur_cluster_indices.max() + 1

        cur_embeddings = embeddings[batch_index].view(-1, C)
        cur_local_features = (local_features[batch_index].view(
            -1, local_features.shape[-1]))
        cur_embeddings_with_loc = torch.cat(
            [cur_embeddings, cur_local_features], -1)
        cur_embeddings_with_loc = common_utils.normalize_embedding(
            cur_embeddings_with_loc)

        # Remove ignore label.
        if ignore_index is not None:
            valid_pixel_indices = torch.ne(cur_labels, ignore_index)
            valid_pixel_indices = valid_pixel_indices.nonzero().view(-1)
            cur_labels = torch.index_select(cur_labels, 0, valid_pixel_indices)
            cur_cluster_indices = torch.index_select(cur_cluster_indices, 0,
                                                     valid_pixel_indices)
            cur_embeddings = torch.index_select(cur_embeddings, 0,
                                                valid_pixel_indices)
            cur_embeddings_with_loc = torch.index_select(
                cur_embeddings_with_loc, 0, valid_pixel_indices)

        # KMeans clustering.
        if cur_embeddings.shape[0] > 0:
            cur_cluster_indices = kmeans_with_initial_labels(
                cur_embeddings_with_loc, cur_cluster_indices, cur_num_clusters,
                iterations)

        # Small hack to solve issue of batch index for multi-gpu.
        gpu_id = cur_cluster_indices.device.index
        _batch_index = batch_index + (N * gpu_id)

        # Add offset to labels to separate different images.
        cur_batch_indices = torch.zeros_like(cur_cluster_indices)
        cur_batch_indices.fill_(_batch_index)

        # Gather from each image.
        gathered_datas['labels'].append(cur_labels)
        gathered_datas['cluster_indices'].append(cur_cluster_indices)
        gathered_datas['batch_indices'].append(cur_batch_indices)
        gathered_datas['embeddings'].append(cur_embeddings)
        gathered_datas['embeddings_with_loc'].append(cur_embeddings_with_loc)

    # Concat results from each images.
    labels = torch.cat(gathered_datas['labels'], 0)
    embeddings = torch.cat(gathered_datas['embeddings'], 0)
    embeddings_with_loc = torch.cat(gathered_datas['embeddings_with_loc'], 0)
    cluster_indices = torch.cat(gathered_datas['cluster_indices'], 0)
    batch_indices = torch.cat(gathered_datas['batch_indices'], 0)

    # Partition segments by image.
    lab_div = cluster_indices.max() + 1
    cluster_indices = batch_indices * lab_div + cluster_indices
    _, cluster_indices = torch.unique(cluster_indices, return_inverse=True)

    # Partition segments by ground-truth labels.
    _, cluster_indices = prepare_prototype_labels(labels, cluster_indices,
                                                  labels.max() + 1)

    return embeddings, embeddings_with_loc,\
           labels, cluster_indices, batch_indices
Esempio n. 4
0
def main():
    """Inference for semantic segmentation.
  """
    # Retreve experiment configurations.
    args = parse_args('Inference for semantic segmentation.')
    config.network.kmeans_num_clusters = separate_comma(
        args.kmeans_num_clusters)
    config.network.label_divisor = args.label_divisor

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config)
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'segsort':
        prediction_model = segsort(config)
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    embedding_model = embedding_model.to("cuda:0")
    prediction_model = prediction_model.to("cuda:0")
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    #prediction_model.load_state_dict(
    #    torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Load memory prototypes.
    semantic_memory_prototypes, semantic_memory_prototype_labels = None, None
    if args.semantic_memory_dir is not None:
        semantic_memory_prototypes, semantic_memory_prototype_labels = (
            segsort_others.load_memory_banks(args.semantic_memory_dir))
        semantic_memory_prototypes = semantic_memory_prototypes.to("cuda:0")
        semantic_memory_prototype_labels = semantic_memory_prototype_labels.to(
            "cuda:0")

        # Remove ignore class.
        valid_prototypes = torch.ne(
            semantic_memory_prototype_labels,
            config.dataset.semantic_ignore_index).nonzero()
        valid_prototypes = valid_prototypes.view(-1)
        semantic_memory_prototypes = torch.index_select(
            semantic_memory_prototypes, 0, valid_prototypes)
        semantic_memory_prototype_labels = torch.index_select(
            semantic_memory_prototype_labels, 0, valid_prototypes)

    # Start inferencing.
    for data_index in tqdm(range(len(test_dataset))):
        # Image path.
        image_path = test_image_paths[data_index]
        base_name = os.path.basename(image_path).replace('.jpg', '.png')

        # Image resolution.
        image_batch, label_batch, _ = test_dataset[data_index]
        image_h, image_w = image_batch['image'].shape[-2:]

        # Resize the input image.
        if config.test.image_size > 0:
            image_batch['image'] = transforms.resize_with_interpolation(
                image_batch['image'].transpose(1, 2, 0),
                config.test.image_size,
                method='bilinear').transpose(2, 0, 1)
            for lab_name in ['semantic_label', 'instance_label']:
                label_batch[lab_name] = transforms.resize_with_interpolation(
                    label_batch[lab_name],
                    config.test.image_size,
                    method='nearest')
        resize_image_h, resize_image_w = image_batch['image'].shape[-2:]

        # Crop and Pad the input image.
        image_batch['image'] = transforms.resize_with_pad(
            image_batch['image'].transpose(1, 2, 0),
            config.test.crop_size,
            image_pad_value=0).transpose(2, 0, 1)
        image_batch['image'] = torch.FloatTensor(
            image_batch['image'][np.newaxis, ...]).to("cuda:0")
        pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

        # Create the fake labels where clustering ignores 255.
        fake_label_batch = {}
        for label_name in ['semantic_label', 'instance_label']:
            lab = np.zeros((resize_image_h, resize_image_w), dtype=np.uint8)
            lab = transforms.resize_with_pad(
                lab,
                config.test.crop_size,
                image_pad_value=config.dataset.semantic_ignore_index)

            fake_label_batch[label_name] = torch.LongTensor(
                lab[np.newaxis, ...]).to("cuda:0")

        # Put label batch to gpu 1.
        for k, v in label_batch.items():
            label_batch[k] = torch.LongTensor(v[np.newaxis, ...]).to("cuda:0")

        # Create the ending index of each patch.
        stride_h, stride_w = config.test.stride
        crop_h, crop_w = config.test.crop_size
        npatches_h = math.ceil(1.0 * (pad_image_h - crop_h) / stride_h) + 1
        npatches_w = math.ceil(1.0 * (pad_image_w - crop_w) / stride_w) + 1
        patch_ind_h = np.linspace(crop_h,
                                  pad_image_h,
                                  npatches_h,
                                  dtype=np.int32)
        patch_ind_w = np.linspace(crop_w,
                                  pad_image_w,
                                  npatches_w,
                                  dtype=np.int32)

        # Create place holder for full-resolution embeddings.
        embeddings = {}
        counts = torch.FloatTensor(1, 1, pad_image_h,
                                   pad_image_w).zero_().to("cuda:0")
        with torch.no_grad():
            for ind_h in patch_ind_h:
                for ind_w in patch_ind_w:
                    sh, eh = ind_h - crop_h, ind_h
                    sw, ew = ind_w - crop_w, ind_w
                    crop_image_batch = {
                        k: v[:, :, sh:eh, sw:ew]
                        for k, v in image_batch.items()
                    }

                    # Feed-forward.
                    crop_embeddings = embedding_model.generate_embeddings(
                        crop_image_batch, resize_as_input=True)

                    # Initialize embedding.
                    for name in crop_embeddings:
                        if crop_embeddings[name] is None:
                            continue
                        crop_emb = crop_embeddings[name].to("cuda:0")
                        if name in ['embedding']:
                            crop_emb = common_utils.normalize_embedding(
                                crop_emb.permute(0, 2, 3, 1).contiguous())
                            crop_emb = crop_emb.permute(0, 3, 1, 2)
                        else:
                            continue

                        if name not in embeddings.keys():
                            embeddings[name] = torch.FloatTensor(
                                1, crop_emb.shape[1], pad_image_h,
                                pad_image_w).zero_().to("cuda:0")
                        embeddings[name][:, :, sh:eh, sw:ew] += crop_emb
                    counts[:, :, sh:eh, sw:ew] += 1

        for k in embeddings.keys():
            embeddings[k] /= counts

        # KMeans.
        lab_div = config.network.label_divisor
        fake_sem_lab = fake_label_batch['semantic_label']
        fake_inst_lab = fake_label_batch['instance_label']
        clustering_outputs = embedding_model.generate_clusters(
            embeddings.get('embedding', None), fake_sem_lab, fake_inst_lab)
        embeddings.update(clustering_outputs)

        # Generate predictions.
        outputs = prediction_model(
            embeddings, {
                'semantic_memory_prototype': semantic_memory_prototypes,
                'semantic_memory_prototype_label':
                semantic_memory_prototype_labels
            },
            with_loss=False,
            with_prediction=True)

        # Save semantic predictions.
        semantic_pred = outputs.get('semantic_prediction', None)
        if semantic_pred is not None:
            semantic_pred = (semantic_pred.view(
                resize_image_h,
                resize_image_w).cpu().data.numpy().astype(np.uint8))
            semantic_pred = cv2.resize(semantic_pred, (image_w, image_h),
                                       interpolation=cv2.INTER_NEAREST)

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)
Esempio n. 5
0
    def generate_clusters(self,
                          embeddings,
                          semantic_labels,
                          instance_labels,
                          local_features=None):
        """Perform Spherical KMeans clustering within each image.
    
    We squeeze the numerical values of pixel-wise embeddings
    when concatenating with location coordinates. It provides
    better performance for feature affinity regularization, 
    where labels are propagated to unlabeled segments using
    nearest neighbor retrievals.

    Args:
      embeddings: A a 4-D float tensor of shape
        `[batch_size, channels, height, width]`.
      semantic_labels: A 3-D long tensor of shape
        `[batch_size, height, width]`.
      instance_labels: A 3-D long tensor of shape
        `[batch_size, height, width]`.
      local_features: A 4-D float tensor of shape
        `[batch_size, height, width, channels]`.

    Return:
      A dict with entry `cluster_embedding`, `cluster_embedding_with_loc`,
      `cluster_semantic_label`, `cluster_instance_label`, `cluster_index`
      and `cluster_batch_index`.
    """

        if semantic_labels is not None and instance_labels is not None:
            labels = semantic_labels * self.label_divisor + instance_labels
            ignore_index = labels.max() + 1
            labels = labels.masked_fill(
                semantic_labels == self.semantic_ignore_index, ignore_index)
        else:
            labels = None
            ignore_index = None

        # Spherical KMeans clustering.
        (cluster_embeddings, cluster_embeddings_with_loc, cluster_labels,
         cluster_indices,
         cluster_batch_indices) = (segsort_common.segment_by_kmeans(
             embeddings,
             labels,
             self.kmeans_num_clusters,
             local_features=local_features,
             ignore_index=ignore_index,
             iterations=self.kmeans_iterations))

        # Squeeze the numerical values of pixel-wise embeddings
        # when concatenating with location features.
        if local_features is not None:
            if semantic_labels is not None:
                valid_pixels = semantic_labels != self.semantic_ignore_index
                valid_pixels = valid_pixels.view(-1).nonzero().view(-1)
                local_features = local_features.view(-1,
                                                     local_features.shape[-1])
                local_features = torch.index_select(local_features, 0,
                                                    valid_pixels)
            cluster_embeddings_with_loc = torch.cat(
                [cluster_embeddings * 0.1, local_features], dim=-1)
            cluster_embeddings_with_loc = common_utils.normalize_embedding(
                cluster_embeddings_with_loc)

        cluster_semantic_labels = cluster_labels // self.label_divisor
        cluster_instance_labels = cluster_labels % self.label_divisor

        outputs = {
            'cluster_embedding': cluster_embeddings,
            'cluster_embedding_with_loc': cluster_embeddings_with_loc,
            'cluster_semantic_label': cluster_semantic_labels,
            'cluster_instance_label': cluster_instance_labels,
            'cluster_index': cluster_indices,
            'cluster_batch_index': cluster_batch_indices,
        }

        return outputs
Esempio n. 6
0
def main():
    """Generate pseudo labels by nearest neighbor retrievals.
  """
    # Retreve experiment configurations.
    args = parse_args('Generate pseudo labels by nearest neighbor retrievals.')
    config.network.kmeans_num_clusters = separate_comma(
        args.kmeans_num_clusters)
    config.network.label_divisor = args.label_divisor

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'segsort':
        prediction_model = segsort(config)
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    embedding_model = embedding_model.to("cuda:0")
    prediction_model = prediction_model.to("cuda:0")
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    #prediction_model.load_state_dict(
    #    torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Load memory prototypes.
    semantic_memory_prototypes, semantic_memory_prototype_labels = None, None
    if args.semantic_memory_dir is not None:
        semantic_memory_prototypes, semantic_memory_prototype_labels = (
            segsort_others.load_memory_banks(args.semantic_memory_dir))
        semantic_memory_prototypes = semantic_memory_prototypes.to("cuda:0")
        semantic_memory_prototype_labels = semantic_memory_prototype_labels.to(
            "cuda:0")

        # Remove ignore class.
        valid_prototypes = torch.ne(
            semantic_memory_prototype_labels,
            config.dataset.semantic_ignore_index).nonzero()
        valid_prototypes = valid_prototypes.view(-1)
        semantic_memory_prototypes = torch.index_select(
            semantic_memory_prototypes, 0, valid_prototypes)
        semantic_memory_prototype_labels = torch.index_select(
            semantic_memory_prototype_labels, 0, valid_prototypes)

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            original_image_batch, original_label_batch, _ = test_dataset[
                data_index]
            image_h, image_w = original_image_batch['image'].shape[-2:]
            batches = other_utils.create_image_pyramid(original_image_batch,
                                                       original_label_batch,
                                                       scales=[0.5, 1, 1.5, 2],
                                                       is_flip=True)

            lab_tags = np.unique(original_label_batch['semantic_label'])
            lab_tags = lab_tags[lab_tags < config.dataset.num_classes]
            label_tags = np.zeros((config.dataset.num_classes, ),
                                  dtype=np.bool)
            label_tags[lab_tags] = True

            semantic_topks = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).to("cuda:0")
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                # Create the fake labels where clustering ignores 255.
                fake_label_batch = {}
                for label_name in ['semantic_label', 'instance_label']:
                    lab = np.zeros((resize_image_h, resize_image_w),
                                   dtype=np.uint8)
                    lab = transforms.resize_with_pad(
                        lab,
                        config.test.crop_size,
                        image_pad_value=config.dataset.semantic_ignore_index)

                    fake_label_batch[label_name] = torch.LongTensor(
                        lab[np.newaxis, ...]).to("cuda:0")

                # Put label batch to gpu 1.
                #for k, v in label_batch.items():
                #  label_batch[k] = torch.LongTensor(v[np.newaxis, ...]).to("cuda:0")

                # Create the ending index of each patch.
                stride_h, stride_w = config.test.stride
                crop_h, crop_w = config.test.crop_size
                npatches_h = math.ceil(1.0 *
                                       (pad_image_h - crop_h) / stride_h) + 1
                npatches_w = math.ceil(1.0 *
                                       (pad_image_w - crop_w) / stride_w) + 1
                patch_ind_h = np.linspace(crop_h,
                                          pad_image_h,
                                          npatches_h,
                                          dtype=np.int32)
                patch_ind_w = np.linspace(crop_w,
                                          pad_image_w,
                                          npatches_w,
                                          dtype=np.int32)

                # Create place holder for full-resolution embeddings.
                embeddings = {}
                counts = torch.FloatTensor(1, 1, pad_image_h,
                                           pad_image_w).zero_().to("cuda:0")
                for ind_h in patch_ind_h:
                    for ind_w in patch_ind_w:
                        sh, eh = ind_h - crop_h, ind_h
                        sw, ew = ind_w - crop_w, ind_w
                        crop_image_batch = {
                            k: v[:, :, sh:eh, sw:ew]
                            for k, v in image_batch.items()
                        }

                        # Feed-forward.
                        crop_embeddings = embedding_model.generate_embeddings(
                            crop_image_batch, resize_as_input=True)

                        # Initialize embedding.
                        for name in crop_embeddings:
                            if crop_embeddings[name] is None:
                                continue
                            crop_emb = crop_embeddings[name].to("cuda:0")
                            if name in ['embedding']:
                                crop_emb = common_utils.normalize_embedding(
                                    crop_emb.permute(0, 2, 3, 1).contiguous())
                                crop_emb = crop_emb.permute(0, 3, 1, 2)
                            else:
                                continue

                            if name not in embeddings.keys():
                                embeddings[name] = torch.FloatTensor(
                                    1, crop_emb.shape[1], pad_image_h,
                                    pad_image_w).zero_().to("cuda:0")
                            embeddings[name][:, :, sh:eh, sw:ew] += crop_emb
                        counts[:, :, sh:eh, sw:ew] += 1

                for k in embeddings.keys():
                    embeddings[k] /= counts

                # KMeans.
                lab_div = config.network.label_divisor
                fake_sem_lab = fake_label_batch['semantic_label'][
                    ..., :resize_image_h, :resize_image_w]
                fake_inst_lab = fake_label_batch['instance_label'][
                    ..., :resize_image_h, :resize_image_w]
                embs = embeddings['embedding'][
                    ..., :resize_image_h, :resize_image_w]
                clustering_outputs = embedding_model.generate_clusters(
                    embs, fake_sem_lab, fake_inst_lab)
                embeddings.update(clustering_outputs)

                # Generate predictions.
                outputs = prediction_model(embeddings, {
                    'semantic_memory_prototype':
                    semantic_memory_prototypes,
                    'semantic_memory_prototype_label':
                    semantic_memory_prototype_labels
                },
                                           with_loss=False,
                                           with_prediction=True)
                semantic_topk = common_utils.one_hot(
                    outputs['semantic_score'],
                    config.dataset.num_classes).float()
                semantic_topk = torch.mean(semantic_topk, dim=1)
                semantic_topk = semantic_topk.view(resize_image_h,
                                                   resize_image_w, -1)
                #print(semantic_topk.shape)
                semantic_topk = (semantic_topk.data.cpu().numpy().astype(
                    np.float32))
                semantic_topk = cv2.resize(semantic_topk, (image_w, image_h),
                                           interpolation=cv2.INTER_LINEAR)
                if data_info['is_flip']:
                    semantic_topk = semantic_topk[:, ::-1]
                semantic_topks.append(semantic_topk)

            # Save semantic predictions.
            semantic_topks = np.stack(semantic_topks,
                                      axis=0).astype(np.float32)
            #print(semantic_topks.shape)
            semantic_prob = np.mean(semantic_topks, axis=0)
            semantic_prob = semantic_prob.transpose(2, 0, 1)

            # Normalize prediction.
            C, H, W = semantic_prob.shape
            max_prob = np.max(np.reshape(semantic_prob, (C, -1)), axis=1)
            max_prob = np.maximum(max_prob, 0.15)
            max_prob = np.reshape(max_prob, (C, 1, 1))
            max_prob[~label_tags, :, :] = 1
            semantic_prob = semantic_prob / max_prob

            # DenseCRF post-processing.
            image = original_image_batch['image'].astype(np.float32)
            image = image.transpose(1, 2, 0)
            image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            image += np.reshape(config.network.pixel_means, (1, 1, 3))
            image = image * 255
            image = image.astype(np.uint8)

            semantic_prob = postprocessor(image, semantic_prob)

            semantic_pred = np.argmax(semantic_prob, axis=0).astype(np.uint8)

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)