def process(network_info, images_dir, output_dir, threshold=0.8): session, input_tensor, output_tensor, img_size, cls_labels = load_from_xml( network_info) print("Input tensor: {}".format(input_tensor)) print("Output tensor: {}".format(output_tensor)) print("Image size: {}".format(img_size)) print("Classes: {}".format(cls_labels)) print() print("Parsing source directory... (this can take some time)") if img_size[2] == 3: img_type = 'rgb' else: img_type = 'greyscale' gen = InferenceGenerator(images_dir, 64, img_size, img_type) print("Files: {}".format(len(gen.filenames))) print("Batches: {}".format(len(gen))) workers = np.min((multiprocessing.cpu_count(), 8)) print("Workers: {}".format(workers)) print() filenames = [] cls_index = [] cls_names = [] score = [] enq = OrderedEnqueuer(gen, use_multiprocessing=True) enq.start(workers=workers, max_queue_size=multiprocessing.cpu_count() * 4) output_generator = enq.get() for i in range(len(gen)): print("\r{} / {}".format(i, len(gen)), end='') batch_filenames, batch_images = next(output_generator) result = session.run(output_tensor, feed_dict={input_tensor: batch_images}) cls = np.argmax(result, axis=1) scr = np.max(result, axis=1) cls_name = [cls_labels[i] for i in cls] filenames.extend(batch_filenames) cls_index.extend(cls) cls_names.extend(cls_name) score.extend(scr) enq.stop() print() print("Done") print("See {} for results".format(output_dir)) parents = [Path(f).parent.name for f in filenames] files = [Path(f).name for f in filenames] df = pd.DataFrame( data={ 'filename': filenames, 'parent': parents, 'file': files, 'class': cls_names, 'class_index': cls_index, 'score': score }) os.makedirs(output_directory, exist_ok=True) df.to_csv(os.path.join(output_dir, "inference.csv"))
def queue_train_generator(train_gen, workers=1, use_multiprocessing=False, max_queue_size=10, use_sequence_api=True): # all the queue stuff is from https://github.com/keras-team/keras/blob/master/keras/engine/training_generator.py if workers > 0: if use_sequence_api: enqueuer = OrderedEnqueuer( train_gen, use_multiprocessing=use_multiprocessing, # TODO: add a parameter to control this shuffle=False, ) else: enqueuer = GeneratorEnqueuer( train_gen, use_multiprocessing=use_multiprocessing, ) enqueuer.start(workers=workers, max_queue_size=max_queue_size) train_generator = enqueuer.get() else: if use_sequence_api: train_generator = iter_sequence_infinite(train_gen) else: train_generator = train_gen return train_generator
def get_enqueuer(csv,batch_size, FLAGS, tokenizer_wrapper, augmenter=None): data_generator = AugmentedImageSequence( dataset_csv_file=csv, class_names=FLAGS.csv_label_columns, tokenizer_wrapper=tokenizer_wrapper, source_image_dir=FLAGS.image_directory, batch_size=batch_size, target_size=FLAGS.image_target_size, augmenter=augmenter, shuffle_on_epoch_end=True, ) enqueuer = OrderedEnqueuer(data_generator, use_multiprocessing=False, shuffle=False) return enqueuer, data_generator.steps
def __init__(self, data_generator, batch_size, num_samples, output_dir, input_shape, n_classes): self.batch_size = batch_size self.num_samples = num_samples self.tensorboard_writer = tf.summary.create_file_writer( output_dir + "/diagnose/", flush_millis=10000) self.data_generator = data_generator self.input_shape = input_shape self.colors = np.array([[255, 255, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 0, 0]]) self.color_dict = {0: (0, 0, 0), 1: (0, 255, 0)} self.n_classes = n_classes self.colors = self.colors[:self.n_classes] is_sequence = isinstance(self.data_generator, Sequence) if is_sequence: self.enqueuer = OrderedEnqueuer(self.data_generator, use_multiprocessing=True, shuffle=False) else: self.enqueuer = GeneratorEnqueuer(self.data_generator, use_multiprocessing=True, wait_time=0.01) self.enqueuer.start(workers=4, max_queue_size=4)
def __init__(self, data_generator, batch_size, num_samples, output_dir, normalization_mean, start_index=0): super().__init__() self.data_generator = data_generator self.batch_size = batch_size self.num_samples = num_samples self.tensorboard_writer = TensorboardWriter(output_dir) self.normalization_mean = normalization_mean self.start_index = start_index is_sequence = isinstance(self.data_generator, Sequence) if is_sequence: self.enqueuer = OrderedEnqueuer(self.data_generator, use_multiprocessing=False, shuffle=False) else: self.enqueuer = GeneratorEnqueuer(self.data_generator, use_multiprocessing=False) self.enqueuer.start(workers=1, max_queue_size=4)
def generator_fn(): generator = OrderedEnqueuer(Generator(f_names), True) generator.start(workers=8, max_queue_size=10) while True: image, y_true_1, y_true_2, y_true_3 = generator.get().__next__() yield image, y_true_1, y_true_2, y_true_3
def train_esrgan(self, epochs=None, batch_size=16, modelname=None, datapath_train=None, datapath_validation=None, steps_per_validation=1000, datapath_test=None, workers=4, max_queue_size=10, first_epoch=0, print_frequency=1, crops_per_image=2, log_weight_frequency=None, log_weight_path='./model/', log_tensorboard_path='./data/logs/', log_tensorboard_update_freq=10, log_test_frequency=500, log_test_path="./images/samples/", media_type='i' ): """Train the ESRGAN network :param int epochs: how many epochs to train the network for :param str modelname: name to use for storing model weights etc. :param str datapath_train: path for the image files to use for training :param str datapath_test: path for the image files to use for testing / plotting :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference! :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never :param int log_weight_path: where should network weights be saved :param int log_test_frequency: how often (in epochs) should testing & validation be performed :param str log_test_path: where should test results be saved :param str log_tensorboard_path: where should tensorflow logs be sent """ # Create data loaders train_loader = DataLoader( datapath_train, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image, media_type, self.channels, self.colorspace ) # Validation data loader validation_loader = None if datapath_validation is not None: validation_loader = DataLoader( datapath_validation, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image, media_type, self.channels, self.colorspace ) test_loader = None if datapath_test is not None: test_loader = DataLoader( datapath_test, 1, self.height_hr, self.width_hr, self.upscaling_factor, 1, media_type, self.channels, self.colorspace ) # Use several workers on CPU for preparing batches enqueuer = OrderedEnqueuer( train_loader, use_multiprocessing=True, shuffle=True ) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() # Callback: tensorboard if log_tensorboard_path: tensorboard = TensorBoard( log_dir=os.path.join(log_tensorboard_path, modelname), histogram_freq=0, batch_size=batch_size, write_graph=True, write_grads=True, update_freq=log_tensorboard_update_freq ) tensorboard.set_model(self.esrgan) else: print(">> Not logging to tensorboard since no log_tensorboard_path is set") # Learning rate scheduler def lr_scheduler(epoch, lr): factor = 0.5 decay_step = [50000,100000,200000,300000] if epoch in decay_step and epoch: return lr * factor return lr lr_scheduler_gan = LearningRateScheduler(lr_scheduler, verbose=1) lr_scheduler_gan.set_model(self.esrgan) lr_scheduler_gen = LearningRateScheduler(lr_scheduler, verbose=0) lr_scheduler_gen.set_model(self.generator) lr_scheduler_dis = LearningRateScheduler(lr_scheduler, verbose=0) lr_scheduler_dis.set_model(self.discriminator) lr_scheduler_ra = LearningRateScheduler(lr_scheduler, verbose=0) lr_scheduler_ra.set_model(self.ra_discriminator) # Callback: format input value def named_logs(model, logs): """Transform train_on_batch return value to dict expected by on_batch_end callback""" result = {} for l in zip(model.metrics_names, logs): result[l[0]] = l[1] return result # Shape of output from discriminator disciminator_output_shape = list(self.ra_discriminator.output_shape) disciminator_output_shape[0] = batch_size disciminator_output_shape = tuple(disciminator_output_shape) # VALID / FAKE targets for discriminator real = np.ones(disciminator_output_shape) fake = np.zeros(disciminator_output_shape) # Each epoch == "update iteration" as defined in the paper print_losses = {"GAN": [], "D": []} start_epoch = datetime.datetime.now() # Random images to go through #idxs = np.random.randint(0, len(train_loader), epochs) # Loop through epochs / iterations for epoch in range(first_epoch, int(epochs)+first_epoch): lr_scheduler_gan.on_epoch_begin(epoch) lr_scheduler_ra.on_epoch_begin(epoch) lr_scheduler_dis.on_epoch_begin(epoch) lr_scheduler_gen.on_epoch_begin(epoch) # Start epoch time if epoch % print_frequency == 0: print("\nEpoch {}/{}:".format(epoch+1, epochs+first_epoch)) start_epoch = datetime.datetime.now() # Train discriminator self.discriminator.trainable = True self.ra_discriminator.trainable = True imgs_lr, imgs_hr = next(output_generator) generated_hr = self.generator.predict(imgs_lr) real_loss = self.ra_discriminator.train_on_batch([imgs_hr,generated_hr], real) #print("Real: ",real_loss) fake_loss = self.ra_discriminator.train_on_batch([generated_hr,imgs_hr], fake) #print("Fake: ",fake_loss) discriminator_loss = 0.5 * np.add(real_loss, fake_loss) # Train generator self.discriminator.trainable = False self.ra_discriminator.trainable = False #for _ in tqdm(range(10),ncols=60,desc=">> Training generator"): imgs_lr, imgs_hr = next(output_generator) gan_loss = self.esrgan.train_on_batch([imgs_lr,imgs_hr], [imgs_hr,real,imgs_hr]) # Callbacks logs = named_logs(self.esrgan, gan_loss) tensorboard.on_epoch_end(epoch, logs) # Save losses print_losses['GAN'].append(gan_loss) print_losses['D'].append(discriminator_loss) # Show the progress if epoch % print_frequency == 0: g_avg_loss = np.array(print_losses['GAN']).mean(axis=0) d_avg_loss = np.array(print_losses['D']).mean(axis=0) print(">> Time: {}s\n>> GAN: {}\n>> Discriminator: {}".format( (datetime.datetime.now() - start_epoch).seconds, ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.esrgan.metrics_names, g_avg_loss)]), ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.discriminator.metrics_names, d_avg_loss)]) )) print_losses = {"GAN": [], "D": []} # Run validation inference if specified if datapath_validation: validation_losses = self.generator.evaluate_generator( validation_loader, steps=steps_per_validation, use_multiprocessing=workers>1, workers=workers ) print(">> Validation Losses: {}".format( ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.generator.metrics_names, validation_losses)]) )) # If test images are supplied, run model on them and save to log_test_path if datapath_test and epoch % log_test_frequency == 0: plot_test_images(self.generator, test_loader, datapath_test, log_test_path, epoch, modelname, channels = self.channels,colorspace=self.colorspace) # Check if we should save the network weights if log_weight_frequency and epoch % log_weight_frequency == 0: # Save the network weights self.save_weights(os.path.join(log_weight_path, modelname))
def train_generator(self, epochs=None, batch_size=None, workers=None, max_queue_size=None, modelname=None, datapath_train=None, datapath_validation='../', datapath_test='../', steps_per_epoch=None, steps_per_validation=None, crops_per_image=None, print_frequency=None, log_weight_path='./model/', log_tensorboard_path='./logs/', log_tensorboard_update_freq=None, log_test_path="./test/", media_type='i' ): """Trains the generator part of the network with MSE loss""" # Create data loaders train_loader = DataLoader( datapath_train, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image, media_type, self.channels, self.colorspace ) validation_loader = None if datapath_validation is not None: validation_loader = DataLoader( datapath_validation, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image, media_type, self.channels, self.colorspace ) test_loader = None if datapath_test is not None: test_loader = DataLoader( datapath_test, 1, self.height_hr, self.width_hr, self.upscaling_factor, 1, media_type, self.channels, self.colorspace ) # Callback: tensorboard callbacks = [] if log_tensorboard_path: tensorboard = TensorBoard( log_dir=os.path.join(log_tensorboard_path, modelname), histogram_freq=0, batch_size=batch_size, write_graph=True, write_grads=True, update_freq=log_tensorboard_update_freq ) callbacks.append(tensorboard) else: print(">> Not logging to tensorboard since no log_tensorboard_path is set") # Callback: Stop training when a monitored quantity has stopped improving earlystopping = EarlyStopping( monitor='val_loss', patience=500, verbose=1, restore_best_weights=True ) callbacks.append(earlystopping) # Callback: save weights after each epoch modelcheckpoint = ModelCheckpoint( os.path.join(log_weight_path, modelname + '_{}X.h5'.format(self.upscaling_factor)), monitor='val_loss', save_best_only=True, save_weights_only=True ) callbacks.append(modelcheckpoint) # Callback: Reduce lr when a monitored quantity has stopped improving reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=50, min_lr=1e-5,verbose=1) callbacks.append(reduce_lr) # Learning rate scheduler def lr_scheduler(epoch, lr): factor = 0.5 decay_step = 100 #100 epochs * 2000 step per epoch = 2x1e5 if epoch % decay_step == 0 and epoch: return lr * factor return lr lr_scheduler = LearningRateScheduler(lr_scheduler, verbose=1) callbacks.append(lr_scheduler) # Callback: save weights after each epoch modelcheckpoint = ModelCheckpoint( os.path.join(log_weight_path, modelname + '_{}X.h5'.format(self.upscaling_factor)), monitor='val_loss', save_best_only=True, save_weights_only=True) callbacks.append(modelcheckpoint) # Callback: test images plotting if datapath_test is not None: testplotting = LambdaCallback( on_epoch_end=lambda epoch, logs: None if ((epoch+1) % print_frequency != 0 ) else plot_test_images( self.generator, test_loader, datapath_test, log_test_path, epoch+1, name=modelname, channels=self.channels, colorspace=self.colorspace)) callbacks.append(testplotting) # Use several workers on CPU for preparing batches enqueuer = OrderedEnqueuer( train_loader, use_multiprocessing=True ) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() # Fit the model self.generator.fit_generator( output_generator, steps_per_epoch=steps_per_epoch, epochs=epochs, validation_data=validation_loader, validation_steps=steps_per_validation, callbacks=callbacks, use_multiprocessing=False, #workers>1 because single gpu workers=1 )
def train( self, epochs=50, batch_size=8, steps_per_epoch=5, steps_per_validation=5, crops_per_image=4, print_frequency=5, log_tensorboard_update_freq=10, workers=4, max_queue_size=5, model_name='ESPCN', datapath_train='../../../videos_harmonic/MYANMAR_2160p/train/', datapath_validation='../../../videos_harmonic/MYANMAR_2160p/validation/', datapath_test='../../../videos_harmonic/MYANMAR_2160p/test/', log_weight_path='../model/', log_tensorboard_path='../logs/', log_test_path='../test/', media_type='i'): # Create data loaders train_loader = DataLoader(datapath_train, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image, media_type, self.channels, self.colorspace) validation_loader = None if datapath_validation is not None: validation_loader = DataLoader(datapath_validation, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image, media_type, self.channels, self.colorspace) test_loader = None if datapath_test is not None: test_loader = DataLoader(datapath_test, 1, self.height_hr, self.width_hr, self.upscaling_factor, 1, media_type, self.channels, self.colorspace) # Callback: tensorboard callbacks = [] if log_tensorboard_path: tensorboard = TensorBoard(log_dir=os.path.join( log_tensorboard_path, model_name), histogram_freq=0, batch_size=batch_size, write_graph=True, write_grads=True, update_freq=log_tensorboard_update_freq) callbacks.append(tensorboard) else: print( ">> Not logging to tensorboard since no log_tensorboard_path is set" ) # Callback: Stop training when a monitored quantity has stopped improving earlystopping = EarlyStopping(monitor='val_loss', patience=100, verbose=1, restore_best_weights=True) callbacks.append(earlystopping) # Callback: Reduce lr when a monitored quantity has stopped improving reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=100, min_lr=1e-4, verbose=1) callbacks.append(reduce_lr) # Callback: save weights after each epoch modelcheckpoint = ModelCheckpoint(os.path.join( log_weight_path, model_name + '_{}X.h5'.format(self.upscaling_factor)), monitor='val_loss', save_best_only=True, save_weights_only=True) callbacks.append(modelcheckpoint) # Callback: test images plotting if datapath_test is not None: testplotting = LambdaCallback( on_epoch_end=lambda epoch, logs: None if ((epoch + 1) % print_frequency != 0) else plot_test_images( self.model, test_loader, datapath_test, log_test_path, epoch + 1, name=model_name, channels=self.channels, colorspace=self.colorspace)) callbacks.append(testplotting) # Use several workers on CPU for preparing batches enqueuer = OrderedEnqueuer(train_loader, use_multiprocessing=True) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() self.model.fit_generator( output_generator, steps_per_epoch=steps_per_epoch, epochs=epochs, validation_data=validation_loader, validation_steps=steps_per_validation, callbacks=callbacks, #shuffle=True, use_multiprocessing=False, #workers>1 workers=1 #workers )
def predict_generator(self, generator, steps, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0): """Generates predictions for the input samples from a data generator. The generator should return the same kind of data as accepted by `predict_on_batch`. generator = DataGenerator class that returns: x = Input data as a 3D Tensor (batch_size, max_input_len, dim_features) x_len = 1D array with the length of each data in batch_size # Arguments generator: Generator yielding batches of input samples or an instance of Sequence (tensorflow.keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. steps: Total number of steps (batches of samples) to yield from `generator` before stopping. max_queue_size: Maximum size for the generator queue. workers: Maximum number of processes to spin up when using process based threading use_multiprocessing: If `True`, use process based threading. Note that because this implementation relies on multiprocessing, you should not pass non picklable arguments to the generator as they can't be passed easily to children processes. verbose: verbosity mode, 0 or 1. # Returns A numpy array(s) of predictions. # Raises ValueError: In case the generator yields data in an invalid format. """ self.model_pred._make_predict_function() is_sequence = isinstance(generator, Sequence) allab_outs = [] steps_done = 0 enqueuer = None try: if is_sequence: enqueuer = OrderedEnqueuer( generator, use_multiprocessing=use_multiprocessing) else: enqueuer = GeneratorEnqueuer( generator, use_multiprocessing=use_multiprocessing) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() if verbose == 1: progbar = Progbar(target=steps) while steps_done < steps: x = next(output_generator) outs = self.predict_on_batch(x) if not isinstance(outs, list): outs = [outs] for i, out in enumerate(outs): allab_outs.append([int(c) for c in out]) steps_done += 1 if verbose == 1: progbar.update(steps_done) finally: if enqueuer is not None: enqueuer.stop() return allab_outs
def generator_fn(): generator = OrderedEnqueuer(Generator(file_names), True) generator.start(workers=min(os.cpu_count() - 2, config.batch_size)) while True: image, y_true_1, y_true_2, y_true_3 = generator.get().__next__() yield image, y_true_1, y_true_2, y_true_3
def train_srgan( self, epochs, batch_size, dataname, datapath_train, datapath_validation=None, steps_per_validation=1000, datapath_test=None, workers=4, max_queue_size=10, first_epoch=0, print_frequency=1, crops_per_image=2, log_weight_frequency=None, log_weight_path='./data/weights/', log_tensorboard_path='./data/logs/', log_tensorboard_name='SRGAN', log_tensorboard_update_freq=10000, log_test_frequency=500, log_test_path="./images/samples/", ): """Train the SRGAN network :param int epochs: how many epochs to train the network for :param str dataname: name to use for storing model weights etc. :param str datapath_train: path for the image files to use for training :param str datapath_test: path for the image files to use for testing / plotting :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference! :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never :param int log_weight_path: where should network weights be saved :param int log_test_frequency: how often (in epochs) should testing & validation be performed :param str log_test_path: where should test results be saved :param str log_tensorboard_path: where should tensorflow logs be sent :param str log_tensorboard_name: what folder should tf logs be saved under """ # Create train data loader loader = DataLoader(datapath_train, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image) # Validation data loader if datapath_validation is not None: validation_loader = DataLoader(datapath_validation, batch_size, self.height_hr, self.width_hr, self.upscaling_factor, crops_per_image) # Use several workers on CPU for preparing batches enqueuer = OrderedEnqueuer(loader, use_multiprocessing=True, shuffle=True) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() # Callback: tensorboard if log_tensorboard_path: tensorboard = TensorBoard(log_dir=os.path.join( log_tensorboard_path, log_tensorboard_name), histogram_freq=0, batch_size=batch_size, write_graph=False, write_grads=False, update_freq=log_tensorboard_update_freq) tensorboard.set_model(self.srgan) else: print( ">> Not logging to tensorboard since no log_tensorboard_path is set" ) # Callback: format input value def named_logs(model, logs): """Transform train_on_batch return value to dict expected by on_batch_end callback""" result = {} for l in zip(model.metrics_names, logs): result[l[0]] = l[1] return result # Shape of output from discriminator disciminator_output_shape = list(self.discriminator.output_shape) disciminator_output_shape[0] = batch_size disciminator_output_shape = tuple(disciminator_output_shape) # VALID / FAKE targets for discriminator real = np.ones(disciminator_output_shape) fake = np.zeros(disciminator_output_shape) # Each epoch == "update iteration" as defined in the paper print_losses = {"G": [], "D": []} start_epoch = datetime.datetime.now() # Random images to go through idxs = np.random.randint(0, len(loader), epochs) # Loop through epochs / iterations for epoch in range(first_epoch, int(epochs) + first_epoch): # Start epoch time if epoch % (print_frequency + 1) == 0: start_epoch = datetime.datetime.now() # Train discriminator imgs_lr, imgs_hr = next(output_generator) generated_hr = self.generator.predict(imgs_lr) real_loss = self.discriminator.train_on_batch(imgs_hr, real) fake_loss = self.discriminator.train_on_batch(generated_hr, fake) discriminator_loss = 0.5 * np.add(real_loss, fake_loss) # Train generator features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr)) generator_loss = self.srgan.train_on_batch(imgs_lr, [real, features_hr]) # Callbacks logs = named_logs(self.srgan, generator_loss) tensorboard.on_epoch_end(epoch, logs) # Save losses print_losses['G'].append(generator_loss) print_losses['D'].append(discriminator_loss) # Show the progress if epoch % print_frequency == 0: g_avg_loss = np.array(print_losses['G']).mean(axis=0) d_avg_loss = np.array(print_losses['D']).mean(axis=0) print( "\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}" .format( epoch, epochs + first_epoch, (datetime.datetime.now() - start_epoch).seconds, ", ".join([ "{}={:.4f}".format(k, v) for k, v in zip( self.srgan.metrics_names, g_avg_loss) ]), ", ".join([ "{}={:.4f}".format(k, v) for k, v in zip( self.discriminator.metrics_names, d_avg_loss) ]))) print_losses = {"G": [], "D": []} # Run validation inference if specified if datapath_validation: validation_losses = self.generator.evaluate_generator( validation_loader, steps=steps_per_validation, use_multiprocessing=workers > 1, workers=workers) print(">> Validation Losses: {}".format(", ".join([ "{}={:.4f}".format(k, v) for k, v in zip( self.generator.metrics_names, validation_losses) ]))) # If test images are supplied, run model on them and save to log_test_path if datapath_test and epoch % log_test_frequency == 0: plot_test_images(self, loader, datapath_test, log_test_path, epoch) # Check if we should save the network weights if log_weight_frequency and epoch % log_weight_frequency == 0: # Save the network weights self.save_weights(os.path.join(log_weight_path, dataname))
def main(dataname, expr): # metric to keep track of train_accuracy = tf.keras.metrics.CategoricalAccuracy() test_accuracy = tf.keras.metrics.CategoricalAccuracy() train_loss = tf.keras.metrics.Mean() test_loss = tf.keras.metrics.Mean() X, Y = get_imgset_lblset(dataname) np.random.seed(1218) np.random.shuffle(X) np.random.seed(1218) np.random.shuffle(Y) num_classes = len(Y[0]) print("num_classes", num_classes) batch_size = 16 scores = [] kf = KFold(n_splits=3) kf.get_n_splits(X) for k, (train_index, test_index) in enumerate(kf.split(X)): print("# of train: ", len(train_index)) print("# of test: ", len(test_index)) x_train = X[train_index] y_train = Y[train_index] x_test = X[test_index] y_test = Y[test_index] print("y_train", count_class(y_train)) print("y_test", count_class(y_test)) test_indices = np.arange(len(x_test)) # get the training data generator. We are not using validation generator because the # data is already loaded in memory and we don't have to perform any extra operation # apart from loading the validation images and validation labels. ds = DataGenerator(x_train, y_train, num_classes=num_classes, batch_size=batch_size, jsd=expr) enqueuer = OrderedEnqueuer(ds, use_multiprocessing=False) enqueuer.start(workers=1) train_ds = enqueuer.get() plot_name = f"history_{dataname}_fold_{k}.png" history = utils.CTLHistory(expr=expr, filename=plot_name) pt = utils.ProgressTracker() nb_train_steps = int(np.ceil(len(x_train) / batch_size)) nb_test_steps = int(np.ceil(len(x_test) / batch_size)) starting_epoch = 0 nb_epochs = 100 save_dir_path = os.path.join("./checkpoints", f"{dataname}_fold_{k}") if os.path.exists(save_dir_path): shutil.rmtree(save_dir_path) total_steps = nb_train_steps * nb_epochs # get the optimizer # SGD with cosine lr is causing NaNs. Need to investigate more optim = optimizers.Adam(learning_rate=0.0001) model = models.get_resnet50(num_classes) checkpoint_prefix = os.path.join(save_dir_path, "ckpt") checkpoint = tf.train.Checkpoint(optimizer=optim, model=model) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=save_dir_path, max_to_keep=5) train_step_fn = train_step(False) for epoch in range(starting_epoch, nb_epochs): pbar = Progbar(target=nb_train_steps, interval=0.5, width=30) # Train for an epoch and keep track of # loss and accracy for each batch. for bno, (images, labels) in enumerate(train_ds): if bno == nb_train_steps: break if expr: # Get the batch data clean, aug1, aug2 = images const05 = 0.5 randnum = np.random.uniform(0.0, 1.0) if randnum > const05: my_input = aug1 else: my_input = aug2 # loss_value, y_pred_clean = train_step_fn(model, clean, aug1, aug2, labels, optim) loss_value, y_pred_clean = train_step_fn( model, my_input, labels, optim) else: clean = images loss_value, y_pred_clean = train_step_fn( model, clean, labels, optim) # Record batch loss and batch accuracy train_loss(loss_value) train_accuracy(labels, y_pred_clean) pbar.update(bno + 1) # Validate after each epoch for bno in range(nb_test_steps): # Get the indices for the current batch indices = test_indices[bno * batch_size:(bno + 1) * batch_size] # Get the data images, labels = x_test[indices], y_test[indices] # Get the predicitions and loss for this batch loss_value, y_pred = validate_step(model, images, labels) # Record batch loss and accuracy test_loss(loss_value) test_accuracy(labels, y_pred) # get training and validataion stats # after one epoch is completed loss = train_loss.result() acc = train_accuracy.result() val_loss = test_loss.result() val_acc = test_accuracy.result() improved = pt.check_update(val_loss) # check if performance of model has imporved or not if improved: print("Saving model checkpoint.") checkpoint.save(checkpoint_prefix) history.update([loss, acc], [val_loss, val_acc]) # print loss values and accuracy values for each epoch # for both training as well as validation sets print(f"""Epoch: {epoch+1} train_loss: {loss:.6f} train_acc: {acc*100:.2f}% test_loss: {val_loss:.6f} test_acc: {val_acc*100:.2f}%\n""" ) history.plot_and_save(initial_epoch=starting_epoch) train_loss.reset_states() train_accuracy.reset_states() test_loss.reset_states() test_accuracy.reset_states() scores.append(pt.loss) clear_session() min_score = min(scores) if expr: SAVED = "saved/exp" else: SAVED = "saved/default" with open(f"{SAVED}/{dataname}_result.json", "w+") as jf: myd = { 'dataname': dataname, 'num_classes': num_classes, 'score': float(min_score), 'fold': scores.index(min_score) } json.dump(myd, jf) jf.close() print("scores : ", scores)
def train(training_data, validation_data, batch_size=32, nb_epochs=100, min_lr=1e-5, max_lr=1.0, save_dir_path=""): x_train, y_train, y_train_cat = training_data x_test, y_test, y_test_cat = validation_data test_indices = np.arange(len(x_test)) # get the training data generator. We are not using validation generator because the # data is already loaded in memory and we don't have to perform any extra operation # apart from loading the validation images and validation labels. ds = DataGenerator(x_train, y_train_cat, batch_size=batch_size) enqueuer = OrderedEnqueuer(ds, use_multiprocessing=True) enqueuer.start(workers=multiprocessing.cpu_count()) train_ds = enqueuer.get() # get the total number of training and validation steps nb_train_steps = int(np.ceil(len(x_train) / batch_size)) nb_test_steps = int(np.ceil(len(x_test) / batch_size)) global total_steps, lr_max, lr_min total_steps = nb_train_steps lr_max = max_lr lr_min = min_lr # get the optimizer optim = optimizers.SGD(learning_rate=get_lr(0)) # checkpoint prefix checkpoint_prefix = os.path.join(save_dir_path, "ckpt") checkpoint = tf.train.Checkpoint(optimizer=optim, model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=save_dir_path, max_to_keep=10) # check for previous checkpoints, if any checkpoint.restore(checkpoint_manager.latest_checkpoint) if checkpoint_manager.latest_checkpoint: print("Checkpoint restored from {}".format( checkpoint_manager.latest_checkpoint)) starting_epoch = checkpoint.save_counter.numpy() else: print("Initializing from scratch.") starting_epoch = 0 # sanity check for epoch number. For example, if someone restored a checkpoint # from 15th epoch and want to train for two more epochs, then we need to explicitly # encode this logic in the for loop if nb_epochs <= starting_epoch: nb_epochs += starting_epoch for epoch in range(starting_epoch, nb_epochs): pbar = Progbar(target=nb_train_steps, interval=0.5, width=30) # Train for an epoch and keep track of # loss and accracy for each batch. for bno, (images, labels) in enumerate(train_ds): if bno == nb_train_steps: break # Get the batch data clean, aug1, aug2 = images loss_value, y_pred_clean = train_step(clean, aug1, aug2, labels, optim) # Record batch loss and batch accuracy train_loss(loss_value) train_accuracy(labels, y_pred_clean) pbar.update(bno + 1) # Validate after each epoch for bno in range(nb_test_steps): # Get the indices for the current batch indices = test_indices[bno * batch_size:(bno + 1) * batch_size] # Get the data images, labels = x_test[indices], y_test_cat[indices] # Get the predicitions and loss for this batch loss_value, y_pred = validate_step(images, labels) # Record batch loss and accuracy test_loss(loss_value) test_accuracy(labels, y_pred) # get training and validataion stats # after one epoch is completed loss = train_loss.result() acc = train_accuracy.result() val_loss = test_loss.result() val_acc = test_accuracy.result() # record values in the history object history.update([loss, acc], [val_loss, val_acc]) # print loss values and accuracy values for each epoch # for both training as well as validation sets print(f"""Epoch: {epoch+1} train_loss: {loss:.6f} train_acc: {acc*100:.2f}% test_loss: {val_loss:.6f} test_acc: {val_acc*100:.2f}%\n""") # get the model progress improved, stop_training = es.check_progress(val_loss) # check if performance of model has imporved or not if improved: print("Saving model checkpoint.") checkpoint.save(checkpoint_prefix) if stop_training: break # plot and save progression history.plot_and_save(initial_epoch=starting_epoch) # Reset the losses and accuracy train_loss.reset_states() train_accuracy.reset_states() test_loss.reset_states() test_accuracy.reset_states() print("") print("*" * 78) print("")