def test_loss_on_layer(self): class MyLayer(layers.Layer): def call(self, inputs): self.add_loss(math_ops.reduce_sum(inputs)) return inputs inputs = Input((3, )) layer = MyLayer() outputs = layer(inputs) model = Model(inputs, outputs) self.assertEqual(len(model.losses), 1) model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) self.assertEqual(loss, 2 * 3)
class BaseKerasModel(BaseModel): model = None tensorboard = None train_names = ['train_loss', 'train_mse', 'train_mae'] val_names = ['val_loss', 'val_mse', 'val_mae'] counter = 0 inputs = None hidden_layer = None outputs = None def __init__(self, use_default_dense=True, activation='relu', kernel_regularizer=tf.keras.regularizers.l1(0.001)): super().__init__() if use_default_dense: self.activation = activation self.kernel_regularizer = kernel_regularizer def create_input_layer(self, input_placeholder: BaseInputFormatter): """Creates keras model""" self.inputs = tf.keras.layers.InputLayer( input_shape=input_placeholder.get_input_state_dimension()) return self.inputs def create_hidden_layers(self, input_layer=None): if input_layer is None: input_layer = self.inputs hidden_layer = tf.keras.layers.Dropout(0.3)(input_layer) hidden_layer = tf.keras.layers.Dense( 128, kernel_regularizer=self.kernel_regularizer, activation=self.activation)(hidden_layer) hidden_layer = tf.keras.layers.Dropout(0.4)(hidden_layer) hidden_layer = tf.keras.layers.Dense( 64, kernel_regularizer=self.kernel_regularizer, activation=self.activation)(hidden_layer) hidden_layer = tf.keras.layers.Dropout(0.3)(hidden_layer) hidden_layer = tf.keras.layers.Dense( 32, kernel_regularizer=self.kernel_regularizer, activation=self.activation)(hidden_layer) hidden_layer = tf.keras.layers.Dropout(0.1)(hidden_layer) self.hidden_layer = hidden_layer return self.hidden_layer def create_output_layer(self, output_formatter: BaseOutputFormatter, hidden_layer=None): # sigmoid/tanh all you want on self.model if hidden_layer is None: hidden_layer = self.hidden_layer self.outputs = tf.keras.layers.Dense( output_formatter.get_model_output_dimension()[0], activation='tanh')(hidden_layer) self.model = Model(inputs=self.inputs, outputs=self.outputs) return self.outputs def write_log(self, callback, names, logs, batch_no, eval=False): for name, value in zip(names, logs): summary = tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value tag_name = name if eval: tag_name = 'eval_' + tag_name summary_value.tag = tag_name callback.writer.add_summary(summary, batch_no) callback.writer.flush() def finalize_model(self, logname=str(int(random() * 1000))): loss, loss_weights = self.create_loss() self.model.compile(tf.keras.optimizers.Nadam(lr=0.001), loss=loss, loss_weights=loss_weights, metrics=[ tf.keras.metrics.mean_absolute_error, tf.keras.metrics.binary_accuracy ]) log_name = './logs/' + logname self.logger.info("log_name: " + log_name) self.tensorboard = tf.keras.callbacks.TensorBoard( log_dir=log_name, histogram_freq=1, write_images=False, batch_size=1000, ) self.tensorboard.set_model(self.model) self.logger.info("Model has been finalized") def fit(self, x, y, batch_size=1): if self.counter % 200 == 0: logs = self.model.evaluate(x, y, batch_size=batch_size, verbose=1) self.write_log(self.tensorboard, self.model.metrics_names, logs, self.counter, eval=True) print('step:', self.counter) else: logs = self.model.train_on_batch(x, y) self.write_log(self.tensorboard, self.model.metrics_names, logs, self.counter) self.counter += 1 def predict(self, arr): return self.model.predict(arr) def save(self, file_path): self.model.save_weights(filepath=file_path, overwrite=True) def load(self, file_path): path = os.path.abspath(file_path) self.model.load_weights(filepath=os.path.abspath(file_path)) def create_loss(self): return 'mean_absolute_error', None
class RetroCycleGAN: def __init__(self, save_index="0", save_folder="./", generator_size=32, discriminator_size=64, word_vector_dimensions=300, discriminator_lr=0.0001, generator_lr=0.0001, lambda_cycle=1, lambda_id_weight=0.01, one_way_mm=True, cycle_mm=True, cycle_dis=True, id_loss=True, cycle_mm_w=2, cycle_loss=True): self.cycle_mm = cycle_mm self.cycle_dis = cycle_dis self.cycle_mae = cycle_loss self.id_loss = id_loss self.one_way_mm = one_way_mm self.cycle_mm_w = cycle_mm_w if self.cycle_mm else 0 self.save_folder = save_folder # Input shape self.word_vector_dimensions = word_vector_dimensions self.embeddings_dimensionality = (self.word_vector_dimensions,) # , self.channels) self.save_index = save_index # Number of filters in the first layer of G and D self.gf = generator_size self.df = discriminator_size # Loss weights self.lambda_cycle = lambda_cycle if self.cycle_mae else 0# Cycle-consistency loss self.lambda_id = lambda_id_weight if self.id_loss else 0 # Identity loss d_lr = discriminator_lr self.d_lr = d_lr g_lr = generator_lr self.g_lr = g_lr # cv = clip_value # cn = cn self.d_A = self.build_discriminator(name="word_vector_discriminator") self.d_B = self.build_discriminator(name="retrofitted_word_vector_discriminator") self.d_ABBA = self.build_c_discriminator(name="cycle_cond_discriminator_unfit") self.d_BAAB = self.build_c_discriminator(name="cycle_cond_discriminator_fit") # Best combo sofar SGD, gaussian, dropout,5,0.5 mml(0,5,.5),3x1024gen, 2x1024, no normalization # return Adam(lr,amsgrad=True,decay=1e-8) # ------------------------- # Construct Computational # Graph of Generators # ------------------------- # Build the generators self.g_AB = self.build_generator(name="to_retro_generator") # for layer in self.g_AB.layers: # a = layer.get_weights() # print(a) # self.d_A.summary() # self.g_AB.summary() # plot_model(self.g_AB, show_shapes=True) self.g_BA = self.build_generator(name="from_retro_generator") # self.d_B.summary() # self.g_BA.summary() # Input images from both domains unfit_wv = Input(shape=self.embeddings_dimensionality, name="plain_word_vector") fit_wv = Input(shape=self.embeddings_dimensionality, name="retrofitted_word_vector") # # Translate images to the other domain fake_B = self.g_AB(unfit_wv) fake_A = self.g_BA(fit_wv) # Translate images back to original domain reconstr_A = self.g_BA(fake_B) reconstr_B = self.g_AB(fake_A) print("Building recon model") # self.reconstr = Model(inputs=[unfit_wv,fit_wv],outputs=[reconstr_A,reconstr_B]) print("Done") # Identity mapping of images unfit_wv_id = self.g_BA(unfit_wv) fit_wv_id = self.g_AB(fit_wv) # For the combined model we will only train the generators # Discriminators determines validity of translated images valid_A = self.d_A(fake_A) valid_B = self.d_B(fake_B) # Combined model trains generators to fool discriminators self.d_A.trainable = False self.d_B.trainable = False # self.d_ABBA.trainable = False # self.d_BAAB.trainable = False self.combined = Model(inputs=[unfit_wv, fit_wv], # Model that does A->B->A (left), B->A->B (right) outputs=[valid_A, valid_B, # for the bce calculation reconstr_A, reconstr_B, # for the mae calculation reconstr_A, reconstr_B, # for the max margin calculation unfit_wv_id, fit_wv_id, # dAc_r, dBc_r, # for the conditional discriminator margin calculation # dAc_fake, dBc_fake # for the conditional discriminator margin calculation ], # for the id loss calculation name="combinedmodel") log_path = './logs' callback = keras.callbacks.TensorBoard(log_dir=log_path) callback.set_model(self.combined) self.combined_callback = callback def compile_all(self, optimizer="sgd"): def max_margin_loss(y_true, y_pred): cost = 0 sim_neg = 25 sim_margin = 1 for i in range(0, sim_neg): new_true = tf.random.shuffle(y_true) normalize_a = tf.nn.l2_normalize(y_true) normalize_b = tf.nn.l2_normalize(y_pred) normalize_c = tf.nn.l2_normalize(new_true) minimize = tf.reduce_sum(tf.multiply(normalize_a, normalize_b)) maximize = tf.reduce_sum(tf.multiply(normalize_a, normalize_c)) mg = sim_margin - minimize + maximize # print(mg) cost += tf.keras.backend.clip(mg, 0, 1000) return cost / (sim_neg * 1.0) def create_opt(lr=0.1): if optimizer == "adam": opt = tf.optimizers.Adam(lr=lr, epsilon=1e-10) return opt else: raise KeyError("coULD NOT FIND THE OPTIMIZER") # self.d_A.trainable = True # self.d_B.trainable = True self.d_A.compile(loss='binary_crossentropy', optimizer=create_opt(self.d_lr), metrics=['accuracy']) self.d_ABBA.compile(loss='binary_crossentropy', optimizer=create_opt(self.d_lr), metrics=['accuracy']) self.d_BAAB.compile(loss='binary_crossentropy', optimizer=create_opt(self.d_lr), metrics=['accuracy']) self.d_B.compile(loss='binary_crossentropy', optimizer=create_opt(self.d_lr), metrics=['accuracy']) # self.d_A.trainable = False # self.d_B.trainable = False self.g_AB.compile(loss=max_margin_loss, optimizer=create_opt(self.g_lr), ) self.g_BA.compile(loss=max_margin_loss, optimizer=create_opt(self.g_lr), ) self.combined.compile(loss=['binary_crossentropy', 'binary_crossentropy', 'mae', 'mae', max_margin_loss, max_margin_loss, 'mae', 'mae', ], loss_weights=[1, 1, self.lambda_cycle * 1, self.lambda_cycle * 1, self.cycle_mm_w, self.cycle_mm_w, self.lambda_id, self.lambda_id, # self.lambda_cycle * 1, self.lambda_cycle * 1, # self.lambda_cycle * 1, self.lambda_cycle * 1 ], optimizer=create_opt(self.g_lr)) # self.combined.summary() self.g_AB.summary() self.d_A.summary() self.combined.summary() def build_generator(self, name, hidden_dim=2048): """U-Net Generator""" def dense(layer_input, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.2): d = Dense(hidden_dim, activation="relu")(layer_input) if normalization: d = BatchNormalization()(d) if dropout: d = Dropout(dropout_percentage)(d) return d # Image input inpt = Input(shape=self.embeddings_dimensionality) encoder = dense(inpt, hidden_dim, normalization=False, dropout=True, dropout_percentage=0.2) decoder = dense(encoder, hidden_dim, normalization=False, dropout=True, dropout_percentage=0.2) # +encoder output = Dense(self.word_vector_dimensions)(decoder) return Model(inpt, output, name=name) def build_discriminator(self, name, hidden_dim=2048): def d_layer(layer_input, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.3): """Discriminator layer""" d = Dense(hidden_dim, activation="relu")(layer_input) if normalization: d = BatchNormalization()(d) if dropout: d = Dropout(dropout_percentage)(d) return d inpt = Input(shape=self.embeddings_dimensionality) d1 = d_layer(inpt, hidden_dim, normalization=False, dropout=True, dropout_percentage=0.3) d1 = d_layer(d1, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.3) validity = Dense(1, activation="sigmoid", dtype='float32')(d1) return Model(inpt, validity, name=name) def build_c_discriminator(self, name, hidden_dim=2048): def d_layer(layer_input, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.3): """Discriminator layer""" d = Dense(hidden_dim, activation="relu")(layer_input) if normalization: d = BatchNormalization()(d) if dropout: d = Dropout(dropout_percentage)(d) return d inpt = Input(shape=600) d1 = d_layer(inpt, hidden_dim, normalization=False, dropout=True, dropout_percentage=0.3) d1 = d_layer(d1, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.3) validity = Dense(1, activation="sigmoid", dtype='float32')(d1) return Model(inpt, validity, name=name) def load_weights(self, preface="", folder=None): if folder is None: folder = self.save_folder try: self.g_AB.reset_states() self.g_BA.reset_states() self.combined.reset_states() self.d_B.reset_states() self.d_A.reset_states() self.d_A.load_weights(os.path.join(folder, preface + "fromretrodis.h5")) self.d_B.load_weights(os.path.join(folder, preface + "toretrodis.h5")) self.g_AB.load_weights(os.path.join(folder, preface + "toretrogen.h5")) self.g_BA.load_weights(os.path.join(folder, preface + "fromretrogen.h5")) self.combined.load_weights(os.path.join(folder, preface + "combined_model.h5")) except Exception as e: print(e) def train(self, epochs, dataset, save_folder, name, batch_size=1, cache=False, epochs_per_checkpoint=4, dis_train_amount=3): wandb.init(project="retrogan", dir=save_folder) wandb.run.name = name # wandb.watch(self.g_AB,criterion="simlex") wandb.run.save() self.name = name start_time = datetime.datetime.now() res = [] X_train, Y_train = tools.load_all_words_dataset_final(dataset["original"], dataset["retrofitted"], save_folder=save_folder, cache=cache) print("Shapes of training data:", X_train.shape, Y_train.shape) print(X_train) print(Y_train) print("*" * 100) def load_batch(batch_size=32, always_random=False): def _int_load(): iterable = list(Y_train.index) shuffle(iterable) batches = [] print("Prefetching batches") for ndx in tqdm(range(0, len(iterable), batch_size)): try: ixs = iterable[ndx:min(ndx + batch_size, len(iterable))] if always_random: ixs = list(np.array(iterable)[random.sample(range(0, len(iterable)), batch_size)]) imgs_A = X_train.loc[ixs] imgs_B = Y_train.loc[ixs] if np.isnan(imgs_A).any().any() or np.isnan(imgs_B).any().any(): # np.isnan(imgs_B).any(): # print(ixs) continue batches.append((imgs_A, imgs_B)) except Exception as e: print("Skipping batch") # print(e) return batches batches = _int_load() print("Beginning iteration") for i in tqdm(range(0, len(batches)), ncols=30): imgs_A, imgs_B = batches[i] yield np.array(imgs_A.values, dtype=np.float32), np.array(imgs_B.values, dtype=np.float32) # def load_random_batch(batch_size=32, batch_amount=1000000): # iterable = list(Y_train.index) # # shuffle(iterable) # ixs = list(np.array(iterable)[random.sample(range(0, len(iterable)), batch_size)]) # imgs_A = X_train.loc[ixs] # imgs_B = Y_train.loc[ixs] # def test_nan(a,b): # return np.isnan(a).any().any() or np.isnan(b).any().any() # while True: # if(test_nan(imgs_A,imgs_B)): # ixs = list(np.array(iterable)[random.sample(range(0, len(iterable)), batch_size)]) # imgs_A = X_train.loc[ixs] # imgs_B = Y_train.loc[ixs] # else: # break # return imgs_A, imgs_B # # def exp_decay(epoch): # initial_lrate = 0.1 # k = 0.1 # lrate = initial_lrate * math.exp(-k * epoch) # return lrate # noise = np.random.normal(size=(1, dimensionality), scale=0.001) # noise = np.tile(noise,(batch_size,1)) dis_train_amount = dis_train_amount self.compile_all("adam") # ds = tf.data.Dataset.from_generator(load_batch,(tf.float32,tf.float32),args=(batch_size,)) # ds = ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) def train_(training_epochs, always_random=False): global_step = 0 for epoch in range(training_epochs): # noise = np.random.normal(size=(batch_size, dimensionality), scale=0.01) for batch_i, (imgs_A, imgs_B) in enumerate(load_batch(batch_size, always_random=always_random)): global_step += 1 # for batch_i, (imgs_A, imgs_B) in enumerate(ds): # try: # if epoch % 2 == 0: # # print("Adding noise") # imgs_A = np.add(noise[0:imgs_A.shape[0], :], imgs_A) # imgs_B = np.add(noise[0:imgs_B.shape[0], :], imgs_B) # imgs_A = tf.cast(imgs_A, tf.float32) # imgs_B = tf.cast(imgs_B, tf.float32) fake_B = self.g_AB.predict(imgs_A) fake_A = self.g_BA.predict(imgs_B) fake_ABBA = self.g_BA.predict(fake_B) fake_BAAB = self.g_AB.predict(fake_A) # Train the discriminators (original images = real / translated = Fake) dA_loss = None dB_loss = None valid = np.ones((imgs_A.shape[0],)) # *noisy_entries_num,) ) fake = np.zeros((imgs_A.shape[0],)) # *noisy_entries_num,) ) # self.d_A.trainable = True # self.d_B.trainable = True for _ in range(int(dis_train_amount)): # da = self.d_A.evaluate(imgs_A) dA_loss_real = self.d_A.train_on_batch(imgs_A, valid) # daf = self.d_A(fake_A) dA_loss_fake = self.d_A.train_on_batch(fake_A, fake) if dA_loss is None: dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake) else: dA_loss += 0.5 * np.add(dA_loss_real, dA_loss_fake) dB_loss_real = self.d_B.train_on_batch(imgs_B, valid) dB_loss_fake = self.d_B.train_on_batch(fake_B, fake) if dB_loss is None: dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake) else: dB_loss += 0.5 * np.add(dB_loss_real, dB_loss_fake) d_loss = (1.0 / dis_train_amount) * 0.5 * np.add(dA_loss, dB_loss) # self.d_A.trainable = False # self.d_B.trainable = False def CycleCondLoss(d_ground, d_approx): l = tf.math.log(d_ground) + tf.math.log(1 - d_approx) return -1 * tf.reduce_mean(l) # train cycle discriminators d_cycle_dis = 0 g_cycle_dis = 0 if self.cycle_dis: with tf.GradientTape() as tape: dA = self.d_ABBA(tf.concat([fake_B, imgs_A], 1)) dA_r = self.d_ABBA(tf.concat([fake_B, fake_ABBA], 1)) la = CycleCondLoss(dA, dA_r) tga = tape.gradient(la, self.d_ABBA.trainable_variables) self.d_ABBA.optimizer.apply_gradients(zip(tga, self.d_ABBA.trainable_variables)) d_cycle_dis += la with tf.GradientTape() as tape: dB = self.d_BAAB(tf.concat([fake_A, imgs_B], 1)) dB_r = self.d_BAAB(tf.concat([fake_A, fake_BAAB], 1)) lb = CycleCondLoss(dB, dB_r) tgb = tape.gradient(lb, self.d_BAAB.trainable_variables) self.d_BAAB.optimizer.apply_gradients(zip(tgb, self.d_BAAB.trainable_variables)) d_cycle_dis += lb with tf.GradientTape() as tape: fake_B = self.g_AB(imgs_A) fake_A = self.g_BA(imgs_B) fake_ABBA = self.g_BA(fake_B) fake_BAAB = self.g_AB(fake_A) dB = self.d_BAAB(tf.concat([fake_A, imgs_B], 1)) dB_r = self.d_BAAB(tf.concat([fake_A, fake_BAAB], 1)) dA = self.d_ABBA(tf.concat([fake_B, imgs_A], 1)) dA_r = self.d_ABBA(tf.concat([fake_B, fake_ABBA], 1)) la = CycleCondLoss(dA, dA_r) lb = CycleCondLoss(dB, dB_r) tga = tape.gradient((la + lb) / 2.0, self.combined.trainable_variables) self.combined.optimizer.apply_gradients(zip(tga, self.combined.trainable_variables)) g_cycle_dis += (la + lb) / 2.0 # Calculate the max margin loss for A->B, B->A mm_b_loss = 0 mm_a_loss = 0 if self.one_way_mm: mm_a_loss = self.g_AB.train_on_batch(imgs_A, imgs_B) mm_b_loss = self.g_BA.train_on_batch(imgs_B, imgs_A) # Calculate the cycle A->B->A, B->A->B with max margin, and mae # Train cycle dis g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B, imgs_A, imgs_B, # valid,valid, # valid,valid ]) def named_logs(model, logs): result = {} for l in zip(model.metrics_names, logs): result[l[0]] = l[1] return result r = named_logs(self.combined, g_loss) r.update({ 'mma': mm_a_loss, 'mmb': mm_b_loss, }) elapsed_time = datetime.datetime.now() - start_time if batch_i % 50 == 0 and batch_i != 0: print( "\n[Epoch %d/%d] [Batch %d] [D loss: %f, acc: %3d%%] " "[G loss: %05f, adv: %05f, recon: %05f, recon_mm: %05f,id: %05f][mma:%05f,mmb:%05f]time: %s " \ % (epoch, training_epochs, batch_i, d_loss[0], 100 * d_loss[1], g_loss[0], np.mean(g_loss[1:3]), np.mean(g_loss[3:5]), np.mean(g_loss[5:7]), np.mean(g_loss[7:8]), mm_a_loss, mm_b_loss, elapsed_time)) scalars = { "epoch": epoch, # "batch": batch_i, "global_step": global_step, "discriminator_loss": d_loss[0], "discriminator_acc": d_loss[1], "combined_loss": g_loss[0]+g_cycle_dis+d_cycle_dis, "loss": g_loss[0] + d_loss[0], "cycle_da": g_loss[1], "cycle_db": g_loss[2], "cycle_dis": d_cycle_dis, "cycle_gen_condis":g_cycle_dis, "MM_ABBA_CYCLE": g_loss[5], "MM_BAAB_CYCLE": g_loss[6], "abba_mae": g_loss[3], "baab_mae": g_loss[4], "idloss_ab": g_loss[7], "idloss_ba": g_loss[8], "mm_ab_loss": mm_a_loss, "mm_ba_loss": mm_b_loss, } wandb.log(scalars, step=global_step) # wandbcb.on_batch_end(batch_i, r) # wandb.log({"batch_num":batch_i,"epoch_num":epoch}) # self.combined_callback.on_batch_end(batch_i, r) print("\n") sl, sv,c = self.test(dataset) if epoch % epochs_per_checkpoint == 0 and epoch != 0: self.save_model(name="checkpoint") res.append((sl, sv, c)) wandb.log({"simlex": sl, "simverb": sv, "card":c,"epoch": epoch}) # self.combined_callback.on_epoch_end(epoch, {"simlex": sl, "simverb": sv}) # wandbcb.on_epoch_end(epoch, {"simlex": sl, "simverb": sv}) print(res) print("\n") print("Actual training") train_(epochs) print("Final performance") sl, sv,c = self.test(dataset) res.append((sl, sv,c)) self.save_model(name="final") return res def test(self, dataset, simlex="testing/SimLex-999.txt", simverb="testing/SimVerb-3500.txt",card="testing/card660.tsv", fasttext="fasttext_model/cc.en.300.bin", prefix="en_"): sl = tools.test_sem(self.g_AB, dataset, dataset_location=simlex, fast_text_location=fasttext, prefix=prefix,pt=False)[0] sv = tools.test_sem(self.g_AB, dataset, dataset_location=simverb, fast_text_location=fasttext, prefix=prefix,pt=False)[0] c = tools.test_sem(self.g_AB, dataset, dataset_location=card, fast_text_location=fasttext, prefix=prefix,pt=False)[0] return sl, sv,c def save_model(self, name=""): self.d_A.save(os.path.join(self.save_folder, name + "fromretrodis.h5"), include_optimizer=False) self.d_B.save(os.path.join(self.save_folder, name + "toretrodis.h5"), include_optimizer=False) self.g_AB.save(os.path.join(self.save_folder, name + "toretrogen.h5"), include_optimizer=False) self.g_BA.save(os.path.join(self.save_folder, name + "fromretrogen.h5"), include_optimizer=False) self.combined.save(os.path.join(self.save_folder, name + "combined_model.h5"), include_optimizer=False)
(train_x1, train_x2, train_fov, train_y, test_x1, test_x2, test_fov, test_y, images) = get_dataset() train_batch_size = 64 # train sum_logs = [] for batch in range(50000001): idx = np.random.randint(0, len(train_x1), train_batch_size) images_idx_x1 = train_x1[idx] images_idx_x2 = train_x2[idx] images_x1 = images[images_idx_x1] / 255. images_x2 = images[images_idx_x2] / 255. images_fov = train_fov[idx] result = train_y[idx] logs = model.train_on_batch(x=[images_x1, images_x2, images_fov], y=result) sum_logs.append(logs) if batch % 200 == 0 and batch > 0: # check model on the validation data valid_idx = np.random.randint(0, len(test_x1), train_batch_size) valid_images_idx_x1 = test_x1[valid_idx] valid_images_idx_x2 = test_x2[valid_idx] valid_images_x1 = images[valid_images_idx_x1] / 255. valid_images_x2 = images[valid_images_idx_x2] / 255. valid_images_fov = train_fov[valid_idx] valid_result = test_y[valid_idx] v_loss = model.test_on_batch( x=[valid_images_x1, valid_images_x2, valid_images_fov], y=valid_result)
class FMatrixGanModel: """ Defines the complete model with generator, regressor and discriminator. This includes the low level training and prediction methods for this model, like the GAN training. """ def __init__(self, params, model_folder, img_size): """ Inits the model. :param params: Hyperparameters :param model_folder: Folder path, in which all results and temporary data of the model is stored. :param img_size: (image_width, image_height), defining the size of the input images. """ if not isinstance(params, Params): params = Params(params) self.params = params self.model_folder = model_folder # inputs input_shape = (img_size[0], img_size[1], 1) img_A, img_B = Input(shape=input_shape), Input(shape=input_shape) # --- build models discriminator_model, frozen_discriminator_model = build_discriminator_models( img_size, params) generator_with_regressor_model, generator_model, generator_with_output_model, regressor_model = \ build_generator_with_regressor_models(img_size, params) # --- models self.discriminator = discriminator_model self.regressor = regressor_model self.generator = generator_model self.generator_with_output = generator_with_output_model self.generator_with_regressor = generator_with_regressor_model # model: GAN without regressor and without output fake_B = generator_model(img_A) gan_out = frozen_discriminator_model(fake_B) self.gan = Model(inputs=img_A, outputs=gan_out) # model: GAN with regressor if params['use_images']: fake_B, fmatrix = generator_with_regressor_model([img_A, img_B]) gan_out = frozen_discriminator_model(fake_B) self.gan_with_regressor = Model(inputs=[img_A, img_B], outputs=[gan_out, fmatrix]) else: fake_B, fmatrix = generator_with_regressor_model(img_A) gan_out = frozen_discriminator_model(fake_B) self.gan_with_regressor = Model(inputs=img_A, outputs=[gan_out, fmatrix]) # --- compile models self.discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=params['lr_D'], beta_1=0.9, beta_2=0.999, epsilon=1e-08), metrics=['accuracy']) self.regressor.compile(loss='mean_squared_error', optimizer=Adam(lr=params['lr_R'], beta_1=0.9, beta_2=0.999, epsilon=1e-08)) # generators do not need to be compiled as they are compiled within the GANs if params['freeze_discriminator']: frozen_discriminator_model.trainable = False self.gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=params['lr_G'], beta_1=0.9, beta_2=0.999, epsilon=1e-08)) loss_weights = params['generator_loss_weights'] assert len(loss_weights) == 2 self.gan_with_regressor.compile( loss=['binary_crossentropy', 'mean_squared_error'], loss_weights=loss_weights, optimizer=Adam(lr=params['lr_G'], beta_1=0.9, beta_2=0.999, epsilon=1e-08)) self.__models_with_weights = [ self.generator, self.regressor, self.discriminator ] def generate_img(self, img_A): """ Generates an image from img_A using the generator and its current weights. :param img_A: Input image to the generator. Dimension: (img_width, img_height) :return: The generated image """ img_A = _img_to_img_batch(img_A) img_B = self.generator.predict(img_A) return img_B[0] # only 1 sample in batch def generate_regression_input(self, img_pair): """ Generates the regression input for the given image pair using the current weights. :param img_pair: Input image pair. Dimension: (img_width, img_height, 2) :return: The regression input which can be passed into the regressor. This is a list of inputs which may include the image pair, the bottleneck and the derived feature layers. """ img_A, img_B = _img_pair_to_img_batches(img_pair) generator_output, *regression_input = self.generator_with_output.predict( img_A) # for each of the elements in regression input only select the first sample (there are only 1 samples) # because batches are returned regression_input = [batch[0] for batch in regression_input] if self.params['use_images']: regression_input.append(img_pair) return regression_input def regress_from_img(self, img_pair): """ Regresses the fundamental matrix from the given image pair using the current weights. :param img_pair: Input image pair. Dimension: (img_width, img_height, 2) :return: Fundamental matrix. Dimension: (3, 3) """ img_A, img_B = _img_pair_to_img_batches(img_pair) if self.params['use_images']: # as the regressor also uses images, imgA and imgB are needed gen_img, F = self.generator_with_regressor.predict([img_A, img_B]) else: # the regressor uses no images, only input from generator and the generator only needs imgA gen_img, F = self.generator_with_regressor.predict(img_A) return F[0] # only 1 sample in batch def regress_from_regression_input(self, regression_input): """ Regresses the fundamental matrix from the given regression inputs using the current weights (of the regressor) :param regression_input: Regression input which was generated using generate_regression_input. :return: Fundamental matrix. Dimension: (3, 3) """ regression_input = _regression_input_to_batch(regression_input) F_batch = self.regress_from_regression_input_batch(regression_input) return F_batch[0] # only 1 sample in batch def regress_from_regression_input_batch(self, regression_input_batch): """ Regresses the fundamental matrices for a batch of regression inputs using the current weights (of the regressor) :param regression_input_batch: Batch of regression inputs which were generated using generate_regression_input. :return: Batch of fundamental matrices. Dimension: (None, 3, 3) """ F_batch = self.regressor.predict(regression_input_batch) return F_batch def train_GAN(self, img_pair, epochs, discr_iterations, plot_interval=None, img_path_prefix=None, check_img_mse=False, verbose=0): """ Trains the GAN for the given image pair. :param img_pair: Image pair to train the GAN. :param epochs: Number of training epochs. :param discr_iterations: Number of discriminator iterations in each epoch. :param plot_interval: How often is the generated image plot. - None: no plotting - positive integer: plot every nth epoch - -1: plot only after the last epoch. :param img_path_prefix: Prefix for plotted image files. If None: image is only plotted but not saved. :param check_img_mse: bool - Check and store the image mean squared error of the generated image in the history. :param verbose: Verbosity level: 0 to 2 :return: History of the GAN training, dictionary of lists """ return self.__do_train_gan_for_sample( img_pair, epochs=epochs, discr_iterations=discr_iterations, plot_interval=plot_interval, img_path_prefix=img_path_prefix, check_img_mse=check_img_mse, verbose=verbose) def train_GAN_and_regressor(self, img_pair, F_true, epochs, discr_iterations, plot_interval=None, img_path_prefix=None, check_img_mse=False, verbose=0): """ Train the GAN and the regressor together using a combined loss for generator and regressor. :param img_pair: Image pair to train. :param F_true: Ground truth fundamental matrix. :param epochs: Number of training epochs. :param discr_iterations: Number of discriminator iterations in each epoch. :param plot_interval: plot_interval: How often is the generated image plot. - None: no plotting - positive integer: plot every nth epoch - -1: plot only after the last epoch. :param img_path_prefix: Prefix for plotted image files. If None: image is only plotted but not saved. :param check_img_mse: bool - Check and store the image mean squared error of the generated image in the history. :param verbose: Verbosity level: 0 to 2 :return: History of the training, dictionary of lists """ return self.__do_train_gan_for_sample( img_pair, F_true=F_true, epochs=epochs, discr_iterations=discr_iterations, plot_interval=plot_interval, img_path_prefix=img_path_prefix, check_img_mse=check_img_mse, verbose=verbose) def train_regressor(self, regression_input, F_true): """ Trains the regressor for the given regression input and F_true. The regressor is only trained for one epoch on that single sample. :param regression_input: Regression input to train for. :param F_true: Ground truth fundamental matrix. :return: History of the training, dictionary of lists """ return self.train_regressor_batch( _regression_input_to_batch(regression_input), _F_to_F_batch(F_true)) def train_regressor_batch(self, regression_input_batch, F_true_batch): """ Trains the regressor for the given regression input and F_true batch. The regressor is only trained for one epoch on that batch. :param regression_input_batch: Batch of regression inputs to train for. :param F_true_batch: Batch of Ground truth fundamental matrices. :return: History of the training, dictionary of lists """ loss = self.regressor.train_on_batch(regression_input_batch, F_true_batch) return loss def update_regressor_lr(self, update_fn): """ Updates the current learning rate of the regressor using the given update function. The update function gets the old lr as input and should return the new lr. This new lr is then set as the regressor lr. :param update_fn: Function applied to compute new lr: update_fn(old_lr: float) -> new_lr: float :return: New learning rate which was set. """ old_lr = float(K.get_value(self.regressor.optimizer.lr)) new_lr = update_fn(old_lr) K.set_value(self.regressor.optimizer.lr, new_lr) return new_lr def save_weights(self, file_prefix=None): """ Save all model weights. Multiple files will be stored, as this model has multiple sub models. All files are stored within the model folder. :param file_prefix: If defined, use this as prefix for the weight file names. If None: store temporary weights. """ if file_prefix is None: file_prefix = TMP_WEIGHTS_FILE_PREFIX for i, model in enumerate(self.__models_with_weights): model.save_weights(self.model_folder + '/' + file_prefix + ('_%d.h5' % i)) def load_weights(self, file_prefix=None, remove=False): """ Loads all model weights. All files are loaded from within the model folder. :param file_prefix: file_prefix: If defined, use this as prefix for the weight file names. If None: load temporary weights. :param remove: If True, remove the loaded weight files. """ if file_prefix is None: file_prefix = TMP_WEIGHTS_FILE_PREFIX for i, model in enumerate(self.__models_with_weights): file = self.model_folder + '/' + file_prefix + ('_%d.h5' % i) model.load_weights(file) if remove: os.remove(file) def plot_models(self, file_prefix): """ Plots all sub models :param file_prefix: """ print('Plotting model with file prefix %s' % file_prefix) plot_model(self.generator, to_file=file_prefix + 'generator.png', show_shapes=True) plot_model(self.generator_with_output, to_file=file_prefix + 'generator_with_output.png', show_shapes=True) plot_model(self.generator_with_regressor, to_file=file_prefix + 'generator_with_regressor.png', show_shapes=True) plot_model(self.discriminator, to_file=file_prefix + '_discriminator.png', show_shapes=True) plot_model(self.regressor, to_file=file_prefix + 'regressor.png', show_shapes=True) plot_model(self.gan, to_file=file_prefix + 'gan.png', show_shapes=True) plot_model(self.gan_with_regressor, to_file=file_prefix + 'gan_with_regressor.png', show_shapes=True) # verbose=0 -> no logging # verbose=1 -> only show current epoch # verbose=2 -> show epoch results and details # plot_interval: None -> disabled, >0 every i epochs, -1 only at the end of all epochs def __do_train_gan_for_sample(self, img_pair_sample, F_true=None, epochs=1, discr_iterations=1, plot_interval=None, img_path_prefix=None, check_img_mse=False, verbose=0): img_A, img_B = _img_pair_to_img_batches(img_pair_sample) if F_true is not None: F_true = _F_to_F_batch(F_true) if plot_interval is not None and img_path_prefix is not None: # Save original images for later debugging _save_imgs(img_A, img_B, img_path_prefix + 'img_A.png', img_path_prefix + 'img_B.png') valid = np.array([1]) fake = np.array([0]) generator_loss_history = [] generator_gen_loss_history = [] generator_F_history = [] discriminator_history = [] discriminator_real_history = [] discriminator_fake_history = [] img_mse_history = [] fake_B = self.generator.predict(img_A) # generate fake B for 1st epoch for epoch in range(1, epochs + 1): if verbose == 1: print(('--> GAN epoch %d/%d' % (epoch, epochs)).ljust(100), end='\r') elif verbose > 1: print('--> GAN epoch %d/%d' % (epoch, epochs)) # --- train discriminator if verbose >= 2: print('-----> Train D...'.ljust(100), end='\r') discr_input = np.concatenate([img_B, fake_B]) discr_target = np.concatenate([valid, fake]) for it in range(1, discr_iterations + 1): discriminator_loss_real, discriminator_loss_fake = self.discriminator.train_on_batch( discr_input, discr_target) discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2 if verbose >= 2: print( '-----> D iteration %d/%d [loss: %f, real_loss: %f, fake_loss: %f]' .ljust(100) % (it, discr_iterations, discriminator_loss, discriminator_loss_real, discriminator_loss_fake), end='\r') discriminator_history.append(discriminator_loss) discriminator_real_history.append(discriminator_loss_real) discriminator_fake_history.append(discriminator_loss_fake) # --- train generator if verbose >= 2: print('-----> Train G...'.ljust(100), end='\r') if F_true is None: generator_loss = self.gan.train_on_batch(img_A, valid) generator_loss_history.append(generator_loss) if verbose == 1: print( ('--> GAN epoch %d/%d [D - loss: %f] [G - loss: %f]' % (epoch, epochs, discriminator_loss, generator_loss)).ljust(100), end='\r') elif verbose > 1: print(('---> [D - loss: %f] [G - loss: %f]' % (discriminator_loss, generator_loss)).ljust(100)) else: if self.params['use_images']: loss, generator_loss, fmatrix_loss = self.gan_with_regressor.train_on_batch( [img_A, img_B], [valid, F_true]) else: loss, generator_loss, fmatrix_loss = self.gan_with_regressor.train_on_batch( img_A, [valid, F_true]) generator_loss_history.append(loss) generator_gen_loss_history.append(generator_loss) generator_F_history.append(fmatrix_loss) if verbose == 1: print(( '--> GAN epoch %d/%d [D - loss: %f] [G - loss: %f, gen_loss: %f, F_loss: %f]' % (epoch, epochs, discriminator_loss, loss, generator_loss, fmatrix_loss)).ljust(100), end='\r') elif verbose > 1: print(( '---> [D - loss: %f] [G - loss: %f, gen_loss: %f, F_loss: %f]' % (discriminator_loss, loss, generator_loss, fmatrix_loss)).ljust(100)) # Generate for next epoch and for results checking (so that the img has not to be generated twice) fake_B = self.generator.predict(img_A) if check_img_mse: img_mse = _calc_image_mse(img_A, img_B, fake_B) img_mse_history.append(img_mse) if verbose >= 2: print('---> [image_mse: %f]' % img_mse) if plot_interval is not None and plot_interval != -1 and epoch % plot_interval == 0: if img_path_prefix is not None: img_path = img_path_prefix + ('generated_B_%04d.png' % epoch) else: img_path = None _plot_img(img_A, img_B, fake_B, img_path) if plot_interval == -1: if img_path_prefix is not None: img_path = img_path_prefix + 'generated_B.png' else: img_path = None _plot_img(img_A, img_B, fake_B, img_path) if F_true is None: return { 'discriminator_loss': discriminator_history, 'discriminator_loss_real': discriminator_real_history, 'discriminator_loss_fake': discriminator_fake_history, 'generator_loss': generator_loss_history, 'img_mse': img_mse_history } else: return { 'discriminator_loss': discriminator_history, 'discriminator_loss_real': discriminator_real_history, 'discriminator_loss_fake': discriminator_fake_history, 'generator_loss': generator_loss_history, 'generator_loss_gen': generator_gen_loss_history, 'generator_F_loss': generator_F_history, 'img_mse': img_mse_history }
class DCGan: def __init__(self, input_shape, config): self.config = config self.config['model']['generator']['z_size'] = self.config['data']['z_size'] self.input_shape = input_shape self._build_model() def _build_model(self): # Generator model_input = Input(self.config['data']['z_size'], name='gan_input') self.generator = DCGan._build_generator(model_input, self.input_shape, self.config['model']['generator']) # Discriminator self.discriminator = DCGan._build_discriminator(self.input_shape, self.config['model']['discriminator']) # GAN #self.gan = Sequential([self.generator, self.discriminator]) model_output = self.discriminator(self.generator(model_input)) self.gan = Model(model_input, model_output) # Compile discriminator # discriminator_optimizer = RMSprop(lr=self.config['training']['discriminator']['learning_rate']) # self.discriminator.compile(loss="binary_crossentropy", optimizer=discriminator_optimizer, # metrics=['accuracy']) # # # Compile generator # # taken into account only when compiling a model, # # so discriminator is trainable if we call its fit() method while not for the gan model # self.discriminator.trainable = False # # gan_optimizer = RMSprop(lr=self.config['training']['generator']['learning_rate']) # self.gan.compile(loss="binary_crossentropy", optimizer=gan_optimizer, # metrics=['accuracy']) # # self.discriminator.trainable = True # Already with a basic GAN setup we break the use of model.fit and related utilities # need to find a way to manage callbacks and validation # This still doesn't work for a problem possibly related to a bug with nested models https://github.com/keras-team/keras/issues/10074 # Not compiling the discriminator in fact doesn't trigger the error anymore def train(self, train_ds, validation_ds, nb_epochs: int, log_dir, checkpoint_dir, is_tfdataset=False): callbacks = [] # tensorboard tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) callbacks.append(tensorboard_callback) # checkpoints if checkpoint_dir: cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir, save_weights_only=True, verbose=1, period=self.config['training']['checkpoint_steps']) callbacks.append(cp_callback) # plotting callback plot_callback = PlotData(validation_ds, self.generator, log_dir) callbacks.append(plot_callback) # training batch_size = self.config['training']['batch_size'] z_dim = self.config['data']['z_size'] for epoch in range(nb_epochs): if is_tfdataset: for x in train_ds: train_batch = x.numpy() break else: idx = np.random.randint(0, train_ds.shape[0], batch_size) train_batch = train_ds[idx] self.train_discriminator(train_batch, batch_size, z_dim) self.train_generator(batch_size, z_dim) # TODO add validation step # Train with pure TF, because Keras doesn't work def _train(self, train_ds, validation_ds, nb_epochs: int, log_dir, checkpoint_dir, is_tfdataset=False, restore_latest_checkpoint=True): batch_size = self.config['training']['batch_size'] z_dim = self.config['data']['z_size'] noise = tf.random.normal([batch_size, z_dim]) plot_summary_writer = tf.summary.create_file_writer(str(log_dir / 'plot')) train_summary_writer = tf.summary.create_file_writer(str(log_dir / 'train')) # optimizers generator_optimizer = tf.keras.optimizers.Adam(self.config['training']['generator']['learning_rate']) discriminator_optimizer = tf.keras.optimizers.Adam(self.config['training']['discriminator']['learning_rate']) # checkpoints if checkpoint_dir: checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, generator=self.generator, discriminator=self.discriminator) ckpt_manager = tf.train.CheckpointManager(checkpoint, os.path.join(checkpoint_dir, "ckpt"), max_to_keep=self.config['training']['checkpoints_to_keep']) if restore_latest_checkpoint and ckpt_manager.latest_checkpoint: print(f"Restored from {ckpt_manager.latest_checkpoint}") else: print("Initializing from scratch.") # train loop for epoch in tqdm(range(nb_epochs)): gen_losses = [] disc_losses = [] for ds_batch in train_ds: gen_loss, disc_loss = train_step(ds_batch, self.generator, self.discriminator, generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, batch_size=batch_size, noise_dim=z_dim) gen_losses.append(gen_loss) disc_losses.append(disc_loss) # Loss summary avg_gen_loss = np.mean(gen_losses) avg_disc_loss = np.mean(disc_losses) with train_summary_writer.as_default(): tf.summary.scalar("Average Gen Loss", avg_gen_loss, step=epoch) tf.summary.scalar("Average Disc Loss", avg_disc_loss, step=epoch) # Plot data with plot_summary_writer.as_default(): # Plot sample data predictions = self.generator(noise) tf.summary.image("Sample Generated", predictions, step=epoch) tf.summary.image("Sample Input", [ds_batch[np.random.randint(len(ds_batch))]], step=epoch) # checkpoint if checkpoint_dir: checkpoint.step.assign_add(1) ckpt_step = int(checkpoint.step) if ckpt_step % self.config['training']['checkpoint_steps'] == 0: save_path = ckpt_manager.save() print(f"Saved checkpoint for step {ckpt_step}: {save_path}") @staticmethod # takes an image and generates two vectors: means and standards deviations def _build_generator(model_input, img_shape, config): latent_vector = Input(config['z_size'], name='generator_input') init_shape = tuple([get_initial_size(d, config['num_conv_blocks']) for d in img_shape[:-1]] + [config['init_filters']]) x = Dense(np.prod(init_shape))(latent_vector) x = BatchNormalization()(x) x = LeakyReLU()(x) x = Reshape(init_shape)(x) for i in range(config['num_conv_blocks'] - 1): x = upscale(filters=config['init_filters'] // 2 ** i, kernel_size=config['kernel_size'], strides=config['strides'], upscale_method=config['upscale_method'], activation='relu')(x) # last upscale layer model_output = upscale(filters=config['n_channels'], kernel_size=config['kernel_size'], strides=config['strides'], upscale_method=config['upscale_method'], activation='tanh')(x) return Model(latent_vector, model_output) @staticmethod def _build_discriminator(img_shape, config): model_input = Input(shape=img_shape, name="discriminator_input") x = model_input for i in range(config['num_conv_blocks']): x = conv(filters=config['init_filters'] * (2 ** i), kernel_size=config['kernel_size'], strides=config['strides'])(x) features = Flatten()(x) model_output = Dense(1, activation='sigmoid')(features) return Model(model_input, model_output) def train_discriminator(self, true_imgs, batch_size: int, z_dim: int): # Train on real image # [1,1,...,1] with real output since it is true and we want our generated examples to look like it self.discriminator.train_on_batch(true_imgs, np.ones((batch_size, 1))) # Train on generated images # [0,0,...,0] with generated images since they are fake noise = np.random.normal(0, 1, (batch_size, z_dim)) gen_imgs = self.generator.predict(noise) self.discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1))) def train_generator(self, batch_size: int, z_dim: int): # Train on noise input # [1,1,...,1] with generated output since we want the discriminator to believe these are real images noise = np.random.normal(0, 1, (batch_size, z_dim)) self.gan.train_on_batch(noise, np.ones((batch_size, 1))) def setup_dataset(self, dataset): # prefetch lets the dataset fetch batches in the background while the model is training dataset = dataset.shuffle(self.config['data']['buffer_size']) \ .batch(self.config['training']['batch_size'], drop_remainder=True) \ .prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return dataset
class FusionModel: def __init__(self, config, load_weight_path=None, ab_loss='mse'): img_shape = (config.IMAGE_SIZE, config.IMAGE_SIZE) # Creating generator and discriminator optimizer = Adam(0.00002, 0.5) self.foreground_generator = instance_network(img_shape) self.fusion_discriminator = discriminator_network(img_shape) self.fusion_discriminator.compile(loss=wasserstein_loss_dummy, optimizer=optimizer) self.fusion_generator = fusion_network(img_shape, config.BATCH_SIZE) self.fusion_generator.compile(loss=[ab_loss, 'kld'], optimizer=optimizer) if load_weight_path: chroma_gan = load_model(load_weight_path) chroma_gan_layers = [layer.name for layer in chroma_gan.layers] print('Loading chroma GAN parameter to instance network...') instance_layer_names = [ layer.name for layer in self.foreground_generator.layers ] for i, layer in enumerate(instance_layer_names): if layer == 'fg_model_3': print('model 3 skip') continue if len(layer) < 2: continue if layer[:3] == 'fg_': try: j = chroma_gan_layers.index(layer[3:]) self.foreground_generator.layers[i].set_weights( chroma_gan.layers[j].get_weights()) print(f'Successfully set weights for layer {layer}') except ValueError: print(f'Layer {layer} not found in chroma gan.') except Exception as e: print(e) print('Loading chroma GAN parameter to fusion network...') fusion_layer_names = [ layer.name for layer in self.fusion_generator.layers ] for i, layer in enumerate(fusion_layer_names): if layer == 'model_3': print('model 3 skip') continue try: j = chroma_gan_layers.index(layer) self.fusion_generator.layers[i].set_weights( chroma_gan.layers[j].get_weights()) print(f'Successfully set weights for layer {layer}') except ValueError: print(f'Layer {layer} not found in chroma gan.') except Exception as e: print(e) # Fg=instance prediction fg_img_l = Input(shape=(*img_shape, 1, MAX_INSTANCES)) # self.foreground_generator.trainable = False fg_model_3, fg_conv2d_11, fg_conv2d_13, fg_conv2d_15, fg_conv2d_17 = self.foreground_generator( fg_img_l) # Fusion prediction fusion_img_l = Input(shape=(*img_shape, 1)) fusion_img_real_ab = Input(shape=(*img_shape, 2)) fg_bbox = Input(shape=(4, MAX_INSTANCES)) fg_mask = Input(shape=(*img_shape, MAX_INSTANCES)) self.fusion_generator.trainable = False fusion_img_pred_ab, fusion_class_vec = self.fusion_generator([ fusion_img_l, fg_model_3, fg_conv2d_11, fg_conv2d_13, fg_conv2d_15, fg_conv2d_17, fg_bbox, fg_mask ]) dis_pred_ab = self.fusion_discriminator( [fusion_img_pred_ab, fusion_img_l]) dis_real_ab = self.fusion_discriminator( [fusion_img_real_ab, fusion_img_l]) # Sample the gradient penalty img_ab_interp_samples = RandomWeightedAverage()( [fusion_img_pred_ab, fusion_img_real_ab]) dis_interp_ab = self.fusion_discriminator( [img_ab_interp_samples, fusion_img_l]) partial_gp_loss = partial( gradient_penalty_loss, averaged_samples=img_ab_interp_samples, gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT) partial_gp_loss.__name__ = 'gradient_penalty' # Compile D and G as well as combined self.discriminator_model = Model( inputs=[ fusion_img_l, fusion_img_real_ab, fg_img_l, fg_bbox, fg_mask ], outputs=[dis_real_ab, dis_pred_ab, dis_interp_ab]) self.discriminator_model.compile(optimizer=optimizer, loss=[ wasserstein_loss_dummy, wasserstein_loss_dummy, partial_gp_loss ], loss_weights=[-1.0, 1.0, 1.0]) self.fusion_generator.trainable = True self.fusion_discriminator.trainable = False self.combined = Model( inputs=[fusion_img_l, fg_img_l, fg_bbox, fg_mask], outputs=[fusion_img_pred_ab, fusion_class_vec, dis_pred_ab]) self.combined.compile(loss=[ab_loss, 'kld', wasserstein_loss_dummy], loss_weights=[1.0, 0.003, -0.1], optimizer=optimizer) # Monitor stuff self.callback = TensorBoard(config.LOG_DIR) self.callback.set_model(self.combined) self.train_names = [ 'loss', 'mse_loss', 'kullback_loss', 'wasserstein_loss' ] self.disc_names = ['disc_loss', 'disc_valid', 'disc_fake', 'disc_gp'] self.test_loss_array = [] self.g_loss_array = [] def train(self, data: Data, test_data, log, config, skip_to_after_epoch=None): # Load VGG network VGG_modelF = applications.vgg16.VGG16(weights='imagenet', include_top=True) # Real, Fake and Dummy for Discriminator positive_y = np.ones((data.batch_size, 1), dtype=np.float32) negative_y = -positive_y dummy_y = np.zeros((data.batch_size, 1), dtype=np.float32) # total number of batches in one epoch total_batch = int(data.size / data.batch_size) print(f'batch_size={data.batch_size} * total_batch={total_batch}') save_path = lambda type, epoch: os.path.join( config.MODEL_DIR, f"fusion_{type}Epoch{epoch}.h5") if skip_to_after_epoch: start_epoch = skip_to_after_epoch + 1 print(f"Loading weights from epoch {skip_to_after_epoch}") self.combined.load_weights( save_path("combined", skip_to_after_epoch)) self.fusion_discriminator.load_weights( save_path("discriminator", skip_to_after_epoch)) else: start_epoch = 0 for epoch in range(start_epoch, config.NUM_EPOCHS): for batch in tqdm(range(total_batch)): train_batch = data.generate_batch() resized_l = train_batch.resized_images.l resized_ab = train_batch.resized_images.ab # GT vgg predictVGG = VGG_modelF.predict( np.tile(resized_l, [1, 1, 1, 3])) # train generator g_loss = self.combined.train_on_batch([ resized_l, train_batch.instances.l, train_batch.instances.bbox, train_batch.instances.mask ], [resized_ab, predictVGG, positive_y]) # train discriminator d_loss = self.discriminator_model.train_on_batch([ resized_l, resized_ab, train_batch.instances.l, train_batch.instances.bbox, train_batch.instances.mask ], [positive_y, negative_y, dummy_y]) # update log files write_log(self.callback, self.train_names, g_loss, (epoch * total_batch + batch + 1)) write_log(self.callback, self.disc_names, d_loss, (epoch * total_batch + batch + 1)) if batch % 10 == 0: print( f"[Epoch {epoch}] [Batch {batch}/{total_batch}] [generator loss: {g_loss[0]:08f}] [discriminator loss: {d_loss[0]:08f}]" ) print('Saving models...') self.combined.save(save_path("combined", epoch)) self.fusion_discriminator.save(save_path("discriminator", epoch)) print('Models saved.') print('Sampling test images...') # sample images after each epoch self.sample_images(test_data, epoch, config) def sample_images(self, test_data: Data, epoch, config): total_batch = int(ceil(test_data.size / test_data.batch_size)) for _ in range(total_batch): # load test data test_batch = test_data.generate_batch() # predict AB channels fg_model_3, fg_conv2d_11, fg_conv2d_13, fg_conv2d_15, fg_conv2d_17 = self.foreground_generator.predict( test_batch.instances.l) fusion_img_pred_ab, _ = self.fusion_generator.predict([ test_batch.resized_images.l, fg_model_3, fg_conv2d_11, fg_conv2d_13, fg_conv2d_15, fg_conv2d_17, test_batch.instances.bbox, test_batch.instances.mask ]) # print results for i in range(test_data.batch_size): original_full_img = test_batch.images.full[i] height, width, _ = original_full_img.shape pred_ab = cv2.resize( deprocess_float2int(fusion_img_pred_ab[i]), (width, height)) reconstruct_and_save( test_batch.images.l[i], pred_ab, f'epoch{epoch}_{test_batch.file_names[i]}', config)