def eval(opt): Utils.set_seeds(opt) device = Utils.get_device(opt.cuda) set_paths(opt) # DATASET dataset = get_aligned_dataset(opt, "val") input_dataset = CropDataset(dataset, lambda x: x[0:dataset.A_nc, :, :]) # GENERATOR G = Unet(opt, opt.generator_channels, dataset.A_nc, dataset.B_nc).to(device) G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model))) G.eval() # EVALUATE: Generate some images using test set and noise as conditional input G_input_data = DataLoader(GeneratorInputDataset(input_dataset, G_noise), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=False) G_inputs = InfiniteDataSampler(G_input_data) generate_images(G, G_inputs, opt.gen_path, 100, device, lambda x : Utils.create_image_pair(x, dataset.A_nc, dataset.B_nc)) # EVALUATE for Cityscapes if opt.dataset == "cityscapes": writer = SummaryWriter(opt.log_path) val_input_data = DataLoader(dataset, num_workers=int(opt.workers),batch_size=opt.batchSize) pixel_error = get_pixel_acc(opt, device, G, val_input_data, G_noise) print("VALIDATION PERFORMANCE Pixel: " + str(pixel_error)) writer.add_scalar("val_pix", pixel_error) L2_error = get_L2(opt, device, G, val_input_data, G_noise) print("VALIDATION PERFORMANCE L2: " + str(L2_error)) writer.add_scalar("val_L2", L2_error)
def train(opt): Utils.set_seeds(opt) device = Utils.get_device(opt.cuda) set_paths(opt) # DATA dataset = get_aligned_dataset(opt, "train") nc = dataset.A_nc + dataset.B_nc # Warning if desired number of joint samples is larger than dataset, in that case, use whole dataset as paired if opt.num_joint_samples > len(dataset): print("WARNING: Cannot train with " + str(opt.num_joint_samples) + " samples, dataset has only size of " + str(len(dataset)) + ". Using full dataset!") opt.num_joint_samples = len(dataset) # Joint samples dataset_train = Subset(dataset, range(opt.num_joint_samples)) train_joint = InfiniteDataSampler( DataLoader(dataset_train, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True)) if opt.factorGAN == 1: # For marginals, take full dataset and crop train_a = InfiniteDataSampler( DataLoader(CropDataset(dataset, lambda x: x[0:dataset.A_nc, :, :]), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True)) train_b = InfiniteDataSampler( DataLoader(CropDataset(dataset, lambda x: x[dataset.A_nc:, :, :]), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True)) # SETUP GENERATOR MODEL G = ConvGenerator(opt, opt.generator_channels, opt.loadSize, nc).to(device) G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) G_opt = Utils.create_optim(G.parameters(), opt) # Prepare data sources that are a combination of real data and generator network, or purely from the generator network G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True) G_inputs = InfiniteDataSampler(G_input_data) G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device) # SETUP DISCRIMINATOR(S) if opt.factorGAN == 1: # Setup disc networks D1 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.A_nc, opt.disc_channels).to(device) D2 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.B_nc, opt.disc_channels).to(device) # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step if opt.use_real_dep_disc == 1: DP = ConvDiscriminator( opt.loadSize, opt.loadSize, nc, opt.disc_channels, spectral_norm=(opt.lipschitz_p == 1)).to(device) else: DP = lambda x: 0 DQ = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, opt.disc_channels).to(device) print(sum(p.numel() for p in D1.parameters())) # Prepare discriminators for training method # Marginal discriminators D1_setup = DiscriminatorSetup( "D1", D1, Utils.create_optim(D1.parameters(), opt), train_a, G_outputs, crop_fake=lambda x: x[:, 0:dataset.A_nc, :, :]) D2_setup = DiscriminatorSetup( "D2", D2, Utils.create_optim(D2.parameters(), opt), train_b, G_outputs, crop_fake=lambda x: x[:, dataset.A_nc:, :, :]) D_setups = [D1_setup, D2_setup] # Dependency discriminators shuffle_batch_func = lambda x: Utils.shuffle_batch_dims( x, [dataset.A_nc]) if opt.use_real_dep_disc: DP_setup = DependencyDiscriminatorSetup( "DP", DP, Utils.create_optim(DP.parameters(), opt), train_joint, shuffle_batch_func) else: DP_setup = None DQ_setup = DependencyDiscriminatorSetup( "DQ", DQ, Utils.create_optim(DQ.parameters(), opt), G_outputs, shuffle_batch_func) D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)] else: D = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, opt.disc_channels).to(device) print(sum(p.numel() for p in D.parameters())) D_setups = [ DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt), train_joint, G_outputs) ] D_dep_setups = [] # RUN TRAINING training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path) torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G"))
def eval(opt): device = Utils.get_device(opt.cuda) set_paths(opt) # Get test dataset dataset = get_aligned_dataset(opt, "val") nc = dataset.A_nc + dataset.B_nc # SETUP GENERATOR MODEL G = ConvGenerator(opt, opt.generator_channels, opt.loadSize, nc).to(device) G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) # Prepare data sources that are a combination of real data and generator network, or purely from the generator network G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True) G_inputs = InfiniteDataSampler(G_input_data) G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device) G.load_state_dict( torch.load(os.path.join(opt.experiment_path, opt.eval_model))) G.eval() # EVALUATE # GENERATE EXAMPLES generate_images( G, G_inputs, opt.gen_path, 1000, device, lambda x: Utils.create_image_pair(x, dataset.A_nc, dataset.B_nc)) # COMPUTE LS DISTANCE # Partition into test train and test test test_train_samples = int(0.8 * float(len(dataset))) test_test_samples = len(dataset) - test_train_samples print("VALIDATION SAMPLES: " + str(test_train_samples)) print("TEST SAMPLES: " + str(test_test_samples)) real_test_train_loader = DataLoader(Subset(dataset, range(test_train_samples)), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True) real_test_test_loader = DataLoader(Subset( dataset, range(test_train_samples, len(dataset))), num_workers=int(opt.workers), batch_size=opt.batchSize) # Initialise classifier classifier_factory = lambda: ConvDiscriminator(opt.loadSize, opt.loadSize, nc, filters=opt.ls_channels, spectral_norm=False).to( device) # Compute metric losses = LS.compute_ls_metric(classifier_factory, real_test_train_loader, real_test_test_loader, G_outputs, opt.ls_runs, device) # WRITE RESULTS INTO CSV FOR LATER ANALYSIS file_existed = os.path.exists(os.path.join(opt.experiment_path, "LS.csv")) with open(os.path.join(opt.experiment_path, "LS.csv"), "a") as csv_file: writer = csv.writer(csv_file) model = "factorGAN" if opt.factorGAN else "gan" if not file_existed: writer.writerow([ "LS", "Model", "Samples", "Dataset", "Samples_Validation", "Samples_Test" ]) for val in losses: writer.writerow([ val, model, opt.num_joint_samples, opt.dataset, test_train_samples, test_test_samples ])