def train_toy(**kwargs): """ Train model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] clamp_lower = kwargs["clamp_lower"] clamp_upper = kwargs["clamp_upper"] epoch_size = n_batch_per_epoch * batch_size print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") # Setup environment (logging directory etc) general_utils.setup_logging("toy_MLP") # Load and rescale data X_real_train = data_utils.load_toy() # Create optimizers opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_D = data_utils.get_optimizer(opt_D, lr_D) ####################### # Load models ####################### noise_dim = (noise_dim,) generator_model = models.generator_toy(noise_dim) discriminator_model = models.discriminator_toy() GAN_model = models.GAN_toy(generator_model, discriminator_model, noise_dim) ############################ # Compile models ############################ generator_model.compile(loss='mse', optimizer=opt_G) discriminator_model.trainable = False GAN_model.compile(loss=models.wasserstein, optimizer=opt_G) discriminator_model.trainable = True discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D) # Global iteration counter for generator updates gen_iterations = 0 ################# # Start training ################# for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: disc_iterations = kwargs["disc_iterations"] ################################### # 1) Train the critic / discriminator ################################### list_disc_loss_real = [] list_disc_loss_gen = [] for disc_it in range(disc_iterations): # Clip discriminator weights for l in discriminator_model.layers: weights = l.get_weights() weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights] l.set_weights(weights) X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size)) # Create a batch to feed the discriminator model X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch, generator_model, batch_counter, batch_size, noise_dim, noise_scale=noise_scale) # Update the discriminator disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0])) disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0])) list_disc_loss_real.append(disc_loss_real) list_disc_loss_gen.append(disc_loss_gen) ####################### # 2) Train the generator ####################### X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim) # Freeze the discriminator discriminator_model.trainable = False gen_loss = GAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0])) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)), ("Loss_D_real", -np.mean(list_disc_loss_real)), ("Loss_D_gen", np.mean(list_disc_loss_gen)), ("Loss_G", -gen_loss)]) # # Save images for visualization if gen_iterations % 50 == 0: data_utils.plot_generated_toy_batch(X_real_train, generator_model, discriminator_model, noise_dim, gen_iterations) gen_iterations += 1 print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] generator = kwargs["generator"] model_name = kwargs["model_name"] image_dim_ordering = kwargs["image_dim_ordering"] img_dim = kwargs["img_dim"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["label_smoothing"] label_flipping = kwargs["label_flipping"] noise_scale = kwargs["noise_scale"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data # if dset == "celebA": # X_real_train = data_utils.load_celebA(img_dim, image_dim_ordering) # if dset == "mnist": # X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering) # img_dim = X_real_train.shape[-3:] img_dim = (3, 64, 64) noise_dim = (100, ) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.5, beta_2=0.999, epsilon=1e-08) opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) # Load generator model generator_model = models.load("generator_%s" % generator, noise_dim, img_dim, bn_mode, batch_size, dset=dset, use_mbd=use_mbd) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", noise_dim, img_dim, bn_mode, batch_size, dset=dset, use_mbd=use_mbd) #load the weights here for e in range(200, 355, 5): gen_weights_path = os.path.join( '../../CNN/gen_weights_epoch%s.h5' % (e)) # gen_weight_file = h5py.File(gen_weights_path, 'r') file_path_to_save_img = 'GeneratedImages1234/Epoch_%s/' % (e) os.mkdir(file_path_to_save_img) # generate images generator_model.load_weights(gen_weights_path) generator_model.compile(loss='mse', optimizer=opt_discriminator) noise_z = np.random.normal(scale=0.5, size=(32, noise_dim[0])) X_generated = generator_model.predict(noise_z) # print('Epoch%s.png' % (i)) X_gen = inverse_normalization(X_generated) for img in range(X_gen.shape[0]): ret = X_gen[img].transpose(1, 2, 0) fig = plt.figure(frameon=False) fig.set_size_inches(64, 64) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) ax.imshow(ret, aspect='normal') fig.savefig(file_path_to_save_img + 'retina_%s.png' % (img), dpi=1) plt.clf() plt.close() # Xg = X_gen[:8] # Xr = X_gen[8:] # # if image_dim_ordering == "tf": # X = np.concatenate((Xg, Xr), axis=0) # list_rows = [] # for i in range(int(X.shape[0] / 4)): # Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=1) # list_rows.append(Xr) # # Xr = np.concatenate(list_rows, axis=0) # # if image_dim_ordering == "th": # X = np.concatenate((Xg, Xr), axis=0) # list_rows = [] # for i in range(int(X.shape[0] / 4)): # Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=2) # list_rows.append(Xr) # # Xr = np.concatenate(list_rows, axis=1) # Xr = Xr.transpose(1,2,0) # # if Xr.shape[-1] == 1: # plt.imshow(Xr[:, :, 0], cmap="gray") # else: # plt.imshow(Xr) # plt.savefig(file_path_to_save_img+'Epoch%s.png' % (e)) # plt.clf() # plt.close() # generator_model.load_weights('gen_weights_epoch245.h5') # generator_model.compile(loss='mse', optimizer=opt_discriminator) # discriminator_model.trainable = False # # DCGAN_model = models.DCGAN(generator_model, # discriminator_model, # noise_dim, # img_dim) # # loss = ['binary_crossentropy'] # loss_weights = [1] # DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) # # discriminator_model.trainable = True # discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) # noise_z = np.random.normal(scale=0.5, size=(32, noise_dim[0])) # X_generated = generator_model.predict(noise_z) # # X_gen = inverse_normalization(X_generated) # # Xg = X_gen[:8] # Xr = X_gen[8:] # # if image_dim_ordering == "tf": # X = np.concatenate((Xg, Xr), axis=0) # list_rows = [] # for i in range(int(X.shape[0] / 4)): # Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=1) # list_rows.append(Xr) # # Xr = np.concatenate(list_rows, axis=0) # # if image_dim_ordering == "th": # X = np.concatenate((Xg, Xr), axis=0) # list_rows = [] # for i in range(int(X.shape[0] / 4)): # Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=2) # list_rows.append(Xr) # # Xr = np.concatenate(list_rows, axis=1) # Xr = Xr.transpose(1,2,0) # # if Xr.shape[-1] == 1: # plt.imshow(Xr[:, :, 0], cmap="gray") # else: # plt.imshow(Xr) # plt.savefig("current_batch.png") # plt.clf() # plt.close() # gen_loss = 100 # disc_loss = 100 # # # Start training # print("Start training") # k = 0 # for e in range(nb_epoch): # # Initialize progbar and batch counter # progbar = generic_utils.Progbar(epoch_size) # batch_counter = 1 # start = time.time() # # for X_real_batch in data_utils.gen_batch(X_real_train, batch_size): # # # Create a batch to feed the discriminator model # X_disc, y_disc = data_utils.get_disc_batch(X_real_batch, # generator_model, # batch_counter, # batch_size, # noise_dim, # noise_scale=noise_scale, # label_smoothing=label_smoothing, # label_flipping=label_flipping) # # # Update the discriminator # disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # # # Create a batch to feed the generator model # X_gen, y_gen = data_utils.get_gen_batch(batch_size, noise_dim, noise_scale=noise_scale) # # # Freeze the discriminator # discriminator_model.trainable = False # gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen) # # Unfreeze the discriminator # discriminator_model.trainable = True # # batch_counter += 1 # progbar.add(batch_size, values=[("D logloss", disc_loss), # ("G logloss", gen_loss)]) # # # Save images for visualization # if batch_counter % 100 == 0: # data_utils.plot_generated_batch(X_real_batch, generator_model, # batch_size, noise_dim, image_dim_ordering,k) # k = k +1 # if batch_counter >= n_batch_per_epoch: # break # # print("") # print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) # # if e % 5 == 0: # gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) # generator_model.save_weights(gen_weights_path, overwrite=True) # # disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) # discriminator_model.save_weights(disc_weights_path, overwrite=True) # # DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) # DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] generator = kwargs["generator"] model_name = kwargs["model_name"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["label_smoothing"] label_flipping = kwargs["label_flipping"] noise_scale = kwargs["noise_scale"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data if dset == "celebA": X_real_train = data_utils.load_celebA(img_dim, image_data_format) if dset == "mnist": X_real_train, _, _, _ = data_utils.load_mnist(image_data_format) img_dim = X_real_train.shape[-3:] noise_dim = (100, ) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.5, beta_2=0.999, epsilon=1e-08) opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) # Load generator model generator_model = models.load("generator_%s" % generator, noise_dim, img_dim, bn_mode, batch_size, dset=dset, use_mbd=use_mbd) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", noise_dim, img_dim, bn_mode, batch_size, dset=dset, use_mbd=use_mbd) generator_model.compile(loss='mse', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_dim) loss = ['binary_crossentropy'] loss_weights = [1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_real_batch in data_utils.gen_batch(X_real_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch( X_real_batch, generator_model, batch_counter, batch_size, noise_dim, noise_scale=noise_scale, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen, y_gen = data_utils.get_gen_batch( batch_size, noise_dim, noise_scale=noise_scale) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G logloss", gen_loss)]) # Save images for visualization if batch_counter % 100 == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, batch_size, noise_dim, image_data_format) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: gen_weights_path = os.path.join( '../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] discriminator = kwargs["discriminator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] bn_mode = kwargs["bn_mode"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] use_mbd = kwargs["use_mbd"] image_dim_ordering = kwargs["image_dim_ordering"] epoch_size = n_batch_per_epoch * batch_size deterministic = kwargs["deterministic"] inject_noise = kwargs["inject_noise"] model = kwargs["model"] no_supertrain = kwargs["no_supertrain"] pureGAN = kwargs["pureGAN"] lsmooth = kwargs["lsmooth"] disc_type = kwargs["disc_type"] resume = kwargs["resume"] name = kwargs["name"] wd = kwargs["wd"] history_size = kwargs["history_size"] monsterClass = kwargs["monsterClass"] print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") # Setup environment (logging directory etc) general_utils.setup_logging("DCGAN") # Load and normalize data if dset == "mnistM": X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='mnist') # X_source_train=np.concatenate([X_source_train,X_source_train,X_source_train], axis=1) # X_source_test=np.concatenate([X_source_test,X_source_test,X_source_test], axis=1) X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='mnistM') elif dset == "OfficeDslrToAmazon": X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='OfficeDslr') X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='OfficeAmazon') else: print "dataset not supported" if n_classes1 != n_classes2: #sanity check print "number of classes mismatch between source and dest domains" n_classes = n_classes1 # img_source_dim = X_source_train.shape[-3:] # is it backend agnostic? img_dest_dim = X_dest_train.shape[-3:] # Create optimizers opt_D = data_utils.get_optimizer(opt_D, lr_D) opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_GC = data_utils.get_optimizer('Adam', lr_G / 10.0) opt_C = data_utils.get_optimizer('Adam', lr_D) opt_Z = data_utils.get_optimizer('Adam', lr_G) ####################### # Load models ####################### noise_dim = (noise_dim, ) generator_model = models.generator_google_mnistM(noise_dim, img_source_dim, img_dest_dim, deterministic, pureGAN, wd) # discriminator_model = models.discriminator_google_mnistM(img_dest_dim, wd) discriminator_model = models.discriminator_dcgan(img_dest_dim, wd, n_classes, disc_type) classificator_model = models.classificator_google_mnistM( img_dest_dim, n_classes, wd) DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model, noise_dim, img_source_dim) zclass_model = z_coerence(generator_model, img_source_dim, bn_mode, wd, inject_noise, n_classes, noise_dim, model_name="zClass") # GenToClassifier_model = models.GenToClassifierModel(generator_model, classificator_model, noise_dim, img_source_dim) #disc_penalty_model = models.disc_penalty(discriminator_model,noise_dim,img_source_dim,opt_D,model_name="disc_penalty_model") zclass_model = z_coerence(generator_model, img_source_dim, bn_mode, wd, inject_noise, n_classes, noise_dim, model_name="zClass") ############################ # Compile models ############################ generator_model.compile(loss='mse', optimizer=opt_G) models.make_trainable(discriminator_model, False) models.make_trainable(classificator_model, False) # models.make_trainable(disc_penalty_model, False) if model == 'wgan': DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G) models.make_trainable(discriminator_model, True) # models.make_trainable(disc_penalty_model, True) discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D) if model == 'lsgan': if disc_type == "simple_disc": DCGAN_model.compile(loss=['mse'], optimizer=opt_G) models.make_trainable(discriminator_model, True) discriminator_model.compile(loss=['mse'], optimizer=opt_D) elif disc_type == "nclass_disc": DCGAN_model.compile(loss=['mse', 'categorical_crossentropy'], loss_weights=[1.0, 0.1], optimizer=opt_G) models.make_trainable(discriminator_model, True) discriminator_model.compile( loss=['mse', 'categorical_crossentropy'], loss_weights=[1.0, 0.1], optimizer=opt_D) # GenToClassifier_model.compile(loss='categorical_crossentropy', optimizer=opt_GC) models.make_trainable(classificator_model, True) classificator_model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer=opt_C) zclass_model.compile(loss=['mse'], optimizer=opt_Z) visualize = True ######## #MAKING TRAIN+TEST numpy array for global testing: ######## Xtarget_dataset = np.concatenate([X_dest_train, X_dest_test], axis=0) Ytarget_dataset = np.concatenate([Y_dest_train, Y_dest_test], axis=0) if resume: ########loading previous saved model weights and checking actual performance data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name, classificator_model, zclass_model) # data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name,classificator_model) # loss4, acc4 = classificator_model.evaluate(Xtarget_dataset, Ytarget_dataset,batch_size=1024, verbose=0) # print('\n Classifier Accuracy on full target domain: %.2f%%' % (100 * acc4)) else: X_gen = data_utils.sample_noise(noise_scale, X_source_train.shape[0], noise_dim) zclass_loss = zclass_model.fit([X_gen, X_source_train], [X_gen], batch_size=256, epochs=10) ####train zclass regression model only if not resuming: gen_iterations = 0 max_history_size = int(history_size * batch_size) img_buffer = ImageHistoryBuffer((0, ) + img_source_dim, max_history_size, batch_size, n_classes) ################# # Start training ################ for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: if no_supertrain is None: if (gen_iterations < 25) and (not resume): disc_iterations = 100 if gen_iterations % 500 == 0: disc_iterations = 100 else: disc_iterations = kwargs["disc_iterations"] else: if (gen_iterations < 25) and (not resume): disc_iterations = 100 else: disc_iterations = kwargs["disc_iterations"] ################################### # 1) Train the critic / discriminator ################################### list_disc_loss_real = deque(10 * [0], 10) list_disc_loss_gen = deque(10 * [0], 10) list_gen_loss = deque(10 * [0], 10) list_zclass_loss = deque(10 * [0], 10) list_classifier_loss = deque(10 * [0], 10) list_gp_loss = deque(10 * [0], 10) for disc_it in range(disc_iterations): X_dest_batch, Y_dest_batch, idx_dest_batch = next( data_utils.gen_batch(X_dest_train, Y_dest_train, batch_size)) X_source_batch, Y_source_batch, idx_source_batch = next( data_utils.gen_batch(X_source_train, Y_source_train, batch_size)) ########## # Create a batch to feed the discriminator model ######### X_disc_real, X_disc_gen = data_utils.get_disc_batch( X_dest_batch, generator_model, batch_counter, batch_size, noise_dim, X_source_batch, noise_scale=noise_scale) # Update the discriminator if model == 'wgan': current_labels_real = -np.ones(X_disc_real.shape[0]) current_labels_gen = np.ones(X_disc_gen.shape[0]) elif model == 'lsgan': if disc_type == "simple_disc": current_labels_real = np.ones(X_disc_real.shape[0]) current_labels_gen = np.zeros(X_disc_gen.shape[0]) elif disc_type == "nclass_disc": virtual_real_labels = np.zeros( [X_disc_gen.shape[0], n_classes]) current_labels_real = [ np.ones(X_disc_real.shape[0]), virtual_real_labels ] current_labels_gen = [ np.zeros(X_disc_gen.shape[0]), Y_source_batch ] ############## #Train the disc on gen-buffered samples and on current real samples ############## disc_loss_real = discriminator_model.train_on_batch( X_disc_real, current_labels_real) img_buffer.add_to_buffer(X_disc_gen, current_labels_gen, batch_size) bufferImages, bufferLabels = img_buffer.get_from_buffer( batch_size) disc_loss_gen = discriminator_model.train_on_batch( bufferImages, bufferLabels) #if not isinstance(disc_loss_real, collections.Iterable): disc_loss_real = [disc_loss_real] #if not isinstance(disc_loss_real, collections.Iterable): disc_loss_gen = [disc_loss_gen] if disc_type == "simple_disc": list_disc_loss_real.appendleft(disc_loss_real) list_disc_loss_gen.appendleft(disc_loss_gen) elif disc_type == "nclass_disc": list_disc_loss_real.appendleft(disc_loss_real[0]) list_disc_loss_gen.appendleft(disc_loss_gen[0]) ############# ####Train the discriminator w.r.t gradient penalty ############# #gp_loss = disc_penalty_model.train_on_batch([X_disc_real,X_disc_gen],current_labels_real) #dummy labels,not used in the loss function #list_gp_loss.appendleft(gp_loss) ################ ###CLASSIFIER TRAINING OUTSIDE DISC LOOP(wanna train in just 1 time even if disc_iter > 1) ################# class_loss_gen = classificator_model.train_on_batch( X_disc_gen, Y_source_batch * 0.7) #LABEL SMOOTHING!!!! list_classifier_loss.appendleft(class_loss_gen[1]) ####################### # 2) Train the generator ####################### X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim) X_source_batch2, Y_source_batch2, idx_source_batch2 = next( data_utils.gen_batch(X_source_train, Y_source_train, batch_size)) if model == 'wgan': gen_loss = DCGAN_model.train_on_batch([X_gen, X_source_batch2], -np.ones(X_gen.shape[0])) if model == 'lsgan': if disc_type == "simple_disc": gen_loss = DCGAN_model.train_on_batch( [X_gen, X_source_batch2], np.ones(X_gen.shape[0])) #TRYING SAME BATCH OF DISC elif disc_type == "nclass_disc": gen_loss = DCGAN_model.train_on_batch( [X_gen, X_source_batch2], [np.ones(X_gen.shape[0]), Y_source_batch2]) gen_loss = gen_loss[0] list_gen_loss.appendleft(gen_loss) zclass_loss = zclass_model.train_on_batch([X_gen, X_source_batch2], [X_gen]) list_zclass_loss.appendleft(zclass_loss) ############## #Train the generator w.r.t the aux classifier: ############# # GenToClassifier_model.train_on_batch([X_gen,X_source_batch2],Y_source_batch2) # I SHOULD TRY TO CLASSIFY EVEN ON DISCRIMINATOR, PUTTING ONE CLASS FOR REAL SAMPLES AND N CLASS FOR FAKE gen_iterations += 1 batch_counter += 1 progbar.add(batch_size, values=[("Loss_D_real", np.mean(list_disc_loss_real)), ("Loss_D_gen", np.mean(list_disc_loss_gen)), ("Loss_G", np.mean(list_gen_loss)), ("Loss_Z", np.mean(list_zclass_loss)), ("Loss_Classifier", np.mean(list_classifier_loss))]) # plot images 1 times per epoch if batch_counter % (n_batch_per_epoch) == 0: X_source_batch_plot, Y_source_batch_plot, idx_source_plot = next( data_utils.gen_batch(X_source_test, Y_source_test, batch_size=32)) data_utils.plot_generated_batch(X_dest_test, X_source_test, generator_model, noise_dim, image_dim_ordering, idx_source_plot, batch_size=32) if gen_iterations % (n_batch_per_epoch * 5) == 0: if visualize: BIG_ASS_VISUALIZATION_slerp(X_source_train[1], generator_model, noise_dim) # if (e % 20) == 0: # lr_decay([discriminator_model,DCGAN_model,classificator_model],decay_value=0.95) print("Dest labels:") print(Y_dest_test[idx_source_plot].argmax(1)) print("Source labels:") print(Y_source_batch_plot.argmax(1)) print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) # Save model weights (by default, every 5 epochs) data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, name, classificator_model, zclass_model) #testing accuracy of trained classifier loss4, acc4 = classificator_model.evaluate(Xtarget_dataset, Ytarget_dataset, batch_size=1024, verbose=0) print( '\n Classifier Accuracy and loss on full target domain: %.2f%% / %.5f%%' % ((100 * acc4), loss4))
def train(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] discriminator = kwargs["discriminator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] bn_mode = kwargs["bn_mode"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] use_mbd = kwargs["use_mbd"] clamp_lower = kwargs["clamp_lower"] clamp_upper = kwargs["clamp_upper"] image_dim_ordering = kwargs["image_dim_ordering"] epoch_size = n_batch_per_epoch * batch_size deterministic = kwargs["deterministic"] inject_noise = kwargs["inject_noise"] model = kwargs["model"] no_supertrain = kwargs["no_supertrain"] pureGAN = kwargs["pureGAN"] lsmooth = kwargs["lsmooth"] simple_disc = kwargs["simple_disc"] resume = kwargs["resume"] name = kwargs["name"] wd = kwargs["wd"] history_size = kwargs["history_size"] monsterClass = kwargs["monsterClass"] print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") # Setup environment (logging directory etc) general_utils.setup_logging("DCGAN") # Load and normalize data if dset == "mnistM": X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='mnist') X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='mnistM') #code.interact(local=locals()) elif dset == "washington_vandal50k": X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering, dset='washington') X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering, dset='vandal50k') elif dset == "washington_vandal12classes": X_source_train = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='washington12classes') X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering, dset='vandal12classes') elif dset == "washington_vandal12classesNoBackground": X_source_train, Y_source_train, n_classes1 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='washington12classes') X_dest_train, Y_dest_train, n_classes2 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='vandal12classesNoBackground') elif dset == "Wash_Vand_12class_LMDB": X_source_train, Y_source_train, n_classes1 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='Wash_12class_LMDB') elif dset == "OfficeDslrToAmazon": X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='OfficeDslr') X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='OfficeAmazon') elif dset == "bedrooms": X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='bedrooms_small') X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='bedrooms') elif dset == "Vand_Vand_12class_LMDB": X_source_train, Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='Vand_12class_LMDB_Background') X_dest_train, Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset( img_dim, image_dim_ordering, dset='Vand_12class_LMDB') else: print "dataset not supported" if n_classes1 != n_classes2: #sanity check print "number of classes mismatch between source and dest domains" n_classes = n_classes1 # img_source_dim = X_source_train.shape[-3:] # is it backend agnostic? img_dest_dim = X_dest_train.shape[-3:] # Create optimizers opt_D = data_utils.get_optimizer(opt_D, lr_D) opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_C = data_utils.get_optimizer('SGD', 0.01) ####################### # Load models ####################### noise_dim = (noise_dim, ) if generator == "upsampling": generator_model = models.generator_upsampling_mnistM(noise_dim, img_source_dim, img_dest_dim, bn_mode, deterministic, pureGAN, inject_noise, wd, dset=dset) else: generator_model = models.generator_deconv(noise_dim, img_dest_dim, bn_mode, batch_size, dset=dset) if simple_disc: discriminator_model = models.discriminator_naive( img_dest_dim, bn_mode, model, wd, inject_noise, n_classes, use_mbd) DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model, noise_dim, img_source_dim) elif discriminator == "disc_resnet": discriminator_model = models.discriminatorResNet( img_dest_dim, bn_mode, model, wd, monsterClass, inject_noise, n_classes, use_mbd) DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_source_dim, img_dest_dim, monsterClass) else: discriminator_model = models.disc1(img_dest_dim, bn_mode, model, wd, monsterClass, inject_noise, n_classes, use_mbd) DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_source_dim, img_dest_dim, monsterClass) ####special options for bedrooms dataset: if dset == "bedrooms": generator_model = models.generator_dcgan(noise_dim, img_source_dim, img_dest_dim, bn_mode, deterministic, pureGAN, inject_noise, wd) discriminator_model = models.discriminator_naive( img_dest_dim, bn_mode, model, wd, inject_noise, n_classes, use_mbd, model_name="discriminator_naive") DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model, noise_dim, img_source_dim) ############################ # Compile models ############################ generator_model.compile(loss='mse', optimizer=opt_G) models.make_trainable(discriminator_model, False) #discriminator_model.trainable = False if model == 'wgan': DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G) models.make_trainable(discriminator_model, True) discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D) if model == 'lsgan': if simple_disc: DCGAN_model.compile(loss=['mse'], optimizer=opt_G) models.make_trainable(discriminator_model, True) discriminator_model.compile(loss=['mse'], optimizer=opt_D) elif monsterClass: DCGAN_model.compile(loss=['categorical_crossentropy'], optimizer=opt_G) models.make_trainable(discriminator_model, True) discriminator_model.compile(loss=['categorical_crossentropy'], optimizer=opt_D) else: DCGAN_model.compile(loss=['mse', 'categorical_crossentropy'], loss_weights=[1.0, 1.0], optimizer=opt_G) models.make_trainable(discriminator_model, True) discriminator_model.compile( loss=['mse', 'categorical_crossentropy'], loss_weights=[1.0, 1.0], optimizer=opt_D) visualize = True if resume: ########loading previous saved model weights data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name) ##################### ###classifier ##################### if not ((dset == 'mnistM') or (dset == 'bedrooms')): classifier, GenToClassifierModel = classifier_build_test( img_dest_dim, n_classes, generator_model, noise_dim, noise_scale, img_source_dim, opt_C, X_source_test, Y_source_test, X_dest_test, Y_dest_test, wd=0.0001) gen_iterations = 0 max_history_size = int(history_size * batch_size) img_buffer = ImageHistoryBuffer((0, ) + img_source_dim, max_history_size, batch_size, n_classes) ################# # Start training ################ for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: if no_supertrain is None: if (gen_iterations < 25) and (not resume): disc_iterations = 100 if gen_iterations % 500 == 0: disc_iterations = 10 else: disc_iterations = kwargs["disc_iterations"] else: if (gen_iterations < 25) and (not resume): disc_iterations = 100 else: disc_iterations = kwargs["disc_iterations"] ################################### # 1) Train the critic / discriminator ################################### list_disc_loss_real = [] list_disc_loss_gen = [] list_gen_loss = [] for disc_it in range(disc_iterations): # Clip discriminator weights #for l in discriminator_model.layers: # weights = l.get_weights() # weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights] # l.set_weights(weights) X_dest_batch, Y_dest_batch, idx_dest_batch = next( data_utils.gen_batch(X_dest_train, Y_dest_train, batch_size)) X_source_batch, Y_source_batch, idx_source_batch = next( data_utils.gen_batch(X_source_train, Y_source_train, batch_size)) # Create a batch to feed the discriminator model X_disc_real, X_disc_gen = data_utils.get_disc_batch( X_dest_batch, generator_model, batch_counter, batch_size, noise_dim, X_source_batch, noise_scale=noise_scale) if model == 'wgan': # Update the discriminator current_labels_real = -np.ones(X_disc_real.shape[0]) current_labels_gen = np.ones(X_disc_gen.shape[0]) if model == 'lsgan': if simple_disc: #for real domain I put [labels 0 0 0...0], for fake domain I put [0 0...0 labels] current_labels_real = np.ones(X_disc_real.shape[0]) #current_labels_gen = -np.ones(X_disc_gen.shape[0]) current_labels_gen = np.zeros(X_disc_gen.shape[0]) elif monsterClass: #for real domain I put [labels 0 0 0...0], for fake domain I put [0 0...0 labels] current_labels_real = np.concatenate( (Y_dest_batch, np.zeros((X_disc_real.shape[0], n_classes))), axis=1) current_labels_gen = np.concatenate((np.zeros( (X_disc_real.shape[0], n_classes)), Y_source_batch), axis=1) else: current_labels_real = [ np.ones(X_disc_real.shape[0]), Y_dest_batch ] Y_fake_batch = (1.0 / n_classes) * np.ones( [X_disc_gen.shape[0], n_classes]) current_labels_gen = [ np.zeros(X_disc_gen.shape[0]), Y_fake_batch ] #label smoothing #current_labels_real = np.multiply(current_labels_real, lsmooth) #usually lsmooth = 0.7 disc_loss_real = discriminator_model.train_on_batch( X_disc_real, current_labels_real) img_buffer.add_to_buffer(X_disc_gen, current_labels_gen, batch_size) bufferImages, bufferLabels = img_buffer.get_from_buffer( batch_size) disc_loss_gen = discriminator_model.train_on_batch( bufferImages, bufferLabels) list_disc_loss_real.append(disc_loss_real) list_disc_loss_gen.append(disc_loss_gen) ####################### # 2) Train the generator ####################### X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim) X_source_batch2, Y_source_batch2, idx_source_batch2 = next( data_utils.gen_batch(X_source_train, Y_source_train, batch_size)) # w1 = classifier.get_weights() #FOR DEBUG if model == 'wgan': gen_loss = DCGAN_model.train_on_batch([X_gen, X_source_batch2], -np.ones(X_gen.shape[0])) if model == 'lsgan': if simple_disc: gen_loss = DCGAN_model.train_on_batch( [X_gen, X_source_batch2], np.ones(X_gen.shape[0])) #TRYING SAME BATCH OF DISC? elif monsterClass: labels_gen = np.concatenate( (Y_source_batch2, np.zeros((X_disc_real.shape[0], n_classes))), axis=1) gen_loss = DCGAN_model.train_on_batch( [X_gen, X_source_batch2], labels_gen) else: gen_loss = DCGAN_model.train_on_batch( [X_gen, X_source_batch2], [np.ones(X_gen.shape[0]), Y_source_batch2]) # gen_loss2 = GenToClassifierModel.train_on_batch([X_gen,X_source_batch2], Y_source_batch2) # w2 = classifier.get_weights() #FOR DEBUG # for a,b in zip(w1, w2): # if np.all(a == b): # print "no bug in GEN model update" # else: # print "BUG IN GEN MODEL UPDATE" list_gen_loss.append(gen_loss) gen_iterations += 1 batch_counter += 1 progbar.add(batch_size, values=[("Loss_D", 0.5 * np.mean(list_disc_loss_real) + 0.5 * np.mean(list_disc_loss_gen)), ("Loss_D_real", np.mean(list_disc_loss_real)), ("Loss_D_gen", np.mean(list_disc_loss_gen)), ("Loss_G", np.mean(list_gen_loss))]) # plot images 1 times per epoch if batch_counter % (n_batch_per_epoch) == 0: X_source_batch_plot, Y_source_batch_plot, idx_source_plot = next( data_utils.gen_batch(X_source_test, Y_source_test, batch_size=32)) data_utils.plot_generated_batch(X_dest_test, X_source_test, generator_model, noise_dim, image_dim_ordering, idx_source_plot, batch_size=32) if gen_iterations % (n_batch_per_epoch * 5) == 0: if visualize: BIG_ASS_VISUALIZATION_slerp(X_source_train[1], generator_model, noise_dim) print("Dest labels:") print(Y_dest_test[idx_source_plot].argmax(1)) print("Source labels:") print(Y_source_batch_plot.argmax(1)) print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) # Save model weights (by default, every 5 epochs) data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, name)
def train_toy(**kwargs): """ Train model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] clamp_lower = kwargs["clamp_lower"] clamp_upper = kwargs["clamp_upper"] epoch_size = n_batch_per_epoch * batch_size print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") # Setup environment (logging directory etc) general_utils.setup_logging("toy_MLP") # Load and rescale data X_real_train = data_utils.load_toy() # Create optimizers opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_D = data_utils.get_optimizer(opt_D, lr_D) ####################### # Load models ####################### noise_dim = (noise_dim, ) generator_model = models.generator_toy(noise_dim) discriminator_model = models.discriminator_toy() GAN_model = models.GAN_toy(generator_model, discriminator_model, noise_dim) ############################ # Compile models ############################ generator_model.compile(loss='mse', optimizer=opt_G) discriminator_model.trainable = False GAN_model.compile(loss=models.wasserstein, optimizer=opt_G) discriminator_model.trainable = True discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D) # Global iteration counter for generator updates gen_iterations = 0 ################# # Start training ################# for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: disc_iterations = kwargs["disc_iterations"] ################################### # 1) Train the critic / discriminator ################################### list_disc_loss_real = [] list_disc_loss_gen = [] for disc_it in range(disc_iterations): # Clip discriminator weights for l in discriminator_model.layers: weights = l.get_weights() weights = [ np.clip(w, clamp_lower, clamp_upper) for w in weights ] l.set_weights(weights) X_real_batch = next( data_utils.gen_batch(X_real_train, batch_size)) # Create a batch to feed the discriminator model X_disc_real, X_disc_gen = data_utils.get_disc_batch( X_real_batch, generator_model, batch_counter, batch_size, noise_dim, noise_scale=noise_scale) # Update the discriminator disc_loss_real = discriminator_model.train_on_batch( X_disc_real, -np.ones(X_disc_real.shape[0])) disc_loss_gen = discriminator_model.train_on_batch( X_disc_gen, np.ones(X_disc_gen.shape[0])) list_disc_loss_real.append(disc_loss_real) list_disc_loss_gen.append(disc_loss_gen) ####################### # 2) Train the generator ####################### X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim) # Freeze the discriminator discriminator_model.trainable = False gen_loss = GAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0])) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)), ("Loss_D_real", -np.mean(list_disc_loss_real)), ("Loss_D_gen", np.mean(list_disc_loss_gen)), ("Loss_G", -gen_loss)]) # # Save images for visualization if gen_iterations % 50 == 0: data_utils.plot_generated_toy_batch(X_real_train, generator_model, discriminator_model, noise_dim, gen_iterations) gen_iterations += 1 print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_dim_ordering = kwargs["image_dim_ordering"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(dset, image_dim_ordering) img_dim = X_full_train.shape[-3:] # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_dim_ordering) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size) generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_dim_ordering) loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_full_batch, X_sketch_batch in data_utils.gen_batch(X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch(X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_dim_ordering, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next(data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_dim_ordering, "training") X_full_batch, X_sketch_batch = next(data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_dim_ordering, "validation") if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] discriminator = kwargs["discriminator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] bn_mode = kwargs["bn_mode"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] use_mbd = kwargs["use_mbd"] image_dim_ordering = kwargs["image_dim_ordering"] epoch_size = n_batch_per_epoch * batch_size deterministic = kwargs["deterministic"] inject_noise = kwargs["inject_noise"] model = kwargs["model"] no_supertrain = kwargs["no_supertrain"] pureGAN = kwargs["pureGAN"] lsmooth = kwargs["lsmooth"] disc_type = kwargs["disc_type"] resume = kwargs["resume"] name = kwargs["name"] wd = kwargs["wd"] history_size = kwargs["history_size"] monsterClass = kwargs["monsterClass"] data_aug = kwargs["data_aug"] disc_iters = kwargs["disc_iterations"] print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") #####some extra parameters: noise_dim = (noise_dim,) name1 = name + '1' name2 = name + '2' # Setup environment (logging directory etc) general_utils.setup_logging("DCGAN") gen_iterations = 0 # Loading data A_data, A_labels, B_data, B_labels, n_classes, img_A_dim, img_B_dim = load_data( img_dim, image_dim_ordering, dset) # Setup GAN1 deterministic1 = False opt_D1, opt_G1, opt_C1, opt_Z1 = build_opt(opt_D, opt_G, lr_D, lr_G) generator_model1, discriminator_model1,discriminator_class1, classificator_model1, DCGAN_model1, zclass_model1 = load_compile_models(noise_dim, img_A_dim, img_B_dim, deterministic1, pureGAN, wd, 'mse', 'categorical_crossentropy', disc_type, n_classes, opt_D1, opt_G1, opt_C1, opt_Z1) load_pretrained_weights(generator_model1, discriminator_model1,discriminator_class1, DCGAN_model1, name1, B_data, B_labels, noise_scale, classificator_model1, resume=resume) img_buffer1, datagen1 = load_buffer_and_augmentation(history_size, batch_size, img_A_dim, n_classes) ##temporary settings: gen_entropy1=None GAN1=_GAN(generator_model1, discriminator_model1, discriminator_class1,DCGAN_model1,gen_entropy1,classificator_model1, batch_size, img_A_dim,img_B_dim, noise_dim, noise_scale, lr_D, lr_G, deterministic1, inject_noise, model, lsmooth, img_buffer1, datagen1, disc_type, data_aug, n_classes, disc_iters,name1, dir='AtoB' ) pretrain_disc( GAN1, A_data,A_labels, B_data, B_labels, pretrain_iters=500, resume=resume) ##################### # Setup GAN2 deterministic2 = True opt_D2, opt_G2, opt_C2, opt_Z2 = build_opt(opt_D, opt_G, lr_D, lr_G) generator_model2, discriminator_model2, discriminator_class2, classificator_model2, DCGAN_model2, zclass_model2 = load_compile_models(noise_dim, img_B_dim, img_A_dim, deterministic2, pureGAN, wd, 'mse', 'categorical_crossentropy', disc_type, n_classes, opt_D2, opt_G2, opt_C2, opt_Z2) load_pretrained_weights(generator_model2, discriminator_model2,discriminator_class2, DCGAN_model2, name2, B_data, B_labels, noise_scale, classificator_model2, resume=resume) img_buffer2, datagen2 = load_buffer_and_augmentation(history_size, batch_size, img_B_dim, n_classes) ##temporary settings: gen_entropy2=None GAN2=_GAN(generator_model2, discriminator_model2, discriminator_class2, DCGAN_model2,gen_entropy2,classificator_model2, batch_size, img_B_dim,img_A_dim, noise_dim, noise_scale, lr_D, lr_G, deterministic2, inject_noise, model, lsmooth, img_buffer2, datagen2, disc_type, data_aug, n_classes, disc_iters, name2, dir='BtoA' ) pretrain_disc( GAN2, A_data,A_labels, B_data, B_labels, pretrain_iters=500, resume=resume) ################ ################## for e in range(1, nb_epoch + 1): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size,interval=0.2) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: l_disc_real1, l_disc_gen1, l_gen1, l_z1, l_class1 = get_loss_list() A_data_batch, A_labels_batch, B_data_batch, B_labels_batch = train_gan(GAN1, GAN1.disc_iters, A_data, A_labels, B_data, B_labels, batch_counter, l_disc_real1, l_disc_gen1, l_gen1) l_class1 = train_class(GAN1, l_class1, A_data_batch, A_labels_batch) l_disc_real2, l_disc_gen2, l_gen2, l_z2, l_class2 = get_loss_list() A_data_batch, A_labels_batch, B_data_batch, B_labels_batch = train_gan(GAN2, GAN2.disc_iters, A_data, A_labels, B_data, B_labels, batch_counter, l_disc_real2, l_disc_gen2,l_gen2) l_class2 = train_class(GAN2, l_class2, A_data_batch, A_labels_batch) batch_counter, gen_iterations = visualize_save_stuffs([GAN1,GAN2], progbar, gen_iterations, batch_counter, n_batch_per_epoch, l_disc_real1, l_disc_gen1, l_gen1, l_class1, l_disc_real2, l_disc_gen2, l_gen2, l_class2, A_data, A_labels, B_data, B_labels,start,e) #gen_iterations, batch_counter, idx, Yplot testing_class_accuracy([GAN1,GAN2],GAN1.classificator_model, GAN1.generator_model, 5000, GAN1.noise_dim, GAN1.noise_scale, B_data, B_labels)
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] # right strip '/' to avoid empty '/' dir save_dir = kwargs["save_dir"].rstrip('/') # join name with current datetime save_dir = '_'.join( [save_dir, datetime.datetime.now().strftime("%I:%M%p-%B%d-%Y/")]) if not os.path.isdir(save_dir): os.makedirs(save_dir) # save the config in save dir with open('{0}job_config.json'.format(save_dir), 'w') as fp: json.dump(kwargs, fp, sort_keys=True, indent=4) epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data( dset, image_data_format) img_dim = X_full_train.shape[-3:] # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size) generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_full_batch, X_sketch_batch in data_utils.gen_batch( X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch( X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next( data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation data_utils.plot_generated_batch( X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, "{:03}_EPOCH_TRAIN".format(e + 1), save_dir) X_full_batch, X_sketch_batch = next( data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch( X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, "{:03}_EPOCH_VALID".format(e + 1), save_dir) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: pass # save models # gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) # generator_model.save_weights(gen_weights_path, overwrite=True) # disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) # discriminator_model.save_weights(disc_weights_path, overwrite=True) # DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) # DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass # save models DCGAN_model.save(save_dir + 'DCGAN.h5') generator_model.save(save_dir + 'GENERATOR.h5') discriminator_model.save(save_dir + 'DISCRIMINATOR.h5')
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] generator = kwargs["generator"] model_name = kwargs["model_name"] image_dim_ordering = kwargs["image_dim_ordering"] img_dim = kwargs["img_dim"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["label_smoothing"] label_flipping = kwargs["label_flipping"] noise_scale = kwargs["noise_scale"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data if dset == "celebA": X_real_train = data_utils.load_celebA(img_dim, image_dim_ordering) if dset == "mnist": X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering) img_dim = X_real_train.shape[-3:] noise_dim = (100,) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.5, beta_2=0.999, epsilon=1e-08) opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) # Load generator model generator_model = models.load("generator_%s" % generator, noise_dim, img_dim, bn_mode, batch_size, dset=dset, use_mbd=use_mbd) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", noise_dim, img_dim, bn_mode, batch_size, dset=dset, use_mbd=use_mbd) generator_model.compile(loss='mse', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_dim) loss = ['binary_crossentropy'] loss_weights = [1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_real_batch in data_utils.gen_batch(X_real_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch(X_real_batch, generator_model, batch_counter, batch_size, noise_dim, noise_scale=noise_scale, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen, y_gen = data_utils.get_gen_batch(batch_size, noise_dim, noise_scale=noise_scale) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G logloss", gen_loss)]) # Save images for visualization if batch_counter % 100 == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, batch_size, noise_dim, image_dim_ordering) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def eval(**kwargs): # Roll out the parameters batch_size = kwargs["batch_size"] generator = kwargs["generator"] model_name = kwargs["model_name"] image_dim_ordering = kwargs["image_dim_ordering"] img_dim = kwargs["img_dim"] cont_dim = (kwargs["cont_dim"],) cat_dim = (kwargs["cat_dim"],) noise_dim = (kwargs["noise_dim"],) bn_mode = kwargs["bn_mode"] noise_scale = kwargs["noise_scale"] dset = kwargs["dset"] epoch = kwargs["epoch"] # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data if dset == "RGZ": X_real_train = data_utils.load_RGZ(img_dim, image_dim_ordering) if dset == "mnist": X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering) img_dim = X_real_train.shape[-3:] # Load generator model generator_model = models.load("generator_%s" % generator, cat_dim, cont_dim, noise_dim, img_dim, bn_mode, batch_size, dset=dset) # Load colorization model generator_model.load_weights("../../models/%s/gen_weights_epoch%s.h5" % (model_name, epoch)) X_plot = [] # Vary the categorical variable for i in range(cat_dim[0]): X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim) X_cont = data_utils.sample_noise(noise_scale, batch_size, cont_dim) X_cont = np.repeat(X_cont[:1, :], batch_size, axis=0) # fix continuous noise X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32') X_cat[:, i] = 1 # always the same categorical value X_gen = generator_model.predict([X_cat, X_cont, X_noise]) X_gen = data_utils.inverse_normalization(X_gen) if image_dim_ordering == "th": X_gen = X_gen.transpose(0,2,3,1) X_gen = [X_gen[i] for i in range(len(X_gen))] X_plot.append(np.concatenate(X_gen, axis=1)) X_plot = np.concatenate(X_plot, axis=0) plt.figure(figsize=(8,10)) if X_plot.shape[-1] == 1: plt.imshow(X_plot[:, :, 0], cmap="gray") else: plt.imshow(X_plot) plt.xticks([]) plt.yticks([]) plt.ylabel("Varying categorical factor", fontsize=28, labelpad=60) plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1), arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4)) plt.tight_layout() plt.savefig("../../figures/varying_categorical.png") plt.clf() plt.close() # Vary the continuous variables X_plot = [] # First get the extent of the noise sampling x = np.ravel(data_utils.sample_noise(noise_scale, batch_size * 20000, cont_dim)) # Define interpolation points x = np.linspace(x.min(), x.max(), num=batch_size) for i in range(batch_size): X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim) X_cont = np.concatenate([np.array([x[i], x[j]]).reshape(1, -1) for j in range(batch_size)], axis=0) X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32') X_cat[:, 1] = 1 # always the same categorical value X_gen = generator_model.predict([X_cat, X_cont, X_noise]) X_gen = data_utils.inverse_normalization(X_gen) if image_dim_ordering == "th": X_gen = X_gen.transpose(0,2,3,1) X_gen = [X_gen[i] for i in range(len(X_gen))] X_plot.append(np.concatenate(X_gen, axis=1)) X_plot = np.concatenate(X_plot, axis=0) plt.figure(figsize=(10,10)) if X_plot.shape[-1] == 1: plt.imshow(X_plot[:, :, 0], cmap="gray") else: plt.imshow(X_plot) plt.xticks([]) plt.yticks([]) plt.ylabel("Varying continuous factor 1", fontsize=28, labelpad=60) plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1), arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4)) plt.xlabel("Varying continuous factor 2", fontsize=28, labelpad=60) plt.annotate('', xy=(1, -0.05), xycoords='axes fraction', xytext=(0, -0.05), arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4)) plt.tight_layout() plt.savefig("../../figures/varying_continuous.png") plt.clf() plt.close()
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_dim_ordering = kwargs["image_dim_ordering"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) print "hi" # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data( dset, image_dim_ordering) img_dim = X_full_train.shape[-3:] print "data loaded in memory" # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_dim_ordering) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size) generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_dim_ordering) loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = None disc_loss = None iter_num = 102 weights_path = "/home/abhik/pix2pix/src/model/weights/gen_weights_iter%s_epoch30.h5" % ( str(iter_num - 1)) print weights_path generator_model.load_weights(weights_path) #discriminator_model.load_weights("disc_weights1.2.h5") #DCGAN_model.load_weights("DCGAN_weights1.2.h5") print("Weights Loaded for iter - %d" % iter_num) # Running average losses_list = list() # loss_list = list() # prev_avg = 0 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() # global disc_n, disc_prev_avg, gen1_n, gen1_prev_avg, gen2_n, gen2_prev_avg, gen3_n, gen3_prev_avg # disc_n = 1 # disc_prev_avg = 0 # gen1_n = 1 # gen1_prev_avg = 0 # gen2_n = 1 # gen2_prev_avg = 0 # gen3_n = 1 # gen3_prev_avg = 0 for X_full_batch, X_sketch_batch in data_utils.gen_batch( X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch( X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_dim_ordering, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next( data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True # Running average # loss_list.append(disc_loss) # loss_list_n = len(loss_list) # new_avg = ((loss_list_n-1)*prev_avg + disc_loss)/loss_list_n # prev_avg = new_avg # disc_avg, gen1_avg, gen2_avg, gen3_avg = running_avg(disc_loss, gen_loss[0], gen_loss[1], gen_loss[2]) # print("running disc loss", new_avg) # print(disc_loss, gen_loss) # print ("all losses", disc_avg, gen1_avg, gen2_avg, gen3_avg) # print("") batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Saving data for plotting # losses = [e+1, batch_counter, disc_loss, gen_loss[0], gen_loss[1], gen_loss[2], disc_avg, gen1_avg, gen2_avg, gen3_avg, iter_num] # losses_list.append(losses) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation data_utils.plot_generated_batch( X_full_batch, X_sketch_batch, generator_model, batch_size, image_dim_ordering, "training", iter_num) X_full_batch, X_sketch_batch = next( data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch( X_full_batch, X_sketch_batch, generator_model, batch_size, image_dim_ordering, "validation", iter_num) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) #Running average disc_avg, gen1_avg, gen2_avg, gen3_avg = running_avg( disc_loss, gen_loss[0], gen_loss[1], gen_loss[2]) #Validation loss y_gen_val = np.zeros((X_sketch_batch.shape[0], 2), dtype=np.uint8) y_gen_val[:, 1] = 1 val_loss = DCGAN_model.test_on_batch(X_full_batch, [X_sketch_batch, y_gen_val]) # print "val_loss ===" + str(val_loss) #logging # Saving data for plotting losses = [ e + 1, iter_num, disc_loss, gen_loss[0], gen_loss[1], gen_loss[2], disc_avg, gen1_avg, gen2_avg, gen3_avg, val_loss[0], val_loss[1], val_loss[2] ] losses_list.append(losses) if (e + 1) % 5 == 0: gen_weights_path = os.path.join( '../../models/%s/gen_weights_iter%s_epoch%s.h5' % (model_name, iter_num, e + 1)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/%s/disc_weights_iter%s_epoch%s.h5' % (model_name, iter_num, e + 1)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/%s/DCGAN_weights_iter%s_epoch%s.h5' % (model_name, iter_num, e + 1)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) loss_array = np.asarray(losses_list) print(loss_array.shape) # 10 element vector loss_path = os.path.join( '../../losses/loss_iter%s_epoch%s.csv' % (iter_num, e + 1)) np.savetxt(loss_path, loss_array, fmt='%.5f', delimiter=',') np.savetxt('test.csv', loss_array, fmt='%.5f', delimiter=',') except KeyboardInterrupt: pass
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] prob = kwargs["prob"] training_data_file = kwargs["training_data_file"] experiment = kwargs["experiment"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(experiment) # Create a batch generator for the color data DataAug = batch_utils.AugDataGenerator(training_data_file, batch_size=batch_size, prob=prob, dset="training_color") DataAug.add_transform("h_flip") # Load the array of quantized ab value q_ab = np.load("../../data/processed/pts_in_hull.npy") nb_q = q_ab.shape[0] nb_neighbors = 10 # Fit a NN to q_ab nn_finder = nn.NearestNeighbors(n_neighbors=nb_neighbors, algorithm='ball_tree').fit(q_ab) # Load the color prior factor that encourages rare colors prior_factor = np.load("../../data/processed/training_64_prior_factor.npy") # Load and rescale data print("Loading data") with h5py.File(training_data_file, "r") as hf: X_train = hf["training_lab_data"][:100] c, h, w = X_train.shape[1:] print("Data loaded") for f in glob.glob("*.h5"): os.remove(f) for f in glob.glob("../../reports/figures/*.png"): os.remove(f) try: # Create optimizers # opt = SGD(lr=5E-4, momentum=0.9, nesterov=True) opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load colorizer model color_model = models.load("simple_colorful", nb_q, (1, h, w), batch_size) color_model.compile(loss='categorical_crossentropy_color', optimizer=opt) color_model.summary() from keras.utils.visualize_util import plot plot(color_model, to_file='colorful.png') # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for batch in DataAug.gen_batch_colorful(X_train, nn_finder, nb_q, prior_factor): X_batch_black, X_batch_color, Y_batch = batch # X = color_model.predict(X_batch_black) # print color_model.evaluate(X_batch_black, Y_batch) # X = color_model.predict(X_batch_black) # print X[0, 0, 0, :] train_loss = color_model.train_on_batch(X_batch_black / 100., Y_batch) batch_counter += 1 progbar.add(batch_size, values=[("loss", train_loss)]) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) # Format X_colorized X_colorized = color_model.predict(X_batch_black / 100.)[:, :, :, :-1] X_colorized = X_colorized.reshape((batch_size * h * w, nb_q)) X_colorized = q_ab[np.argmax(X_colorized, 1)] X_a = X_colorized[:, 0].reshape((batch_size, 1, h, w)) X_b = X_colorized[:, 1].reshape((batch_size, 1, h, w)) X_colorized = np.concatenate((X_batch_black, X_a, X_b), axis=1).transpose(0, 2, 3, 1) X_colorized = [np.expand_dims(color.lab2rgb(im), 0) for im in X_colorized] X_colorized = np.concatenate(X_colorized, 0).transpose(0, 3, 1, 2) X_batch_color = [np.expand_dims(color.lab2rgb(im.transpose(1, 2, 0)), 0) for im in X_batch_color] X_batch_color = np.concatenate(X_batch_color, 0).transpose(0, 3, 1, 2) print X_batch_color.shape, X_colorized.shape, X_batch_black.shape for i, img in enumerate(X_colorized[:min(32, batch_size)]): arr = np.concatenate([X_batch_color[i], np.repeat(X_batch_black[i] / 100., 3, axis=0), img], axis=2) np.save("../../reports/gen_image_%s.npy" % i, arr) plt.figure(figsize=(20,20)) list_img = glob.glob("../../reports/*.npy") list_img = [np.load(im) for im in list_img] list_img = [np.concatenate(list_img[4 * i: 4 * (i + 1)], axis=2) for i in range(len(list_img) / 4)] arr = np.concatenate(list_img, axis=1) plt.imshow(arr.transpose(1,2,0)) ax = plt.gca() ax.get_xaxis().set_ticks([]) ax.get_yaxis().set_ticks([]) plt.tight_layout() plt.savefig("../../reports/figures/fig_epoch%s.png" % e) plt.clf() plt.close() except KeyboardInterrupt: pass
def trainClassAux(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] discriminator = kwargs["discriminator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] bn_mode = kwargs["bn_mode"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] clamp_lower = kwargs["clamp_lower"] clamp_upper = kwargs["clamp_upper"] image_dim_ordering = kwargs["image_dim_ordering"] epoch_size = n_batch_per_epoch * batch_size deterministic = kwargs["deterministic"] inject_noise = kwargs["inject_noise"] model = kwargs["model"] no_supertrain = kwargs["no_supertrain"] noClass = kwargs["noClass"] resume = kwargs["resume"] name = kwargs["name"] wd = kwargs["wd"] C_weight = kwargs["C_weight"] monsterClass = kwargs["monsterClass"] pretrained = kwargs["pretrained"] print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") # Setup environment (logging directory etc) general_utils.setup_logging("DCGAN") # Load and normalize data if dset == "mnistM": X_source_train,Y_source_train, X_source_test, Y_source_test, n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnist') X_dest_train,Y_dest_train, X_dest_test, Y_dest_test,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnistM') elif dset == "washington_vandal50k": X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington') X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal50k') elif dset == "washington_vandal12classes": X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes') X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classes') elif dset == "washington_vandal12classesNoBackground": X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes') X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classesNoBackground') elif dset == "Wash_Vand_12class_LMDB": X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_12class_LMDB') X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB') elif dset == "Vand_Vand_12class_LMDB": X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB_Background') X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB') elif dset == "Wash_Color_LMDB": X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB') X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB') else: print "dataset not supported" if n_classes1 != n_classes2: #sanity check print "number of classes mismatch between source and dest domains" n_classes = n_classes1 # img_source_dim = X_source_train.shape[-3:] # is it backend agnostic? img_dest_dim = X_dest_train.shape[-3:] X_source_train.flags.writeable = False X_source_test.flags.writeable = False X_dest_train.flags.writeable = False X_dest_test.flags.writeable = False # Create optimizers opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_G_C = data_utils.get_optimizer(opt_G, lr_G*C_weight) opt_D = data_utils.get_optimizer(opt_D, lr_D) opt_C = data_utils.get_optimizer('SGD', 0.01) ####################### # Load models ####################### noise_dim = (noise_dim,) if generator == "upsampling": generator_model = models.generator_upsampling_mnistM(noise_dim, img_source_dim,img_dest_dim, bn_mode,deterministic,inject_noise,wd, dset=dset) else: generator_model = models.generator_deconv(noise_dim, img_dest_dim, bn_mode, batch_size, dset=dset) discriminator_model = models.discriminator(img_dest_dim, bn_mode,model,wd,monsterClass,inject_noise,n_classes) DCGAN_model = models.DCGAN_naive(generator_model, discriminator_model, noise_dim, img_source_dim) classifier = models.resnet(img_dest_dim,n_classes,pretrained,wd=0.0001) #it is img_dest_dim because it is actually the generated image dim,that is equal to dest_dim GenToClassifierModel = models.GenToClassifierModel(generator_model, classifier, noise_dim, img_source_dim) ############################ # Load weights ############################ if resume: data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name) #if pretrained: # model_path = "../../models/DCGAN" # class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5') # classifier.load_weights(class_weights_path) ############################ # Compile models ############################ generator_model.compile(loss='mse', optimizer=opt_G) if model == 'wgan': discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D) models.make_trainable(discriminator_model, False) DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G) if model == 'lsgan': discriminator_model.compile(loss='mse', optimizer=opt_D) models.make_trainable(discriminator_model, False) DCGAN_model.compile(loss='mse', optimizer=opt_G) classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) # it is actually never using optimizer models.make_trainable(classifier, False) GenToClassifierModel.compile(loss='categorical_crossentropy', optimizer=opt_G,metrics=['accuracy']) ####################### # Train classifier ####################### # if not pretrained: # #print ("Testing accuracy on target domain test set before training:") # loss1,acc1 =classifier.evaluate(X_dest_test, Y_dest_test,batch_size=256, verbose=0) # print('\n Classifier Accuracy on target domain test set before training: %.2f%%' % (100 * acc1)) # classifier.fit(X_dest_train, Y_dest_train, validation_split=0.1, batch_size=512, nb_epoch=10, verbose=1) # print ("\n Testing accuracy on target domain test set AFTER training:") # else: # print ("Loaded pretrained classifier, computing accuracy on target domain test set:") # loss2,acc2 = classifier.evaluate(X_dest_test, Y_dest_test,batch_size=512, verbose=0) # print('\n Classifier Accuracy on target domain test set after training: %.2f%%' % (100 * acc2)) #print ("Testing accuracy on source domain test set:") # loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0) # print('\n Classifier Accuracy on source domain test set: %.2f%%' % (100 * acc3)) # evaluating_GENned(noise_scale, noise_dim, X_source_test, Y_source_test, classifier, generator_model) # model_path = "../../models/DCGAN" # class_weights_path = os.path.join(model_path, 'VandToVand_5epochs.h5') # classifier.save_weights(class_weights_path, overwrite=True) ################# # GAN training ################ gen_iterations = 0 for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: if no_supertrain is None: if (gen_iterations < 25) and (not resume): disc_iterations = 100 if gen_iterations % 500 == 0: disc_iterations = 100 else: disc_iterations = kwargs["disc_iterations"] else: if (gen_iterations <25) and (not resume): disc_iterations = 100 else: disc_iterations = kwargs["disc_iterations"] ################################### # 1) Train the critic / discriminator ################################### list_disc_loss_real = [] list_disc_loss_gen = [] list_gen_loss = [] list_class_loss_real = [] for disc_it in range(disc_iterations): # Clip discriminator weights # for l in discriminator_model.layers: # weights = l.get_weights() # weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights] # l.set_weights(weights) X_dest_batch, Y_dest_batch,idx_dest_batch = next(data_utils.gen_batch(X_dest_train, Y_dest_train, batch_size)) X_source_batch, Y_source_batch,idx_source_batch = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size)) # Create a batch to feed the discriminator model X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_dest_batch, generator_model, batch_counter, batch_size, noise_dim, X_source_batch, noise_scale=noise_scale) if model == 'wgan': # Update the discriminator disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0])) disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0])) if model == 'lsgan': disc_loss_real = discriminator_model.train_on_batch(X_disc_real, np.ones(X_disc_real.shape[0])) disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.zeros(X_disc_gen.shape[0])) list_disc_loss_real.append(disc_loss_real) list_disc_loss_gen.append(disc_loss_gen) ####################### # 2) Train the generator with GAN loss ####################### X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim) source_images = X_source_train[np.random.randint(0,X_source_train.shape[0],size=batch_size),:,:,:] X_source_batch2, Y_source_batch2,idx_source_batch2 = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size)) # Freeze the discriminator # discriminator_model.trainable = False if model == 'wgan': gen_loss = DCGAN_model.train_on_batch([X_gen,X_source_batch2], -np.ones(X_gen.shape[0])) if model == 'lsgan': gen_loss = DCGAN_model.train_on_batch([X_gen,X_source_batch2], np.ones(X_gen.shape[0])) list_gen_loss.append(gen_loss) ####################### # 3) Train the generator with Classifier loss ####################### w1 = classifier.get_weights() #FOR DEBUG if not noClass: new_gen_loss = GenToClassifierModel.train_on_batch([X_gen,X_source_batch2], Y_source_batch2) list_class_loss_real.append(new_gen_loss) else: list_class_loss_real.append(0.0) w2 = classifier.get_weights() #FOR DEBUG for a,b in zip(w1, w2): if np.all(a == b): print "no bug in GEN model update" else: print "BUG IN GEN MODEL UPDATE" gen_iterations += 1 batch_counter += 1 progbar.add(batch_size, values=[("Loss_D", 0.5*np.mean(list_disc_loss_real) + 0.5*np.mean(list_disc_loss_gen)), ("Loss_D_real", np.mean(list_disc_loss_real)), ("Loss_D_gen", np.mean(list_disc_loss_gen)), ("Loss_G", np.mean(list_gen_loss)), ("Loss_classifier", np.mean(list_class_loss_real))]) # plot images 1 times per epoch if batch_counter % (n_batch_per_epoch) == 0: # train_WGAN.plot_images(X_dest_batch) X_dest_batch_plot,Y_dest_batch_plot,idx_dest_plot = next(data_utils.gen_batch(X_dest_train,Y_dest_train, batch_size=32)) X_source_batch_plot,Y_source_batch_plot,idx_source_plot = next(data_utils.gen_batch(X_source_train,Y_source_train, batch_size=32)) data_utils.plot_generated_batch(X_dest_train,X_source_train, generator_model, noise_dim, image_dim_ordering,idx_source_plot,batch_size=32) print ("Dest labels:") print (Y_dest_train[idx_source_plot].argmax(1)) print ("Source labels:") print (Y_source_batch_plot.argmax(1)) print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) # Save model weights (by default, every 5 epochs) data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, name) evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model) loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0) print('\n Classifier Accuracy on source domain test set: %.2f%%' % (100 * acc3))
image_data_format = 'channels_last' img_dim = 256 patch_size = [128, 128] bn_mode = 2 label_smoothing = False label_flipping = 0 data_folder = '/Lab1/Lab6/data_pix2pix/data/processed/' dset = 'chest_xray' use_mbd = False do_plot = False logging_dir = './pix2pix/logging_dir_pix2pix/' epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) setup_logging(model_name, logging_dir=logging_dir) # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val = load_data( data_folder, dset, image_data_format) img_dim = X_full_train.shape[-3:] # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
def trainDECO(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] discriminator = kwargs["discriminator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] bn_mode = kwargs["bn_mode"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] clamp_lower = kwargs["clamp_lower"] clamp_upper = kwargs["clamp_upper"] image_dim_ordering = kwargs["image_dim_ordering"] epoch_size = n_batch_per_epoch * batch_size deterministic = kwargs["deterministic"] inject_noise = kwargs["inject_noise"] model = kwargs["model"] no_supertrain = kwargs["no_supertrain"] noClass = kwargs["noClass"] resume = kwargs["resume"] name = kwargs["name"] wd = kwargs["wd"] monsterClass = kwargs["monsterClass"] pretrained = kwargs["pretrained"] print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") # Setup environment (logging directory etc) general_utils.setup_logging("DCGAN") # Load and normalize data if dset == "mnistM": X_source_train,Y_source_train, _, _, n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnist') X_dest_train,Y_dest_train, _, _,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='mnistM') elif dset == "washington_vandal50k": X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington') X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal50k') elif dset == "washington_vandal12classes": X_source_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes') X_dest_train = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classes') elif dset == "washington_vandal12classesNoBackground": X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='washington12classes') X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='vandal12classesNoBackground') elif dset == "Wash_Vand_12class_LMDB": X_source_train,Y_source_train,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_12class_LMDB') X_dest_train,Y_dest_train,n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB') elif dset == "Vand_Vand_12class_LMDB": X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB_Background') X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Vand_12class_LMDB') elif dset == "Wash_Color_LMDB": X_source_train,Y_source_train,X_source_test, Y_source_test,n_classes1 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB') X_dest_train,Y_dest_train, X_dest_test, Y_dest_test, n_classes2 = data_utils.load_image_dataset(img_dim, image_dim_ordering,dset='Wash_Color_LMDB') else: print "dataset not supported" if n_classes1 != n_classes2: #sanity check print "number of classes mismatch between source and dest domains" n_classes = n_classes1 # Get the full real image dimension img_source_dim = X_source_train.shape[-3:] # is it backend agnostic? img_dest_dim = X_dest_train.shape[-3:] # Create optimizers opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_D = data_utils.get_optimizer(opt_D, lr_D) opt_C = data_utils.get_optimizer('SGD', 0.01) ####################### # Load models ####################### noise_dim = (noise_dim,) generator_model = models.generator_upsampling_mnistM(noise_dim, img_source_dim,img_dest_dim, bn_mode,deterministic,inject_noise,wd, dset=dset) classifier = models.resnet(img_dest_dim,n_classes,wd=0.0001) #it is img_dest_dim because it is actually the generated image dim,that is equal to dest_dim GenToClassifierModel = models.GenToClassifierModel(generator_model, classifier, noise_dim, img_source_dim) ################# # Load weight ################ if pretrained: model_path = "../../models/DCGAN" class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5') classifier.load_weights(class_weights_path) # if resume: ########loading previous saved model weights # data_utils.load_model_weights(generator_model, discriminator_model, DCGAN_model, name) ####################### # Compile models ####################### generator_model.compile(loss='mse', optimizer=opt_G) # classifier.trainable = True # I wanna freeze the classifier without any training updates classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) # it is actually never using optimizer models.make_trainable(classifier, False) GenToClassifierModel.compile(loss='categorical_crossentropy', optimizer=opt_G,metrics=['accuracy']) ####################### # Train classifier ####################### if not pretrained: loss1,acc1 =classifier.evaluate(X_dest_test, Y_dest_test,batch_size=256, verbose=0) print('\n Classifier Accuracy on target domain test set before training: %.00f%%' % (100.0 * acc1)) classifier.fit(X_dest_train, Y_dest_train, validation_split=0.1, batch_size=512, nb_epoch=20, verbose=1) else: print ("Loaded pretrained classifier, computing accuracy on target domain test set:") loss2,acc2 = classifier.evaluate(X_dest_test, Y_dest_test,batch_size=512, verbose=0) print('\n Classifier Accuracy on target domain test set after training: %.00f%%' % (100.0 * acc2)) loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0) print('\n Classifier Accuracy on source domain test set: %.00f%%' % (100.0 * acc3)) evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model) # model_path = "../../models/DCGAN" # class_weights_path = os.path.join(model_path, 'NoBackground_100epochs.h5') # classifier.save_weights(class_weights_path, overwrite=True) # models.make_trainable(classifier, False) #classifier.trainable = False # I wanna freeze the classifier without any more training updates # classifier.compile(loss='categorical_crossentropy', optimizer=opt_C,metrics=['accuracy']) ################# # DECO training ################ gen_iterations = 0 for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() ################################### # 1) Train the critic / discriminator ################################### list_class_loss_real = [] X_dest_batch, Y_dest_batch,idx_dest_batch = next(data_utils.gen_batch(X_dest_train, Y_dest_train, batch_size)) X_source_batch, Y_source_batch,idx_source_batch = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size)) ####################### # 2) Train the generator ####################### X_gen = data_utils.sample_noise(noise_scale, batch_size*n_batch_per_epoch, noise_dim) X_source_batch2, Y_source_batch2,idx_source_batch2 = next(data_utils.gen_batch(X_source_train, Y_source_train, batch_size*n_batch_per_epoch)) GenToClassifierModel.fit([X_gen, X_source_batch2], Y_source_batch2, batch_size=256, nb_epoch=1, verbose=1) gen_iterations += 1 batch_counter += 1 # plot images 1 times per epoch X_dest_batch_plot,Y_dest_batch_plot,idx_dest_plot = next(data_utils.gen_batch(X_dest_train,Y_dest_train, batch_size=32)) X_source_batch_plot,Y_source_batch_plot,idx_source_plot = next(data_utils.gen_batch(X_source_train,Y_source_train, batch_size=32)) data_utils.plot_generated_batch(X_dest_train,X_source_train, generator_model, noise_dim, image_dim_ordering,idx_source_plot,batch_size=32) print ("Dest labels:") print (Y_dest_train[idx_source_plot].argmax(1)) print ("Source labels:") print (Y_source_batch_plot.argmax(1)) print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) # Save model weights (by default, every 5 epochs) #data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, name) evaluating_GENned(noise_scale,noise_dim,X_source_test,Y_source_test,classifier,generator_model) loss3, acc3 = classifier.evaluate(X_source_test, Y_source_test,batch_size=512, verbose=0) print('\n Classifier Accuracy on source domain test set: %.00f%%' % (100.0 * acc3))
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] generator = kwargs["generator"] model_name = kwargs["model_name"] image_data_format = kwargs["image_data_format"] celebA_img_dim = kwargs["celebA_img_dim"] cont_dim = (kwargs["cont_dim"], ) cat_dim = (kwargs["cat_dim"], ) noise_dim = (kwargs["noise_dim"], ) label_smoothing = kwargs["label_smoothing"] label_flipping = kwargs["label_flipping"] noise_scale = kwargs["noise_scale"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] load_from_dir = kwargs["load_from_dir"] target_size = kwargs["target_size"] save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"] save_only_last_n_weights = kwargs["save_only_last_n_weights"] visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(**kwargs) # Load and rescale data if dset == "celebA": X_real_train = data_utils.load_celebA(celebA_img_dim, image_data_format) elif dset == "mnist": X_real_train, _, _, _ = data_utils.load_mnist(image_data_format) else: X_batch_gen = data_utils.data_generator_from_dir( dset, target_size, batch_size) X_real_train = next(X_batch_gen) img_dim = X_real_train.shape[-3:] try: # Create optimizers opt_dcgan = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08) opt_discriminator = Adam(lr=1E-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True) # Load generator model generator_model = models.load("generator_%s" % generator, cat_dim, cont_dim, noise_dim, img_dim, batch_size, dset=dset, use_mbd=use_mbd) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", cat_dim, cont_dim, noise_dim, img_dim, batch_size, dset=dset, use_mbd=use_mbd) generator_model.compile(loss='mse', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, cat_dim, cont_dim, noise_dim) list_losses = [ 'binary_crossentropy', 'categorical_crossentropy', gaussian_loss ] list_weights = [1, 1, 1] DCGAN_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_dcgan) # Multiple discriminator losses discriminator_model.trainable = True discriminator_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 if not load_from_dir: X_batch_gen = data_utils.gen_batch(X_real_train, batch_size) # Start training print("Start training") disc_total_losses = [] disc_log_losses = [] disc_cat_losses = [] disc_cont_losses = [] gen_total_losses = [] gen_log_losses = [] gen_cat_losses = [] gen_cont_losses = [] start = time.time() for e in range(nb_epoch): print('--------------------------------------------') print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d}\n'.format( datetime.datetime.now(), e + 1, nb_epoch)) # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 disc_total_loss_batch = 0 disc_log_loss_batch = 0 disc_cat_loss_batch = 0 disc_cont_loss_batch = 0 gen_total_loss_batch = 0 gen_log_loss_batch = 0 gen_cat_loss_batch = 0 gen_cont_loss_batch = 0 for batch_counter in range(n_batch_per_epoch): # Load data X_real_batch = next(X_batch_gen) # Create a batch to feed the discriminator model X_disc, y_disc, y_cat, y_cont = data_utils.get_disc_batch( X_real_batch, generator_model, batch_counter, batch_size, cat_dim, cont_dim, noise_dim, noise_scale=noise_scale, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch( X_disc, [y_disc, y_cat, y_cont]) # Create a batch to feed the generator model X_gen, y_gen, y_cat, y_cont, y_cont_target = data_utils.get_gen_batch( batch_size, cat_dim, cont_dim, noise_dim, noise_scale=noise_scale) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch( [y_cat, y_cont, X_gen], [y_gen, y_cat, y_cont_target]) # Unfreeze the discriminator discriminator_model.trainable = True progbar.add(batch_size, values=[("D tot", disc_loss[0]), ("D log", disc_loss[1]), ("D cat", disc_loss[2]), ("D cont", disc_loss[3]), ("G tot", gen_loss[0]), ("G log", gen_loss[1]), ("G cat", gen_loss[2]), ("G cont", gen_loss[3])]) disc_total_loss_batch += disc_loss[0] disc_log_loss_batch += disc_loss[1] disc_cat_loss_batch += disc_loss[2] disc_cont_loss_batch += disc_loss[3] gen_total_loss_batch += gen_loss[0] gen_log_loss_batch += gen_loss[1] gen_cat_loss_batch += gen_loss[2] gen_cont_loss_batch += gen_loss[3] # # Save images for visualization # if batch_counter % (n_batch_per_epoch / 2) == 0: # data_utils.plot_generated_batch(X_real_batch, generator_model, e, # batch_size, cat_dim, cont_dim, noise_dim, # image_data_format, model_name) disc_total_losses.append(disc_total_loss_batch / n_batch_per_epoch) disc_log_losses.append(disc_log_loss_batch / n_batch_per_epoch) disc_cat_losses.append(disc_cat_loss_batch / n_batch_per_epoch) disc_cont_losses.append(disc_cont_loss_batch / n_batch_per_epoch) gen_total_losses.append(gen_total_loss_batch / n_batch_per_epoch) gen_log_losses.append(gen_log_loss_batch / n_batch_per_epoch) gen_cat_losses.append(gen_cat_loss_batch / n_batch_per_epoch) gen_cont_losses.append(gen_cont_loss_batch / n_batch_per_epoch) # Save images for visualization if (e + 1) % visualize_images_every_n_epochs == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, e, batch_size, cat_dim, cont_dim, noise_dim, image_data_format, model_name) data_utils.plot_losses(disc_total_losses, disc_log_losses, disc_cat_losses, disc_cont_losses, gen_total_losses, gen_log_losses, gen_cat_losses, gen_cont_losses, model_name) if (e + 1) % save_weights_every_n_epochs == 0: print("Saving weights...") # Delete all but the last n weights general_utils.purge_weights(save_only_last_n_weights, model_name) # Save weights gen_weights_path = os.path.join( '../../models/%s/gen_weights_epoch%05d.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/%s/disc_weights_epoch%05d.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/%s/DCGAN_weights_epoch%05d.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) end = time.time() print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, end - start)) start = end except KeyboardInterrupt: pass gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name) print("Saving", gen_weights_path) generator_model.save(gen_weights_path, overwrite=True)
def train(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] discriminator = kwargs["discriminator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] bn_mode = kwargs["bn_mode"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_rec = kwargs["lr_D"] opt_rec = kwargs["opt_rec"] lr_G = kwargs["lr_G"] lr_D = kwargs["lr_D"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] use_mbd = kwargs["use_mbd"] image_dim_ordering = kwargs["image_dim_ordering"] epoch_size = n_batch_per_epoch * batch_size deterministic1 = kwargs["deterministic1"] deterministic2 = kwargs["deterministic2"] inject_noise = kwargs["inject_noise"] model = kwargs["model"] no_supertrain = kwargs["no_supertrain"] pureGAN = kwargs["pureGAN"] lsmooth = kwargs["lsmooth"] disc_type = kwargs["disc_type"] resume = kwargs["resume"] name = kwargs["name"] wd = kwargs["wd"] history_size = kwargs["history_size"] monsterClass = kwargs["monsterClass"] data_aug = kwargs["data_aug"] disc_iters = kwargs["disc_iterations"] class_weight = kwargs["class_weight"] reconst_w = kwargs["reconst_w"] rec = kwargs["rec"] reconstClass = kwargs["reconstClass"] pretrained = kwargs["pretrained"] print("\nExperiment parameters:") for key in kwargs.keys(): print(key, kwargs[key]) print("\n") #####some extra parameters: noise_dim = (noise_dim, ) name1 = name + '1' name2 = name + '2' # Setup environment (logging directory etc) general_utils.setup_logging("DCGAN") gen_iterations = 0 # Loading data A_data, A_labels, B_data, B_labels, n_classes, img_A_dim, img_B_dim = load_data( img_dim, image_dim_ordering, dset) test_data, test_labels = load_testset(img_dim, image_dim_ordering, dset) if deterministic1 is None: deterministic1 = False if deterministic2 is None: deterministic2 = False opt_D1, opt_G1, opt_C1, opt_Z1, opt_rec = build_opt( opt_D, opt_G, lr_D, lr_G, lr_rec, opt_rec) generator_model1, discriminator_model1, discriminator_class1, classificator_model1, classificator_model2, DCGAN_model1, GenClass_model1, generator_ss1, discriminator_ss1, DCGAN_ss1, zclass_model1 = load_compile_models( noise_dim, img_A_dim, img_B_dim, deterministic1, pureGAN, wd, 'mse', 'categorical_crossentropy', disc_type, n_classes, opt_D1, opt_G1, opt_C1, opt_Z1, suffix=None, pretrained=pretrained) load_pretrained_weights(generator_model1, discriminator_model1, discriminator_class1, DCGAN_model1, name1, B_data, B_labels, noise_scale, classificator_model1, resume=resume) img_buffer1, datagen1 = load_buffer_and_augmentation( history_size, batch_size, img_A_dim, n_classes) GAN1 = _GAN(generator_model1, discriminator_model1, discriminator_class1, DCGAN_model1, GenClass_model1, classificator_model1, classificator_model2, generator_ss1, discriminator_ss1, DCGAN_ss1, batch_size, img_A_dim, img_B_dim, noise_dim, noise_scale, lr_D, lr_G, deterministic1, inject_noise, model, lsmooth, img_buffer1, datagen1, disc_type, data_aug, n_classes, disc_iters, name1, dir='AtoB') pretrain_disc(GAN1, A_data, A_labels, B_data, B_labels, class_weight, pretrain_iters=500, resume=resume) ##################### accuracy = [] for e in range(1, nb_epoch + 1): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size, interval=0.2) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: l_disc_real1, l_disc_gen1, l_gen1, l_disc_real1_ss, l_disc_gen1_ss, l_disc_ss1, l_disc_ss2, l_disc_ss3, l_disc_ss4, l_disc_ss5, l_gen1_ss, l_z1, l_class1, l_rec1, l_GenClass1, _ = get_loss_list( ) A_data_batch, A_labels_batch, B_data_batch, B_labels_batch = train_gan( GAN1, GAN1.disc_iters, A_data, A_labels, B_data, B_labels, batch_counter, l_disc_real1, l_disc_gen1, l_gen1, l_disc_real1_ss, l_disc_gen1_ss, l_disc_ss1, l_disc_ss2, l_disc_ss3, l_disc_ss4, l_disc_ss5, l_gen1_ss, l_GenClass1, class_weight) if rec: train_rec(GAN1, rec1, rec2, A_data_batch, B_data_batch, l_rec1, l_rec2, reconst_w) #BRINGING US TO L.A.? :) # if reconstClass > 0.0: # train_recClass(GAN1,recClass, A_data_batch, A_labels_batch, l_recClass, reconstClass) l_class1 = train_class(GAN1, l_class1, l_rec1, A_data_batch, A_labels_batch, B_data_batch, test_data, test_labels) # l_class2 = train_class(GAN2, l_class2, A_data_batch, A_labels_batch) # dummy = GAN1.discriminator2.predict(B_data_batch) # print(dummy) batch_counter, gen_iterations = visualize_save_stuffs( [GAN1], progbar, gen_iterations, batch_counter, n_batch_per_epoch, l_disc_real1, l_disc_gen1, l_gen1, l_disc_real1_ss, l_disc_gen1_ss, l_gen1_ss, l_class1, A_data, A_labels, B_data, B_labels, start, e, l_rec1, l_GenClass1) acc = testing_class_accuracy([GAN1], GAN1.classificator_model, GAN1.generator_model, test_data.shape[0], GAN1.noise_dim, GAN1.noise_scale, test_data, test_labels) X_noise = sample_noise(GAN1.noise_scale, A_data.shape[0], GAN1.noise_dim) gen_output = GAN1.generator_model.predict([X_noise, A_data]) np.save('MnistM', gen_output) # testing_class_accuracy([GAN1],GAN1.classificator_model, GAN1.generator_model, # 5000, GAN1.noise_dim, GAN1.noise_scale, B_data, B_labels) accuracy = np.append(accuracy, acc) np.save('accuracy', accuracy)
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val, target_train, target_val = data_utils.load_data( dset, image_data_format) img_dim = X_full_train.shape[-3:] # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # DCGAN_model = models.DCGAN(generator_model, # discriminator_model, # img_dim, # patch_size, # image_data_format) ########################################################################## classifier_model = models.Pereira_classifier(img_dim) #classifier_model = models.MyResNet18(img_dim) #classifier_model = models.MyDensNet121(img_dim) #classifier_model = models.MyNASNetMobile(img_dim) ######################################################################### loss = [keras.losses.categorical_crossentropy] loss_weights = [1] classifier_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) class_loss = 100 disc_loss = 100 max_accval = 0 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_full_batch, X_sketch_batch, Y_target in data_utils.gen_batch( X_full_train, X_sketch_train, target_train, batch_size): class_loss = classifier_model.train_on_batch( X_sketch_batch, Y_target) # Unfreeze the discriminator batch_counter += 1 progbar.add(batch_size, values=[("class_loss", class_loss)]) # Save images for visualization if batch_counter >= n_batch_per_epoch: X_full_batch, X_sketch_batch, Y_target_val = next( data_utils.gen_batch(X_full_val, X_sketch_val, target_val, int(X_sketch_val.shape[0]))) y_pred = classifier_model.predict(X_sketch_batch) y_predd = np.argmax(y_pred, axis=1) y_true = np.argmax(Y_target_val, axis=1) #print(y_true.shape) accval = (sum( (y_predd == y_true)) / y_predd.shape[0] * 100) if (accval > max_accval): max_accval = accval print('valacc=%.2f' % (accval)) print('max_accval=%.2f' % (max_accval)) break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) except KeyboardInterrupt: pass
def train(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] bn_mode = kwargs["bn_mode"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] clamp_lower = kwargs["clamp_lower"] clamp_upper = kwargs["clamp_upper"] image_dim_ordering = kwargs["image_dim_ordering"] epoch_size = n_batch_per_epoch * batch_size print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") # Setup environment (logging directory etc) general_utils.setup_logging("DCGAN") # Load and normalize data X_real_train = data_utils.load_image_dataset(dset, img_dim, image_dim_ordering) # Get the full real image dimension img_dim = X_real_train.shape[-3:] # Create optimizers opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_D = data_utils.get_optimizer(opt_D, lr_D) ####################### # Load models ####################### noise_dim = (noise_dim, ) if generator == "upsampling": generator_model = models.generator_upsampling(noise_dim, img_dim, bn_mode, dset=dset) else: generator_model = models.generator_deconv(noise_dim, img_dim, bn_mode, batch_size, dset=dset) discriminator_model = models.discriminator(img_dim, bn_mode, dset=dset) DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_dim) ############################ # Compile models ############################ generator_model.compile(loss='mse', optimizer=opt_G) discriminator_model.trainable = False DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G) discriminator_model.trainable = True discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D) # Global iteration counter for generator updates gen_iterations = 0 ################# # Start training ################ for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: if gen_iterations < 25 or gen_iterations % 500 == 0: disc_iterations = 100 else: disc_iterations = kwargs["disc_iterations"] ################################### # 1) Train the critic / discriminator ################################### list_disc_loss_real = [] list_disc_loss_gen = [] for disc_it in range(disc_iterations): # Clip discriminator weights for l in discriminator_model.layers: weights = l.get_weights() weights = [ np.clip(w, clamp_lower, clamp_upper) for w in weights ] l.set_weights(weights) X_real_batch = next( data_utils.gen_batch(X_real_train, batch_size)) # Create a batch to feed the discriminator model X_disc_real, X_disc_gen = data_utils.get_disc_batch( X_real_batch, generator_model, batch_counter, batch_size, noise_dim, noise_scale=noise_scale) # Update the discriminator disc_loss_real = discriminator_model.train_on_batch( X_disc_real, -np.ones(X_disc_real.shape[0])) disc_loss_gen = discriminator_model.train_on_batch( X_disc_gen, np.ones(X_disc_gen.shape[0])) list_disc_loss_real.append(disc_loss_real) list_disc_loss_gen.append(disc_loss_gen) ####################### # 2) Train the generator ####################### X_gen = X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0])) # Unfreeze the discriminator discriminator_model.trainable = True gen_iterations += 1 batch_counter += 1 progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)), ("Loss_D_real", -np.mean(list_disc_loss_real)), ("Loss_D_gen", np.mean(list_disc_loss_gen)), ("Loss_G", -gen_loss)]) # Save images for visualization ~2 times per epoch if batch_counter % (n_batch_per_epoch / 2) == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, batch_size, noise_dim, image_dim_ordering) print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) # Save model weights (by default, every 5 epochs) data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e)
def train(**kwargs): """ Train model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] data_file = kwargs["data_file"] nb_neighbors = kwargs["nb_neighbors"] model_name = kwargs["model_name"] training_mode = kwargs["training_mode"] epoch_size = n_batch_per_epoch * batch_size img_size = int(os.path.basename(data_file).split("_")[1]) # Setup directories to save model, architecture etc general_utils.setup_logging(model_name) # Create a batch generator for the color data DataGen = batch_utils.DataGenerator(data_file, batch_size=batch_size, dset="training") c, h, w = DataGen.get_config()["data_shape"][1:] # Load the array of quantized ab value q_ab = np.load("../../data/processed/pts_in_hull.npy") nb_q = q_ab.shape[0] # Fit a NN to q_ab nn_finder = nn.NearestNeighbors(n_neighbors=nb_neighbors, algorithm='ball_tree').fit(q_ab) # Load the color prior factor that encourages rare colors prior_factor = np.load("../../data/processed/CelebA_%s_prior_factor.npy" % img_size) # Load and rescale data if training_mode == "in_memory": with h5py.File(data_file, "r") as hf: X_train = hf["training_lab_data"][:] # Remove possible previous figures to avoid confusion for f in glob.glob("../../figures/*.png"): os.remove(f) try: # Create optimizers opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load colorizer model color_model = models.load(model_name, nb_q, (1, h, w), batch_size) color_model.compile(loss='categorical_crossentropy_color', optimizer=opt) color_model.summary() from keras.utils.visualize_util import plot plot(color_model, to_file='../../figures/colorful.png', show_shapes=True, show_layer_names=True) # Actual training loop for epoch in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() # Choose Batch Generation mode if training_mode == "in_memory": BatchGen = DataGen.gen_batch_in_memory(X_train, nn_finder, nb_q, prior_factor) else: BatchGen = DataGen.gen_batch(nn_finder, nb_q, prior_factor) for batch in BatchGen: X_batch_black, X_batch_color, Y_batch = batch train_loss = color_model.train_on_batch(X_batch_black / 100., Y_batch) batch_counter += 1 progbar.add(batch_size, values=[("loss", train_loss)]) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (epoch + 1, nb_epoch, time.time() - start)) # Plot some data with original, b and w and colorized versions side by side general_utils.plot_batch(color_model, q_ab, X_batch_black, X_batch_color, batch_size, h, w, nb_q, epoch) # Save weights every 5 epoch if epoch % 5 == 0: weights_path = os.path.join('../../models/%s/%s_weights_epoch%s.h5' % (model_name, model_name, epoch)) color_model.save_weights(weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] image_data_format = kwargs["image_data_format"] generator_type = kwargs["generator_type"] dset = kwargs["dset"] use_identity_image = kwargs["use_identity_image"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] augment_data = kwargs["augment_data"] model_name = kwargs["model_name"] save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"] visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"] save_only_last_n_weights = kwargs["save_only_last_n_weights"] use_mbd = kwargs["use_mbd"] label_smoothing = kwargs["use_label_smoothing"] label_flipping_prob = kwargs["label_flipping_prob"] use_l1_weighted_loss = kwargs["use_l1_weighted_loss"] prev_model = kwargs["prev_model"] change_model_name_to_prev_model = kwargs["change_model_name_to_prev_model"] discriminator_optimizer = kwargs["discriminator_optimizer"] n_run_of_gen_for_1_run_of_disc = kwargs["n_run_of_gen_for_1_run_of_disc"] load_all_data_at_once = kwargs["load_all_data_at_once"] MAX_FRAMES_PER_GIF = kwargs["MAX_FRAMES_PER_GIF"] dont_train = kwargs["dont_train"] # batch_size = args.batch_size # n_batch_per_epoch = args.n_batch_per_epoch # nb_epoch = args.nb_epoch # save_weights_every_n_epochs = args.save_weights_every_n_epochs # generator_type = args.generator_type # patch_size = args.patch_size # label_smoothing = False # label_flipping_prob = False # dset = args.dset # use_mbd = False if dont_train: # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) generator_model = models.load("generator_unet_%s" % generator_type, img_dim, nb_patch, use_mbd, batch_size, model_name) generator_model.compile(loss='mae', optimizer='adam') return generator_model # Check and make the dataset # If .h5 file of dset is not present, try making it if load_all_data_at_once: if not os.path.exists("../../data/processed/%s_data.h5" % dset): print("dset %s_data.h5 not present in '../../data/processed'!" % dset) if not os.path.exists("../../data/%s/" % dset): print("dset folder %s not present in '../../data'!\n\nERROR: Dataset .h5 file not made, and dataset not available in '../../data/'.\n\nQuitting." % dset) return else: if not os.path.exists("../../data/%s/train" % dset) or not os.path.exists("../../data/%s/val" % dset) or not os.path.exists("../../data/%s/test" % dset): print("'train', 'val' or 'test' folders not present in dset folder '../../data/%s'!\n\nERROR: Dataset must contain 'train', 'val' and 'test' folders.\n\nQuitting." % dset) return else: print("Making %s dataset" % dset) subprocess.call(['python3', '../data/make_dataset.py', '../../data/%s' % dset, '3']) print("Done!") else: if not os.path.exists(dset): print("dset does not exist! Given:", dset) return if not os.path.exists(os.path.join(dset, 'train')): print("dset does not contain a 'train' dir! Given dset:", dset) return if not os.path.exists(os.path.join(dset, 'val')): print("dset does not contain a 'val' dir! Given dset:", dset) return epoch_size = n_batch_per_epoch * batch_size init_epoch = 0 if prev_model: print('\n\nLoading prev_model from', prev_model, '...\n\n') prev_model_latest_gen = sorted(glob.glob(os.path.join('../../models/', prev_model, '*gen*epoch*.h5')))[-1] prev_model_latest_disc = sorted(glob.glob(os.path.join('../../models/', prev_model, '*disc*epoch*.h5')))[-1] prev_model_latest_DCGAN = sorted(glob.glob(os.path.join('../../models/', prev_model, '*DCGAN*epoch*.h5')))[-1] print(prev_model_latest_gen, prev_model_latest_disc, prev_model_latest_DCGAN) if change_model_name_to_prev_model: # Find prev model name, epoch model_name = prev_model_latest_DCGAN.split('models')[-1].split('/')[1] init_epoch = int(prev_model_latest_DCGAN.split('epoch')[1][:5]) + 1 # img_dim = X_target_train.shape[-3:] # img_dim = (256, 256, 3) # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) if discriminator_optimizer == 'sgd': opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) elif discriminator_optimizer == 'adam': opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator_type, img_dim, nb_patch, use_mbd, batch_size, model_name) generator_model.compile(loss='mae', optimizer=opt_dcgan) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, use_mbd, batch_size, model_name) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) if use_l1_weighted_loss: loss = [l1_weighted_loss, 'binary_crossentropy'] else: loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) # Load prev_model if prev_model: generator_model.load_weights(prev_model_latest_gen) discriminator_model.load_weights(prev_model_latest_disc) DCGAN_model.load_weights(prev_model_latest_DCGAN) # Load .h5 data all at once print('\n\nLoading data...\n\n') check_this_process_memory() if load_all_data_at_once: X_target_train, X_sketch_train, X_target_val, X_sketch_val = data_utils.load_data(dset, image_data_format) check_this_process_memory() print('X_target_train: %.4f' % (X_target_train.nbytes/2**30), "GB") print('X_sketch_train: %.4f' % (X_sketch_train.nbytes/2**30), "GB") print('X_target_val: %.4f' % (X_target_val.nbytes/2**30), "GB") print('X_sketch_val: %.4f' % (X_sketch_val.nbytes/2**30), "GB") # To generate training data X_target_batch_gen_train, X_sketch_batch_gen_train = data_utils.data_generator(X_target_train, X_sketch_train, batch_size, augment_data=augment_data) X_target_batch_gen_val, X_sketch_batch_gen_val = data_utils.data_generator(X_target_val, X_sketch_val, batch_size, augment_data=False) # Load data from images through an ImageDataGenerator else: X_batch_gen_train = data_utils.data_generator_from_dir(os.path.join(dset, 'train'), target_size=(img_dim[0], 2*img_dim[1]), batch_size=batch_size) X_batch_gen_val = data_utils.data_generator_from_dir(os.path.join(dset, 'val'), target_size=(img_dim[0], 2*img_dim[1]), batch_size=batch_size) check_this_process_memory() # Setup environment (logging directory etc) general_utils.setup_logging(**kwargs) # Losses disc_losses = [] gen_total_losses = [] gen_L1_losses = [] gen_log_losses = [] # Start training print("\n\nStarting training...\n\n") # For each epoch for e in range(nb_epoch): # Initialize progbar and batch counter # progbar = generic_utils.Progbar(epoch_size) batch_counter = 0 gen_total_loss_epoch = 0 gen_L1_loss_epoch = 0 gen_log_loss_epoch = 0 start = time.time() # For each batch # for X_target_batch, X_sketch_batch in data_utils.gen_batch(X_target_train, X_sketch_train, batch_size): for batch in range(n_batch_per_epoch): # Create a batch to feed the discriminator model if load_all_data_at_once: X_target_batch_train, X_sketch_batch_train = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train) else: X_target_batch_train, X_sketch_batch_train = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim, augment_data=augment_data, use_identity_image=use_identity_image) X_disc, y_disc = data_utils.get_disc_batch(X_target_batch_train, X_sketch_batch_train, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping_prob=label_flipping_prob) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model if load_all_data_at_once: X_gen_target, X_gen_sketch = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train) else: X_gen_target, X_gen_sketch = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim, augment_data=augment_data, use_identity_image=use_identity_image) y_gen_target = np.zeros((X_gen_target.shape[0], 2), dtype=np.uint8) y_gen_target[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False # Train generator for _ in range(n_run_of_gen_for_1_run_of_disc-1): gen_loss = DCGAN_model.train_on_batch(X_gen_sketch, [X_gen_target, y_gen_target]) gen_total_loss_epoch += gen_loss[0]/n_run_of_gen_for_1_run_of_disc gen_L1_loss_epoch += gen_loss[1]/n_run_of_gen_for_1_run_of_disc gen_log_loss_epoch += gen_loss[2]/n_run_of_gen_for_1_run_of_disc if load_all_data_at_once: X_gen_target, X_gen_sketch = next(X_target_batch_gen_train), next(X_sketch_batch_gen_train) else: X_gen_target, X_gen_sketch = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_train, img_dim=img_dim, augment_data=augment_data, use_identity_image=use_identity_image) gen_loss = DCGAN_model.train_on_batch(X_gen_sketch, [X_gen_target, y_gen_target]) # Add losses gen_total_loss_epoch += gen_loss[0]/n_run_of_gen_for_1_run_of_disc gen_L1_loss_epoch += gen_loss[1]/n_run_of_gen_for_1_run_of_disc gen_log_loss_epoch += gen_loss[2]/n_run_of_gen_for_1_run_of_disc # Unfreeze the discriminator discriminator_model.trainable = True # Progress # progbar.add(batch_size, values=[("D logloss", disc_loss), # ("G tot", gen_loss[0]), # ("G L1", gen_loss[1]), # ("G logloss", gen_loss[2])]) print("Epoch", str(init_epoch+e+1), "batch", str(batch+1), "D_logloss", disc_loss, "G_tot", gen_loss[0], "G_L1", gen_loss[1], "G_log", gen_loss[2]) gen_total_loss = gen_total_loss_epoch/n_batch_per_epoch gen_L1_loss = gen_L1_loss_epoch/n_batch_per_epoch gen_log_loss = gen_log_loss_epoch/n_batch_per_epoch disc_losses.append(disc_loss) gen_total_losses.append(gen_total_loss) gen_L1_losses.append(gen_L1_loss) gen_log_losses.append(gen_log_loss) # Save images for visualization if (e + 1) % visualize_images_every_n_epochs == 0: data_utils.plot_generated_batch(X_target_batch_train, X_sketch_batch_train, generator_model, batch_size, image_data_format, model_name, "training", init_epoch + e + 1, MAX_FRAMES_PER_GIF) # Get new images for validation if load_all_data_at_once: X_target_batch_val, X_sketch_batch_val = next(X_target_batch_gen_val), next(X_sketch_batch_gen_val) else: X_target_batch_val, X_sketch_batch_val = data_utils.load_data_from_data_generator_from_dir(X_batch_gen_val, img_dim=img_dim, augment_data=False, use_identity_image=use_identity_image) # Predict and validate data_utils.plot_generated_batch(X_target_batch_val, X_sketch_batch_val, generator_model, batch_size, image_data_format, model_name, "validation", init_epoch + e + 1, MAX_FRAMES_PER_GIF) # Plot losses data_utils.plot_losses(disc_losses, gen_total_losses, gen_L1_losses, gen_log_losses, model_name, init_epoch) # Save weights if (e + 1) % save_weights_every_n_epochs == 0: # Delete all but the last n weights purge_weights(save_only_last_n_weights, model_name) # Save gen weights gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) print("Saving", gen_weights_path) generator_model.save_weights(gen_weights_path, overwrite=True) # Save disc weights disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) print("Saving", disc_weights_path) discriminator_model.save_weights(disc_weights_path, overwrite=True) # Save DCGAN weights DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) print("Saving", DCGAN_weights_path) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) check_this_process_memory() print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d} END, Time taken: {3:.4f} seconds'.format(datetime.datetime.now(), init_epoch + e + 1, init_epoch + nb_epoch, time.time() - start)) print('------------------------------------------------------------------------------------') except KeyboardInterrupt: pass # SAVE THE MODEL try: # Save the model as it is, so that it can be loaded using - # ```from keras.models import load_model; gen = load_model('generator_latest.h5')``` gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name) print("Saving", gen_weights_path) generator_model.save(gen_weights_path, overwrite=True) # Save model as json string generator_model_json_string = generator_model.to_json() print("Saving", '../../models/%s/generator_latest.txt' % model_name) with open('../../models/%s/generator_latest.txt' % model_name, 'w') as outfile: a = outfile.write(generator_model_json_string) # Save model as json generator_model_json_data = json.loads(generator_model_json_string) print("Saving", '../../models/%s/generator_latest.json' % model_name) with open('../../models/%s/generator_latest.json' % model_name, 'w') as outfile: json.dump(generator_model_json_data, outfile) except: print(sys.exc_info()[0]) print("Done.") return generator_model
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters patch_size = kwargs["patch_size"] image_data_format = kwargs["image_data_format"] generator_type = kwargs["generator_type"] dset = kwargs["dset"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"] visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"] use_mbd = kwargs["use_mbd"] label_smoothing = kwargs["use_label_smoothing"] label_flipping_prob = kwargs["label_flipping_prob"] use_l1_weighted_loss = kwargs["use_l1_weighted_loss"] prev_model = kwargs["prev_model"] discriminator_optimizer = kwargs["discriminator_optimizer"] n_run_of_gen_for_1_run_of_disc = kwargs["n_run_of_gen_for_1_run_of_disc"] MAX_FRAMES_PER_GIF = kwargs["MAX_FRAMES_PER_GIF"] # batch_size = args.batch_size # n_batch_per_epoch = args.n_batch_per_epoch # nb_epoch = args.nb_epoch # save_weights_every_n_epochs = args.save_weights_every_n_epochs # generator_type = args.generator_type # patch_size = args.patch_size # label_smoothing = False # label_flipping_prob = False # dset = args.dset # use_mbd = False # Check and make the dataset # If .h5 file of dset is not present, try making it if not os.path.exists("../../data/processed/%s_data.h5" % dset): print("dset %s_data.h5 not present in '../../data/processed'!" % dset) if not os.path.exists("../../data/%s/" % dset): print( "dset folder %s not present in '../../data'!\n\nERROR: Dataset .h5 file not made, and dataset not available in '../../data/'.\n\nQuitting." % dset) return else: if not os.path.exists( "../../data/%s/train" % dset) or not os.path.exists( "../../data/%s/val" % dset) or not os.path.exists( "../../data/%s/test" % dset): print( "'train', 'val' or 'test' folders not present in dset folder '../../data/%s'!\n\nERROR: Dataset must contain 'train', 'val' and 'test' folders.\n\nQuitting." % dset) return else: print("Making %s dataset" % dset) subprocess.call([ 'python3', '../data/make_dataset.py', '../../data/%s' % dset, '3' ]) print("Done!") epoch_size = n_batch_per_epoch * batch_size init_epoch = 0 if prev_model: print('\n\nLoading prev_model from', prev_model, '...\n\n') prev_model_latest_gen = sorted( glob.glob(os.path.join('../../models/', prev_model, '*gen*.h5')))[-1] prev_model_latest_disc = sorted( glob.glob(os.path.join('../../models/', prev_model, '*disc*.h5')))[-1] prev_model_latest_DCGAN = sorted( glob.glob(os.path.join('../../models/', prev_model, '*DCGAN*.h5')))[-1] # Find prev model name, epoch model_name = prev_model_latest_DCGAN.split('models')[-1].split('/')[1] init_epoch = int(prev_model_latest_DCGAN.split('epoch')[1][:5]) + 1 # Setup environment (logging directory etc), if no prev_model is mentioned general_utils.setup_logging(model_name) # img_dim = X_full_train.shape[-3:] img_dim = (256, 256, 3) # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) if discriminator_optimizer == 'sgd': opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) elif discriminator_optimizer == 'adam': opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Load generator model generator_model = models.load("generator_unet_%s" % generator_type, img_dim, nb_patch, use_mbd, batch_size, model_name) generator_model.compile(loss='mae', optimizer=opt_discriminator) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, use_mbd, batch_size, model_name) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) if use_l1_weighted_loss: loss = [l1_weighted_loss, 'binary_crossentropy'] else: loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) # Load prev_model if prev_model: generator_model.load_weights(prev_model_latest_gen) discriminator_model.load_weights(prev_model_latest_disc) DCGAN_model.load_weights(prev_model_latest_DCGAN) # Load and rescale data print('\n\nLoading data...\n\n') X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data( dset, image_data_format) check_this_process_memory() print('X_full_train: %.4f' % (X_full_train.nbytes / 2**30), "GB") print('X_sketch_train: %.4f' % (X_sketch_train.nbytes / 2**30), "GB") print('X_full_val: %.4f' % (X_full_val.nbytes / 2**30), "GB") print('X_sketch_val: %.4f' % (X_sketch_val.nbytes / 2**30), "GB") # Losses disc_losses = [] gen_total_losses = [] gen_L1_losses = [] gen_log_losses = [] # Start training print("\n\nStarting training\n\n") for e in range(nb_epoch): # Initialize progbar and batch counter # progbar = generic_utils.Progbar(epoch_size) batch_counter = 0 gen_total_loss_epoch = 0 gen_L1_loss_epoch = 0 gen_log_loss_epoch = 0 start = time.time() for X_full_batch, X_sketch_batch in data_utils.gen_batch( X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch( X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping_prob=label_flipping_prob) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next( data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False # Train generator for _ in range(n_run_of_gen_for_1_run_of_disc - 1): gen_loss = DCGAN_model.train_on_batch( X_gen, [X_gen_target, y_gen]) gen_total_loss_epoch += gen_loss[ 0] / n_run_of_gen_for_1_run_of_disc gen_L1_loss_epoch += gen_loss[ 1] / n_run_of_gen_for_1_run_of_disc gen_log_loss_epoch += gen_loss[ 2] / n_run_of_gen_for_1_run_of_disc X_gen_target, X_gen = next( data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Add losses gen_total_loss_epoch += gen_loss[ 0] / n_run_of_gen_for_1_run_of_disc gen_L1_loss_epoch += gen_loss[ 1] / n_run_of_gen_for_1_run_of_disc gen_log_loss_epoch += gen_loss[ 2] / n_run_of_gen_for_1_run_of_disc # Unfreeze the discriminator discriminator_model.trainable = True # Progress # progbar.add(batch_size, values=[("D logloss", disc_loss), # ("G tot", gen_loss[0]), # ("G L1", gen_loss[1]), # ("G logloss", gen_loss[2])]) print("Epoch", str(init_epoch + e + 1), "batch", str(batch_counter + 1), "D_logloss", disc_loss, "G_tot", gen_loss[0], "G_L1", gen_loss[1], "G_log", gen_loss[2]) batch_counter += 1 if batch_counter >= n_batch_per_epoch: break gen_total_loss = gen_total_loss_epoch / n_batch_per_epoch gen_L1_loss = gen_L1_loss_epoch / n_batch_per_epoch gen_log_loss = gen_log_loss_epoch / n_batch_per_epoch disc_losses.append(disc_loss) gen_total_losses.append(gen_total_loss) gen_L1_losses.append(gen_L1_loss) gen_log_losses.append(gen_log_loss) check_this_process_memory() print('Epoch %s/%s, Time: %.4f' % (init_epoch + e + 1, init_epoch + nb_epoch, time.time() - start)) # Save images for visualization if (e + 1) % visualize_images_every_n_epochs == 0: data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, model_name, "training", init_epoch + e + 1, MAX_FRAMES_PER_GIF) # Get new images from validation X_full_batch, X_sketch_batch = next( data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, model_name, "validation", init_epoch + e + 1, MAX_FRAMES_PER_GIF) # Plot losses data_utils.plot_losses(disc_losses, gen_total_losses, gen_L1_losses, gen_log_losses, model_name, init_epoch) # Save weights if (e + 1) % save_weights_every_n_epochs == 0: gen_weights_path = os.path.join( '../../models/%s/gen_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/%s/disc_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/%s/DCGAN_weights_epoch%05d_discLoss%.04f_genTotL%.04f_genL1L%.04f_genLogL%.04f.h5' % (model_name, init_epoch + e, disc_losses[-1], gen_total_losses[-1], gen_L1_losses[-1], gen_log_losses[-1])) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(cat_dim, noise_dim, batch_size, n_batch_per_epoch, nb_epoch, dset="mnist"): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ general_utils.setup_logging("IG") # Load and rescale data if dset == "mnist": print("loading mnist data") X_real_train, Y_real_train, X_real_test, Y_real_test = data_utils.load_mnist( ) # pick 1000 sample for testing # X_real_test = X_real_test[-1000:] # Y_real_test = Y_real_test[-1000:] img_dim = X_real_train.shape[-3:] epoch_size = n_batch_per_epoch * batch_size try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) opt_discriminator = Adam(lr=2E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-4, momentum=0.9, nesterov=True) # Load generator model generator_model = models.load("generator_deconv", cat_dim, noise_dim, img_dim, batch_size, dset=dset) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", cat_dim, noise_dim, img_dim, batch_size, dset=dset) generator_model.compile(loss='mse', optimizer=opt_discriminator) # stop the discriminator to learn while in generator is learning discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, cat_dim, noise_dim) list_losses = ['binary_crossentropy', 'categorical_crossentropy'] list_weights = [1, 1] DCGAN_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_dcgan) # Multiple discriminator losses # allow the discriminator to learn again discriminator_model.trainable = True discriminator_model.compile(loss=list_losses, loss_weights=list_weights, optimizer=opt_discriminator) # Start training print("Start training") for e in range(nb_epoch + 1): # Initialize progbar and batch counter # progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() print("Epoch: {}".format(e)) for X_real_batch, Y_real_batch in zip( data_utils.gen_batch(X_real_train, batch_size), data_utils.gen_batch(Y_real_train, batch_size)): # Create a batch to feed the discriminator model X_disc_fake, y_disc_fake, noise_sample = data_utils.get_disc_batch( X_real_batch, Y_real_batch, generator_model, batch_size, cat_dim, noise_dim, type="fake") X_disc_real, y_disc_real = data_utils.get_disc_batch( X_real_batch, Y_real_batch, generator_model, batch_size, cat_dim, noise_dim, type="real") # Update the discriminator disc_loss_fake = discriminator_model.train_on_batch( X_disc_fake, [y_disc_fake, Y_real_batch]) disc_loss_real = discriminator_model.train_on_batch( X_disc_real, [y_disc_real, Y_real_batch]) disc_loss = disc_loss_fake + disc_loss_real # Create a batch to feed the generator model # X_noise, y_gen = data_utils.get_gen_batch(batch_size, cat_dim, noise_dim) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch( [Y_real_batch, noise_sample], [y_disc_real, Y_real_batch]) # Unfreeze the discriminator discriminator_model.trainable = True # training validation p_real_batch, p_Y_batch = discriminator_model.predict( X_real_batch, batch_size=batch_size) acc_train = data_utils.accuracy(p_Y_batch, Y_real_batch) batch_counter += 1 # progbar.add(batch_size, values=[("D tot", disc_loss[0]), # ("D cat", disc_loss[2]), # ("G tot", gen_loss[0]), # ("G cat", gen_loss[2]), # ("P Real:", p_real_batch), # ("Q acc", acc_train)]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0 and e % 10 == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, batch_size, cat_dim, noise_dim, e) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) _, p_Y_test = discriminator_model.predict( X_real_test, batch_size=X_real_test.shape[0]) acc_test = data_utils.accuracy(p_Y_test, Y_real_test) print("Epoch: {} Accuracy: {}".format(e + 1, acc_test)) if e % 1000 == 0: gen_weights_path = os.path.join( '../../models/IG/gen_weights.h5') generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join( '../../models/IG/disc_weights.h5') discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join( '../../models/IG/DCGAN_weights.h5') DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def eval(**kwargs): # Roll out the parameters batch_size = kwargs["batch_size"] generator = kwargs["generator"] model_name = kwargs["model_name"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] cont_dim = (kwargs["cont_dim"],) cat_dim = (kwargs["cat_dim"],) noise_dim = (kwargs["noise_dim"],) bn_mode = kwargs["bn_mode"] noise_scale = kwargs["noise_scale"] dset = kwargs["dset"] eval_epoch = kwargs["eval_epoch"] # Setup environment (logging directory etc) general_utils.setup_logging(**kwargs) # Load and rescale data if dset == "RGZ": X_real_train = data_utils.load_RGZ(img_dim, image_data_format) if dset == "mnist": X_real_train, _, _, _ = data_utils.load_mnist(image_data_format) img_dim = X_real_train.shape[-3:] # Load generator model generator_model = models.load("generator_%s" % generator, cat_dim, cont_dim, noise_dim, img_dim, bn_mode, batch_size, dset=dset) # Load colorization model generator_model.load_weights("../../models/%s/gen_weights_epoch%05d.h5" % (model_name, eval_epoch)) X_plot = [] # Vary the categorical variable for i in range(cat_dim[0]): X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim) X_cont = data_utils.sample_noise(noise_scale, batch_size, cont_dim) X_cont = np.repeat(X_cont[:1, :], batch_size, axis=0) # fix continuous noise X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32') X_cat[:, i] = 1 # always the same categorical value X_gen = generator_model.predict([X_cat, X_cont, X_noise]) X_gen = data_utils.inverse_normalization(X_gen) if image_data_format == "channels_first": X_gen = X_gen.transpose(0,2,3,1) X_gen = [X_gen[i] for i in range(len(X_gen))] X_plot.append(np.concatenate(X_gen, axis=1)) X_plot = np.concatenate(X_plot, axis=0) plt.figure(figsize=(8,10)) if X_plot.shape[-1] == 1: plt.imshow(X_plot[:, :, 0], cmap="gray") else: plt.imshow(X_plot) plt.xticks([]) plt.yticks([]) plt.ylabel("Varying categorical factor", fontsize=28, labelpad=60) plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1), arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4)) plt.tight_layout() plt.savefig(os.path.join("../../figures", model_name, "varying_categorical.png")) plt.clf() plt.close() # Vary the continuous variables X_plot = [] # First get the extent of the noise sampling x = np.ravel(data_utils.sample_noise(noise_scale, batch_size * 20000, cont_dim)) # Define interpolation points x = np.linspace(x.min(), x.max(), num=batch_size) for i in range(batch_size): X_noise = data_utils.sample_noise(noise_scale, batch_size, noise_dim) X_cont = np.concatenate([np.array([x[i], x[j]]).reshape(1, -1) for j in range(batch_size)], axis=0) X_cat = np.zeros((batch_size, cat_dim[0]), dtype='float32') X_cat[:, 1] = 1 # always the same categorical value X_gen = generator_model.predict([X_cat, X_cont, X_noise]) X_gen = data_utils.inverse_normalization(X_gen) if image_data_format == "channels_first": X_gen = X_gen.transpose(0,2,3,1) X_gen = [X_gen[i] for i in range(len(X_gen))] X_plot.append(np.concatenate(X_gen, axis=1)) X_plot = np.concatenate(X_plot, axis=0) plt.figure(figsize=(10,10)) if X_plot.shape[-1] == 1: plt.imshow(X_plot[:, :, 0], cmap="gray") else: plt.imshow(X_plot) plt.xticks([]) plt.yticks([]) plt.ylabel("Varying continuous factor 1", fontsize=28, labelpad=60) plt.annotate('', xy=(-0.05, 0), xycoords='axes fraction', xytext=(-0.05, 1), arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4)) plt.xlabel("Varying continuous factor 2", fontsize=28, labelpad=60) plt.annotate('', xy=(1, -0.05), xycoords='axes fraction', xytext=(0, -0.05), arrowprops=dict(arrowstyle="-|>", color='k', linewidth=4)) plt.tight_layout() plt.savefig(os.path.join("../../figures", model_name, "varying_continuous.png")) plt.clf() plt.close()
def train(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] bn_mode = kwargs["bn_mode"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] clamp_lower = kwargs["clamp_lower"] clamp_upper = kwargs["clamp_upper"] image_dim_ordering = kwargs["image_dim_ordering"] epoch_size = n_batch_per_epoch * batch_size print("\nExperiment parameters:") for key in kwargs.keys(): print key, kwargs[key] print("\n") # Setup environment (logging directory etc) general_utils.setup_logging("DCGAN") # Load and normalize data X_real_train = data_utils.load_image_dataset(dset, img_dim, image_dim_ordering) # Get the full real image dimension img_dim = X_real_train.shape[-3:] # Create optimizers opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_D = data_utils.get_optimizer(opt_D, lr_D) ####################### # Load models ####################### noise_dim = (noise_dim,) if generator == "upsampling": generator_model = models.generator_upsampling(noise_dim, img_dim, bn_mode, dset=dset) else: generator_model = models.generator_deconv(noise_dim, img_dim, bn_mode, batch_size, dset=dset) discriminator_model = models.discriminator(img_dim, bn_mode) DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_dim) ############################ # Compile models ############################ generator_model.compile(loss='mse', optimizer=opt_G) discriminator_model.trainable = False DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G) discriminator_model.trainable = True discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D) # Global iteration counter for generator updates gen_iterations = 0 ################# # Start training ################ for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() while batch_counter < n_batch_per_epoch: if gen_iterations < 25 or gen_iterations % 500 == 0: disc_iterations = 100 else: disc_iterations = kwargs["disc_iterations"] ################################### # 1) Train the critic / discriminator ################################### list_disc_loss_real = [] list_disc_loss_gen = [] for disc_it in range(disc_iterations): # Clip discriminator weights for l in discriminator_model.layers: weights = l.get_weights() weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights] l.set_weights(weights) X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size)) # Create a batch to feed the discriminator model X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch, generator_model, batch_counter, batch_size, noise_dim, noise_scale=noise_scale) # Update the discriminator disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0])) disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0])) list_disc_loss_real.append(disc_loss_real) list_disc_loss_gen.append(disc_loss_gen) ####################### # 2) Train the generator ####################### X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0])) # Unfreeze the discriminator discriminator_model.trainable = True gen_iterations += 1 batch_counter += 1 progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)), ("Loss_D_real", -np.mean(list_disc_loss_real)), ("Loss_D_gen", np.mean(list_disc_loss_gen)), ("Loss_G", -gen_loss)]) # Save images for visualization ~2 times per epoch if batch_counter % (n_batch_per_epoch / 2) == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, batch_size, noise_dim, image_dim_ordering) print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) # Save model weights (by default, every 5 epochs) data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e)
def train(model_name, **kwargs): """ Train model Load the whole train data in memory for faster operations args: model_name (str, keras model name) **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters nb_classes = kwargs["nb_classes"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] prob = kwargs["prob"] do_plot = kwargs["do_plot"] data_file = kwargs["data_file"] semi_super_file = kwargs["semi_super_file"] pretr_weights_file = kwargs["pretr_weights_file"] normalisation_style = kwargs["normalisation_style"] objective = kwargs["objective"] experiment = kwargs["experiment"] list_folds = kwargs["list_folds"] # Setup environment (logging directory etc) general_utils.setup_logging(experiment) # Compile model. # opt = RMSprop(lr=5E-6, rho=0.9, epsilon=1e-06) opt = SGD(lr=5e-4, decay=1e-6, momentum=0.9, nesterov=True) # opt = Adam(lr=1E-5, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Batch generator DataAug = batch_utils.AugDataGenerator(data_file, batch_size=batch_size, prob=prob, dset="train", maxproc=4, num_cached=60, random_augm=False, hdf5_file_semi=semi_super_file) DataAug.add_transform("h_flip") # DataAug.add_transform("v_flip") # DataAug.add_transform("fixed_rot", angle=40) DataAug.add_transform("random_rot", angle=40) # DataAug.add_transform("fixed_tr", tr_x=40, tr_y=40) DataAug.add_transform("random_tr", tr_x=40, tr_y=40) # DataAug.add_transform("fixed_blur", kernel_size=5) DataAug.add_transform("random_blur", kernel_size=5) # DataAug.add_transform("fixed_erode", kernel_size=4) DataAug.add_transform("random_erode", kernel_size=3) # DataAug.add_transform("fixed_dilate", kernel_size=4) DataAug.add_transform("random_dilate", kernel_size=3) # DataAug.add_transform("fixed_crop", pos_x=10, pos_y=10, crop_size_x=200, crop_size_y=200) DataAug.add_transform("random_crop", min_crop_size=140, max_crop_size=160) # DataAug.add_transform("hist_equal") # DataAug.add_transform("random_occlusion", occ_size_x=100, occ_size_y=100) epoch_size = n_batch_per_epoch * batch_size general_utils.pretty_print("Load all data...") with h5py.File(data_file, "r") as hf: X = hf["train_data"][:, :, :, :] y = hf["train_label"][:].astype(np.uint8) y = np_utils.to_categorical(y, nb_classes=nb_classes) # Format for keras try: for fold in list_folds: min_valid_loss = 100 # Save losses list_train_loss = [] list_valid_loss = [] # Load valid data in memory for fast error evaluation idx_valid = hf["valid_fold%s" % fold][:] idx_train = hf["train_fold%s" % fold][:] X_valid = X[idx_valid] y_valid = y[idx_valid] # Normalise X_valid = normalisation(X_valid, normalisation_style) # Compile model general_utils.pretty_print("Compiling...") model = models.load(model_name, nb_classes, X_valid.shape[-3:], pretr_weights_file=pretr_weights_file) model.compile(optimizer=opt, loss=objective) # Save architecture json_string = model.to_json() with open(os.path.join(data_dir, '%s_archi.json' % model.name), 'w') as f: f.write(json_string) for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 l_train_loss = [] start = time.time() for X_train, y_train in DataAug.gen_batch_inmemory( X, y, idx_train=idx_train): if do_plot: general_utils.plot_batch(X_train, np.argmax(y_train, 1), batch_size) # Normalise X_train = normalisation(X_train, normalisation_style) train_loss = model.train_on_batch(X_train, y_train) l_train_loss.append(train_loss) batch_counter += 1 progbar.add(batch_size, values=[("train loss", train_loss)]) if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) y_valid_pred = model.predict(X_valid, verbose=0, batch_size=16) train_loss = float(np.mean( l_train_loss)) # use float to make it json saveable valid_loss = log_loss(y_valid, y_valid_pred) print("Train loss:", train_loss, "valid loss:", valid_loss) list_train_loss.append(train_loss) list_valid_loss.append(valid_loss) # Record experimental data in a dict d_log = {} d_log["fold"] = fold d_log["nb_classes"] = nb_classes d_log["batch_size"] = batch_size d_log["n_batch_per_epoch"] = n_batch_per_epoch d_log["nb_epoch"] = nb_epoch d_log["epoch_size"] = epoch_size d_log["prob"] = prob d_log["optimizer"] = opt.get_config() d_log["augmentator_config"] = DataAug.get_config() d_log["train_loss"] = list_train_loss d_log["valid_loss"] = list_valid_loss json_file = os.path.join( exp_dir, 'experiment_log_fold%s.json' % fold) general_utils.save_exp_log(json_file, d_log) # Only save the best epoch if valid_loss < min_valid_loss: min_valid_loss = valid_loss trained_weights_path = os.path.join( exp_dir, '%s_weights_fold%s.h5' % (model.name, fold)) model.save_weights(trained_weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(**kwargs): """ Train model Load the whole train data in memory for faster operations args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] nb_epoch = kwargs["nb_epoch"] model_name = kwargs["model_name"] generator = kwargs["generator"] image_data_format = kwargs["image_data_format"] img_dim = kwargs["img_dim"] patch_size = kwargs["patch_size"] bn_mode = kwargs["bn_mode"] label_smoothing = kwargs["use_label_smoothing"] label_flipping = kwargs["label_flipping"] dset = kwargs["dset"] use_mbd = kwargs["use_mbd"] pretrained_model_path = kwargs["pretrained_model_path"] epoch_size = n_batch_per_epoch * batch_size # Setup environment (logging directory etc) general_utils.setup_logging(model_name) # Load and rescale data X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(dset, image_data_format) img_dim = X_full_train.shape[-3:] # Get the number of non overlapping patch and the size of input image to the discriminator nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size, image_data_format) try: # Create optimizers opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True) opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) load_pretrained = False if pretrained_model_path: load_pretrained = True # Load generator model generator_model = models.load("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size, load_pretrained) # Load discriminator model discriminator_model = models.load("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size, load_pretrained) generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False DCGAN_model = models.DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) loss = [l1_loss, 'binary_crossentropy'] loss_weights = [1E1, 1] DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) gen_loss = 100 disc_loss = 100 # Start training print("Start training") for e in range(nb_epoch): # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 1 start = time.time() for X_full_batch, X_sketch_batch in data_utils.gen_batch(X_full_train, X_sketch_train, batch_size): # Create a batch to feed the discriminator model X_disc, y_disc = data_utils.get_disc_batch(X_full_batch, X_sketch_batch, generator_model, batch_counter, patch_size, image_data_format, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next(data_utils.gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, "training") X_full_batch, X_sketch_batch = next(data_utils.gen_batch(X_full_val, X_sketch_val, batch_size)) data_utils.plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size, image_data_format, "validation") if batch_counter >= n_batch_per_epoch: break print("") print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start)) if e % 5 == 0: gen_weights_path = os.path.join('../../models/%s/gen_weights_epoch%s.h5' % (model_name, e)) generator_model.save_weights(gen_weights_path, overwrite=True) disc_weights_path = os.path.join('../../models/%s/disc_weights_epoch%s.h5' % (model_name, e)) discriminator_model.save_weights(disc_weights_path, overwrite=True) DCGAN_weights_path = os.path.join('../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e)) DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True) except KeyboardInterrupt: pass
def train(**kwargs): """ Train standard DCGAN model args: **kwargs (dict) keyword arguments that specify the model hyperparameters """ # Roll out the parameters generator = kwargs["generator"] dset = kwargs["dset"] img_dim = kwargs["img_dim"] nb_epoch = kwargs["nb_epoch"] batch_size = kwargs["batch_size"] n_batch_per_epoch = kwargs["n_batch_per_epoch"] noise_dim = kwargs["noise_dim"] noise_scale = kwargs["noise_scale"] lr_D = kwargs["lr_D"] lr_G = kwargs["lr_G"] opt_D = kwargs["opt_D"] opt_G = kwargs["opt_G"] clamp_lower = kwargs["clamp_lower"] clamp_upper = kwargs["clamp_upper"] image_data_format = kwargs["image_data_format"] save_weights_every_n_epochs = kwargs["save_weights_every_n_epochs"] save_only_last_n_weights = kwargs["save_only_last_n_weights"] visualize_images_every_n_epochs = kwargs["visualize_images_every_n_epochs"] model_name = kwargs["model_name"] epoch_size = n_batch_per_epoch * batch_size print("\nExperiment parameters:") for key in kwargs.keys(): print(key, kwargs[key]) print("\n") # Setup environment (logging directory etc) general_utils.setup_logging(**kwargs) # Load and normalize data X_real_train, X_batch_gen = data_utils.load_image_dataset( dset, img_dim, image_data_format, batch_size) # Get the full real image dimension img_dim = X_real_train.shape[-3:] # Create optimizers opt_G = data_utils.get_optimizer(opt_G, lr_G) opt_D = data_utils.get_optimizer(opt_D, lr_D) ####################### # Load models ####################### noise_dim = (noise_dim, ) if generator == "upsampling": generator_model = models.generator_upsampling(noise_dim, img_dim, dset=dset) else: generator_model = models.generator_deconv(noise_dim, img_dim, batch_size, dset=dset) discriminator_model = models.discriminator(img_dim) DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_dim) ############################ # Compile models ############################ generator_model.compile(loss='mse', optimizer=opt_G) discriminator_model.trainable = False DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G) discriminator_model.trainable = True discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D) # Global iteration counter for generator updates gen_iterations = 0 disc_losses = [] disc_losses_real = [] disc_losses_gen = [] gen_losses = [] ################# # Start training ################ try: for e in range(nb_epoch): print('--------------------------------------------') print('[{0:%Y/%m/%d %H:%M:%S}] Epoch {1:d}/{2:d}\n'.format( datetime.datetime.now(), e + 1, nb_epoch)) # Initialize progbar and batch counter progbar = generic_utils.Progbar(epoch_size) batch_counter = 0 start = time.time() disc_loss_batch = 0 disc_loss_real_batch = 0 disc_loss_gen_batch = 0 gen_loss_batch = 0 for batch_counter in range(n_batch_per_epoch): if gen_iterations < 25 or gen_iterations % 500 == 0: disc_iterations = 100 else: disc_iterations = kwargs["disc_iterations"] ################################### # 1) Train the critic / discriminator ################################### list_disc_loss_real = [] list_disc_loss_gen = [] for disc_it in range(disc_iterations): # Clip discriminator weights for l in discriminator_model.layers: weights = l.get_weights() weights = [ np.clip(w, clamp_lower, clamp_upper) for w in weights ] l.set_weights(weights) X_real_batch = next( data_utils.gen_batch(X_real_train, X_batch_gen, batch_size)) # Create a batch to feed the discriminator model X_disc_real, X_disc_gen = data_utils.get_disc_batch( X_real_batch, generator_model, batch_counter, batch_size, noise_dim, noise_scale=noise_scale) # Update the discriminator disc_loss_real = discriminator_model.train_on_batch( X_disc_real, -np.ones(X_disc_real.shape[0])) disc_loss_gen = discriminator_model.train_on_batch( X_disc_gen, np.ones(X_disc_gen.shape[0])) list_disc_loss_real.append(disc_loss_real) list_disc_loss_gen.append(disc_loss_gen) ####################### # 2) Train the generator ####################### X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim) # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0])) # Unfreeze the discriminator discriminator_model.trainable = True gen_iterations += 1 disc_loss_batch += -np.mean(list_disc_loss_real) - np.mean( list_disc_loss_gen) disc_loss_real_batch += -np.mean(list_disc_loss_real) disc_loss_gen_batch += np.mean(list_disc_loss_gen) gen_loss_batch += -gen_loss progbar.add(batch_size, values=[ ("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)), ("Loss_D_real", -np.mean(list_disc_loss_real)), ("Loss_D_gen", np.mean(list_disc_loss_gen)), ("Loss_G", -gen_loss) ]) # # Save images for visualization ~2 times per epoch # if batch_counter % (n_batch_per_epoch / 2) == 0: # data_utils.plot_generated_batch(X_real_batch, generator_model, # batch_size, noise_dim, image_data_format) disc_losses.append(disc_loss_batch / n_batch_per_epoch) disc_losses_real.append(disc_loss_real_batch / n_batch_per_epoch) disc_losses_gen.append(disc_loss_gen_batch / n_batch_per_epoch) gen_losses.append(gen_loss_batch / n_batch_per_epoch) # Save images for visualization if (e + 1) % visualize_images_every_n_epochs == 0: data_utils.plot_generated_batch(X_real_batch, generator_model, e, batch_size, noise_dim, image_data_format, model_name) data_utils.plot_losses(disc_losses, disc_losses_real, disc_losses_gen, gen_losses, model_name) # Save model weights (by default, every 5 epochs) data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e, save_weights_every_n_epochs, save_only_last_n_weights, model_name) end = time.time() print('\nEpoch %s/%s END, Time: %s' % (e + 1, nb_epoch, end - start)) start = end except KeyboardInterrupt: pass gen_weights_path = '../../models/%s/generator_latest.h5' % (model_name) print("Saving", gen_weights_path) generator_model.save(gen_weights_path, overwrite=True)