opt.use_SN = True opt.correlation_renormalize = True opt.NL_use_mask = True opt.NL_fusion_method = "combine" opt.non_local = "Setting_42" opt.name = "mapping_scratch" opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") if __name__ == "__main__": opt = TestOptions().parse(save=False) parameter_set(opt) model = Pix2PixHDModel_Mapping() model.initialize(opt) model.eval() if not os.path.exists(opt.outputs_dir + "/" + "input_image"): os.makedirs(opt.outputs_dir + "/" + "input_image") if not os.path.exists(opt.outputs_dir + "/" + "restored_image"): os.makedirs(opt.outputs_dir + "/" + "restored_image") if not os.path.exists(opt.outputs_dir + "/" + "origin"): os.makedirs(opt.outputs_dir + "/" + "origin") dataset_size = 0 input_loader = os.listdir(opt.test_input) dataset_size = len(os.listdir(opt.test_input))
def test(input_opts): opt = TestOptions().parse(_input_opts=input_opts, save=False) parameter_set(opt) model = Pix2PixHDModel_Mapping() model.initialize(opt) model.eval() if not os.path.exists(opt.outputs_dir + "/" + "input_image"): os.makedirs(opt.outputs_dir + "/" + "input_image") if not os.path.exists(opt.outputs_dir + "/" + "restored_image"): os.makedirs(opt.outputs_dir + "/" + "restored_image") if not os.path.exists(opt.outputs_dir + "/" + "origin"): os.makedirs(opt.outputs_dir + "/" + "origin") dataset_size = 0 input_loader = os.listdir(opt.test_input) dataset_size = len(input_loader) input_loader.sort() # dataset_size = len(input_loader) if opt.test_mask != "": mask_loader = os.listdir(opt.test_mask) dataset_size = len(os.listdir(opt.test_mask)) mask_loader.sort() img_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) mask_transform = transforms.ToTensor() for i in range(dataset_size): input_name = input_loader[i] input_file = os.path.join(opt.test_input, input_name) if not os.path.isfile(input_file): print("Skipping non-file %s" % input_name) continue input = Image.open(input_file).convert("RGB") print("Now you are processing %s" % (input_name)) if opt.NL_use_mask: mask_name = mask_loader[i] mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB") origin = input input = irregular_hole_synthesize(input, mask) mask = mask_transform(mask) mask = mask[:1, :, :] ## Convert to single channel mask = mask.unsqueeze(0) input = img_transform(input) input = input.unsqueeze(0) else: if opt.test_mode == "Scale": input = data_transforms(input, scale=True) if opt.test_mode == "Full": input = data_transforms(input, scale=False) if opt.test_mode == "Crop": input = data_transforms_rgb_old(input) origin = input input = img_transform(input) input = input.unsqueeze(0) mask = torch.zeros_like(input) ### Necessary input try: generated = model.inference(input, mask) except Exception as ex: print("Skip %s due to an error:\n%s" % (input_name, str(ex))) continue if input_name.endswith(".jpg"): input_name = input_name[:-4] + ".png" image_grid = vutils.save_image( (input + 1.0) / 2.0, opt.outputs_dir + "/input_image/" + input_name, nrow=1, padding=0, normalize=True, ) image_grid = vutils.save_image( (generated.data.cpu() + 1.0) / 2.0, opt.outputs_dir + "/restored_image/" + input_name, nrow=1, padding=0, normalize=True, ) origin.save(opt.outputs_dir + "/origin/" + input_name)