def main():
    global args, logger
    args = get_parser()
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    colors = np.loadtxt(args.colors_path).astype('uint8')

    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False)
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact,
                       shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False)
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = False #True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
    with open(args.image) as f:
        image_files = f.read().splitlines()
        for file in image_files:
            image = file.split()
            image = os.path.join('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/', image[0])
            test(model.eval(), image, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, colors)

    if (args.image).split('/')[-1] == 'training.txt'
        train_label_list = os.listdir('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/train_label')
        with open('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/training.txt', 'w') as f:
            for label in train_label_list:
                f.write('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/train_label/'+label+'\n')
    else:
        val_label_list = os.listdir('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/val_label/')
        with open('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/validation.txt', 'w') as f:
            for label in val_label_list:
                f.write('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/val_label/'+label+'\n')
Esempio n. 2
0
def main():
    global args, logger
    args = get_parser()
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    colors = np.loadtxt(args.colors_path).astype('uint8')

    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       pretrained=False)
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       pretrained=False)
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))
    paths = glob.glob(args.image + '/scene*/color/*00.jpg')
    for path in paths:
        test(model.eval(), path, args.classes, mean, std, args.base_size,
             args.test_h, args.test_w, args.scales, colors)