def train(args): import models import numpy as np # np.random.seed(1234) if args.dataset == 'mnist': n_dim, n_out, n_channels = 28, 10, 1 X_train, y_train, X_val, y_val, _, _ = data.load_mnist() elif args.dataset == 'random': n_dim, n_out, n_channels = 2, 2, 1 X_train, y_train = data.load_noise(n=1000, d=n_dim) X_val, y_val = X_train, y_train else: raise ValueError('Invalid dataset name: %s' % args.dataset) # set up optimization params opt_params = { 'lr' : args.lr, 'c' : args.c, 'n_critic' : args.n_critic } # create model if args.model == 'dcgan': model = models.DCGAN(n_dim=n_dim, n_chan=n_channels, opt_params=opt_params) elif args.model == 'wdcgan': model = models.WDCGAN(n_dim=n_dim, n_chan=n_channels, opt_params=opt_params) else: raise ValueError('Invalid model') # train model model.fit(X_train, X_val, n_epoch=args.epochs, n_batch=args.n_batch, logdir=args.logdir)
def train(args): #load dataset if args.dataset == 'mnist': n_dim, n_out, n_channels = 28, 10, 1 X_train, y_train, X_val, y_val, _, _ = data.load_mnist() elif args.dataset == 'random': n_dim, n_out, n_channels = 2, 2, 1 X_train, y_train = data.load_noise(n=1000, d=n_dim) X_val, y_val = X_train, y_train #可扩展 elif args.dataset == 'malware_clean_data': n_dim, n_channels = 64, 1 xtrain_mal, ytrain_mal, xtrain_ben, ytrain_ben, xtest_mal, ytest_mal, xtest_ben, ytest_ben = data.load_Malware_clean_ApkToImage( ) if args.same_train_data: X_train, y_train, X_val, y_val = xtrain_mal, ytrain_mal, xtrain_ben, ytrain_ben else: raise ValueError('Invalid dataset name: %s' % args.dataset) # set up optimization params opt_params = {'lr': args.lr, 'c': args.c, 'n_critic': args.n_critic} # create model if args.model == 'dcgan': model = models.DCGAN(n_dim=n_dim, n_chan=n_channels, opt_params=opt_params) elif args.model == 'wdcgan': model = models.WDCGAN(n_dim=n_dim, n_chan=n_channels, opt_params=opt_params) else: raise ValueError('Invalid model') # train model model.fit(X_train, y_train, X_val, y_val, n_epoch=args.epochs, n_batch=args.n_batch, logdir=args.logdir)
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] image_data_format = kwargs["image_data_format"] generator_type = kwargs["generator_type"] dset = kwargs["dset"] use_identity_image = kwargs["use_identity_image"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] augment_data = kwargs["augment_data"] model_name = kwargs["model_name"] save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"] visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"] save_only_last_n_weights = kwargs["save_only_last_n_weights"] use_mbd = kwargs["use_mbd"] label_smoothing = kwargs["use_label_smoothing"] label_flipping_prob = kwargs["label_flipping_prob"] use_l1_weighted_loss = kwargs["use_l1_weighted_loss"] prev_model = kwargs["prev_model"] change_model_name_to_prev_model = kwargs["change_model_name_to_prev_model"] discriminator_optimizer = kwargs["discriminator_optimizer"] n_run_of_gen_for_1_run_of_disc = kwargs["n_run_of_gen_for_1_run_of_disc"] load_all_data_at_once = kwargs["load_all_data_at_once"] MAX_FRAMES_PER_GIF = kwargs["MAX_FRAMES_PER_GIF"] dont_train = kwargs["dont_train"] # batch_size = args.batch_size # n_batch_per_epoch = args.n_batch_per_epoch # nb_epoch = args.nb_epoch # save_weights_every_n_epochs = args.save_weights_every_n_epochs # generator_type = args.generator_type # patch_size = args.patch_size # label_smoothing = False # label_flipping_prob = False # dset = args.dset # use_mbd = False if dont_train: # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) generator_model = models.load("generator_unet_%s" % generator_type, img_dim, nb_patch, use_mbd, batch_size, model_name) generator_model.compile(loss='mae', optimizer='adam') return generator_model # Check and make the dataset # If .h5 file of dset is not present, try making it if load_all_data_at_once: if not os.path.exists("../../data/processed/%s_data.h5" % dset): print("dset %s_data.h5 not present in '../../data/processed'!" % dset) if not os.path.exists("../../data/%s/" % dset): print("dset folder %s not present in '../../data'!\n\nERROR: Dataset .h5 file not made, and dataset not available in '../../data/'.\n\nQuitting." % dset) return else: if not os.path.exists("../../data/%s/train" % dset) or not os.path.exists("../../data/%s/val" % dset) or not os.path.exists("../../data/%s/test" % dset): print("'train', 'val' or 'test' folders not present in dset folder '../../data/%s'!\n\nERROR: Dataset must contain 'train', 'val' and 'test' folders.\n\nQuitting." % dset) return else: print("Making %s dataset" % dset) subprocess.call(['python3', '../data/make_dataset.py', '../../data/%s' % dset, '3']) print("Done!") else: if not os.path.exists(dset): print("dset does not exist! Given:", dset) return if not os.path.exists(os.path.join(dset, 'train')): print("dset does not contain a 'train' dir! Given dset:", dset) return if not os.path.exists(os.path.join(dset, 'val')): print("dset does not contain a 'val' dir! Given dset:", dset) return epoch_size = n_batch_per_epoch * batch_size init_epoch = 0 if prev_model: print('\n\nLoading prev_model from', prev_model, '...\n\n') prev_model_latest_gen = sorted(glob.glob(os.path.join('../../models/', prev_model, '*gen*epoch*.h5')))[-1] prev_model_latest_disc = sorted(glob.glob(os.path.join('../../models/', prev_model, '*disc*epoch*.h5')))[-1] prev_model_latest_DCGAN = sorted(glob.glob(os.path.join('../../models/', prev_model, '*DCGAN*epoch*.h5')))[-1] print(prev_model_latest_gen, prev_model_latest_disc, prev_model_latest_DCGAN) if change_model_name_to_prev_model: # Find prev model name, epoch model_name = prev_model_latest_DCGAN.split('models')[-1].split('/')[1] init_epoch = int(prev_model_latest_DCGAN.split('epoch')[1][:5]) + 1 # img_dim = X_target_train.shape[-3:] # img_dim = (256, 256, 3) # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) if discriminator_optimizer == 'sgd': opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) elif discriminator_optimizer == 'adam': opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator_type, img_dim, nb_patch, use_mbd, batch_size, model_name) generator_model.compile(loss='mae', optimizer=opt_dcgan) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, use_mbd, batch_size, model_name) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) if use_l1_weighted_loss: loss = [l1_weighted_loss, 'binary_crossentropy'] else: loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) # Load prev_model if prev_model: generator_model.load_weights(prev_model_latest_gen) discriminator_model.load_weights(prev_model_latest_disc) DCGAN_model.load_weights(prev_model_latest_DCGAN) # Load .h5 data all at once print('\n\nLoading data...\n\n') check_this_process_memory() if load_all_data_at_once: X_target_train, X_sketch_train, X_target_val, X_sketch_val = data_utils.load_data(dset, image_data_format) check_this_process_memory() print('X_target_train: %.4f' % (X_target_train.nbytes/2**30), "GB") print('X_sketch_train: %.4f' % (X_sketch_train.nbytes/2**30), "GB") print('X_target_val: %.4f' % (X_target_val.nbytes/2**30), "GB") print('X_sketch_val: %.4f' % (X_sketch_val.nbytes/2**30), "GB") # To generate training data X_target_batch_gen_train, X_sketch_batch_gen_train = data_utils.data_generator(X_target_train, X_sketch_train, batch_size, augment_data=augment_data) X_target_batch_gen_val, X_sketch_batch_gen_val = data_utils.data_generator(X_target_val, X_sketch_val, batch_size, augment_data=False) # Load data from images through an ImageDataGenerator else: X_batch_gen_train = data_utils.data_generator_from_dir(os.path.join(dset, 'train'), target_size=(img_dim[0], 2*img_dim[1]), batch_size=batch_size) X_batch_gen_val = data_utils.data_generator_from_dir(os.path.join(dset, 'val'), target_size=(img_dim[0], 2*img_dim[1]), batch_size=batch_size) check_this_process_memory() # Setup environment (logging directory etc) general_utils.setup_logging(**kwargs) # Losses disc_losses = [] gen_total_losses = [] gen_L1_losses = [] gen_log_losses = [] # Start training print("\n\nStarting training...\n\n") # For each epoch for e in range(nb_epoch): # Initialize progbar and batch counter # progbar = generic_utils.Progbar(epoch_size) batch_counter = 0 gen_total_loss_epoch = 0 gen_L1_loss_epoch = 0 gen_log_loss_epoch = 0 start = time.time() # For each batch # for X_target_batch, X_sketch_batch in data_utils.gen_batch(X_target_train, X_sketch_train, batch_size): for batch in range(n_batch_per_epoch): # Create a batch to feed the discriminator model if load_all_data_at_once: X_target_batch_train, X_sketch_batch_train = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train) else: X_target_batch_train, X_sketch_batch_train = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim, augment_data=augment_data, use_identity_image=use_identity_image) X_disc, y_disc = data_utils.get_disc_batch(X_target_batch_train, X_sketch_batch_train, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping_prob=label_flipping_prob) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model if load_all_data_at_once: X_gen_target, X_gen_sketch = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train) else: X_gen_target, X_gen_sketch = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim, augment_data=augment_data, use_identity_image=use_identity_image) y_gen_target = np.zeros((X_gen_target.shape[0], 2), dtype=np.uint8) y_gen_target[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False # Train generator for _ in range(n_run_of_gen_for_1_run_of_disc-1): gen_loss = DCGAN_model.train_on_batch(X_gen_sketch, [X_gen_target, y_gen_target]) gen_total_loss_epoch += gen_loss[0]/n_run_of_gen_for_1_run_of_disc gen_L1_loss_epoch += gen_loss[1]/n_run_of_gen_for_1_run_of_disc gen_log_loss_epoch += gen_loss[2]/n_run_of_gen_for_1_run_of_disc if load_all_data_at_once: X_gen_target, X_gen_sketch = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train) else: X_gen_target, X_gen_sketch = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim, augment_data=augment_data, use_identity_image=use_identity_image) gen_loss = DCGAN_model.train_on_batch(X_gen_sketch, [X_gen_target, y_gen_target]) # Add losses gen_total_loss_epoch += gen_loss[0]/n_run_of_gen_for_1_run_of_disc gen_L1_loss_epoch += gen_loss[1]/n_run_of_gen_for_1_run_of_disc gen_log_loss_epoch += gen_loss[2]/n_run_of_gen_for_1_run_of_disc # Unfreeze the discriminator discriminator_model.trainable = True # Progress # progbar.add(batch_size, values=[("D logloss", disc_loss), # ("G tot", gen_loss[0]), # ("G L1", gen_loss[1]), # ("G logloss", gen_loss[2])]) print("Epoch", str(init_epoch+e+1), "batch", str(batch+1), "D_logloss", disc_loss, "G_tot", gen_loss[0], "G_L1", gen_loss[1], "G_log", gen_loss[2]) gen_total_loss = gen_total_loss_epoch/n_batch_per_epoch gen_L1_loss = gen_L1_loss_epoch/n_batch_per_epoch gen_log_loss = gen_log_loss_epoch/n_batch_per_epoch disc_losses.append(disc_loss) gen_total_losses.append(gen_total_loss) gen_L1_losses.append(gen_L1_loss) gen_log_losses.append(gen_log_loss) # Save images for visualization if (e + 1) % visualize_images_every_n_epochs == 0: data_utils.plot_generated_batch(X_target_batch_train, X_sketch_batch_train, generator_model, batch_size, image_data_format, model_name, "training", init_epoch + e + 1, MAX_FRAMES_PER_GIF) # Get new images for validation if load_all_data_at_once: X_target_batch_val, X_sketch_batch_val = next(X_target_batch_gen_val), next(X_sketch_batch_gen_val) else: X_target_batch_val, X_sketch_batch_val = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_val, img_dim=img_dim, augment_data=False, use_identity_image=use_identity_image) # Predict and validate data_utils.plot_generated_batch(X_target_batch_val, X_sketch_batch_val, generator_model, batch_size, image_data_format, model_name, "validation", init_epoch + e + 1, MAX_FRAMES_PER_GIF) # Plot losses data_utils.plot_losses(disc_losses, gen_total_losses, gen_L1_losses, gen_log_losses, model_name, init_epoch) # Save weights if (e + 1) % save_weights_every_n_epochs == 0: # Delete all but the last n weights purge_weights(save_only_last_n_weights, model_name) # Save gen weights gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) print("Saving", gen_weights_path) generator_model.save_weights(gen_weights_path, overwrite=True) # Save disc weights disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) print("Saving", disc_weights_path) discriminator_model.save_weights(disc_weights_path, overwrite=True) # Save DCGAN weights DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) print("Saving", DCGAN_weights_path) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) check_this_process_memory() print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d} END, Time taken: {3:.4f} seconds'.format(datetime.datetime.now(), init_epoch + e + 1, init_epoch + nb_epoch, time.time() - start)) print('------------------------------------------------------------------------------------') except KeyboardInterrupt: pass # SAVE THE MODEL try: # Save the model as it is, so that it can be loaded using - # ```from keras.models import load_model; gen = load_model('generator_latest.h5')``` gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name) print("Saving", gen_weights_path) generator_model.save(gen_weights_path, overwrite=True) # Save model as json string generator_model_json_string = generator_model.to_json() print("Saving", '../../models/%s/generator_latest.txt' % model_name) with open('../../models/%s/generator_latest.txt' % model_name, 'w') as outfile: a = outfile.write(generator_model_json_string) # Save model as json generator_model_json_data = json.loads(generator_model_json_string) print("Saving", '../../models/%s/generator_latest.json' % model_name) with open('../../models/%s/generator_latest.json' % model_name, 'w') as outfile: json.dump(generator_model_json_data, outfile) except: print(sys.exc_info()[0]) print("Done.") return generator_model
def main(dataset, batch_size, patch_size, epochs, label_smoothing, label_flipping): print(project_dir) config = tf.ConfigProto() config.gpu_options.allow_growth = True # dynamically grow the memory used on the GPU sess = tf.Session(config=config) K.tensorflow_backend.set_session( sess) # set this TensorFlow session as the default session for Keras image_data_format = "channels_first" K.set_image_data_format(image_data_format) save_images_every_n_batches = 30 save_model_every_n_epochs = 0 # configuration parameters print("Config params:") print(" dataset = {}".format(dataset)) print(" batch_size = {}".format(batch_size)) print(" patch_size = {}".format(patch_size)) print(" epochs = {}".format(epochs)) print(" label_smoothing = {}".format(label_smoothing)) print(" label_flipping = {}".format(label_flipping)) print(" save_images_every_n_batches = {}".format( save_images_every_n_batches)) print(" save_model_every_n_epochs = {}".format(save_model_every_n_epochs)) model_name = datetime.strftime(datetime.now(), '%y%m%d-%H%M') model_dir = os.path.join(project_dir, "models", model_name) fig_dir = os.path.join(project_dir, "reports", "figures") logs_dir = os.path.join(project_dir, "reports", "logs", model_name) os.makedirs(model_dir) # Load and rescale data ds_train_gen = data_utils.DataGenerator(file_path=dataset, dataset_type="train", batch_size=batch_size) ds_train_disc = data_utils.DataGenerator(file_path=dataset, dataset_type="train", batch_size=batch_size) ds_val = data_utils.DataGenerator(file_path=dataset, dataset_type="val", batch_size=batch_size) enq_train_gen = OrderedEnqueuer(ds_train_gen, use_multiprocessing=True, shuffle=True) enq_train_disc = OrderedEnqueuer(ds_train_disc, use_multiprocessing=True, shuffle=True) enq_val = OrderedEnqueuer(ds_val, use_multiprocessing=True, shuffle=False) img_dim = ds_train_gen[0][0].shape[-3:] n_batch_per_epoch = len(ds_train_gen) epoch_size = n_batch_per_epoch * batch_size print("Derived params:") print(" n_batch_per_epoch = {}".format(n_batch_per_epoch)) print(" epoch_size = {}".format(epoch_size)) print(" n_batches_val = {}".format(len(ds_val))) # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size) tensorboard = TensorBoard(log_dir=logs_dir, histogram_freq=0, batch_size=batch_size, write_graph=True, write_grads=True, update_freq='batch') try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.generator_unet_upsampling(img_dim) generator_model.summary() plot_model(generator_model, to_file=os.path.join(fig_dir, "generator_model.png"), show_shapes=True, show_layer_names=True) # Load discriminator model # TODO: modify disc to accept real input as well discriminator_model = models.DCGAN_discriminator( img_dim_disc, nb_patch) discriminator_model.summary() plot_model(discriminator_model, to_file=os.path.join(fig_dir, "discriminator_model.png"), show_shapes=True, show_layer_names=True) # TODO: pretty sure this is unnecessary generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) # L1 loss applies to generated image, cross entropy applies to predicted label loss = [models.l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) tensorboard.set_model(DCGAN_model) # Start training enq_train_gen.start(workers=1, max_queue_size=20) enq_train_disc.start(workers=1, max_queue_size=20) enq_val.start(workers=1, max_queue_size=20) out_train_gen = enq_train_gen.get() out_train_disc = enq_train_disc.get() out_val = enq_val.get() print("Start training") for e in range(1, epochs + 1): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) start = time.time() for batch_counter in range(1, n_batch_per_epoch + 1): X_transformed_batch, X_orig_batch = next(out_train_disc) # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch( X_transformed_batch, X_orig_batch, generator_model, batch_counter, patch_size, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next(out_train_gen) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) # Set labels to 1 (real) to maximize the discriminator loss y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True metrics = [("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])] progbar.add(batch_size, values=metrics) logs = {k: v for (k, v) in metrics} logs["size"] = batch_size tensorboard.on_batch_end(batch_counter, logs=logs) # Save images for visualization if batch_counter % save_images_every_n_batches == 0: # Get new images from validation data_utils.plot_generated_batch( X_transformed_batch, X_orig_batch, generator_model, os.path.join(logs_dir, "current_batch_training.png")) X_transformed_batch, X_orig_batch = next(out_val) data_utils.plot_generated_batch( X_transformed_batch, X_orig_batch, generator_model, os.path.join(logs_dir, "current_batch_validation.png")) print("") print('Epoch %s/%s, Time: %s' % (e, epochs, time.time() - start)) tensorboard.on_epoch_end(e, logs=logs) if (save_model_every_n_epochs >= 1 and e % save_model_every_n_epochs == 0) or \ (e == epochs): print("Saving model for epoch {}...".format(e), end="") sys.stdout.flush() gen_weights_path = os.path.join( model_dir, 'gen_weights_epoch{:03d}.h5'.format(e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( model_dir, 'disc_weights_epoch{:03d}.h5'.format(e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( model_dir, 'DCGAN_weights_epoch{:03d}.h5'.format(e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) print("done") except KeyboardInterrupt: pass enq_train_gen.stop() enq_train_disc.stop() enq_val.stop()
def gen_theano_fn(args): """ Generate the networks and returns the train functions """ if args.verbose: print 'Creating networks...' # Setup input variables inpt_noise = T.matrix('inpt_noise') inpt_image = T.tensor4('inpt_image') corr_mask = T.matrix('corr_mask') # corruption mask corr_image = T.tensor4('corr_image') if args.captions: inpt_embd = T.matrix('inpt_embedding') # Shared variable for image reconstruction reconstr_noise_shrd = theano.shared( np.random.uniform(-1., 1., size=(1, 100)).astype(theano.config.floatX)) # Build generator and discriminator if args.captions: cond_gen_dc_gan = models.CaptionGenOnlyDCGAN(args) generator, lyr_gen_noise, lyr_gen_embd = cond_gen_dc_gan.init_generator( first_layer=64, input_var=None, embedding_var=None) discriminator = cond_gen_dc_gan.init_discriminator(first_layer=128, input_var=None) else: dc_gan = models.DCGAN(args) generator = dc_gan.init_generator(first_layer=64, input_var=None) discriminator = dc_gan.init_discriminator(first_layer=128, input_var=None) # Get images from generator (for training and outputing images) if args.captions: image_fake = lyr.get_output(generator, inputs={ lyr_gen_noise: inpt_noise, lyr_gen_embd: inpt_embd }) image_fake_det = lyr.get_output(generator, inputs={ lyr_gen_noise: inpt_noise, lyr_gen_embd: inpt_embd }, deterministic=True) image_reconstr = lyr.get_output(generator, inputs={ lyr_gen_noise: reconstr_noise_shrd, lyr_gen_embd: inpt_embd }, deterministic=True) else: image_fake = lyr.get_output(generator, inputs=inpt_noise) image_fake_det = lyr.get_output(generator, inputs=inpt_noise, deterministic=True) image_reconstr = lyr.get_output(generator, inputs=reconstr_noise_shrd, deterministic=True) # Get probabilities from discriminator probs_real = lyr.get_output(discriminator, inputs=inpt_image) probs_fake = lyr.get_output(discriminator, inputs=image_fake) probs_fake_det = lyr.get_output(discriminator, inputs=image_fake_det, deterministic=True) probs_reconstr = lyr.get_output(discriminator, inputs=image_reconstr, deterministic=True) # Calc loss for discriminator # minimize prob of error on true images d_loss_real = -T.mean(T.log(probs_real)) # minimize prob of error on fake images d_loss_fake = -T.mean(T.log(1 - probs_fake)) loss_discr = d_loss_real + d_loss_fake # Calc loss for generator # minimize the error of the discriminator on fake images loss_gener = -T.mean(T.log(probs_fake)) # Create params dict for both discriminator and generator params_discr = lyr.get_all_params(discriminator, trainable=True) params_gener = lyr.get_all_params(generator, trainable=True) # Set update rules for params using adam updates_discr = lasagne.updates.adam(loss_discr, params_discr, learning_rate=0.001, beta1=0.9) updates_gener = lasagne.updates.adam(loss_gener, params_gener, learning_rate=0.0005, beta1=0.6) # Contextual and perceptual loss for contx_loss = T.mean( lasagne.objectives.squared_error(image_reconstr * corr_mask, corr_image * corr_mask)) prcpt_loss = T.mean(T.log(1 - probs_reconstr)) # Total loss lbda = 10.0**-5 reconstr_loss = contx_loss + lbda * prcpt_loss # Set update rule that will change the input noise grad = T.grad(reconstr_loss, reconstr_noise_shrd) lr = 0.9 update_rule = reconstr_noise_shrd - lr * grad if args.verbose: print 'Networks created.' # Compile Theano functions print 'compiling...' if args.captions: train_d = theano.function([inpt_image, inpt_noise, inpt_embd], loss_discr, updates=updates_discr) print '- 1 of 4 compiled.' train_g = theano.function([inpt_noise, inpt_embd], loss_gener, updates=updates_gener) print '- 2 of 4 compiled.' predict = theano.function([inpt_noise, inpt_embd], [image_fake_det, probs_fake_det]) print '- 3 of 4 compiled.' reconstr = theano.function( [corr_image, corr_mask, inpt_embd], [reconstr_noise_shrd, image_reconstr, reconstr_loss, grad], updates=[(reconstr_noise_shrd, update_rule)]) print '- 4 of 4 compiled.' else: train_d = theano.function([inpt_image, inpt_noise], loss_discr, updates=updates_discr) print '- 1 of 4 compiled.' train_g = theano.function([inpt_noise], loss_gener, updates=updates_gener) print '- 2 of 4 compiled.' predict = theano.function([inpt_noise], [image_fake_det, probs_fake_det]) print '- 3 of 4 compiled.' reconstr = theano.function( [corr_image, corr_mask], [reconstr_noise_shrd, image_reconstr, reconstr_loss, grad], updates=[(reconstr_noise_shrd, update_rule)]) print '- 4 of 4 compiled.' print 'compiled.' return train_d, train_g, predict, reconstr, reconstr_noise_shrd, ( discriminator, generator)
def train(args): import models import numpy as np np.random.seed(1234) if args.dataset == 'digits': n_dim, n_out, n_channels = 8, 10, 1 X_train, y_train, X_val, y_val = data.load_digits() elif args.dataset == 'mnist': n_dim, n_out, n_channels = 28, 10, 1 X_train, y_train, X_val, y_val, _, _ = data.load_mnist() elif args.dataset == 'svhn': n_dim, n_out, n_channels = 32, 10, 3 X_train, y_train, X_val, y_val = data.load_svhn() X_train, y_train, X_val, y_val = data.prepare_dataset(X_train, y_train, X_val, y_val) elif args.dataset == 'cifar10': n_dim, n_out, n_channels = 32, 10, 3 X_train, y_train, X_val, y_val = data.load_cifar10() X_train, y_train, X_val, y_val = data.prepare_dataset(X_train, y_train, X_val, y_val) elif args.dataset == 'random': n_dim, n_out, n_channels = 2, 2, 1 X_train, y_train = data.load_noise(n=1000, d=n_dim) X_val, y_val = X_train, y_train else: raise ValueError('Invalid dataset name: %s' % args.dataset) print 'dataset loaded, dim:', X_train.shape # set up optimization params p = { 'lr' : args.lr, 'b1': args.b1, 'b2': args.b2 } # create model if args.model == 'softmax': model = models.Softmax(n_dim=n_dim, n_out=n_out, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'mlp': model = models.MLP(n_dim=n_dim, n_out=n_out, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'cnn': model = models.CNN(n_dim=n_dim, n_out=n_out, n_chan=n_channels, model=args.dataset, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'kcnn': model = models.KCNN(n_dim=n_dim, n_out=n_out, n_chan=n_channels, model=args.dataset, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'resnet': model = models.Resnet(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'vae': model = models.VAE(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p, model='bernoulli' if args.dataset in ('digits', 'mnist') else 'gaussian') elif args.model == 'convvae': model = models.ConvVAE(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p, model='bernoulli' if args.dataset in ('digits', 'mnist') else 'gaussian') elif args.model == 'convadgm': model = models.ConvADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p, model='bernoulli' if args.dataset in ('digits', 'mnist') else 'gaussian') elif args.model == 'sbn': model = models.SBN(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'adgm': model = models.ADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p, model='bernoulli' if args.dataset in ('digits', 'mnist') else 'gaussian') elif args.model == 'hdgm': model = models.HDGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'dadgm': model = models.DADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'dcgan': model = models.DCGAN(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) elif args.model == 'ssadgm': X_train_lbl, y_train_lbl, X_train_unl, y_train_unl \ = data.split_semisup(X_train, y_train, n_lbl=args.n_labeled) model = models.SSADGM(X_labeled=X_train_lbl, y_labeled=y_train_lbl, n_out=n_out, n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) X_train, y_train = X_train_unl, y_train_unl else: raise ValueError('Invalid model') # train model model.fit(X_train, y_train, X_val, y_val, n_epoch=args.epochs, n_batch=args.n_batch, logname=args.logname)
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] lastLayerActivation=kwargs["lastLayerActivation"] PercentageOfTrianable=kwargs["PercentageOfTrianable"] SpecificPathStr=kwargs["SpecificPathStr"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) #general_utils.setup_logging(model_name) # Load and rescale data #X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(dset, image_data_format) img_dim = (256,256,3) # Manual entry # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model """ generator_model = models.load("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size) """ generator_model=CreatErrorMapModel(input_shape=img_dim,lastLayerActivation=lastLayerActivation, PercentageOfTrianable=PercentageOfTrianable) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size) generator_model.compile(loss='mae', optimizer=opt_discriminator) #------------------------------------------------------------------------------- logpath=os.path.join('../../log','DepthMapWith'+lastLayerActivation+str(PercentageOfTrianable)+'UnTr'+SpecificPathStr) modelPath=os.path.join('../../models','DepthMapwith'+lastLayerActivation+str(PercentageOfTrianable)+'Untr'+SpecificPathStr) os.makedirs(logpath, exist_ok=True) os.makedirs(modelPath, exist_ok=True)os.makedirs(modelPath, exist_ok=True) #-----------------------PreTraining Depth Map------------------------------------- nb_train_samples = 2000 nb_validation_samples = epochs = 20 history=whole_model.fit_generator(data_utils.facades_generator(img_dim,batch_size=batch_size), samples_per_epoch=nb_train_samples,epochs=epochs,validation_data=data_utils.facades_generator(img_dim,batch_size=batch_size),nb_val_samples=nb_validation_ samples, callbacks=[ keras.callbacks.ModelCheckpoint(os.path.join(modelPath,'DepthMap_weightsBestLoss.h5'), monitor='val_loss', verbose=1, save_best_only=True), keras.callbacks.ModelCheckpoint(os.path.join(modelPath,'DepthMap_weightsBestAcc.h5'), monitor='acc', verbose=1, save_best_only=True), keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.1, patience=2, verbose=1, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0), keras.callbacks.TensorBoard(log_dir=logpath, histogram_freq=0, batch_size=batchSize, write_graph=True, write_grads=False, write_images=True, embeddin gs_freq=0, embeddings_layer_names=None, embeddings_metadata=None)],) #------------------------------------------------------------------------------------ discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_full_batch, X_sketch_batch in data_utils.facades_generator(img_dim,batch_size=batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch(X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # X_disc, y_disc # Create a batch to feed the generator model X_gen_target, X_gen = next(data_utils.facades_generator(img_dim,batch_size=batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation figure_name = "training_"+str(e) data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, figure_name) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_dim_ordering = kwargs["image_dim_ordering"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) print "hi" # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data( dset, image_dim_ordering) img_dim = X_full_train.shape[-3:] print "data loaded in memory" # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_dim_ordering) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size) generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_dim_ordering) loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = None disc_loss = None iter_num = 102 weights_path = "/home/abhik/pix2pix/src/model/weights/gen_weights_iter%s_epoch30.h5" % ( str(iter_num - 1)) print weights_path generator_model.load_weights(weights_path) #discriminator_model.load_weights("disc_weights1.2.h5") #DCGAN_model.load_weights("DCGAN_weights1.2.h5") print("Weights Loaded for iter - %d" % iter_num) # Running average losses_list = list() # loss_list = list() # prev_avg = 0 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() # global disc_n, disc_prev_avg, gen1_n, gen1_prev_avg, gen2_n, gen2_prev_avg, gen3_n, gen3_prev_avg # disc_n = 1 # disc_prev_avg = 0 # gen1_n = 1 # gen1_prev_avg = 0 # gen2_n = 1 # gen2_prev_avg = 0 # gen3_n = 1 # gen3_prev_avg = 0 for X_full_batch, X_sketch_batch in data_utils.gen_batch( X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch( X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_dim_ordering, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next( data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True # Running average # loss_list.append(disc_loss) # loss_list_n = len(loss_list) # new_avg = ((loss_list_n-1)*prev_avg + disc_loss)/loss_list_n # prev_avg = new_avg # disc_avg, gen1_avg, gen2_avg, gen3_avg = running_avg(disc_loss, gen_loss[0], gen_loss[1], gen_loss[2]) # print("running disc loss", new_avg) # print(disc_loss, gen_loss) # print ("all losses", disc_avg, gen1_avg, gen2_avg, gen3_avg) # print("") batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Saving data for plotting # losses = [e+1, batch_counter, disc_loss, gen_loss[0], gen_loss[1], gen_loss[2], disc_avg, gen1_avg, gen2_avg, gen3_avg, iter_num] # losses_list.append(losses) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation data_utils.plot_generated_batch( X_full_batch, X_sketch_batch, generator_model, batch_size, image_dim_ordering, "training", iter_num) X_full_batch, X_sketch_batch = next( data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch( X_full_batch, X_sketch_batch, generator_model, batch_size, image_dim_ordering, "validation", iter_num) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) #Running average disc_avg, gen1_avg, gen2_avg, gen3_avg = running_avg( disc_loss, gen_loss[0], gen_loss[1], gen_loss[2]) #Validation loss y_gen_val = np.zeros((X_sketch_batch.shape[0], 2), dtype=np.uint8) y_gen_val[:, 1] = 1 val_loss = DCGAN_model.test_on_batch(X_full_batch, [X_sketch_batch, y_gen_val]) # print "val_loss ===" + str(val_loss) #logging # Saving data for plotting losses = [ e + 1, iter_num, disc_loss, gen_loss[0], gen_loss[1], gen_loss[2], disc_avg, gen1_avg, gen2_avg, gen3_avg, val_loss[0], val_loss[1], val_loss[2] ] losses_list.append(losses) if (e + 1) % 5 == 0: gen_weights_path = os.path.join( '../../models/%s/gen_weights_iter%s_epoch%s.h5' % (model_name, iter_num, e + 1)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/%s/disc_weights_iter%s_epoch%s.h5' % (model_name, iter_num, e + 1)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/%s/DCGAN_weights_iter%s_epoch%s.h5' % (model_name, iter_num, e + 1)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) loss_array = np.asarray(losses_list) print(loss_array.shape) # 10 element vector loss_path = os.path.join( '../../losses/loss_iter%s_epoch%s.csv' % (iter_num, e + 1)) np.savetxt(loss_path, loss_array, fmt='%.5f', delimiter=',') np.savetxt('test.csv', loss_array, fmt='%.5f', delimiter=',') except KeyboardInterrupt: pass
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters patch_size = kwargs["patch_size"] image_data_format = kwargs["image_data_format"] generator_type = kwargs["generator_type"] dset = kwargs["dset"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"] visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"] use_mbd = kwargs["use_mbd"] label_smoothing = kwargs["use_label_smoothing"] label_flipping_prob = kwargs["label_flipping_prob"] use_l1_weighted_loss = kwargs["use_l1_weighted_loss"] prev_model = kwargs["prev_model"] discriminator_optimizer = kwargs["discriminator_optimizer"] n_run_of_gen_for_1_run_of_disc = kwargs["n_run_of_gen_for_1_run_of_disc"] MAX_FRAMES_PER_GIF = kwargs["MAX_FRAMES_PER_GIF"] # batch_size = args.batch_size # n_batch_per_epoch = args.n_batch_per_epoch # nb_epoch = args.nb_epoch # save_weights_every_n_epochs = args.save_weights_every_n_epochs # generator_type = args.generator_type # patch_size = args.patch_size # label_smoothing = False # label_flipping_prob = False # dset = args.dset # use_mbd = False # Check and make the dataset # If .h5 file of dset is not present, try making it if not os.path.exists("../../data/processed/%s_data.h5" % dset): print("dset %s_data.h5 not present in '../../data/processed'!" % dset) if not os.path.exists("../../data/%s/" % dset): print( "dset folder %s not present in '../../data'!\n\nERROR: Dataset .h5 file not made, and dataset not available in '../../data/'.\n\nQuitting." % dset) return else: if not os.path.exists( "../../data/%s/train" % dset) or not os.path.exists( "../../data/%s/val" % dset) or not os.path.exists( "../../data/%s/test" % dset): print( "'train', 'val' or 'test' folders not present in dset folder '../../data/%s'!\n\nERROR: Dataset must contain 'train', 'val' and 'test' folders.\n\nQuitting." % dset) return else: print("Making %s dataset" % dset) subprocess.call([ 'python3', '../data/make_dataset.py', '../../data/%s' % dset, '3' ]) print("Done!") epoch_size = n_batch_per_epoch * batch_size init_epoch = 0 if prev_model: print('\n\nLoading prev_model from', prev_model, '...\n\n') prev_model_latest_gen = sorted( glob.glob(os.path.join('../../models/', prev_model, '*gen*.h5')))[-1] prev_model_latest_disc = sorted( glob.glob(os.path.join('../../models/', prev_model, '*disc*.h5')))[-1] prev_model_latest_DCGAN = sorted( glob.glob(os.path.join('../../models/', prev_model, '*DCGAN*.h5')))[-1] # Find prev model name, epoch model_name = prev_model_latest_DCGAN.split('models')[-1].split('/')[1] init_epoch = int(prev_model_latest_DCGAN.split('epoch')[1][:5]) + 1 # Setup environment (logging directory etc), if no prev_model is mentioned general_utils.setup_logging(model_name) # img_dim = X_full_train.shape[-3:] img_dim = (256, 256, 3) # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) if discriminator_optimizer == 'sgd': opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) elif discriminator_optimizer == 'adam': opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator_type, img_dim, nb_patch, use_mbd, batch_size, model_name) generator_model.compile(loss='mae', optimizer=opt_discriminator) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, use_mbd, batch_size, model_name) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) if use_l1_weighted_loss: loss = [l1_weighted_loss, 'binary_crossentropy'] else: loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) # Load prev_model if prev_model: generator_model.load_weights(prev_model_latest_gen) discriminator_model.load_weights(prev_model_latest_disc) DCGAN_model.load_weights(prev_model_latest_DCGAN) # Load and rescale data print('\n\nLoading data...\n\n') X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data( dset, image_data_format) check_this_process_memory() print('X_full_train: %.4f' % (X_full_train.nbytes / 2**30), "GB") print('X_sketch_train: %.4f' % (X_sketch_train.nbytes / 2**30), "GB") print('X_full_val: %.4f' % (X_full_val.nbytes / 2**30), "GB") print('X_sketch_val: %.4f' % (X_sketch_val.nbytes / 2**30), "GB") # Losses disc_losses = [] gen_total_losses = [] gen_L1_losses = [] gen_log_losses = [] # Start training print("\n\nStarting training\n\n") for e in range(nb_epoch): # Initialize progbar and batch counter # progbar = generic_utils.Progbar(epoch_size) batch_counter = 0 gen_total_loss_epoch = 0 gen_L1_loss_epoch = 0 gen_log_loss_epoch = 0 start = time.time() for X_full_batch, X_sketch_batch in data_utils.gen_batch( X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch( X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping_prob=label_flipping_prob) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next( data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False # Train generator for _ in range(n_run_of_gen_for_1_run_of_disc - 1): gen_loss = DCGAN_model.train_on_batch( X_gen, [X_gen_target, y_gen]) gen_total_loss_epoch += gen_loss[ 0] / n_run_of_gen_for_1_run_of_disc gen_L1_loss_epoch += gen_loss[ 1] / n_run_of_gen_for_1_run_of_disc gen_log_loss_epoch += gen_loss[ 2] / n_run_of_gen_for_1_run_of_disc X_gen_target, X_gen = next( data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Add losses gen_total_loss_epoch += gen_loss[ 0] / n_run_of_gen_for_1_run_of_disc gen_L1_loss_epoch += gen_loss[ 1] / n_run_of_gen_for_1_run_of_disc gen_log_loss_epoch += gen_loss[ 2] / n_run_of_gen_for_1_run_of_disc # Unfreeze the discriminator discriminator_model.trainable = True # Progress # progbar.add(batch_size, values=[("D logloss", disc_loss), # ("G tot", gen_loss[0]), # ("G L1", gen_loss[1]), # ("G logloss", gen_loss[2])]) print("Epoch", str(init_epoch + e + 1), "batch", str(batch_counter + 1), "D_logloss", disc_loss, "G_tot", gen_loss[0], "G_L1", gen_loss[1], "G_log", gen_loss[2]) batch_counter += 1 if batch_counter >= n_batch_per_epoch: break gen_total_loss = gen_total_loss_epoch / n_batch_per_epoch gen_L1_loss = gen_L1_loss_epoch / n_batch_per_epoch gen_log_loss = gen_log_loss_epoch / n_batch_per_epoch disc_losses.append(disc_loss) gen_total_losses.append(gen_total_loss) gen_L1_losses.append(gen_L1_loss) gen_log_losses.append(gen_log_loss) check_this_process_memory() print('Epoch %s/%s, Time: %.4f' % (init_epoch + e + 1, init_epoch + nb_epoch, time.time() - start)) # Save images for visualization if (e + 1) % visualize_images_every_n_epochs == 0: data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, model_name, "training", init_epoch + e + 1, MAX_FRAMES_PER_GIF) # Get new images from validation X_full_batch, X_sketch_batch = next( data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, model_name, "validation", init_epoch + e + 1, MAX_FRAMES_PER_GIF) # Plot losses data_utils.plot_losses(disc_losses, gen_total_losses, gen_L1_losses, gen_log_losses, model_name, init_epoch) # Save weights if (e + 1) % save_weights_every_n_epochs == 0: gen_weights_path = os.path.join( '../../models/%s/gen_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/%s/disc_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/%s/DCGAN_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] generator = kwargs["generator"] model_name = kwargs["model_name"] image_data_format = kwargs["image_data_format"] celebA_img_dim = kwargs["celebA_img_dim"] cont_dim = (kwargs["cont_dim"], ) cat_dim = (kwargs["cat_dim"], ) noise_dim = (kwargs["noise_dim"], ) label_smoothing = kwargs["label_smoothing"] label_flipping = kwargs["label_flipping"] noise_scale = kwargs["noise_scale"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] load_from_dir = kwargs["load_from_dir"] target_size = kwargs["target_size"] save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"] save_only_last_n_weights = kwargs["save_only_last_n_weights"] visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(**kwargs) # Load and rescale data if dset == "celebA": X_real_train = data_utils.load_celebA(celebA_img_dim, image_data_format) elif dset == "mnist": X_real_train, _, _, _ = data_utils.load_mnist(image_data_format) else: X_batch_gen = data_utils.data_generator_from_dir( dset, target_size, batch_size) X_real_train = next(X_batch_gen) img_dim = X_real_train.shape[-3:] try: # Create optimizers opt_dcgan = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08) opt_discriminator = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True) # Load generator model generator_model = models.load("generator_%s" % generator, cat_dim, cont_dim, noise_dim, img_dim, batch_size, dset=dset, use_mbd=use_mbd) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", cat_dim, cont_dim, noise_dim, img_dim, batch_size, dset=dset, use_mbd=use_mbd) generator_model.compile(loss='mse', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, cat_dim, cont_dim, noise_dim) list_losses = [ 'binary_crossentropy', 'categorical_crossentropy', gaussian_loss ] list_weights = [1, 1, 1] DCGAN_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_dcgan) # Multiple discriminator losses discriminator_model.trainable = True discriminator_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 if not load_from_dir: X_batch_gen = data_utils.gen_batch(X_real_train, batch_size) # Start training print("Start training") disc_total_losses = [] disc_log_losses = [] disc_cat_losses = [] disc_cont_losses = [] gen_total_losses = [] gen_log_losses = [] gen_cat_losses = [] gen_cont_losses = [] start = time.time() for e in range(nb_epoch): print('--------------------------------------------') print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d}\n'.format( datetime.datetime.now(), e + 1, nb_epoch)) # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 disc_total_loss_batch = 0 disc_log_loss_batch = 0 disc_cat_loss_batch = 0 disc_cont_loss_batch = 0 gen_total_loss_batch = 0 gen_log_loss_batch = 0 gen_cat_loss_batch = 0 gen_cont_loss_batch = 0 for batch_counter in range(n_batch_per_epoch): # Load data X_real_batch = next(X_batch_gen) # Create a batch to feed the discriminator model X_disc, y_disc, y_cat, y_cont = data_utils.get_disc_batch( X_real_batch, generator_model, batch_counter, batch_size, cat_dim, cont_dim, noise_dim, noise_scale=noise_scale, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch( X_disc, [y_disc, y_cat, y_cont]) # Create a batch to feed the generator model X_gen, y_gen, y_cat, y_cont, y_cont_target = data_utils.get_gen_batch( batch_size, cat_dim, cont_dim, noise_dim, noise_scale=noise_scale) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch( [y_cat, y_cont, X_gen], [y_gen, y_cat, y_cont_target]) # Unfreeze the discriminator discriminator_model.trainable = True progbar.add(batch_size, values=[("D tot", disc_loss[0]), ("D log", disc_loss[1]), ("D cat", disc_loss[2]), ("D cont", disc_loss[3]), ("G tot", gen_loss[0]), ("G log", gen_loss[1]), ("G cat", gen_loss[2]), ("G cont", gen_loss[3])]) disc_total_loss_batch += disc_loss[0] disc_log_loss_batch += disc_loss[1] disc_cat_loss_batch += disc_loss[2] disc_cont_loss_batch += disc_loss[3] gen_total_loss_batch += gen_loss[0] gen_log_loss_batch += gen_loss[1] gen_cat_loss_batch += gen_loss[2] gen_cont_loss_batch += gen_loss[3] # # Save images for visualization # if batch_counter % (n_batch_per_epoch / 2) == 0: # data_utils.plot_generated_batch(X_real_batch, generator_model, e, # batch_size, cat_dim, cont_dim, noise_dim, # image_data_format, model_name) disc_total_losses.append(disc_total_loss_batch / n_batch_per_epoch) disc_log_losses.append(disc_log_loss_batch / n_batch_per_epoch) disc_cat_losses.append(disc_cat_loss_batch / n_batch_per_epoch) disc_cont_losses.append(disc_cont_loss_batch / n_batch_per_epoch) gen_total_losses.append(gen_total_loss_batch / n_batch_per_epoch) gen_log_losses.append(gen_log_loss_batch / n_batch_per_epoch) gen_cat_losses.append(gen_cat_loss_batch / n_batch_per_epoch) gen_cont_losses.append(gen_cont_loss_batch / n_batch_per_epoch) # Save images for visualization if (e + 1) % visualize_images_every_n_epochs == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, e, batch_size, cat_dim, cont_dim, noise_dim, image_data_format, model_name) data_utils.plot_losses(disc_total_losses, disc_log_losses, disc_cat_losses, disc_cont_losses, gen_total_losses, gen_log_losses, gen_cat_losses, gen_cont_losses, model_name) if (e + 1) % save_weights_every_n_epochs == 0: print("Saving weights...") # Delete all but the last n weights general_utils.purge_weights(save_only_last_n_weights, model_name) # Save weights gen_weights_path = os.path.join( '../../models/%s/gen_weights_epoch%05d.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/%s/disc_weights_epoch%05d.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/%s/DCGAN_weights_epoch%05d.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) end = time.time() print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, end - start)) start = end except KeyboardInterrupt: pass gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name) print("Saving", gen_weights_path) generator_model.save(gen_weights_path, overwrite=True)
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] generator = kwargs["generator"] model_name = kwargs["model_name"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] cont_dim = (kwargs["cont_dim"], ) cat_dim = (kwargs["cat_dim"], ) noise_dim = (kwargs["noise_dim"], ) bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["label_smoothing"] label_flipping = kwargs["label_flipping"] noise_scale = kwargs["noise_scale"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data if dset == "celebA": X_real_train = data_utils.load_celebA(img_dim, image_data_format) if dset == "mnist": X_real_train, _, _, _ = data_utils.load_mnist(image_data_format) img_dim = X_real_train.shape[-3:] try: # Create optimizers opt_dcgan = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08) opt_discriminator = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True) # Load generator model generator_model = models.load("generator_%s" % generator, cat_dim, cont_dim, noise_dim, img_dim, bn_mode, batch_size, dset=dset, use_mbd=use_mbd) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", cat_dim, cont_dim, noise_dim, img_dim, bn_mode, batch_size, dset=dset, use_mbd=use_mbd) generator_model.compile(loss='mse', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, cat_dim, cont_dim, noise_dim) list_losses = [ 'binary_crossentropy', 'categorical_crossentropy', gaussian_loss ] list_weights = [1, 1, 1] DCGAN_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_dcgan) # Multiple discriminator losses discriminator_model.trainable = True discriminator_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_real_batch in data_utils.gen_batch(X_real_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc, y_cat, y_cont = data_utils.get_disc_batch( X_real_batch, generator_model, batch_counter, batch_size, cat_dim, cont_dim, noise_dim, noise_scale=noise_scale, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch( X_disc, [y_disc, y_cat, y_cont]) # Create a batch to feed the generator model X_gen, y_gen, y_cat, y_cont, y_cont_target = data_utils.get_gen_batch( batch_size, cat_dim, cont_dim, noise_dim, noise_scale=noise_scale) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch( [y_cat, y_cont, X_gen], [y_gen, y_cat, y_cont_target]) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D tot", disc_loss[0]), ("D log", disc_loss[1]), ("D cat", disc_loss[2]), ("D cont", disc_loss[3]), ("G tot", gen_loss[0]), ("G log", gen_loss[1]), ("G cat", gen_loss[2]), ("G cont", gen_loss[3])]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, batch_size, cat_dim, cont_dim, noise_dim, image_data_format) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: gen_weights_path = os.path.join( '../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def main(args): if args.mnist: # Normalize image for MNIST # normalize = Normalize(mean=(0.1307,), std=(0.3081,)) normalize = None args.input_size = 784 elif args.cifar: normalize = utils.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) args.input_size = 32 * 32 * 3 else: # Normalize image for ImageNet normalize = utils.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) args.input_size = 150528 # Load data train_loader, test_loader = utils.get_data(args) # The unknown model to attack unk_model = utils.load_unk_model(args) # Try Whitebox Untargeted first if args.debug: ipdb.set_trace() if args.train_vae: encoder, decoder, vae = train_mnist_vae(args) else: encoder, decoder, vae = None, None, None if args.train_ae: encoder, decoder, ae = train_mnist_ae(args) else: encoder, decoder, ae = None, None, None # Add A Flow norm_flow = None if args.use_flow: # norm_flow = flows.NormalizingFlow(30, args.latent).to(args.device) norm_flow = flows.Planar # Test white box if args.white: # Choose Attack Function if args.no_pgd_optim: white_attack_func = attacks.L2_white_box_generator else: white_attack_func = attacks.PGD_white_box_generator # Choose Dataset if args.mnist: G = models.Generator(input_size=784).to(args.device) elif args.cifar: if args.vanilla_G: G = models.DCGAN().to(args.device) G = nn.DataParallel(G.generator) else: G = models.ConvGenerator(models.Bottleneck,[6,12,24,16],growth_rate=12,\ flows=norm_flow,use_flow=args.use_flow,\ deterministic=args.deterministic_G).to(args.device) G = nn.DataParallel(G) nc, h, w = 3, 32, 32 if args.run_baseline: attacks.whitebox_pgd(args, unk_model) pred, delta = white_attack_func(args, train_loader,\ test_loader, unk_model, G, nc, h, w) # Blackbox Attack model model = models.GaussianPolicy(args.input_size, 400, args.latent_size, decode=False).to(args.device) # Control Variate cv = to_cuda(models.FC(args.input_size, args.classes))
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] # right strip '/' to avoid empty '/' dir save_dir = kwargs["save_dir"].rstrip('/') # join name with current datetime save_dir = '_'.join( [save_dir, datetime.datetime.now().strftime("%I:%M%p-%B%d-%Y/")]) if not os.path.isdir(save_dir): os.makedirs(save_dir) # save the config in save dir with open('{0}job_config.json'.format(save_dir), 'w') as fp: json.dump(kwargs, fp, sort_keys=True, indent=4) epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data( dset, image_data_format) img_dim = X_full_train.shape[-3:] # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size) generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_full_batch, X_sketch_batch in data_utils.gen_batch( X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch( X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next( data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation data_utils.plot_generated_batch( X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, "{:03}_EPOCH_TRAIN".format(e + 1), save_dir) X_full_batch, X_sketch_batch = next( data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch( X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, "{:03}_EPOCH_VALID".format(e + 1), save_dir) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: pass # save models # gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) # generator_model.save_weights(gen_weights_path, overwrite=True) # disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) # discriminator_model.save_weights(disc_weights_path, overwrite=True) # DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) # DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass # save models DCGAN_model.save(save_dir + 'DCGAN.h5') generator_model.save(save_dir + 'GENERATOR.h5') discriminator_model.save(save_dir + 'DISCRIMINATOR.h5')
import os import models import utils if __name__ == "__main__": opt = utils.get_options() model = models.DCGAN(opt) dataloader = utils.create_dataloader(opt) os.makedirs("./results", exist_ok=True) G_losses, D_losses = [], [] print("Start Training...") for epoch in range(1, opt.num_epochs + 1): for iters, (data, _) in enumerate(dataloader): model.set_input(data) model.optimize_parameters() loss_G, loss_D = model.get_losses() G_losses.append(loss_G) D_losses.append(loss_D) if iters % 50 == 0: print("Epoch: %d/%d\tIter: %d/%d\tLoss_G: %.4f\tLoss_D: %.4f" % (epoch, opt.num_epochs, iters, len(dataloader), loss_G, loss_D)) if (iters % 500 == 0) or ((epoch == opt.num_epochs) and
def train(cat_dim, noise_dim, batch_size, n_batch_per_epoch, nb_epoch, dset="mnist"): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ general_utils.setup_logging("IG") # Load and rescale data if dset == "mnist": print("loading mnist data") X_real_train, Y_real_train, X_real_test, Y_real_test = data_utils.load_mnist( ) # pick 1000 sample for testing # X_real_test = X_real_test[-1000:] # Y_real_test = Y_real_test[-1000:] img_dim = X_real_train.shape[-3:] epoch_size = n_batch_per_epoch * batch_size try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) opt_discriminator = Adam(lr=2E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True) # Load generator model generator_model = models.load("generator_deconv", cat_dim, noise_dim, img_dim, batch_size, dset=dset) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", cat_dim, noise_dim, img_dim, batch_size, dset=dset) generator_model.compile(loss='mse', optimizer=opt_discriminator) # stop the discriminator to learn while in generator is learning discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, cat_dim, noise_dim) list_losses = ['binary_crossentropy', 'categorical_crossentropy'] list_weights = [1, 1] DCGAN_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_dcgan) # Multiple discriminator losses # allow the discriminator to learn again discriminator_model.trainable = True discriminator_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_discriminator) # Start training print("Start training") for e in range(nb_epoch + 1): # Initialize progbar and batch counter # progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() print("Epoch: {}".format(e)) for X_real_batch, Y_real_batch in zip( data_utils.gen_batch(X_real_train, batch_size), data_utils.gen_batch(Y_real_train, batch_size)): # Create a batch to feed the discriminator model X_disc_fake, y_disc_fake, noise_sample = data_utils.get_disc_batch( X_real_batch, Y_real_batch, generator_model, batch_size, cat_dim, noise_dim, type="fake") X_disc_real, y_disc_real = data_utils.get_disc_batch( X_real_batch, Y_real_batch, generator_model, batch_size, cat_dim, noise_dim, type="real") # Update the discriminator disc_loss_fake = discriminator_model.train_on_batch( X_disc_fake, [y_disc_fake, Y_real_batch]) disc_loss_real = discriminator_model.train_on_batch( X_disc_real, [y_disc_real, Y_real_batch]) disc_loss = disc_loss_fake + disc_loss_real # Create a batch to feed the generator model # X_noise, y_gen = data_utils.get_gen_batch(batch_size, cat_dim, noise_dim) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch( [Y_real_batch, noise_sample], [y_disc_real, Y_real_batch]) # Unfreeze the discriminator discriminator_model.trainable = True # training validation p_real_batch, p_Y_batch = discriminator_model.predict( X_real_batch, batch_size=batch_size) acc_train = data_utils.accuracy(p_Y_batch, Y_real_batch) batch_counter += 1 # progbar.add(batch_size, values=[("D tot", disc_loss[0]), # ("D cat", disc_loss[2]), # ("G tot", gen_loss[0]), # ("G cat", gen_loss[2]), # ("P Real:", p_real_batch), # ("Q acc", acc_train)]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0 and e % 10 == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, batch_size, cat_dim, noise_dim, e) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) _, p_Y_test = discriminator_model.predict( X_real_test, batch_size=X_real_test.shape[0]) acc_test = data_utils.accuracy(p_Y_test, Y_real_test) print("Epoch: {} Accuracy: {}".format(e + 1, acc_test)) if e % 1000 == 0: gen_weights_path = os.path.join( '../../models/IG/gen_weights.h5') generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/IG/disc_weights.h5') discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/IG/DCGAN_weights.h5') DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] pretrained_model_path = kwargs["pretrained_model_path"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(dset, image_data_format) img_dim = X_full_train.shape[-3:] # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) load_pretrained = False if pretrained_model_path: load_pretrained = True # Load generator model generator_model = models.load("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size, load_pretrained) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size, load_pretrained) generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_full_batch, X_sketch_batch in data_utils.gen_batch(X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch(X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next(data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, "training") X_full_batch, X_sketch_batch = next(data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, "validation") if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
# Create optimizers G_opt = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) D_opt = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator", img_dim=(256, 256, 3)) # generator_model.load_weights('./models/pix2pix/gen_weights_epoch_6.h5') generator_model.compile(loss='mae', optimizer=G_opt) # Load discriminator model discriminator_model = models.load("discriminator", img_dim=(256, 256, 3)) # discriminator_model.load_weights('./models/pix2pix/disc_weights_epoch_6.h5') discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim=(256, 256, 3)) # DCGAN_model.load_weights('./models/pix2pix/DCGAN_weights_epoch_6.h5') loss = [l1_loss, 'binary_crossentropy'] # loss = [perceptual_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=G_opt) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=D_opt) # Start training print("Start training") for e in range(1, nb_epoch + 1): # Initialize progbar and batch counter