def test(self, refer_model=None, batch_size=4, datapath_test='./images/val_dir', crops_per_image=1, log_test_path="./images/test/" ): """Trains the generator part of the network with MSE loss""" # Create data loaders loader = DataLoader( datapath_test, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image ) print(">> Ploting test images") plot_test_images(self, loader, datapath_test, log_test_path, 0, refer_model=refer_model)
def train_generator(self, epochs, batch_size, workers, dataname, datapath_train, datapath_validation=None, datapath_test=None, steps_per_epoch=1000, steps_per_validation=1000, crops_per_image=2, log_weight_path='./data/weights/', log_tensorboard_path='./data/logs/', log_tensorboard_name='SRResNet', log_tensorboard_update_freq=10000, log_test_path="./images/samples/"): """Trains the generator part of the network with MSE loss""" # Create data loaders train_loader = DataLoader(datapath_train, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image) test_loader = None if datapath_validation is not None: test_loader = DataLoader(datapath_validation, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image) # Callback: tensorboard callbacks = [] if log_tensorboard_path: tensorboard = TensorBoard(log_dir=os.path.join( log_tensorboard_path, log_tensorboard_name), histogram_freq=0, batch_size=batch_size, write_graph=False, write_grads=False, update_freq=log_tensorboard_update_freq) callbacks.append(tensorboard) else: print( ">> Not logging to tensorboard since no log_tensorboard_path is set" ) # Callback: save weights after each epoch modelcheckpoint = ModelCheckpoint(os.path.join( log_weight_path, dataname + '_{}X'.format(self.upscaling_factor)), monitor='val_loss', save_best_only=True, save_weights_only=True) callbacks.append(modelcheckpoint) # Callback: test images plotting if datapath_test is not None: testplotting = LambdaCallback(on_epoch_end=lambda epoch, logs: plot_test_images(self, train_loader, datapath_test, log_test_path, epoch, name='SRResNet')) callbacks.append(testplotting) # Fit the model self.generator.fit_generator(train_loader, steps_per_epoch=steps_per_epoch, epochs=epochs, validation_data=test_loader, validation_steps=steps_per_validation, callbacks=callbacks, use_multiprocessing=workers > 1, workers=workers, verbose=1)
def train_srgan( self, epochs, batch_size, dataname, datapath_train, datapath_validation=None, steps_per_validation=1000, datapath_test=None, workers=4, max_queue_size=10, first_epoch=0, print_frequency=1, crops_per_image=2, log_weight_frequency=None, log_weight_path='./data/weights/', log_tensorboard_path='./data/logs/', log_tensorboard_name='SRGAN', log_tensorboard_update_freq=10000, log_test_frequency=500, log_test_path="./images/samples/", ): """Train the SRGAN network :param int epochs: how many epochs to train the network for :param str dataname: name to use for storing model weights etc. :param str datapath_train: path for the image files to use for training :param str datapath_test: path for the image files to use for testing / plotting :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference! :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never :param int log_weight_path: where should network weights be saved :param int log_test_frequency: how often (in epochs) should testing & validation be performed :param str log_test_path: where should test results be saved :param str log_tensorboard_path: where should tensorflow logs be sent :param str log_tensorboard_name: what folder should tf logs be saved under """ # Create train data loader loader = DataLoader(datapath_train, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image) # Validation data loader if datapath_validation is not None: validation_loader = DataLoader(datapath_validation, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image) # Use several workers on CPU for preparing batches enqueuer = OrderedEnqueuer(loader, use_multiprocessing=True, shuffle=True) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() # Callback: tensorboard if log_tensorboard_path: tensorboard = TensorBoard(log_dir=os.path.join( log_tensorboard_path, log_tensorboard_name), histogram_freq=0, batch_size=batch_size, write_graph=False, write_grads=False, update_freq=log_tensorboard_update_freq) tensorboard.set_model(self.srgan) else: print( ">> Not logging to tensorboard since no log_tensorboard_path is set" ) # Callback: format input value def named_logs(model, logs): """Transform train_on_batch return value to dict expected by on_batch_end callback""" result = {} for l in zip(model.metrics_names, logs): result[l[0]] = l[1] return result # Shape of output from discriminator disciminator_output_shape = list(self.discriminator.output_shape) disciminator_output_shape[0] = batch_size disciminator_output_shape = tuple(disciminator_output_shape) # VALID / FAKE targets for discriminator real = np.ones(disciminator_output_shape) fake = np.zeros(disciminator_output_shape) # Each epoch == "update iteration" as defined in the paper print_losses = {"G": [], "D": []} start_epoch = datetime.datetime.now() # Random images to go through idxs = np.random.randint(0, len(loader), epochs) # Loop through epochs / iterations for epoch in range(first_epoch, int(epochs) + first_epoch): # Start epoch time if epoch % (print_frequency + 1) == 0: start_epoch = datetime.datetime.now() # Train discriminator imgs_lr, imgs_hr = next(output_generator) generated_hr = self.generator.predict(imgs_lr) real_loss = self.discriminator.train_on_batch(imgs_hr, real) fake_loss = self.discriminator.train_on_batch(generated_hr, fake) discriminator_loss = 0.5 * np.add(real_loss, fake_loss) # Train generator features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr)) generator_loss = self.srgan.train_on_batch(imgs_lr, [real, features_hr]) # Callbacks logs = named_logs(self.srgan, generator_loss) tensorboard.on_epoch_end(epoch, logs) # Save losses print_losses['G'].append(generator_loss) print_losses['D'].append(discriminator_loss) # Show the progress if epoch % print_frequency == 0: g_avg_loss = np.array(print_losses['G']).mean(axis=0) d_avg_loss = np.array(print_losses['D']).mean(axis=0) print( "\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}" .format( epoch, epochs + first_epoch, (datetime.datetime.now() - start_epoch).seconds, ", ".join([ "{}={:.4f}".format(k, v) for k, v in zip( self.srgan.metrics_names, g_avg_loss) ]), ", ".join([ "{}={:.4f}".format(k, v) for k, v in zip( self.discriminator.metrics_names, d_avg_loss) ]))) print_losses = {"G": [], "D": []} # Run validation inference if specified if datapath_validation: validation_losses = self.generator.evaluate_generator( validation_loader, steps=steps_per_validation, use_multiprocessing=workers > 1, workers=workers) print(">> Validation Losses: {}".format(", ".join([ "{}={:.4f}".format(k, v) for k, v in zip( self.generator.metrics_names, validation_losses) ]))) # If test images are supplied, run model on them and save to log_test_path if datapath_test and epoch % log_test_frequency == 0: plot_test_images(self, loader, datapath_test, log_test_path, epoch) # Check if we should save the network weights if log_weight_frequency and epoch % log_weight_frequency == 0: # Save the network weights self.save_weights(os.path.join(log_weight_path, dataname))
def train(self, epochs, dataname, datapath, batch_size=1, test_images=None, test_frequency=50, test_path="./images/samples/", weight_frequency=None, weight_path='./data/weights/', print_frequency=1): """Train the SRGAN network :param int epochs: how many epochs to train the network for :param str dataname: name to use for storing model weights etc. :param str datapath: path for the image files to use for training :param int batch_size: how large mini-batches to use :param list test_images: list of image paths to perform testing on :param int test_frequency: how often (in epochs) should testing be performed :param str test_path: where should test results be saved :param int weight_frequency: how often (in epochs) should network weights be saved. None for never :param int weight_path: where should network weights be saved :param int print_frequency: how often (in epochs) to print progress to terminal """ # Create data loader loader = DataLoader(datapath, self.height_hr, self.width_hr, self.height_lr, self.width_lr, self.upscaling_factor) # Shape of output from discriminator disciminator_output_shape = list(self.discriminator.output_shape) disciminator_output_shape[0] = batch_size disciminator_output_shape = tuple(disciminator_output_shape) # VALID / FAKE targets for discriminator real = np.ones(disciminator_output_shape) fake = np.ones(disciminator_output_shape) # Each epoch == "update iteration" as defined in the paper losses = [] for epoch in range(epochs): # Start epoch time if epoch % (print_frequency + 1) == 0: start_epoch = datetime.datetime.now() # Train discriminator imgs_hr, imgs_lr = loader.load_batch(batch_size) generated_hr = self.generator.predict(imgs_lr) real_loss = self.discriminator.train_on_batch(imgs_hr, real) fake_loss = self.discriminator.train_on_batch(generated_hr, fake) discriminator_loss = 0.5 * np.add(real_loss, fake_loss) # Train generator imgs_hr, imgs_lr = loader.load_batch(batch_size) features_hr = self.vgg.predict(imgs_hr) generator_loss = self.srgan.train_on_batch([imgs_lr, imgs_hr], [real, features_hr]) # Save losses losses.append({ 'generator': generator_loss, 'discriminator': discriminator_loss }) # Plot the progress if epoch % print_frequency == 0: print( "Epoch {}/{} | Time: {}s\n>> Generator: {}\n>> Discriminator: {}\n" .format( epoch, epochs, (datetime.datetime.now() - start_epoch).seconds, ", ".join([ "{}={:.3e}".format(k, v) for k, v in zip( self.srgan.metrics_names, generator_loss) ]), ", ".join([ "{}={:.3e}".format(k, v) for k, v in zip(self.discriminator.metrics_names, discriminator_loss) ]))) # If test images are supplied, show them to the user if test_images and epoch % test_frequency == 0: plot_test_images(self, loader, test_images, test_path, epoch) # Check if we should save the network weights if weight_frequency and epoch % weight_frequency == 0: # Save the network weights self.save_weights(os.path.join(weight_path, dataname)) # Save the recorded losses pickle.dump( losses, open(os.path.join(weight_path, dataname + '_losses.p'), 'wb'))