Exemplo n.º 1
0
def main():
    """Inference for semantic segmentation.
  """
    # Retreve experiment configurations.
    args = parse_args('Inference for semantic segmentation.')

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')
    os.makedirs(semantic_dir, exist_ok=True)
    os.makedirs(semantic_rgb_dir, exist_ok=True)

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    prediction_model = softmax_classifier(config).cuda()
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    prediction_model.load_state_dict(
        torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            image_batch, label_batch, _ = test_dataset[data_index]
            image_h, image_w = image_batch['image'].shape[-2:]
            batches = other_utils.create_image_pyramid(
                image_batch,
                label_batch,
                scales=[0.5, 0.75, 1, 1.25, 1.5],
                is_flip=True)

            semantic_logits = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).cuda()
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                # Create the ending index of each patch.
                stride_h, stride_w = config.test.stride
                crop_h, crop_w = config.test.crop_size
                npatches_h = math.ceil(1.0 *
                                       (pad_image_h - crop_h) / stride_h) + 1
                npatches_w = math.ceil(1.0 *
                                       (pad_image_w - crop_w) / stride_w) + 1
                patch_ind_h = np.linspace(crop_h,
                                          pad_image_h,
                                          npatches_h,
                                          dtype=np.int32)
                patch_ind_w = np.linspace(crop_w,
                                          pad_image_w,
                                          npatches_w,
                                          dtype=np.int32)

                # Create place holder for full-resolution embeddings.
                semantic_logit = torch.FloatTensor(
                    1, config.dataset.num_classes, pad_image_h,
                    pad_image_w).zero_().to("cuda:0")
                counts = torch.FloatTensor(1, 1, pad_image_h,
                                           pad_image_w).zero_().to("cuda:0")
                for ind_h in patch_ind_h:
                    for ind_w in patch_ind_w:
                        sh, eh = ind_h - crop_h, ind_h
                        sw, ew = ind_w - crop_w, ind_w
                        crop_image_batch = {
                            k: v[:, :, sh:eh, sw:ew]
                            for k, v in image_batch.items()
                        }

                        # Feed-forward.
                        crop_embeddings = embedding_model(crop_image_batch,
                                                          resize_as_input=True)
                        crop_outputs = prediction_model(crop_embeddings)
                        semantic_logit[..., sh:eh, sw:ew] += crop_outputs[
                            'semantic_logit'].to("cuda:0")
                        counts[..., sh:eh, sw:ew] += 1
                semantic_logit /= counts
                semantic_logit = semantic_logit[
                    ..., :resize_image_h, :resize_image_w]
                semantic_logit = F.interpolate(semantic_logit,
                                               size=(image_h, image_w),
                                               mode='bilinear')
                semantic_logit = F.softmax(semantic_logit, dim=1)
                semantic_logit = semantic_logit.data.cpu().numpy().astype(
                    np.float32)
                if data_info['is_flip']:
                    semantic_logit = semantic_logit[..., ::-1]
                semantic_logits.append(semantic_logit)

            # Save semantic predictions.
            semantic_logits = np.concatenate(semantic_logits, axis=0)
            semantic_logits = np.sum(semantic_logits, axis=0)
            if semantic_logits is not None:
                semantic_pred = np.argmax(semantic_logits,
                                          axis=0).astype(np.uint8)

                semantic_pred_name = os.path.join(semantic_dir, base_name)
                Image.fromarray(semantic_pred,
                                mode='L').save(semantic_pred_name)

                semantic_pred_rgb = color_map[semantic_pred]
                semantic_pred_rgb_name = os.path.join(semantic_rgb_dir,
                                                      base_name)
                Image.fromarray(semantic_pred_rgb,
                                mode='RGB').save(semantic_pred_rgb_name)

            # Clean GPU memory cache to save more space.
            outputs = {}
            crop_embeddings = {}
            crop_outputs = {}
            torch.cuda.empty_cache()
Exemplo n.º 2
0
def main():
    """Generate pseudo labels from CAM by random walk and CRF.
  """
    # Retreve experiment configurations.
    args = parse_args(
        'Generate pseudo labels from CAM by random walk and CRF.')
    config.network.kmeans_num_clusters = separate_comma(
        args.kmeans_num_clusters)
    config.network.label_divisor = args.label_divisor

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config)
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    embedding_model = embedding_model.to("cuda:0")
    embedding_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)

    # Start inferencing.
    for data_index in tqdm(range(len(test_dataset))):
        # Image path.
        image_path = test_image_paths[data_index]
        base_name = os.path.basename(image_path).replace('.jpg', '.png')

        # Image resolution.
        image_batch, label_batch, _ = test_dataset[data_index]
        image_h, image_w = image_batch['image'].shape[-2:]

        # Load cam
        sem_labs = np.unique(label_batch['semantic_label'])
        #cam = np.load(os.path.join('/home/twke/repos/SEAM/outputs/train+/cam', base_name.replace('.png', '.npy')), allow_pickle=True).item()
        cam = np.load(os.path.join(args.cam_dir,
                                   base_name.replace('.png', '.npy')),
                      allow_pickle=True).item()
        cam_full_arr = np.zeros((21, image_h, image_w), np.float32)
        for k, v in cam.items():
            cam_full_arr[k + 1] = v
        cam_full_arr[0] = np.power(
            1 - np.max(cam_full_arr[1:], axis=0, keepdims=True), ALPHA)
        cam_full_arr = torch.from_numpy(cam_full_arr).cuda()

        # Image resolution.
        batches = other_utils.create_image_pyramid(image_batch,
                                                   label_batch,
                                                   scales=[1],
                                                   is_flip=True)

        affs = []
        for image_batch, label_batch, data_info in batches:
            resize_image_h, resize_image_w = image_batch['image'].shape[-2:]
            # Crop and Pad the input image.
            image_batch['image'] = transforms.resize_with_pad(
                image_batch['image'].transpose(1, 2, 0),
                config.test.crop_size,
                image_pad_value=0).transpose(2, 0, 1)
            for lab_name in ['semantic_label', 'instance_label']:
                label_batch[lab_name] = transforms.resize_with_pad(
                    label_batch[lab_name],
                    config.test.crop_size,
                    image_pad_value=255)
            image_batch['image'] = torch.FloatTensor(
                image_batch['image'][np.newaxis, ...]).to("cuda:0")
            for k, v in label_batch.items():
                label_batch[k] = torch.LongTensor(v[np.newaxis,
                                                    ...]).to("cuda:0")
            pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

            with torch.no_grad():
                embeddings = embedding_model(image_batch,
                                             label_batch,
                                             resize_as_input=True)
                embs = embeddings[
                    'embedding'][:, :, :resize_image_h, :resize_image_w]
                if data_info['is_flip']:
                    embs = torch.flip(embs, dims=[3])
                embs = F.interpolate(embs,
                                     size=(image_h // 8, image_w // 8),
                                     mode='bilinear')
                embs = embs / torch.norm(embs, dim=1)
                embs_flat = embs.view(embs.shape[1], -1)
                aff = torch.matmul(embs_flat.t(),
                                   embs_flat).mul_(5).add_(-5).exp_()
                affs.append(aff)

        aff = torch.mean(torch.stack(affs, dim=0), dim=0)
        cam_full_arr = F.interpolate(cam_full_arr.unsqueeze(0),
                                     scale_factor=1 / 8.,
                                     mode='bilinear').squeeze(0)
        cam_shape = cam_full_arr.shape[-2:]

        # Start random walk.
        aff_mat = aff**20

        trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
        for _ in range(WALK_STEPS):
            trans_mat = torch.matmul(trans_mat, trans_mat)

        cam_vec = cam_full_arr.view(21, -1)
        cam_rw = torch.matmul(cam_vec, trans_mat)
        cam_rw = cam_rw.view(21, cam_shape[0], cam_shape[1])

        cam_rw = cam_rw.data.cpu().numpy()
        cam_rw = cv2.resize(cam_rw.transpose(1, 2, 0),
                            dsize=(image_w, image_h),
                            interpolation=cv2.INTER_LINEAR)
        cam_rw_pred = np.argmax(cam_rw, axis=-1).astype(np.uint8)

        # CRF
        image = image_batch['image'].data.cpu().numpy().astype(np.float32)
        image = image[0, :, :image_h, :image_w].transpose(1, 2, 0)
        image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
        image += np.reshape(config.network.pixel_means, (1, 1, 3))
        image = image * 255
        image = image.astype(np.uint8)
        cam_rw = postprocessor(image, cam_rw.transpose(2, 0, 1))

        cam_rw_pred = np.argmax(cam_rw, axis=0).astype(np.uint8)

        # Save semantic predictions.
        semantic_pred = cam_rw_pred

        semantic_pred_name = os.path.join(semantic_dir, base_name)
        if not os.path.isdir(os.path.dirname(semantic_pred_name)):
            os.makedirs(os.path.dirname(semantic_pred_name))
        Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

        semantic_pred_rgb = color_map[semantic_pred]
        semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
        if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
            os.makedirs(os.path.dirname(semantic_pred_rgb_name))
        Image.fromarray(semantic_pred_rgb,
                        mode='RGB').save(semantic_pred_rgb_name)
Exemplo n.º 3
0
def main():
    """Generate pseudo labels by softmax classifier.
  """
    # Retreve experiment configurations.
    args = parse_args('Generate pseudo labels by softmax classifier.')

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    prediction_model = softmax_classifier(config).cuda()
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    prediction_model.load_state_dict(
        torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            original_image_batch, original_label_batch, _ = test_dataset[
                data_index]
            image_h, image_w = original_image_batch['image'].shape[-2:]

            lab_tags = np.unique(original_label_batch['semantic_label'])
            lab_tags = lab_tags[lab_tags < config.dataset.num_classes]
            label_tags = np.zeros((config.dataset.num_classes, ),
                                  dtype=np.bool)
            label_tags[lab_tags] = True
            label_tags = torch.from_numpy(label_tags).cuda()

            # Image resolution.
            batches = other_utils.create_image_pyramid(original_image_batch,
                                                       original_label_batch,
                                                       scales=[0.75, 1],
                                                       is_flip=True)

            affs = []
            semantic_probs = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).cuda()
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                embeddings = embedding_model(image_batch, resize_as_input=True)
                outputs = prediction_model(embeddings)

                embs = embeddings[
                    'embedding'][:, :, :resize_image_h, :resize_image_w]
                semantic_logit = outputs['semantic_logit'][
                    ..., :resize_image_h, :resize_image_w]
                if data_info['is_flip']:
                    embs = torch.flip(embs, dims=[3])
                    semantic_logit = torch.flip(semantic_logit, dims=[3])
                embs = F.interpolate(embs,
                                     size=(image_h // 8, image_w // 8),
                                     mode='bilinear')
                embs = embs / torch.norm(embs, dim=1)
                embs_flat = embs.view(embs.shape[1], -1)
                aff = torch.matmul(embs_flat.t(),
                                   embs_flat).mul_(5).add_(-5).exp_()
                affs.append(aff)

                semantic_logit = F.interpolate(semantic_logit,
                                               size=(image_h // 8,
                                                     image_w // 8),
                                               mode='bilinear')
                #semantic_prob = F.softmax(semantic_logit, dim=1)
                #semantic_probs.append(semantic_prob)
                semantic_probs.append(semantic_logit)

            cat_semantic_probs = torch.cat(semantic_probs, dim=0)
            #semantic_probs, _ = torch.max(cat_semantic_probs, dim=0)
            #semantic_probs[0] = torch.min(cat_semantic_probs[:, 0, :, :], dim=0)[0]
            semantic_probs = torch.mean(cat_semantic_probs, dim=0)
            semantic_probs = F.softmax(semantic_probs, dim=0)

            # normalize cam.
            max_prob = torch.max(semantic_probs.view(21, -1), dim=1)[0]
            cam_full_arr = semantic_probs / max_prob.view(21, 1, 1)

            cam_shape = cam_full_arr.shape[-2:]
            label_tags = (~label_tags).view(-1, 1,
                                            1).expand(-1, cam_shape[0],
                                                      cam_shape[1])
            cam_full_arr = cam_full_arr.masked_fill(label_tags, 0)
            if TH is not None:
                cam_full_arr[0] = TH

            aff = torch.mean(torch.stack(affs, dim=0), dim=0)

            # Start random walk.
            aff_mat = aff**20

            trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
            for _ in range(WALK_STEPS):
                trans_mat = torch.matmul(trans_mat, trans_mat)

            cam_vec = cam_full_arr.view(21, -1)
            cam_rw = torch.matmul(cam_vec, trans_mat)
            cam_rw = cam_rw.view(21, cam_shape[0], cam_shape[1])

            cam_rw = cam_rw.data.cpu().numpy()
            cam_rw = cv2.resize(cam_rw.transpose(1, 2, 0),
                                dsize=(image_w, image_h),
                                interpolation=cv2.INTER_LINEAR)
            cam_rw_pred = np.argmax(cam_rw, axis=-1).astype(np.uint8)

            # CRF
            #image = image_batch['image'].data.cpu().numpy().astype(np.float32)
            #image = image[0, :, :image_h, :image_w].transpose(1, 2, 0)
            #image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            #image += np.reshape(config.network.pixel_means, (1, 1, 3))
            #image = image * 255
            #image = image.astype(np.uint8)
            #cam_rw = postprocessor(image, cam_rw.transpose(2,0,1))

            #cam_rw_pred = np.argmax(cam_rw, axis=0).astype(np.uint8)

            # Save semantic predictions.
            semantic_pred = cam_rw_pred

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)
Exemplo n.º 4
0
def main():
    """Generate pseudo labels by nearest neighbor retrievals.
  """
    # Retreve experiment configurations.
    args = parse_args('Generate pseudo labels by nearest neighbor retrievals.')
    config.network.kmeans_num_clusters = separate_comma(
        args.kmeans_num_clusters)
    config.network.label_divisor = args.label_divisor

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'segsort':
        prediction_model = segsort(config)
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    embedding_model = embedding_model.to("cuda:0")
    prediction_model = prediction_model.to("cuda:0")
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    #prediction_model.load_state_dict(
    #    torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Load memory prototypes.
    semantic_memory_prototypes, semantic_memory_prototype_labels = None, None
    if args.semantic_memory_dir is not None:
        semantic_memory_prototypes, semantic_memory_prototype_labels = (
            segsort_others.load_memory_banks(args.semantic_memory_dir))
        semantic_memory_prototypes = semantic_memory_prototypes.to("cuda:0")
        semantic_memory_prototype_labels = semantic_memory_prototype_labels.to(
            "cuda:0")

        # Remove ignore class.
        valid_prototypes = torch.ne(
            semantic_memory_prototype_labels,
            config.dataset.semantic_ignore_index).nonzero()
        valid_prototypes = valid_prototypes.view(-1)
        semantic_memory_prototypes = torch.index_select(
            semantic_memory_prototypes, 0, valid_prototypes)
        semantic_memory_prototype_labels = torch.index_select(
            semantic_memory_prototype_labels, 0, valid_prototypes)

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            original_image_batch, original_label_batch, _ = test_dataset[
                data_index]
            image_h, image_w = original_image_batch['image'].shape[-2:]
            batches = other_utils.create_image_pyramid(original_image_batch,
                                                       original_label_batch,
                                                       scales=[0.5, 1, 1.5, 2],
                                                       is_flip=True)

            lab_tags = np.unique(original_label_batch['semantic_label'])
            lab_tags = lab_tags[lab_tags < config.dataset.num_classes]
            label_tags = np.zeros((config.dataset.num_classes, ),
                                  dtype=np.bool)
            label_tags[lab_tags] = True

            semantic_topks = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).to("cuda:0")
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                # Create the fake labels where clustering ignores 255.
                fake_label_batch = {}
                for label_name in ['semantic_label', 'instance_label']:
                    lab = np.zeros((resize_image_h, resize_image_w),
                                   dtype=np.uint8)
                    lab = transforms.resize_with_pad(
                        lab,
                        config.test.crop_size,
                        image_pad_value=config.dataset.semantic_ignore_index)

                    fake_label_batch[label_name] = torch.LongTensor(
                        lab[np.newaxis, ...]).to("cuda:0")

                # Put label batch to gpu 1.
                #for k, v in label_batch.items():
                #  label_batch[k] = torch.LongTensor(v[np.newaxis, ...]).to("cuda:0")

                # Create the ending index of each patch.
                stride_h, stride_w = config.test.stride
                crop_h, crop_w = config.test.crop_size
                npatches_h = math.ceil(1.0 *
                                       (pad_image_h - crop_h) / stride_h) + 1
                npatches_w = math.ceil(1.0 *
                                       (pad_image_w - crop_w) / stride_w) + 1
                patch_ind_h = np.linspace(crop_h,
                                          pad_image_h,
                                          npatches_h,
                                          dtype=np.int32)
                patch_ind_w = np.linspace(crop_w,
                                          pad_image_w,
                                          npatches_w,
                                          dtype=np.int32)

                # Create place holder for full-resolution embeddings.
                embeddings = {}
                counts = torch.FloatTensor(1, 1, pad_image_h,
                                           pad_image_w).zero_().to("cuda:0")
                for ind_h in patch_ind_h:
                    for ind_w in patch_ind_w:
                        sh, eh = ind_h - crop_h, ind_h
                        sw, ew = ind_w - crop_w, ind_w
                        crop_image_batch = {
                            k: v[:, :, sh:eh, sw:ew]
                            for k, v in image_batch.items()
                        }

                        # Feed-forward.
                        crop_embeddings = embedding_model.generate_embeddings(
                            crop_image_batch, resize_as_input=True)

                        # Initialize embedding.
                        for name in crop_embeddings:
                            if crop_embeddings[name] is None:
                                continue
                            crop_emb = crop_embeddings[name].to("cuda:0")
                            if name in ['embedding']:
                                crop_emb = common_utils.normalize_embedding(
                                    crop_emb.permute(0, 2, 3, 1).contiguous())
                                crop_emb = crop_emb.permute(0, 3, 1, 2)
                            else:
                                continue

                            if name not in embeddings.keys():
                                embeddings[name] = torch.FloatTensor(
                                    1, crop_emb.shape[1], pad_image_h,
                                    pad_image_w).zero_().to("cuda:0")
                            embeddings[name][:, :, sh:eh, sw:ew] += crop_emb
                        counts[:, :, sh:eh, sw:ew] += 1

                for k in embeddings.keys():
                    embeddings[k] /= counts

                # KMeans.
                lab_div = config.network.label_divisor
                fake_sem_lab = fake_label_batch['semantic_label'][
                    ..., :resize_image_h, :resize_image_w]
                fake_inst_lab = fake_label_batch['instance_label'][
                    ..., :resize_image_h, :resize_image_w]
                embs = embeddings['embedding'][
                    ..., :resize_image_h, :resize_image_w]
                clustering_outputs = embedding_model.generate_clusters(
                    embs, fake_sem_lab, fake_inst_lab)
                embeddings.update(clustering_outputs)

                # Generate predictions.
                outputs = prediction_model(embeddings, {
                    'semantic_memory_prototype':
                    semantic_memory_prototypes,
                    'semantic_memory_prototype_label':
                    semantic_memory_prototype_labels
                },
                                           with_loss=False,
                                           with_prediction=True)
                semantic_topk = common_utils.one_hot(
                    outputs['semantic_score'],
                    config.dataset.num_classes).float()
                semantic_topk = torch.mean(semantic_topk, dim=1)
                semantic_topk = semantic_topk.view(resize_image_h,
                                                   resize_image_w, -1)
                #print(semantic_topk.shape)
                semantic_topk = (semantic_topk.data.cpu().numpy().astype(
                    np.float32))
                semantic_topk = cv2.resize(semantic_topk, (image_w, image_h),
                                           interpolation=cv2.INTER_LINEAR)
                if data_info['is_flip']:
                    semantic_topk = semantic_topk[:, ::-1]
                semantic_topks.append(semantic_topk)

            # Save semantic predictions.
            semantic_topks = np.stack(semantic_topks,
                                      axis=0).astype(np.float32)
            #print(semantic_topks.shape)
            semantic_prob = np.mean(semantic_topks, axis=0)
            semantic_prob = semantic_prob.transpose(2, 0, 1)

            # Normalize prediction.
            C, H, W = semantic_prob.shape
            max_prob = np.max(np.reshape(semantic_prob, (C, -1)), axis=1)
            max_prob = np.maximum(max_prob, 0.15)
            max_prob = np.reshape(max_prob, (C, 1, 1))
            max_prob[~label_tags, :, :] = 1
            semantic_prob = semantic_prob / max_prob

            # DenseCRF post-processing.
            image = original_image_batch['image'].astype(np.float32)
            image = image.transpose(1, 2, 0)
            image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            image += np.reshape(config.network.pixel_means, (1, 1, 3))
            image = image * 255
            image = image.astype(np.uint8)

            semantic_prob = postprocessor(image, semantic_prob)

            semantic_pred = np.argmax(semantic_prob, axis=0).astype(np.uint8)

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)