def main(): global args, logger args = get_parser() check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] colors = np.loadtxt(args.colors_path).astype('uint8') if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = False #True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path)) with open(args.image) as f: image_files = f.read().splitlines() for file in image_files: image = file.split() image = os.path.join('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/', image[0]) test(model.eval(), image, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, colors) if (args.image).split('/')[-1] == 'training.txt' train_label_list = os.listdir('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/train_label') with open('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/training.txt', 'w') as f: for label in train_label_list: f.write('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/train_label/'+label+'\n') else: val_label_list = os.listdir('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/val_label/') with open('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/validation.txt', 'w') as f: for label in val_label_list: f.write('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/val_label/'+label+'\n')
def main(): global args, logger args = get_parser() check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] colors = np.loadtxt(args.colors_path).astype('uint8') if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) paths = glob.glob(args.image + '/scene*/color/*00.jpg') for path in paths: test(model.eval(), path, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, colors)