def predict(image, mask, root_path, AI_directory_path, model_type="life"): device = torch.device('cuda') size = (256, 256) img_transform = transforms.Compose([ transforms.Resize(size=size), transforms.ToTensor(), transforms.Normalize(mean=opt.MEAN, std=opt.STD) ]) mask_transform = transforms.Compose( [transforms.Resize(size=size), transforms.ToTensor()]) dataset_val = Places2(root_path, image, mask, img_transform, mask_transform) model = PConvUNet().to(device) load_ckpt(AI_directory_path, [('model', model)]) model.eval() evaluate(model, dataset_val, device, image.split('.')[0] + 'result.jpg') return image.split('.')[0] + 'result.jpg'
parser = argparse.ArgumentParser() # training options parser.add_argument('--root', type=str, default='./data') parser.add_argument('--snapshot', type=str, default='') parser.add_argument('--image_size', type=int, default=256) parser.add_argument('--mask_root', type=str, default='./mask') args = parser.parse_args() device = torch.device('cuda') size = (args.image_size, args.image_size) img_transform = transforms.Compose([ transforms.Resize(size=size), transforms.ToTensor(), transforms.Normalize(mean=opt.MEAN, std=opt.STD) ]) mask_transform = transforms.Compose( [transforms.Resize(size=size), transforms.ToTensor()]) dataset_val = Places2(args.root, args.mask_root, img_transform, mask_transform, 'val') model = PConvUNet().to(device) load_ckpt(args.snapshot, [('model', model)]) model.eval() evaluate(model, dataset_val, device, 'result.jpg')