def train_toy(**kwargs):
    """
    Train model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("toy_MLP")

    # Load and rescale data
    X_real_train = data_utils.load_toy()

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    generator_model = models.generator_toy(noise_dim)
    discriminator_model = models.discriminator_toy()
    GAN_model = models.GAN_toy(generator_model, discriminator_model, noise_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    GAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    #################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
                    l.set_weights(weights)

                X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch,
                                                                    generator_model,
                                                                    batch_counter,
                                                                    batch_size,
                                                                    noise_dim,
                                                                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0]))
                disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = GAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            batch_counter += 1
            progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)),
                                            ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                            ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                            ("Loss_G", -gen_loss)])

            # # Save images for visualization
            if gen_iterations % 50 == 0:
                data_utils.plot_generated_toy_batch(X_real_train, generator_model,
                                                    discriminator_model, noise_dim, gen_iterations)
            gen_iterations += 1

        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))
Exemplo n.º 2
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    img_dim = kwargs["img_dim"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    # if dset == "celebA":
    #     X_real_train = data_utils.load_celebA(img_dim, image_dim_ordering)
    # if dset == "mnist":
    #     X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering)
    # img_dim = X_real_train.shape[-3:]
    img_dim = (3, 64, 64)
    noise_dim = (100, )

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
        opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)

        # Load generator model
        generator_model = models.load("generator_%s" % generator,
                                      noise_dim,
                                      img_dim,
                                      bn_mode,
                                      batch_size,
                                      dset=dset,
                                      use_mbd=use_mbd)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          noise_dim,
                                          img_dim,
                                          bn_mode,
                                          batch_size,
                                          dset=dset,
                                          use_mbd=use_mbd)
        #load the weights here
        for e in range(200, 355, 5):
            gen_weights_path = os.path.join(
                '../../CNN/gen_weights_epoch%s.h5' % (e))
            # gen_weight_file  = h5py.File(gen_weights_path, 'r')
            file_path_to_save_img = 'GeneratedImages1234/Epoch_%s/' % (e)
            os.mkdir(file_path_to_save_img)
            # generate images
            generator_model.load_weights(gen_weights_path)
            generator_model.compile(loss='mse', optimizer=opt_discriminator)
            noise_z = np.random.normal(scale=0.5, size=(32, noise_dim[0]))
            X_generated = generator_model.predict(noise_z)
            # print('Epoch%s.png' % (i))
            X_gen = inverse_normalization(X_generated)
            for img in range(X_gen.shape[0]):
                ret = X_gen[img].transpose(1, 2, 0)
                fig = plt.figure(frameon=False)
                fig.set_size_inches(64, 64)
                ax = plt.Axes(fig, [0., 0., 1., 1.])
                ax.set_axis_off()
                fig.add_axes(ax)
                ax.imshow(ret, aspect='normal')
                fig.savefig(file_path_to_save_img + 'retina_%s.png' % (img),
                            dpi=1)
                plt.clf()
                plt.close()

            # Xg = X_gen[:8]
            # Xr = X_gen[8:]
            #
            # if image_dim_ordering == "tf":
            #     X = np.concatenate((Xg, Xr), axis=0)
            #     list_rows = []
            #     for i in range(int(X.shape[0] / 4)):
            #         Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=1)
            #         list_rows.append(Xr)
            #
            #     Xr = np.concatenate(list_rows, axis=0)
            #
            # if image_dim_ordering == "th":
            #     X = np.concatenate((Xg, Xr), axis=0)
            #     list_rows = []
            #     for i in range(int(X.shape[0] / 4)):
            #         Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=2)
            #         list_rows.append(Xr)
            #
            #     Xr = np.concatenate(list_rows, axis=1)
            #     Xr = Xr.transpose(1,2,0)
            #
            # if Xr.shape[-1] == 1:
            #     plt.imshow(Xr[:, :, 0], cmap="gray")
            # else:
            #     plt.imshow(Xr)
            # plt.savefig(file_path_to_save_img+'Epoch%s.png' % (e))
            # plt.clf()
            # plt.close()

        # generator_model.load_weights('gen_weights_epoch245.h5')
        # generator_model.compile(loss='mse', optimizer=opt_discriminator)
        # discriminator_model.trainable = False
        #
        # DCGAN_model = models.DCGAN(generator_model,
        #                            discriminator_model,
        #                            noise_dim,
        #                            img_dim)
        #
        # loss = ['binary_crossentropy']
        # loss_weights = [1]
        # DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)
        #
        # discriminator_model.trainable = True
        # discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

        # noise_z = np.random.normal(scale=0.5, size=(32, noise_dim[0]))
        # X_generated = generator_model.predict(noise_z)
        #
        # X_gen = inverse_normalization(X_generated)
        #
        # Xg = X_gen[:8]
        # Xr = X_gen[8:]
        #
        # if image_dim_ordering == "tf":
        #     X = np.concatenate((Xg, Xr), axis=0)
        #     list_rows = []
        #     for i in range(int(X.shape[0] / 4)):
        #         Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=1)
        #         list_rows.append(Xr)
        #
        #     Xr = np.concatenate(list_rows, axis=0)
        #
        # if image_dim_ordering == "th":
        #     X = np.concatenate((Xg, Xr), axis=0)
        #     list_rows = []
        #     for i in range(int(X.shape[0] / 4)):
        #         Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=2)
        #         list_rows.append(Xr)
        #
        #     Xr = np.concatenate(list_rows, axis=1)
        #     Xr = Xr.transpose(1,2,0)
        #
        # if Xr.shape[-1] == 1:
        #     plt.imshow(Xr[:, :, 0], cmap="gray")
        # else:
        #     plt.imshow(Xr)
        # plt.savefig("current_batch.png")
        # plt.clf()
        # plt.close()

        # gen_loss = 100
        # disc_loss = 100
        #
        # # Start training
        # print("Start training")
        # k = 0
        # for e in range(nb_epoch):
        #     # Initialize progbar and batch counter
        #     progbar = generic_utils.Progbar(epoch_size)
        #     batch_counter = 1
        #     start = time.time()
        #
        #     for X_real_batch in data_utils.gen_batch(X_real_train, batch_size):
        #
        #         # Create a batch to feed the discriminator model
        #         X_disc, y_disc = data_utils.get_disc_batch(X_real_batch,
        #                                                    generator_model,
        #                                                    batch_counter,
        #                                                    batch_size,
        #                                                    noise_dim,
        #                                                    noise_scale=noise_scale,
        #                                                    label_smoothing=label_smoothing,
        #                                                    label_flipping=label_flipping)
        #
        #         # Update the discriminator
        #         disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
        #
        #         # Create a batch to feed the generator model
        #         X_gen, y_gen = data_utils.get_gen_batch(batch_size, noise_dim, noise_scale=noise_scale)
        #
        #         # Freeze the discriminator
        #         discriminator_model.trainable = False
        #         gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen)
        #         # Unfreeze the discriminator
        #         discriminator_model.trainable = True
        #
        #         batch_counter += 1
        #         progbar.add(batch_size, values=[("D logloss", disc_loss),
        #                                         ("G logloss", gen_loss)])
        #
        #         # Save images for visualization
        #         if batch_counter % 100 == 0:
        #             data_utils.plot_generated_batch(X_real_batch, generator_model,
        #                                             batch_size, noise_dim, image_dim_ordering,k)
        #             k = k +1
        #         if batch_counter >= n_batch_per_epoch:
        #             break
        #
        #     print("")
        #     print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))
        #
        #     if e % 5 == 0:
        #         gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
        #         generator_model.save_weights(gen_weights_path, overwrite=True)
        #
        #         disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
        #         discriminator_model.save_weights(disc_weights_path, overwrite=True)
        #
        #         DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
        #         DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Exemplo n.º 3
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    if dset == "celebA":
        X_real_train = data_utils.load_celebA(img_dim, image_data_format)
    if dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_data_format)
    img_dim = X_real_train.shape[-3:]
    noise_dim = (100, )

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
        opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)

        # Load generator model
        generator_model = models.load("generator_%s" % generator,
                                      noise_dim,
                                      img_dim,
                                      bn_mode,
                                      batch_size,
                                      dset=dset,
                                      use_mbd=use_mbd)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          noise_dim,
                                          img_dim,
                                          bn_mode,
                                          batch_size,
                                          dset=dset,
                                          use_mbd=use_mbd)

        generator_model.compile(loss='mse', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   noise_dim, img_dim)

        loss = ['binary_crossentropy']
        loss_weights = [1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_real_batch in data_utils.gen_batch(X_real_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_real_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    noise_scale=noise_scale,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen, y_gen = data_utils.get_gen_batch(
                    batch_size, noise_dim, noise_scale=noise_scale)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen)
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size,
                            values=[("D logloss", disc_loss),
                                    ("G logloss", gen_loss)])

                # Save images for visualization
                if batch_counter % 100 == 0:
                    data_utils.plot_generated_batch(X_real_batch,
                                                    generator_model,
                                                    batch_size, noise_dim,
                                                    image_data_format)

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                gen_weights_path = os.path.join(
                    '../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    '../../models/%s/disc_weights_epoch%s.h5' %
                    (model_name, e))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    '../../models/%s/DCGAN_weights_epoch%s.h5' %
                    (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Exemplo n.º 4
0
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    use_mbd = kwargs["use_mbd"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    pureGAN = kwargs["pureGAN"]
    lsmooth = kwargs["lsmooth"]
    disc_type = kwargs["disc_type"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    history_size = kwargs["history_size"]
    monsterClass = kwargs["monsterClass"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    if dset == "mnistM":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='mnist')
        #        X_source_train=np.concatenate([X_source_train,X_source_train,X_source_train], axis=1)
        #        X_source_test=np.concatenate([X_source_test,X_source_test,X_source_test], axis=1)
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='mnistM')
    elif dset == "OfficeDslrToAmazon":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='OfficeDslr')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='OfficeAmazon')
    else:
        print "dataset not supported"
    if n_classes1 != n_classes2:  #sanity check
        print "number of classes mismatch between source and dest domains"
    n_classes = n_classes1  #

    img_source_dim = X_source_train.shape[-3:]  # is it backend agnostic?
    img_dest_dim = X_dest_train.shape[-3:]

    # Create optimizers
    opt_D = data_utils.get_optimizer(opt_D, lr_D)
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_GC = data_utils.get_optimizer('Adam', lr_G / 10.0)
    opt_C = data_utils.get_optimizer('Adam', lr_D)
    opt_Z = data_utils.get_optimizer('Adam', lr_G)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    generator_model = models.generator_google_mnistM(noise_dim, img_source_dim,
                                                     img_dest_dim,
                                                     deterministic, pureGAN,
                                                     wd)
    #    discriminator_model = models.discriminator_google_mnistM(img_dest_dim, wd)
    discriminator_model = models.discriminator_dcgan(img_dest_dim, wd,
                                                     n_classes, disc_type)
    classificator_model = models.classificator_google_mnistM(
        img_dest_dim, n_classes, wd)
    DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model,
                                     noise_dim, img_source_dim)
    zclass_model = z_coerence(generator_model,
                              img_source_dim,
                              bn_mode,
                              wd,
                              inject_noise,
                              n_classes,
                              noise_dim,
                              model_name="zClass")
    #    GenToClassifier_model = models.GenToClassifierModel(generator_model, classificator_model, noise_dim, img_source_dim)
    #disc_penalty_model = models.disc_penalty(discriminator_model,noise_dim,img_source_dim,opt_D,model_name="disc_penalty_model")
    zclass_model = z_coerence(generator_model,
                              img_source_dim,
                              bn_mode,
                              wd,
                              inject_noise,
                              n_classes,
                              noise_dim,
                              model_name="zClass")

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)

    models.make_trainable(discriminator_model, False)
    models.make_trainable(classificator_model, False)
    #    models.make_trainable(disc_penalty_model, False)
    if model == 'wgan':
        DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
        models.make_trainable(discriminator_model, True)
        #     models.make_trainable(disc_penalty_model, True)
        discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
    if model == 'lsgan':
        if disc_type == "simple_disc":
            DCGAN_model.compile(loss=['mse'], optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(loss=['mse'], optimizer=opt_D)
        elif disc_type == "nclass_disc":
            DCGAN_model.compile(loss=['mse', 'categorical_crossentropy'],
                                loss_weights=[1.0, 0.1],
                                optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(
                loss=['mse', 'categorical_crossentropy'],
                loss_weights=[1.0, 0.1],
                optimizer=opt_D)
#    GenToClassifier_model.compile(loss='categorical_crossentropy', optimizer=opt_GC)
    models.make_trainable(classificator_model, True)
    classificator_model.compile(loss='categorical_crossentropy',
                                metrics=['accuracy'],
                                optimizer=opt_C)
    zclass_model.compile(loss=['mse'], optimizer=opt_Z)

    visualize = True
    ########
    #MAKING TRAIN+TEST numpy array for global testing:
    ########
    Xtarget_dataset = np.concatenate([X_dest_train, X_dest_test], axis=0)
    Ytarget_dataset = np.concatenate([Y_dest_train, Y_dest_test], axis=0)

    if resume:  ########loading previous saved model weights and checking actual performance
        data_utils.load_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, name, classificator_model,
                                      zclass_model)
#        data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name,classificator_model)
#        loss4, acc4 = classificator_model.evaluate(Xtarget_dataset, Ytarget_dataset,batch_size=1024, verbose=0)
#        print('\n Classifier Accuracy on full target domain:  %.2f%%' % (100 * acc4))

    else:
        X_gen = data_utils.sample_noise(noise_scale, X_source_train.shape[0],
                                        noise_dim)
        zclass_loss = zclass_model.fit([X_gen, X_source_train], [X_gen],
                                       batch_size=256,
                                       epochs=10)
    ####train zclass regression model only if not resuming:

    gen_iterations = 0
    max_history_size = int(history_size * batch_size)
    img_buffer = ImageHistoryBuffer((0, ) + img_source_dim, max_history_size,
                                    batch_size, n_classes)
    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:
            if no_supertrain is None:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                if gen_iterations % 500 == 0:
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]
            else:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = deque(10 * [0], 10)
            list_disc_loss_gen = deque(10 * [0], 10)
            list_gen_loss = deque(10 * [0], 10)
            list_zclass_loss = deque(10 * [0], 10)
            list_classifier_loss = deque(10 * [0], 10)
            list_gp_loss = deque(10 * [0], 10)
            for disc_it in range(disc_iterations):
                X_dest_batch, Y_dest_batch, idx_dest_batch = next(
                    data_utils.gen_batch(X_dest_train, Y_dest_train,
                                         batch_size))
                X_source_batch, Y_source_batch, idx_source_batch = next(
                    data_utils.gen_batch(X_source_train, Y_source_train,
                                         batch_size))
                ##########
                # Create a batch to feed the discriminator model
                #########
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                    X_dest_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    X_source_batch,
                    noise_scale=noise_scale)

                # Update the discriminator
                if model == 'wgan':
                    current_labels_real = -np.ones(X_disc_real.shape[0])
                    current_labels_gen = np.ones(X_disc_gen.shape[0])
                elif model == 'lsgan':
                    if disc_type == "simple_disc":
                        current_labels_real = np.ones(X_disc_real.shape[0])
                        current_labels_gen = np.zeros(X_disc_gen.shape[0])
                    elif disc_type == "nclass_disc":
                        virtual_real_labels = np.zeros(
                            [X_disc_gen.shape[0], n_classes])
                        current_labels_real = [
                            np.ones(X_disc_real.shape[0]), virtual_real_labels
                        ]
                        current_labels_gen = [
                            np.zeros(X_disc_gen.shape[0]), Y_source_batch
                        ]
                ##############
                #Train the disc on gen-buffered samples and on current real samples
                ##############
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, current_labels_real)
                img_buffer.add_to_buffer(X_disc_gen, current_labels_gen,
                                         batch_size)
                bufferImages, bufferLabels = img_buffer.get_from_buffer(
                    batch_size)
                disc_loss_gen = discriminator_model.train_on_batch(
                    bufferImages, bufferLabels)

                #if not isinstance(disc_loss_real, collections.Iterable): disc_loss_real = [disc_loss_real]
                #if not isinstance(disc_loss_real, collections.Iterable): disc_loss_gen = [disc_loss_gen]
                if disc_type == "simple_disc":
                    list_disc_loss_real.appendleft(disc_loss_real)
                    list_disc_loss_gen.appendleft(disc_loss_gen)
                elif disc_type == "nclass_disc":
                    list_disc_loss_real.appendleft(disc_loss_real[0])
                    list_disc_loss_gen.appendleft(disc_loss_gen[0])
                #############
                ####Train the discriminator w.r.t gradient penalty
                #############
                #gp_loss = disc_penalty_model.train_on_batch([X_disc_real,X_disc_gen],current_labels_real) #dummy labels,not used in the loss function
                #list_gp_loss.appendleft(gp_loss)

            ################
            ###CLASSIFIER TRAINING OUTSIDE DISC LOOP(wanna train in just 1 time even if disc_iter > 1)
            #################
            class_loss_gen = classificator_model.train_on_batch(
                X_disc_gen, Y_source_batch * 0.7)  #LABEL SMOOTHING!!!!
            list_classifier_loss.appendleft(class_loss_gen[1])
            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
            X_source_batch2, Y_source_batch2, idx_source_batch2 = next(
                data_utils.gen_batch(X_source_train, Y_source_train,
                                     batch_size))
            if model == 'wgan':
                gen_loss = DCGAN_model.train_on_batch([X_gen, X_source_batch2],
                                                      -np.ones(X_gen.shape[0]))
            if model == 'lsgan':
                if disc_type == "simple_disc":
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2],
                        np.ones(X_gen.shape[0]))  #TRYING SAME BATCH OF DISC
                elif disc_type == "nclass_disc":
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2],
                        [np.ones(X_gen.shape[0]), Y_source_batch2])
                    gen_loss = gen_loss[0]
            list_gen_loss.appendleft(gen_loss)

            zclass_loss = zclass_model.train_on_batch([X_gen, X_source_batch2],
                                                      [X_gen])
            list_zclass_loss.appendleft(zclass_loss)
            ##############
            #Train the generator w.r.t the aux classifier:
            #############
            #            GenToClassifier_model.train_on_batch([X_gen,X_source_batch2],Y_source_batch2)

            # I SHOULD TRY TO CLASSIFY EVEN ON DISCRIMINATOR, PUTTING ONE CLASS FOR REAL SAMPLES AND N CLASS FOR FAKE

            gen_iterations += 1
            batch_counter += 1

            progbar.add(batch_size,
                        values=[("Loss_D_real", np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", np.mean(list_gen_loss)),
                                ("Loss_Z", np.mean(list_zclass_loss)),
                                ("Loss_Classifier",
                                 np.mean(list_classifier_loss))])

            # plot images 1 times per epoch
            if batch_counter % (n_batch_per_epoch) == 0:
                X_source_batch_plot, Y_source_batch_plot, idx_source_plot = next(
                    data_utils.gen_batch(X_source_test,
                                         Y_source_test,
                                         batch_size=32))
                data_utils.plot_generated_batch(X_dest_test,
                                                X_source_test,
                                                generator_model,
                                                noise_dim,
                                                image_dim_ordering,
                                                idx_source_plot,
                                                batch_size=32)
            if gen_iterations % (n_batch_per_epoch * 5) == 0:
                if visualize:
                    BIG_ASS_VISUALIZATION_slerp(X_source_train[1],
                                                generator_model, noise_dim)
        #    if (e % 20) == 0:
        #        lr_decay([discriminator_model,DCGAN_model,classificator_model],decay_value=0.95)

        print("Dest labels:")
        print(Y_dest_test[idx_source_plot].argmax(1))
        print("Source labels:")
        print(Y_source_batch_plot.argmax(1))
        print('\nEpoch %s/%s, Time: %s' %
              (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, e, name,
                                      classificator_model, zclass_model)

        #testing accuracy of trained classifier
        loss4, acc4 = classificator_model.evaluate(Xtarget_dataset,
                                                   Ytarget_dataset,
                                                   batch_size=1024,
                                                   verbose=0)
        print(
            '\n Classifier Accuracy and loss on full target domain:  %.2f%% / %.5f%%'
            % ((100 * acc4), loss4))
Exemplo n.º 5
0
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    use_mbd = kwargs["use_mbd"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    pureGAN = kwargs["pureGAN"]
    lsmooth = kwargs["lsmooth"]
    simple_disc = kwargs["simple_disc"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    history_size = kwargs["history_size"]
    monsterClass = kwargs["monsterClass"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    if dset == "mnistM":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='mnist')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='mnistM')
        #code.interact(local=locals())
    elif dset == "washington_vandal50k":
        X_source_train = data_utils.load_image_dataset(img_dim,
                                                       image_dim_ordering,
                                                       dset='washington')
        X_dest_train = data_utils.load_image_dataset(img_dim,
                                                     image_dim_ordering,
                                                     dset='vandal50k')
    elif dset == "washington_vandal12classes":
        X_source_train = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='washington12classes')
        X_dest_train = data_utils.load_image_dataset(img_dim,
                                                     image_dim_ordering,
                                                     dset='vandal12classes')
    elif dset == "washington_vandal12classesNoBackground":
        X_source_train, Y_source_train, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='washington12classes')
        X_dest_train, Y_dest_train, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='vandal12classesNoBackground')
    elif dset == "Wash_Vand_12class_LMDB":
        X_source_train, Y_source_train, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='Wash_12class_LMDB')
    elif dset == "OfficeDslrToAmazon":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='OfficeDslr')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='OfficeAmazon')
    elif dset == "bedrooms":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='bedrooms_small')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='bedrooms')
    elif dset == "Vand_Vand_12class_LMDB":
        X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='Vand_12class_LMDB_Background')
        X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(
            img_dim, image_dim_ordering, dset='Vand_12class_LMDB')
    else:
        print "dataset not supported"
    if n_classes1 != n_classes2:  #sanity check
        print "number of classes mismatch between source and dest domains"
    n_classes = n_classes1  #
    img_source_dim = X_source_train.shape[-3:]  # is it backend agnostic?
    img_dest_dim = X_dest_train.shape[-3:]

    # Create optimizers
    opt_D = data_utils.get_optimizer(opt_D, lr_D)
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_C = data_utils.get_optimizer('SGD', 0.01)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    if generator == "upsampling":
        generator_model = models.generator_upsampling_mnistM(noise_dim,
                                                             img_source_dim,
                                                             img_dest_dim,
                                                             bn_mode,
                                                             deterministic,
                                                             pureGAN,
                                                             inject_noise,
                                                             wd,
                                                             dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim,
                                                  img_dest_dim,
                                                  bn_mode,
                                                  batch_size,
                                                  dset=dset)

    if simple_disc:
        discriminator_model = models.discriminator_naive(
            img_dest_dim, bn_mode, model, wd, inject_noise, n_classes, use_mbd)
        DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model,
                                         noise_dim, img_source_dim)
    elif discriminator == "disc_resnet":
        discriminator_model = models.discriminatorResNet(
            img_dest_dim, bn_mode, model, wd, monsterClass, inject_noise,
            n_classes, use_mbd)
        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   noise_dim, img_source_dim, img_dest_dim,
                                   monsterClass)
    else:
        discriminator_model = models.disc1(img_dest_dim, bn_mode, model, wd,
                                           monsterClass, inject_noise,
                                           n_classes, use_mbd)
        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   noise_dim, img_source_dim, img_dest_dim,
                                   monsterClass)

    ####special options for bedrooms dataset:
    if dset == "bedrooms":
        generator_model = models.generator_dcgan(noise_dim, img_source_dim,
                                                 img_dest_dim, bn_mode,
                                                 deterministic, pureGAN,
                                                 inject_noise, wd)
        discriminator_model = models.discriminator_naive(
            img_dest_dim,
            bn_mode,
            model,
            wd,
            inject_noise,
            n_classes,
            use_mbd,
            model_name="discriminator_naive")
        DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model,
                                         noise_dim, img_source_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)

    models.make_trainable(discriminator_model, False)
    #discriminator_model.trainable = False
    if model == 'wgan':
        DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
        models.make_trainable(discriminator_model, True)
        discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
    if model == 'lsgan':
        if simple_disc:
            DCGAN_model.compile(loss=['mse'], optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(loss=['mse'], optimizer=opt_D)
        elif monsterClass:
            DCGAN_model.compile(loss=['categorical_crossentropy'],
                                optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(loss=['categorical_crossentropy'],
                                        optimizer=opt_D)
        else:
            DCGAN_model.compile(loss=['mse', 'categorical_crossentropy'],
                                loss_weights=[1.0, 1.0],
                                optimizer=opt_G)
            models.make_trainable(discriminator_model, True)
            discriminator_model.compile(
                loss=['mse', 'categorical_crossentropy'],
                loss_weights=[1.0, 1.0],
                optimizer=opt_D)

    visualize = True

    if resume:  ########loading previous saved model weights
        data_utils.load_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, name)

    #####################
    ###classifier
    #####################
    if not ((dset == 'mnistM') or (dset == 'bedrooms')):
        classifier, GenToClassifierModel = classifier_build_test(
            img_dest_dim,
            n_classes,
            generator_model,
            noise_dim,
            noise_scale,
            img_source_dim,
            opt_C,
            X_source_test,
            Y_source_test,
            X_dest_test,
            Y_dest_test,
            wd=0.0001)

    gen_iterations = 0
    max_history_size = int(history_size * batch_size)
    img_buffer = ImageHistoryBuffer((0, ) + img_source_dim, max_history_size,
                                    batch_size, n_classes)
    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:
            if no_supertrain is None:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                if gen_iterations % 500 == 0:
                    disc_iterations = 10
                else:
                    disc_iterations = kwargs["disc_iterations"]
            else:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            list_gen_loss = []

            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                #for l in discriminator_model.layers:
                #    weights = l.get_weights()
                #    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
                #    l.set_weights(weights)

                X_dest_batch, Y_dest_batch, idx_dest_batch = next(
                    data_utils.gen_batch(X_dest_train, Y_dest_train,
                                         batch_size))
                X_source_batch, Y_source_batch, idx_source_batch = next(
                    data_utils.gen_batch(X_source_train, Y_source_train,
                                         batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                    X_dest_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    X_source_batch,
                    noise_scale=noise_scale)
                if model == 'wgan':
                    # Update the discriminator
                    current_labels_real = -np.ones(X_disc_real.shape[0])
                    current_labels_gen = np.ones(X_disc_gen.shape[0])
                if model == 'lsgan':
                    if simple_disc:  #for real domain I put [labels 0 0 0...0], for fake domain I put [0 0...0 labels]
                        current_labels_real = np.ones(X_disc_real.shape[0])
                        #current_labels_gen = -np.ones(X_disc_gen.shape[0])
                        current_labels_gen = np.zeros(X_disc_gen.shape[0])
                    elif monsterClass:  #for real domain I put [labels 0 0 0...0], for fake domain I put [0 0...0 labels]
                        current_labels_real = np.concatenate(
                            (Y_dest_batch,
                             np.zeros((X_disc_real.shape[0], n_classes))),
                            axis=1)
                        current_labels_gen = np.concatenate((np.zeros(
                            (X_disc_real.shape[0], n_classes)),
                                                             Y_source_batch),
                                                            axis=1)
                    else:
                        current_labels_real = [
                            np.ones(X_disc_real.shape[0]), Y_dest_batch
                        ]
                        Y_fake_batch = (1.0 / n_classes) * np.ones(
                            [X_disc_gen.shape[0], n_classes])
                        current_labels_gen = [
                            np.zeros(X_disc_gen.shape[0]), Y_fake_batch
                        ]
                #label smoothing
                #current_labels_real = np.multiply(current_labels_real, lsmooth) #usually lsmooth = 0.7
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, current_labels_real)
                img_buffer.add_to_buffer(X_disc_gen, current_labels_gen,
                                         batch_size)
                bufferImages, bufferLabels = img_buffer.get_from_buffer(
                    batch_size)
                disc_loss_gen = discriminator_model.train_on_batch(
                    bufferImages, bufferLabels)

                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
            X_source_batch2, Y_source_batch2, idx_source_batch2 = next(
                data_utils.gen_batch(X_source_train, Y_source_train,
                                     batch_size))
            #            w1 = classifier.get_weights() #FOR DEBUG
            if model == 'wgan':
                gen_loss = DCGAN_model.train_on_batch([X_gen, X_source_batch2],
                                                      -np.ones(X_gen.shape[0]))
            if model == 'lsgan':
                if simple_disc:
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2],
                        np.ones(X_gen.shape[0]))  #TRYING SAME BATCH OF DISC?
                elif monsterClass:
                    labels_gen = np.concatenate(
                        (Y_source_batch2,
                         np.zeros((X_disc_real.shape[0], n_classes))),
                        axis=1)
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2], labels_gen)
                else:
                    gen_loss = DCGAN_model.train_on_batch(
                        [X_gen, X_source_batch2],
                        [np.ones(X_gen.shape[0]), Y_source_batch2])
#            gen_loss2 = GenToClassifierModel.train_on_batch([X_gen,X_source_batch2], Y_source_batch2)

#            w2 = classifier.get_weights() #FOR DEBUG
#           for a,b in zip(w1, w2):
#              if np.all(a == b):
#                 print "no bug in GEN model update"
#            else:
#               print "BUG IN GEN MODEL UPDATE"
            list_gen_loss.append(gen_loss)

            gen_iterations += 1
            batch_counter += 1

            progbar.add(batch_size,
                        values=[("Loss_D", 0.5 * np.mean(list_disc_loss_real) +
                                 0.5 * np.mean(list_disc_loss_gen)),
                                ("Loss_D_real", np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", np.mean(list_gen_loss))])

            # plot images 1 times per epoch
            if batch_counter % (n_batch_per_epoch) == 0:
                X_source_batch_plot, Y_source_batch_plot, idx_source_plot = next(
                    data_utils.gen_batch(X_source_test,
                                         Y_source_test,
                                         batch_size=32))
                data_utils.plot_generated_batch(X_dest_test,
                                                X_source_test,
                                                generator_model,
                                                noise_dim,
                                                image_dim_ordering,
                                                idx_source_plot,
                                                batch_size=32)
            if gen_iterations % (n_batch_per_epoch * 5) == 0:
                if visualize:
                    BIG_ASS_VISUALIZATION_slerp(X_source_train[1],
                                                generator_model, noise_dim)

        print("Dest labels:")
        print(Y_dest_test[idx_source_plot].argmax(1))
        print("Source labels:")
        print(Y_source_batch_plot.argmax(1))
        print('\nEpoch %s/%s, Time: %s' %
              (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, e, name)
def train_toy(**kwargs):
    """
    Train model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("toy_MLP")

    # Load and rescale data
    X_real_train = data_utils.load_toy()

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    generator_model = models.generator_toy(noise_dim)
    discriminator_model = models.discriminator_toy()
    GAN_model = models.GAN_toy(generator_model, discriminator_model, noise_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    GAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    #################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [
                        np.clip(w, clamp_lower, clamp_upper) for w in weights
                    ]
                    l.set_weights(weights)

                X_real_batch = next(
                    data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                    X_real_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, -np.ones(X_disc_real.shape[0]))
                disc_loss_gen = discriminator_model.train_on_batch(
                    X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = GAN_model.train_on_batch(X_gen,
                                                -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            batch_counter += 1
            progbar.add(batch_size,
                        values=[("Loss_D", -np.mean(list_disc_loss_real) -
                                 np.mean(list_disc_loss_gen)),
                                ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", -gen_loss)])

            # # Save images for visualization
            if gen_iterations % 50 == 0:
                data_utils.plot_generated_toy_batch(X_real_train,
                                                    generator_model,
                                                    discriminator_model,
                                                    noise_dim, gen_iterations)
            gen_iterations += 1

        print('\nEpoch %s/%s, Time: %s' %
              (e + 1, nb_epoch, time.time() - start))
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]

    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(dset, image_dim_ordering)
    img_dim = X_full_train.shape[-3:]

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_dim_ordering)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator,
                                      img_dim,
                                      nb_patch,
                                      bn_mode,
                                      use_mbd,
                                      batch_size)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          img_dim_disc,
                                          nb_patch,
                                          bn_mode,
                                          use_mbd,
                                          batch_size)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model,
                                   discriminator_model,
                                   img_dim,
                                   patch_size,
                                   image_dim_ordering)

        loss = [l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_full_batch, X_sketch_batch in data_utils.gen_batch(X_full_train, X_sketch_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(X_full_batch,
                                                           X_sketch_batch,
                                                           generator_model,
                                                           batch_counter,
                                                           patch_size,
                                                           image_dim_ordering,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(data_utils.gen_batch(X_full_train, X_sketch_train, batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size, values=[("D logloss", disc_loss),
                                                ("G tot", gen_loss[0]),
                                                ("G L1", gen_loss[1]),
                                                ("G logloss", gen_loss[2])])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model,
                                                    batch_size, image_dim_ordering, "training")
                    X_full_batch, X_sketch_batch = next(data_utils.gen_batch(X_full_val, X_sketch_val, batch_size))
                    data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model,
                                                    batch_size, image_dim_ordering, "validation")

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
                discriminator_model.save_weights(disc_weights_path, overwrite=True)

                DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Exemplo n.º 8
0
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    use_mbd = kwargs["use_mbd"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    pureGAN = kwargs["pureGAN"]
    lsmooth = kwargs["lsmooth"]
    disc_type = kwargs["disc_type"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    history_size = kwargs["history_size"]
    monsterClass = kwargs["monsterClass"]
    data_aug = kwargs["data_aug"]
    disc_iters = kwargs["disc_iterations"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")
    #####some extra parameters:
    
    noise_dim = (noise_dim,)
    name1 = name + '1'
    name2 = name + '2'
    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")
    gen_iterations = 0
    # Loading data
    A_data, A_labels, B_data, B_labels, n_classes, img_A_dim, img_B_dim = load_data(
        img_dim, image_dim_ordering, dset)
     
    # Setup GAN1
    deterministic1 = False
    opt_D1, opt_G1, opt_C1, opt_Z1 = build_opt(opt_D, opt_G, lr_D, lr_G)
    generator_model1, discriminator_model1,discriminator_class1, classificator_model1, DCGAN_model1, zclass_model1 = load_compile_models(noise_dim, img_A_dim, img_B_dim, deterministic1, pureGAN, wd, 'mse', 'categorical_crossentropy', disc_type, n_classes, opt_D1, opt_G1, opt_C1, opt_Z1)
    load_pretrained_weights(generator_model1, discriminator_model1,discriminator_class1, DCGAN_model1, name1, B_data, B_labels, noise_scale, classificator_model1, resume=resume)
    img_buffer1, datagen1 = load_buffer_and_augmentation(history_size, batch_size, img_A_dim, n_classes)
    ##temporary settings:
    gen_entropy1=None
    GAN1=_GAN(generator_model1, discriminator_model1, discriminator_class1,DCGAN_model1,gen_entropy1,classificator_model1, batch_size, img_A_dim,img_B_dim, noise_dim, noise_scale,
               lr_D, lr_G, deterministic1, inject_noise, model, lsmooth, img_buffer1, datagen1, disc_type, data_aug, n_classes, disc_iters,name1, dir='AtoB' )
    pretrain_disc( GAN1, A_data,A_labels, B_data, B_labels, pretrain_iters=500, resume=resume)
    #####################

   # Setup GAN2
    deterministic2 = True
    opt_D2, opt_G2, opt_C2, opt_Z2 = build_opt(opt_D, opt_G, lr_D, lr_G)
    generator_model2, discriminator_model2, discriminator_class2, classificator_model2, DCGAN_model2, zclass_model2 = load_compile_models(noise_dim, img_B_dim, img_A_dim, deterministic2, pureGAN, wd, 'mse', 'categorical_crossentropy', disc_type, n_classes, opt_D2, opt_G2, opt_C2, opt_Z2)
    load_pretrained_weights(generator_model2, discriminator_model2,discriminator_class2, DCGAN_model2, name2, B_data, B_labels, noise_scale, classificator_model2, resume=resume)
    img_buffer2, datagen2 = load_buffer_and_augmentation(history_size, batch_size, img_B_dim, n_classes)

    ##temporary settings:
    gen_entropy2=None
    GAN2=_GAN(generator_model2, discriminator_model2, discriminator_class2, DCGAN_model2,gen_entropy2,classificator_model2, batch_size, img_B_dim,img_A_dim, noise_dim, noise_scale,
               lr_D, lr_G, deterministic2, inject_noise, model, lsmooth, img_buffer2, datagen2, disc_type, data_aug, n_classes, disc_iters, name2, dir='BtoA' )
    pretrain_disc( GAN2, A_data,A_labels, B_data, B_labels, pretrain_iters=500, resume=resume)


    ################
    ##################
    for e in range(1, nb_epoch + 1):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size,interval=0.2)
        batch_counter = 1
        start = time.time()
        while batch_counter < n_batch_per_epoch:
            l_disc_real1, l_disc_gen1, l_gen1, l_z1, l_class1 = get_loss_list()
            A_data_batch, A_labels_batch, B_data_batch, B_labels_batch = train_gan(GAN1, GAN1.disc_iters, A_data, A_labels, B_data, B_labels, batch_counter, l_disc_real1, l_disc_gen1, l_gen1)
            l_class1 = train_class(GAN1, l_class1,  A_data_batch, A_labels_batch)

            l_disc_real2, l_disc_gen2, l_gen2, l_z2, l_class2 = get_loss_list()
            A_data_batch, A_labels_batch, B_data_batch, B_labels_batch = train_gan(GAN2, GAN2.disc_iters, A_data, A_labels, B_data, B_labels, batch_counter, l_disc_real2, l_disc_gen2,l_gen2)
            l_class2 = train_class(GAN2, l_class2,  A_data_batch, A_labels_batch)
            batch_counter, gen_iterations = visualize_save_stuffs([GAN1,GAN2], progbar, gen_iterations, batch_counter, n_batch_per_epoch, 
                                                                              l_disc_real1, l_disc_gen1, l_gen1, l_class1, l_disc_real2, l_disc_gen2,
                                                                              l_gen2, l_class2, A_data, A_labels, B_data, B_labels,start,e)

#gen_iterations, batch_counter, idx, Yplot

        testing_class_accuracy([GAN1,GAN2],GAN1.classificator_model, GAN1.generator_model,
                               5000, GAN1.noise_dim, GAN1.noise_scale, B_data, B_labels)
Exemplo n.º 9
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    # right strip '/' to avoid empty '/' dir
    save_dir = kwargs["save_dir"].rstrip('/')
    # join name with current datetime
    save_dir = '_'.join(
        [save_dir,
         datetime.datetime.now().strftime("%I:%M%p-%B%d-%Y/")])

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    # save the config in save dir
    with open('{0}job_config.json'.format(save_dir), 'w') as fp:
        json.dump(kwargs, fp, sort_keys=True, indent=4)

    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(
        dset, image_data_format)
    img_dim = X_full_train.shape[-3:]

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size,
                                                     image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator, img_dim,
                                      nb_patch, bn_mode, use_mbd, batch_size)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator", img_dim_disc,
                                          nb_patch, bn_mode, use_mbd,
                                          batch_size)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_data_format)

        loss = [l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_full_batch, X_sketch_batch in data_utils.gen_batch(
                    X_full_train, X_sketch_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_full_batch,
                    X_sketch_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    image_data_format,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(
                    data_utils.gen_batch(X_full_train, X_sketch_train,
                                         batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size,
                            values=[("D logloss", disc_loss),
                                    ("G tot", gen_loss[0]),
                                    ("G L1", gen_loss[1]),
                                    ("G logloss", gen_loss[2])])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(
                        X_full_batch, X_sketch_batch, generator_model,
                        batch_size, image_data_format,
                        "{:03}_EPOCH_TRAIN".format(e + 1), save_dir)
                    X_full_batch, X_sketch_batch = next(
                        data_utils.gen_batch(X_full_val, X_sketch_val,
                                             batch_size))
                    data_utils.plot_generated_batch(
                        X_full_batch, X_sketch_batch, generator_model,
                        batch_size, image_data_format,
                        "{:03}_EPOCH_VALID".format(e + 1), save_dir)

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                pass
                # save models
                # gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                # generator_model.save_weights(gen_weights_path, overwrite=True)

                # disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
                # discriminator_model.save_weights(disc_weights_path, overwrite=True)

                # DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
                # DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass

    # save models
    DCGAN_model.save(save_dir + 'DCGAN.h5')
    generator_model.save(save_dir + 'GENERATOR.h5')
    discriminator_model.save(save_dir + 'DISCRIMINATOR.h5')
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    img_dim = kwargs["img_dim"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    if dset == "celebA":
        X_real_train = data_utils.load_celebA(img_dim, image_dim_ordering)
    if dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering)
    img_dim = X_real_train.shape[-3:]
    noise_dim = (100,)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
        opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)

        # Load generator model
        generator_model = models.load("generator_%s" % generator,
                                      noise_dim,
                                      img_dim,
                                      bn_mode,
                                      batch_size,
                                      dset=dset,
                                      use_mbd=use_mbd)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          noise_dim,
                                          img_dim,
                                          bn_mode,
                                          batch_size,
                                          dset=dset,
                                          use_mbd=use_mbd)

        generator_model.compile(loss='mse', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model,
                                   discriminator_model,
                                   noise_dim,
                                   img_dim)

        loss = ['binary_crossentropy']
        loss_weights = [1]
        DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_real_batch in data_utils.gen_batch(X_real_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(X_real_batch,
                                                           generator_model,
                                                           batch_counter,
                                                           batch_size,
                                                           noise_dim,
                                                           noise_scale=noise_scale,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen, y_gen = data_utils.get_gen_batch(batch_size, noise_dim, noise_scale=noise_scale)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen)
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size, values=[("D logloss", disc_loss),
                                                ("G logloss", gen_loss)])

                # Save images for visualization
                if batch_counter % 100 == 0:
                    data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                    batch_size, noise_dim, image_dim_ordering)

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
                discriminator_model.save_weights(disc_weights_path, overwrite=True)

                DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Exemplo n.º 11
0
def eval(**kwargs):

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    img_dim = kwargs["img_dim"]
    cont_dim = (kwargs["cont_dim"],)
    cat_dim = (kwargs["cat_dim"],)
    noise_dim = (kwargs["noise_dim"],)
    bn_mode = kwargs["bn_mode"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    epoch = kwargs["epoch"]

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    if dset == "RGZ":
        X_real_train = data_utils.load_RGZ(img_dim, image_dim_ordering)
    if dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering)
    img_dim = X_real_train.shape[-3:]

    # Load generator model
    generator_model = models.load("generator_%s" % generator,
                                  cat_dim,
                                  cont_dim,
                                  noise_dim,
                                  img_dim,
                                  bn_mode,
                                  batch_size,
                                  dset=dset)

    # Load colorization model
    generator_model.load_weights("../../models/%s/gen_weights_epoch%s.h5" %
                                 (model_name, epoch))

    X_plot = []
    # Vary the categorical variable
    for i in range(cat_dim[0]):
        X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
        X_cont = data_utils.sample_noise(noise_scale, batch_size, cont_dim)
        X_cont = np.repeat(X_cont[:1, :], batch_size, axis=0)  # fix continuous noise
        X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32')
        X_cat[:, i] = 1  # always the same categorical value

        X_gen = generator_model.predict([X_cat, X_cont, X_noise])
        X_gen = data_utils.inverse_normalization(X_gen)

        if image_dim_ordering == "th":
            X_gen = X_gen.transpose(0,2,3,1)

        X_gen = [X_gen[i] for i in range(len(X_gen))]
        X_plot.append(np.concatenate(X_gen, axis=1))
    X_plot = np.concatenate(X_plot, axis=0)

    plt.figure(figsize=(8,10))
    if X_plot.shape[-1] == 1:
        plt.imshow(X_plot[:, :, 0], cmap="gray")
    else:
        plt.imshow(X_plot)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("Varying categorical factor", fontsize=28, labelpad=60)

    plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.tight_layout()
    plt.savefig("../../figures/varying_categorical.png")
    plt.clf()
    plt.close()

    # Vary the continuous variables
    X_plot = []
    # First get the extent of the noise sampling
    x = np.ravel(data_utils.sample_noise(noise_scale, batch_size * 20000, cont_dim))
    # Define interpolation points
    x = np.linspace(x.min(), x.max(), num=batch_size)
    for i in range(batch_size):
        X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
        X_cont = np.concatenate([np.array([x[i], x[j]]).reshape(1, -1) for j in range(batch_size)], axis=0)
        X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32')
        X_cat[:, 1] = 1  # always the same categorical value

        X_gen = generator_model.predict([X_cat, X_cont, X_noise])
        X_gen = data_utils.inverse_normalization(X_gen)
        if image_dim_ordering == "th":
            X_gen = X_gen.transpose(0,2,3,1)
        X_gen = [X_gen[i] for i in range(len(X_gen))]
        X_plot.append(np.concatenate(X_gen, axis=1))
    X_plot = np.concatenate(X_plot, axis=0)

    plt.figure(figsize=(10,10))
    if X_plot.shape[-1] == 1:
        plt.imshow(X_plot[:, :, 0], cmap="gray")
    else:
        plt.imshow(X_plot)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("Varying continuous factor 1", fontsize=28, labelpad=60)
    plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.xlabel("Varying continuous factor 2", fontsize=28, labelpad=60)
    plt.annotate('', xy=(1, -0.05), xycoords='axes fraction', xytext=(0, -0.05),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.tight_layout()
    plt.savefig("../../figures/varying_continuous.png")
    plt.clf()
    plt.close()
Exemplo n.º 12
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]

    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)
    print "hi"

    # Load and rescale data
    X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(
        dset, image_dim_ordering)
    img_dim = X_full_train.shape[-3:]
    print "data loaded in memory"

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size,
                                                     image_dim_ordering)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator, img_dim,
                                      nb_patch, bn_mode, use_mbd, batch_size)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator", img_dim_disc,
                                          nb_patch, bn_mode, use_mbd,
                                          batch_size)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_dim_ordering)

        loss = [l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        gen_loss = None
        disc_loss = None

        iter_num = 102
        weights_path = "/home/abhik/pix2pix/src/model/weights/gen_weights_iter%s_epoch30.h5" % (
            str(iter_num - 1))
        print weights_path
        generator_model.load_weights(weights_path)

        #discriminator_model.load_weights("disc_weights1.2.h5")

        #DCGAN_model.load_weights("DCGAN_weights1.2.h5")
        print("Weights Loaded for iter - %d" % iter_num)

        # Running average
        losses_list = list()
        # loss_list = list()
        # prev_avg = 0

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            # global disc_n, disc_prev_avg, gen1_n, gen1_prev_avg, gen2_n, gen2_prev_avg, gen3_n, gen3_prev_avg

            # disc_n = 1
            # disc_prev_avg = 0

            # gen1_n = 1
            # gen1_prev_avg = 0

            # gen2_n = 1
            # gen2_prev_avg = 0

            # gen3_n = 1
            # gen3_prev_avg = 0

            for X_full_batch, X_sketch_batch in data_utils.gen_batch(
                    X_full_train, X_sketch_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_full_batch,
                    X_sketch_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    image_dim_ordering,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(
                    data_utils.gen_batch(X_full_train, X_sketch_train,
                                         batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                # Running average
                # loss_list.append(disc_loss)
                # loss_list_n = len(loss_list)
                # new_avg = ((loss_list_n-1)*prev_avg + disc_loss)/loss_list_n
                # prev_avg = new_avg

                # disc_avg, gen1_avg, gen2_avg, gen3_avg = running_avg(disc_loss, gen_loss[0], gen_loss[1], gen_loss[2])

                # print("running disc loss", new_avg)
                # print(disc_loss, gen_loss)
                # print ("all losses", disc_avg, gen1_avg, gen2_avg, gen3_avg)
                # print("")

                batch_counter += 1
                progbar.add(batch_size,
                            values=[("D logloss", disc_loss),
                                    ("G tot", gen_loss[0]),
                                    ("G L1", gen_loss[1]),
                                    ("G logloss", gen_loss[2])])

                # Saving data for plotting
                # losses = [e+1, batch_counter, disc_loss, gen_loss[0], gen_loss[1], gen_loss[2], disc_avg, gen1_avg, gen2_avg, gen3_avg, iter_num]
                # losses_list.append(losses)

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(
                        X_full_batch, X_sketch_batch, generator_model,
                        batch_size, image_dim_ordering, "training", iter_num)
                    X_full_batch, X_sketch_batch = next(
                        data_utils.gen_batch(X_full_val, X_sketch_val,
                                             batch_size))
                    data_utils.plot_generated_batch(
                        X_full_batch, X_sketch_batch, generator_model,
                        batch_size, image_dim_ordering, "validation", iter_num)

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))

            #Running average
            disc_avg, gen1_avg, gen2_avg, gen3_avg = running_avg(
                disc_loss, gen_loss[0], gen_loss[1], gen_loss[2])

            #Validation loss
            y_gen_val = np.zeros((X_sketch_batch.shape[0], 2), dtype=np.uint8)
            y_gen_val[:, 1] = 1
            val_loss = DCGAN_model.test_on_batch(X_full_batch,
                                                 [X_sketch_batch, y_gen_val])
            # print "val_loss ===" + str(val_loss)

            #logging
            # Saving data for plotting
            losses = [
                e + 1, iter_num, disc_loss, gen_loss[0], gen_loss[1],
                gen_loss[2], disc_avg, gen1_avg, gen2_avg, gen3_avg,
                val_loss[0], val_loss[1], val_loss[2]
            ]
            losses_list.append(losses)

            if (e + 1) % 5 == 0:
                gen_weights_path = os.path.join(
                    '../../models/%s/gen_weights_iter%s_epoch%s.h5' %
                    (model_name, iter_num, e + 1))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    '../../models/%s/disc_weights_iter%s_epoch%s.h5' %
                    (model_name, iter_num, e + 1))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    '../../models/%s/DCGAN_weights_iter%s_epoch%s.h5' %
                    (model_name, iter_num, e + 1))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

                loss_array = np.asarray(losses_list)
                print(loss_array.shape)  # 10 element vector

                loss_path = os.path.join(
                    '../../losses/loss_iter%s_epoch%s.csv' % (iter_num, e + 1))
                np.savetxt(loss_path, loss_array, fmt='%.5f', delimiter=',')
                np.savetxt('test.csv', loss_array, fmt='%.5f', delimiter=',')

    except KeyboardInterrupt:
        pass
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    prob = kwargs["prob"]
    training_data_file = kwargs["training_data_file"]
    experiment = kwargs["experiment"]
    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(experiment)

    # Create a batch generator for the color data
    DataAug = batch_utils.AugDataGenerator(training_data_file,
                                           batch_size=batch_size,
                                           prob=prob,
                                           dset="training_color")
    DataAug.add_transform("h_flip")

    # Load the array of quantized ab value
    q_ab = np.load("../../data/processed/pts_in_hull.npy")
    nb_q = q_ab.shape[0]
    nb_neighbors = 10
    # Fit a NN to q_ab
    nn_finder = nn.NearestNeighbors(n_neighbors=nb_neighbors, algorithm='ball_tree').fit(q_ab)

    # Load the color prior factor that encourages rare colors
    prior_factor = np.load("../../data/processed/training_64_prior_factor.npy")

    # Load and rescale data
    print("Loading data")
    with h5py.File(training_data_file, "r") as hf:
        X_train = hf["training_lab_data"][:100]
        c, h, w = X_train.shape[1:]
    print("Data loaded")

    for f in glob.glob("*.h5"):
        os.remove(f)

    for f in glob.glob("../../reports/figures/*.png"):
        os.remove(f)

    try:

        # Create optimizers
        # opt = SGD(lr=5E-4, momentum=0.9, nesterov=True)
        opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        # Load colorizer model
        color_model = models.load("simple_colorful", nb_q, (1, h, w), batch_size)
        color_model.compile(loss='categorical_crossentropy_color', optimizer=opt)

        color_model.summary()
        from keras.utils.visualize_util import plot
        plot(color_model, to_file='colorful.png')

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for batch in DataAug.gen_batch_colorful(X_train, nn_finder, nb_q, prior_factor):

                X_batch_black, X_batch_color, Y_batch = batch
                # X = color_model.predict(X_batch_black)

                # print color_model.evaluate(X_batch_black, Y_batch)
                # X = color_model.predict(X_batch_black)
                # print X[0, 0, 0, :]

                train_loss = color_model.train_on_batch(X_batch_black / 100., Y_batch)

                batch_counter += 1
                progbar.add(batch_size, values=[("loss", train_loss)])

                if batch_counter >= n_batch_per_epoch:
                    break
            print("")
            print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

            # Format X_colorized
            X_colorized = color_model.predict(X_batch_black / 100.)[:, :, :, :-1]
            X_colorized = X_colorized.reshape((batch_size * h * w, nb_q))
            X_colorized = q_ab[np.argmax(X_colorized, 1)]
            X_a = X_colorized[:, 0].reshape((batch_size, 1, h, w))
            X_b = X_colorized[:, 1].reshape((batch_size, 1, h, w))
            X_colorized = np.concatenate((X_batch_black, X_a, X_b), axis=1).transpose(0, 2, 3, 1)
            X_colorized = [np.expand_dims(color.lab2rgb(im), 0) for im in X_colorized]
            X_colorized = np.concatenate(X_colorized, 0).transpose(0, 3, 1, 2)

            X_batch_color = [np.expand_dims(color.lab2rgb(im.transpose(1, 2, 0)), 0) for im in X_batch_color]
            X_batch_color = np.concatenate(X_batch_color, 0).transpose(0, 3, 1, 2)

            print X_batch_color.shape, X_colorized.shape, X_batch_black.shape

            for i, img in enumerate(X_colorized[:min(32, batch_size)]):
                arr = np.concatenate([X_batch_color[i], np.repeat(X_batch_black[i] / 100., 3, axis=0), img], axis=2)
                np.save("../../reports/gen_image_%s.npy" % i, arr)

            plt.figure(figsize=(20,20))
            list_img = glob.glob("../../reports/*.npy")
            list_img = [np.load(im) for im in list_img]
            list_img = [np.concatenate(list_img[4 * i: 4 * (i + 1)], axis=2) for i in range(len(list_img) / 4)]
            arr = np.concatenate(list_img, axis=1)
            plt.imshow(arr.transpose(1,2,0))
            ax = plt.gca()
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])
            plt.tight_layout()
            plt.savefig("../../reports/figures/fig_epoch%s.png" % e)
            plt.clf()
            plt.close()

    except KeyboardInterrupt:
        pass
Exemplo n.º 14
0
def trainClassAux(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    noClass = kwargs["noClass"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    C_weight = kwargs["C_weight"]
    monsterClass = kwargs["monsterClass"]
    pretrained = kwargs["pretrained"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    if dset == "mnistM":
        X_source_train,Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnist')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnistM')
    elif dset == "washington_vandal50k":
        X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington')
        X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal50k')
    elif dset == "washington_vandal12classes":
        X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes')
        X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classes')
    elif dset == "washington_vandal12classesNoBackground":
        X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes')
        X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classesNoBackground')
    elif dset == "Wash_Vand_12class_LMDB":
        X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_12class_LMDB')
        X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB')
    elif dset == "Vand_Vand_12class_LMDB":
        X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB_Background')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB')
    elif dset == "Wash_Color_LMDB":
        X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB')
    else:
        print "dataset not supported"
    if n_classes1 != n_classes2: #sanity check
        print "number of classes mismatch between source and dest domains"
    n_classes = n_classes1 #
    img_source_dim = X_source_train.shape[-3:] # is it backend agnostic?
    img_dest_dim = X_dest_train.shape[-3:] 

    X_source_train.flags.writeable = False
    X_source_test.flags.writeable = False
    X_dest_train.flags.writeable = False
    X_dest_test.flags.writeable = False
    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_G_C = data_utils.get_optimizer(opt_G, lr_G*C_weight)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)
    opt_C = data_utils.get_optimizer('SGD', 0.01)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    if generator == "upsampling":
        generator_model = models.generator_upsampling_mnistM(noise_dim, img_source_dim,img_dest_dim, bn_mode,deterministic,inject_noise,wd, dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim, img_dest_dim, bn_mode, batch_size, dset=dset)

    discriminator_model = models.discriminator(img_dest_dim, bn_mode,model,wd,monsterClass,inject_noise,n_classes)
    DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model, noise_dim, img_source_dim)
    classifier = models.resnet(img_dest_dim,n_classes,pretrained,wd=0.0001) #it is img_dest_dim because it is actually the generated image dim,that is equal to dest_dim
 
    GenToClassifierModel = models.GenToClassifierModel(generator_model, classifier, noise_dim, img_source_dim)

    ############################
    # Load weights
    ############################
    if resume:
        data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name)
    #if pretrained:
    #    model_path = "../../models/DCGAN"
    #    class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5')
    #    classifier.load_weights(class_weights_path)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)

    if model == 'wgan':
        discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
        models.make_trainable(discriminator_model, False)
        DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    if model == 'lsgan':
        discriminator_model.compile(loss='mse', optimizer=opt_D)
        models.make_trainable(discriminator_model, False)
        DCGAN_model.compile(loss='mse',  optimizer=opt_G)

    classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) # it is actually never using optimizer
    models.make_trainable(classifier, False)
    GenToClassifierModel.compile(loss='categorical_crossentropy', optimizer=opt_G,metrics=['accuracy'])


    #######################
    # Train classifier
    #######################
#    if not pretrained: 
#        #print ("Testing accuracy on target domain test set before training:")
#        loss1,acc1 =classifier.evaluate(X_dest_test, Y_dest_test,batch_size=256, verbose=0)
#        print('\n Classifier Accuracy on target domain test set before training: %.2f%%' % (100 * acc1))
#        classifier.fit(X_dest_train, Y_dest_train, validation_split=0.1, batch_size=512, nb_epoch=10, verbose=1)
#        print ("\n Testing accuracy on target domain test set AFTER training:")
#    else:
#        print ("Loaded pretrained classifier, computing accuracy on target domain test set:")
#    loss2,acc2 = classifier.evaluate(X_dest_test, Y_dest_test,batch_size=512, verbose=0)
#    print('\n Classifier Accuracy on target domain test set after training:  %.2f%%' % (100 * acc2))
    #print ("Testing accuracy on source domain test set:")
#    loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0)
#    print('\n Classifier Accuracy on source domain test set:  %.2f%%' % (100 * acc3))
#    evaluating_GENned(noise_scale, noise_dim, X_source_test, Y_source_test, classifier, generator_model)

#    model_path = "../../models/DCGAN"
#    class_weights_path = os.path.join(model_path, 'VandToVand_5epochs.h5')
#    classifier.save_weights(class_weights_path, overwrite=True) 
    #################
    # GAN training
    ################
    gen_iterations = 0 
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:
            if no_supertrain is None:
                if (gen_iterations < 25) and (not resume):
                    disc_iterations = 100
                if gen_iterations % 500 == 0:
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]
            else:
                if (gen_iterations <25) and (not resume):
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            list_gen_loss = []
            list_class_loss_real = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
#                for l in discriminator_model.layers:
#                    weights = l.get_weights()
#                    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
#                    l.set_weights(weights)

                X_dest_batch, Y_dest_batch,idx_dest_batch = next(data_utils.gen_batch(X_dest_train, Y_dest_train, batch_size))
                X_source_batch, Y_source_batch,idx_source_batch = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_dest_batch,
                                                                    generator_model,
                                                                    batch_counter,
                                                                    batch_size,
                                                                    noise_dim,
                                                                    X_source_batch,
                                                                    noise_scale=noise_scale)
                if model == 'wgan':
                # Update the discriminator
                    disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0]))
                    disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
                if model == 'lsgan':
                    disc_loss_real = discriminator_model.train_on_batch(X_disc_real, np.ones(X_disc_real.shape[0])) 
                    disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.zeros(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)


            #######################
            # 2) Train the generator with GAN loss
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
            source_images = X_source_train[np.random.randint(0,X_source_train.shape[0],size=batch_size),:,:,:]
            X_source_batch2, Y_source_batch2,idx_source_batch2 = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size))
            # Freeze the discriminator
 #           discriminator_model.trainable = False
            if model == 'wgan':
                gen_loss = DCGAN_model.train_on_batch([X_gen,X_source_batch2], -np.ones(X_gen.shape[0]))
            if model == 'lsgan':
                gen_loss = DCGAN_model.train_on_batch([X_gen,X_source_batch2], np.ones(X_gen.shape[0]))
            list_gen_loss.append(gen_loss)


            #######################
            # 3) Train the generator with Classifier loss
            #######################
            w1 = classifier.get_weights() #FOR DEBUG
            if not noClass:
                new_gen_loss = GenToClassifierModel.train_on_batch([X_gen,X_source_batch2], Y_source_batch2)
                list_class_loss_real.append(new_gen_loss)
            else:
                list_class_loss_real.append(0.0)
            w2 = classifier.get_weights() #FOR DEBUG
            for a,b in zip(w1, w2):
                if np.all(a == b):
                    print "no bug in GEN model update"
                else:
                    print "BUG IN GEN MODEL UPDATE"

            gen_iterations += 1
            batch_counter += 1

            progbar.add(batch_size, values=[("Loss_D", 0.5*np.mean(list_disc_loss_real) + 0.5*np.mean(list_disc_loss_gen)),
                                            ("Loss_D_real", np.mean(list_disc_loss_real)),
                                            ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                            ("Loss_G", np.mean(list_gen_loss)),
                                            ("Loss_classifier", np.mean(list_class_loss_real))])

            # plot images 1 times per epoch
            if batch_counter % (n_batch_per_epoch) == 0:
          #      train_WGAN.plot_images(X_dest_batch)
                X_dest_batch_plot,Y_dest_batch_plot,idx_dest_plot = next(data_utils.gen_batch(X_dest_train,Y_dest_train, batch_size=32))
                X_source_batch_plot,Y_source_batch_plot,idx_source_plot = next(data_utils.gen_batch(X_source_train,Y_source_train, batch_size=32))

                data_utils.plot_generated_batch(X_dest_train,X_source_train, generator_model,
                                                 noise_dim, image_dim_ordering,idx_source_plot,batch_size=32)
        print ("Dest labels:") 
        print (Y_dest_train[idx_source_plot].argmax(1))
        print ("Source labels:") 
        print (Y_source_batch_plot.argmax(1))
        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, name)
        evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model)


        loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0)
        print('\n Classifier Accuracy on source domain test set:  %.2f%%' % (100 * acc3))
image_data_format = 'channels_last'
img_dim = 256
patch_size = [128, 128]
bn_mode = 2
label_smoothing = False
label_flipping = 0
data_folder = '/Lab1/Lab6/data_pix2pix/data/processed/'
dset = 'chest_xray'
use_mbd = False
do_plot = False
logging_dir = './pix2pix/logging_dir_pix2pix/'

epoch_size = n_batch_per_epoch * batch_size

# Setup environment (logging directory etc)
setup_logging(model_name, logging_dir=logging_dir)

# Load and rescale data
X_full_train, X_sketch_train, X_full_val, X_sketch_val = load_data(
    data_folder, dset, image_data_format)
img_dim = X_full_train.shape[-3:]

# Get the number of non overlapping patch and the size of input image to the discriminator
nb_patch, img_dim_disc = get_nb_patch(img_dim, patch_size, image_data_format)

try:

    # Create optimizers
    opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
Exemplo n.º 16
0
def trainDECO(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """
    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic = kwargs["deterministic"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    noClass = kwargs["noClass"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    monsterClass = kwargs["monsterClass"]
    pretrained = kwargs["pretrained"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    if dset == "mnistM":
        X_source_train,Y_source_train, _, _, n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnist')
        X_dest_train,Y_dest_train, _, _,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnistM')
    elif dset == "washington_vandal50k":
        X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington')
        X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal50k')
    elif dset == "washington_vandal12classes":
        X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes')
        X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classes')
    elif dset == "washington_vandal12classesNoBackground":
        X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes')
        X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classesNoBackground')
    elif dset == "Wash_Vand_12class_LMDB":
        X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_12class_LMDB')
        X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB')
    elif dset == "Vand_Vand_12class_LMDB":
        X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB_Background')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB')
    elif dset == "Wash_Color_LMDB":
        X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB')
        X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB')
    else:
        print "dataset not supported"
    if n_classes1 != n_classes2: #sanity check
        print "number of classes mismatch between source and dest domains"
    n_classes = n_classes1 

    # Get the full real image dimension
    img_source_dim = X_source_train.shape[-3:] # is it backend agnostic?
    img_dest_dim = X_dest_train.shape[-3:] 

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)
    opt_C = data_utils.get_optimizer('SGD', 0.01)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    generator_model = models.generator_upsampling_mnistM(noise_dim, img_source_dim,img_dest_dim, bn_mode,deterministic,inject_noise,wd, dset=dset)
    classifier = models.resnet(img_dest_dim,n_classes,wd=0.0001) #it is img_dest_dim because it is actually the generated image dim,that is equal to dest_dim
    GenToClassifierModel = models.GenToClassifierModel(generator_model, classifier, noise_dim, img_source_dim)

    #################
    # Load weight 
    ################
    if pretrained:
        model_path = "../../models/DCGAN"
        class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5')
        classifier.load_weights(class_weights_path)
 #   if resume: ########loading previous saved model weights
 #       data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name)

    #######################
    # Compile models
    #######################
    generator_model.compile(loss='mse', optimizer=opt_G)
#    classifier.trainable = True # I wanna freeze the classifier without any training updates
    classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) # it is actually never using optimizer

    models.make_trainable(classifier, False)
    GenToClassifierModel.compile(loss='categorical_crossentropy', optimizer=opt_G,metrics=['accuracy'])

    #######################
    # Train classifier
    #######################
    if not pretrained: 
        loss1,acc1 =classifier.evaluate(X_dest_test, Y_dest_test,batch_size=256, verbose=0)
        print('\n Classifier Accuracy on target domain test set before training: %.00f%%' % (100.0 * acc1))
        classifier.fit(X_dest_train, Y_dest_train, validation_split=0.1, batch_size=512, nb_epoch=20, verbose=1)
    else:
        print ("Loaded pretrained classifier, computing accuracy on target domain test set:")
    loss2,acc2 = classifier.evaluate(X_dest_test, Y_dest_test,batch_size=512, verbose=0)
    print('\n Classifier Accuracy on target domain test set after training:  %.00f%%' % (100.0 * acc2))
    loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0)
    print('\n Classifier Accuracy on source domain test set:  %.00f%%' % (100.0 * acc3))
    evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model)

#    model_path = "../../models/DCGAN"
#    class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5')
#    classifier.save_weights(class_weights_path, overwrite=True)

#    models.make_trainable(classifier, False)
    #classifier.trainable = False # I wanna freeze the classifier without any more training updates
#    classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) 
    #################
    #  DECO training
    ################
    gen_iterations = 0
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()
            ###################################
            # 1) Train the critic / discriminator
            ###################################
        list_class_loss_real = []
        X_dest_batch, Y_dest_batch,idx_dest_batch = next(data_utils.gen_batch(X_dest_train, Y_dest_train, batch_size))
        X_source_batch, Y_source_batch,idx_source_batch = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size))
        #######################
        # 2) Train the generator
        #######################
        X_gen = data_utils.sample_noise(noise_scale, batch_size*n_batch_per_epoch, noise_dim)
        X_source_batch2, Y_source_batch2,idx_source_batch2 = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size*n_batch_per_epoch))
        GenToClassifierModel.fit([X_gen, X_source_batch2], Y_source_batch2, batch_size=256, nb_epoch=1, verbose=1)

       
        gen_iterations += 1
        batch_counter += 1

            # plot images 1 times per epoch
        X_dest_batch_plot,Y_dest_batch_plot,idx_dest_plot = next(data_utils.gen_batch(X_dest_train,Y_dest_train, batch_size=32))
        X_source_batch_plot,Y_source_batch_plot,idx_source_plot = next(data_utils.gen_batch(X_source_train,Y_source_train, batch_size=32))

        data_utils.plot_generated_batch(X_dest_train,X_source_train, generator_model,
                                                 noise_dim, image_dim_ordering,idx_source_plot,batch_size=32)
        print ("Dest labels:") 
        print (Y_dest_train[idx_source_plot].argmax(1))
        print ("Source labels:") 
        print (Y_source_batch_plot.argmax(1))
        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        #data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, name)
        evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model)

        loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0)
        print('\n Classifier Accuracy on source domain test set:  %.00f%%' % (100.0 * acc3))
Exemplo n.º 17
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_data_format = kwargs["image_data_format"]
    celebA_img_dim = kwargs["celebA_img_dim"]
    cont_dim = (kwargs["cont_dim"], )
    cat_dim = (kwargs["cat_dim"], )
    noise_dim = (kwargs["noise_dim"], )
    label_smoothing = kwargs["label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    load_from_dir = kwargs["load_from_dir"]
    target_size = kwargs["target_size"]
    save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"]
    save_only_last_n_weights = kwargs["save_only_last_n_weights"]
    visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"]
    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(**kwargs)

    # Load and rescale data
    if dset == "celebA":
        X_real_train = data_utils.load_celebA(celebA_img_dim,
                                              image_data_format)
    elif dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_data_format)
    else:
        X_batch_gen = data_utils.data_generator_from_dir(
            dset, target_size, batch_size)
        X_real_train = next(X_batch_gen)
    img_dim = X_real_train.shape[-3:]

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
        opt_discriminator = Adam(lr=1E-4,
                                 beta_1=0.5,
                                 beta_2=0.999,
                                 epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True)

        # Load generator model
        generator_model = models.load("generator_%s" % generator,
                                      cat_dim,
                                      cont_dim,
                                      noise_dim,
                                      img_dim,
                                      batch_size,
                                      dset=dset,
                                      use_mbd=use_mbd)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          cat_dim,
                                          cont_dim,
                                          noise_dim,
                                          img_dim,
                                          batch_size,
                                          dset=dset,
                                          use_mbd=use_mbd)

        generator_model.compile(loss='mse', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   cat_dim, cont_dim, noise_dim)

        list_losses = [
            'binary_crossentropy', 'categorical_crossentropy', gaussian_loss
        ]
        list_weights = [1, 1, 1]
        DCGAN_model.compile(loss=list_losses,
                            loss_weights=list_weights,
                            optimizer=opt_dcgan)

        # Multiple discriminator losses
        discriminator_model.trainable = True
        discriminator_model.compile(loss=list_losses,
                                    loss_weights=list_weights,
                                    optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        if not load_from_dir:
            X_batch_gen = data_utils.gen_batch(X_real_train, batch_size)

        # Start training
        print("Start training")

        disc_total_losses = []
        disc_log_losses = []
        disc_cat_losses = []
        disc_cont_losses = []
        gen_total_losses = []
        gen_log_losses = []
        gen_cat_losses = []
        gen_cont_losses = []

        start = time.time()

        for e in range(nb_epoch):

            print('--------------------------------------------')
            print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d}\n'.format(
                datetime.datetime.now(), e + 1, nb_epoch))

            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1

            disc_total_loss_batch = 0
            disc_log_loss_batch = 0
            disc_cat_loss_batch = 0
            disc_cont_loss_batch = 0
            gen_total_loss_batch = 0
            gen_log_loss_batch = 0
            gen_cat_loss_batch = 0
            gen_cont_loss_batch = 0

            for batch_counter in range(n_batch_per_epoch):

                # Load data
                X_real_batch = next(X_batch_gen)

                # Create a batch to feed the discriminator model
                X_disc, y_disc, y_cat, y_cont = data_utils.get_disc_batch(
                    X_real_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    cat_dim,
                    cont_dim,
                    noise_dim,
                    noise_scale=noise_scale,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(
                    X_disc, [y_disc, y_cat, y_cont])

                # Create a batch to feed the generator model
                X_gen, y_gen, y_cat, y_cont, y_cont_target = data_utils.get_gen_batch(
                    batch_size,
                    cat_dim,
                    cont_dim,
                    noise_dim,
                    noise_scale=noise_scale)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(
                    [y_cat, y_cont, X_gen], [y_gen, y_cat, y_cont_target])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                progbar.add(batch_size,
                            values=[("D tot", disc_loss[0]),
                                    ("D log", disc_loss[1]),
                                    ("D cat", disc_loss[2]),
                                    ("D cont", disc_loss[3]),
                                    ("G tot", gen_loss[0]),
                                    ("G log", gen_loss[1]),
                                    ("G cat", gen_loss[2]),
                                    ("G cont", gen_loss[3])])

                disc_total_loss_batch += disc_loss[0]
                disc_log_loss_batch += disc_loss[1]
                disc_cat_loss_batch += disc_loss[2]
                disc_cont_loss_batch += disc_loss[3]
                gen_total_loss_batch += gen_loss[0]
                gen_log_loss_batch += gen_loss[1]
                gen_cat_loss_batch += gen_loss[2]
                gen_cont_loss_batch += gen_loss[3]

                # # Save images for visualization
                # if batch_counter % (n_batch_per_epoch / 2) == 0:
                #     data_utils.plot_generated_batch(X_real_batch, generator_model, e,
                #                                     batch_size, cat_dim, cont_dim, noise_dim,
                #                                     image_data_format, model_name)

            disc_total_losses.append(disc_total_loss_batch / n_batch_per_epoch)
            disc_log_losses.append(disc_log_loss_batch / n_batch_per_epoch)
            disc_cat_losses.append(disc_cat_loss_batch / n_batch_per_epoch)
            disc_cont_losses.append(disc_cont_loss_batch / n_batch_per_epoch)
            gen_total_losses.append(gen_total_loss_batch / n_batch_per_epoch)
            gen_log_losses.append(gen_log_loss_batch / n_batch_per_epoch)
            gen_cat_losses.append(gen_cat_loss_batch / n_batch_per_epoch)
            gen_cont_losses.append(gen_cont_loss_batch / n_batch_per_epoch)

            # Save images for visualization
            if (e + 1) % visualize_images_every_n_epochs == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                e, batch_size, cat_dim,
                                                cont_dim, noise_dim,
                                                image_data_format, model_name)
                data_utils.plot_losses(disc_total_losses, disc_log_losses,
                                       disc_cat_losses, disc_cont_losses,
                                       gen_total_losses, gen_log_losses,
                                       gen_cat_losses, gen_cont_losses,
                                       model_name)

            if (e + 1) % save_weights_every_n_epochs == 0:

                print("Saving weights...")

                # Delete all but the last n weights
                general_utils.purge_weights(save_only_last_n_weights,
                                            model_name)

                # Save weights
                gen_weights_path = os.path.join(
                    '../../models/%s/gen_weights_epoch%05d.h5' %
                    (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    '../../models/%s/disc_weights_epoch%05d.h5' %
                    (model_name, e))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    '../../models/%s/DCGAN_weights_epoch%05d.h5' %
                    (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

            end = time.time()
            print("")
            print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, end - start))
            start = end

    except KeyboardInterrupt:
        pass

    gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name)
    print("Saving", gen_weights_path)
    generator_model.save(gen_weights_path, overwrite=True)
Exemplo n.º 18
0
def train(**kwargs):
    """
    Train standard DCGAN model
    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    discriminator = kwargs["discriminator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_rec = kwargs["lr_D"]
    opt_rec = kwargs["opt_rec"]
    lr_G = kwargs["lr_G"]
    lr_D = kwargs["lr_D"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    use_mbd = kwargs["use_mbd"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size
    deterministic1 = kwargs["deterministic1"]
    deterministic2 = kwargs["deterministic2"]
    inject_noise = kwargs["inject_noise"]
    model = kwargs["model"]
    no_supertrain = kwargs["no_supertrain"]
    pureGAN = kwargs["pureGAN"]
    lsmooth = kwargs["lsmooth"]
    disc_type = kwargs["disc_type"]
    resume = kwargs["resume"]
    name = kwargs["name"]
    wd = kwargs["wd"]
    history_size = kwargs["history_size"]
    monsterClass = kwargs["monsterClass"]
    data_aug = kwargs["data_aug"]
    disc_iters = kwargs["disc_iterations"]
    class_weight = kwargs["class_weight"]
    reconst_w = kwargs["reconst_w"]
    rec = kwargs["rec"]
    reconstClass = kwargs["reconstClass"]
    pretrained = kwargs["pretrained"]
    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print(key, kwargs[key])
    print("\n")
    #####some extra parameters:

    noise_dim = (noise_dim, )
    name1 = name + '1'
    name2 = name + '2'
    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")
    gen_iterations = 0

    # Loading data
    A_data, A_labels, B_data, B_labels, n_classes, img_A_dim, img_B_dim = load_data(
        img_dim, image_dim_ordering, dset)

    test_data, test_labels = load_testset(img_dim, image_dim_ordering, dset)

    if deterministic1 is None:
        deterministic1 = False
    if deterministic2 is None:
        deterministic2 = False

    opt_D1, opt_G1, opt_C1, opt_Z1, opt_rec = build_opt(
        opt_D, opt_G, lr_D, lr_G, lr_rec, opt_rec)
    generator_model1, discriminator_model1, discriminator_class1, classificator_model1, classificator_model2, DCGAN_model1, GenClass_model1, generator_ss1, discriminator_ss1, DCGAN_ss1, zclass_model1 = load_compile_models(
        noise_dim,
        img_A_dim,
        img_B_dim,
        deterministic1,
        pureGAN,
        wd,
        'mse',
        'categorical_crossentropy',
        disc_type,
        n_classes,
        opt_D1,
        opt_G1,
        opt_C1,
        opt_Z1,
        suffix=None,
        pretrained=pretrained)
    load_pretrained_weights(generator_model1,
                            discriminator_model1,
                            discriminator_class1,
                            DCGAN_model1,
                            name1,
                            B_data,
                            B_labels,
                            noise_scale,
                            classificator_model1,
                            resume=resume)
    img_buffer1, datagen1 = load_buffer_and_augmentation(
        history_size, batch_size, img_A_dim, n_classes)

    GAN1 = _GAN(generator_model1,
                discriminator_model1,
                discriminator_class1,
                DCGAN_model1,
                GenClass_model1,
                classificator_model1,
                classificator_model2,
                generator_ss1,
                discriminator_ss1,
                DCGAN_ss1,
                batch_size,
                img_A_dim,
                img_B_dim,
                noise_dim,
                noise_scale,
                lr_D,
                lr_G,
                deterministic1,
                inject_noise,
                model,
                lsmooth,
                img_buffer1,
                datagen1,
                disc_type,
                data_aug,
                n_classes,
                disc_iters,
                name1,
                dir='AtoB')
    pretrain_disc(GAN1,
                  A_data,
                  A_labels,
                  B_data,
                  B_labels,
                  class_weight,
                  pretrain_iters=500,
                  resume=resume)
    #####################

    accuracy = []
    for e in range(1, nb_epoch + 1):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size, interval=0.2)
        batch_counter = 1
        start = time.time()
        while batch_counter < n_batch_per_epoch:
            l_disc_real1, l_disc_gen1, l_gen1, l_disc_real1_ss, l_disc_gen1_ss, l_disc_ss1, l_disc_ss2, l_disc_ss3, l_disc_ss4, l_disc_ss5, l_gen1_ss, l_z1, l_class1, l_rec1, l_GenClass1, _ = get_loss_list(
            )
            A_data_batch, A_labels_batch, B_data_batch, B_labels_batch = train_gan(
                GAN1, GAN1.disc_iters, A_data, A_labels, B_data, B_labels,
                batch_counter, l_disc_real1, l_disc_gen1, l_gen1,
                l_disc_real1_ss, l_disc_gen1_ss, l_disc_ss1, l_disc_ss2,
                l_disc_ss3, l_disc_ss4, l_disc_ss5, l_gen1_ss, l_GenClass1,
                class_weight)

            if rec:
                train_rec(GAN1, rec1, rec2, A_data_batch, B_data_batch, l_rec1,
                          l_rec2, reconst_w)  #BRINGING US TO L.A.? :)
#             if reconstClass > 0.0:
#                 train_recClass(GAN1,recClass, A_data_batch, A_labels_batch,  l_recClass, reconstClass)

            l_class1 = train_class(GAN1, l_class1, l_rec1, A_data_batch,
                                   A_labels_batch, B_data_batch, test_data,
                                   test_labels)
            #             l_class2 = train_class(GAN2, l_class2,  A_data_batch, A_labels_batch)

            #             dummy = GAN1.discriminator2.predict(B_data_batch)
            #             print(dummy)

            batch_counter, gen_iterations = visualize_save_stuffs(
                [GAN1], progbar, gen_iterations, batch_counter,
                n_batch_per_epoch, l_disc_real1, l_disc_gen1, l_gen1,
                l_disc_real1_ss, l_disc_gen1_ss, l_gen1_ss, l_class1, A_data,
                A_labels, B_data, B_labels, start, e, l_rec1, l_GenClass1)

        acc = testing_class_accuracy([GAN1], GAN1.classificator_model,
                                     GAN1.generator_model, test_data.shape[0],
                                     GAN1.noise_dim, GAN1.noise_scale,
                                     test_data, test_labels)

        X_noise = sample_noise(GAN1.noise_scale, A_data.shape[0],
                               GAN1.noise_dim)
        gen_output = GAN1.generator_model.predict([X_noise, A_data])
        np.save('MnistM', gen_output)
        #        testing_class_accuracy([GAN1],GAN1.classificator_model, GAN1.generator_model,
        #                               5000, GAN1.noise_dim, GAN1.noise_scale, B_data, B_labels)
        accuracy = np.append(accuracy, acc)
        np.save('accuracy', accuracy)
Exemplo n.º 19
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]

    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    X_full_train, X_sketch_train, X_full_val, X_sketch_val, target_train, target_val = data_utils.load_data(
        dset, image_data_format)
    img_dim = X_full_train.shape[-3:]

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size,
                                                     image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # DCGAN_model = models.DCGAN(generator_model,
        #                           discriminator_model,
        #                            img_dim,
        #                            patch_size,
        #                            image_data_format)
        ##########################################################################
        classifier_model = models.Pereira_classifier(img_dim)
        #classifier_model  = models.MyResNet18(img_dim)
        #classifier_model  = models.MyDensNet121(img_dim)
        #classifier_model  = models.MyNASNetMobile(img_dim)

        #########################################################################
        loss = [keras.losses.categorical_crossentropy]
        loss_weights = [1]
        classifier_model.compile(loss=loss,
                                 loss_weights=loss_weights,
                                 optimizer=opt_dcgan)

        class_loss = 100
        disc_loss = 100
        max_accval = 0
        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_full_batch, X_sketch_batch, Y_target in data_utils.gen_batch(
                    X_full_train, X_sketch_train, target_train, batch_size):

                class_loss = classifier_model.train_on_batch(
                    X_sketch_batch, Y_target)

                # Unfreeze the discriminator

                batch_counter += 1
                progbar.add(batch_size, values=[("class_loss", class_loss)])

                # Save images for visualization

                if batch_counter >= n_batch_per_epoch:
                    X_full_batch, X_sketch_batch, Y_target_val = next(
                        data_utils.gen_batch(X_full_val, X_sketch_val,
                                             target_val,
                                             int(X_sketch_val.shape[0])))
                    y_pred = classifier_model.predict(X_sketch_batch)
                    y_predd = np.argmax(y_pred, axis=1)
                    y_true = np.argmax(Y_target_val, axis=1)
                    #print(y_true.shape)
                    accval = (sum(
                        (y_predd == y_true)) / y_predd.shape[0] * 100)
                    if (accval > max_accval):
                        max_accval = accval

                    print('valacc=%.2f' % (accval))
                    print('max_accval=%.2f' % (max_accval))

                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))
    except KeyboardInterrupt:
        pass
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    X_real_train = data_utils.load_image_dataset(dset, img_dim,
                                                 image_dim_ordering)

    # Get the full real image dimension
    img_dim = X_real_train.shape[-3:]

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    if generator == "upsampling":
        generator_model = models.generator_upsampling(noise_dim,
                                                      img_dim,
                                                      bn_mode,
                                                      dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim,
                                                  img_dim,
                                                  bn_mode,
                                                  batch_size,
                                                  dset=dset)
    discriminator_model = models.discriminator(img_dim, bn_mode, dset=dset)
    DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim,
                               img_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            if gen_iterations < 25 or gen_iterations % 500 == 0:
                disc_iterations = 100
            else:
                disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [
                        np.clip(w, clamp_lower, clamp_upper) for w in weights
                    ]
                    l.set_weights(weights)

                X_real_batch = next(
                    data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                    X_real_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    noise_dim,
                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, -np.ones(X_disc_real.shape[0]))
                disc_loss_gen = discriminator_model.train_on_batch(
                    X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = X_gen = data_utils.sample_noise(noise_scale, batch_size,
                                                    noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                  -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            gen_iterations += 1
            batch_counter += 1
            progbar.add(batch_size,
                        values=[("Loss_D", -np.mean(list_disc_loss_real) -
                                 np.mean(list_disc_loss_gen)),
                                ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", -gen_loss)])

            # Save images for visualization ~2 times per epoch
            if batch_counter % (n_batch_per_epoch / 2) == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                batch_size, noise_dim,
                                                image_dim_ordering)

        print('\nEpoch %s/%s, Time: %s' %
              (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model,
                                      DCGAN_model, e)
def train(**kwargs):
    """
    Train model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    data_file = kwargs["data_file"]
    nb_neighbors = kwargs["nb_neighbors"]
    model_name = kwargs["model_name"]
    training_mode = kwargs["training_mode"]
    epoch_size = n_batch_per_epoch * batch_size
    img_size = int(os.path.basename(data_file).split("_")[1])

    # Setup directories to save model, architecture etc
    general_utils.setup_logging(model_name)

    # Create a batch generator for the color data
    DataGen = batch_utils.DataGenerator(data_file,
                                        batch_size=batch_size,
                                        dset="training")
    c, h, w = DataGen.get_config()["data_shape"][1:]

    # Load the array of quantized ab value
    q_ab = np.load("../../data/processed/pts_in_hull.npy")
    nb_q = q_ab.shape[0]
    # Fit a NN to q_ab
    nn_finder = nn.NearestNeighbors(n_neighbors=nb_neighbors, algorithm='ball_tree').fit(q_ab)

    # Load the color prior factor that encourages rare colors
    prior_factor = np.load("../../data/processed/CelebA_%s_prior_factor.npy" % img_size)

    # Load and rescale data
    if training_mode == "in_memory":
        with h5py.File(data_file, "r") as hf:
            X_train = hf["training_lab_data"][:]

    # Remove possible previous figures to avoid confusion
    for f in glob.glob("../../figures/*.png"):
        os.remove(f)

    try:

        # Create optimizers
        opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        # Load colorizer model
        color_model = models.load(model_name, nb_q, (1, h, w), batch_size)
        color_model.compile(loss='categorical_crossentropy_color', optimizer=opt)

        color_model.summary()
        from keras.utils.visualize_util import plot
        plot(color_model, to_file='../../figures/colorful.png', show_shapes=True, show_layer_names=True)

        # Actual training loop
        for epoch in range(nb_epoch):

            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            # Choose Batch Generation mode
            if training_mode == "in_memory":
                BatchGen = DataGen.gen_batch_in_memory(X_train, nn_finder, nb_q, prior_factor)
            else:
                BatchGen = DataGen.gen_batch(nn_finder, nb_q, prior_factor)

            for batch in BatchGen:

                X_batch_black, X_batch_color, Y_batch = batch

                train_loss = color_model.train_on_batch(X_batch_black / 100., Y_batch)

                batch_counter += 1
                progbar.add(batch_size, values=[("loss", train_loss)])

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' % (epoch + 1, nb_epoch, time.time() - start))

            # Plot some data with original, b and w and colorized versions side by side
            general_utils.plot_batch(color_model, q_ab, X_batch_black, X_batch_color,
                                     batch_size, h, w, nb_q, epoch)

            # Save weights every 5 epoch
            if epoch % 5 == 0:
                weights_path = os.path.join('../../models/%s/%s_weights_epoch%s.h5' %
                                            (model_name, model_name, epoch))
                color_model.save_weights(weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Exemplo n.º 22
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    image_data_format = kwargs["image_data_format"]
    generator_type = kwargs["generator_type"]
    dset = kwargs["dset"]
    use_identity_image = kwargs["use_identity_image"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    augment_data = kwargs["augment_data"]
    model_name = kwargs["model_name"]
    save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"]
    visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"]
    save_only_last_n_weights = kwargs["save_only_last_n_weights"]
    use_mbd = kwargs["use_mbd"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping_prob = kwargs["label_flipping_prob"]
    use_l1_weighted_loss = kwargs["use_l1_weighted_loss"]
    prev_model = kwargs["prev_model"]
    change_model_name_to_prev_model = kwargs["change_model_name_to_prev_model"]
    discriminator_optimizer = kwargs["discriminator_optimizer"]
    n_run_of_gen_for_1_run_of_disc = kwargs["n_run_of_gen_for_1_run_of_disc"]
    load_all_data_at_once = kwargs["load_all_data_at_once"]
    MAX_FRAMES_PER_GIF = kwargs["MAX_FRAMES_PER_GIF"]
    dont_train = kwargs["dont_train"]

    # batch_size = args.batch_size
    # n_batch_per_epoch = args.n_batch_per_epoch
    # nb_epoch = args.nb_epoch
    # save_weights_every_n_epochs = args.save_weights_every_n_epochs
    # generator_type = args.generator_type
    # patch_size = args.patch_size
    # label_smoothing = False
    # label_flipping_prob = False
    # dset = args.dset
    # use_mbd = False

    if dont_train:
        # Get the number of non overlapping patch and the size of input image to the discriminator
        nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format)
        generator_model = models.load("generator_unet_%s" % generator_type,
                                      img_dim,
                                      nb_patch,
                                      use_mbd,
                                      batch_size,
                                      model_name)
        generator_model.compile(loss='mae', optimizer='adam')
        return generator_model

    # Check and make the dataset
    # If .h5 file of dset is not present, try making it
    if load_all_data_at_once:
        if not os.path.exists("../../data/processed/%s_data.h5" % dset):
            print("dset %s_data.h5 not present in '../../data/processed'!" % dset)
            if not os.path.exists("../../data/%s/" % dset):
                print("dset folder %s not present in '../../data'!\n\nERROR: Dataset .h5 file not made, and dataset not available in '../../data/'.\n\nQuitting." % dset)
                return
            else:
                if not os.path.exists("../../data/%s/train" % dset) or not os.path.exists("../../data/%s/val" % dset) or not os.path.exists("../../data/%s/test" % dset):
                    print("'train', 'val' or 'test' folders not present in dset folder '../../data/%s'!\n\nERROR: Dataset must contain 'train', 'val' and 'test' folders.\n\nQuitting." % dset)
                    return
                else:
                    print("Making %s dataset" % dset)
                    subprocess.call(['python3', '../data/make_dataset.py', '../../data/%s' % dset, '3'])
                    print("Done!")
    else:
        if not os.path.exists(dset):
            print("dset does not exist! Given:", dset)
            return
        if not os.path.exists(os.path.join(dset, 'train')):
            print("dset does not contain a 'train' dir! Given dset:", dset)
            return
        if not os.path.exists(os.path.join(dset, 'val')):
            print("dset does not contain a 'val' dir! Given dset:", dset)
            return

    epoch_size = n_batch_per_epoch * batch_size

    init_epoch = 0

    if prev_model:
        print('\n\nLoading prev_model from', prev_model, '...\n\n')
        prev_model_latest_gen = sorted(glob.glob(os.path.join('../../models/', prev_model, '*gen*epoch*.h5')))[-1]
        prev_model_latest_disc = sorted(glob.glob(os.path.join('../../models/', prev_model, '*disc*epoch*.h5')))[-1]
        prev_model_latest_DCGAN = sorted(glob.glob(os.path.join('../../models/', prev_model, '*DCGAN*epoch*.h5')))[-1]
        print(prev_model_latest_gen, prev_model_latest_disc, prev_model_latest_DCGAN)
        if change_model_name_to_prev_model:
            # Find prev model name, epoch
            model_name = prev_model_latest_DCGAN.split('models')[-1].split('/')[1]
            init_epoch = int(prev_model_latest_DCGAN.split('epoch')[1][:5]) + 1

    # img_dim = X_target_train.shape[-3:]
    # img_dim = (256, 256, 3)

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        if discriminator_optimizer == 'sgd':
            opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        elif discriminator_optimizer == 'adam':
            opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator_type,
                                      img_dim,
                                      nb_patch,
                                      use_mbd,
                                      batch_size,
                                      model_name)

        generator_model.compile(loss='mae', optimizer=opt_dcgan)

        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          img_dim_disc,
                                          nb_patch,
                                          use_mbd,
                                          batch_size,
                                          model_name)

        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model,
                                   discriminator_model,
                                   img_dim,
                                   patch_size,
                                   image_data_format)

        if use_l1_weighted_loss:
            loss = [l1_weighted_loss, 'binary_crossentropy']
        else:
            loss = [l1_loss, 'binary_crossentropy']

        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

        # Load prev_model
        if prev_model:
            generator_model.load_weights(prev_model_latest_gen)
            discriminator_model.load_weights(prev_model_latest_disc)
            DCGAN_model.load_weights(prev_model_latest_DCGAN)

        # Load .h5 data all at once
        print('\n\nLoading data...\n\n')
        check_this_process_memory()

        if load_all_data_at_once:
            X_target_train, X_sketch_train, X_target_val, X_sketch_val = data_utils.load_data(dset, image_data_format)
            check_this_process_memory()
            print('X_target_train: %.4f' % (X_target_train.nbytes/2**30), "GB")
            print('X_sketch_train: %.4f' % (X_sketch_train.nbytes/2**30), "GB")
            print('X_target_val: %.4f' % (X_target_val.nbytes/2**30), "GB")
            print('X_sketch_val: %.4f' % (X_sketch_val.nbytes/2**30), "GB")

            # To generate training data
            X_target_batch_gen_train, X_sketch_batch_gen_train = data_utils.data_generator(X_target_train, X_sketch_train, batch_size, augment_data=augment_data)
            X_target_batch_gen_val, X_sketch_batch_gen_val = data_utils.data_generator(X_target_val, X_sketch_val, batch_size, augment_data=False)

        # Load data from images through an ImageDataGenerator
        else:
            X_batch_gen_train = data_utils.data_generator_from_dir(os.path.join(dset, 'train'), target_size=(img_dim[0], 2*img_dim[1]), batch_size=batch_size)
            X_batch_gen_val = data_utils.data_generator_from_dir(os.path.join(dset, 'val'), target_size=(img_dim[0], 2*img_dim[1]), batch_size=batch_size)

        check_this_process_memory()

        # Setup environment (logging directory etc)
        general_utils.setup_logging(**kwargs)

        # Losses
        disc_losses = []
        gen_total_losses = []
        gen_L1_losses = []
        gen_log_losses = []

        # Start training
        print("\n\nStarting training...\n\n")

        # For each epoch
        for e in range(nb_epoch):
            
            # Initialize progbar and batch counter
            # progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 0
            gen_total_loss_epoch = 0
            gen_L1_loss_epoch = 0
            gen_log_loss_epoch = 0
            start = time.time()
            
            # For each batch
            # for X_target_batch, X_sketch_batch in data_utils.gen_batch(X_target_train, X_sketch_train, batch_size):
            for batch in range(n_batch_per_epoch):
                
                # Create a batch to feed the discriminator model
                if load_all_data_at_once:
                    X_target_batch_train, X_sketch_batch_train = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train)
                else:
                    X_target_batch_train, X_sketch_batch_train = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim,
                                                                                                                   augment_data=augment_data,
                                                                                                                   use_identity_image=use_identity_image)

                X_disc, y_disc = data_utils.get_disc_batch(X_target_batch_train,
                                                           X_sketch_batch_train,
                                                           generator_model,
                                                           batch_counter,
                                                           patch_size,
                                                           image_data_format,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping_prob=label_flipping_prob)
                
                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
                
                # Create a batch to feed the generator model
                if load_all_data_at_once:
                    X_gen_target, X_gen_sketch = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train)
                else:
                    X_gen_target, X_gen_sketch = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim,
                                                                                                   augment_data=augment_data,
                                                                                                   use_identity_image=use_identity_image)

                y_gen_target = np.zeros((X_gen_target.shape[0], 2), dtype=np.uint8)
                y_gen_target[:, 1] = 1
                
                # Freeze the discriminator
                discriminator_model.trainable = False
                
                # Train generator
                for _ in range(n_run_of_gen_for_1_run_of_disc-1):
                    gen_loss = DCGAN_model.train_on_batch(X_gen_sketch, [X_gen_target, y_gen_target])
                    gen_total_loss_epoch += gen_loss[0]/n_run_of_gen_for_1_run_of_disc
                    gen_L1_loss_epoch += gen_loss[1]/n_run_of_gen_for_1_run_of_disc
                    gen_log_loss_epoch += gen_loss[2]/n_run_of_gen_for_1_run_of_disc
                    if load_all_data_at_once:
                        X_gen_target, X_gen_sketch = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train)
                    else:
                        X_gen_target, X_gen_sketch = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim,
                                                                                                       augment_data=augment_data,
                                                                                                       use_identity_image=use_identity_image)

                gen_loss = DCGAN_model.train_on_batch(X_gen_sketch, [X_gen_target, y_gen_target])
                
                # Add losses
                gen_total_loss_epoch += gen_loss[0]/n_run_of_gen_for_1_run_of_disc
                gen_L1_loss_epoch += gen_loss[1]/n_run_of_gen_for_1_run_of_disc
                gen_log_loss_epoch += gen_loss[2]/n_run_of_gen_for_1_run_of_disc
                
                # Unfreeze the discriminator
                discriminator_model.trainable = True
                
                # Progress
                # progbar.add(batch_size, values=[("D logloss", disc_loss),
                #                                 ("G tot", gen_loss[0]),
                #                                 ("G L1", gen_loss[1]),
                #                                 ("G logloss", gen_loss[2])])
                print("Epoch", str(init_epoch+e+1), "batch", str(batch+1), "D_logloss", disc_loss, "G_tot", gen_loss[0], "G_L1", gen_loss[1], "G_log", gen_loss[2])
            
            gen_total_loss = gen_total_loss_epoch/n_batch_per_epoch
            gen_L1_loss = gen_L1_loss_epoch/n_batch_per_epoch
            gen_log_loss = gen_log_loss_epoch/n_batch_per_epoch
            disc_losses.append(disc_loss)
            gen_total_losses.append(gen_total_loss)
            gen_L1_losses.append(gen_L1_loss)
            gen_log_losses.append(gen_log_loss)
            
            # Save images for visualization
            if (e + 1) % visualize_images_every_n_epochs == 0:
                data_utils.plot_generated_batch(X_target_batch_train, X_sketch_batch_train, generator_model, batch_size, image_data_format,
                                                model_name, "training", init_epoch + e + 1, MAX_FRAMES_PER_GIF)
                # Get new images for validation
                if load_all_data_at_once:
                    X_target_batch_val, X_sketch_batch_val = next(X_target_batch_gen_val), next(X_sketch_batch_gen_val)
                else:
                    X_target_batch_val, X_sketch_batch_val = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_val, img_dim=img_dim, augment_data=False, use_identity_image=use_identity_image)
                # Predict and validate
                data_utils.plot_generated_batch(X_target_batch_val, X_sketch_batch_val, generator_model, batch_size, image_data_format,
                                                model_name, "validation", init_epoch + e + 1, MAX_FRAMES_PER_GIF)
                # Plot losses
                data_utils.plot_losses(disc_losses, gen_total_losses, gen_L1_losses, gen_log_losses, model_name, init_epoch)
            
            # Save weights
            if (e + 1) % save_weights_every_n_epochs == 0:
                # Delete all but the last n weights
                purge_weights(save_only_last_n_weights, model_name)
                # Save gen weights
                gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1]))
                print("Saving", gen_weights_path)
                generator_model.save_weights(gen_weights_path, overwrite=True)
                # Save disc weights
                disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1]))
                print("Saving", disc_weights_path)
                discriminator_model.save_weights(disc_weights_path, overwrite=True)
                # Save DCGAN weights
                DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1]))
                print("Saving", DCGAN_weights_path)
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

            check_this_process_memory()
            print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d} END, Time taken: {3:.4f} seconds'.format(datetime.datetime.now(), init_epoch + e + 1, init_epoch + nb_epoch, time.time() - start))
            print('------------------------------------------------------------------------------------')

    except KeyboardInterrupt:
        pass

    # SAVE THE MODEL

    try:
        # Save the model as it is, so that it can be loaded using -
        # ```from keras.models import load_model; gen = load_model('generator_latest.h5')```
        gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name)
        print("Saving", gen_weights_path)
        generator_model.save(gen_weights_path, overwrite=True)
    
        # Save model as json string
        generator_model_json_string = generator_model.to_json()
        print("Saving", '../../models/%s/generator_latest.txt' % model_name)
        with open('../../models/%s/generator_latest.txt' % model_name, 'w') as outfile:
            a = outfile.write(generator_model_json_string)
    
        # Save model as json
        generator_model_json_data = json.loads(generator_model_json_string)
        print("Saving", '../../models/%s/generator_latest.json' % model_name)
        with open('../../models/%s/generator_latest.json' % model_name, 'w') as outfile:
            json.dump(generator_model_json_data, outfile)

    except:
        print(sys.exc_info()[0])

    print("Done.")

    return generator_model
Exemplo n.º 23
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    patch_size = kwargs["patch_size"]
    image_data_format = kwargs["image_data_format"]
    generator_type = kwargs["generator_type"]
    dset = kwargs["dset"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"]
    visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"]
    use_mbd = kwargs["use_mbd"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping_prob = kwargs["label_flipping_prob"]
    use_l1_weighted_loss = kwargs["use_l1_weighted_loss"]
    prev_model = kwargs["prev_model"]
    discriminator_optimizer = kwargs["discriminator_optimizer"]
    n_run_of_gen_for_1_run_of_disc = kwargs["n_run_of_gen_for_1_run_of_disc"]
    MAX_FRAMES_PER_GIF = kwargs["MAX_FRAMES_PER_GIF"]

    # batch_size = args.batch_size
    # n_batch_per_epoch = args.n_batch_per_epoch
    # nb_epoch = args.nb_epoch
    # save_weights_every_n_epochs = args.save_weights_every_n_epochs
    # generator_type = args.generator_type
    # patch_size = args.patch_size
    # label_smoothing = False
    # label_flipping_prob = False
    # dset = args.dset
    # use_mbd = False

    # Check and make the dataset
    # If .h5 file of dset is not present, try making it
    if not os.path.exists("../../data/processed/%s_data.h5" % dset):
        print("dset %s_data.h5 not present in '../../data/processed'!" % dset)
        if not os.path.exists("../../data/%s/" % dset):
            print(
                "dset folder %s not present in '../../data'!\n\nERROR: Dataset .h5 file not made, and dataset not available in '../../data/'.\n\nQuitting."
                % dset)
            return
        else:
            if not os.path.exists(
                    "../../data/%s/train" % dset) or not os.path.exists(
                        "../../data/%s/val" % dset) or not os.path.exists(
                            "../../data/%s/test" % dset):
                print(
                    "'train', 'val' or 'test' folders not present in dset folder '../../data/%s'!\n\nERROR: Dataset must contain 'train', 'val' and 'test' folders.\n\nQuitting."
                    % dset)
                return
            else:
                print("Making %s dataset" % dset)
                subprocess.call([
                    'python3', '../data/make_dataset.py',
                    '../../data/%s' % dset, '3'
                ])
                print("Done!")

    epoch_size = n_batch_per_epoch * batch_size

    init_epoch = 0

    if prev_model:
        print('\n\nLoading prev_model from', prev_model, '...\n\n')
        prev_model_latest_gen = sorted(
            glob.glob(os.path.join('../../models/', prev_model,
                                   '*gen*.h5')))[-1]
        prev_model_latest_disc = sorted(
            glob.glob(os.path.join('../../models/', prev_model,
                                   '*disc*.h5')))[-1]
        prev_model_latest_DCGAN = sorted(
            glob.glob(os.path.join('../../models/', prev_model,
                                   '*DCGAN*.h5')))[-1]
        # Find prev model name, epoch
        model_name = prev_model_latest_DCGAN.split('models')[-1].split('/')[1]
        init_epoch = int(prev_model_latest_DCGAN.split('epoch')[1][:5]) + 1

    # Setup environment (logging directory etc), if no prev_model is mentioned
    general_utils.setup_logging(model_name)

    # img_dim = X_full_train.shape[-3:]
    img_dim = (256, 256, 3)

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size,
                                                     image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        if discriminator_optimizer == 'sgd':
            opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        elif discriminator_optimizer == 'adam':
            opt_discriminator = Adam(lr=1E-3,
                                     beta_1=0.9,
                                     beta_2=0.999,
                                     epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator_type,
                                      img_dim, nb_patch, use_mbd, batch_size,
                                      model_name)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)

        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator", img_dim_disc,
                                          nb_patch, use_mbd, batch_size,
                                          model_name)

        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_data_format)

        if use_l1_weighted_loss:
            loss = [l1_weighted_loss, 'binary_crossentropy']
        else:
            loss = [l1_loss, 'binary_crossentropy']

        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        # Load prev_model
        if prev_model:
            generator_model.load_weights(prev_model_latest_gen)
            discriminator_model.load_weights(prev_model_latest_disc)
            DCGAN_model.load_weights(prev_model_latest_DCGAN)

        # Load and rescale data
        print('\n\nLoading data...\n\n')
        X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(
            dset, image_data_format)
        check_this_process_memory()
        print('X_full_train: %.4f' % (X_full_train.nbytes / 2**30), "GB")
        print('X_sketch_train: %.4f' % (X_sketch_train.nbytes / 2**30), "GB")
        print('X_full_val: %.4f' % (X_full_val.nbytes / 2**30), "GB")
        print('X_sketch_val: %.4f' % (X_sketch_val.nbytes / 2**30), "GB")

        # Losses
        disc_losses = []
        gen_total_losses = []
        gen_L1_losses = []
        gen_log_losses = []

        # Start training
        print("\n\nStarting training\n\n")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            # progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 0
            gen_total_loss_epoch = 0
            gen_L1_loss_epoch = 0
            gen_log_loss_epoch = 0
            start = time.time()
            for X_full_batch, X_sketch_batch in data_utils.gen_batch(
                    X_full_train, X_sketch_train, batch_size):
                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_full_batch,
                    X_sketch_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    image_data_format,
                    label_smoothing=label_smoothing,
                    label_flipping_prob=label_flipping_prob)
                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(
                    data_utils.gen_batch(X_full_train, X_sketch_train,
                                         batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1
                # Freeze the discriminator
                discriminator_model.trainable = False
                # Train generator
                for _ in range(n_run_of_gen_for_1_run_of_disc - 1):
                    gen_loss = DCGAN_model.train_on_batch(
                        X_gen, [X_gen_target, y_gen])
                    gen_total_loss_epoch += gen_loss[
                        0] / n_run_of_gen_for_1_run_of_disc
                    gen_L1_loss_epoch += gen_loss[
                        1] / n_run_of_gen_for_1_run_of_disc
                    gen_log_loss_epoch += gen_loss[
                        2] / n_run_of_gen_for_1_run_of_disc
                    X_gen_target, X_gen = next(
                        data_utils.gen_batch(X_full_train, X_sketch_train,
                                             batch_size))
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Add losses
                gen_total_loss_epoch += gen_loss[
                    0] / n_run_of_gen_for_1_run_of_disc
                gen_L1_loss_epoch += gen_loss[
                    1] / n_run_of_gen_for_1_run_of_disc
                gen_log_loss_epoch += gen_loss[
                    2] / n_run_of_gen_for_1_run_of_disc
                # Unfreeze the discriminator
                discriminator_model.trainable = True
                # Progress
                # progbar.add(batch_size, values=[("D logloss", disc_loss),
                #                                 ("G tot", gen_loss[0]),
                #                                 ("G L1", gen_loss[1]),
                #                                 ("G logloss", gen_loss[2])])
                print("Epoch", str(init_epoch + e + 1), "batch",
                      str(batch_counter + 1), "D_logloss", disc_loss, "G_tot",
                      gen_loss[0], "G_L1", gen_loss[1], "G_log", gen_loss[2])
                batch_counter += 1
                if batch_counter >= n_batch_per_epoch:
                    break
            gen_total_loss = gen_total_loss_epoch / n_batch_per_epoch
            gen_L1_loss = gen_L1_loss_epoch / n_batch_per_epoch
            gen_log_loss = gen_log_loss_epoch / n_batch_per_epoch
            disc_losses.append(disc_loss)
            gen_total_losses.append(gen_total_loss)
            gen_L1_losses.append(gen_L1_loss)
            gen_log_losses.append(gen_log_loss)
            check_this_process_memory()
            print('Epoch %s/%s, Time: %.4f' % (init_epoch + e + 1, init_epoch +
                                               nb_epoch, time.time() - start))
            # Save images for visualization
            if (e + 1) % visualize_images_every_n_epochs == 0:
                data_utils.plot_generated_batch(X_full_batch, X_sketch_batch,
                                                generator_model, batch_size,
                                                image_data_format, model_name,
                                                "training", init_epoch + e + 1,
                                                MAX_FRAMES_PER_GIF)
                # Get new images from validation
                X_full_batch, X_sketch_batch = next(
                    data_utils.gen_batch(X_full_val, X_sketch_val, batch_size))
                data_utils.plot_generated_batch(X_full_batch, X_sketch_batch,
                                                generator_model, batch_size,
                                                image_data_format, model_name,
                                                "validation",
                                                init_epoch + e + 1,
                                                MAX_FRAMES_PER_GIF)
                # Plot losses
                data_utils.plot_losses(disc_losses, gen_total_losses,
                                       gen_L1_losses, gen_log_losses,
                                       model_name, init_epoch)
            # Save weights
            if (e + 1) % save_weights_every_n_epochs == 0:
                gen_weights_path = os.path.join(
                    '../../models/%s/gen_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5'
                    % (model_name, init_epoch + e, disc_losses[-1],
                       gen_total_losses[-1], gen_L1_losses[-1],
                       gen_log_losses[-1]))
                generator_model.save_weights(gen_weights_path, overwrite=True)
                disc_weights_path = os.path.join(
                    '../../models/%s/disc_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5'
                    % (model_name, init_epoch + e, disc_losses[-1],
                       gen_total_losses[-1], gen_L1_losses[-1],
                       gen_log_losses[-1]))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)
                DCGAN_weights_path = os.path.join(
                    '../../models/%s/DCGAN_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5'
                    % (model_name, init_epoch + e, disc_losses[-1],
                       gen_total_losses[-1], gen_L1_losses[-1],
                       gen_log_losses[-1]))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Exemplo n.º 24
0
def train(cat_dim,
          noise_dim,
          batch_size,
          n_batch_per_epoch,
          nb_epoch,
          dset="mnist"):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """
    general_utils.setup_logging("IG")
    # Load and rescale data
    if dset == "mnist":
        print("loading mnist data")
        X_real_train, Y_real_train, X_real_test, Y_real_test = data_utils.load_mnist(
        )
        # pick 1000 sample for testing
        # X_real_test = X_real_test[-1000:]
        # Y_real_test = Y_real_test[-1000:]

    img_dim = X_real_train.shape[-3:]
    epoch_size = n_batch_per_epoch * batch_size

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        opt_discriminator = Adam(lr=2E-4,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True)

        # Load generator model
        generator_model = models.load("generator_deconv",
                                      cat_dim,
                                      noise_dim,
                                      img_dim,
                                      batch_size,
                                      dset=dset)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          cat_dim,
                                          noise_dim,
                                          img_dim,
                                          batch_size,
                                          dset=dset)

        generator_model.compile(loss='mse', optimizer=opt_discriminator)
        # stop the discriminator to learn while in generator is learning
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   cat_dim, noise_dim)

        list_losses = ['binary_crossentropy', 'categorical_crossentropy']
        list_weights = [1, 1]
        DCGAN_model.compile(loss=list_losses,
                            loss_weights=list_weights,
                            optimizer=opt_dcgan)

        # Multiple discriminator losses
        # allow the discriminator to learn again
        discriminator_model.trainable = True
        discriminator_model.compile(loss=list_losses,
                                    loss_weights=list_weights,
                                    optimizer=opt_discriminator)
        # Start training
        print("Start training")
        for e in range(nb_epoch + 1):
            # Initialize progbar and batch counter
            # progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()
            print("Epoch: {}".format(e))
            for X_real_batch, Y_real_batch in zip(
                    data_utils.gen_batch(X_real_train, batch_size),
                    data_utils.gen_batch(Y_real_train, batch_size)):

                # Create a batch to feed the discriminator model
                X_disc_fake, y_disc_fake, noise_sample = data_utils.get_disc_batch(
                    X_real_batch,
                    Y_real_batch,
                    generator_model,
                    batch_size,
                    cat_dim,
                    noise_dim,
                    type="fake")
                X_disc_real, y_disc_real = data_utils.get_disc_batch(
                    X_real_batch,
                    Y_real_batch,
                    generator_model,
                    batch_size,
                    cat_dim,
                    noise_dim,
                    type="real")

                # Update the discriminator
                disc_loss_fake = discriminator_model.train_on_batch(
                    X_disc_fake, [y_disc_fake, Y_real_batch])
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, [y_disc_real, Y_real_batch])
                disc_loss = disc_loss_fake + disc_loss_real
                # Create a batch to feed the generator model
                # X_noise, y_gen = data_utils.get_gen_batch(batch_size, cat_dim, noise_dim)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(
                    [Y_real_batch, noise_sample], [y_disc_real, Y_real_batch])
                # Unfreeze the discriminator
                discriminator_model.trainable = True
                # training validation
                p_real_batch, p_Y_batch = discriminator_model.predict(
                    X_real_batch, batch_size=batch_size)
                acc_train = data_utils.accuracy(p_Y_batch, Y_real_batch)
                batch_counter += 1
                # progbar.add(batch_size, values=[("D tot", disc_loss[0]),
                #                                 ("D cat", disc_loss[2]),
                #                                 ("G tot", gen_loss[0]),
                #                                 ("G cat", gen_loss[2]),
                #                                 ("P Real:", p_real_batch),
                #                                 ("Q acc", acc_train)])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch /
                                    2) == 0 and e % 10 == 0:
                    data_utils.plot_generated_batch(X_real_batch,
                                                    generator_model,
                                                    batch_size, cat_dim,
                                                    noise_dim, e)
                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))
            _, p_Y_test = discriminator_model.predict(
                X_real_test, batch_size=X_real_test.shape[0])
            acc_test = data_utils.accuracy(p_Y_test, Y_real_test)
            print("Epoch: {} Accuracy: {}".format(e + 1, acc_test))
            if e % 1000 == 0:
                gen_weights_path = os.path.join(
                    '../../models/IG/gen_weights.h5')
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    '../../models/IG/disc_weights.h5')
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    '../../models/IG/DCGAN_weights.h5')
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Exemplo n.º 25
0
def eval(**kwargs):

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    cont_dim = (kwargs["cont_dim"],)
    cat_dim = (kwargs["cat_dim"],)
    noise_dim = (kwargs["noise_dim"],)
    bn_mode = kwargs["bn_mode"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    eval_epoch = kwargs["eval_epoch"]

    # Setup environment (logging directory etc)
    general_utils.setup_logging(**kwargs)

    # Load and rescale data
    if dset == "RGZ":
        X_real_train = data_utils.load_RGZ(img_dim, image_data_format)
    if dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_data_format)
    img_dim = X_real_train.shape[-3:]

    # Load generator model
    generator_model = models.load("generator_%s" % generator,
                                  cat_dim,
                                  cont_dim,
                                  noise_dim,
                                  img_dim,
                                  bn_mode,
                                  batch_size,
                                  dset=dset)

    # Load colorization model
    generator_model.load_weights("../../models/%s/gen_weights_epoch%05d.h5" %
                                 (model_name, eval_epoch))

    X_plot = []
    # Vary the categorical variable
    for i in range(cat_dim[0]):
        X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
        X_cont = data_utils.sample_noise(noise_scale, batch_size, cont_dim)
        X_cont = np.repeat(X_cont[:1, :], batch_size, axis=0)  # fix continuous noise
        X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32')
        X_cat[:, i] = 1  # always the same categorical value

        X_gen = generator_model.predict([X_cat, X_cont, X_noise])
        X_gen = data_utils.inverse_normalization(X_gen)

        if image_data_format == "channels_first":
            X_gen = X_gen.transpose(0,2,3,1)

        X_gen = [X_gen[i] for i in range(len(X_gen))]
        X_plot.append(np.concatenate(X_gen, axis=1))
    X_plot = np.concatenate(X_plot, axis=0)

    plt.figure(figsize=(8,10))
    if X_plot.shape[-1] == 1:
        plt.imshow(X_plot[:, :, 0], cmap="gray")
    else:
        plt.imshow(X_plot)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("Varying categorical factor", fontsize=28, labelpad=60)

    plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.tight_layout()
    plt.savefig(os.path.join("../../figures", model_name, "varying_categorical.png"))
    plt.clf()
    plt.close()

    # Vary the continuous variables
    X_plot = []
    # First get the extent of the noise sampling
    x = np.ravel(data_utils.sample_noise(noise_scale, batch_size * 20000, cont_dim))
    # Define interpolation points
    x = np.linspace(x.min(), x.max(), num=batch_size)
    for i in range(batch_size):
        X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim)
        X_cont = np.concatenate([np.array([x[i], x[j]]).reshape(1, -1) for j in range(batch_size)], axis=0)
        X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32')
        X_cat[:, 1] = 1  # always the same categorical value

        X_gen = generator_model.predict([X_cat, X_cont, X_noise])
        X_gen = data_utils.inverse_normalization(X_gen)
        if image_data_format == "channels_first":
            X_gen = X_gen.transpose(0,2,3,1)
        X_gen = [X_gen[i] for i in range(len(X_gen))]
        X_plot.append(np.concatenate(X_gen, axis=1))
    X_plot = np.concatenate(X_plot, axis=0)

    plt.figure(figsize=(10,10))
    if X_plot.shape[-1] == 1:
        plt.imshow(X_plot[:, :, 0], cmap="gray")
    else:
        plt.imshow(X_plot)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("Varying continuous factor 1", fontsize=28, labelpad=60)
    plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.xlabel("Varying continuous factor 2", fontsize=28, labelpad=60)
    plt.annotate('', xy=(1, -0.05), xycoords='axes fraction', xytext=(0, -0.05),
                 arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4))
    plt.tight_layout()
    plt.savefig(os.path.join("../../figures", model_name, "varying_continuous.png"))
    plt.clf()
    plt.close()
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    bn_mode = kwargs["bn_mode"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print key, kwargs[key]
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging("DCGAN")

    # Load and normalize data
    X_real_train = data_utils.load_image_dataset(dset, img_dim, image_dim_ordering)

    # Get the full real image dimension
    img_dim = X_real_train.shape[-3:]

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    if generator == "upsampling":
        generator_model = models.generator_upsampling(noise_dim, img_dim, bn_mode, dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim, img_dim, bn_mode, batch_size, dset=dset)
    discriminator_model = models.discriminator(img_dim, bn_mode)
    DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            if gen_iterations < 25 or gen_iterations % 500 == 0:
                disc_iterations = 100
            else:
                disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
                    l.set_weights(weights)

                X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch,
                                                                    generator_model,
                                                                    batch_counter,
                                                                    batch_size,
                                                                    noise_dim,
                                                                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0]))
                disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            gen_iterations += 1
            batch_counter += 1
            progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)),
                                            ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                            ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                            ("Loss_G", -gen_loss)])

            # Save images for visualization ~2 times per epoch
            if batch_counter % (n_batch_per_epoch / 2) == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                batch_size, noise_dim, image_dim_ordering)

        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e)
def train(model_name, **kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: model_name (str, keras model name)
          **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    nb_classes = kwargs["nb_classes"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    prob = kwargs["prob"]
    do_plot = kwargs["do_plot"]
    data_file = kwargs["data_file"]
    semi_super_file = kwargs["semi_super_file"]
    pretr_weights_file = kwargs["pretr_weights_file"]
    normalisation_style = kwargs["normalisation_style"]
    objective = kwargs["objective"]
    experiment = kwargs["experiment"]
    list_folds = kwargs["list_folds"]

    # Setup environment (logging directory etc)
    general_utils.setup_logging(experiment)

    # Compile model.
    # opt = RMSprop(lr=5E-6, rho=0.9, epsilon=1e-06)
    opt = SGD(lr=5e-4, decay=1e-6, momentum=0.9, nesterov=True)
    # opt = Adam(lr=1E-5, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    # Batch generator
    DataAug = batch_utils.AugDataGenerator(data_file,
                                           batch_size=batch_size,
                                           prob=prob,
                                           dset="train",
                                           maxproc=4,
                                           num_cached=60,
                                           random_augm=False,
                                           hdf5_file_semi=semi_super_file)
    DataAug.add_transform("h_flip")
    # DataAug.add_transform("v_flip")
    # DataAug.add_transform("fixed_rot", angle=40)
    DataAug.add_transform("random_rot", angle=40)
    # DataAug.add_transform("fixed_tr", tr_x=40, tr_y=40)
    DataAug.add_transform("random_tr", tr_x=40, tr_y=40)
    # DataAug.add_transform("fixed_blur", kernel_size=5)
    DataAug.add_transform("random_blur", kernel_size=5)
    # DataAug.add_transform("fixed_erode", kernel_size=4)
    DataAug.add_transform("random_erode", kernel_size=3)
    # DataAug.add_transform("fixed_dilate", kernel_size=4)
    DataAug.add_transform("random_dilate", kernel_size=3)
    # DataAug.add_transform("fixed_crop", pos_x=10, pos_y=10, crop_size_x=200, crop_size_y=200)
    DataAug.add_transform("random_crop", min_crop_size=140, max_crop_size=160)
    # DataAug.add_transform("hist_equal")
    # DataAug.add_transform("random_occlusion", occ_size_x=100, occ_size_y=100)

    epoch_size = n_batch_per_epoch * batch_size

    general_utils.pretty_print("Load all data...")

    with h5py.File(data_file, "r") as hf:
        X = hf["train_data"][:, :, :, :]
        y = hf["train_label"][:].astype(np.uint8)
        y = np_utils.to_categorical(y,
                                    nb_classes=nb_classes)  # Format for keras

        try:
            for fold in list_folds:

                min_valid_loss = 100

                # Save losses
                list_train_loss = []
                list_valid_loss = []

                # Load valid data in memory for fast error evaluation
                idx_valid = hf["valid_fold%s" % fold][:]
                idx_train = hf["train_fold%s" % fold][:]
                X_valid = X[idx_valid]
                y_valid = y[idx_valid]

                # Normalise
                X_valid = normalisation(X_valid, normalisation_style)

                # Compile model
                general_utils.pretty_print("Compiling...")
                model = models.load(model_name,
                                    nb_classes,
                                    X_valid.shape[-3:],
                                    pretr_weights_file=pretr_weights_file)
                model.compile(optimizer=opt, loss=objective)

                # Save architecture
                json_string = model.to_json()
                with open(os.path.join(data_dir, '%s_archi.json' % model.name),
                          'w') as f:
                    f.write(json_string)

                for e in range(nb_epoch):
                    # Initialize progbar and batch counter
                    progbar = generic_utils.Progbar(epoch_size)
                    batch_counter = 1
                    l_train_loss = []
                    start = time.time()

                    for X_train, y_train in DataAug.gen_batch_inmemory(
                            X, y, idx_train=idx_train):

                        if do_plot:
                            general_utils.plot_batch(X_train,
                                                     np.argmax(y_train, 1),
                                                     batch_size)

                        # Normalise
                        X_train = normalisation(X_train, normalisation_style)

                        train_loss = model.train_on_batch(X_train, y_train)
                        l_train_loss.append(train_loss)
                        batch_counter += 1
                        progbar.add(batch_size,
                                    values=[("train loss", train_loss)])
                        if batch_counter >= n_batch_per_epoch:
                            break
                    print("")
                    print('Epoch %s/%s, Time: %s' %
                          (e + 1, nb_epoch, time.time() - start))
                    y_valid_pred = model.predict(X_valid,
                                                 verbose=0,
                                                 batch_size=16)
                    train_loss = float(np.mean(
                        l_train_loss))  # use float to make it json saveable
                    valid_loss = log_loss(y_valid, y_valid_pred)
                    print("Train loss:", train_loss, "valid loss:", valid_loss)
                    list_train_loss.append(train_loss)
                    list_valid_loss.append(valid_loss)

                    # Record experimental data in a dict
                    d_log = {}
                    d_log["fold"] = fold
                    d_log["nb_classes"] = nb_classes
                    d_log["batch_size"] = batch_size
                    d_log["n_batch_per_epoch"] = n_batch_per_epoch
                    d_log["nb_epoch"] = nb_epoch
                    d_log["epoch_size"] = epoch_size
                    d_log["prob"] = prob
                    d_log["optimizer"] = opt.get_config()
                    d_log["augmentator_config"] = DataAug.get_config()
                    d_log["train_loss"] = list_train_loss
                    d_log["valid_loss"] = list_valid_loss

                    json_file = os.path.join(
                        exp_dir, 'experiment_log_fold%s.json' % fold)
                    general_utils.save_exp_log(json_file, d_log)

                    # Only save the best epoch
                    if valid_loss < min_valid_loss:
                        min_valid_loss = valid_loss
                        trained_weights_path = os.path.join(
                            exp_dir,
                            '%s_weights_fold%s.h5' % (model.name, fold))
                        model.save_weights(trained_weights_path,
                                           overwrite=True)

        except KeyboardInterrupt:
            pass
Exemplo n.º 28
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    pretrained_model_path = kwargs["pretrained_model_path"]

    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(dset, image_data_format)
    img_dim = X_full_train.shape[-3:]


    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        load_pretrained = False
        if pretrained_model_path:
            load_pretrained = True

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator,
                                      img_dim,
                                      nb_patch,
                                      bn_mode,
                                      use_mbd,
                                      batch_size,
                                      load_pretrained)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          img_dim_disc,
                                          nb_patch,
                                          bn_mode,
                                          use_mbd,
                                          batch_size,
                                          load_pretrained)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model,
                                   discriminator_model,
                                   img_dim,
                                   patch_size,
                                   image_data_format)

        loss = [l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_full_batch, X_sketch_batch in data_utils.gen_batch(X_full_train, X_sketch_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(X_full_batch,
                                                           X_sketch_batch,
                                                           generator_model,
                                                           batch_counter,
                                                           patch_size,
                                                           image_data_format,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(data_utils.gen_batch(X_full_train, X_sketch_train, batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size, values=[("D logloss", disc_loss),
                                                ("G tot", gen_loss[0]),
                                                ("G L1", gen_loss[1]),
                                                ("G logloss", gen_loss[2])])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model,
                                                    batch_size, image_data_format, "training")
                    X_full_batch, X_sketch_batch = next(data_utils.gen_batch(X_full_val, X_sketch_val, batch_size))
                    data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model,
                                                    batch_size, image_data_format, "validation")

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
                discriminator_model.save_weights(disc_weights_path, overwrite=True)

                DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Exemplo n.º 29
0
def train(**kwargs):
    """
    Train standard DCGAN model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    generator = kwargs["generator"]
    dset = kwargs["dset"]
    img_dim = kwargs["img_dim"]
    nb_epoch = kwargs["nb_epoch"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    noise_dim = kwargs["noise_dim"]
    noise_scale = kwargs["noise_scale"]
    lr_D = kwargs["lr_D"]
    lr_G = kwargs["lr_G"]
    opt_D = kwargs["opt_D"]
    opt_G = kwargs["opt_G"]
    clamp_lower = kwargs["clamp_lower"]
    clamp_upper = kwargs["clamp_upper"]
    image_data_format = kwargs["image_data_format"]
    save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"]
    save_only_last_n_weights = kwargs["save_only_last_n_weights"]
    visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"]
    model_name = kwargs["model_name"]
    epoch_size = n_batch_per_epoch * batch_size

    print("\nExperiment parameters:")
    for key in kwargs.keys():
        print(key, kwargs[key])
    print("\n")

    # Setup environment (logging directory etc)
    general_utils.setup_logging(**kwargs)

    # Load and normalize data
    X_real_train, X_batch_gen = data_utils.load_image_dataset(
        dset, img_dim, image_data_format, batch_size)

    # Get the full real image dimension
    img_dim = X_real_train.shape[-3:]

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim, )
    if generator == "upsampling":
        generator_model = models.generator_upsampling(noise_dim,
                                                      img_dim,
                                                      dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim,
                                                  img_dim,
                                                  batch_size,
                                                  dset=dset)
    discriminator_model = models.discriminator(img_dim)
    DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim,
                               img_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    disc_losses = []
    disc_losses_real = []
    disc_losses_gen = []
    gen_losses = []

    #################
    # Start training
    ################

    try:

        for e in range(nb_epoch):

            print('--------------------------------------------')
            print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d}\n'.format(
                datetime.datetime.now(), e + 1, nb_epoch))

            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 0
            start = time.time()

            disc_loss_batch = 0
            disc_loss_real_batch = 0
            disc_loss_gen_batch = 0
            gen_loss_batch = 0

            for batch_counter in range(n_batch_per_epoch):

                if gen_iterations < 25 or gen_iterations % 500 == 0:
                    disc_iterations = 100
                else:
                    disc_iterations = kwargs["disc_iterations"]

                ###################################
                # 1) Train the critic / discriminator
                ###################################
                list_disc_loss_real = []
                list_disc_loss_gen = []
                for disc_it in range(disc_iterations):

                    # Clip discriminator weights
                    for l in discriminator_model.layers:
                        weights = l.get_weights()
                        weights = [
                            np.clip(w, clamp_lower, clamp_upper)
                            for w in weights
                        ]
                        l.set_weights(weights)

                    X_real_batch = next(
                        data_utils.gen_batch(X_real_train, X_batch_gen,
                                             batch_size))

                    # Create a batch to feed the discriminator model
                    X_disc_real, X_disc_gen = data_utils.get_disc_batch(
                        X_real_batch,
                        generator_model,
                        batch_counter,
                        batch_size,
                        noise_dim,
                        noise_scale=noise_scale)

                    # Update the discriminator
                    disc_loss_real = discriminator_model.train_on_batch(
                        X_disc_real, -np.ones(X_disc_real.shape[0]))
                    disc_loss_gen = discriminator_model.train_on_batch(
                        X_disc_gen, np.ones(X_disc_gen.shape[0]))
                    list_disc_loss_real.append(disc_loss_real)
                    list_disc_loss_gen.append(disc_loss_gen)

                #######################
                # 2) Train the generator
                #######################
                X_gen = data_utils.sample_noise(noise_scale, batch_size,
                                                noise_dim)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      -np.ones(X_gen.shape[0]))
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                gen_iterations += 1

                disc_loss_batch += -np.mean(list_disc_loss_real) - np.mean(
                    list_disc_loss_gen)
                disc_loss_real_batch += -np.mean(list_disc_loss_real)
                disc_loss_gen_batch += np.mean(list_disc_loss_gen)
                gen_loss_batch += -gen_loss

                progbar.add(batch_size,
                            values=[
                                ("Loss_D", -np.mean(list_disc_loss_real) -
                                 np.mean(list_disc_loss_gen)),
                                ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                ("Loss_G", -gen_loss)
                            ])

                # # Save images for visualization ~2 times per epoch
                # if batch_counter % (n_batch_per_epoch / 2) == 0:
                #     data_utils.plot_generated_batch(X_real_batch, generator_model,
                #                                     batch_size, noise_dim, image_data_format)

            disc_losses.append(disc_loss_batch / n_batch_per_epoch)
            disc_losses_real.append(disc_loss_real_batch / n_batch_per_epoch)
            disc_losses_gen.append(disc_loss_gen_batch / n_batch_per_epoch)
            gen_losses.append(gen_loss_batch / n_batch_per_epoch)

            # Save images for visualization
            if (e + 1) % visualize_images_every_n_epochs == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                e, batch_size, noise_dim,
                                                image_data_format, model_name)
                data_utils.plot_losses(disc_losses, disc_losses_real,
                                       disc_losses_gen, gen_losses, model_name)

            # Save model weights (by default, every 5 epochs)
            data_utils.save_model_weights(generator_model, discriminator_model,
                                          DCGAN_model, e,
                                          save_weights_every_n_epochs,
                                          save_only_last_n_weights, model_name)

            end = time.time()
            print('\nEpoch %s/%s END, Time: %s' %
                  (e + 1, nb_epoch, end - start))
            start = end

    except KeyboardInterrupt:
        pass

    gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name)
    print("Saving", gen_weights_path)
    generator_model.save(gen_weights_path, overwrite=True)