Beispiel #1
0
def train(args):
  import models
  import numpy as np
  # np.random.seed(1234)

  if args.dataset == 'mnist':
    n_dim, n_out, n_channels = 28, 10, 1
    X_train, y_train, X_val, y_val, _, _ = data.load_mnist()
  elif args.dataset == 'random':
    n_dim, n_out, n_channels = 2, 2, 1
    X_train, y_train = data.load_noise(n=1000, d=n_dim)
    X_val, y_val = X_train, y_train
  else:
    raise ValueError('Invalid dataset name: %s' % args.dataset)

  # set up optimization params
  opt_params = { 'lr' : args.lr, 'c' : args.c, 'n_critic' : args.n_critic }

  # create model
  if args.model == 'dcgan':
    model = models.DCGAN(n_dim=n_dim, n_chan=n_channels, opt_params=opt_params)
  elif args.model == 'wdcgan':
    model = models.WDCGAN(n_dim=n_dim, n_chan=n_channels, opt_params=opt_params)    
  else:
    raise ValueError('Invalid model')
  
  # train model
  model.fit(X_train, X_val, 
            n_epoch=args.epochs, n_batch=args.n_batch,
            logdir=args.logdir)
def train(args):

    #load dataset
    if args.dataset == 'mnist':
        n_dim, n_out, n_channels = 28, 10, 1
        X_train, y_train, X_val, y_val, _, _ = data.load_mnist()

    elif args.dataset == 'random':
        n_dim, n_out, n_channels = 2, 2, 1
        X_train, y_train = data.load_noise(n=1000, d=n_dim)
        X_val, y_val = X_train, y_train

    #可扩展
    elif args.dataset == 'malware_clean_data':
        n_dim, n_channels = 64, 1
        xtrain_mal, ytrain_mal, xtrain_ben, ytrain_ben, xtest_mal, ytest_mal, xtest_ben, ytest_ben = data.load_Malware_clean_ApkToImage(
        )
        if args.same_train_data:
            X_train, y_train, X_val, y_val = xtrain_mal, ytrain_mal, xtrain_ben, ytrain_ben
    else:
        raise ValueError('Invalid dataset name: %s' % args.dataset)

    # set up optimization params
    opt_params = {'lr': args.lr, 'c': args.c, 'n_critic': args.n_critic}

    # create model
    if args.model == 'dcgan':
        model = models.DCGAN(n_dim=n_dim,
                             n_chan=n_channels,
                             opt_params=opt_params)
    elif args.model == 'wdcgan':
        model = models.WDCGAN(n_dim=n_dim,
                              n_chan=n_channels,
                              opt_params=opt_params)
    else:
        raise ValueError('Invalid model')

    # train model

    model.fit(X_train,
              y_train,
              X_val,
              y_val,
              n_epoch=args.epochs,
              n_batch=args.n_batch,
              logdir=args.logdir)
Beispiel #3
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    image_data_format = kwargs["image_data_format"]
    generator_type = kwargs["generator_type"]
    dset = kwargs["dset"]
    use_identity_image = kwargs["use_identity_image"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    augment_data = kwargs["augment_data"]
    model_name = kwargs["model_name"]
    save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"]
    visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"]
    save_only_last_n_weights = kwargs["save_only_last_n_weights"]
    use_mbd = kwargs["use_mbd"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping_prob = kwargs["label_flipping_prob"]
    use_l1_weighted_loss = kwargs["use_l1_weighted_loss"]
    prev_model = kwargs["prev_model"]
    change_model_name_to_prev_model = kwargs["change_model_name_to_prev_model"]
    discriminator_optimizer = kwargs["discriminator_optimizer"]
    n_run_of_gen_for_1_run_of_disc = kwargs["n_run_of_gen_for_1_run_of_disc"]
    load_all_data_at_once = kwargs["load_all_data_at_once"]
    MAX_FRAMES_PER_GIF = kwargs["MAX_FRAMES_PER_GIF"]
    dont_train = kwargs["dont_train"]

    # batch_size = args.batch_size
    # n_batch_per_epoch = args.n_batch_per_epoch
    # nb_epoch = args.nb_epoch
    # save_weights_every_n_epochs = args.save_weights_every_n_epochs
    # generator_type = args.generator_type
    # patch_size = args.patch_size
    # label_smoothing = False
    # label_flipping_prob = False
    # dset = args.dset
    # use_mbd = False

    if dont_train:
        # Get the number of non overlapping patch and the size of input image to the discriminator
        nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format)
        generator_model = models.load("generator_unet_%s" % generator_type,
                                      img_dim,
                                      nb_patch,
                                      use_mbd,
                                      batch_size,
                                      model_name)
        generator_model.compile(loss='mae', optimizer='adam')
        return generator_model

    # Check and make the dataset
    # If .h5 file of dset is not present, try making it
    if load_all_data_at_once:
        if not os.path.exists("../../data/processed/%s_data.h5" % dset):
            print("dset %s_data.h5 not present in '../../data/processed'!" % dset)
            if not os.path.exists("../../data/%s/" % dset):
                print("dset folder %s not present in '../../data'!\n\nERROR: Dataset .h5 file not made, and dataset not available in '../../data/'.\n\nQuitting." % dset)
                return
            else:
                if not os.path.exists("../../data/%s/train" % dset) or not os.path.exists("../../data/%s/val" % dset) or not os.path.exists("../../data/%s/test" % dset):
                    print("'train', 'val' or 'test' folders not present in dset folder '../../data/%s'!\n\nERROR: Dataset must contain 'train', 'val' and 'test' folders.\n\nQuitting." % dset)
                    return
                else:
                    print("Making %s dataset" % dset)
                    subprocess.call(['python3', '../data/make_dataset.py', '../../data/%s' % dset, '3'])
                    print("Done!")
    else:
        if not os.path.exists(dset):
            print("dset does not exist! Given:", dset)
            return
        if not os.path.exists(os.path.join(dset, 'train')):
            print("dset does not contain a 'train' dir! Given dset:", dset)
            return
        if not os.path.exists(os.path.join(dset, 'val')):
            print("dset does not contain a 'val' dir! Given dset:", dset)
            return

    epoch_size = n_batch_per_epoch * batch_size

    init_epoch = 0

    if prev_model:
        print('\n\nLoading prev_model from', prev_model, '...\n\n')
        prev_model_latest_gen = sorted(glob.glob(os.path.join('../../models/', prev_model, '*gen*epoch*.h5')))[-1]
        prev_model_latest_disc = sorted(glob.glob(os.path.join('../../models/', prev_model, '*disc*epoch*.h5')))[-1]
        prev_model_latest_DCGAN = sorted(glob.glob(os.path.join('../../models/', prev_model, '*DCGAN*epoch*.h5')))[-1]
        print(prev_model_latest_gen, prev_model_latest_disc, prev_model_latest_DCGAN)
        if change_model_name_to_prev_model:
            # Find prev model name, epoch
            model_name = prev_model_latest_DCGAN.split('models')[-1].split('/')[1]
            init_epoch = int(prev_model_latest_DCGAN.split('epoch')[1][:5]) + 1

    # img_dim = X_target_train.shape[-3:]
    # img_dim = (256, 256, 3)

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        if discriminator_optimizer == 'sgd':
            opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        elif discriminator_optimizer == 'adam':
            opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator_type,
                                      img_dim,
                                      nb_patch,
                                      use_mbd,
                                      batch_size,
                                      model_name)

        generator_model.compile(loss='mae', optimizer=opt_dcgan)

        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          img_dim_disc,
                                          nb_patch,
                                          use_mbd,
                                          batch_size,
                                          model_name)

        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model,
                                   discriminator_model,
                                   img_dim,
                                   patch_size,
                                   image_data_format)

        if use_l1_weighted_loss:
            loss = [l1_weighted_loss, 'binary_crossentropy']
        else:
            loss = [l1_loss, 'binary_crossentropy']

        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

        # Load prev_model
        if prev_model:
            generator_model.load_weights(prev_model_latest_gen)
            discriminator_model.load_weights(prev_model_latest_disc)
            DCGAN_model.load_weights(prev_model_latest_DCGAN)

        # Load .h5 data all at once
        print('\n\nLoading data...\n\n')
        check_this_process_memory()

        if load_all_data_at_once:
            X_target_train, X_sketch_train, X_target_val, X_sketch_val = data_utils.load_data(dset, image_data_format)
            check_this_process_memory()
            print('X_target_train: %.4f' % (X_target_train.nbytes/2**30), "GB")
            print('X_sketch_train: %.4f' % (X_sketch_train.nbytes/2**30), "GB")
            print('X_target_val: %.4f' % (X_target_val.nbytes/2**30), "GB")
            print('X_sketch_val: %.4f' % (X_sketch_val.nbytes/2**30), "GB")

            # To generate training data
            X_target_batch_gen_train, X_sketch_batch_gen_train = data_utils.data_generator(X_target_train, X_sketch_train, batch_size, augment_data=augment_data)
            X_target_batch_gen_val, X_sketch_batch_gen_val = data_utils.data_generator(X_target_val, X_sketch_val, batch_size, augment_data=False)

        # Load data from images through an ImageDataGenerator
        else:
            X_batch_gen_train = data_utils.data_generator_from_dir(os.path.join(dset, 'train'), target_size=(img_dim[0], 2*img_dim[1]), batch_size=batch_size)
            X_batch_gen_val = data_utils.data_generator_from_dir(os.path.join(dset, 'val'), target_size=(img_dim[0], 2*img_dim[1]), batch_size=batch_size)

        check_this_process_memory()

        # Setup environment (logging directory etc)
        general_utils.setup_logging(**kwargs)

        # Losses
        disc_losses = []
        gen_total_losses = []
        gen_L1_losses = []
        gen_log_losses = []

        # Start training
        print("\n\nStarting training...\n\n")

        # For each epoch
        for e in range(nb_epoch):
            
            # Initialize progbar and batch counter
            # progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 0
            gen_total_loss_epoch = 0
            gen_L1_loss_epoch = 0
            gen_log_loss_epoch = 0
            start = time.time()
            
            # For each batch
            # for X_target_batch, X_sketch_batch in data_utils.gen_batch(X_target_train, X_sketch_train, batch_size):
            for batch in range(n_batch_per_epoch):
                
                # Create a batch to feed the discriminator model
                if load_all_data_at_once:
                    X_target_batch_train, X_sketch_batch_train = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train)
                else:
                    X_target_batch_train, X_sketch_batch_train = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim,
                                                                                                                   augment_data=augment_data,
                                                                                                                   use_identity_image=use_identity_image)

                X_disc, y_disc = data_utils.get_disc_batch(X_target_batch_train,
                                                           X_sketch_batch_train,
                                                           generator_model,
                                                           batch_counter,
                                                           patch_size,
                                                           image_data_format,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping_prob=label_flipping_prob)
                
                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
                
                # Create a batch to feed the generator model
                if load_all_data_at_once:
                    X_gen_target, X_gen_sketch = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train)
                else:
                    X_gen_target, X_gen_sketch = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim,
                                                                                                   augment_data=augment_data,
                                                                                                   use_identity_image=use_identity_image)

                y_gen_target = np.zeros((X_gen_target.shape[0], 2), dtype=np.uint8)
                y_gen_target[:, 1] = 1
                
                # Freeze the discriminator
                discriminator_model.trainable = False
                
                # Train generator
                for _ in range(n_run_of_gen_for_1_run_of_disc-1):
                    gen_loss = DCGAN_model.train_on_batch(X_gen_sketch, [X_gen_target, y_gen_target])
                    gen_total_loss_epoch += gen_loss[0]/n_run_of_gen_for_1_run_of_disc
                    gen_L1_loss_epoch += gen_loss[1]/n_run_of_gen_for_1_run_of_disc
                    gen_log_loss_epoch += gen_loss[2]/n_run_of_gen_for_1_run_of_disc
                    if load_all_data_at_once:
                        X_gen_target, X_gen_sketch = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train)
                    else:
                        X_gen_target, X_gen_sketch = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim,
                                                                                                       augment_data=augment_data,
                                                                                                       use_identity_image=use_identity_image)

                gen_loss = DCGAN_model.train_on_batch(X_gen_sketch, [X_gen_target, y_gen_target])
                
                # Add losses
                gen_total_loss_epoch += gen_loss[0]/n_run_of_gen_for_1_run_of_disc
                gen_L1_loss_epoch += gen_loss[1]/n_run_of_gen_for_1_run_of_disc
                gen_log_loss_epoch += gen_loss[2]/n_run_of_gen_for_1_run_of_disc
                
                # Unfreeze the discriminator
                discriminator_model.trainable = True
                
                # Progress
                # progbar.add(batch_size, values=[("D logloss", disc_loss),
                #                                 ("G tot", gen_loss[0]),
                #                                 ("G L1", gen_loss[1]),
                #                                 ("G logloss", gen_loss[2])])
                print("Epoch", str(init_epoch+e+1), "batch", str(batch+1), "D_logloss", disc_loss, "G_tot", gen_loss[0], "G_L1", gen_loss[1], "G_log", gen_loss[2])
            
            gen_total_loss = gen_total_loss_epoch/n_batch_per_epoch
            gen_L1_loss = gen_L1_loss_epoch/n_batch_per_epoch
            gen_log_loss = gen_log_loss_epoch/n_batch_per_epoch
            disc_losses.append(disc_loss)
            gen_total_losses.append(gen_total_loss)
            gen_L1_losses.append(gen_L1_loss)
            gen_log_losses.append(gen_log_loss)
            
            # Save images for visualization
            if (e + 1) % visualize_images_every_n_epochs == 0:
                data_utils.plot_generated_batch(X_target_batch_train, X_sketch_batch_train, generator_model, batch_size, image_data_format,
                                                model_name, "training", init_epoch + e + 1, MAX_FRAMES_PER_GIF)
                # Get new images for validation
                if load_all_data_at_once:
                    X_target_batch_val, X_sketch_batch_val = next(X_target_batch_gen_val), next(X_sketch_batch_gen_val)
                else:
                    X_target_batch_val, X_sketch_batch_val = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_val, img_dim=img_dim, augment_data=False, use_identity_image=use_identity_image)
                # Predict and validate
                data_utils.plot_generated_batch(X_target_batch_val, X_sketch_batch_val, generator_model, batch_size, image_data_format,
                                                model_name, "validation", init_epoch + e + 1, MAX_FRAMES_PER_GIF)
                # Plot losses
                data_utils.plot_losses(disc_losses, gen_total_losses, gen_L1_losses, gen_log_losses, model_name, init_epoch)
            
            # Save weights
            if (e + 1) % save_weights_every_n_epochs == 0:
                # Delete all but the last n weights
                purge_weights(save_only_last_n_weights, model_name)
                # Save gen weights
                gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1]))
                print("Saving", gen_weights_path)
                generator_model.save_weights(gen_weights_path, overwrite=True)
                # Save disc weights
                disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1]))
                print("Saving", disc_weights_path)
                discriminator_model.save_weights(disc_weights_path, overwrite=True)
                # Save DCGAN weights
                DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1]))
                print("Saving", DCGAN_weights_path)
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

            check_this_process_memory()
            print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d} END, Time taken: {3:.4f} seconds'.format(datetime.datetime.now(), init_epoch + e + 1, init_epoch + nb_epoch, time.time() - start))
            print('------------------------------------------------------------------------------------')

    except KeyboardInterrupt:
        pass

    # SAVE THE MODEL

    try:
        # Save the model as it is, so that it can be loaded using -
        # ```from keras.models import load_model; gen = load_model('generator_latest.h5')```
        gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name)
        print("Saving", gen_weights_path)
        generator_model.save(gen_weights_path, overwrite=True)
    
        # Save model as json string
        generator_model_json_string = generator_model.to_json()
        print("Saving", '../../models/%s/generator_latest.txt' % model_name)
        with open('../../models/%s/generator_latest.txt' % model_name, 'w') as outfile:
            a = outfile.write(generator_model_json_string)
    
        # Save model as json
        generator_model_json_data = json.loads(generator_model_json_string)
        print("Saving", '../../models/%s/generator_latest.json' % model_name)
        with open('../../models/%s/generator_latest.json' % model_name, 'w') as outfile:
            json.dump(generator_model_json_data, outfile)

    except:
        print(sys.exc_info()[0])

    print("Done.")

    return generator_model
Beispiel #4
0
def main(dataset, batch_size, patch_size, epochs, label_smoothing,
         label_flipping):
    print(project_dir)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
    sess = tf.Session(config=config)
    K.tensorflow_backend.set_session(
        sess)  # set this TensorFlow session as the default session for Keras

    image_data_format = "channels_first"
    K.set_image_data_format(image_data_format)

    save_images_every_n_batches = 30
    save_model_every_n_epochs = 0

    # configuration parameters
    print("Config params:")
    print("  dataset = {}".format(dataset))
    print("  batch_size = {}".format(batch_size))
    print("  patch_size = {}".format(patch_size))
    print("  epochs = {}".format(epochs))
    print("  label_smoothing = {}".format(label_smoothing))
    print("  label_flipping = {}".format(label_flipping))
    print("  save_images_every_n_batches = {}".format(
        save_images_every_n_batches))
    print("  save_model_every_n_epochs = {}".format(save_model_every_n_epochs))

    model_name = datetime.strftime(datetime.now(), '%y%m%d-%H%M')
    model_dir = os.path.join(project_dir, "models", model_name)
    fig_dir = os.path.join(project_dir, "reports", "figures")
    logs_dir = os.path.join(project_dir, "reports", "logs", model_name)

    os.makedirs(model_dir)

    # Load and rescale data
    ds_train_gen = data_utils.DataGenerator(file_path=dataset,
                                            dataset_type="train",
                                            batch_size=batch_size)
    ds_train_disc = data_utils.DataGenerator(file_path=dataset,
                                             dataset_type="train",
                                             batch_size=batch_size)
    ds_val = data_utils.DataGenerator(file_path=dataset,
                                      dataset_type="val",
                                      batch_size=batch_size)
    enq_train_gen = OrderedEnqueuer(ds_train_gen,
                                    use_multiprocessing=True,
                                    shuffle=True)
    enq_train_disc = OrderedEnqueuer(ds_train_disc,
                                     use_multiprocessing=True,
                                     shuffle=True)
    enq_val = OrderedEnqueuer(ds_val, use_multiprocessing=True, shuffle=False)

    img_dim = ds_train_gen[0][0].shape[-3:]

    n_batch_per_epoch = len(ds_train_gen)
    epoch_size = n_batch_per_epoch * batch_size

    print("Derived params:")
    print("  n_batch_per_epoch = {}".format(n_batch_per_epoch))
    print("  epoch_size = {}".format(epoch_size))
    print("  n_batches_val = {}".format(len(ds_val)))

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size)

    tensorboard = TensorBoard(log_dir=logs_dir,
                              histogram_freq=0,
                              batch_size=batch_size,
                              write_graph=True,
                              write_grads=True,
                              update_freq='batch')

    try:
        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # Load generator model
        generator_model = models.generator_unet_upsampling(img_dim)
        generator_model.summary()
        plot_model(generator_model,
                   to_file=os.path.join(fig_dir, "generator_model.png"),
                   show_shapes=True,
                   show_layer_names=True)

        # Load discriminator model
        # TODO: modify disc to accept real input as well
        discriminator_model = models.DCGAN_discriminator(
            img_dim_disc, nb_patch)
        discriminator_model.summary()
        plot_model(discriminator_model,
                   to_file=os.path.join(fig_dir, "discriminator_model.png"),
                   show_shapes=True,
                   show_layer_names=True)

        # TODO: pretty sure this is unnecessary
        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_data_format)

        # L1 loss applies to generated image, cross entropy applies to predicted label
        loss = [models.l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        tensorboard.set_model(DCGAN_model)

        # Start training
        enq_train_gen.start(workers=1, max_queue_size=20)
        enq_train_disc.start(workers=1, max_queue_size=20)
        enq_val.start(workers=1, max_queue_size=20)
        out_train_gen = enq_train_gen.get()
        out_train_disc = enq_train_disc.get()
        out_val = enq_val.get()

        print("Start training")
        for e in range(1, epochs + 1):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            start = time.time()

            for batch_counter in range(1, n_batch_per_epoch + 1):
                X_transformed_batch, X_orig_batch = next(out_train_disc)

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_transformed_batch,
                    X_orig_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(out_train_gen)
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                # Set labels to 1 (real) to maximize the discriminator loss
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                metrics = [("D logloss", disc_loss), ("G tot", gen_loss[0]),
                           ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]
                progbar.add(batch_size, values=metrics)

                logs = {k: v for (k, v) in metrics}
                logs["size"] = batch_size

                tensorboard.on_batch_end(batch_counter, logs=logs)

                # Save images for visualization
                if batch_counter % save_images_every_n_batches == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(
                        X_transformed_batch, X_orig_batch, generator_model,
                        os.path.join(logs_dir, "current_batch_training.png"))
                    X_transformed_batch, X_orig_batch = next(out_val)
                    data_utils.plot_generated_batch(
                        X_transformed_batch, X_orig_batch, generator_model,
                        os.path.join(logs_dir, "current_batch_validation.png"))

            print("")
            print('Epoch %s/%s, Time: %s' % (e, epochs, time.time() - start))
            tensorboard.on_epoch_end(e, logs=logs)

            if (save_model_every_n_epochs >= 1 and e % save_model_every_n_epochs == 0) or \
                    (e == epochs):
                print("Saving model for epoch {}...".format(e), end="")
                sys.stdout.flush()
                gen_weights_path = os.path.join(
                    model_dir, 'gen_weights_epoch{:03d}.h5'.format(e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    model_dir, 'disc_weights_epoch{:03d}.h5'.format(e))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    model_dir, 'DCGAN_weights_epoch{:03d}.h5'.format(e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)
                print("done")

    except KeyboardInterrupt:
        pass

    enq_train_gen.stop()
    enq_train_disc.stop()
    enq_val.stop()
def gen_theano_fn(args):
    """
    Generate the networks and returns the train functions
    """
    if args.verbose:
        print 'Creating networks...'

    # Setup input variables
    inpt_noise = T.matrix('inpt_noise')
    inpt_image = T.tensor4('inpt_image')
    corr_mask = T.matrix('corr_mask')  # corruption mask
    corr_image = T.tensor4('corr_image')
    if args.captions:
        inpt_embd = T.matrix('inpt_embedding')

    # Shared variable for image reconstruction
    reconstr_noise_shrd = theano.shared(
        np.random.uniform(-1., 1., size=(1, 100)).astype(theano.config.floatX))

    # Build generator and discriminator
    if args.captions:
        cond_gen_dc_gan = models.CaptionGenOnlyDCGAN(args)
        generator, lyr_gen_noise, lyr_gen_embd = cond_gen_dc_gan.init_generator(
            first_layer=64, input_var=None, embedding_var=None)
        discriminator = cond_gen_dc_gan.init_discriminator(first_layer=128,
                                                           input_var=None)
    else:
        dc_gan = models.DCGAN(args)
        generator = dc_gan.init_generator(first_layer=64, input_var=None)
        discriminator = dc_gan.init_discriminator(first_layer=128,
                                                  input_var=None)

    # Get images from generator (for training and outputing images)
    if args.captions:
        image_fake = lyr.get_output(generator,
                                    inputs={
                                        lyr_gen_noise: inpt_noise,
                                        lyr_gen_embd: inpt_embd
                                    })
        image_fake_det = lyr.get_output(generator,
                                        inputs={
                                            lyr_gen_noise: inpt_noise,
                                            lyr_gen_embd: inpt_embd
                                        },
                                        deterministic=True)
        image_reconstr = lyr.get_output(generator,
                                        inputs={
                                            lyr_gen_noise: reconstr_noise_shrd,
                                            lyr_gen_embd: inpt_embd
                                        },
                                        deterministic=True)
    else:
        image_fake = lyr.get_output(generator, inputs=inpt_noise)
        image_fake_det = lyr.get_output(generator,
                                        inputs=inpt_noise,
                                        deterministic=True)
        image_reconstr = lyr.get_output(generator,
                                        inputs=reconstr_noise_shrd,
                                        deterministic=True)

    # Get probabilities from discriminator
    probs_real = lyr.get_output(discriminator, inputs=inpt_image)
    probs_fake = lyr.get_output(discriminator, inputs=image_fake)
    probs_fake_det = lyr.get_output(discriminator,
                                    inputs=image_fake_det,
                                    deterministic=True)
    probs_reconstr = lyr.get_output(discriminator,
                                    inputs=image_reconstr,
                                    deterministic=True)

    # Calc loss for discriminator
    # minimize prob of error on true images
    d_loss_real = -T.mean(T.log(probs_real))
    # minimize prob of error on fake images
    d_loss_fake = -T.mean(T.log(1 - probs_fake))
    loss_discr = d_loss_real + d_loss_fake

    # Calc loss for generator
    # minimize the error of the discriminator on fake images
    loss_gener = -T.mean(T.log(probs_fake))

    # Create params dict for both discriminator and generator
    params_discr = lyr.get_all_params(discriminator, trainable=True)
    params_gener = lyr.get_all_params(generator, trainable=True)

    # Set update rules for params using adam
    updates_discr = lasagne.updates.adam(loss_discr,
                                         params_discr,
                                         learning_rate=0.001,
                                         beta1=0.9)
    updates_gener = lasagne.updates.adam(loss_gener,
                                         params_gener,
                                         learning_rate=0.0005,
                                         beta1=0.6)

    # Contextual and perceptual loss for
    contx_loss = T.mean(
        lasagne.objectives.squared_error(image_reconstr * corr_mask,
                                         corr_image * corr_mask))
    prcpt_loss = T.mean(T.log(1 - probs_reconstr))

    # Total loss
    lbda = 10.0**-5
    reconstr_loss = contx_loss + lbda * prcpt_loss

    # Set update rule that will change the input noise
    grad = T.grad(reconstr_loss, reconstr_noise_shrd)
    lr = 0.9
    update_rule = reconstr_noise_shrd - lr * grad

    if args.verbose:
        print 'Networks created.'

    # Compile Theano functions
    print 'compiling...'

    if args.captions:
        train_d = theano.function([inpt_image, inpt_noise, inpt_embd],
                                  loss_discr,
                                  updates=updates_discr)
        print '- 1 of 4 compiled.'
        train_g = theano.function([inpt_noise, inpt_embd],
                                  loss_gener,
                                  updates=updates_gener)
        print '- 2 of 4 compiled.'
        predict = theano.function([inpt_noise, inpt_embd],
                                  [image_fake_det, probs_fake_det])
        print '- 3 of 4 compiled.'
        reconstr = theano.function(
            [corr_image, corr_mask, inpt_embd],
            [reconstr_noise_shrd, image_reconstr, reconstr_loss, grad],
            updates=[(reconstr_noise_shrd, update_rule)])
        print '- 4 of 4 compiled.'
    else:
        train_d = theano.function([inpt_image, inpt_noise],
                                  loss_discr,
                                  updates=updates_discr)
        print '- 1 of 4 compiled.'
        train_g = theano.function([inpt_noise],
                                  loss_gener,
                                  updates=updates_gener)
        print '- 2 of 4 compiled.'
        predict = theano.function([inpt_noise],
                                  [image_fake_det, probs_fake_det])
        print '- 3 of 4 compiled.'
        reconstr = theano.function(
            [corr_image, corr_mask],
            [reconstr_noise_shrd, image_reconstr, reconstr_loss, grad],
            updates=[(reconstr_noise_shrd, update_rule)])
        print '- 4 of 4 compiled.'

    print 'compiled.'

    return train_d, train_g, predict, reconstr, reconstr_noise_shrd, (
        discriminator, generator)
Beispiel #6
0
def train(args):
  import models
  import numpy as np
  np.random.seed(1234)

  if args.dataset == 'digits':
    n_dim, n_out, n_channels = 8, 10, 1
    X_train, y_train, X_val, y_val = data.load_digits()
  elif args.dataset == 'mnist':
    n_dim, n_out, n_channels = 28, 10, 1
    X_train, y_train, X_val, y_val, _, _ = data.load_mnist()
  elif args.dataset == 'svhn':
    n_dim, n_out, n_channels = 32, 10, 3
    X_train, y_train, X_val, y_val = data.load_svhn()
    X_train, y_train, X_val, y_val = data.prepare_dataset(X_train, y_train, X_val, y_val)
  elif args.dataset == 'cifar10':
    n_dim, n_out, n_channels = 32, 10, 3
    X_train, y_train, X_val, y_val = data.load_cifar10()
    X_train, y_train, X_val, y_val = data.prepare_dataset(X_train, y_train, X_val, y_val)
  elif args.dataset == 'random':
    n_dim, n_out, n_channels = 2, 2, 1
    X_train, y_train = data.load_noise(n=1000, d=n_dim)
    X_val, y_val = X_train, y_train
  else:
    raise ValueError('Invalid dataset name: %s' % args.dataset)
  print 'dataset loaded, dim:', X_train.shape

  # set up optimization params
  p = { 'lr' : args.lr, 'b1': args.b1, 'b2': args.b2 }

  # create model
  if args.model == 'softmax':
    model = models.Softmax(n_dim=n_dim, n_out=n_out, n_superbatch=args.n_superbatch, 
                           opt_alg=args.alg, opt_params=p)
  elif args.model == 'mlp':
    model = models.MLP(n_dim=n_dim, n_out=n_out, n_superbatch=args.n_superbatch, 
                       opt_alg=args.alg, opt_params=p)
  elif args.model == 'cnn':
    model = models.CNN(n_dim=n_dim, n_out=n_out, n_chan=n_channels, model=args.dataset,
                       n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)  
  elif args.model == 'kcnn':
    model = models.KCNN(n_dim=n_dim, n_out=n_out, n_chan=n_channels, model=args.dataset,
                       n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)    
  elif args.model == 'resnet':
    model = models.Resnet(n_dim=n_dim, n_out=n_out, n_chan=n_channels,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)    
  elif args.model == 'vae':
    model = models.VAE(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p,
                          model='bernoulli' if args.dataset in ('digits', 'mnist') 
                                            else 'gaussian')    
  elif args.model == 'convvae':
    model = models.ConvVAE(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p,
                          model='bernoulli' if args.dataset in ('digits', 'mnist') 
                                            else 'gaussian')    
  elif args.model == 'convadgm':
    model = models.ConvADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p,
                          model='bernoulli' if args.dataset in ('digits', 'mnist') 
                                            else 'gaussian')    
  elif args.model == 'sbn':
    model = models.SBN(n_dim=n_dim, n_out=n_out, n_chan=n_channels,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)      
  elif args.model == 'adgm':
    model = models.ADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p,
                          model='bernoulli' if args.dataset in ('digits', 'mnist') 
                                            else 'gaussian')
  elif args.model == 'hdgm':
    model = models.HDGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)        
  elif args.model == 'dadgm':
    model = models.DADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) 
  elif args.model == 'dcgan':
    model = models.DCGAN(n_dim=n_dim, n_out=n_out, n_chan=n_channels,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)   
  elif args.model == 'ssadgm':
    X_train_lbl, y_train_lbl, X_train_unl, y_train_unl \
      = data.split_semisup(X_train, y_train, n_lbl=args.n_labeled)
    model = models.SSADGM(X_labeled=X_train_lbl, y_labeled=y_train_lbl, n_out=n_out,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)
    X_train, y_train = X_train_unl, y_train_unl
  else:
    raise ValueError('Invalid model')
  
  # train model
  model.fit(X_train, y_train, X_val, y_val, 
            n_epoch=args.epochs, n_batch=args.n_batch,
            logname=args.logname)
Beispiel #7
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    lastLayerActivation=kwargs["lastLayerActivation"]
    PercentageOfTrianable=kwargs["PercentageOfTrianable"]
    SpecificPathStr=kwargs["SpecificPathStr"]
    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    #general_utils.setup_logging(model_name)

    # Load and rescale data
    #X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(dset, image_data_format)
    img_dim = (256,256,3) # Manual entry

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        # Load generator model
        """
        generator_model = models.load("generator_unet_%s" % generator,
                                      img_dim,
                                      nb_patch,
                                      bn_mode,
                                      use_mbd,
                                      batch_size)
        """
        generator_model=CreatErrorMapModel(input_shape=img_dim,lastLayerActivation=lastLayerActivation, PercentageOfTrianable=PercentageOfTrianable)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          img_dim_disc,
                                          nb_patch,
                                          bn_mode,
                                          use_mbd,
                                          batch_size)

         generator_model.compile(loss='mae', optimizer=opt_discriminator)
#-------------------------------------------------------------------------------
         logpath=os.path.join('../../log','DepthMapWith'+lastLayerActivation+str(PercentageOfTrianable)+'UnTr'+SpecificPathStr)
         modelPath=os.path.join('../../models','DepthMapwith'+lastLayerActivation+str(PercentageOfTrianable)+'Untr'+SpecificPathStr)
         os.makedirs(logpath, exist_ok=True)
         os.makedirs(modelPath, exist_ok=True)os.makedirs(modelPath, exist_ok=True)

#-----------------------PreTraining Depth Map-------------------------------------
         nb_train_samples = 2000
         nb_validation_samples = 
         epochs = 20
         history=whole_model.fit_generator(data_utils.facades_generator(img_dim,batch_size=batch_size), samples_per_epoch=nb_train_samples,epochs=epochs,validation_data=data_utils.facades_generator(img_dim,batch_size=batch_size),nb_val_samples=nb_validation_    samples,       callbacks=[
         keras.callbacks.ModelCheckpoint(os.path.join(modelPath,'DepthMap_weightsBestLoss.h5'), monitor='val_loss', verbose=1, save_best_only=True),
         keras.callbacks.ModelCheckpoint(os.path.join(modelPath,'DepthMap_weightsBestAcc.h5'), monitor='acc', verbose=1, save_best_only=True),
         keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.1, patience=2, verbose=1, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0),
         keras.callbacks.TensorBoard(log_dir=logpath, histogram_freq=0, batch_size=batchSize, write_graph=True, write_grads=False, write_images=True, embeddin    gs_freq=0, embeddings_layer_names=None, embeddings_metadata=None)],)
#------------------------------------------------------------------------------------


        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model,
                                   discriminator_model,
                                   img_dim,
                                   patch_size,
                                   image_data_format)

        loss = [l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_full_batch, X_sketch_batch in data_utils.facades_generator(img_dim,batch_size=batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(X_full_batch,
                                                           X_sketch_batch,
                                                           generator_model,
                                                           batch_counter,
                                                           patch_size,
                                                           image_data_format,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # X_disc, y_disc
                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(data_utils.facades_generator(img_dim,batch_size=batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size, values=[("D logloss", disc_loss),
                                                ("G tot", gen_loss[0]),
                                                ("G L1", gen_loss[1]),
                                                ("G logloss", gen_loss[2])])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    # Get new images from validation
                    figure_name = "training_"+str(e)
                    data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model,
                                                    batch_size, image_data_format, figure_name)

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
                discriminator_model.save_weights(disc_weights_path, overwrite=True)

                DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)
Beispiel #8
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_dim_ordering = kwargs["image_dim_ordering"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]

    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)
    print "hi"

    # Load and rescale data
    X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(
        dset, image_dim_ordering)
    img_dim = X_full_train.shape[-3:]
    print "data loaded in memory"

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size,
                                                     image_dim_ordering)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator, img_dim,
                                      nb_patch, bn_mode, use_mbd, batch_size)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator", img_dim_disc,
                                          nb_patch, bn_mode, use_mbd,
                                          batch_size)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_dim_ordering)

        loss = [l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        gen_loss = None
        disc_loss = None

        iter_num = 102
        weights_path = "/home/abhik/pix2pix/src/model/weights/gen_weights_iter%s_epoch30.h5" % (
            str(iter_num - 1))
        print weights_path
        generator_model.load_weights(weights_path)

        #discriminator_model.load_weights("disc_weights1.2.h5")

        #DCGAN_model.load_weights("DCGAN_weights1.2.h5")
        print("Weights Loaded for iter - %d" % iter_num)

        # Running average
        losses_list = list()
        # loss_list = list()
        # prev_avg = 0

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            # global disc_n, disc_prev_avg, gen1_n, gen1_prev_avg, gen2_n, gen2_prev_avg, gen3_n, gen3_prev_avg

            # disc_n = 1
            # disc_prev_avg = 0

            # gen1_n = 1
            # gen1_prev_avg = 0

            # gen2_n = 1
            # gen2_prev_avg = 0

            # gen3_n = 1
            # gen3_prev_avg = 0

            for X_full_batch, X_sketch_batch in data_utils.gen_batch(
                    X_full_train, X_sketch_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_full_batch,
                    X_sketch_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    image_dim_ordering,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(
                    data_utils.gen_batch(X_full_train, X_sketch_train,
                                         batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                # Running average
                # loss_list.append(disc_loss)
                # loss_list_n = len(loss_list)
                # new_avg = ((loss_list_n-1)*prev_avg + disc_loss)/loss_list_n
                # prev_avg = new_avg

                # disc_avg, gen1_avg, gen2_avg, gen3_avg = running_avg(disc_loss, gen_loss[0], gen_loss[1], gen_loss[2])

                # print("running disc loss", new_avg)
                # print(disc_loss, gen_loss)
                # print ("all losses", disc_avg, gen1_avg, gen2_avg, gen3_avg)
                # print("")

                batch_counter += 1
                progbar.add(batch_size,
                            values=[("D logloss", disc_loss),
                                    ("G tot", gen_loss[0]),
                                    ("G L1", gen_loss[1]),
                                    ("G logloss", gen_loss[2])])

                # Saving data for plotting
                # losses = [e+1, batch_counter, disc_loss, gen_loss[0], gen_loss[1], gen_loss[2], disc_avg, gen1_avg, gen2_avg, gen3_avg, iter_num]
                # losses_list.append(losses)

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(
                        X_full_batch, X_sketch_batch, generator_model,
                        batch_size, image_dim_ordering, "training", iter_num)
                    X_full_batch, X_sketch_batch = next(
                        data_utils.gen_batch(X_full_val, X_sketch_val,
                                             batch_size))
                    data_utils.plot_generated_batch(
                        X_full_batch, X_sketch_batch, generator_model,
                        batch_size, image_dim_ordering, "validation", iter_num)

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))

            #Running average
            disc_avg, gen1_avg, gen2_avg, gen3_avg = running_avg(
                disc_loss, gen_loss[0], gen_loss[1], gen_loss[2])

            #Validation loss
            y_gen_val = np.zeros((X_sketch_batch.shape[0], 2), dtype=np.uint8)
            y_gen_val[:, 1] = 1
            val_loss = DCGAN_model.test_on_batch(X_full_batch,
                                                 [X_sketch_batch, y_gen_val])
            # print "val_loss ===" + str(val_loss)

            #logging
            # Saving data for plotting
            losses = [
                e + 1, iter_num, disc_loss, gen_loss[0], gen_loss[1],
                gen_loss[2], disc_avg, gen1_avg, gen2_avg, gen3_avg,
                val_loss[0], val_loss[1], val_loss[2]
            ]
            losses_list.append(losses)

            if (e + 1) % 5 == 0:
                gen_weights_path = os.path.join(
                    '../../models/%s/gen_weights_iter%s_epoch%s.h5' %
                    (model_name, iter_num, e + 1))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    '../../models/%s/disc_weights_iter%s_epoch%s.h5' %
                    (model_name, iter_num, e + 1))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    '../../models/%s/DCGAN_weights_iter%s_epoch%s.h5' %
                    (model_name, iter_num, e + 1))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

                loss_array = np.asarray(losses_list)
                print(loss_array.shape)  # 10 element vector

                loss_path = os.path.join(
                    '../../losses/loss_iter%s_epoch%s.csv' % (iter_num, e + 1))
                np.savetxt(loss_path, loss_array, fmt='%.5f', delimiter=',')
                np.savetxt('test.csv', loss_array, fmt='%.5f', delimiter=',')

    except KeyboardInterrupt:
        pass
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    patch_size = kwargs["patch_size"]
    image_data_format = kwargs["image_data_format"]
    generator_type = kwargs["generator_type"]
    dset = kwargs["dset"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"]
    visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"]
    use_mbd = kwargs["use_mbd"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping_prob = kwargs["label_flipping_prob"]
    use_l1_weighted_loss = kwargs["use_l1_weighted_loss"]
    prev_model = kwargs["prev_model"]
    discriminator_optimizer = kwargs["discriminator_optimizer"]
    n_run_of_gen_for_1_run_of_disc = kwargs["n_run_of_gen_for_1_run_of_disc"]
    MAX_FRAMES_PER_GIF = kwargs["MAX_FRAMES_PER_GIF"]

    # batch_size = args.batch_size
    # n_batch_per_epoch = args.n_batch_per_epoch
    # nb_epoch = args.nb_epoch
    # save_weights_every_n_epochs = args.save_weights_every_n_epochs
    # generator_type = args.generator_type
    # patch_size = args.patch_size
    # label_smoothing = False
    # label_flipping_prob = False
    # dset = args.dset
    # use_mbd = False

    # Check and make the dataset
    # If .h5 file of dset is not present, try making it
    if not os.path.exists("../../data/processed/%s_data.h5" % dset):
        print("dset %s_data.h5 not present in '../../data/processed'!" % dset)
        if not os.path.exists("../../data/%s/" % dset):
            print(
                "dset folder %s not present in '../../data'!\n\nERROR: Dataset .h5 file not made, and dataset not available in '../../data/'.\n\nQuitting."
                % dset)
            return
        else:
            if not os.path.exists(
                    "../../data/%s/train" % dset) or not os.path.exists(
                        "../../data/%s/val" % dset) or not os.path.exists(
                            "../../data/%s/test" % dset):
                print(
                    "'train', 'val' or 'test' folders not present in dset folder '../../data/%s'!\n\nERROR: Dataset must contain 'train', 'val' and 'test' folders.\n\nQuitting."
                    % dset)
                return
            else:
                print("Making %s dataset" % dset)
                subprocess.call([
                    'python3', '../data/make_dataset.py',
                    '../../data/%s' % dset, '3'
                ])
                print("Done!")

    epoch_size = n_batch_per_epoch * batch_size

    init_epoch = 0

    if prev_model:
        print('\n\nLoading prev_model from', prev_model, '...\n\n')
        prev_model_latest_gen = sorted(
            glob.glob(os.path.join('../../models/', prev_model,
                                   '*gen*.h5')))[-1]
        prev_model_latest_disc = sorted(
            glob.glob(os.path.join('../../models/', prev_model,
                                   '*disc*.h5')))[-1]
        prev_model_latest_DCGAN = sorted(
            glob.glob(os.path.join('../../models/', prev_model,
                                   '*DCGAN*.h5')))[-1]
        # Find prev model name, epoch
        model_name = prev_model_latest_DCGAN.split('models')[-1].split('/')[1]
        init_epoch = int(prev_model_latest_DCGAN.split('epoch')[1][:5]) + 1

    # Setup environment (logging directory etc), if no prev_model is mentioned
    general_utils.setup_logging(model_name)

    # img_dim = X_full_train.shape[-3:]
    img_dim = (256, 256, 3)

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size,
                                                     image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        if discriminator_optimizer == 'sgd':
            opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        elif discriminator_optimizer == 'adam':
            opt_discriminator = Adam(lr=1E-3,
                                     beta_1=0.9,
                                     beta_2=0.999,
                                     epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator_type,
                                      img_dim, nb_patch, use_mbd, batch_size,
                                      model_name)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)

        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator", img_dim_disc,
                                          nb_patch, use_mbd, batch_size,
                                          model_name)

        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_data_format)

        if use_l1_weighted_loss:
            loss = [l1_weighted_loss, 'binary_crossentropy']
        else:
            loss = [l1_loss, 'binary_crossentropy']

        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        # Load prev_model
        if prev_model:
            generator_model.load_weights(prev_model_latest_gen)
            discriminator_model.load_weights(prev_model_latest_disc)
            DCGAN_model.load_weights(prev_model_latest_DCGAN)

        # Load and rescale data
        print('\n\nLoading data...\n\n')
        X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(
            dset, image_data_format)
        check_this_process_memory()
        print('X_full_train: %.4f' % (X_full_train.nbytes / 2**30), "GB")
        print('X_sketch_train: %.4f' % (X_sketch_train.nbytes / 2**30), "GB")
        print('X_full_val: %.4f' % (X_full_val.nbytes / 2**30), "GB")
        print('X_sketch_val: %.4f' % (X_sketch_val.nbytes / 2**30), "GB")

        # Losses
        disc_losses = []
        gen_total_losses = []
        gen_L1_losses = []
        gen_log_losses = []

        # Start training
        print("\n\nStarting training\n\n")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            # progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 0
            gen_total_loss_epoch = 0
            gen_L1_loss_epoch = 0
            gen_log_loss_epoch = 0
            start = time.time()
            for X_full_batch, X_sketch_batch in data_utils.gen_batch(
                    X_full_train, X_sketch_train, batch_size):
                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_full_batch,
                    X_sketch_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    image_data_format,
                    label_smoothing=label_smoothing,
                    label_flipping_prob=label_flipping_prob)
                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(
                    data_utils.gen_batch(X_full_train, X_sketch_train,
                                         batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1
                # Freeze the discriminator
                discriminator_model.trainable = False
                # Train generator
                for _ in range(n_run_of_gen_for_1_run_of_disc - 1):
                    gen_loss = DCGAN_model.train_on_batch(
                        X_gen, [X_gen_target, y_gen])
                    gen_total_loss_epoch += gen_loss[
                        0] / n_run_of_gen_for_1_run_of_disc
                    gen_L1_loss_epoch += gen_loss[
                        1] / n_run_of_gen_for_1_run_of_disc
                    gen_log_loss_epoch += gen_loss[
                        2] / n_run_of_gen_for_1_run_of_disc
                    X_gen_target, X_gen = next(
                        data_utils.gen_batch(X_full_train, X_sketch_train,
                                             batch_size))
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Add losses
                gen_total_loss_epoch += gen_loss[
                    0] / n_run_of_gen_for_1_run_of_disc
                gen_L1_loss_epoch += gen_loss[
                    1] / n_run_of_gen_for_1_run_of_disc
                gen_log_loss_epoch += gen_loss[
                    2] / n_run_of_gen_for_1_run_of_disc
                # Unfreeze the discriminator
                discriminator_model.trainable = True
                # Progress
                # progbar.add(batch_size, values=[("D logloss", disc_loss),
                #                                 ("G tot", gen_loss[0]),
                #                                 ("G L1", gen_loss[1]),
                #                                 ("G logloss", gen_loss[2])])
                print("Epoch", str(init_epoch + e + 1), "batch",
                      str(batch_counter + 1), "D_logloss", disc_loss, "G_tot",
                      gen_loss[0], "G_L1", gen_loss[1], "G_log", gen_loss[2])
                batch_counter += 1
                if batch_counter >= n_batch_per_epoch:
                    break
            gen_total_loss = gen_total_loss_epoch / n_batch_per_epoch
            gen_L1_loss = gen_L1_loss_epoch / n_batch_per_epoch
            gen_log_loss = gen_log_loss_epoch / n_batch_per_epoch
            disc_losses.append(disc_loss)
            gen_total_losses.append(gen_total_loss)
            gen_L1_losses.append(gen_L1_loss)
            gen_log_losses.append(gen_log_loss)
            check_this_process_memory()
            print('Epoch %s/%s, Time: %.4f' % (init_epoch + e + 1, init_epoch +
                                               nb_epoch, time.time() - start))
            # Save images for visualization
            if (e + 1) % visualize_images_every_n_epochs == 0:
                data_utils.plot_generated_batch(X_full_batch, X_sketch_batch,
                                                generator_model, batch_size,
                                                image_data_format, model_name,
                                                "training", init_epoch + e + 1,
                                                MAX_FRAMES_PER_GIF)
                # Get new images from validation
                X_full_batch, X_sketch_batch = next(
                    data_utils.gen_batch(X_full_val, X_sketch_val, batch_size))
                data_utils.plot_generated_batch(X_full_batch, X_sketch_batch,
                                                generator_model, batch_size,
                                                image_data_format, model_name,
                                                "validation",
                                                init_epoch + e + 1,
                                                MAX_FRAMES_PER_GIF)
                # Plot losses
                data_utils.plot_losses(disc_losses, gen_total_losses,
                                       gen_L1_losses, gen_log_losses,
                                       model_name, init_epoch)
            # Save weights
            if (e + 1) % save_weights_every_n_epochs == 0:
                gen_weights_path = os.path.join(
                    '../../models/%s/gen_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5'
                    % (model_name, init_epoch + e, disc_losses[-1],
                       gen_total_losses[-1], gen_L1_losses[-1],
                       gen_log_losses[-1]))
                generator_model.save_weights(gen_weights_path, overwrite=True)
                disc_weights_path = os.path.join(
                    '../../models/%s/disc_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5'
                    % (model_name, init_epoch + e, disc_losses[-1],
                       gen_total_losses[-1], gen_L1_losses[-1],
                       gen_log_losses[-1]))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)
                DCGAN_weights_path = os.path.join(
                    '../../models/%s/DCGAN_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5'
                    % (model_name, init_epoch + e, disc_losses[-1],
                       gen_total_losses[-1], gen_L1_losses[-1],
                       gen_log_losses[-1]))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_data_format = kwargs["image_data_format"]
    celebA_img_dim = kwargs["celebA_img_dim"]
    cont_dim = (kwargs["cont_dim"], )
    cat_dim = (kwargs["cat_dim"], )
    noise_dim = (kwargs["noise_dim"], )
    label_smoothing = kwargs["label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    load_from_dir = kwargs["load_from_dir"]
    target_size = kwargs["target_size"]
    save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"]
    save_only_last_n_weights = kwargs["save_only_last_n_weights"]
    visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"]
    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(**kwargs)

    # Load and rescale data
    if dset == "celebA":
        X_real_train = data_utils.load_celebA(celebA_img_dim,
                                              image_data_format)
    elif dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_data_format)
    else:
        X_batch_gen = data_utils.data_generator_from_dir(
            dset, target_size, batch_size)
        X_real_train = next(X_batch_gen)
    img_dim = X_real_train.shape[-3:]

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
        opt_discriminator = Adam(lr=1E-4,
                                 beta_1=0.5,
                                 beta_2=0.999,
                                 epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True)

        # Load generator model
        generator_model = models.load("generator_%s" % generator,
                                      cat_dim,
                                      cont_dim,
                                      noise_dim,
                                      img_dim,
                                      batch_size,
                                      dset=dset,
                                      use_mbd=use_mbd)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          cat_dim,
                                          cont_dim,
                                          noise_dim,
                                          img_dim,
                                          batch_size,
                                          dset=dset,
                                          use_mbd=use_mbd)

        generator_model.compile(loss='mse', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   cat_dim, cont_dim, noise_dim)

        list_losses = [
            'binary_crossentropy', 'categorical_crossentropy', gaussian_loss
        ]
        list_weights = [1, 1, 1]
        DCGAN_model.compile(loss=list_losses,
                            loss_weights=list_weights,
                            optimizer=opt_dcgan)

        # Multiple discriminator losses
        discriminator_model.trainable = True
        discriminator_model.compile(loss=list_losses,
                                    loss_weights=list_weights,
                                    optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        if not load_from_dir:
            X_batch_gen = data_utils.gen_batch(X_real_train, batch_size)

        # Start training
        print("Start training")

        disc_total_losses = []
        disc_log_losses = []
        disc_cat_losses = []
        disc_cont_losses = []
        gen_total_losses = []
        gen_log_losses = []
        gen_cat_losses = []
        gen_cont_losses = []

        start = time.time()

        for e in range(nb_epoch):

            print('--------------------------------------------')
            print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d}\n'.format(
                datetime.datetime.now(), e + 1, nb_epoch))

            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1

            disc_total_loss_batch = 0
            disc_log_loss_batch = 0
            disc_cat_loss_batch = 0
            disc_cont_loss_batch = 0
            gen_total_loss_batch = 0
            gen_log_loss_batch = 0
            gen_cat_loss_batch = 0
            gen_cont_loss_batch = 0

            for batch_counter in range(n_batch_per_epoch):

                # Load data
                X_real_batch = next(X_batch_gen)

                # Create a batch to feed the discriminator model
                X_disc, y_disc, y_cat, y_cont = data_utils.get_disc_batch(
                    X_real_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    cat_dim,
                    cont_dim,
                    noise_dim,
                    noise_scale=noise_scale,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(
                    X_disc, [y_disc, y_cat, y_cont])

                # Create a batch to feed the generator model
                X_gen, y_gen, y_cat, y_cont, y_cont_target = data_utils.get_gen_batch(
                    batch_size,
                    cat_dim,
                    cont_dim,
                    noise_dim,
                    noise_scale=noise_scale)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(
                    [y_cat, y_cont, X_gen], [y_gen, y_cat, y_cont_target])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                progbar.add(batch_size,
                            values=[("D tot", disc_loss[0]),
                                    ("D log", disc_loss[1]),
                                    ("D cat", disc_loss[2]),
                                    ("D cont", disc_loss[3]),
                                    ("G tot", gen_loss[0]),
                                    ("G log", gen_loss[1]),
                                    ("G cat", gen_loss[2]),
                                    ("G cont", gen_loss[3])])

                disc_total_loss_batch += disc_loss[0]
                disc_log_loss_batch += disc_loss[1]
                disc_cat_loss_batch += disc_loss[2]
                disc_cont_loss_batch += disc_loss[3]
                gen_total_loss_batch += gen_loss[0]
                gen_log_loss_batch += gen_loss[1]
                gen_cat_loss_batch += gen_loss[2]
                gen_cont_loss_batch += gen_loss[3]

                # # Save images for visualization
                # if batch_counter % (n_batch_per_epoch / 2) == 0:
                #     data_utils.plot_generated_batch(X_real_batch, generator_model, e,
                #                                     batch_size, cat_dim, cont_dim, noise_dim,
                #                                     image_data_format, model_name)

            disc_total_losses.append(disc_total_loss_batch / n_batch_per_epoch)
            disc_log_losses.append(disc_log_loss_batch / n_batch_per_epoch)
            disc_cat_losses.append(disc_cat_loss_batch / n_batch_per_epoch)
            disc_cont_losses.append(disc_cont_loss_batch / n_batch_per_epoch)
            gen_total_losses.append(gen_total_loss_batch / n_batch_per_epoch)
            gen_log_losses.append(gen_log_loss_batch / n_batch_per_epoch)
            gen_cat_losses.append(gen_cat_loss_batch / n_batch_per_epoch)
            gen_cont_losses.append(gen_cont_loss_batch / n_batch_per_epoch)

            # Save images for visualization
            if (e + 1) % visualize_images_every_n_epochs == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                e, batch_size, cat_dim,
                                                cont_dim, noise_dim,
                                                image_data_format, model_name)
                data_utils.plot_losses(disc_total_losses, disc_log_losses,
                                       disc_cat_losses, disc_cont_losses,
                                       gen_total_losses, gen_log_losses,
                                       gen_cat_losses, gen_cont_losses,
                                       model_name)

            if (e + 1) % save_weights_every_n_epochs == 0:

                print("Saving weights...")

                # Delete all but the last n weights
                general_utils.purge_weights(save_only_last_n_weights,
                                            model_name)

                # Save weights
                gen_weights_path = os.path.join(
                    '../../models/%s/gen_weights_epoch%05d.h5' %
                    (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    '../../models/%s/disc_weights_epoch%05d.h5' %
                    (model_name, e))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    '../../models/%s/DCGAN_weights_epoch%05d.h5' %
                    (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

            end = time.time()
            print("")
            print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, end - start))
            start = end

    except KeyboardInterrupt:
        pass

    gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name)
    print("Saving", gen_weights_path)
    generator_model.save(gen_weights_path, overwrite=True)
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    generator = kwargs["generator"]
    model_name = kwargs["model_name"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    cont_dim = (kwargs["cont_dim"], )
    cat_dim = (kwargs["cat_dim"], )
    noise_dim = (kwargs["noise_dim"], )
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    noise_scale = kwargs["noise_scale"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    if dset == "celebA":
        X_real_train = data_utils.load_celebA(img_dim, image_data_format)
    if dset == "mnist":
        X_real_train, _, _, _ = data_utils.load_mnist(image_data_format)
    img_dim = X_real_train.shape[-3:]

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
        opt_discriminator = Adam(lr=1E-4,
                                 beta_1=0.5,
                                 beta_2=0.999,
                                 epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True)

        # Load generator model
        generator_model = models.load("generator_%s" % generator,
                                      cat_dim,
                                      cont_dim,
                                      noise_dim,
                                      img_dim,
                                      bn_mode,
                                      batch_size,
                                      dset=dset,
                                      use_mbd=use_mbd)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          cat_dim,
                                          cont_dim,
                                          noise_dim,
                                          img_dim,
                                          bn_mode,
                                          batch_size,
                                          dset=dset,
                                          use_mbd=use_mbd)

        generator_model.compile(loss='mse', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   cat_dim, cont_dim, noise_dim)

        list_losses = [
            'binary_crossentropy', 'categorical_crossentropy', gaussian_loss
        ]
        list_weights = [1, 1, 1]
        DCGAN_model.compile(loss=list_losses,
                            loss_weights=list_weights,
                            optimizer=opt_dcgan)

        # Multiple discriminator losses
        discriminator_model.trainable = True
        discriminator_model.compile(loss=list_losses,
                                    loss_weights=list_weights,
                                    optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_real_batch in data_utils.gen_batch(X_real_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc, y_cat, y_cont = data_utils.get_disc_batch(
                    X_real_batch,
                    generator_model,
                    batch_counter,
                    batch_size,
                    cat_dim,
                    cont_dim,
                    noise_dim,
                    noise_scale=noise_scale,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(
                    X_disc, [y_disc, y_cat, y_cont])

                # Create a batch to feed the generator model
                X_gen, y_gen, y_cat, y_cont, y_cont_target = data_utils.get_gen_batch(
                    batch_size,
                    cat_dim,
                    cont_dim,
                    noise_dim,
                    noise_scale=noise_scale)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(
                    [y_cat, y_cont, X_gen], [y_gen, y_cat, y_cont_target])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size,
                            values=[("D tot", disc_loss[0]),
                                    ("D log", disc_loss[1]),
                                    ("D cat", disc_loss[2]),
                                    ("D cont", disc_loss[3]),
                                    ("G tot", gen_loss[0]),
                                    ("G log", gen_loss[1]),
                                    ("G cat", gen_loss[2]),
                                    ("G cont", gen_loss[3])])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    data_utils.plot_generated_batch(X_real_batch,
                                                    generator_model,
                                                    batch_size, cat_dim,
                                                    cont_dim, noise_dim,
                                                    image_data_format)

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                gen_weights_path = os.path.join(
                    '../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    '../../models/%s/disc_weights_epoch%s.h5' %
                    (model_name, e))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    '../../models/%s/DCGAN_weights_epoch%s.h5' %
                    (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Beispiel #12
0
def main(args):
    if args.mnist:
        # Normalize image for MNIST
        # normalize = Normalize(mean=(0.1307,), std=(0.3081,))
        normalize = None
        args.input_size = 784
    elif args.cifar:
        normalize = utils.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        args.input_size = 32 * 32 * 3
    else:
        # Normalize image for ImageNet
        normalize = utils.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
        args.input_size = 150528

    # Load data
    train_loader, test_loader = utils.get_data(args)

    # The unknown model to attack
    unk_model = utils.load_unk_model(args)

    # Try Whitebox Untargeted first
    if args.debug:
        ipdb.set_trace()

    if args.train_vae:
        encoder, decoder, vae = train_mnist_vae(args)
    else:
        encoder, decoder, vae = None, None, None

    if args.train_ae:
        encoder, decoder, ae = train_mnist_ae(args)
    else:
        encoder, decoder, ae = None, None, None

    # Add A Flow
    norm_flow = None
    if args.use_flow:
        # norm_flow = flows.NormalizingFlow(30, args.latent).to(args.device)
        norm_flow = flows.Planar
    # Test white box
    if args.white:
        # Choose Attack Function
        if args.no_pgd_optim:
            white_attack_func = attacks.L2_white_box_generator
        else:
            white_attack_func = attacks.PGD_white_box_generator

        # Choose Dataset
        if args.mnist:
            G = models.Generator(input_size=784).to(args.device)
        elif args.cifar:
            if args.vanilla_G:
                G = models.DCGAN().to(args.device)
                G = nn.DataParallel(G.generator)
            else:
                G = models.ConvGenerator(models.Bottleneck,[6,12,24,16],growth_rate=12,\
                                     flows=norm_flow,use_flow=args.use_flow,\
                                     deterministic=args.deterministic_G).to(args.device)
                G = nn.DataParallel(G)
            nc, h, w = 3, 32, 32

        if args.run_baseline:
            attacks.whitebox_pgd(args, unk_model)

        pred, delta = white_attack_func(args, train_loader,\
                test_loader, unk_model, G, nc, h, w)

    # Blackbox Attack model
    model = models.GaussianPolicy(args.input_size,
                                  400,
                                  args.latent_size,
                                  decode=False).to(args.device)

    # Control Variate
    cv = to_cuda(models.FC(args.input_size, args.classes))
Beispiel #13
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    # right strip '/' to avoid empty '/' dir
    save_dir = kwargs["save_dir"].rstrip('/')
    # join name with current datetime
    save_dir = '_'.join(
        [save_dir,
         datetime.datetime.now().strftime("%I:%M%p-%B%d-%Y/")])

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    # save the config in save dir
    with open('{0}job_config.json'.format(save_dir), 'w') as fp:
        json.dump(kwargs, fp, sort_keys=True, indent=4)

    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(
        dset, image_data_format)
    img_dim = X_full_train.shape[-3:]

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size,
                                                     image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator, img_dim,
                                      nb_patch, bn_mode, use_mbd, batch_size)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator", img_dim_disc,
                                          nb_patch, bn_mode, use_mbd,
                                          batch_size)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_data_format)

        loss = [l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_full_batch, X_sketch_batch in data_utils.gen_batch(
                    X_full_train, X_sketch_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_full_batch,
                    X_sketch_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    image_data_format,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(
                    data_utils.gen_batch(X_full_train, X_sketch_train,
                                         batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size,
                            values=[("D logloss", disc_loss),
                                    ("G tot", gen_loss[0]),
                                    ("G L1", gen_loss[1]),
                                    ("G logloss", gen_loss[2])])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(
                        X_full_batch, X_sketch_batch, generator_model,
                        batch_size, image_data_format,
                        "{:03}_EPOCH_TRAIN".format(e + 1), save_dir)
                    X_full_batch, X_sketch_batch = next(
                        data_utils.gen_batch(X_full_val, X_sketch_val,
                                             batch_size))
                    data_utils.plot_generated_batch(
                        X_full_batch, X_sketch_batch, generator_model,
                        batch_size, image_data_format,
                        "{:03}_EPOCH_VALID".format(e + 1), save_dir)

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                pass
                # save models
                # gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                # generator_model.save_weights(gen_weights_path, overwrite=True)

                # disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
                # discriminator_model.save_weights(disc_weights_path, overwrite=True)

                # DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
                # DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass

    # save models
    DCGAN_model.save(save_dir + 'DCGAN.h5')
    generator_model.save(save_dir + 'GENERATOR.h5')
    discriminator_model.save(save_dir + 'DISCRIMINATOR.h5')
Beispiel #14
0
import os

import models
import utils

if __name__ == "__main__":
    opt = utils.get_options()
    model = models.DCGAN(opt)
    dataloader = utils.create_dataloader(opt)

    os.makedirs("./results", exist_ok=True)

    G_losses, D_losses = [], []

    print("Start Training...")

    for epoch in range(1, opt.num_epochs + 1):
        for iters, (data, _) in enumerate(dataloader):
            model.set_input(data)
            model.optimize_parameters()

            loss_G, loss_D = model.get_losses()
            G_losses.append(loss_G)
            D_losses.append(loss_D)

            if iters % 50 == 0:
                print("Epoch: %d/%d\tIter: %d/%d\tLoss_G: %.4f\tLoss_D: %.4f" %
                      (epoch, opt.num_epochs, iters, len(dataloader), loss_G,
                       loss_D))

            if (iters % 500 == 0) or ((epoch == opt.num_epochs) and
Beispiel #15
0
def train(cat_dim,
          noise_dim,
          batch_size,
          n_batch_per_epoch,
          nb_epoch,
          dset="mnist"):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """
    general_utils.setup_logging("IG")
    # Load and rescale data
    if dset == "mnist":
        print("loading mnist data")
        X_real_train, Y_real_train, X_real_test, Y_real_test = data_utils.load_mnist(
        )
        # pick 1000 sample for testing
        # X_real_test = X_real_test[-1000:]
        # Y_real_test = Y_real_test[-1000:]

    img_dim = X_real_train.shape[-3:]
    epoch_size = n_batch_per_epoch * batch_size

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        opt_discriminator = Adam(lr=2E-4,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True)

        # Load generator model
        generator_model = models.load("generator_deconv",
                                      cat_dim,
                                      noise_dim,
                                      img_dim,
                                      batch_size,
                                      dset=dset)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          cat_dim,
                                          noise_dim,
                                          img_dim,
                                          batch_size,
                                          dset=dset)

        generator_model.compile(loss='mse', optimizer=opt_discriminator)
        # stop the discriminator to learn while in generator is learning
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   cat_dim, noise_dim)

        list_losses = ['binary_crossentropy', 'categorical_crossentropy']
        list_weights = [1, 1]
        DCGAN_model.compile(loss=list_losses,
                            loss_weights=list_weights,
                            optimizer=opt_dcgan)

        # Multiple discriminator losses
        # allow the discriminator to learn again
        discriminator_model.trainable = True
        discriminator_model.compile(loss=list_losses,
                                    loss_weights=list_weights,
                                    optimizer=opt_discriminator)
        # Start training
        print("Start training")
        for e in range(nb_epoch + 1):
            # Initialize progbar and batch counter
            # progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()
            print("Epoch: {}".format(e))
            for X_real_batch, Y_real_batch in zip(
                    data_utils.gen_batch(X_real_train, batch_size),
                    data_utils.gen_batch(Y_real_train, batch_size)):

                # Create a batch to feed the discriminator model
                X_disc_fake, y_disc_fake, noise_sample = data_utils.get_disc_batch(
                    X_real_batch,
                    Y_real_batch,
                    generator_model,
                    batch_size,
                    cat_dim,
                    noise_dim,
                    type="fake")
                X_disc_real, y_disc_real = data_utils.get_disc_batch(
                    X_real_batch,
                    Y_real_batch,
                    generator_model,
                    batch_size,
                    cat_dim,
                    noise_dim,
                    type="real")

                # Update the discriminator
                disc_loss_fake = discriminator_model.train_on_batch(
                    X_disc_fake, [y_disc_fake, Y_real_batch])
                disc_loss_real = discriminator_model.train_on_batch(
                    X_disc_real, [y_disc_real, Y_real_batch])
                disc_loss = disc_loss_fake + disc_loss_real
                # Create a batch to feed the generator model
                # X_noise, y_gen = data_utils.get_gen_batch(batch_size, cat_dim, noise_dim)

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(
                    [Y_real_batch, noise_sample], [y_disc_real, Y_real_batch])
                # Unfreeze the discriminator
                discriminator_model.trainable = True
                # training validation
                p_real_batch, p_Y_batch = discriminator_model.predict(
                    X_real_batch, batch_size=batch_size)
                acc_train = data_utils.accuracy(p_Y_batch, Y_real_batch)
                batch_counter += 1
                # progbar.add(batch_size, values=[("D tot", disc_loss[0]),
                #                                 ("D cat", disc_loss[2]),
                #                                 ("G tot", gen_loss[0]),
                #                                 ("G cat", gen_loss[2]),
                #                                 ("P Real:", p_real_batch),
                #                                 ("Q acc", acc_train)])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch /
                                    2) == 0 and e % 10 == 0:
                    data_utils.plot_generated_batch(X_real_batch,
                                                    generator_model,
                                                    batch_size, cat_dim,
                                                    noise_dim, e)
                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' %
                  (e + 1, nb_epoch, time.time() - start))
            _, p_Y_test = discriminator_model.predict(
                X_real_test, batch_size=X_real_test.shape[0])
            acc_test = data_utils.accuracy(p_Y_test, Y_real_test)
            print("Epoch: {} Accuracy: {}".format(e + 1, acc_test))
            if e % 1000 == 0:
                gen_weights_path = os.path.join(
                    '../../models/IG/gen_weights.h5')
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    '../../models/IG/disc_weights.h5')
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    '../../models/IG/DCGAN_weights.h5')
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Beispiel #16
0
def train(**kwargs):
    """
    Train model

    Load the whole train data in memory for faster operations

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    model_name = kwargs["model_name"]
    generator = kwargs["generator"]
    image_data_format = kwargs["image_data_format"]
    img_dim = kwargs["img_dim"]
    patch_size = kwargs["patch_size"]
    bn_mode = kwargs["bn_mode"]
    label_smoothing = kwargs["use_label_smoothing"]
    label_flipping = kwargs["label_flipping"]
    dset = kwargs["dset"]
    use_mbd = kwargs["use_mbd"]
    pretrained_model_path = kwargs["pretrained_model_path"]

    epoch_size = n_batch_per_epoch * batch_size

    # Setup environment (logging directory etc)
    general_utils.setup_logging(model_name)

    # Load and rescale data
    X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(dset, image_data_format)
    img_dim = X_full_train.shape[-3:]


    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format)

    try:

        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        load_pretrained = False
        if pretrained_model_path:
            load_pretrained = True

        # Load generator model
        generator_model = models.load("generator_unet_%s" % generator,
                                      img_dim,
                                      nb_patch,
                                      bn_mode,
                                      use_mbd,
                                      batch_size,
                                      load_pretrained)
        # Load discriminator model
        discriminator_model = models.load("DCGAN_discriminator",
                                          img_dim_disc,
                                          nb_patch,
                                          bn_mode,
                                          use_mbd,
                                          batch_size,
                                          load_pretrained)

        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model,
                                   discriminator_model,
                                   img_dim,
                                   patch_size,
                                   image_data_format)

        loss = [l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_full_batch, X_sketch_batch in data_utils.gen_batch(X_full_train, X_sketch_train, batch_size):

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(X_full_batch,
                                                           X_sketch_batch,
                                                           generator_model,
                                                           batch_counter,
                                                           patch_size,
                                                           image_data_format,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(data_utils.gen_batch(X_full_train, X_sketch_train, batch_size))
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size, values=[("D logloss", disc_loss),
                                                ("G tot", gen_loss[0]),
                                                ("G L1", gen_loss[1]),
                                                ("G logloss", gen_loss[2])])

                # Save images for visualization
                if batch_counter % (n_batch_per_epoch / 2) == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model,
                                                    batch_size, image_data_format, "training")
                    X_full_batch, X_sketch_batch = next(data_utils.gen_batch(X_full_val, X_sketch_val, batch_size))
                    data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model,
                                                    batch_size, image_data_format, "validation")

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

            if e % 5 == 0:
                gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
                discriminator_model.save_weights(disc_weights_path, overwrite=True)

                DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass
Beispiel #17
0
# Create optimizers
G_opt = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
D_opt = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

# Load generator model
generator_model = models.load("generator", img_dim=(256, 256, 3))
# generator_model.load_weights('./models/pix2pix/gen_weights_epoch_6.h5')
generator_model.compile(loss='mae', optimizer=G_opt)

# Load discriminator model
discriminator_model = models.load("discriminator", img_dim=(256, 256, 3))
# discriminator_model.load_weights('./models/pix2pix/disc_weights_epoch_6.h5')
discriminator_model.trainable = False

DCGAN_model = models.DCGAN(generator_model,
                           discriminator_model,
                           img_dim=(256, 256, 3))
# DCGAN_model.load_weights('./models/pix2pix/DCGAN_weights_epoch_6.h5')

loss = [l1_loss, 'binary_crossentropy']
# loss = [perceptual_loss, 'binary_crossentropy']
loss_weights = [1E1, 1]
DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=G_opt)

discriminator_model.trainable = True
discriminator_model.compile(loss='binary_crossentropy', optimizer=D_opt)

# Start training
print("Start training")
for e in range(1, nb_epoch + 1):
    # Initialize progbar and batch counter