class K_DCGAN(data_model): def __init__( self, flag="deconv", epoch=100000, img_shape=[256, 256, 3], learning_rate=1e-3, white_bk=True, name="gan_unet_model", loss=[ 'categorical_crossentropy', # generator l1_loss, # K_dcgan 'binary_crossentropy' #dis criminator ]): self.loss = loss assert (flag == "deconv") or ( flag == "upsampling"), "Only support flag for deconv or upsampling" self.flag = flag self.img_dim = img_shape bk = "bk" if white_bk: bk = "wh" data_model.__init__(self, name + "bk%s_lr_%s_img_dim%s_loss_%s" % (bk, learning_rate, img_shape[0], loss), "DCGAN", img_shape=img_shape, epochs=epoch) assert len(loss) == 3 # training params self.patch_size = [64, 64] self.n_batch_per_epoch = self.batch_size * 100 self.learning_rate = learning_rate self.losses = loss # init all need dir and model self.build() self.disc_weights_path = os.path.join(self.weight_path, "disc_weight_epoch.h5") self.gen_weights_path = os.path.join(self.weight_path, "gen_weight_epoch.h5") self.DCGAN_weights_path = os.path.join(self.weight_path, "DCGAN_weight_epoch.h5") check_folders(self.weight_path) def build(self): generator_loss, dc_gan_loss, discriminator_loss = self.losses self.generator = load_model("generator_unet_%s" % (self.flag), self.img_dim, 64, True, True, self.batch_size) nb_patch, img_shape_disc = get_nb_patch(self.img_dim, self.patch_size) self.discriminator = load_model("DCGAN_discriminator", self.img_dim, nb_patch, -1, True, self.batch_size) opt_dcgan, opt_discriminator = Adam( lr=self.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),\ Adam(lr=self.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08) self.generator.compile(loss=generator_loss, optimizer=opt_discriminator) self.discriminator.trainable = False self.DCGAN_model = DCGAN(self.generator, self.discriminator, self.img_dim, self.patch_size, "channels_last") loss = [dc_gan_loss, discriminator_loss] loss_weight = [1E1, 1] self.DCGAN_model.compile(loss=loss, loss_weights=loss_weight, optimizer=opt_dcgan) self.discriminator.trainable = True self.discriminator.compile(loss=discriminator_loss, optimizer=opt_discriminator) def log_checkpoint(self, epoch, batch, loss): log_path = os.path.join(self.weight_path, "checkpoint") prev_epochs = 0 if os.path.isfile(log_path): with open(log_path, "w+") as f: line = f.readline() if "Epoch" in line: line = f.readline().split(" ") prev_epochs = int(line[4]) with open(log_path, "w+") as f: f.write("Model_Name {} ".format(self.title)) f.write("Epoch {} in batch {}".format(epoch + prev_epochs, batch)) f.write("\n") f.write("Losses: {}".format(loss)) def save(self): if not os.path.exists(self.gen_weights_path): h5py.File(self.gen_weights_path) h5py.File(self.disc_weights_path) h5py.File(self.DCGAN_weights_path) self.generator.save_weights(self.gen_weights_path, overwrite=True) self.discriminator.save_weights(self.disc_weights_path, overwrite=True) self.DCGAN_model.save_weights(self.DCGAN_weights_path, overwrite=True) def test_img(self): # pick a random index idx = rnd.choice([i for i in range(0, len(self.data['X']))]) X, y = self.get_data(idx) # normalized images self.load() X_pred = self.generator.predict(np.array([X])) X = image.array_to_img(inverse_normalization(X, self.max, self.min)) y = image.array_to_img(inverse_normalization(y, self.max, self.min)) X_pred = image.array_to_img( inverse_normalization(X_pred[0], self.max, self.min)) suffix = "End_test" result = np.hstack((X, y, X_pred)) # check_folders("../figures/%s" % (self.title)) # plt.imshow(result) # plt.savefig( # "../figures/%s/current_batch_%s.png" % # (self.title, suffix)) # plt.axis("off") # plt.show() def load(self): '''Load models weight from log/${model_name}''' if os.path.exists(self.gen_weights_path): self.generator.load_weights(self.gen_weights_path) self.discriminator.load_weights(self.disc_weights_path) self.DCGAN_model.load_weights(self.DCGAN_weights_path) else: raise FileNotFoundError("No Previous Model Found") print("Loading model from {}".format([ self.disc_weights_path, self.gen_weights_path, self.DCGAN_weights_path ])) def summary(self, name="DCGAN"): if name == "Generator": self.generator.summary() elif name == "Discriminator": self.discriminator.summary() else: self.DCGAN_model.summary() @timeit(log_info="Training pix2pix") def train(self, label_smoothing=False, retrain=False): gen_loss, disc_loss = 100, 100 n_batch_per_epoch = self.n_batch_per_epoch total_epoch = n_batch_per_epoch * self.batch_size try: if retrain: print("Found prev_trained models ...") self.load() print("Retrain the model ") except FileNotFoundError: print("No previous model found start train a new model") try: os.system("clear") for e in range(self.nb_epochs): batch_counter = 1 start = time() progbar = generic_utils.Progbar(total_epoch) for X, y in self.gen_batch(self.batch_size): X_disc, y_disc = self.get_disc_batch( X, y, self.generator, batch_counter, self.patch_size, label_smoothing=label_smoothing, label_flipping=0) disc_loss = self.discriminator.train_on_batch( X_disc, y_disc) X_gen_target, Y_gen = next(self.gen_batch(self.batch_size)) self.generator.train_on_batch(X_gen_target, Y_gen) y_gen = np.zeros((Y_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 self.discriminator.trainable = False gen_loss = self.DCGAN_model.train_on_batch( X_gen_target, [Y_gen, y_gen]) self.DCGAN_model.trainable = True batch_counter += 1 progbar.add(self.batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation plot_generated_batch(X, y, self.generator, self.batch_size, "training", self.title, self) # get next validation batches X_test, y_test = next( self.gen_batch(self.batch_size, validation=True)) plot_generated_batch(X_test, y_test, self.generator, self.batch_size, "validation", self.title, self) if batch_counter >= n_batch_per_epoch: break print("") t_time = time() - start print('Epoch %s/%s, Time: %s ms' % (e + 1, self.nb_epochs, round(t_time, 2)), end="\r") if e % 5 == 0: self.save() self.log_checkpoint(e, batch_counter, [("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) except KeyboardInterrupt: print( "\nInterruption occured.... Saving the model Epochs:{}".format( e)) self.save() self.log_checkpoint(e, batch_counter, [("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])])
image_data_format, label_smoothing=label_smoothing, label_flipping=label_flipping) # Update the discriminator disc_loss = discriminator_model.train_on_batch(X_disc, y_disc) # Create a batch to feed the generator model X_gen_target, X_gen = next( gen_batch(X_full_train, X_sketch_train, batch_size)) y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 # Freeze the discriminator discriminator_model.trainable = False gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) # Unfreeze the discriminator discriminator_model.trainable = True batch_counter += 1 progbar.add(batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) # Save images for visualization if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation plot_generated_batch(X_full_batch, X_sketch_batch, generator_model, batch_size,