Beispiel #1
0
def train_recClass(GAN, recClass, A_data_batch, A_labels_batch, l_recClass,
                   rec_weight):
    X_noise = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size,
                                      GAN.noise_dim)
    X_noise2 = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size,
                                       GAN.noise_dim)
    recClass_loss = recClass.train_on_batch(
        [X_noise, A_data_batch, X_noise2],
        A_labels_batch,
        sample_weight=np.ones(GAN.batch_size) * rec_weight)
    l_recClass.appendleft(recClass_loss)
    return l_recClass
Beispiel #2
0
def train_gen_zclass(generator_model, DCGAN_model, zclass_model, disc_type,
                     deterministic, noise_dim, noise_scale, batch_size, l_gen,
                     l_zclass, X_source, Y_source, n_classes):
    X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
    X_source_batch2, Y_source_batch2, idx_source_batch2 = next(
        data_utils.gen_batch(X_source, Y_source, batch_size))
    if disc_type == "simple_disc":
        gen_loss = DCGAN_model.train_on_batch(
            [X_gen, X_source_batch2],
            np.ones(X_gen.shape[0]))  # TRYING SAME BATCH OF DISC
    elif disc_type == ("nclass_disc"):

        #(disc_p, class_p) = DCGAN_model.predict_on_batch(X_source_batch2)
        #idx = np.argmax(class_p, axis=1)
        #virtual_labels = (idx[:, None] == np.arange(n_classes)) * 1

        virtual_labels = np.zeros([GAN.batch_size, GAN.n_classes])
        gen_loss = DCGAN_model.train_on_batch(
            [X_gen, X_source_batch2],
            [np.ones(X_gen.shape[0]), virtual_labels])  # FIX :((
        #gen_loss = gen_loss[0]
    l_gen.appendleft(gen_loss)
    if not deterministic:
        zclass_loss = zclass_model.train_on_batch([X_gen, X_source_batch2],
                                                  [X_gen])
    else:
        zclass_loss = 0.0
    l_zclass.appendleft(zclass_loss)
    return l_gen, l_zclass
Beispiel #3
0
def testing_class_accuracy(GANs, classificator_model, generator_model,
                           vis_samples, noise_dim, noise_scale, data, labels):
    acc = []
    loss = []
    for GAN in GANs:
        if GAN.dir == 'BtoA':
            # testing accuracy of trained classifier
            X_noise = data_utils.sample_noise(GAN.noise_scale, vis_samples,
                                              GAN.noise_dim)
            Xsource_dataset_mapped = GAN.generator_model.predict(
                [X_noise, data[:vis_samples]], batch_size=1000)
            true_labels = labels[:vis_samples]
            p1 = GAN.classificator_model.predict(Xsource_dataset_mapped,
                                                 batch_size=1000,
                                                 verbose=1)
            score1 = np.sum(
                np.argmax(true_labels, axis=1) == np.argmax(
                    p1, axis=1)) / float(true_labels.shape[0])
            print(
                '\n Classifier Accuracy and loss on full target domain:  %.2f%%  '
                % ((100 * score1)))

        if GAN.dir == 'AtoB':
            X_noise = data_utils.sample_noise(GAN.noise_scale, vis_samples,
                                              GAN.noise_dim)
            Xsource_dataset_mapped = data[:vis_samples]
            true_labels = labels[:vis_samples]
            p2 = GAN.classificator_model.predict(Xsource_dataset_mapped,
                                                 batch_size=1000,
                                                 verbose=1)
            score2 = np.sum(
                np.argmax(true_labels, axis=1) == np.argmax(
                    p2, axis=1)) / float(true_labels.shape[0])
            print(
                '\n Classifier Accuracy and loss on full target domain:  %.2f%%  '
                % ((100 * score2)))

    res = []
    for x in np.arange(0, 1.1, 0.1):
        res.append((x,
                    np.sum(
                        np.argmax(true_labels, axis=1) == np.argmax(
                            p1 * x + p2 * (1 - x), axis=1)) /
                    float(true_labels.shape[0])))
    for (x, score) in res:
        print("\n Coeff: %f - score: %.2f" % (x, score * 100))
Beispiel #4
0
def train_rec(GAN, rec1, rec2, A_data_batch, B_data_batch, l_rec1, l_rec2,
              rec_weight):
    X_noise = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size,
                                      GAN.noise_dim)
    X_noise2 = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size,
                                       GAN.noise_dim)
    rec_loss = rec1.train_on_batch([X_noise, A_data_batch, X_noise2],
                                   A_data_batch,
                                   sample_weight=np.ones(GAN.batch_size) *
                                   rec_weight)
    rec_loss2 = rec2.train_on_batch([X_noise, B_data_batch, X_noise2],
                                    B_data_batch,
                                    sample_weight=np.ones(GAN.batch_size) *
                                    rec_weight)
    l_rec1.appendleft(rec_loss)
    l_rec2.appendleft(rec_loss2)
    return l_rec1, l_rec2
Beispiel #5
0
def testing_class_accuracy(GANs,classificator_model, generator_model, vis_samples, noise_dim, noise_scale, data, labels):
    acc=[]
    loss=[]
    for GAN in GANs:
        if GAN.dir == 'BtoA':
            # testing accuracy of trained classifier
            X_noise = data_utils.sample_noise(GAN.noise_scale, vis_samples, GAN.noise_dim)
            Xsource_dataset_mapped = GAN.generator_model.predict(
                [X_noise, data[:vis_samples]], batch_size=1000)
            loss4, acc4 = GAN.classificator_model.evaluate(Xsource_dataset_mapped, labels[
                                               :vis_samples], batch_size=1000, verbose=0)
            
        if GAN.dir == 'AtoB':
            X_noise = data_utils.sample_noise(GAN.noise_scale, vis_samples, GAN.noise_dim)
            Xsource_dataset_mapped = data[:vis_samples]
            loss4, acc4 = GAN.classificator_model.evaluate(Xsource_dataset_mapped, labels[
                                               :vis_samples], batch_size=1000, verbose=0)
        acc.append(acc4)
        loss.append(loss4)
    print('\n Classifier Accuracy and loss on full target domain:  %.2f%% / %.5f%% /// %.2f%% / %.5f%%' %
              ((100 * acc[0]), loss[0], (100 * acc[1]), loss[1]) )
Beispiel #6
0
def evaluating_GENned(noise_scale, noise_dim, X_source_test, Y_source_test,
                      classifier, gen_model):
    #converting source test set into source-GENned set(passing through the GEN)
    n = data_utils.sample_noise(noise_scale, X_source_test.shape[0], noise_dim)
    X_gen_test = gen_model.predict([n, X_source_test],
                                   batch_size=512,
                                   verbose=0)
    loss4, acc4 = classifier.evaluate(X_gen_test,
                                      Y_source_test,
                                      batch_size=256,
                                      verbose=0)
    print('\n Classifier Accuracy on source-GENned domain:  %.2f%%' %
          (100 * acc4))
Beispiel #7
0
def train_class(GAN, l_class,  A_data_batch, A_labels_batch):
    if GAN.dir == 'AtoB':
        X_noise = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size, GAN.noise_dim)
        if GAN.data_aug:
            x_dest_batch = GAN.generator_model.predict([X_noise,datagen.output(A_data_batch)])
        else:
            x_dest_batch = GAN.generator_model.predict([X_noise,A_data_batch])
        # NO LABEL SMOOTHING!!!! inverted training w.r.t. to AtoB, because I
        # have labels of A
        class_loss = GAN.classificator_model.train_on_batch(x_dest_batch, A_labels_batch)
    elif GAN.dir == 'BtoA':
        class_loss = GAN.classificator_model.train_on_batch(A_data_batch, A_labels_batch)
    l_class.appendleft(class_loss[0])
    return l_class
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    X_real_train = data_utils.load_image_dataset(dset, img_dim,
                                                 image_dim_ordering)

    # Get the full real image dimension
    img_dim = X_real_train.shape[-3:]

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    if generator == "upsampling":
        generator_model = models.generator_upsampling(noise_dim,
                                                      img_dim,
                                                      bn_mode,
                                                      dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim,
                                                  img_dim,
                                                  bn_mode,
                                                  batch_size,
                                                  dset=dset)
    discriminator_model = models.discriminator(img_dim, bn_mode, dset=dset)
    DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim,
                               img_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            if gen_iterations < 25 or gen_iterations % 500 == 0:
                disc_iterations = 100
            else:
                disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [
                        np.clip(w, clamp_lower, clamp_upper) for w in weights
                    ]
                    l.set_weights(weights)

                X_real_batch = next(
                    data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                    X_real_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, -np.ones(X_disc_real.shape[0]))
                disc_loss_gen = discriminator_model.train_on_batch(
                    X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = X_gen = data_utils.sample_noise(noise_scale, batch_size,
                                                    noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                  -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            gen_iterations += 1
            batch_counter += 1
            progbar.add(batch_size,
                        values=[("Loss_D", -np.mean(list_disc_loss_real) -
                                 np.mean(list_disc_loss_gen)),
                                ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", -gen_loss)])

            # Save images for visualization ~2 times per epoch
            if batch_counter % (n_batch_per_epoch / 2) == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                batch_size, noise_dim,
                                                image_dim_ordering)

        print('\nEpoch %s/%s, Time: %s' %
              (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, e)
Beispiel #9
0
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    use_mbd = kwargs["use_mbd"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    pureGAN = kwargs["pureGAN"]
    lsmooth = kwargs["lsmooth"]
    simple_disc = kwargs["simple_disc"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    history_size = kwargs["history_size"]
    monsterClass = kwargs["monsterClass"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    if dset == "mnistM":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='mnist')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='mnistM')
        #code.interact(local=locals())
    elif dset == "washington_vandal50k":
        X_source_train = data_utils.load_image_dataset(img_dim,
                                                       image_dim_ordering,
                                                       dset='washington')
        X_dest_train = data_utils.load_image_dataset(img_dim,
                                                     image_dim_ordering,
                                                     dset='vandal50k')
    elif dset == "washington_vandal12classes":
        X_source_train = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='washington12classes')
        X_dest_train = data_utils.load_image_dataset(img_dim,
                                                     image_dim_ordering,
                                                     dset='vandal12classes')
    elif dset == "washington_vandal12classesNoBackground":
        X_source_train, Y_source_train, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='washington12classes')
        X_dest_train, Y_dest_train, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='vandal12classesNoBackground')
    elif dset == "Wash_Vand_12class_LMDB":
        X_source_train, Y_source_train, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='Wash_12class_LMDB')
    elif dset == "OfficeDslrToAmazon":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='OfficeDslr')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='OfficeAmazon')
    elif dset == "bedrooms":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='bedrooms_small')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='bedrooms')
    elif dset == "Vand_Vand_12class_LMDB":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='Vand_12class_LMDB_Background')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='Vand_12class_LMDB')
    else:
        print "dataset not supported"
    if n_classes1 != n_classes2:  #sanity check
        print "number of classes mismatch between source and dest domains"
    n_classes = n_classes1  #
    img_source_dim = X_source_train.shape[-3:]  # is it backend agnostic?
    img_dest_dim = X_dest_train.shape[-3:]

    # Create optimizers
    opt_D = data_utils.get_optimizer(opt_D, lr_D)
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_C = data_utils.get_optimizer('SGD', 0.01)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    if generator == "upsampling":
        generator_model = models.generator_upsampling_mnistM(noise_dim,
                                                             img_source_dim,
                                                             img_dest_dim,
                                                             bn_mode,
                                                             deterministic,
                                                             pureGAN,
                                                             inject_noise,
                                                             wd,
                                                             dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim,
                                                  img_dest_dim,
                                                  bn_mode,
                                                  batch_size,
                                                  dset=dset)

    if simple_disc:
        discriminator_model = models.discriminator_naive(
            img_dest_dim, bn_mode, model, wd, inject_noise, n_classes, use_mbd)
        DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model,
                                         noise_dim, img_source_dim)
    elif discriminator == "disc_resnet":
        discriminator_model = models.discriminatorResNet(
            img_dest_dim, bn_mode, model, wd, monsterClass, inject_noise,
            n_classes, use_mbd)
        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   noise_dim, img_source_dim, img_dest_dim,
                                   monsterClass)
    else:
        discriminator_model = models.disc1(img_dest_dim, bn_mode, model, wd,
                                           monsterClass, inject_noise,
                                           n_classes, use_mbd)
        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   noise_dim, img_source_dim, img_dest_dim,
                                   monsterClass)

    ####special options for bedrooms dataset:
    if dset == "bedrooms":
        generator_model = models.generator_dcgan(noise_dim, img_source_dim,
                                                 img_dest_dim, bn_mode,
                                                 deterministic, pureGAN,
                                                 inject_noise, wd)
        discriminator_model = models.discriminator_naive(
            img_dest_dim,
            bn_mode,
            model,
            wd,
            inject_noise,
            n_classes,
            use_mbd,
            model_name="discriminator_naive")
        DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model,
                                         noise_dim, img_source_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)

    models.make_trainable(discriminator_model, False)
    #discriminator_model.trainable = False
    if model == 'wgan':
        DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
        models.make_trainable(discriminator_model, True)
        discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
    if model == 'lsgan':
        if simple_disc:
            DCGAN_model.compile(loss=['mse'], optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(loss=['mse'], optimizer=opt_D)
        elif monsterClass:
            DCGAN_model.compile(loss=['categorical_crossentropy'],
                                optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(loss=['categorical_crossentropy'],
                                        optimizer=opt_D)
        else:
            DCGAN_model.compile(loss=['mse', 'categorical_crossentropy'],
                                loss_weights=[1.0, 1.0],
                                optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(
                loss=['mse', 'categorical_crossentropy'],
                loss_weights=[1.0, 1.0],
                optimizer=opt_D)

    visualize = True

    if resume:  ########loading previous saved model weights
        data_utils.load_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, name)

    #####################
    ###classifier
    #####################
    if not ((dset == 'mnistM') or (dset == 'bedrooms')):
        classifier, GenToClassifierModel = classifier_build_test(
            img_dest_dim,
            n_classes,
            generator_model,
            noise_dim,
            noise_scale,
            img_source_dim,
            opt_C,
            X_source_test,
            Y_source_test,
            X_dest_test,
            Y_dest_test,
            wd=0.0001)

    gen_iterations = 0
    max_history_size = int(history_size * batch_size)
    img_buffer = ImageHistoryBuffer((0, ) + img_source_dim, max_history_size,
                                    batch_size, n_classes)
    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:
            if no_supertrain is None:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                if gen_iterations % 500 == 0:
                    disc_iterations = 10
                else:
                    disc_iterations = kwargs["disc_iterations"]
            else:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            list_gen_loss = []

            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                #for l in discriminator_model.layers:
                #    weights = l.get_weights()
                #    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
                #    l.set_weights(weights)

                X_dest_batch, Y_dest_batch, idx_dest_batch = next(
                    data_utils.gen_batch(X_dest_train, Y_dest_train,
                                         batch_size))
                X_source_batch, Y_source_batch, idx_source_batch = next(
                    data_utils.gen_batch(X_source_train, Y_source_train,
                                         batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                    X_dest_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    X_source_batch,
                    noise_scale=noise_scale)
                if model == 'wgan':
                    # Update the discriminator
                    current_labels_real = -np.ones(X_disc_real.shape[0])
                    current_labels_gen = np.ones(X_disc_gen.shape[0])
                if model == 'lsgan':
                    if simple_disc:  #for real domain I put [labels 0 0 0...0], for fake domain I put [0 0...0 labels]
                        current_labels_real = np.ones(X_disc_real.shape[0])
                        #current_labels_gen = -np.ones(X_disc_gen.shape[0])
                        current_labels_gen = np.zeros(X_disc_gen.shape[0])
                    elif monsterClass:  #for real domain I put [labels 0 0 0...0], for fake domain I put [0 0...0 labels]
                        current_labels_real = np.concatenate(
                            (Y_dest_batch,
                             np.zeros((X_disc_real.shape[0], n_classes))),
                            axis=1)
                        current_labels_gen = np.concatenate((np.zeros(
                            (X_disc_real.shape[0], n_classes)),
                                                             Y_source_batch),
                                                            axis=1)
                    else:
                        current_labels_real = [
                            np.ones(X_disc_real.shape[0]), Y_dest_batch
                        ]
                        Y_fake_batch = (1.0 / n_classes) * np.ones(
                            [X_disc_gen.shape[0], n_classes])
                        current_labels_gen = [
                            np.zeros(X_disc_gen.shape[0]), Y_fake_batch
                        ]
                #label smoothing
                #current_labels_real = np.multiply(current_labels_real, lsmooth) #usually lsmooth = 0.7
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, current_labels_real)
                img_buffer.add_to_buffer(X_disc_gen, current_labels_gen,
                                         batch_size)
                bufferImages, bufferLabels = img_buffer.get_from_buffer(
                    batch_size)
                disc_loss_gen = discriminator_model.train_on_batch(
                    bufferImages, bufferLabels)

                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
            X_source_batch2, Y_source_batch2, idx_source_batch2 = next(
                data_utils.gen_batch(X_source_train, Y_source_train,
                                     batch_size))
            #            w1 = classifier.get_weights() #FOR DEBUG
            if model == 'wgan':
                gen_loss = DCGAN_model.train_on_batch([X_gen, X_source_batch2],
                                                      -np.ones(X_gen.shape[0]))
            if model == 'lsgan':
                if simple_disc:
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2],
                        np.ones(X_gen.shape[0]))  #TRYING SAME BATCH OF DISC?
                elif monsterClass:
                    labels_gen = np.concatenate(
                        (Y_source_batch2,
                         np.zeros((X_disc_real.shape[0], n_classes))),
                        axis=1)
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2], labels_gen)
                else:
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2],
                        [np.ones(X_gen.shape[0]), Y_source_batch2])
#            gen_loss2 = GenToClassifierModel.train_on_batch([X_gen,X_source_batch2], Y_source_batch2)

#            w2 = classifier.get_weights() #FOR DEBUG
#           for a,b in zip(w1, w2):
#              if np.all(a == b):
#                 print "no bug in GEN model update"
#            else:
#               print "BUG IN GEN MODEL UPDATE"
            list_gen_loss.append(gen_loss)

            gen_iterations += 1
            batch_counter += 1

            progbar.add(batch_size,
                        values=[("Loss_D", 0.5 * np.mean(list_disc_loss_real) +
                                 0.5 * np.mean(list_disc_loss_gen)),
                                ("Loss_D_real", np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", np.mean(list_gen_loss))])

            # plot images 1 times per epoch
            if batch_counter % (n_batch_per_epoch) == 0:
                X_source_batch_plot, Y_source_batch_plot, idx_source_plot = next(
                    data_utils.gen_batch(X_source_test,
                                         Y_source_test,
                                         batch_size=32))
                data_utils.plot_generated_batch(X_dest_test,
                                                X_source_test,
                                                generator_model,
                                                noise_dim,
                                                image_dim_ordering,
                                                idx_source_plot,
                                                batch_size=32)
            if gen_iterations % (n_batch_per_epoch * 5) == 0:
                if visualize:
                    BIG_ASS_VISUALIZATION_slerp(X_source_train[1],
                                                generator_model, noise_dim)

        print("Dest labels:")
        print(Y_dest_test[idx_source_plot].argmax(1))
        print("Source labels:")
        print(Y_source_batch_plot.argmax(1))
        print('\nEpoch %s/%s, Time: %s' %
              (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, e, name)
def eval(**kwargs):

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    img_dim = kwargs["img_dim"]
    cont_dim = (kwargs["cont_dim"],)
    cat_dim = (kwargs["cat_dim"],)
    noise_dim = (kwargs["noise_dim"],)
    bn_mode = kwargs["bn_mode"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    epoch = kwargs["epoch"]

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    if dset == "RGZ":
        X_real_train = data_utils.load_RGZ(img_dim, image_dim_ordering)
    if dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering)
    img_dim = X_real_train.shape[-3:]

    # Load generator model
    generator_model = models.load("generator_%s" % generator,
                                  cat_dim,
                                  cont_dim,
                                  noise_dim,
                                  img_dim,
                                  bn_mode,
                                  batch_size,
                                  dset=dset)

    # Load colorization model
    generator_model.load_weights("../../models/%s/gen_weights_epoch%s.h5" %
                                 (model_name, epoch))

    X_plot = []
    # Vary the categorical variable
    for i in range(cat_dim[0]):
        X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
        X_cont = data_utils.sample_noise(noise_scale, batch_size, cont_dim)
        X_cont = np.repeat(X_cont[:1, :], batch_size, axis=0)  # fix continuous noise
        X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32')
        X_cat[:, i] = 1  # always the same categorical value

        X_gen = generator_model.predict([X_cat, X_cont, X_noise])
        X_gen = data_utils.inverse_normalization(X_gen)

        if image_dim_ordering == "th":
            X_gen = X_gen.transpose(0,2,3,1)

        X_gen = [X_gen[i] for i in range(len(X_gen))]
        X_plot.append(np.concatenate(X_gen, axis=1))
    X_plot = np.concatenate(X_plot, axis=0)

    plt.figure(figsize=(8,10))
    if X_plot.shape[-1] == 1:
        plt.imshow(X_plot[:, :, 0], cmap="gray")
    else:
        plt.imshow(X_plot)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("Varying categorical factor", fontsize=28, labelpad=60)

    plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.tight_layout()
    plt.savefig("../../figures/varying_categorical.png")
    plt.clf()
    plt.close()

    # Vary the continuous variables
    X_plot = []
    # First get the extent of the noise sampling
    x = np.ravel(data_utils.sample_noise(noise_scale, batch_size * 20000, cont_dim))
    # Define interpolation points
    x = np.linspace(x.min(), x.max(), num=batch_size)
    for i in range(batch_size):
        X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
        X_cont = np.concatenate([np.array([x[i], x[j]]).reshape(1, -1) for j in range(batch_size)], axis=0)
        X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32')
        X_cat[:, 1] = 1  # always the same categorical value

        X_gen = generator_model.predict([X_cat, X_cont, X_noise])
        X_gen = data_utils.inverse_normalization(X_gen)
        if image_dim_ordering == "th":
            X_gen = X_gen.transpose(0,2,3,1)
        X_gen = [X_gen[i] for i in range(len(X_gen))]
        X_plot.append(np.concatenate(X_gen, axis=1))
    X_plot = np.concatenate(X_plot, axis=0)

    plt.figure(figsize=(10,10))
    if X_plot.shape[-1] == 1:
        plt.imshow(X_plot[:, :, 0], cmap="gray")
    else:
        plt.imshow(X_plot)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("Varying continuous factor 1", fontsize=28, labelpad=60)
    plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.xlabel("Varying continuous factor 2", fontsize=28, labelpad=60)
    plt.annotate('', xy=(1, -0.05), xycoords='axes fraction', xytext=(0, -0.05),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.tight_layout()
    plt.savefig("../../figures/varying_continuous.png")
    plt.clf()
    plt.close()
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    X_real_train = data_utils.load_image_dataset(dset, img_dim, image_dim_ordering)

    # Get the full real image dimension
    img_dim = X_real_train.shape[-3:]

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    if generator == "upsampling":
        generator_model = models.generator_upsampling(noise_dim, img_dim, bn_mode, dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim, img_dim, bn_mode, batch_size, dset=dset)
    discriminator_model = models.discriminator(img_dim, bn_mode)
    DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            if gen_iterations < 25 or gen_iterations % 500 == 0:
                disc_iterations = 100
            else:
                disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
                    l.set_weights(weights)

                X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch,
                                                                    generator_model,
                                                                    batch_counter,
                                                                    batch_size,
                                                                    noise_dim,
                                                                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0]))
                disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            gen_iterations += 1
            batch_counter += 1
            progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)),
                                            ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                            ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                            ("Loss_G", -gen_loss)])

            # Save images for visualization ~2 times per epoch
            if batch_counter % (n_batch_per_epoch / 2) == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                batch_size, noise_dim, image_dim_ordering)

        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e)
Beispiel #12
0
def trainClassAux(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    noClass = kwargs["noClass"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    C_weight = kwargs["C_weight"]
    monsterClass = kwargs["monsterClass"]
    pretrained = kwargs["pretrained"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    if dset == "mnistM":
        X_source_train,Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnist')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnistM')
    elif dset == "washington_vandal50k":
        X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington')
        X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal50k')
    elif dset == "washington_vandal12classes":
        X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes')
        X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classes')
    elif dset == "washington_vandal12classesNoBackground":
        X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes')
        X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classesNoBackground')
    elif dset == "Wash_Vand_12class_LMDB":
        X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_12class_LMDB')
        X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB')
    elif dset == "Vand_Vand_12class_LMDB":
        X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB_Background')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB')
    elif dset == "Wash_Color_LMDB":
        X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB')
    else:
        print "dataset not supported"
    if n_classes1 != n_classes2: #sanity check
        print "number of classes mismatch between source and dest domains"
    n_classes = n_classes1 #
    img_source_dim = X_source_train.shape[-3:] # is it backend agnostic?
    img_dest_dim = X_dest_train.shape[-3:] 

    X_source_train.flags.writeable = False
    X_source_test.flags.writeable = False
    X_dest_train.flags.writeable = False
    X_dest_test.flags.writeable = False
    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_G_C = data_utils.get_optimizer(opt_G, lr_G*C_weight)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)
    opt_C = data_utils.get_optimizer('SGD', 0.01)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    if generator == "upsampling":
        generator_model = models.generator_upsampling_mnistM(noise_dim, img_source_dim,img_dest_dim, bn_mode,deterministic,inject_noise,wd, dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim, img_dest_dim, bn_mode, batch_size, dset=dset)

    discriminator_model = models.discriminator(img_dest_dim, bn_mode,model,wd,monsterClass,inject_noise,n_classes)
    DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model, noise_dim, img_source_dim)
    classifier = models.resnet(img_dest_dim,n_classes,pretrained,wd=0.0001) #it is img_dest_dim because it is actually the generated image dim,that is equal to dest_dim
 
    GenToClassifierModel = models.GenToClassifierModel(generator_model, classifier, noise_dim, img_source_dim)

    ############################
    # Load weights
    ############################
    if resume:
        data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name)
    #if pretrained:
    #    model_path = "../../models/DCGAN"
    #    class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5')
    #    classifier.load_weights(class_weights_path)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)

    if model == 'wgan':
        discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
        models.make_trainable(discriminator_model, False)
        DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    if model == 'lsgan':
        discriminator_model.compile(loss='mse', optimizer=opt_D)
        models.make_trainable(discriminator_model, False)
        DCGAN_model.compile(loss='mse',  optimizer=opt_G)

    classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) # it is actually never using optimizer
    models.make_trainable(classifier, False)
    GenToClassifierModel.compile(loss='categorical_crossentropy', optimizer=opt_G,metrics=['accuracy'])


    #######################
    # Train classifier
    #######################
#    if not pretrained: 
#        #print ("Testing accuracy on target domain test set before training:")
#        loss1,acc1 =classifier.evaluate(X_dest_test, Y_dest_test,batch_size=256, verbose=0)
#        print('\n Classifier Accuracy on target domain test set before training: %.2f%%' % (100 * acc1))
#        classifier.fit(X_dest_train, Y_dest_train, validation_split=0.1, batch_size=512, nb_epoch=10, verbose=1)
#        print ("\n Testing accuracy on target domain test set AFTER training:")
#    else:
#        print ("Loaded pretrained classifier, computing accuracy on target domain test set:")
#    loss2,acc2 = classifier.evaluate(X_dest_test, Y_dest_test,batch_size=512, verbose=0)
#    print('\n Classifier Accuracy on target domain test set after training:  %.2f%%' % (100 * acc2))
    #print ("Testing accuracy on source domain test set:")
#    loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0)
#    print('\n Classifier Accuracy on source domain test set:  %.2f%%' % (100 * acc3))
#    evaluating_GENned(noise_scale, noise_dim, X_source_test, Y_source_test, classifier, generator_model)

#    model_path = "../../models/DCGAN"
#    class_weights_path = os.path.join(model_path, 'VandToVand_5epochs.h5')
#    classifier.save_weights(class_weights_path, overwrite=True) 
    #################
    # GAN training
    ################
    gen_iterations = 0 
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:
            if no_supertrain is None:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                if gen_iterations % 500 == 0:
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]
            else:
                if (gen_iterations <25) and (not resume):
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            list_gen_loss = []
            list_class_loss_real = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
#                for l in discriminator_model.layers:
#                    weights = l.get_weights()
#                    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
#                    l.set_weights(weights)

                X_dest_batch, Y_dest_batch,idx_dest_batch = next(data_utils.gen_batch(X_dest_train, Y_dest_train, batch_size))
                X_source_batch, Y_source_batch,idx_source_batch = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_dest_batch,
                                                                    generator_model,
                                                                    batch_counter,
                                                                    batch_size,
                                                                    noise_dim,
                                                                    X_source_batch,
                                                                    noise_scale=noise_scale)
                if model == 'wgan':
                # Update the discriminator
                    disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0]))
                    disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
                if model == 'lsgan':
                    disc_loss_real = discriminator_model.train_on_batch(X_disc_real, np.ones(X_disc_real.shape[0])) 
                    disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.zeros(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)


            #######################
            # 2) Train the generator with GAN loss
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
            source_images = X_source_train[np.random.randint(0,X_source_train.shape[0],size=batch_size),:,:,:]
            X_source_batch2, Y_source_batch2,idx_source_batch2 = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size))
            # Freeze the discriminator
 #           discriminator_model.trainable = False
            if model == 'wgan':
                gen_loss = DCGAN_model.train_on_batch([X_gen,X_source_batch2], -np.ones(X_gen.shape[0]))
            if model == 'lsgan':
                gen_loss = DCGAN_model.train_on_batch([X_gen,X_source_batch2], np.ones(X_gen.shape[0]))
            list_gen_loss.append(gen_loss)


            #######################
            # 3) Train the generator with Classifier loss
            #######################
            w1 = classifier.get_weights() #FOR DEBUG
            if not noClass:
                new_gen_loss = GenToClassifierModel.train_on_batch([X_gen,X_source_batch2], Y_source_batch2)
                list_class_loss_real.append(new_gen_loss)
            else:
                list_class_loss_real.append(0.0)
            w2 = classifier.get_weights() #FOR DEBUG
            for a,b in zip(w1, w2):
                if np.all(a == b):
                    print "no bug in GEN model update"
                else:
                    print "BUG IN GEN MODEL UPDATE"

            gen_iterations += 1
            batch_counter += 1

            progbar.add(batch_size, values=[("Loss_D", 0.5*np.mean(list_disc_loss_real) + 0.5*np.mean(list_disc_loss_gen)),
                                            ("Loss_D_real", np.mean(list_disc_loss_real)),
                                            ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                            ("Loss_G", np.mean(list_gen_loss)),
                                            ("Loss_classifier", np.mean(list_class_loss_real))])

            # plot images 1 times per epoch
            if batch_counter % (n_batch_per_epoch) == 0:
          #      train_WGAN.plot_images(X_dest_batch)
                X_dest_batch_plot,Y_dest_batch_plot,idx_dest_plot = next(data_utils.gen_batch(X_dest_train,Y_dest_train, batch_size=32))
                X_source_batch_plot,Y_source_batch_plot,idx_source_plot = next(data_utils.gen_batch(X_source_train,Y_source_train, batch_size=32))

                data_utils.plot_generated_batch(X_dest_train,X_source_train, generator_model,
                                                 noise_dim, image_dim_ordering,idx_source_plot,batch_size=32)
        print ("Dest labels:") 
        print (Y_dest_train[idx_source_plot].argmax(1))
        print ("Source labels:") 
        print (Y_source_batch_plot.argmax(1))
        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, name)
        evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model)


        loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0)
        print('\n Classifier Accuracy on source domain test set:  %.2f%%' % (100 * acc3))
def eval(**kwargs):

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    cont_dim = (kwargs["cont_dim"],)
    cat_dim = (kwargs["cat_dim"],)
    noise_dim = (kwargs["noise_dim"],)
    bn_mode = kwargs["bn_mode"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    eval_epoch = kwargs["eval_epoch"]

    # Setup environment (logging directory etc)
    general_utils.setup_logging(**kwargs)

    # Load and rescale data
    if dset == "RGZ":
        X_real_train = data_utils.load_RGZ(img_dim, image_data_format)
    if dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_data_format)
    img_dim = X_real_train.shape[-3:]

    # Load generator model
    generator_model = models.load("generator_%s" % generator,
                                  cat_dim,
                                  cont_dim,
                                  noise_dim,
                                  img_dim,
                                  bn_mode,
                                  batch_size,
                                  dset=dset)

    # Load colorization model
    generator_model.load_weights("../../models/%s/gen_weights_epoch%05d.h5" %
                                 (model_name, eval_epoch))

    X_plot = []
    # Vary the categorical variable
    for i in range(cat_dim[0]):
        X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
        X_cont = data_utils.sample_noise(noise_scale, batch_size, cont_dim)
        X_cont = np.repeat(X_cont[:1, :], batch_size, axis=0)  # fix continuous noise
        X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32')
        X_cat[:, i] = 1  # always the same categorical value

        X_gen = generator_model.predict([X_cat, X_cont, X_noise])
        X_gen = data_utils.inverse_normalization(X_gen)

        if image_data_format == "channels_first":
            X_gen = X_gen.transpose(0,2,3,1)

        X_gen = [X_gen[i] for i in range(len(X_gen))]
        X_plot.append(np.concatenate(X_gen, axis=1))
    X_plot = np.concatenate(X_plot, axis=0)

    plt.figure(figsize=(8,10))
    if X_plot.shape[-1] == 1:
        plt.imshow(X_plot[:, :, 0], cmap="gray")
    else:
        plt.imshow(X_plot)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("Varying categorical factor", fontsize=28, labelpad=60)

    plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.tight_layout()
    plt.savefig(os.path.join("../../figures", model_name, "varying_categorical.png"))
    plt.clf()
    plt.close()

    # Vary the continuous variables
    X_plot = []
    # First get the extent of the noise sampling
    x = np.ravel(data_utils.sample_noise(noise_scale, batch_size * 20000, cont_dim))
    # Define interpolation points
    x = np.linspace(x.min(), x.max(), num=batch_size)
    for i in range(batch_size):
        X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
        X_cont = np.concatenate([np.array([x[i], x[j]]).reshape(1, -1) for j in range(batch_size)], axis=0)
        X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32')
        X_cat[:, 1] = 1  # always the same categorical value

        X_gen = generator_model.predict([X_cat, X_cont, X_noise])
        X_gen = data_utils.inverse_normalization(X_gen)
        if image_data_format == "channels_first":
            X_gen = X_gen.transpose(0,2,3,1)
        X_gen = [X_gen[i] for i in range(len(X_gen))]
        X_plot.append(np.concatenate(X_gen, axis=1))
    X_plot = np.concatenate(X_plot, axis=0)

    plt.figure(figsize=(10,10))
    if X_plot.shape[-1] == 1:
        plt.imshow(X_plot[:, :, 0], cmap="gray")
    else:
        plt.imshow(X_plot)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("Varying continuous factor 1", fontsize=28, labelpad=60)
    plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.xlabel("Varying continuous factor 2", fontsize=28, labelpad=60)
    plt.annotate('', xy=(1, -0.05), xycoords='axes fraction', xytext=(0, -0.05),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.tight_layout()
    plt.savefig(os.path.join("../../figures", model_name, "varying_continuous.png"))
    plt.clf()
    plt.close()
def train_toy(**kwargs):
    """
    Train model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("toy_MLP")

    # Load and rescale data
    X_real_train = data_utils.load_toy()

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    generator_model = models.generator_toy(noise_dim)
    discriminator_model = models.discriminator_toy()
    GAN_model = models.GAN_toy(generator_model, discriminator_model, noise_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    GAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    #################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [
                        np.clip(w, clamp_lower, clamp_upper) for w in weights
                    ]
                    l.set_weights(weights)

                X_real_batch = next(
                    data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                    X_real_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, -np.ones(X_disc_real.shape[0]))
                disc_loss_gen = discriminator_model.train_on_batch(
                    X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = GAN_model.train_on_batch(X_gen,
                                                -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            batch_counter += 1
            progbar.add(batch_size,
                        values=[("Loss_D", -np.mean(list_disc_loss_real) -
                                 np.mean(list_disc_loss_gen)),
                                ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", -gen_loss)])

            # # Save images for visualization
            if gen_iterations % 50 == 0:
                data_utils.plot_generated_toy_batch(X_real_train,
                                                    generator_model,
                                                    discriminator_model,
                                                    noise_dim, gen_iterations)
            gen_iterations += 1

        print('\nEpoch %s/%s, Time: %s' %
              (e + 1, nb_epoch, time.time() - start))
Beispiel #15
0
def trainDECO(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """
    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    noClass = kwargs["noClass"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    monsterClass = kwargs["monsterClass"]
    pretrained = kwargs["pretrained"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    if dset == "mnistM":
        X_source_train,Y_source_train, _, _, n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnist')
        X_dest_train,Y_dest_train, _, _,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnistM')
    elif dset == "washington_vandal50k":
        X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington')
        X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal50k')
    elif dset == "washington_vandal12classes":
        X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes')
        X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classes')
    elif dset == "washington_vandal12classesNoBackground":
        X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes')
        X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classesNoBackground')
    elif dset == "Wash_Vand_12class_LMDB":
        X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_12class_LMDB')
        X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB')
    elif dset == "Vand_Vand_12class_LMDB":
        X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB_Background')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB')
    elif dset == "Wash_Color_LMDB":
        X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB')
    else:
        print "dataset not supported"
    if n_classes1 != n_classes2: #sanity check
        print "number of classes mismatch between source and dest domains"
    n_classes = n_classes1 

    # Get the full real image dimension
    img_source_dim = X_source_train.shape[-3:] # is it backend agnostic?
    img_dest_dim = X_dest_train.shape[-3:] 

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)
    opt_C = data_utils.get_optimizer('SGD', 0.01)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    generator_model = models.generator_upsampling_mnistM(noise_dim, img_source_dim,img_dest_dim, bn_mode,deterministic,inject_noise,wd, dset=dset)
    classifier = models.resnet(img_dest_dim,n_classes,wd=0.0001) #it is img_dest_dim because it is actually the generated image dim,that is equal to dest_dim
    GenToClassifierModel = models.GenToClassifierModel(generator_model, classifier, noise_dim, img_source_dim)

    #################
    # Load weight 
    ################
    if pretrained:
        model_path = "../../models/DCGAN"
        class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5')
        classifier.load_weights(class_weights_path)
 #   if resume: ########loading previous saved model weights
 #       data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name)

    #######################
    # Compile models
    #######################
    generator_model.compile(loss='mse', optimizer=opt_G)
#    classifier.trainable = True # I wanna freeze the classifier without any training updates
    classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) # it is actually never using optimizer

    models.make_trainable(classifier, False)
    GenToClassifierModel.compile(loss='categorical_crossentropy', optimizer=opt_G,metrics=['accuracy'])

    #######################
    # Train classifier
    #######################
    if not pretrained: 
        loss1,acc1 =classifier.evaluate(X_dest_test, Y_dest_test,batch_size=256, verbose=0)
        print('\n Classifier Accuracy on target domain test set before training: %.00f%%' % (100.0 * acc1))
        classifier.fit(X_dest_train, Y_dest_train, validation_split=0.1, batch_size=512, nb_epoch=20, verbose=1)
    else:
        print ("Loaded pretrained classifier, computing accuracy on target domain test set:")
    loss2,acc2 = classifier.evaluate(X_dest_test, Y_dest_test,batch_size=512, verbose=0)
    print('\n Classifier Accuracy on target domain test set after training:  %.00f%%' % (100.0 * acc2))
    loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0)
    print('\n Classifier Accuracy on source domain test set:  %.00f%%' % (100.0 * acc3))
    evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model)

#    model_path = "../../models/DCGAN"
#    class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5')
#    classifier.save_weights(class_weights_path, overwrite=True)

#    models.make_trainable(classifier, False)
    #classifier.trainable = False # I wanna freeze the classifier without any more training updates
#    classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) 
    #################
    #  DECO training
    ################
    gen_iterations = 0
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()
            ###################################
            # 1) Train the critic / discriminator
            ###################################
        list_class_loss_real = []
        X_dest_batch, Y_dest_batch,idx_dest_batch = next(data_utils.gen_batch(X_dest_train, Y_dest_train, batch_size))
        X_source_batch, Y_source_batch,idx_source_batch = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size))
        #######################
        # 2) Train the generator
        #######################
        X_gen = data_utils.sample_noise(noise_scale, batch_size*n_batch_per_epoch, noise_dim)
        X_source_batch2, Y_source_batch2,idx_source_batch2 = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size*n_batch_per_epoch))
        GenToClassifierModel.fit([X_gen, X_source_batch2], Y_source_batch2, batch_size=256, nb_epoch=1, verbose=1)

       
        gen_iterations += 1
        batch_counter += 1

            # plot images 1 times per epoch
        X_dest_batch_plot,Y_dest_batch_plot,idx_dest_plot = next(data_utils.gen_batch(X_dest_train,Y_dest_train, batch_size=32))
        X_source_batch_plot,Y_source_batch_plot,idx_source_plot = next(data_utils.gen_batch(X_source_train,Y_source_train, batch_size=32))

        data_utils.plot_generated_batch(X_dest_train,X_source_train, generator_model,
                                                 noise_dim, image_dim_ordering,idx_source_plot,batch_size=32)
        print ("Dest labels:") 
        print (Y_dest_train[idx_source_plot].argmax(1))
        print ("Source labels:") 
        print (Y_source_batch_plot.argmax(1))
        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        #data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, name)
        evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model)

        loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0)
        print('\n Classifier Accuracy on source domain test set:  %.00f%%' % (100.0 * acc3))
Beispiel #16
0
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    use_mbd = kwargs["use_mbd"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    pureGAN = kwargs["pureGAN"]
    lsmooth = kwargs["lsmooth"]
    disc_type = kwargs["disc_type"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    history_size = kwargs["history_size"]
    monsterClass = kwargs["monsterClass"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    if dset == "mnistM":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='mnist')
        #        X_source_train=np.concatenate([X_source_train,X_source_train,X_source_train], axis=1)
        #        X_source_test=np.concatenate([X_source_test,X_source_test,X_source_test], axis=1)
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='mnistM')
    elif dset == "OfficeDslrToAmazon":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='OfficeDslr')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='OfficeAmazon')
    else:
        print "dataset not supported"
    if n_classes1 != n_classes2:  #sanity check
        print "number of classes mismatch between source and dest domains"
    n_classes = n_classes1  #

    img_source_dim = X_source_train.shape[-3:]  # is it backend agnostic?
    img_dest_dim = X_dest_train.shape[-3:]

    # Create optimizers
    opt_D = data_utils.get_optimizer(opt_D, lr_D)
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_GC = data_utils.get_optimizer('Adam', lr_G / 10.0)
    opt_C = data_utils.get_optimizer('Adam', lr_D)
    opt_Z = data_utils.get_optimizer('Adam', lr_G)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    generator_model = models.generator_google_mnistM(noise_dim, img_source_dim,
                                                     img_dest_dim,
                                                     deterministic, pureGAN,
                                                     wd)
    #    discriminator_model = models.discriminator_google_mnistM(img_dest_dim, wd)
    discriminator_model = models.discriminator_dcgan(img_dest_dim, wd,
                                                     n_classes, disc_type)
    classificator_model = models.classificator_google_mnistM(
        img_dest_dim, n_classes, wd)
    DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model,
                                     noise_dim, img_source_dim)
    zclass_model = z_coerence(generator_model,
                              img_source_dim,
                              bn_mode,
                              wd,
                              inject_noise,
                              n_classes,
                              noise_dim,
                              model_name="zClass")
    #    GenToClassifier_model = models.GenToClassifierModel(generator_model, classificator_model, noise_dim, img_source_dim)
    #disc_penalty_model = models.disc_penalty(discriminator_model,noise_dim,img_source_dim,opt_D,model_name="disc_penalty_model")
    zclass_model = z_coerence(generator_model,
                              img_source_dim,
                              bn_mode,
                              wd,
                              inject_noise,
                              n_classes,
                              noise_dim,
                              model_name="zClass")

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)

    models.make_trainable(discriminator_model, False)
    models.make_trainable(classificator_model, False)
    #    models.make_trainable(disc_penalty_model, False)
    if model == 'wgan':
        DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
        models.make_trainable(discriminator_model, True)
        #     models.make_trainable(disc_penalty_model, True)
        discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
    if model == 'lsgan':
        if disc_type == "simple_disc":
            DCGAN_model.compile(loss=['mse'], optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(loss=['mse'], optimizer=opt_D)
        elif disc_type == "nclass_disc":
            DCGAN_model.compile(loss=['mse', 'categorical_crossentropy'],
                                loss_weights=[1.0, 0.1],
                                optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(
                loss=['mse', 'categorical_crossentropy'],
                loss_weights=[1.0, 0.1],
                optimizer=opt_D)
#    GenToClassifier_model.compile(loss='categorical_crossentropy', optimizer=opt_GC)
    models.make_trainable(classificator_model, True)
    classificator_model.compile(loss='categorical_crossentropy',
                                metrics=['accuracy'],
                                optimizer=opt_C)
    zclass_model.compile(loss=['mse'], optimizer=opt_Z)

    visualize = True
    ########
    #MAKING TRAIN+TEST numpy array for global testing:
    ########
    Xtarget_dataset = np.concatenate([X_dest_train, X_dest_test], axis=0)
    Ytarget_dataset = np.concatenate([Y_dest_train, Y_dest_test], axis=0)

    if resume:  ########loading previous saved model weights and checking actual performance
        data_utils.load_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, name, classificator_model,
                                      zclass_model)
#        data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name,classificator_model)
#        loss4, acc4 = classificator_model.evaluate(Xtarget_dataset, Ytarget_dataset,batch_size=1024, verbose=0)
#        print('\n Classifier Accuracy on full target domain:  %.2f%%' % (100 * acc4))

    else:
        X_gen = data_utils.sample_noise(noise_scale, X_source_train.shape[0],
                                        noise_dim)
        zclass_loss = zclass_model.fit([X_gen, X_source_train], [X_gen],
                                       batch_size=256,
                                       epochs=10)
    ####train zclass regression model only if not resuming:

    gen_iterations = 0
    max_history_size = int(history_size * batch_size)
    img_buffer = ImageHistoryBuffer((0, ) + img_source_dim, max_history_size,
                                    batch_size, n_classes)
    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:
            if no_supertrain is None:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                if gen_iterations % 500 == 0:
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]
            else:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = deque(10 * [0], 10)
            list_disc_loss_gen = deque(10 * [0], 10)
            list_gen_loss = deque(10 * [0], 10)
            list_zclass_loss = deque(10 * [0], 10)
            list_classifier_loss = deque(10 * [0], 10)
            list_gp_loss = deque(10 * [0], 10)
            for disc_it in range(disc_iterations):
                X_dest_batch, Y_dest_batch, idx_dest_batch = next(
                    data_utils.gen_batch(X_dest_train, Y_dest_train,
                                         batch_size))
                X_source_batch, Y_source_batch, idx_source_batch = next(
                    data_utils.gen_batch(X_source_train, Y_source_train,
                                         batch_size))
                ##########
                # Create a batch to feed the discriminator model
                #########
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                    X_dest_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    X_source_batch,
                    noise_scale=noise_scale)

                # Update the discriminator
                if model == 'wgan':
                    current_labels_real = -np.ones(X_disc_real.shape[0])
                    current_labels_gen = np.ones(X_disc_gen.shape[0])
                elif model == 'lsgan':
                    if disc_type == "simple_disc":
                        current_labels_real = np.ones(X_disc_real.shape[0])
                        current_labels_gen = np.zeros(X_disc_gen.shape[0])
                    elif disc_type == "nclass_disc":
                        virtual_real_labels = np.zeros(
                            [X_disc_gen.shape[0], n_classes])
                        current_labels_real = [
                            np.ones(X_disc_real.shape[0]), virtual_real_labels
                        ]
                        current_labels_gen = [
                            np.zeros(X_disc_gen.shape[0]), Y_source_batch
                        ]
                ##############
                #Train the disc on gen-buffered samples and on current real samples
                ##############
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, current_labels_real)
                img_buffer.add_to_buffer(X_disc_gen, current_labels_gen,
                                         batch_size)
                bufferImages, bufferLabels = img_buffer.get_from_buffer(
                    batch_size)
                disc_loss_gen = discriminator_model.train_on_batch(
                    bufferImages, bufferLabels)

                #if not isinstance(disc_loss_real, collections.Iterable): disc_loss_real = [disc_loss_real]
                #if not isinstance(disc_loss_real, collections.Iterable): disc_loss_gen = [disc_loss_gen]
                if disc_type == "simple_disc":
                    list_disc_loss_real.appendleft(disc_loss_real)
                    list_disc_loss_gen.appendleft(disc_loss_gen)
                elif disc_type == "nclass_disc":
                    list_disc_loss_real.appendleft(disc_loss_real[0])
                    list_disc_loss_gen.appendleft(disc_loss_gen[0])
                #############
                ####Train the discriminator w.r.t gradient penalty
                #############
                #gp_loss = disc_penalty_model.train_on_batch([X_disc_real,X_disc_gen],current_labels_real) #dummy labels,not used in the loss function
                #list_gp_loss.appendleft(gp_loss)

            ################
            ###CLASSIFIER TRAINING OUTSIDE DISC LOOP(wanna train in just 1 time even if disc_iter > 1)
            #################
            class_loss_gen = classificator_model.train_on_batch(
                X_disc_gen, Y_source_batch * 0.7)  #LABEL SMOOTHING!!!!
            list_classifier_loss.appendleft(class_loss_gen[1])
            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
            X_source_batch2, Y_source_batch2, idx_source_batch2 = next(
                data_utils.gen_batch(X_source_train, Y_source_train,
                                     batch_size))
            if model == 'wgan':
                gen_loss = DCGAN_model.train_on_batch([X_gen, X_source_batch2],
                                                      -np.ones(X_gen.shape[0]))
            if model == 'lsgan':
                if disc_type == "simple_disc":
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2],
                        np.ones(X_gen.shape[0]))  #TRYING SAME BATCH OF DISC
                elif disc_type == "nclass_disc":
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2],
                        [np.ones(X_gen.shape[0]), Y_source_batch2])
                    gen_loss = gen_loss[0]
            list_gen_loss.appendleft(gen_loss)

            zclass_loss = zclass_model.train_on_batch([X_gen, X_source_batch2],
                                                      [X_gen])
            list_zclass_loss.appendleft(zclass_loss)
            ##############
            #Train the generator w.r.t the aux classifier:
            #############
            #            GenToClassifier_model.train_on_batch([X_gen,X_source_batch2],Y_source_batch2)

            # I SHOULD TRY TO CLASSIFY EVEN ON DISCRIMINATOR, PUTTING ONE CLASS FOR REAL SAMPLES AND N CLASS FOR FAKE

            gen_iterations += 1
            batch_counter += 1

            progbar.add(batch_size,
                        values=[("Loss_D_real", np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", np.mean(list_gen_loss)),
                                ("Loss_Z", np.mean(list_zclass_loss)),
                                ("Loss_Classifier",
                                 np.mean(list_classifier_loss))])

            # plot images 1 times per epoch
            if batch_counter % (n_batch_per_epoch) == 0:
                X_source_batch_plot, Y_source_batch_plot, idx_source_plot = next(
                    data_utils.gen_batch(X_source_test,
                                         Y_source_test,
                                         batch_size=32))
                data_utils.plot_generated_batch(X_dest_test,
                                                X_source_test,
                                                generator_model,
                                                noise_dim,
                                                image_dim_ordering,
                                                idx_source_plot,
                                                batch_size=32)
            if gen_iterations % (n_batch_per_epoch * 5) == 0:
                if visualize:
                    BIG_ASS_VISUALIZATION_slerp(X_source_train[1],
                                                generator_model, noise_dim)
        #    if (e % 20) == 0:
        #        lr_decay([discriminator_model,DCGAN_model,classificator_model],decay_value=0.95)

        print("Dest labels:")
        print(Y_dest_test[idx_source_plot].argmax(1))
        print("Source labels:")
        print(Y_source_batch_plot.argmax(1))
        print('\nEpoch %s/%s, Time: %s' %
              (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, e, name,
                                      classificator_model, zclass_model)

        #testing accuracy of trained classifier
        loss4, acc4 = classificator_model.evaluate(Xtarget_dataset,
                                                   Ytarget_dataset,
                                                   batch_size=1024,
                                                   verbose=0)
        print(
            '\n Classifier Accuracy and loss on full target domain:  %.2f%% / %.5f%%'
            % ((100 * acc4), loss4))
Beispiel #17
0
def train_gan(GAN, disc_iters, A_data, A_labels, B_data, B_labels, batch_counter, l_disc_real, l_disc_gen, l_gen):
    if GAN.dir == 'BtoA':
        for disc_it in range(disc_iters):
            A_data_batch, A_labels_batch, B_data_batch, B_labels_batch = get_batch(A_data, A_labels, B_data, B_labels, GAN.batch_size)
            X_source_batch = B_data_batch
            #Y_source_batch = B_labels_batch
            X_dest_batch = A_data_batch
            Y_dest_batch = A_labels_batch

            ##########
            # Create a batch to feed the discriminator model
            #########
            X_noise = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size, GAN.noise_dim)
            gen_output = GAN.generator_model.predict([X_noise,X_source_batch])
            #X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_dest_batch, GAN.generator_model, batch_counter, GAN.batch_size,
            #                                                    GAN.noise_dim, X_source_batch, noise_scale=GAN.noise_scale)
            if GAN.disc_type == "simple_disc":
                current_labels_real = np.ones(GAN.batch_size)
                current_labels_gen = np.zeros(GAN.batch_size)
            if GAN.disc_type == ("nclass_disc"):
                current_labels_real = np.ones(GAN.batch_size) 
                current_labels_gen = np.zeros(GAN.batch_size) 
                #class_p = GAN.discriminator2.predict_on_batch(X_disc_gen)
                #idx = np.argmax(class_p, axis=1)
                #virtual_labels = (idx[:, None]) == np.arange(GAN.n_classes) * 1
            ##############
            # Train the disc on gen-buffered samples and on current real samples
            ##############
            disc_loss_real = GAN.discriminator_model.train_on_batch(X_dest_batch, current_labels_real)
            GAN.img_buffer.add_to_buffer(gen_output, current_labels_gen, GAN.batch_size)
            bufferImages, bufferLabels = GAN.img_buffer.get_from_buffer(GAN.batch_size)
            disc_loss_gen = GAN.discriminator_model.train_on_batch(bufferImages, bufferLabels)
            disc2_loss = GAN.discriminator2.train_on_batch(X_dest_batch,Y_dest_batch * 1.0) #GAN.lsmooth) #training the discriminator_classifier model
            disc2_entropyloss =0.0 # discriminator2.train_on_batch(X_disc_gen,virtual_labels * 0.7) #training on dest as we are here in BtoA

            l_disc_real.appendleft(disc_loss_real)
            l_disc_gen.appendleft(disc_loss_gen)
        #Train the GENERATOR, it is the same on both AtoB and BtoA:
        X_noise = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size, GAN.noise_dim)
        if GAN.disc_type == "simple_disc":                
            gen_loss =  GAN.DCGAN_model.train_on_batch([X_noise,X_source_batch], np.ones(GAN.batch_size)) #TRYING SAME BATCH OF DISC
        elif GAN.disc_type == "nclass_disc":
            gen_loss =  GAN.DCGAN_model.train_on_batch([X_noise,X_source_batch], np.ones(GAN.batch_size)) #TRYING SAME BATCH OF DISC
            #gen_loss =  GAN.DCGAN_model.train_on_batch([X_noise,X_source_batch], [np.ones(GAN.batch_size),Y_virtual_labels])
            #gen_loss = gen_loss[0]
        l_gen.appendleft(gen_loss)

    elif GAN.dir == 'AtoB':
        for disc_it in range(disc_iters):
            A_data_batch, A_labels_batch, B_data_batch, B_labels_batch = get_batch(A_data, A_labels, B_data, B_labels, GAN.batch_size)
            X_source_batch = A_data_batch
            Y_source_batch = A_labels_batch
            X_dest_batch = B_data_batch
            #Y_dest_batch = B_labels_batch
            X_noise = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size, GAN.noise_dim)
            gen_output = GAN.generator_model.predict([X_noise,X_source_batch])
            #X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_dest_batch, GAN.generator_model, batch_counter, GAN.batch_size,
            #                                                    GAN.noise_dim, X_source_batch, noise_scale=GAN.noise_scale)
            if GAN.disc_type == "simple_disc":
                current_labels_real = np.ones(GAN.batch_size)
                current_labels_gen = np.zeros(GAN.batch_size)
            if GAN.disc_type == ("nclass_disc"):
                current_labels_real = np.ones(GAN.batch_size) 
                current_labels_gen = np.zeros(GAN.batch_size) 
            ##############
            #Train the disc on gen-buffered samples and on current real samples
            ##############
            disc_loss_real = GAN.discriminator_model.train_on_batch(X_dest_batch, current_labels_real)
            GAN.img_buffer.add_to_buffer(gen_output,current_labels_gen, GAN.batch_size)
            bufferImages, bufferLabels = GAN.img_buffer.get_from_buffer(GAN.batch_size)
            disc_loss_gen = GAN.discriminator_model.train_on_batch(bufferImages, bufferLabels)
            #code.interact(local=locals())
            disc2_loss = GAN.discriminator2.train_on_batch(gen_output,Y_source_batch * 1.0) #GAN.lsmooth) #training the discriminator_classifier model

            l_disc_real.appendleft(disc_loss_real)
            l_disc_gen.appendleft(disc_loss_gen)

        #Train the GENERATOR, it is the same on both AtoB and BtoA:
        X_noise = data_utils.sample_noise(GAN.noise_scale, GAN.batch_size, GAN.noise_dim)
        if GAN.disc_type == "simple_disc":                
            gen_loss =  GAN.DCGAN_model.train_on_batch([X_noise,X_source_batch], np.ones(GAN.batch_size)) #TRYING SAME BATCH OF DISC
        elif GAN.disc_type == "nclass_disc":
            gen_loss =  GAN.DCGAN_model.train_on_batch([X_noise,X_source_batch], np.ones(GAN.batch_size)) #TRYING SAME BATCH OF DISC
            #gen_loss =  GAN.DCGAN_model.train_on_batch([X_noise,X_source_batch], [np.ones(GAN.batch_size),Y_virtual_labels])
            #gen_loss = gen_loss[0]
        l_gen.appendleft(gen_loss)
    return A_data_batch, A_labels_batch, B_data_batch, B_labels_batch
def train_toy(**kwargs):
    """
    Train model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("toy_MLP")

    # Load and rescale data
    X_real_train = data_utils.load_toy()

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    generator_model = models.generator_toy(noise_dim)
    discriminator_model = models.discriminator_toy()
    GAN_model = models.GAN_toy(generator_model, discriminator_model, noise_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    GAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    #################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
                    l.set_weights(weights)

                X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch,
                                                                    generator_model,
                                                                    batch_counter,
                                                                    batch_size,
                                                                    noise_dim,
                                                                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0]))
                disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = GAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            batch_counter += 1
            progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)),
                                            ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                            ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                            ("Loss_G", -gen_loss)])

            # # Save images for visualization
            if gen_iterations % 50 == 0:
                data_utils.plot_generated_toy_batch(X_real_train, generator_model,
                                                    discriminator_model, noise_dim, gen_iterations)
            gen_iterations += 1

        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))
Beispiel #19
0
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_data_format = kwargs["image_data_format"]
    save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"]
    save_only_last_n_weights = kwargs["save_only_last_n_weights"]
    visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"]
    model_name = kwargs["model_name"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print(key, kwargs[key])
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging(**kwargs)

    # Load and normalize data
    X_real_train, X_batch_gen = data_utils.load_image_dataset(
        dset, img_dim, image_data_format, batch_size)

    # Get the full real image dimension
    img_dim = X_real_train.shape[-3:]

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    if generator == "upsampling":
        generator_model = models.generator_upsampling(noise_dim,
                                                      img_dim,
                                                      dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim,
                                                  img_dim,
                                                  batch_size,
                                                  dset=dset)
    discriminator_model = models.discriminator(img_dim)
    DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim,
                               img_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    disc_losses = []
    disc_losses_real = []
    disc_losses_gen = []
    gen_losses = []

    #################
    # Start training
    ################

    try:

        for e in range(nb_epoch):

            print('--------------------------------------------')
            print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d}\n'.format(
                datetime.datetime.now(), e + 1, nb_epoch))

            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 0
            start = time.time()

            disc_loss_batch = 0
            disc_loss_real_batch = 0
            disc_loss_gen_batch = 0
            gen_loss_batch = 0

            for batch_counter in range(n_batch_per_epoch):

                if gen_iterations < 25 or gen_iterations % 500 == 0:
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]

                ###################################
                # 1) Train the critic / discriminator
                ###################################
                list_disc_loss_real = []
                list_disc_loss_gen = []
                for disc_it in range(disc_iterations):

                    # Clip discriminator weights
                    for l in discriminator_model.layers:
                        weights = l.get_weights()
                        weights = [
                            np.clip(w, clamp_lower, clamp_upper)
                            for w in weights
                        ]
                        l.set_weights(weights)

                    X_real_batch = next(
                        data_utils.gen_batch(X_real_train, X_batch_gen,
                                             batch_size))

                    # Create a batch to feed the discriminator model
                    X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                        X_real_batch,
                        generator_model,
                        batch_counter,
                        batch_size,
                        noise_dim,
                        noise_scale=noise_scale)

                    # Update the discriminator
                    disc_loss_real = discriminator_model.train_on_batch(
                        X_disc_real, -np.ones(X_disc_real.shape[0]))
                    disc_loss_gen = discriminator_model.train_on_batch(
                        X_disc_gen, np.ones(X_disc_gen.shape[0]))
                    list_disc_loss_real.append(disc_loss_real)
                    list_disc_loss_gen.append(disc_loss_gen)

                #######################
                # 2) Train the generator
                #######################
                X_gen = data_utils.sample_noise(noise_scale, batch_size,
                                                noise_dim)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      -np.ones(X_gen.shape[0]))
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                gen_iterations += 1

                disc_loss_batch += -np.mean(list_disc_loss_real) - np.mean(
                    list_disc_loss_gen)
                disc_loss_real_batch += -np.mean(list_disc_loss_real)
                disc_loss_gen_batch += np.mean(list_disc_loss_gen)
                gen_loss_batch += -gen_loss

                progbar.add(batch_size,
                            values=[
                                ("Loss_D", -np.mean(list_disc_loss_real) -
                                 np.mean(list_disc_loss_gen)),
                                ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", -gen_loss)
                            ])

                # # Save images for visualization ~2 times per epoch
                # if batch_counter % (n_batch_per_epoch / 2) == 0:
                #     data_utils.plot_generated_batch(X_real_batch, generator_model,
                #                                     batch_size, noise_dim, image_data_format)

            disc_losses.append(disc_loss_batch / n_batch_per_epoch)
            disc_losses_real.append(disc_loss_real_batch / n_batch_per_epoch)
            disc_losses_gen.append(disc_loss_gen_batch / n_batch_per_epoch)
            gen_losses.append(gen_loss_batch / n_batch_per_epoch)

            # Save images for visualization
            if (e + 1) % visualize_images_every_n_epochs == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                e, batch_size, noise_dim,
                                                image_data_format, model_name)
                data_utils.plot_losses(disc_losses, disc_losses_real,
                                       disc_losses_gen, gen_losses, model_name)

            # Save model weights (by default, every 5 epochs)
            data_utils.save_model_weights(generator_model, discriminator_model,
                                          DCGAN_model, e,
                                          save_weights_every_n_epochs,
                                          save_only_last_n_weights, model_name)

            end = time.time()
            print('\nEpoch %s/%s END, Time: %s' %
                  (e + 1, nb_epoch, end - start))
            start = end

    except KeyboardInterrupt:
        pass

    gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name)
    print("Saving", gen_weights_path)
    generator_model.save(gen_weights_path, overwrite=True)