Beispiel #1
0
def main():
    """Inference for semantic segmentation.
  """
    # Retreve experiment configurations.
    args = parse_args('Inference for semantic segmentation.')
    config.network.kmeans_num_clusters = separate_comma(
        args.kmeans_num_clusters)
    config.network.label_divisor = args.label_divisor

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config)
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'segsort':
        prediction_model = segsort(config)
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    embedding_model = embedding_model.to("cuda:0")
    prediction_model = prediction_model.to("cuda:0")
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    #prediction_model.load_state_dict(
    #    torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Load memory prototypes.
    semantic_memory_prototypes, semantic_memory_prototype_labels = None, None
    if args.semantic_memory_dir is not None:
        semantic_memory_prototypes, semantic_memory_prototype_labels = (
            segsort_others.load_memory_banks(args.semantic_memory_dir))
        semantic_memory_prototypes = semantic_memory_prototypes.to("cuda:0")
        semantic_memory_prototype_labels = semantic_memory_prototype_labels.to(
            "cuda:0")

        # Remove ignore class.
        valid_prototypes = torch.ne(
            semantic_memory_prototype_labels,
            config.dataset.semantic_ignore_index).nonzero()
        valid_prototypes = valid_prototypes.view(-1)
        semantic_memory_prototypes = torch.index_select(
            semantic_memory_prototypes, 0, valid_prototypes)
        semantic_memory_prototype_labels = torch.index_select(
            semantic_memory_prototype_labels, 0, valid_prototypes)

    # Start inferencing.
    for data_index in tqdm(range(len(test_dataset))):
        # Image path.
        image_path = test_image_paths[data_index]
        base_name = os.path.basename(image_path).replace('.jpg', '.png')

        # Image resolution.
        image_batch, label_batch, _ = test_dataset[data_index]
        image_h, image_w = image_batch['image'].shape[-2:]

        # Resize the input image.
        if config.test.image_size > 0:
            image_batch['image'] = transforms.resize_with_interpolation(
                image_batch['image'].transpose(1, 2, 0),
                config.test.image_size,
                method='bilinear').transpose(2, 0, 1)
            for lab_name in ['semantic_label', 'instance_label']:
                label_batch[lab_name] = transforms.resize_with_interpolation(
                    label_batch[lab_name],
                    config.test.image_size,
                    method='nearest')
        resize_image_h, resize_image_w = image_batch['image'].shape[-2:]

        # Crop and Pad the input image.
        image_batch['image'] = transforms.resize_with_pad(
            image_batch['image'].transpose(1, 2, 0),
            config.test.crop_size,
            image_pad_value=0).transpose(2, 0, 1)
        image_batch['image'] = torch.FloatTensor(
            image_batch['image'][np.newaxis, ...]).to("cuda:0")
        pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

        # Create the fake labels where clustering ignores 255.
        fake_label_batch = {}
        for label_name in ['semantic_label', 'instance_label']:
            lab = np.zeros((resize_image_h, resize_image_w), dtype=np.uint8)
            lab = transforms.resize_with_pad(
                lab,
                config.test.crop_size,
                image_pad_value=config.dataset.semantic_ignore_index)

            fake_label_batch[label_name] = torch.LongTensor(
                lab[np.newaxis, ...]).to("cuda:0")

        # Put label batch to gpu 1.
        for k, v in label_batch.items():
            label_batch[k] = torch.LongTensor(v[np.newaxis, ...]).to("cuda:0")

        # Create the ending index of each patch.
        stride_h, stride_w = config.test.stride
        crop_h, crop_w = config.test.crop_size
        npatches_h = math.ceil(1.0 * (pad_image_h - crop_h) / stride_h) + 1
        npatches_w = math.ceil(1.0 * (pad_image_w - crop_w) / stride_w) + 1
        patch_ind_h = np.linspace(crop_h,
                                  pad_image_h,
                                  npatches_h,
                                  dtype=np.int32)
        patch_ind_w = np.linspace(crop_w,
                                  pad_image_w,
                                  npatches_w,
                                  dtype=np.int32)

        # Create place holder for full-resolution embeddings.
        embeddings = {}
        counts = torch.FloatTensor(1, 1, pad_image_h,
                                   pad_image_w).zero_().to("cuda:0")
        with torch.no_grad():
            for ind_h in patch_ind_h:
                for ind_w in patch_ind_w:
                    sh, eh = ind_h - crop_h, ind_h
                    sw, ew = ind_w - crop_w, ind_w
                    crop_image_batch = {
                        k: v[:, :, sh:eh, sw:ew]
                        for k, v in image_batch.items()
                    }

                    # Feed-forward.
                    crop_embeddings = embedding_model.generate_embeddings(
                        crop_image_batch, resize_as_input=True)

                    # Initialize embedding.
                    for name in crop_embeddings:
                        if crop_embeddings[name] is None:
                            continue
                        crop_emb = crop_embeddings[name].to("cuda:0")
                        if name in ['embedding']:
                            crop_emb = common_utils.normalize_embedding(
                                crop_emb.permute(0, 2, 3, 1).contiguous())
                            crop_emb = crop_emb.permute(0, 3, 1, 2)
                        else:
                            continue

                        if name not in embeddings.keys():
                            embeddings[name] = torch.FloatTensor(
                                1, crop_emb.shape[1], pad_image_h,
                                pad_image_w).zero_().to("cuda:0")
                        embeddings[name][:, :, sh:eh, sw:ew] += crop_emb
                    counts[:, :, sh:eh, sw:ew] += 1

        for k in embeddings.keys():
            embeddings[k] /= counts

        # KMeans.
        lab_div = config.network.label_divisor
        fake_sem_lab = fake_label_batch['semantic_label']
        fake_inst_lab = fake_label_batch['instance_label']
        clustering_outputs = embedding_model.generate_clusters(
            embeddings.get('embedding', None), fake_sem_lab, fake_inst_lab)
        embeddings.update(clustering_outputs)

        # Generate predictions.
        outputs = prediction_model(
            embeddings, {
                'semantic_memory_prototype': semantic_memory_prototypes,
                'semantic_memory_prototype_label':
                semantic_memory_prototype_labels
            },
            with_loss=False,
            with_prediction=True)

        # Save semantic predictions.
        semantic_pred = outputs.get('semantic_prediction', None)
        if semantic_pred is not None:
            semantic_pred = (semantic_pred.view(
                resize_image_h,
                resize_image_w).cpu().data.numpy().astype(np.uint8))
            semantic_pred = cv2.resize(semantic_pred, (image_w, image_h),
                                       interpolation=cv2.INTER_NEAREST)

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)
Beispiel #2
0
def main():
    """Training for softmax classifier only.
  """
    # Retreve experiment configurations.
    args = parse_args('Training for softmax classifier only.')

    # Retrieve GPU informations.
    device_ids = [int(i) for i in config.gpus.split(',')]
    gpu_ids = [torch.device('cuda', i) for i in device_ids]
    num_gpus = len(gpu_ids)

    # Create logger and tensorboard writer.
    summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir)
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)

    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    optimizer_path_template = os.path.join(args.snapshot_dir,
                                           'model-{:d}.state.pth')

    # Create data loaders.
    train_dataset = ListTagClassifierDataset(
        data_dir=args.data_dir,
        data_list=args.data_list,
        img_mean=config.network.pixel_means,
        img_std=config.network.pixel_stds,
        size=config.train.crop_size,
        random_crop=config.train.random_crop,
        random_scale=config.train.random_scale,
        random_mirror=config.train.random_mirror,
        random_grayscale=True,
        random_blur=True,
        training=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=config.train.shuffle,
        num_workers=num_gpus * config.num_threads,
        drop_last=False,
        collate_fn=train_dataset.collate_fn)

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'softmax_classifier':
        prediction_model = softmax_classifier(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    # Use customized optimizer and pass lr=1 to support different lr for
    # different weights.
    optimizer = SGD(embedding_model.get_params_lr() +
                    prediction_model.get_params_lr(),
                    lr=1,
                    momentum=config.train.momentum,
                    weight_decay=config.train.weight_decay)
    optimizer.zero_grad()

    # Load pre-trained weights.
    curr_iter = config.train.begin_iteration
    if config.network.pretrained:
        print('Loading pre-trained model: {:s}'.format(
            config.network.pretrained))
        embedding_model.load_state_dict(torch.load(
            config.network.pretrained)['embedding_model'],
                                        resume=True)
    else:
        raise ValueError('Pre-trained model is required.')

    # Distribute model weights to multi-gpus.
    embedding_model = DataParallel(embedding_model,
                                   device_ids=device_ids,
                                   gather_output=False)
    prediction_model = DataParallel(prediction_model,
                                    device_ids=device_ids,
                                    gather_output=False)

    embedding_model.eval()
    prediction_model.train()
    print(embedding_model)
    print(prediction_model)

    # Create memory bank.
    memory_banks = {}

    # start training
    train_iterator = train_loader.__iter__()
    iterator_index = 0
    pbar = tqdm(range(curr_iter, config.train.max_iteration))
    for curr_iter in pbar:
        # Check if the rest of datas is enough to iterate through;
        # otherwise, re-initiate the data iterator.
        if iterator_index + num_gpus >= len(train_loader):
            train_iterator = train_loader.__iter__()
            iterator_index = 0

        # Feed-forward.
        image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu(
            train_iterator, gpu_ids)
        iterator_index += num_gpus

        # Generate embeddings, clustering and prototypes.
        with torch.no_grad():
            embeddings = embedding_model(*zip(image_batch, label_batch))

        # Compute loss.
        outputs = prediction_model(*zip(embeddings, label_batch))
        outputs = scatter_gather.gather(outputs, gpu_ids[0])
        losses = []
        for k in ['sem_ann_loss']:
            loss = outputs.get(k, None)
            if loss is not None:
                outputs[k] = loss.mean()
                losses.append(outputs[k])
        loss = sum(losses)
        acc = outputs['accuracy'].mean()

        # Backward propogation.
        if config.train.lr_policy == 'step':
            lr = train_utils.lr_step(config.train.base_lr, curr_iter,
                                     config.train.decay_iterations,
                                     config.train.warmup_iteration)
        else:
            lr = train_utils.lr_poly(config.train.base_lr, curr_iter,
                                     config.train.max_iteration,
                                     config.train.warmup_iteration)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step(lr)

        # Snapshot the trained model.
        if ((curr_iter + 1) % config.train.snapshot_step == 0
                or curr_iter == config.train.max_iteration - 1):
            model_state_dict = {
                'embedding_model': embedding_model.module.state_dict(),
                'prediction_model': prediction_model.module.state_dict()
            }
            torch.save(model_state_dict, model_path_template.format(curr_iter))
            torch.save(optimizer.state_dict(),
                       optimizer_path_template.format(curr_iter))

        # Print loss in the progress bar.
        line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format(
            loss.item(), acc.item(), lr)
        pbar.set_description(line)
Beispiel #3
0
def main():
    """Inference for semantic segmentation.
  """
    # Retreve experiment configurations.
    args = parse_args('Inference for semantic segmentation.')

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')
    if not os.path.isdir(semantic_dir):
        os.makedirs(semantic_dir)
    if not os.path.isdir(semantic_rgb_dir):
        os.makedirs(semantic_rgb_dir)

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    prediction_model = softmax_classifier(config).cuda()
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    prediction_model.load_state_dict(
        torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Start inferencing.
    for data_index in range(len(test_dataset)):
        # Image path.
        image_path = test_image_paths[data_index]
        base_name = os.path.basename(image_path).replace('.jpg', '.png')

        # Image resolution.
        image_batch, _, _ = test_dataset[data_index]
        image_h, image_w = image_batch['image'].shape[-2:]

        # Resize the input image.
        if config.test.image_size > 0:
            image_batch['image'] = transforms.resize_with_interpolation(
                image_batch['image'].transpose(1, 2, 0),
                config.test.image_size,
                method='bilinear').transpose(2, 0, 1)
        resize_image_h, resize_image_w = image_batch['image'].shape[-2:]

        # Crop and Pad the input image.
        image_batch['image'] = transforms.resize_with_pad(
            image_batch['image'].transpose(1, 2, 0),
            config.test.crop_size,
            image_pad_value=0).transpose(2, 0, 1)
        image_batch['image'] = torch.FloatTensor(
            image_batch['image'][np.newaxis, ...]).cuda()
        pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

        # Create the ending index of each patch.
        stride_h, stride_w = config.test.stride
        crop_h, crop_w = config.test.crop_size
        npatches_h = math.ceil(1.0 * (pad_image_h - crop_h) / stride_h) + 1
        npatches_w = math.ceil(1.0 * (pad_image_w - crop_w) / stride_w) + 1
        patch_ind_h = np.linspace(crop_h,
                                  pad_image_h,
                                  npatches_h,
                                  dtype=np.int32)
        patch_ind_w = np.linspace(crop_w,
                                  pad_image_w,
                                  npatches_w,
                                  dtype=np.int32)

        # Create place holder for full-resolution embeddings.
        outputs = {}
        with torch.no_grad():
            for ind_h in patch_ind_h:
                for ind_w in patch_ind_w:
                    sh, eh = ind_h - crop_h, ind_h
                    sw, ew = ind_w - crop_w, ind_w
                    crop_image_batch = {
                        k: v[:, :, sh:eh, sw:ew]
                        for k, v in image_batch.items()
                    }

                    # Feed-forward.
                    crop_embeddings = embedding_model(crop_image_batch,
                                                      resize_as_input=True)
                    crop_outputs = prediction_model(crop_embeddings)

                    for name, crop_out in crop_outputs.items():

                        if crop_out is not None:
                            if name not in outputs.keys():
                                output_shape = list(crop_out.shape)
                                output_shape[-2:] = pad_image_h, pad_image_w
                                outputs[name] = torch.zeros(
                                    output_shape, dtype=crop_out.dtype).cuda()
                            outputs[name][..., sh:eh, sw:ew] += crop_out

        # Save semantic predictions.
        semantic_logits = outputs.get('semantic_logit', None)
        if semantic_logits is not None:
            semantic_prob = F.softmax(semantic_logits, dim=1)
            semantic_prob = semantic_prob[
                0, :, :resize_image_h, :resize_image_w]
            semantic_prob = semantic_prob.data.cpu().numpy().astype(np.float32)

            # DenseCRF post-processing.
            image = image_batch['image'][0].data.cpu().numpy().astype(
                np.float32)
            image = image.transpose(1, 2, 0)
            image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            image += np.reshape(config.network.pixel_means, (1, 1, 3))
            image = image * 255
            image = image.astype(np.uint8)
            image = image[:resize_image_h, :resize_image_w, :]

            semantic_prob = postprocessor(image, semantic_prob)

            #semantic_pred = torch.argmax(semantic_logits, 1)
            semantic_pred = np.argmax(semantic_prob, axis=0).astype(np.uint8)
            #semantic_pred = (semantic_pred.view(pad_image_h, pad_image_w)
            #                              .cpu()
            #                              .data
            #                              .numpy()
            #                              .astype(np.uint8))
            #semantic_pred = semantic_pred[:resize_image_h, :resize_image_w]
            semantic_pred = cv2.resize(semantic_pred, (image_w, image_h),
                                       interpolation=cv2.INTER_NEAREST)

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)

            # Clean GPU memory cache to save more space.
            outputs = {}
            crop_embeddings = {}
            crop_outputs = {}
            torch.cuda.empty_cache()
Beispiel #4
0
def main():
    """Inference for semantic segmentation.
  """
    # Retreve experiment configurations.
    args = parse_args('Inference for semantic segmentation.')

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')
    os.makedirs(semantic_dir, exist_ok=True)
    os.makedirs(semantic_rgb_dir, exist_ok=True)

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    prediction_model = softmax_classifier(config).cuda()
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    prediction_model.load_state_dict(
        torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            image_batch, label_batch, _ = test_dataset[data_index]
            image_h, image_w = image_batch['image'].shape[-2:]
            batches = other_utils.create_image_pyramid(
                image_batch,
                label_batch,
                scales=[0.5, 0.75, 1, 1.25, 1.5],
                is_flip=True)

            semantic_logits = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).cuda()
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                # Create the ending index of each patch.
                stride_h, stride_w = config.test.stride
                crop_h, crop_w = config.test.crop_size
                npatches_h = math.ceil(1.0 *
                                       (pad_image_h - crop_h) / stride_h) + 1
                npatches_w = math.ceil(1.0 *
                                       (pad_image_w - crop_w) / stride_w) + 1
                patch_ind_h = np.linspace(crop_h,
                                          pad_image_h,
                                          npatches_h,
                                          dtype=np.int32)
                patch_ind_w = np.linspace(crop_w,
                                          pad_image_w,
                                          npatches_w,
                                          dtype=np.int32)

                # Create place holder for full-resolution embeddings.
                semantic_logit = torch.FloatTensor(
                    1, config.dataset.num_classes, pad_image_h,
                    pad_image_w).zero_().to("cuda:0")
                counts = torch.FloatTensor(1, 1, pad_image_h,
                                           pad_image_w).zero_().to("cuda:0")
                for ind_h in patch_ind_h:
                    for ind_w in patch_ind_w:
                        sh, eh = ind_h - crop_h, ind_h
                        sw, ew = ind_w - crop_w, ind_w
                        crop_image_batch = {
                            k: v[:, :, sh:eh, sw:ew]
                            for k, v in image_batch.items()
                        }

                        # Feed-forward.
                        crop_embeddings = embedding_model(crop_image_batch,
                                                          resize_as_input=True)
                        crop_outputs = prediction_model(crop_embeddings)
                        semantic_logit[..., sh:eh, sw:ew] += crop_outputs[
                            'semantic_logit'].to("cuda:0")
                        counts[..., sh:eh, sw:ew] += 1
                semantic_logit /= counts
                semantic_logit = semantic_logit[
                    ..., :resize_image_h, :resize_image_w]
                semantic_logit = F.interpolate(semantic_logit,
                                               size=(image_h, image_w),
                                               mode='bilinear')
                semantic_logit = F.softmax(semantic_logit, dim=1)
                semantic_logit = semantic_logit.data.cpu().numpy().astype(
                    np.float32)
                if data_info['is_flip']:
                    semantic_logit = semantic_logit[..., ::-1]
                semantic_logits.append(semantic_logit)

            # Save semantic predictions.
            semantic_logits = np.concatenate(semantic_logits, axis=0)
            semantic_logits = np.sum(semantic_logits, axis=0)
            if semantic_logits is not None:
                semantic_pred = np.argmax(semantic_logits,
                                          axis=0).astype(np.uint8)

                semantic_pred_name = os.path.join(semantic_dir, base_name)
                Image.fromarray(semantic_pred,
                                mode='L').save(semantic_pred_name)

                semantic_pred_rgb = color_map[semantic_pred]
                semantic_pred_rgb_name = os.path.join(semantic_rgb_dir,
                                                      base_name)
                Image.fromarray(semantic_pred_rgb,
                                mode='RGB').save(semantic_pred_rgb_name)

            # Clean GPU memory cache to save more space.
            outputs = {}
            crop_embeddings = {}
            crop_outputs = {}
            torch.cuda.empty_cache()
Beispiel #5
0
def main():
    """Generate pseudo labels by softmax classifier.
  """
    # Retreve experiment configurations.
    args = parse_args('Generate pseudo labels by softmax classifier.')

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    prediction_model = softmax_classifier(config).cuda()
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    prediction_model.load_state_dict(
        torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            original_image_batch, original_label_batch, _ = test_dataset[
                data_index]
            image_h, image_w = original_image_batch['image'].shape[-2:]

            lab_tags = np.unique(original_label_batch['semantic_label'])
            lab_tags = lab_tags[lab_tags < config.dataset.num_classes]
            label_tags = np.zeros((config.dataset.num_classes, ),
                                  dtype=np.bool)
            label_tags[lab_tags] = True
            label_tags = torch.from_numpy(label_tags).cuda()

            # Image resolution.
            batches = other_utils.create_image_pyramid(original_image_batch,
                                                       original_label_batch,
                                                       scales=[0.75, 1],
                                                       is_flip=True)

            affs = []
            semantic_probs = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).cuda()
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                embeddings = embedding_model(image_batch, resize_as_input=True)
                outputs = prediction_model(embeddings)

                embs = embeddings[
                    'embedding'][:, :, :resize_image_h, :resize_image_w]
                semantic_logit = outputs['semantic_logit'][
                    ..., :resize_image_h, :resize_image_w]
                if data_info['is_flip']:
                    embs = torch.flip(embs, dims=[3])
                    semantic_logit = torch.flip(semantic_logit, dims=[3])
                embs = F.interpolate(embs,
                                     size=(image_h // 8, image_w // 8),
                                     mode='bilinear')
                embs = embs / torch.norm(embs, dim=1)
                embs_flat = embs.view(embs.shape[1], -1)
                aff = torch.matmul(embs_flat.t(),
                                   embs_flat).mul_(5).add_(-5).exp_()
                affs.append(aff)

                semantic_logit = F.interpolate(semantic_logit,
                                               size=(image_h // 8,
                                                     image_w // 8),
                                               mode='bilinear')
                #semantic_prob = F.softmax(semantic_logit, dim=1)
                #semantic_probs.append(semantic_prob)
                semantic_probs.append(semantic_logit)

            cat_semantic_probs = torch.cat(semantic_probs, dim=0)
            #semantic_probs, _ = torch.max(cat_semantic_probs, dim=0)
            #semantic_probs[0] = torch.min(cat_semantic_probs[:, 0, :, :], dim=0)[0]
            semantic_probs = torch.mean(cat_semantic_probs, dim=0)
            semantic_probs = F.softmax(semantic_probs, dim=0)

            # normalize cam.
            max_prob = torch.max(semantic_probs.view(21, -1), dim=1)[0]
            cam_full_arr = semantic_probs / max_prob.view(21, 1, 1)

            cam_shape = cam_full_arr.shape[-2:]
            label_tags = (~label_tags).view(-1, 1,
                                            1).expand(-1, cam_shape[0],
                                                      cam_shape[1])
            cam_full_arr = cam_full_arr.masked_fill(label_tags, 0)
            if TH is not None:
                cam_full_arr[0] = TH

            aff = torch.mean(torch.stack(affs, dim=0), dim=0)

            # Start random walk.
            aff_mat = aff**20

            trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
            for _ in range(WALK_STEPS):
                trans_mat = torch.matmul(trans_mat, trans_mat)

            cam_vec = cam_full_arr.view(21, -1)
            cam_rw = torch.matmul(cam_vec, trans_mat)
            cam_rw = cam_rw.view(21, cam_shape[0], cam_shape[1])

            cam_rw = cam_rw.data.cpu().numpy()
            cam_rw = cv2.resize(cam_rw.transpose(1, 2, 0),
                                dsize=(image_w, image_h),
                                interpolation=cv2.INTER_LINEAR)
            cam_rw_pred = np.argmax(cam_rw, axis=-1).astype(np.uint8)

            # CRF
            #image = image_batch['image'].data.cpu().numpy().astype(np.float32)
            #image = image[0, :, :image_h, :image_w].transpose(1, 2, 0)
            #image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            #image += np.reshape(config.network.pixel_means, (1, 1, 3))
            #image = image * 255
            #image = image.astype(np.uint8)
            #cam_rw = postprocessor(image, cam_rw.transpose(2,0,1))

            #cam_rw_pred = np.argmax(cam_rw, axis=0).astype(np.uint8)

            # Save semantic predictions.
            semantic_pred = cam_rw_pred

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)
Beispiel #6
0
def main():
    """Training for pixel-wise embeddings by pixel-segment
  contrastive learning loss for DensePose.
  """
    # Retreve experiment configurations.
    args = parse_args('Training for pixel-wise embeddings for DensePose.')

    # Retrieve GPU informations.
    device_ids = [int(i) for i in config.gpus.split(',')]
    gpu_ids = [torch.device('cuda', i) for i in device_ids]
    num_gpus = len(gpu_ids)

    # Create logger and tensorboard writer.
    summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir)
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)

    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    optimizer_path_template = os.path.join(args.snapshot_dir,
                                           'model-{:d}.state.pth')

    # Create data loaders.
    train_dataset = DenseposeDataset(data_dir=args.data_dir,
                                     data_list=args.data_list,
                                     img_mean=config.network.pixel_means,
                                     img_std=config.network.pixel_stds,
                                     size=config.train.crop_size,
                                     random_crop=config.train.random_crop,
                                     random_scale=config.train.random_scale,
                                     random_mirror=config.train.random_mirror,
                                     training=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=config.train.shuffle,
        num_workers=num_gpus * config.num_threads,
        drop_last=False,
        collate_fn=train_dataset.collate_fn)

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'segsort':
        prediction_model = segsort(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    # Use synchronize batchnorm.
    if config.network.use_syncbn:
        embedding_model = convert_model(embedding_model).cuda()
        prediction_model = convert_model(prediction_model).cuda()

    # Use customized optimizer and pass lr=1 to support different lr for
    # different weights.
    optimizer = SGD(embedding_model.get_params_lr() +
                    prediction_model.get_params_lr(),
                    lr=1,
                    momentum=config.train.momentum,
                    weight_decay=config.train.weight_decay)
    optimizer.zero_grad()

    # Load pre-trained weights.
    curr_iter = config.train.begin_iteration
    if config.train.resume:
        model_path = model_path_template.fromat(curr_iter)
        print('Resume training from {:s}'.format(model_path))
        embedding_model.load_state_dict(
            torch.load(model_path)['embedding_model'], resume=True)
        prediction_model.load_state_dict(
            torch.load(model_path)['prediction_model'], resume=True)
        optimizer.load_state_dict(
            torch.load(optimizer_path_template.format(curr_iter)))
    elif config.network.pretrained:
        print('Loading pre-trained model: {:s}'.format(
            config.network.pretrained))
        embedding_model.load_state_dict(torch.load(config.network.pretrained))
    else:
        print('Training from scratch')

    # Distribute model weights to multi-gpus.
    embedding_model = DataParallel(embedding_model,
                                   device_ids=device_ids,
                                   gather_output=False)
    prediction_model = DataParallel(prediction_model,
                                    device_ids=device_ids,
                                    gather_output=False)
    if config.network.use_syncbn:
        patch_replication_callback(embedding_model)
        patch_replication_callback(prediction_model)

    for module in embedding_model.modules():
        if isinstance(module, _BatchNorm) or isinstance(module, _ConvNd):
            print(module.training, module)
    print(embedding_model)
    print(prediction_model)

    # Create memory bank.
    memory_banks = {}

    # start training
    train_iterator = train_loader.__iter__()
    iterator_index = 0
    pbar = tqdm(range(curr_iter, config.train.max_iteration))
    for curr_iter in pbar:
        # Check if the rest of datas is enough to iterate through;
        # otherwise, re-initiate the data iterator.
        if iterator_index + num_gpus >= len(train_loader):
            train_iterator = train_loader.__iter__()
            iterator_index = 0

        # Feed-forward.
        image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu(
            train_iterator, gpu_ids)
        iterator_index += num_gpus

        # Generate embeddings, clustering and prototypes.
        embeddings = embedding_model(*zip(image_batch, label_batch))

        # Synchronize cluster indices and computer prototypes.
        c_inds = [emb['cluster_index'] for emb in embeddings]
        cb_inds = [emb['cluster_batch_index'] for emb in embeddings]
        cs_labs = [emb['cluster_semantic_label'] for emb in embeddings]
        ci_labs = [emb['cluster_instance_label'] for emb in embeddings]
        c_embs = [emb['cluster_embedding'] for emb in embeddings]
        c_embs_with_loc = [
            emb['cluster_embedding_with_loc'] for emb in embeddings
        ]
        (prototypes, prototypes_with_loc, prototype_semantic_labels,
         prototype_instance_labels, prototype_batch_indices,
         cluster_indices) = (
             model_utils.gather_clustering_and_update_prototypes(
                 c_embs, c_embs_with_loc, c_inds, cb_inds, cs_labs, ci_labs,
                 'cuda:{:d}'.format(num_gpus - 1)))

        for i in range(len(label_batch)):
            label_batch[i]['prototype'] = prototypes[i]
            label_batch[i]['prototype_with_loc'] = prototypes_with_loc[i]
            label_batch[i][
                'prototype_semantic_label'] = prototype_semantic_labels[i]
            label_batch[i][
                'prototype_instance_label'] = prototype_instance_labels[i]
            label_batch[i]['prototype_batch_index'] = prototype_batch_indices[
                i]
            embeddings[i]['cluster_index'] = cluster_indices[i]

        #semantic_tags = model_utils.gather_and_update_datas(
        #    [lab['semantic_tag'] for lab in label_batch],
        #    'cuda:{:d}'.format(num_gpus-1))
        #for i in range(len(label_batch)):
        #  label_batch[i]['semantic_tag'] = semantic_tags[i]
        #  label_batch[i]['prototype_semantic_tag'] = torch.index_select(
        #      semantic_tags[i],
        #      0,
        #      label_batch[i]['prototype_batch_index'])

        # Add memory bank to label batch.
        for k in memory_banks.keys():
            for i in range(len(label_batch)):
                assert (label_batch[i].get(k, None) is None)
                label_batch[i][k] = [m.to(gpu_ids[i]) for m in memory_banks[k]]

        # Compute loss.
        outputs = prediction_model(*zip(embeddings, label_batch))
        outputs = scatter_gather.gather(outputs, gpu_ids[0])
        losses = []
        for k in [
                'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss', 'feat_aff_loss'
        ]:
            loss = outputs.get(k, None)
            if loss is not None:
                outputs[k] = loss.mean()
                losses.append(outputs[k])
        loss = sum(losses)
        acc = outputs['accuracy'].mean()

        # Write to tensorboard summary.
        writer = (summary_writer if curr_iter %
                  config.train.tensorboard_step == 0 else None)
        if writer is not None:
            summary_vis = []
            summary_val = {}
            # Gather labels to cpu.
            cpu_label_batch = scatter_gather.gather(label_batch, -1)
            summary_vis.append(
                vis_utils.convert_label_to_color(
                    cpu_label_batch['semantic_label'], color_map))
            summary_vis.append(
                vis_utils.convert_label_to_color(
                    cpu_label_batch['instance_label'], color_map))

            # Gather outputs to cpu.
            vis_names = ['embedding']
            cpu_embeddings = scatter_gather.gather(
                [{k: emb.get(k, None)
                  for k in vis_names} for emb in embeddings], -1)
            for vis_name in vis_names:
                if cpu_embeddings.get(vis_name, None) is not None:
                    summary_vis.append(
                        vis_utils.embedding_to_rgb(cpu_embeddings[vis_name],
                                                   'pca'))

            val_names = [
                'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss',
                'feat_aff_loss', 'accuracy'
            ]
            for val_name in val_names:
                if outputs.get(val_name, None) is not None:
                    summary_val[val_name] = outputs[val_name].mean().to('cpu')

            vis_utils.write_image_to_tensorboard(summary_writer, summary_vis,
                                                 summary_vis[-1].shape[-2:],
                                                 curr_iter)
            vis_utils.write_scalars_to_tensorboard(summary_writer, summary_val,
                                                   curr_iter)

        # Backward propogation.
        if config.train.lr_policy == 'step':
            lr = train_utils.lr_step(config.train.base_lr, curr_iter,
                                     config.train.decay_iterations,
                                     config.train.warmup_iteration)
        else:
            lr = train_utils.lr_poly(config.train.base_lr, curr_iter,
                                     config.train.max_iteration,
                                     config.train.warmup_iteration)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step(lr)

        # Update memory banks.
        with torch.no_grad():
            for k in label_batch[0].keys():
                if 'prototype' in k and 'memory' not in k:
                    memory = label_batch[0][k].clone().detach()
                    memory_key = 'memory_' + k
                    if memory_key not in memory_banks.keys():
                        memory_banks[memory_key] = []
                    memory_banks[memory_key].append(memory)
                    if len(memory_banks[memory_key]
                           ) > config.train.memory_bank_size:
                        memory_banks[memory_key] = memory_banks[memory_key][1:]

            # Update batch labels.
            for k in ['memory_prototype_batch_index']:
                memory_labels = memory_banks.get(k, None)
                if memory_labels is not None:
                    for i, memory_label in enumerate(memory_labels):
                        memory_labels[i] += config.train.batch_size * num_gpus

        # Snapshot the trained model.
        if ((curr_iter + 1) % config.train.snapshot_step == 0
                or curr_iter == config.train.max_iteration - 1):
            model_state_dict = {
                'embedding_model': embedding_model.module.state_dict(),
                'prediction_model': prediction_model.module.state_dict()
            }
            torch.save(model_state_dict, model_path_template.format(curr_iter))
            torch.save(optimizer.state_dict(),
                       optimizer_path_template.format(curr_iter))

        # Print loss in the progress bar.
        line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format(
            loss.item(), acc.item(), lr)
        pbar.set_description(line)
Beispiel #7
0
def main():
    """Generate pseudo labels from CAM by random walk and CRF.
  """
    # Retreve experiment configurations.
    args = parse_args(
        'Generate pseudo labels from CAM by random walk and CRF.')
    config.network.kmeans_num_clusters = separate_comma(
        args.kmeans_num_clusters)
    config.network.label_divisor = args.label_divisor

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config)
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    embedding_model = embedding_model.to("cuda:0")
    embedding_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)

    # Start inferencing.
    for data_index in tqdm(range(len(test_dataset))):
        # Image path.
        image_path = test_image_paths[data_index]
        base_name = os.path.basename(image_path).replace('.jpg', '.png')

        # Image resolution.
        image_batch, label_batch, _ = test_dataset[data_index]
        image_h, image_w = image_batch['image'].shape[-2:]

        # Load cam
        sem_labs = np.unique(label_batch['semantic_label'])
        #cam = np.load(os.path.join('/home/twke/repos/SEAM/outputs/train+/cam', base_name.replace('.png', '.npy')), allow_pickle=True).item()
        cam = np.load(os.path.join(args.cam_dir,
                                   base_name.replace('.png', '.npy')),
                      allow_pickle=True).item()
        cam_full_arr = np.zeros((21, image_h, image_w), np.float32)
        for k, v in cam.items():
            cam_full_arr[k + 1] = v
        cam_full_arr[0] = np.power(
            1 - np.max(cam_full_arr[1:], axis=0, keepdims=True), ALPHA)
        cam_full_arr = torch.from_numpy(cam_full_arr).cuda()

        # Image resolution.
        batches = other_utils.create_image_pyramid(image_batch,
                                                   label_batch,
                                                   scales=[1],
                                                   is_flip=True)

        affs = []
        for image_batch, label_batch, data_info in batches:
            resize_image_h, resize_image_w = image_batch['image'].shape[-2:]
            # Crop and Pad the input image.
            image_batch['image'] = transforms.resize_with_pad(
                image_batch['image'].transpose(1, 2, 0),
                config.test.crop_size,
                image_pad_value=0).transpose(2, 0, 1)
            for lab_name in ['semantic_label', 'instance_label']:
                label_batch[lab_name] = transforms.resize_with_pad(
                    label_batch[lab_name],
                    config.test.crop_size,
                    image_pad_value=255)
            image_batch['image'] = torch.FloatTensor(
                image_batch['image'][np.newaxis, ...]).to("cuda:0")
            for k, v in label_batch.items():
                label_batch[k] = torch.LongTensor(v[np.newaxis,
                                                    ...]).to("cuda:0")
            pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

            with torch.no_grad():
                embeddings = embedding_model(image_batch,
                                             label_batch,
                                             resize_as_input=True)
                embs = embeddings[
                    'embedding'][:, :, :resize_image_h, :resize_image_w]
                if data_info['is_flip']:
                    embs = torch.flip(embs, dims=[3])
                embs = F.interpolate(embs,
                                     size=(image_h // 8, image_w // 8),
                                     mode='bilinear')
                embs = embs / torch.norm(embs, dim=1)
                embs_flat = embs.view(embs.shape[1], -1)
                aff = torch.matmul(embs_flat.t(),
                                   embs_flat).mul_(5).add_(-5).exp_()
                affs.append(aff)

        aff = torch.mean(torch.stack(affs, dim=0), dim=0)
        cam_full_arr = F.interpolate(cam_full_arr.unsqueeze(0),
                                     scale_factor=1 / 8.,
                                     mode='bilinear').squeeze(0)
        cam_shape = cam_full_arr.shape[-2:]

        # Start random walk.
        aff_mat = aff**20

        trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
        for _ in range(WALK_STEPS):
            trans_mat = torch.matmul(trans_mat, trans_mat)

        cam_vec = cam_full_arr.view(21, -1)
        cam_rw = torch.matmul(cam_vec, trans_mat)
        cam_rw = cam_rw.view(21, cam_shape[0], cam_shape[1])

        cam_rw = cam_rw.data.cpu().numpy()
        cam_rw = cv2.resize(cam_rw.transpose(1, 2, 0),
                            dsize=(image_w, image_h),
                            interpolation=cv2.INTER_LINEAR)
        cam_rw_pred = np.argmax(cam_rw, axis=-1).astype(np.uint8)

        # CRF
        image = image_batch['image'].data.cpu().numpy().astype(np.float32)
        image = image[0, :, :image_h, :image_w].transpose(1, 2, 0)
        image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
        image += np.reshape(config.network.pixel_means, (1, 1, 3))
        image = image * 255
        image = image.astype(np.uint8)
        cam_rw = postprocessor(image, cam_rw.transpose(2, 0, 1))

        cam_rw_pred = np.argmax(cam_rw, axis=0).astype(np.uint8)

        # Save semantic predictions.
        semantic_pred = cam_rw_pred

        semantic_pred_name = os.path.join(semantic_dir, base_name)
        if not os.path.isdir(os.path.dirname(semantic_pred_name)):
            os.makedirs(os.path.dirname(semantic_pred_name))
        Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

        semantic_pred_rgb = color_map[semantic_pred]
        semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
        if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
            os.makedirs(os.path.dirname(semantic_pred_rgb_name))
        Image.fromarray(semantic_pred_rgb,
                        mode='RGB').save(semantic_pred_rgb_name)
Beispiel #8
0
def main():
    """Generate pseudo labels by random walk and CRF for DensePose.
  """
    # Retreve experiment configurations.
    args = parse_args(
        'Generate pseudo labels by random walk and CRF for DensePose.')
    config.network.kmeans_num_clusters = separate_comma(
        args.kmeans_num_clusters)
    config.network.label_divisor = args.label_divisor

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    embedding_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            image_batch, label_batch, _ = test_dataset[data_index]
            image_h, image_w = image_batch['image'].shape[-2:]

            lab_tags = np.unique(label_batch['semantic_label'])
            lab_tags = lab_tags[lab_tags < config.dataset.num_classes]
            label_tags = np.zeros((config.dataset.num_classes, ),
                                  dtype=np.bool)
            label_tags[lab_tags] = True
            label_tags = torch.from_numpy(label_tags).cuda()

            # Pad the input image.
            image_batch['image'] = transforms.resize_with_pad(
                image_batch['image'].transpose(1, 2, 0),
                config.test.crop_size,
                image_pad_value=0).transpose(2, 0, 1)
            image_batch['image'] = torch.FloatTensor(
                image_batch['image'][np.newaxis, ...]).cuda()
            pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

            original_semantic_label = label_batch['semantic_label'].copy()
            lab = label_batch['semantic_label']
            lab[lab == 255] = config.dataset.num_classes
            for lab in label_batch.keys():
                label_batch[lab] = torch.LongTensor(
                    label_batch[lab][np.newaxis, ...]).cuda()

            embeddings = embedding_model.generate_embeddings(
                image_batch, resize_as_input=True)
            embeddings['embedding'] = F.interpolate(embeddings['embedding'],
                                                    size=(pad_image_h // 2,
                                                          pad_image_w // 2),
                                                    mode='bilinear')
            embeddings['embedding'] = (
                embeddings['embedding'][:, :, :image_h // 2, :image_w // 2])

            # Create affinity matrix.
            embs = embeddings['embedding']
            embs = F.interpolate(embs,
                                 size=(image_h // 8, image_w // 8),
                                 mode='bilinear')
            embs = embs / torch.norm(embs, dim=1)
            embs_flat = embs.view(embs.shape[1], -1)
            aff = torch.matmul(embs_flat.t(),
                               embs_flat).mul_(5).add_(-5).exp_()

            # Assign unknown labels to nearest neighbor.
            size = embeddings['embedding'].shape[-2:]
            s_lab = common_utils.resize_labels(label_batch['semantic_label'],
                                               size)
            i_lab = common_utils.resize_labels(label_batch['instance_label'],
                                               size)
            clusterings = embedding_model.generate_clusters(
                embeddings['embedding'], s_lab, i_lab)

            s_labs, c_inds = segsort_common.prepare_prototype_labels(
                clusterings['cluster_semantic_label'],
                clusterings['cluster_index'],
                clusterings['cluster_semantic_label'].max() + 1)
            embs = clusterings['cluster_embedding']
            protos = segsort_common.calculate_prototypes_from_labels(
                embs, c_inds)
            s_tags = model_utils.gather_multiset_labels_per_batch_by_nearest_neighbor(
                embs,
                protos,
                s_labs,
                torch.zeros_like(clusterings['cluster_semantic_label']),
                torch.zeros_like(s_labs),
                num_classes=config.dataset.num_classes,
                top_k=1,
                threshold=-1,
                label_divisor=config.network.label_divisor)
            #s_labs = torch.argmax(s_tags, dim=1)
            #semantic_pred = torch.gather(s_labs, 0, c_inds)
            #semantic_pred = s_labs
            s_probs = common_utils.segment_mean(s_tags.float(), c_inds)
            s_probs = s_probs / s_probs.sum(dim=1, keepdims=True)
            semantic_probs = torch.index_select(s_probs, 0, c_inds)
            #semantic_pred = torch.argmax(semantic_probs, dim=1)
            #semantic_pred = (semantic_pred.view(image_h, image_w)
            #                              .data.cpu().numpy().astype(np.uint8))
            semantic_probs = semantic_probs.view(1, image_h // 2, image_w // 2,
                                                 -1)
            semantic_probs = semantic_probs.permute(0, 3, 1, 2).contiguous()
            semantic_probs = F.interpolate(semantic_probs,
                                           size=(image_h // 8, image_w // 8),
                                           mode='bilinear')
            max_prob = torch.max(semantic_probs.view(15, -1), dim=1)[0]
            cam_full_arr = semantic_probs / max_prob.view(15, 1, 1)

            cam_shape = cam_full_arr.shape[-2:]
            label_tags = (~label_tags).view(-1, 1,
                                            1).expand(-1, cam_shape[0],
                                                      cam_shape[1])
            cam_full_arr = cam_full_arr.masked_fill(label_tags, 0)
            if TH is not None:
                cam_full_arr[0] = TH

            # Start random walk.
            aff_mat = aff**20

            trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
            for _ in range(WALK_STEPS):
                trans_mat = torch.matmul(trans_mat, trans_mat)

            cam_vec = cam_full_arr.view(15, -1)
            cam_rw = torch.matmul(cam_vec, trans_mat)
            cam_rw = cam_rw.view(15, cam_shape[0], cam_shape[1])

            cam_rw = cam_rw.data.cpu().numpy()
            cam_rw = cv2.resize(cam_rw.transpose(1, 2, 0),
                                dsize=(image_w, image_h),
                                interpolation=cv2.INTER_LINEAR)
            cam_rw_pred = np.argmax(cam_rw, axis=-1).astype(np.uint8)

            # CRF
            image = image_batch['image'].data.cpu().numpy().astype(np.float32)
            image = image[0, :, :image_h, :image_w].transpose(1, 2, 0)
            image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            image += np.reshape(config.network.pixel_means, (1, 1, 3))
            image = image * 255
            image = image.astype(np.uint8)
            cam_rw = postprocessor(image, cam_rw.transpose(2, 0, 1))

            cam_rw_pred = np.argmax(cam_rw, axis=0).astype(np.uint8)

            # Ignore regions.
            ignore_mask = original_semantic_label == 255
            cam_rw_pred[ignore_mask] = 255

            # Save semantic predictions.
            semantic_pred = cam_rw_pred

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)
Beispiel #9
0
def main():
    """Generate pseudo labels by nearest neighbor retrievals.
  """
    # Retreve experiment configurations.
    args = parse_args('Generate pseudo labels by nearest neighbor retrievals.')
    config.network.kmeans_num_clusters = separate_comma(
        args.kmeans_num_clusters)
    config.network.label_divisor = args.label_divisor

    # Create directories to save results.
    semantic_dir = os.path.join(args.save_dir, 'semantic_gray')
    semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color')

    # Create color map.
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)
    color_map = color_map.numpy()

    # Create data loaders.
    test_dataset = ListDataset(data_dir=args.data_dir,
                               data_list=args.data_list,
                               img_mean=config.network.pixel_means,
                               img_std=config.network.pixel_stds,
                               size=None,
                               random_crop=False,
                               random_scale=False,
                               random_mirror=False,
                               training=False)
    test_image_paths = test_dataset.image_paths

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'segsort':
        prediction_model = segsort(config)
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    embedding_model = embedding_model.to("cuda:0")
    prediction_model = prediction_model.to("cuda:0")
    embedding_model.eval()
    prediction_model.eval()

    # Load trained weights.
    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    save_iter = config.train.max_iteration - 1
    embedding_model.load_state_dict(torch.load(
        model_path_template.format(save_iter))['embedding_model'],
                                    resume=True)
    #prediction_model.load_state_dict(
    #    torch.load(model_path_template.format(save_iter))['prediction_model'])

    # Define CRF.
    postprocessor = DenseCRF(
        iter_max=args.crf_iter_max,
        pos_xy_std=args.crf_pos_xy_std,
        pos_w=args.crf_pos_w,
        bi_xy_std=args.crf_bi_xy_std,
        bi_rgb_std=args.crf_bi_rgb_std,
        bi_w=args.crf_bi_w,
    )

    # Load memory prototypes.
    semantic_memory_prototypes, semantic_memory_prototype_labels = None, None
    if args.semantic_memory_dir is not None:
        semantic_memory_prototypes, semantic_memory_prototype_labels = (
            segsort_others.load_memory_banks(args.semantic_memory_dir))
        semantic_memory_prototypes = semantic_memory_prototypes.to("cuda:0")
        semantic_memory_prototype_labels = semantic_memory_prototype_labels.to(
            "cuda:0")

        # Remove ignore class.
        valid_prototypes = torch.ne(
            semantic_memory_prototype_labels,
            config.dataset.semantic_ignore_index).nonzero()
        valid_prototypes = valid_prototypes.view(-1)
        semantic_memory_prototypes = torch.index_select(
            semantic_memory_prototypes, 0, valid_prototypes)
        semantic_memory_prototype_labels = torch.index_select(
            semantic_memory_prototype_labels, 0, valid_prototypes)

    # Start inferencing.
    with torch.no_grad():
        for data_index in tqdm(range(len(test_dataset))):
            # Image path.
            image_path = test_image_paths[data_index]
            base_name = os.path.basename(image_path).replace('.jpg', '.png')

            # Image resolution.
            original_image_batch, original_label_batch, _ = test_dataset[
                data_index]
            image_h, image_w = original_image_batch['image'].shape[-2:]
            batches = other_utils.create_image_pyramid(original_image_batch,
                                                       original_label_batch,
                                                       scales=[0.5, 1, 1.5, 2],
                                                       is_flip=True)

            lab_tags = np.unique(original_label_batch['semantic_label'])
            lab_tags = lab_tags[lab_tags < config.dataset.num_classes]
            label_tags = np.zeros((config.dataset.num_classes, ),
                                  dtype=np.bool)
            label_tags[lab_tags] = True

            semantic_topks = []
            for image_batch, label_batch, data_info in batches:
                resize_image_h, resize_image_w = image_batch['image'].shape[
                    -2:]
                # Crop and Pad the input image.
                image_batch['image'] = transforms.resize_with_pad(
                    image_batch['image'].transpose(1, 2, 0),
                    config.test.crop_size,
                    image_pad_value=0).transpose(2, 0, 1)
                image_batch['image'] = torch.FloatTensor(
                    image_batch['image'][np.newaxis, ...]).to("cuda:0")
                pad_image_h, pad_image_w = image_batch['image'].shape[-2:]

                # Create the fake labels where clustering ignores 255.
                fake_label_batch = {}
                for label_name in ['semantic_label', 'instance_label']:
                    lab = np.zeros((resize_image_h, resize_image_w),
                                   dtype=np.uint8)
                    lab = transforms.resize_with_pad(
                        lab,
                        config.test.crop_size,
                        image_pad_value=config.dataset.semantic_ignore_index)

                    fake_label_batch[label_name] = torch.LongTensor(
                        lab[np.newaxis, ...]).to("cuda:0")

                # Put label batch to gpu 1.
                #for k, v in label_batch.items():
                #  label_batch[k] = torch.LongTensor(v[np.newaxis, ...]).to("cuda:0")

                # Create the ending index of each patch.
                stride_h, stride_w = config.test.stride
                crop_h, crop_w = config.test.crop_size
                npatches_h = math.ceil(1.0 *
                                       (pad_image_h - crop_h) / stride_h) + 1
                npatches_w = math.ceil(1.0 *
                                       (pad_image_w - crop_w) / stride_w) + 1
                patch_ind_h = np.linspace(crop_h,
                                          pad_image_h,
                                          npatches_h,
                                          dtype=np.int32)
                patch_ind_w = np.linspace(crop_w,
                                          pad_image_w,
                                          npatches_w,
                                          dtype=np.int32)

                # Create place holder for full-resolution embeddings.
                embeddings = {}
                counts = torch.FloatTensor(1, 1, pad_image_h,
                                           pad_image_w).zero_().to("cuda:0")
                for ind_h in patch_ind_h:
                    for ind_w in patch_ind_w:
                        sh, eh = ind_h - crop_h, ind_h
                        sw, ew = ind_w - crop_w, ind_w
                        crop_image_batch = {
                            k: v[:, :, sh:eh, sw:ew]
                            for k, v in image_batch.items()
                        }

                        # Feed-forward.
                        crop_embeddings = embedding_model.generate_embeddings(
                            crop_image_batch, resize_as_input=True)

                        # Initialize embedding.
                        for name in crop_embeddings:
                            if crop_embeddings[name] is None:
                                continue
                            crop_emb = crop_embeddings[name].to("cuda:0")
                            if name in ['embedding']:
                                crop_emb = common_utils.normalize_embedding(
                                    crop_emb.permute(0, 2, 3, 1).contiguous())
                                crop_emb = crop_emb.permute(0, 3, 1, 2)
                            else:
                                continue

                            if name not in embeddings.keys():
                                embeddings[name] = torch.FloatTensor(
                                    1, crop_emb.shape[1], pad_image_h,
                                    pad_image_w).zero_().to("cuda:0")
                            embeddings[name][:, :, sh:eh, sw:ew] += crop_emb
                        counts[:, :, sh:eh, sw:ew] += 1

                for k in embeddings.keys():
                    embeddings[k] /= counts

                # KMeans.
                lab_div = config.network.label_divisor
                fake_sem_lab = fake_label_batch['semantic_label'][
                    ..., :resize_image_h, :resize_image_w]
                fake_inst_lab = fake_label_batch['instance_label'][
                    ..., :resize_image_h, :resize_image_w]
                embs = embeddings['embedding'][
                    ..., :resize_image_h, :resize_image_w]
                clustering_outputs = embedding_model.generate_clusters(
                    embs, fake_sem_lab, fake_inst_lab)
                embeddings.update(clustering_outputs)

                # Generate predictions.
                outputs = prediction_model(embeddings, {
                    'semantic_memory_prototype':
                    semantic_memory_prototypes,
                    'semantic_memory_prototype_label':
                    semantic_memory_prototype_labels
                },
                                           with_loss=False,
                                           with_prediction=True)
                semantic_topk = common_utils.one_hot(
                    outputs['semantic_score'],
                    config.dataset.num_classes).float()
                semantic_topk = torch.mean(semantic_topk, dim=1)
                semantic_topk = semantic_topk.view(resize_image_h,
                                                   resize_image_w, -1)
                #print(semantic_topk.shape)
                semantic_topk = (semantic_topk.data.cpu().numpy().astype(
                    np.float32))
                semantic_topk = cv2.resize(semantic_topk, (image_w, image_h),
                                           interpolation=cv2.INTER_LINEAR)
                if data_info['is_flip']:
                    semantic_topk = semantic_topk[:, ::-1]
                semantic_topks.append(semantic_topk)

            # Save semantic predictions.
            semantic_topks = np.stack(semantic_topks,
                                      axis=0).astype(np.float32)
            #print(semantic_topks.shape)
            semantic_prob = np.mean(semantic_topks, axis=0)
            semantic_prob = semantic_prob.transpose(2, 0, 1)

            # Normalize prediction.
            C, H, W = semantic_prob.shape
            max_prob = np.max(np.reshape(semantic_prob, (C, -1)), axis=1)
            max_prob = np.maximum(max_prob, 0.15)
            max_prob = np.reshape(max_prob, (C, 1, 1))
            max_prob[~label_tags, :, :] = 1
            semantic_prob = semantic_prob / max_prob

            # DenseCRF post-processing.
            image = original_image_batch['image'].astype(np.float32)
            image = image.transpose(1, 2, 0)
            image *= np.reshape(config.network.pixel_stds, (1, 1, 3))
            image += np.reshape(config.network.pixel_means, (1, 1, 3))
            image = image * 255
            image = image.astype(np.uint8)

            semantic_prob = postprocessor(image, semantic_prob)

            semantic_pred = np.argmax(semantic_prob, axis=0).astype(np.uint8)

            semantic_pred_name = os.path.join(semantic_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_name)):
                os.makedirs(os.path.dirname(semantic_pred_name))
            Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name)

            semantic_pred_rgb = color_map[semantic_pred]
            semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name)
            if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)):
                os.makedirs(os.path.dirname(semantic_pred_rgb_name))
            Image.fromarray(semantic_pred_rgb,
                            mode='RGB').save(semantic_pred_rgb_name)