示例#1
0
class ExtractSegmentationMask:
    def __init__(self):
        """
        extract semantic map from pretrained model, for knowing where to ignore in the image,
        since we only have depth info on mountains.
        """
        self.names = {}
        with open('semseg/object150_info.csv') as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                self.names[int(row[0])] = row[5].split(";")[0]

        # Network Builders
        self.net_encoder = ModelBuilder.build_encoder(
            arch='resnet50dilated',
            fc_dim=2048,
            weights=
            'semseg/ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth'
        )
        self.net_decoder = ModelBuilder.build_decoder(
            arch='ppm_deepsup',
            fc_dim=2048,
            num_class=150,
            weights=
            'semseg/ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',
            use_softmax=True)

        self.crit = torch.nn.NLLLoss(ignore_index=-1)
        self.segmentation_module = SegmentationModule(self.net_encoder,
                                                      self.net_decoder,
                                                      self.crit)
        self.segmentation_module.eval()
        self.segmentation_module.to(device=defs.get_dev())

    def visualize_result(self, img, pred, index=None):
        # filter prediction class if requested
        colors = scipy.io.loadmat('semseg/color150.mat')['colors']
        if index is not None:
            pred = pred.copy()
            pred[pred != index] = -1
            print(f'{self.names[index + 1]}:')

        # colorize prediction
        pred_color = colorEncode(pred, colors).astype(np.float32) / 255

        # aggregate images and save
        im_vis = np.concatenate((img, pred_color), axis=1)
        plt.imshow(im_vis)
        plt.show()

    def __call__(self, sample):
        img_data = sample['image']
        old_shape = img_data.shape[1], img_data.shape[2]
        img_data = ResizeToAlmostResolution(180, 224)({
            'image': img_data
        })['image']
        singleton_batch = {'img_data': img_data[None]}
        output_size = img_data.shape[1:]
        # Run the segmentation at the highest resolution.
        with torch.no_grad():
            scores = self.segmentation_module(singleton_batch,
                                              segSize=output_size)
        # Get the predicted scores for each pixel
        _, pred = torch.max(scores, dim=1)
        pred = TF.resize(pred, old_shape, Image.NEAREST)
        # other irrelevant classes: 1, 4, 12, 20, 25, 83, 116, 126, 127.
        # see csv in semseg folder.
        bad_classes = torch.Tensor([2]).to(device=defs.get_dev())
        mask = torch.full_like(pred, True, dtype=torch.bool)
        mask[(pred[..., None] == bad_classes).any(-1)] = False
        if 'mask' in sample:
            sample['mask'] = sample['mask'] & mask
        else:
            sample['mask'] = mask
        return sample
示例#2
0
def inference_prob(
    img,
    device,
    select_model_option="ade20k-resnet50dilated-ppm_deepsup"
):  # select_model_option = "ade20k-mobilenetv2dilated-c1_deepsup" / "ade20k-hrnetv2"
    '''Load the data and preprocess settings
    Input:
        img - the path of our target image
        device - Current device running
        select_model_option - name of NN we use
    '''
    cfg_ss.merge_from_file("ss/config/" + select_model_option + ".yaml")

    logger = setup_logger(distributed_rank=0)  # TODO

    cfg_ss.MODEL.arch_encoder = cfg_ss.MODEL.arch_encoder.lower()
    cfg_ss.MODEL.arch_decoder = cfg_ss.MODEL.arch_decoder.lower()

    # absolute paths of model weights
    cfg_ss.MODEL.weights_encoder = os.path.join(
        'ss/' + cfg_ss.DIR, 'encoder_' + cfg_ss.TEST.checkpoint)
    cfg_ss.MODEL.weights_decoder = os.path.join(
        'ss/' + cfg_ss.DIR, 'decoder_' + cfg_ss.TEST.checkpoint)

    assert os.path.exists(cfg_ss.MODEL.weights_encoder) and os.path.exists(
        cfg_ss.MODEL.weights_decoder), "checkpoint does not exist!"

    # generate testing image list
    imgs = [img]
    assert len(imgs), "imgs should be a path to image (.jpg) or directory."
    cfg_ss.list_test = [{'fpath_img': x} for x in imgs]

    if not os.path.isdir(cfg_ss.TEST.result):
        os.makedirs(cfg_ss.TEST.result)

    if torch.cuda.is_available():
        torch.cuda.set_device(device)

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg_ss.MODEL.arch_encoder,
        fc_dim=cfg_ss.MODEL.fc_dim,
        weights=cfg_ss.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg_ss.MODEL.arch_decoder,
        fc_dim=cfg_ss.MODEL.fc_dim,
        num_class=cfg_ss.DATASET.num_class,
        weights=cfg_ss.MODEL.weights_decoder,
        use_softmax=True)

    crit = nn.NLLLoss(ignore_index=-1)

    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

    # Dataset and Loader
    dataset_test = TestDataset(cfg_ss.list_test, cfg_ss.DATASET)
    loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=cfg_ss.TEST.batch_size,
        shuffle=False,
        collate_fn=user_scattered_collate,
        num_workers=5,
        drop_last=True)

    segmentation_module.to(device)

    # Main loop
    return segmentation_module, loader_test
def main(cfg, gpus):
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)

    crit = nn.NLLLoss(ignore_index=-1)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder, crit, cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder, crit)

    # Dataset and Loader
    dataset_train = TrainDataset(
        cfg.DATASET.root_dataset,
        cfg.DATASET.list_train,
        cfg.DATASET,
        batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)

    segmentation_module.to(device=cuda)

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history, epoch+1, cfg)

        # checkpointing
        checkpoint(nets, history, cfg, epoch+1)

    print('Training Done!')