Exemplo n.º 1
0
def eval(opt):
    Utils.set_seeds(opt)
    device = Utils.get_device(opt.cuda)
    set_paths(opt)

    # DATASET
    dataset = get_aligned_dataset(opt, "val")
    input_dataset = CropDataset(dataset, lambda x: x[0:dataset.A_nc, :, :])

    # GENERATOR
    G = Unet(opt, opt.generator_channels, dataset.A_nc, dataset.B_nc).to(device)
    G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
    G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model)))
    G.eval()

    # EVALUATE: Generate some images using test set and noise as conditional input
    G_input_data = DataLoader(GeneratorInputDataset(input_dataset, G_noise), num_workers=int(opt.workers),
                              batch_size=opt.batchSize, shuffle=False)
    G_inputs = InfiniteDataSampler(G_input_data)

    generate_images(G, G_inputs, opt.gen_path, 100, device, lambda x : Utils.create_image_pair(x, dataset.A_nc, dataset.B_nc))

    # EVALUATE for Cityscapes
    if opt.dataset == "cityscapes":
        writer = SummaryWriter(opt.log_path)
        val_input_data = DataLoader(dataset, num_workers=int(opt.workers),batch_size=opt.batchSize)

        pixel_error = get_pixel_acc(opt, device, G, val_input_data, G_noise)
        print("VALIDATION PERFORMANCE Pixel: " + str(pixel_error))
        writer.add_scalar("val_pix", pixel_error)

        L2_error = get_L2(opt, device, G, val_input_data, G_noise)
        print("VALIDATION PERFORMANCE L2: " + str(L2_error))
        writer.add_scalar("val_L2", L2_error)
Exemplo n.º 2
0
def train(opt):
    Utils.set_seeds(opt)
    device = Utils.get_device(opt.cuda)
    set_paths(opt)

    # DATA
    dataset = get_aligned_dataset(opt, "train")
    nc = dataset.A_nc + dataset.B_nc

    # Warning if desired number of joint samples is larger than dataset, in that case, use whole dataset as paired
    if opt.num_joint_samples > len(dataset):
        print("WARNING: Cannot train with " + str(opt.num_joint_samples) +
              " samples, dataset has only size of " + str(len(dataset)) +
              ". Using full dataset!")
        opt.num_joint_samples = len(dataset)

    # Joint samples
    dataset_train = Subset(dataset, range(opt.num_joint_samples))
    train_joint = InfiniteDataSampler(
        DataLoader(dataset_train,
                   num_workers=int(opt.workers),
                   batch_size=opt.batchSize,
                   shuffle=True,
                   drop_last=True))

    if opt.factorGAN == 1:
        # For marginals, take full dataset and crop
        train_a = InfiniteDataSampler(
            DataLoader(CropDataset(dataset, lambda x: x[0:dataset.A_nc, :, :]),
                       num_workers=int(opt.workers),
                       batch_size=opt.batchSize,
                       shuffle=True))
        train_b = InfiniteDataSampler(
            DataLoader(CropDataset(dataset, lambda x: x[dataset.A_nc:, :, :]),
                       num_workers=int(opt.workers),
                       batch_size=opt.batchSize,
                       shuffle=True))

    # SETUP GENERATOR MODEL
    G = ConvGenerator(opt, opt.generator_channels, opt.loadSize, nc).to(device)
    G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz),
                                                  torch.Tensor([1] * opt.nz))
    G_opt = Utils.create_optim(G.parameters(), opt)

    # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
    G_input_data = DataLoader(GeneratorInputDataset(None, G_noise),
                              num_workers=int(opt.workers),
                              batch_size=opt.batchSize,
                              shuffle=True)
    G_inputs = InfiniteDataSampler(G_input_data)
    G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G,
                                     device)

    # SETUP DISCRIMINATOR(S)
    if opt.factorGAN == 1:
        # Setup disc networks
        D1 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.A_nc,
                               opt.disc_channels).to(device)
        D2 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.B_nc,
                               opt.disc_channels).to(device)
        # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step
        if opt.use_real_dep_disc == 1:
            DP = ConvDiscriminator(
                opt.loadSize,
                opt.loadSize,
                nc,
                opt.disc_channels,
                spectral_norm=(opt.lipschitz_p == 1)).to(device)
        else:
            DP = lambda x: 0

        DQ = ConvDiscriminator(opt.loadSize, opt.loadSize, nc,
                               opt.disc_channels).to(device)
        print(sum(p.numel() for p in D1.parameters()))

        # Prepare discriminators for training method
        # Marginal discriminators
        D1_setup = DiscriminatorSetup(
            "D1",
            D1,
            Utils.create_optim(D1.parameters(), opt),
            train_a,
            G_outputs,
            crop_fake=lambda x: x[:, 0:dataset.A_nc, :, :])
        D2_setup = DiscriminatorSetup(
            "D2",
            D2,
            Utils.create_optim(D2.parameters(), opt),
            train_b,
            G_outputs,
            crop_fake=lambda x: x[:, dataset.A_nc:, :, :])
        D_setups = [D1_setup, D2_setup]

        # Dependency discriminators
        shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(
            x, [dataset.A_nc])
        if opt.use_real_dep_disc:
            DP_setup = DependencyDiscriminatorSetup(
                "DP", DP, Utils.create_optim(DP.parameters(), opt),
                train_joint, shuffle_batch_func)
        else:
            DP_setup = None

        DQ_setup = DependencyDiscriminatorSetup(
            "DQ", DQ, Utils.create_optim(DQ.parameters(), opt), G_outputs,
            shuffle_batch_func)
        D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)]
    else:
        D = ConvDiscriminator(opt.loadSize, opt.loadSize, nc,
                              opt.disc_channels).to(device)
        print(sum(p.numel() for p in D.parameters()))
        D_setups = [
            DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt),
                               train_joint, G_outputs)
        ]
        D_dep_setups = []

    # RUN TRAINING
    training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups,
                                       D_dep_setups, device, opt.log_path)
    torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G"))
Exemplo n.º 3
0
def eval(opt):
    device = Utils.get_device(opt.cuda)
    set_paths(opt)

    # Get test dataset
    dataset = get_aligned_dataset(opt, "val")
    nc = dataset.A_nc + dataset.B_nc

    # SETUP GENERATOR MODEL
    G = ConvGenerator(opt, opt.generator_channels, opt.loadSize, nc).to(device)
    G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz),
                                                  torch.Tensor([1] * opt.nz))

    # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
    G_input_data = DataLoader(GeneratorInputDataset(None, G_noise),
                              num_workers=int(opt.workers),
                              batch_size=opt.batchSize,
                              shuffle=True)
    G_inputs = InfiniteDataSampler(G_input_data)
    G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G,
                                     device)
    G.load_state_dict(
        torch.load(os.path.join(opt.experiment_path, opt.eval_model)))
    G.eval()

    # EVALUATE
    # GENERATE EXAMPLES
    generate_images(
        G, G_inputs, opt.gen_path, 1000, device,
        lambda x: Utils.create_image_pair(x, dataset.A_nc, dataset.B_nc))

    # COMPUTE LS DISTANCE
    # Partition into test train and test test
    test_train_samples = int(0.8 * float(len(dataset)))
    test_test_samples = len(dataset) - test_train_samples
    print("VALIDATION SAMPLES: " + str(test_train_samples))
    print("TEST SAMPLES: " + str(test_test_samples))
    real_test_train_loader = DataLoader(Subset(dataset,
                                               range(test_train_samples)),
                                        num_workers=int(opt.workers),
                                        batch_size=opt.batchSize,
                                        shuffle=True,
                                        drop_last=True)
    real_test_test_loader = DataLoader(Subset(
        dataset, range(test_train_samples, len(dataset))),
                                       num_workers=int(opt.workers),
                                       batch_size=opt.batchSize)

    # Initialise classifier
    classifier_factory = lambda: ConvDiscriminator(opt.loadSize,
                                                   opt.loadSize,
                                                   nc,
                                                   filters=opt.ls_channels,
                                                   spectral_norm=False).to(
                                                       device)
    # Compute metric
    losses = LS.compute_ls_metric(classifier_factory, real_test_train_loader,
                                  real_test_test_loader, G_outputs,
                                  opt.ls_runs, device)

    # WRITE RESULTS INTO CSV FOR LATER ANALYSIS
    file_existed = os.path.exists(os.path.join(opt.experiment_path, "LS.csv"))
    with open(os.path.join(opt.experiment_path, "LS.csv"), "a") as csv_file:
        writer = csv.writer(csv_file)
        model = "factorGAN" if opt.factorGAN else "gan"
        if not file_existed:
            writer.writerow([
                "LS", "Model", "Samples", "Dataset", "Samples_Validation",
                "Samples_Test"
            ])
        for val in losses:
            writer.writerow([
                val, model, opt.num_joint_samples, opt.dataset,
                test_train_samples, test_test_samples
            ])
Exemplo n.º 4
0
def train(opt):
    print("Using " + str(opt.num_joint_samples) + " joint samples!")
    Utils.set_seeds(opt)
    device = Utils.get_device(opt.cuda)

    # DATA
    MNIST_dim = 784
    dataset = datasets.MNIST('datasets', train=True, download=True)

    # Create partitions of stacked MNIST
    dataset_joint = DoubleMNISTDataset(
        dataset,
        range(opt.num_joint_samples),
        same_digit_prob=opt.mnist_same_digit_prob)
    train_joint = InfiniteDataSampler(
        DataLoader(dataset_joint,
                   num_workers=int(opt.workers),
                   batch_size=opt.batchSize,
                   shuffle=True))
    if opt.factorGAN == 1:
        # For marginals, take full dataset and crop it
        full_dataset = DoubleMNISTDataset(
            dataset, None, same_digit_prob=opt.mnist_same_digit_prob)
        train_x1 = InfiniteDataSampler(
            DataLoader(CropDataset(full_dataset, lambda x: x[:MNIST_dim]),
                       num_workers=int(opt.workers),
                       batch_size=opt.batchSize,
                       shuffle=True))
        train_x2 = InfiniteDataSampler(
            DataLoader(CropDataset(full_dataset, lambda x: x[MNIST_dim:]),
                       num_workers=int(opt.workers),
                       batch_size=opt.batchSize,
                       shuffle=True))

    # SETUP GENERATOR MODEL
    G = FCGenerator(opt, 2 * MNIST_dim).to(device)
    G.train()
    G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz),
                                                  torch.Tensor([1] * opt.nz))
    G_opt = Utils.create_optim(G.parameters(), opt)

    # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
    G_input_data = DataLoader(GeneratorInputDataset(None, G_noise),
                              num_workers=int(opt.workers),
                              batch_size=opt.batchSize,
                              shuffle=True)
    G_inputs = InfiniteDataSampler(G_input_data)
    G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G,
                                     device)

    # SETUP DISCRIMINATOR(S)
    if opt.factorGAN == 1:
        # Setup disc networks
        D1 = FCDiscriminator(MNIST_dim).to(device)
        D2 = FCDiscriminator(MNIST_dim).to(device)
        # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step
        if opt.use_real_dep_disc == 1:
            DP = FCDiscriminator(
                2 * MNIST_dim, spectral_norm=(opt.lipschitz_p == 1)).to(device)
        else:
            DP = lambda x: 0

        DQ = FCDiscriminator(2 * MNIST_dim).to(device)

        # Prepare discriminators for training method
        # Marginal discriminators
        D1_setup = DiscriminatorSetup("D1",
                                      D1,
                                      Utils.create_optim(D1.parameters(), opt),
                                      train_x1,
                                      G_outputs,
                                      crop_fake=lambda x: x[:, :MNIST_dim])
        D2_setup = DiscriminatorSetup("D2",
                                      D2,
                                      Utils.create_optim(D2.parameters(), opt),
                                      train_x2,
                                      G_outputs,
                                      crop_fake=lambda x: x[:, MNIST_dim:])
        D_setups = [D1_setup, D2_setup]

        # Dependency discriminators
        shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(
            x, marginal_index=MNIST_dim)

        if opt.use_real_dep_disc:
            DP_setup = DependencyDiscriminatorSetup(
                "DP", DP, Utils.create_optim(DP.parameters(), opt),
                train_joint, shuffle_batch_func)
        else:
            DP_setup = None
        DQ_setup = DependencyDiscriminatorSetup(
            "DQ", DQ, Utils.create_optim(DQ.parameters(), opt), G_outputs,
            shuffle_batch_func)
        D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)]
    else:
        D = FCDiscriminator(2 * MNIST_dim).to(device)
        D_setups = [
            DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt),
                               train_joint, G_outputs)
        ]
        D_dep_setups = []

    # RUN TRAINING
    training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups,
                                       D_dep_setups, device, opt.log_path)
    torch.save(G.state_dict(), os.path.join(opt.out_path, "G"))
Exemplo n.º 5
0
def eval(opt):
    print("EVALUATING MNIST MODEL...")
    MNIST_dim = 784
    device = Utils.get_device(opt.cuda)

    # Train and save a digit classification model, needed for factorGAN variants and evaluation
    classifier = MNIST.main(opt)
    classifier.to(device)
    classifier.eval()

    # SETUP GENERATOR MODEL
    G = FCGenerator(opt, 2 * MNIST_dim).to(device)
    G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz),
                                                  torch.Tensor([1] * opt.nz))
    # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
    G_input_data = DataLoader(GeneratorInputDataset(None, G_noise),
                              num_workers=int(opt.workers),
                              batch_size=opt.batchSize,
                              shuffle=True)
    G_inputs = InfiniteDataSampler(G_input_data)

    G.load_state_dict(
        torch.load(os.path.join(opt.experiment_path, opt.eval_model)))
    G.eval()

    # EVALUATE: Save images to eyeball them + FID for marginals + Class probability correlations
    writer = SummaryWriter(opt.log_path)

    test_mnist = datasets.MNIST('datasets', train=False, download=True)
    test_dataset = DoubleMNISTDataset(
        test_mnist, None, same_digit_prob=opt.mnist_same_digit_prob)
    test_dataset_loader = DataLoader(test_dataset,
                                     num_workers=int(opt.workers),
                                     batch_size=opt.batchSize,
                                     shuffle=True)
    transform_func = lambda x: x.view(-1, 1, 56, 28)
    Visualisation.generate_images(G, G_inputs, opt.gen_path, len(test_dataset),
                                  device, transform_func)

    crop_upper = lambda x: x[:, :, :28, :]
    crop_lower = lambda x: x[:, :, 28:, :]
    fid_upper = FID.evaluate_MNIST(opt,
                                   classifier,
                                   test_dataset_loader,
                                   opt.gen_path,
                                   device,
                                   crop_real=crop_upper,
                                   crop_fake=crop_upper)
    fid_lower = FID.evaluate_MNIST(opt,
                                   classifier,
                                   test_dataset_loader,
                                   opt.gen_path,
                                   device,
                                   crop_real=crop_lower,
                                   crop_fake=crop_lower)
    print("FID Upper Digit: " + str(fid_upper))
    print("FID Lower Digit: " + str(fid_lower))
    writer.add_scalar("FID_lower", fid_lower)
    writer.add_scalar("FID_upper", fid_upper)

    # ESTIMATE QUALITY OF DEPENDENCY MODELLING
    # cp(...) = cq(...) ideally for all inputs on the test set if dependencies are perfectly modelled. So compute average of that value and take difference to 1
    # Get joint distribution of real class indices in the data
    test_dataset = DoubleMNISTDataset(
        test_mnist,
        None,
        same_digit_prob=opt.mnist_same_digit_prob,
        deterministic=True,
        return_labels=True)
    test_it = DataLoader(test_dataset)
    real_class_probs = np.zeros((10, 10))
    for sample in test_it:
        _, d1, d2 = sample
        real_class_probs[d1, d2] += 1
    real_class_probs /= np.sum(real_class_probs)

    # Compute marginal distribution of real class indices from joint one
    real_class_probs_upper = np.sum(real_class_probs, axis=1)  # a
    real_class_probs_lower = np.sum(real_class_probs, axis=0)  # b
    real_class_probs_marginal = real_class_probs_upper * np.reshape(
        real_class_probs_lower, [-1, 1])

    # Get joint distribution of class indices on generated data (using classifier predictions)
    fake_class_probs = get_class_prob_matrix(G, G_inputs, classifier,
                                             len(test_dataset), device)
    # Compute marginal distribution of class indices on generated data
    fake_class_probs_upper = np.sum(fake_class_probs, axis=1)
    fake_class_probs_lower = np.sum(fake_class_probs, axis=0)
    fake_class_probs_marginal = fake_class_probs_upper * np.reshape(
        fake_class_probs_lower, [-1, 1])

    # Compute average of |cp(...) - cq(...)|
    cp = np.divide(real_class_probs, real_class_probs_marginal + 0.001)
    cq = np.divide(fake_class_probs, fake_class_probs_marginal + 0.001)

    diff_metric = np.mean(np.abs(cp - cq))

    print("Dependency cp/cq diff metric: " + str(diff_metric))
    writer.add_scalar("Diff-Dep", diff_metric)

    return fid_upper, fid_lower
Exemplo n.º 6
0
def train(opt):
    Utils.set_seeds(opt)
    device = Utils.get_device(opt.cuda)
    set_paths(opt)

    if opt.num_joint_songs > 100:
        print("ERROR: Cannot train with " + str(opt.num_joint_songs) +
              " samples, dataset has only size of 100")
        return

    # Partition into paired and unpaired songs
    idx = [i for i in range(100)]
    random.shuffle(idx)

    # Joint samples
    dataset_train = MUSDBDataset(opt, idx[:opt.num_joint_songs], "paired")
    train_joint = InfiniteDataSampler(
        DataLoader(dataset_train,
                   num_workers=int(opt.workers),
                   batch_size=opt.batchSize,
                   shuffle=True,
                   drop_last=True))

    if opt.factorGAN == 1:
        # For marginals, take full dataset
        mix_dataset = MUSDBDataset(opt, idx, "mix")

        acc_dataset = MUSDBDataset(opt, idx, "accompaniment")
        acc_loader = InfiniteDataSampler(
            DataLoader(acc_dataset,
                       num_workers=int(opt.workers),
                       batch_size=opt.batchSize,
                       shuffle=True,
                       drop_last=True))

        vocal_dataset = MUSDBDataset(opt, idx, "vocals")
        vocal_loader = InfiniteDataSampler(
            DataLoader(vocal_dataset,
                       num_workers=int(opt.workers),
                       batch_size=opt.batchSize,
                       shuffle=True,
                       drop_last=True))
    else:  # For normal GAN, take only few joint songs
        mix_dataset = MUSDBDataset(opt, idx[:opt.num_joint_songs], "mix")

    # SETUP GENERATOR MODEL
    G = Unet(opt, opt.generator_channels, 1, 1).to(
        device)  # 1 input channel (mixture), 1 output channel (mask)
    G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz),
                                                  torch.Tensor([1] * opt.nz))
    G_opt = Utils.create_optim(G.parameters(), opt)

    # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
    G_input_data = DataLoader(GeneratorInputDataset(mix_dataset, G_noise),
                              num_workers=int(opt.workers),
                              batch_size=opt.batchSize,
                              shuffle=True,
                              drop_last=True)
    G_inputs = InfiniteDataSampler(G_input_data)
    G_filled_outputs = TransformDataSampler(InfiniteDataSampler(G_inputs), G,
                                            device)

    # SETUP DISCRIMINATOR(S)
    crop_mix = lambda x: x[:, 1:, :, :
                           ]  # Only keep sources, not mixture for dep discs
    if opt.factorGAN == 1:
        # Setup marginal disc networks
        D_voc = ConvDiscriminator(opt.input_height, opt.input_width, 1,
                                  opt.disc_channels).to(device)
        D_acc = ConvDiscriminator(opt.input_height, opt.input_width, 1,
                                  opt.disc_channels).to(device)

        D_acc_setup = DiscriminatorSetup("D_acc",
                                         D_acc,
                                         Utils.create_optim(
                                             D_acc.parameters(), opt),
                                         acc_loader,
                                         G_filled_outputs,
                                         crop_fake=lambda x: x[:, 1:2, :, :])

        D_voc_setup = DiscriminatorSetup("D_voc",
                                         D_voc,
                                         Utils.create_optim(
                                             D_voc.parameters(), opt),
                                         vocal_loader,
                                         G_filled_outputs,
                                         crop_fake=lambda x: x[:, 2:3, :, :])
        # Marginal discriminator
        D_setups = [D_acc_setup, D_voc_setup]

        # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step
        if opt.use_real_dep_disc == 1:
            DP = ConvDiscriminator(
                opt.input_height,
                opt.input_width,
                2,
                opt.disc_channels,
                spectral_norm=(opt.lipschitz_p == 1)).to(device)
        else:
            DP = lambda x: 0

        DQ = ConvDiscriminator(opt.input_height, opt.input_width, 2,
                               opt.disc_channels).to(device)

        # Dependency discriminators
        shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(
            x, 1
        )  # Randomly mixes different sources together (e.g. accompaniment from one song with vocals from another)

        if opt.use_real_dep_disc:
            DP_setup = DependencyDiscriminatorSetup("DP",
                                                    DP,
                                                    Utils.create_optim(
                                                        DP.parameters(), opt),
                                                    train_joint,
                                                    shuffle_batch_func,
                                                    crop_func=crop_mix)
        else:
            DP_setup = None

        DQ_setup = DependencyDiscriminatorSetup("DQ",
                                                DQ,
                                                Utils.create_optim(
                                                    DQ.parameters(), opt),
                                                G_filled_outputs,
                                                shuffle_batch_func,
                                                crop_func=crop_mix)
        D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)]
    else:
        D = ConvDiscriminator(opt.input_height, opt.input_width, 2,
                              opt.disc_channels).to(device)

        D_setup = DiscriminatorSetup("D",
                                     D,
                                     Utils.create_optim(D.parameters(), opt),
                                     train_joint,
                                     G_filled_outputs,
                                     crop_real=crop_mix,
                                     crop_fake=crop_mix)
        D_setups = [D_setup]
        D_dep_setups = []

    # RUN TRAINING
    training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups,
                                       D_dep_setups, device, opt.log_path)
    torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G"))