예제 #1
0
def test(args):
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
    else:
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       dilated=args.dilated,
                                       lateral=args.lateral,
                                       jpu=args.jpu,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    print(model)
    scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
        [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    if not args.ms:
        scales = [1.0]
    num_classes = datasets[args.dataset.lower()].NUM_CLASS
    evaluator = MultiEvalModule(model,
                                num_classes,
                                scales=scales,
                                flip=args.ms).cuda()
    evaluator.eval()

    img = input_transform(Image.open(
        args.input_path).convert('RGB')).unsqueeze(0)

    with torch.no_grad():
        output = evaluator.parallel_forward(img)[0]
        predict = torch.max(output, 1)[1].cpu().numpy()
    mask = utils.get_mask_pallete(predict, args.dataset)
    mask.save(args.save_path)
예제 #2
0
def test(args):
    # output folder
    outdir = args.save_folder
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    testset = get_segmentation_dataset(args.dataset,
                                       split=args.split,
                                       mode=args.mode,
                                       transform=input_transform)
    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
    else:
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       dilated=args.dilated,
                                       multi_grid=args.multi_grid,
                                       stride=args.stride,
                                       lateral=args.lateral,
                                       jpu=args.jpu,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    # print(model)
    scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
        [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    if not args.ms:
        scales = [1.0]
    evaluator = MultiEvalModule(model,
                                testset.num_class,
                                scales=scales,
                                flip=args.ms).cuda()
    evaluator.eval()
    metric = utils.SegmentationMetric(testset.num_class)

    tbar = tqdm(test_data)
    for i, (image, dst) in enumerate(tbar):
        if 'val' in args.mode:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description('pixAcc: %.4f, mIoU: %.4f' %
                                     (pixAcc, mIoU))
        else:
            # with torch.no_grad():
            #     outputs = evaluator.parallel_forward(image)
            #     predicts = [testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
            #                 for output in outputs]
            # for predict, impath in zip(predicts, dst):
            #     mask = utils.get_mask_pallete(predict, args.dataset)
            #     outname = os.path.splitext(impath)[0] + '.png'
            #     mask.save(os.path.join(outdir, outname))
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                # predicts = [testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
                #             for output in outputs]
                predicts = [
                    torch.softmax(output, 1).cpu().numpy()
                    for output in outputs
                ]
            for predict, impath in zip(predicts, dst):
                # mask = utils.get_mask_pallete(predict, args.dataset)
                import numpy as np
                from PIL import Image
                mask = Image.fromarray(
                    (predict[0, 1, :, :] * 255).astype(np.uint8))
                outname = os.path.splitext(impath)[0] + '.bmp'
                mask.save(os.path.join(outdir, outname))
예제 #3
0
def test(args):
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    if args.eval:
        testset = get_dataset(args.dataset,
                              split='val',
                              mode='testval',
                              transform=input_transform)
    elif args.test_val:
        testset = get_dataset(args.dataset,
                              split='val',
                              mode='test',
                              transform=input_transform)
    else:
        testset = get_dataset(args.dataset,
                              split='test',
                              mode='test',
                              transform=input_transform)
    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)
    # model
    pretrained = args.resume is None and args.verify is None
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=pretrained)
        model.base_size = args.base_size
        model.crop_size = args.crop_size
    else:
        # my
        model_kwargs = {}
        if args.choice_indices is not None:
            assert 'alone_resnest50' in args.backbone
            model_kwargs['choice_indices'] = args.choice_indices
        #
        model = get_segmentation_model(
            args.model,
            dataset=args.dataset,
            backbone=args.backbone,
            aux=args.aux,
            se_loss=args.se_loss,
            norm_layer=torch.nn.BatchNorm2d if args.acc_bn else SyncBatchNorm,
            base_size=args.base_size,
            crop_size=args.crop_size,
            **model_kwargs)

    # resuming checkpoint
    if args.verify is not None and os.path.isfile(args.verify):
        print("=> loading checkpoint '{}'".format(args.verify))
        model.load_state_dict(torch.load(args.verify, map_location='cpu'))
    elif args.resume is not None and os.path.isfile(args.resume):
        checkpoint = torch.load(args.resume, map_location='cpu')
        # strict=False, so that it is compatible with old pytorch saved models
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}'".format(args.resume))
    elif not pretrained:
        raise RuntimeError("=> no checkpoint found")

    print(model)
    if args.acc_bn:
        from encoding.utils.precise_bn import update_bn_stats
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_dataset(args.dataset,
                               split=args.train_split,
                               mode='train',
                               **data_kwargs)
        trainloader = data.DataLoader(ReturnFirstClosure(trainset),
                                      batch_size=args.batch_size,
                                      drop_last=True,
                                      shuffle=True,
                                      **loader_kwargs)
        print('Reseting BN statistics')
        #model.apply(reset_bn_statistics)
        model.cuda()
        update_bn_stats(model, trainloader)

    if args.export:
        torch.save(model.state_dict(), args.export + '.pth')
        return

    scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
            [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]#, 2.0
    evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
    evaluator.eval()
    metric = utils.SegmentationMetric(testset.num_class)

    tbar = tqdm(test_data)
    for i, (image, dst) in enumerate(tbar):
        if args.eval:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description('pixAcc: %.4f, mIoU: %.4f' %
                                     (pixAcc, mIoU))
        else:
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                predicts = [
                    testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
                    for output in outputs
                ]
            for predict, impath in zip(predicts, dst):
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))

    if args.eval:
        print('pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
예제 #4
0
def test(args):
    # output folder
    outdir = '%s/msdanet_vis' % (args.dataset)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    if args.eval:
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='testval',
                                           transform=input_transform)
    else:  # set split='test' for test set
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='vis',
                                           transform=input_transform)
    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
    else:
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size,
                                       multi_grid=args.multi_grid,
                                       multi_dilation=args.multi_dilation)
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    print(model)
    num_class = testset.num_class
    evaluator = MultiEvalModule(model,
                                testset.num_class,
                                multi_scales=args.multi_scales).cuda()
    evaluator.eval()

    tbar = tqdm(test_data)

    def eval_batch(image, dst, evaluator, eval_mode):
        if eval_mode:
            # evaluation mode on validation set
            targets = dst
            outputs = evaluator.parallel_forward(image)

            batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
            for output, target in zip(outputs, targets):
                correct, labeled = utils.batch_pix_accuracy(
                    output.data.cpu(), target)
                inter, union = utils.batch_intersection_union(
                    output.data.cpu(), target, testset.num_class)
                batch_correct += correct
                batch_label += labeled
                batch_inter += inter
                batch_union += union
            return batch_correct, batch_label, batch_inter, batch_union
        else:
            # Visualize and dump the results
            im_paths = dst
            outputs = evaluator.parallel_forward(image)
            predicts = [
                torch.max(output, 1)[1].cpu().numpy() + testset.pred_offset
                for output in outputs
            ]
            for predict, impath in zip(predicts, im_paths):
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))
            # dummy outputs for compatible with eval mode
            return 0, 0, 0, 0

    total_inter, total_union, total_correct, total_label = \
        np.int64(0), np.int64(0), np.int64(0), np.int64(0)
    for i, (image, dst) in enumerate(tbar):
        if torch_ver == "0.3":
            image = Variable(image, volatile=True)
            correct, labeled, inter, union = eval_batch(
                image, dst, evaluator, args.eval)
        else:
            with torch.no_grad():
                correct, labeled, inter, union = eval_batch(
                    image, dst, evaluator, args.eval)
        pixAcc, mIoU, IoU = 0, 0, 0
        if args.eval:
            total_correct += correct.astype('int64')
            total_label += labeled.astype('int64')
            total_inter += inter.astype('int64')
            total_union += union.astype('int64')
            pixAcc = np.float64(1.0) * total_correct / (
                np.spacing(1, dtype=np.float64) + total_label)
            IoU = np.float64(1.0) * total_inter / (
                np.spacing(1, dtype=np.float64) + total_union)
            mIoU = IoU.mean()
            tbar.set_description('pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
    return pixAcc, mIoU, IoU, num_class
def test(args):
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    if args.eval:
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           transform=input_transform,
                                           return_file=True)
    else:
        testset = get_segmentation_dataset(args.dataset,
                                           split='test',
                                           mode='test',
                                           transform=input_transform)
    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
    else:
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        pretrained_dict = checkpoint['state_dict']
        model_dict = model.state_dict()

        for name, param in pretrained_dict.items():
            if name not in model_dict:
                continue
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            model_dict[name].copy_(param)

        #model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    print(model)

    # count parameter number
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Total number of parameters: %d" % pytorch_total_params)

    evaluator = MultiEvalModule(model, testset.num_class).cuda()
    evaluator.eval()

    tbar = tqdm(test_data)

    def eval_batch(image, dst, im_paths, evaluator, eval_mode):
        if eval_mode:
            # evaluation mode on validation set
            targets = dst
            outputs = evaluator.parallel_forward(image)
            batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
            for output, target in zip(outputs, targets):
                correct, labeled = utils.batch_pix_accuracy(
                    output.data.cpu(), target)
                inter, union = utils.batch_intersection_union(
                    output.data.cpu(), target, testset.num_class)
                batch_correct += correct
                batch_label += labeled
                batch_inter += inter
                batch_union += union

            # save outputs
            predicts = [
                torch.max(output, 1)[1].cpu().numpy()  # + testset.pred_offset
                for output in outputs
            ]
            for predict, impath, target in zip(predicts, im_paths, targets):
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))

                # save ground truth into png format
                target = target.data.cpu().numpy()
                target = utils.get_mask_pallete(target, args.dataset)
                outname = os.path.splitext(impath)[0] + '_gtruth.png'
                target.save(os.path.join(outdir, outname))

            return batch_correct, batch_label, batch_inter, batch_union
        else:
            # test mode, dump the results
            im_paths = dst
            outputs = evaluator.parallel_forward(image)
            predicts = [
                torch.max(output, 1)[1].cpu().numpy()  # + testset.pred_offset
                for output in outputs
            ]
            for predict, impath in zip(predicts, im_paths):
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))
            # dummy outputs for compatible with eval mode
            return 0, 0, 0, 0

    total_inter, total_union, total_correct, total_label = \
        np.int64(0), np.int64(0), np.int64(0), np.int64(0)
    for i, (image, dst, img_paths) in enumerate(tbar):
        if torch_ver == "0.3":
            image = Variable(image, volatile=True)
            correct, labeled, inter, union = eval_batch(
                image, dst, img_paths, evaluator, args.eval)
        else:
            with torch.no_grad():
                correct, labeled, inter, union = eval_batch(
                    image, dst, img_paths, evaluator, args.eval)
        if args.eval:
            total_correct += correct.astype('int64')
            total_label += labeled.astype('int64')
            total_inter += inter.astype('int64')
            total_union += union.astype('int64')
            pixAcc = np.float64(1.0) * total_correct / (
                np.spacing(1, dtype=np.float64) + total_label)
            IoU = np.float64(1.0) * total_inter / (
                np.spacing(1, dtype=np.float64) + total_union)
            mIoU = IoU.mean()
            tbar.set_description('pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
예제 #6
0
def test(args):
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    data_kwargs = {'root': args.data_root}
    if args.eval:
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='testval',
                                           transform=input_transform,
                                           **data_kwargs)
    elif args.test_val:
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='test',
                                           transform=input_transform,
                                           **data_kwargs)
    else:
        testset = get_segmentation_dataset(args.dataset,
                                           split='test',
                                           mode='test',
                                           transform=input_transform,
                                           **data_kwargs)
    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
        #model.base_size = args.base_size
        #model.crop_size = args.crop_size
    else:
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=SyncBatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    print(model)
    # scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
    #     [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
    scales = [1.0]
    evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
    evaluator.eval()
    metric = utils.SegmentationMetric(testset.num_class)

    tbar = tqdm(test_data)
    for i, (image, dst) in enumerate(tbar):
        if args.eval:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description('pixAcc: %.4f, mIoU: %.4f' %
                                     (pixAcc, mIoU))
        else:
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                predicts = [
                    testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
                    for output in outputs
                ]
            for predict, impath in zip(predicts, dst):
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))
예제 #7
0
def test(args):
    directory = "runs/val_summary/%s/%s/%s/" % (args.dataset, args.model,
                                                args.resume)
    if not os.path.exists(directory):
        os.makedirs(directory)
    writer = SummaryWriter(directory)
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    if args.eval:
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='testval',
                                           transform=input_transform)
    elif args.test_val:
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='test',
                                           transform=input_transform)
    else:
        testset = get_segmentation_dataset(args.dataset,
                                           split='test',
                                           mode='test',
                                           transform=input_transform)
    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)

    Norm_method = torch.nn.BatchNorm2d
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
        #model.base_size = args.base_size
        #model.crop_size = args.crop_size
    else:
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       multi_grid=args.multi_grid,
                                       num_center=args.num_center,
                                       norm_layer=Norm_method,
                                       root=args.backbone_path,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        #model.module.load_state_dict(checkpoint['state_dict'])
        old_state_dict = checkpoint['state_dict']
        new_state_dict = dict()
        for k, v in old_state_dict.items():
            if k.startswith('module.'):
                #new_state_dict[k[len('module.'):]] = old_state_dict[k]
                new_state_dict[k[len('model.module.'):]] = old_state_dict[k]
                #new_state_dict[k] = old_state_dict[k]
            else:
                new_state_dict[k] = old_state_dict[k]
                #new_k = 'module.' + k
                #new_state_dict[new_k] = old_state_dict[k]

        model.load_state_dict(new_state_dict)
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    print(model)
    scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
        [0.75, 1.0, 1.25, 1.5, 1.75, 2.0]

    if args.dataset == 'ade20k':
        scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
    if not args.ms:
        scales = [1.0]

    if args.dataset == 'ade20k':
        evaluator = MultiEvalModule2(model,
                                     testset.num_class,
                                     scales=scales,
                                     flip=args.ms).cuda()
    else:
        evaluator = MultiEvalModule(model,
                                    testset.num_class,
                                    scales=scales,
                                    flip=args.ms).cuda()

    evaluator.eval()
    metric = utils.SegmentationMetric(testset.num_class)

    tbar = tqdm(test_data)
    for i, (image, dst) in enumerate(tbar):
        if args.eval:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description('pixAcc: %.4f, mIoU: %.4f' %
                                     (pixAcc, mIoU))
                writer.add_scalar('pixAcc', pixAcc, i)
                writer.add_scalar('mIoU', mIoU, i)
        else:
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                predicts = [
                    testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
                    for output in outputs
                ]
            for predict, impath in zip(predicts, dst):
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))
    writer.close()
예제 #8
0
def test(args):
    # output folder
    outdir = args.save_folder
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    testset = get_segmentation_dataset(args.dataset,
                                       split=args.split,
                                       mode=args.mode,
                                       transform=input_transform)
    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
    else:
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       dilated=args.dilated,
                                       multi_grid=args.multi_grid,
                                       stride=args.stride,
                                       lateral=args.lateral,
                                       jpu=args.jpu,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    # print(model)
    scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
        [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    if not args.ms:
        scales = [1.0]
    evaluator = MultiEvalModule(model,
                                testset.num_class,
                                scales=scales,
                                flip=args.ms).cuda()
    evaluator.eval()
    tbar = tqdm(test_data)
    total_inter, total_union, total_correct, total_label = 0, 0, 0, 0

    result = []
    for i, (image, dst) in enumerate(tbar):
        # print(dst)
        with torch.no_grad():
            if i > 20:
                st = time.time()
            outputs = evaluator.forward(image[0].unsqueeze(0).cuda())

            if i > 20:
                result.append(1 / (time.time() - st))
                print(np.mean(result), np.std(result))

            if 'val' in args.mode:
                # compute image IoU metric
                inter, union, area_pred, area_lab = batch_intersection_union(
                    outputs, dst[0], testset.num_class)
                total_label += area_lab
                total_inter += inter
                total_union += union

                class_pixAcc = 1.0 * inter / (np.spacing(1) + area_lab)
                class_IoU = 1.0 * inter / (np.spacing(1) + union)
                print("img Classes pixAcc:", class_pixAcc)
                print("img Classes IoU:", class_IoU)
            else:
                # save prediction results
                predict = testset.make_pred(
                    torch.max(output, 1)[1].cpu().numpy())
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(dst[0])[0] + '.png'
                mask.save(os.path.join(outdir, outname))

    if 'val' in args.mode:
        # compute set IoU metric
        pixAcc = 1.0 * total_inter / (np.spacing(1) + total_label)
        IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
        mIoU = IoU.mean()

        print("set Classes pixAcc:", pixAcc)
        print("set Classes IoU:", IoU)
        print("set mean IoU:", mIoU)
예제 #9
0
def semseg(input_path, output_path=None, with_L0=False):
    """
    param:
        input_path: str, path of input image
        output_path: str, path to save output image
    return: tuple, [animal_name, "background"] if pixels of "background" dominate,
                   ["background", animal_name] else.
    """
    sys.argv = sys.argv[:1]
    option = Options()
    args = option.parse()
    args.aux = True
    args.se_loss = True
    args.resume = "./checkpoints/encnet_jpu_res101_pcontext.pth.tar"  # model checkpoint
    torch.manual_seed(args.seed)

    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])

    # using L0_smooth to transform the orignal picture
    if with_L0:
        mid_result = os.path.join(os.path.dirname(input_path), "L0_result.png")
        L0_smooth(input_path, mid_result)
        input_path = mid_result

    # model
    model = get_segmentation_model(args.model,
                                   dataset=args.dataset,
                                   backbone=args.backbone,
                                   dilated=args.dilated,
                                   lateral=args.lateral,
                                   jpu=args.jpu,
                                   aux=args.aux,
                                   se_loss=args.se_loss,
                                   norm_layer=BatchNorm,
                                   base_size=args.base_size,
                                   crop_size=args.crop_size)
    # resuming checkpoint
    if args.resume is None or not os.path.isfile(args.resume):
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.resume))
    checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
    # strict=False, so that it is compatible with old pytorch saved models
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    print("semseg model loaded successfully!")
    scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
        [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    if not args.ms:
        scales = [1.0]
    num_classes = datasets[args.dataset.lower()].NUM_CLASS
    evaluator = MultiEvalModule(model,
                                num_classes,
                                scales=scales,
                                flip=args.ms).cuda()
    evaluator.eval()
    classes = np.array([
        'empty', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
        'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car',
        'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup',
        'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass',
        'ground', 'horse', 'keyboard', 'light', 'motorbike', 'mountain',
        'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock',
        'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
        'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
        'window', 'wood'
    ])
    animals = ['bird', 'cat', 'cow', 'dog', 'horse', 'mouse', 'sheep']
    img = input_transform(Image.open(input_path).convert('RGB')).unsqueeze(0)

    with torch.no_grad():
        output = evaluator.parallel_forward(img)[0]
        predict = torch.max(output, 1)[1].cpu().numpy() + 1
    pred_idx = np.unique(predict)
    pred_label = classes[pred_idx]
    print("[SemSeg] ", input_path, ": ", pred_label, sep='')

    main_pixels = 0
    main_idx = -1
    for idx, label in zip(pred_idx, pred_label):
        if label in animals:
            pixels = np.sum(predict == idx)
            if pixels > main_pixels:
                main_pixels = pixels
                main_idx = idx
    background_pixels = np.sum(predict != main_idx)

    main_animal = classes[main_idx]
    predict[predict != main_idx] = 29
    mask_matrix = predict.copy()

    if output_path is not None:
        mask_matrix[np.where(mask_matrix != 29)] = 1
        mask_matrix[np.where(mask_matrix == 29)] = 0
        mask = utils.get_mask_pallete(mask_matrix, args.dataset)
        mask.save(output_path)

    if main_idx < 29:
        return predict, (main_animal, "background")
    else:
        return predict, ("background", main_animal)
예제 #10
0
def test(args):
    # output folder
    outdir = args.save_folder
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    testset = get_segmentation_dataset(args.dataset,
                                       split=args.split,
                                       mode=args.mode,
                                       transform=input_transform)
    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
    else:
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       dilated=args.dilated,
                                       multi_grid=args.multi_grid,
                                       stride=args.stride,
                                       lateral=args.lateral,
                                       jpu=args.jpu,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    # print(model)
    scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
        [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    if not args.ms:
        scales = [1.0]
    evaluator = MultiEvalModule(model,
                                testset.num_class,
                                scales=scales,
                                flip=args.ms).cuda()
    evaluator.eval()
    metric = utils.SegmentationMetric(testset.num_class)

    tbar = tqdm(test_data)
    total_inter, total_union, total_correct, total_label, all_label = 0, 0, 0, 0, 0

    # for i, (image, dst) in enumerate(tbar):
    #     # print(dst)
    #     with torch.no_grad():
    #         outputs = evaluator.parallel_forward(image)[0]
    #         correct, labeled = batch_pix_accuracy(outputs, dst[0])
    #         total_correct += correct
    #         all_label += labeled
    #         img_pixAcc = 1.0 * correct / (np.spacing(1) + labeled)

    #         inter, union, area_pred, area_lab = batch_intersection_union(outputs, dst[0], testset.num_class)
    #         total_label += area_lab
    #         total_inter += inter
    #         total_union += union

    #         class_pixAcc = 1.0 * inter / (np.spacing(1) + area_lab)
    #         class_IoU = 1.0 * inter / (np.spacing(1) + union)
    #         class_mIoU = class_IoU.mean()
    #         print("img pixAcc:", img_pixAcc)
    #         print("img Classes pixAcc:", class_pixAcc)
    #         print("img Classes IoU:", class_IoU)
    # total_pixAcc = 1.0 * total_correct / (np.spacing(1) + all_label)
    # pixAcc = 1.0 * total_inter / (np.spacing(1) + total_label)
    # IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
    # mIoU = IoU.mean()

    # print("set pixAcc:", pixAcc)
    # print("set Classes pixAcc:", pixAcc)
    # print("set Classes IoU:", IoU)
    # print("set mean IoU:", mIoU)

    for i, (image, dst) in enumerate(tbar):
        if 'val' in args.mode:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                # metric.update(dst[0], predicts[0])
                # pixAcc, mIoU = metric.get()
                # tbar.set_description( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
        else:
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
예제 #11
0
                       'crop_size': args.crop_size}
        trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
        trainloader = data.DataLoader(ReturnFirstClosure(trainset), batch_size=args.batch_size,
                                      drop_last=True, shuffle=True, **loader_kwargs)
        print('Reseting BN statistics')
        #model.apply(reset_bn_statistics)
        model.cuda()
        update_bn_stats(model, trainloader)

    if args.export:
        torch.save(model.state_dict(), args.export + '.pth')
        return

    scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
            [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]#, 2.0
    evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
    evaluator.eval()
    metric = utils.SegmentationMetric(testset.num_class)

    tbar = tqdm(test_data)
    for i, (image, dst) in enumerate(tbar):
        if args.eval:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
        else:
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                predicts = [testset.make_pred(torch.max(output, 1)[1].cpu().numpy())