Exemplo n.º 1
0
def process(network_info, images_dir, output_dir, threshold=0.8):
    session, input_tensor, output_tensor, img_size, cls_labels = load_from_xml(
        network_info)
    print("Input tensor: {}".format(input_tensor))
    print("Output tensor: {}".format(output_tensor))
    print("Image size: {}".format(img_size))
    print("Classes: {}".format(cls_labels))
    print()
    print("Parsing source directory... (this can take some time)")
    if img_size[2] == 3:
        img_type = 'rgb'
    else:
        img_type = 'greyscale'
    gen = InferenceGenerator(images_dir, 64, img_size, img_type)

    print("Files: {}".format(len(gen.filenames)))
    print("Batches: {}".format(len(gen)))

    workers = np.min((multiprocessing.cpu_count(), 8))
    print("Workers: {}".format(workers))
    print()

    filenames = []
    cls_index = []
    cls_names = []
    score = []
    enq = OrderedEnqueuer(gen, use_multiprocessing=True)
    enq.start(workers=workers, max_queue_size=multiprocessing.cpu_count() * 4)
    output_generator = enq.get()
    for i in range(len(gen)):
        print("\r{} / {}".format(i, len(gen)), end='')
        batch_filenames, batch_images = next(output_generator)
        result = session.run(output_tensor,
                             feed_dict={input_tensor: batch_images})
        cls = np.argmax(result, axis=1)
        scr = np.max(result, axis=1)
        cls_name = [cls_labels[i] for i in cls]
        filenames.extend(batch_filenames)
        cls_index.extend(cls)
        cls_names.extend(cls_name)
        score.extend(scr)
    enq.stop()
    print()
    print("Done")
    print("See {} for results".format(output_dir))

    parents = [Path(f).parent.name for f in filenames]
    files = [Path(f).name for f in filenames]

    df = pd.DataFrame(
        data={
            'filename': filenames,
            'parent': parents,
            'file': files,
            'class': cls_names,
            'class_index': cls_index,
            'score': score
        })
    os.makedirs(output_directory, exist_ok=True)
    df.to_csv(os.path.join(output_dir, "inference.csv"))
def queue_train_generator(train_gen,
                          workers=1,
                          use_multiprocessing=False,
                          max_queue_size=10,
                          use_sequence_api=True):
    # all the queue stuff is from https://github.com/keras-team/keras/blob/master/keras/engine/training_generator.py
    if workers > 0:
        if use_sequence_api:
            enqueuer = OrderedEnqueuer(
                train_gen,
                use_multiprocessing=use_multiprocessing,
                # TODO: add a parameter to control this
                shuffle=False,
            )
        else:
            enqueuer = GeneratorEnqueuer(
                train_gen,
                use_multiprocessing=use_multiprocessing,
            )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        train_generator = enqueuer.get()
    else:
        if use_sequence_api:
            train_generator = iter_sequence_infinite(train_gen)
        else:
            train_generator = train_gen
    return train_generator
Exemplo n.º 3
0
def get_enqueuer(csv,batch_size, FLAGS, tokenizer_wrapper, augmenter=None):
    data_generator = AugmentedImageSequence(
        dataset_csv_file=csv,
        class_names=FLAGS.csv_label_columns,
        tokenizer_wrapper=tokenizer_wrapper,
        source_image_dir=FLAGS.image_directory,
        batch_size=batch_size,
        target_size=FLAGS.image_target_size,
        augmenter=augmenter,
        shuffle_on_epoch_end=True,
    )
    enqueuer = OrderedEnqueuer(data_generator,
                               use_multiprocessing=False,
                               shuffle=False)
    return enqueuer, data_generator.steps
Exemplo n.º 4
0
 def __init__(self, data_generator, batch_size, num_samples, output_dir,
              input_shape, n_classes):
     self.batch_size = batch_size
     self.num_samples = num_samples
     self.tensorboard_writer = tf.summary.create_file_writer(
         output_dir + "/diagnose/", flush_millis=10000)
     self.data_generator = data_generator
     self.input_shape = input_shape
     self.colors = np.array([[255, 255, 0], [255, 0, 0], [0, 255, 0],
                             [0, 0, 255], [0, 0, 0]])
     self.color_dict = {0: (0, 0, 0), 1: (0, 255, 0)}
     self.n_classes = n_classes
     self.colors = self.colors[:self.n_classes]
     is_sequence = isinstance(self.data_generator, Sequence)
     if is_sequence:
         self.enqueuer = OrderedEnqueuer(self.data_generator,
                                         use_multiprocessing=True,
                                         shuffle=False)
     else:
         self.enqueuer = GeneratorEnqueuer(self.data_generator,
                                           use_multiprocessing=True,
                                           wait_time=0.01)
     self.enqueuer.start(workers=4, max_queue_size=4)
Exemplo n.º 5
0
 def __init__(self,
              data_generator,
              batch_size,
              num_samples,
              output_dir,
              normalization_mean,
              start_index=0):
     super().__init__()
     self.data_generator = data_generator
     self.batch_size = batch_size
     self.num_samples = num_samples
     self.tensorboard_writer = TensorboardWriter(output_dir)
     self.normalization_mean = normalization_mean
     self.start_index = start_index
     is_sequence = isinstance(self.data_generator, Sequence)
     if is_sequence:
         self.enqueuer = OrderedEnqueuer(self.data_generator,
                                         use_multiprocessing=False,
                                         shuffle=False)
     else:
         self.enqueuer = GeneratorEnqueuer(self.data_generator,
                                           use_multiprocessing=False)
     self.enqueuer.start(workers=1, max_queue_size=4)
Exemplo n.º 6
0
 def generator_fn():
     generator = OrderedEnqueuer(Generator(f_names), True)
     generator.start(workers=8, max_queue_size=10)
     while True:
         image, y_true_1, y_true_2, y_true_3 = generator.get().__next__()
         yield image, y_true_1, y_true_2, y_true_3
Exemplo n.º 7
0
    def train_esrgan(self, 
        epochs=None, batch_size=16, 
        modelname=None, 
        datapath_train=None,
        datapath_validation=None, 
        steps_per_validation=1000,
        datapath_test=None, 
        workers=4, max_queue_size=10,
        first_epoch=0,
        print_frequency=1,
        crops_per_image=2,
        log_weight_frequency=None, 
        log_weight_path='./model/', 
        log_tensorboard_path='./data/logs/',
        log_tensorboard_update_freq=10,
        log_test_frequency=500,
        log_test_path="./images/samples/", 
        media_type='i'        
    ):
        """Train the ESRGAN network

        :param int epochs: how many epochs to train the network for
        :param str modelname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved        
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        """

        
        
         # Create data loaders
        train_loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image,
            media_type,
            self.channels,
            self.colorspace
        )

        # Validation data loader
        validation_loader = None 
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image,
                media_type,
                self.channels,
                self.colorspace
        )

        test_loader = None
        if datapath_test is not None:
            test_loader = DataLoader(
                datapath_test, 1,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                1,
                media_type,
                self.channels,
                self.colorspace
        )
    
        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            train_loader,
            use_multiprocessing=True,
            shuffle=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
        
        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(
                log_dir=os.path.join(log_tensorboard_path, modelname),
                histogram_freq=0,
                batch_size=batch_size,
                write_graph=True,
                write_grads=True,
                update_freq=log_tensorboard_update_freq
            )
            tensorboard.set_model(self.esrgan)
        else:
            print(">> Not logging to tensorboard since no log_tensorboard_path is set")

        # Learning rate scheduler
        def lr_scheduler(epoch, lr):
            factor = 0.5
            decay_step =  [50000,100000,200000,300000]  
            if epoch in decay_step and epoch:
                return lr * factor
            return lr
        lr_scheduler_gan = LearningRateScheduler(lr_scheduler, verbose=1)
        lr_scheduler_gan.set_model(self.esrgan)
        lr_scheduler_gen = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_gen.set_model(self.generator)
        lr_scheduler_dis = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_dis.set_model(self.discriminator)
        lr_scheduler_ra = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_ra.set_model(self.ra_discriminator)

        
        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.ra_discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape) 
               

        # Each epoch == "update iteration" as defined in the paper        
        print_losses = {"GAN": [], "D": []}
        start_epoch = datetime.datetime.now()
        
        # Random images to go through
        #idxs = np.random.randint(0, len(train_loader), epochs)        
        
        # Loop through epochs / iterations
        for epoch in range(first_epoch, int(epochs)+first_epoch):
            lr_scheduler_gan.on_epoch_begin(epoch)
            lr_scheduler_ra.on_epoch_begin(epoch)
            lr_scheduler_dis.on_epoch_begin(epoch)
            lr_scheduler_gen.on_epoch_begin(epoch)

            # Start epoch time
            if epoch % print_frequency == 0:
                print("\nEpoch {}/{}:".format(epoch+1, epochs+first_epoch))
                start_epoch = datetime.datetime.now()            

            # Train discriminator 
            self.discriminator.trainable = True
            self.ra_discriminator.trainable = True
            
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)

            real_loss = self.ra_discriminator.train_on_batch([imgs_hr,generated_hr], real)
            #print("Real: ",real_loss)
            fake_loss = self.ra_discriminator.train_on_batch([generated_hr,imgs_hr], fake)
            #print("Fake: ",fake_loss)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            self.discriminator.trainable = False
            self.ra_discriminator.trainable = False
            
            #for _ in tqdm(range(10),ncols=60,desc=">> Training generator"):
            imgs_lr, imgs_hr = next(output_generator)
            gan_loss = self.esrgan.train_on_batch([imgs_lr,imgs_hr], [imgs_hr,real,imgs_hr])
     
            # Callbacks
            logs = named_logs(self.esrgan, gan_loss)
            tensorboard.on_epoch_end(epoch, logs)
            

            # Save losses            
            print_losses['GAN'].append(gan_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['GAN']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(">> Time: {}s\n>> GAN: {}\n>> Discriminator: {}".format(
                    (datetime.datetime.now() - start_epoch).seconds,
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.esrgan.metrics_names, g_avg_loss)]),
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.discriminator.metrics_names, d_avg_loss)])
                ))
                print_losses = {"GAN": [], "D": []}

                # Run validation inference if specified
                if datapath_validation:
                    validation_losses = self.generator.evaluate_generator(
                        validation_loader,
                        steps=steps_per_validation,
                        use_multiprocessing=workers>1,
                        workers=workers
                    )
                    print(">> Validation Losses: {}".format(
                        ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.generator.metrics_names, validation_losses)])
                    ))                

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                plot_test_images(self.generator, test_loader, datapath_test, log_test_path, epoch, modelname,
                channels = self.channels,colorspace=self.colorspace)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                self.save_weights(os.path.join(log_weight_path, modelname))
Exemplo n.º 8
0
    def train_generator(self,
        epochs=None, batch_size=None,
        workers=None,
        max_queue_size=None,
        modelname=None, 
        datapath_train=None,
        datapath_validation='../',
        datapath_test='../',
        steps_per_epoch=None,
        steps_per_validation=None,
        crops_per_image=None,
        print_frequency=None,
        log_weight_path='./model/', 
        log_tensorboard_path='./logs/',
        log_tensorboard_update_freq=None,
        log_test_path="./test/",
        media_type='i'
    ):
        """Trains the generator part of the network with MSE loss"""


        # Create data loaders
        train_loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image,
            media_type,
            self.channels,
            self.colorspace
        )

        
        validation_loader = None 
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image,
                media_type,
                self.channels,
                self.colorspace
        )

        test_loader = None
        if datapath_test is not None:
            test_loader = DataLoader(
                datapath_test, 1,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                1,
                media_type,
                self.channels,
                self.colorspace
        )

        
        # Callback: tensorboard
        callbacks = []
        if log_tensorboard_path:
            tensorboard = TensorBoard(
                log_dir=os.path.join(log_tensorboard_path, modelname),
                histogram_freq=0,
                batch_size=batch_size,
                write_graph=True,
                write_grads=True,
                update_freq=log_tensorboard_update_freq
            )
            callbacks.append(tensorboard)
        else:
            print(">> Not logging to tensorboard since no log_tensorboard_path is set")

	    # Callback: Stop training when a monitored quantity has stopped improving
        earlystopping = EarlyStopping(
            monitor='val_loss', 
	        patience=500, verbose=1, 
	        restore_best_weights=True     
        )
        callbacks.append(earlystopping)
        
        # Callback: save weights after each epoch
        modelcheckpoint = ModelCheckpoint(
            os.path.join(log_weight_path, modelname + '_{}X.h5'.format(self.upscaling_factor)), 
            monitor='val_loss', 
            save_best_only=True, 
            save_weights_only=True
        )
        callbacks.append(modelcheckpoint)

        # Callback: Reduce lr when a monitored quantity has stopped improving
        reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                                    patience=50, min_lr=1e-5,verbose=1)
        callbacks.append(reduce_lr)

        # Learning rate scheduler
        def lr_scheduler(epoch, lr):
            factor = 0.5
            decay_step = 100 #100 epochs * 2000 step per epoch = 2x1e5
            if epoch % decay_step == 0 and epoch:
                return lr * factor
            return lr
        lr_scheduler = LearningRateScheduler(lr_scheduler, verbose=1)
        callbacks.append(lr_scheduler)


        # Callback: save weights after each epoch
        modelcheckpoint = ModelCheckpoint(
            os.path.join(log_weight_path, modelname + '_{}X.h5'.format(self.upscaling_factor)), 
            monitor='val_loss', 
            save_best_only=True, 
            save_weights_only=True)
        callbacks.append(modelcheckpoint)
 
        
         # Callback: test images plotting
        if datapath_test is not None:
            testplotting = LambdaCallback(
                on_epoch_end=lambda epoch, logs: None if ((epoch+1) % print_frequency != 0 ) else plot_test_images(
                    self.generator,
                    test_loader,
                    datapath_test,
                    log_test_path,
                    epoch+1,
                    name=modelname,
                    channels=self.channels,
                    colorspace=self.colorspace))
        callbacks.append(testplotting)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            train_loader,
            use_multiprocessing=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

                            
        # Fit the model
        self.generator.fit_generator(
            output_generator,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            validation_data=validation_loader,
            validation_steps=steps_per_validation,
            callbacks=callbacks,
            use_multiprocessing=False, #workers>1 because single gpu
            workers=1
        )
Exemplo n.º 9
0
    def train(
            self,
            epochs=50,
            batch_size=8,
            steps_per_epoch=5,
            steps_per_validation=5,
            crops_per_image=4,
            print_frequency=5,
            log_tensorboard_update_freq=10,
            workers=4,
            max_queue_size=5,
            model_name='ESPCN',
            datapath_train='../../../videos_harmonic/MYANMAR_2160p/train/',
            datapath_validation='../../../videos_harmonic/MYANMAR_2160p/validation/',
            datapath_test='../../../videos_harmonic/MYANMAR_2160p/test/',
            log_weight_path='../model/',
            log_tensorboard_path='../logs/',
            log_test_path='../test/',
            media_type='i'):

        # Create data loaders
        train_loader = DataLoader(datapath_train, batch_size, self.height_hr,
                                  self.width_hr, self.upscaling_factor,
                                  crops_per_image, media_type, self.channels,
                                  self.colorspace)

        validation_loader = None
        if datapath_validation is not None:
            validation_loader = DataLoader(datapath_validation, batch_size,
                                           self.height_hr, self.width_hr,
                                           self.upscaling_factor,
                                           crops_per_image, media_type,
                                           self.channels, self.colorspace)

        test_loader = None
        if datapath_test is not None:
            test_loader = DataLoader(datapath_test, 1, self.height_hr,
                                     self.width_hr, self.upscaling_factor, 1,
                                     media_type, self.channels,
                                     self.colorspace)

        # Callback: tensorboard
        callbacks = []
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, model_name),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=True,
                                      write_grads=True,
                                      update_freq=log_tensorboard_update_freq)
            callbacks.append(tensorboard)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Callback: Stop training when a monitored quantity has stopped improving
        earlystopping = EarlyStopping(monitor='val_loss',
                                      patience=100,
                                      verbose=1,
                                      restore_best_weights=True)
        callbacks.append(earlystopping)

        # Callback: Reduce lr when a monitored quantity has stopped improving
        reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                      factor=0.5,
                                      patience=100,
                                      min_lr=1e-4,
                                      verbose=1)
        callbacks.append(reduce_lr)

        # Callback: save weights after each epoch
        modelcheckpoint = ModelCheckpoint(os.path.join(
            log_weight_path,
            model_name + '_{}X.h5'.format(self.upscaling_factor)),
                                          monitor='val_loss',
                                          save_best_only=True,
                                          save_weights_only=True)
        callbacks.append(modelcheckpoint)

        # Callback: test images plotting
        if datapath_test is not None:
            testplotting = LambdaCallback(
                on_epoch_end=lambda epoch, logs: None
                if ((epoch + 1) % print_frequency != 0) else plot_test_images(
                    self.model,
                    test_loader,
                    datapath_test,
                    log_test_path,
                    epoch + 1,
                    name=model_name,
                    channels=self.channels,
                    colorspace=self.colorspace))
        callbacks.append(testplotting)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(train_loader, use_multiprocessing=True)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        self.model.fit_generator(
            output_generator,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            validation_data=validation_loader,
            validation_steps=steps_per_validation,
            callbacks=callbacks,
            #shuffle=True,
            use_multiprocessing=False,  #workers>1 
            workers=1  #workers
        )
Exemplo n.º 10
0
    def predict_generator(self,
                          generator,
                          steps,
                          max_queue_size=10,
                          workers=1,
                          use_multiprocessing=False,
                          verbose=0):
        """Generates predictions for the input samples from a data generator.
        The generator should return the same kind of data as accepted by `predict_on_batch`.

        generator = DataGenerator class that returns:
            x = Input data as a 3D Tensor (batch_size, max_input_len, dim_features)
            x_len = 1D array with the length of each data in batch_size

        # Arguments
            generator: Generator yielding batches of input samples
                    or an instance of Sequence (tensorflow.keras.utils.Sequence)
                    object in order to avoid duplicate data
                    when using multiprocessing.
            steps:
                Total number of steps (batches of samples)
                to yield from `generator` before stopping.
            max_queue_size:
                Maximum size for the generator queue.
            workers: Maximum number of processes to spin up
                when using process based threading
            use_multiprocessing: If `True`, use process based threading.
                Note that because this implementation relies on multiprocessing,
                you should not pass non picklable arguments to the generator
                as they can't be passed easily to children processes.
            verbose:
                verbosity mode, 0 or 1.

        # Returns
            A numpy array(s) of predictions.

        # Raises
            ValueError: In case the generator yields
                data in an invalid format.
        """

        self.model_pred._make_predict_function()
        is_sequence = isinstance(generator, Sequence)

        allab_outs = []
        steps_done = 0
        enqueuer = None

        try:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)

            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()

            if verbose == 1:
                progbar = Progbar(target=steps)

            while steps_done < steps:
                x = next(output_generator)
                outs = self.predict_on_batch(x)

                if not isinstance(outs, list):
                    outs = [outs]

                for i, out in enumerate(outs):
                    allab_outs.append([int(c) for c in out])

                steps_done += 1
                if verbose == 1:
                    progbar.update(steps_done)

        finally:
            if enqueuer is not None:
                enqueuer.stop()

        return allab_outs
Exemplo n.º 11
0
 def generator_fn():
     generator = OrderedEnqueuer(Generator(file_names), True)
     generator.start(workers=min(os.cpu_count() - 2, config.batch_size))
     while True:
         image, y_true_1, y_true_2, y_true_3 = generator.get().__next__()
         yield image, y_true_1, y_true_2, y_true_3
Exemplo n.º 12
0
    def train_srgan(
        self,
        epochs,
        batch_size,
        dataname,
        datapath_train,
        datapath_validation=None,
        steps_per_validation=1000,
        datapath_test=None,
        workers=4,
        max_queue_size=10,
        first_epoch=0,
        print_frequency=1,
        crops_per_image=2,
        log_weight_frequency=None,
        log_weight_path='./data/weights/',
        log_tensorboard_path='./data/logs/',
        log_tensorboard_name='SRGAN',
        log_tensorboard_update_freq=10000,
        log_test_frequency=500,
        log_test_path="./images/samples/",
    ):
        """Train the SRGAN network

        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved        
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        :param str log_tensorboard_name: what folder should tf logs be saved under        
        """

        # Create train data loader
        loader = DataLoader(datapath_train, batch_size, self.height_hr,
                            self.width_hr, self.upscaling_factor,
                            crops_per_image)

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(datapath_validation, batch_size,
                                           self.height_hr, self.width_hr,
                                           self.upscaling_factor,
                                           crops_per_image)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(loader,
                                   use_multiprocessing=True,
                                   shuffle=True)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, log_tensorboard_name),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=False,
                                      write_grads=False,
                                      update_freq=log_tensorboard_update_freq)
            tensorboard.set_model(self.srgan)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)

        # Loop through epochs / iterations
        for epoch in range(first_epoch, int(epochs) + first_epoch):

            # Start epoch time
            if epoch % (print_frequency + 1) == 0:
                start_epoch = datetime.datetime.now()

            # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)
            real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch(imgs_lr,
                                                       [real, features_hr])

            # Callbacks
            logs = named_logs(self.srgan, generator_loss)
            tensorboard.on_epoch_end(epoch, logs)

            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(
                    "\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}"
                    .format(
                        epoch, epochs + first_epoch,
                        (datetime.datetime.now() - start_epoch).seconds,
                        ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.srgan.metrics_names, g_avg_loss)
                        ]), ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.discriminator.metrics_names, d_avg_loss)
                        ])))
                print_losses = {"G": [], "D": []}

                # Run validation inference if specified
                if datapath_validation:
                    validation_losses = self.generator.evaluate_generator(
                        validation_loader,
                        steps=steps_per_validation,
                        use_multiprocessing=workers > 1,
                        workers=workers)
                    print(">> Validation Losses: {}".format(", ".join([
                        "{}={:.4f}".format(k, v) for k, v in zip(
                            self.generator.metrics_names, validation_losses)
                    ])))

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                plot_test_images(self, loader, datapath_test, log_test_path,
                                 epoch)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:

                # Save the network weights
                self.save_weights(os.path.join(log_weight_path, dataname))
Exemplo n.º 13
0
def main(dataname, expr):
    # metric to keep track of
    train_accuracy = tf.keras.metrics.CategoricalAccuracy()
    test_accuracy = tf.keras.metrics.CategoricalAccuracy()
    train_loss = tf.keras.metrics.Mean()
    test_loss = tf.keras.metrics.Mean()

    X, Y = get_imgset_lblset(dataname)
    np.random.seed(1218)
    np.random.shuffle(X)
    np.random.seed(1218)
    np.random.shuffle(Y)

    num_classes = len(Y[0])
    print("num_classes", num_classes)
    batch_size = 16

    scores = []
    kf = KFold(n_splits=3)
    kf.get_n_splits(X)
    for k, (train_index, test_index) in enumerate(kf.split(X)):
        print("# of train: ", len(train_index))
        print("# of test: ", len(test_index))
        x_train = X[train_index]
        y_train = Y[train_index]
        x_test = X[test_index]
        y_test = Y[test_index]

        print("y_train", count_class(y_train))
        print("y_test", count_class(y_test))

        test_indices = np.arange(len(x_test))

        # get the training data generator. We are not using validation generator because the
        # data is already loaded in memory and we don't have to perform any extra operation
        # apart from loading the validation images and validation labels.
        ds = DataGenerator(x_train,
                           y_train,
                           num_classes=num_classes,
                           batch_size=batch_size,
                           jsd=expr)
        enqueuer = OrderedEnqueuer(ds, use_multiprocessing=False)
        enqueuer.start(workers=1)
        train_ds = enqueuer.get()

        plot_name = f"history_{dataname}_fold_{k}.png"
        history = utils.CTLHistory(expr=expr, filename=plot_name)

        pt = utils.ProgressTracker()

        nb_train_steps = int(np.ceil(len(x_train) / batch_size))
        nb_test_steps = int(np.ceil(len(x_test) / batch_size))

        starting_epoch = 0
        nb_epochs = 100

        save_dir_path = os.path.join("./checkpoints", f"{dataname}_fold_{k}")
        if os.path.exists(save_dir_path):
            shutil.rmtree(save_dir_path)

        total_steps = nb_train_steps * nb_epochs

        # get the optimizer
        # SGD with cosine lr is causing NaNs. Need to investigate more
        optim = optimizers.Adam(learning_rate=0.0001)
        model = models.get_resnet50(num_classes)

        checkpoint_prefix = os.path.join(save_dir_path, "ckpt")
        checkpoint = tf.train.Checkpoint(optimizer=optim, model=model)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, directory=save_dir_path, max_to_keep=5)

        train_step_fn = train_step(False)
        for epoch in range(starting_epoch, nb_epochs):
            pbar = Progbar(target=nb_train_steps, interval=0.5, width=30)
            # Train for an epoch and keep track of
            # loss and accracy for each batch.
            for bno, (images, labels) in enumerate(train_ds):
                if bno == nb_train_steps:
                    break

                if expr:
                    # Get the batch data
                    clean, aug1, aug2 = images

                    const05 = 0.5
                    randnum = np.random.uniform(0.0, 1.0)
                    if randnum > const05:
                        my_input = aug1
                    else:
                        my_input = aug2

                    # loss_value, y_pred_clean = train_step_fn(model, clean, aug1, aug2, labels, optim)
                    loss_value, y_pred_clean = train_step_fn(
                        model, my_input, labels, optim)
                else:
                    clean = images
                    loss_value, y_pred_clean = train_step_fn(
                        model, clean, labels, optim)

                # Record batch loss and batch accuracy
                train_loss(loss_value)
                train_accuracy(labels, y_pred_clean)
                pbar.update(bno + 1)

            # Validate after each epoch
            for bno in range(nb_test_steps):
                # Get the indices for the current batch
                indices = test_indices[bno * batch_size:(bno + 1) * batch_size]

                # Get the data
                images, labels = x_test[indices], y_test[indices]

                # Get the predicitions and loss for this batch
                loss_value, y_pred = validate_step(model, images, labels)

                # Record batch loss and accuracy
                test_loss(loss_value)
                test_accuracy(labels, y_pred)

            # get training and validataion stats
            # after one epoch is completed
            loss = train_loss.result()
            acc = train_accuracy.result()
            val_loss = test_loss.result()
            val_acc = test_accuracy.result()

            improved = pt.check_update(val_loss)
            # check if performance of model has imporved or not
            if improved:
                print("Saving model checkpoint.")
                checkpoint.save(checkpoint_prefix)

            history.update([loss, acc], [val_loss, val_acc])
            # print loss values and accuracy values for each epoch
            # for both training as well as validation sets
            print(f"""Epoch: {epoch+1} 
                    train_loss: {loss:.6f}  train_acc: {acc*100:.2f}%  
                    test_loss:  {val_loss:.6f}  test_acc:  {val_acc*100:.2f}%\n"""
                  )

            history.plot_and_save(initial_epoch=starting_epoch)

            train_loss.reset_states()
            train_accuracy.reset_states()
            test_loss.reset_states()
            test_accuracy.reset_states()

        scores.append(pt.loss)
        clear_session()

    min_score = min(scores)
    if expr:
        SAVED = "saved/exp"
    else:
        SAVED = "saved/default"
    with open(f"{SAVED}/{dataname}_result.json", "w+") as jf:
        myd = {
            'dataname': dataname,
            'num_classes': num_classes,
            'score': float(min_score),
            'fold': scores.index(min_score)
        }
        json.dump(myd, jf)
        jf.close()

    print("scores : ", scores)
Exemplo n.º 14
0
def train(training_data,
          validation_data,
          batch_size=32,
          nb_epochs=100,
          min_lr=1e-5,
          max_lr=1.0,
          save_dir_path=""):

    x_train, y_train, y_train_cat = training_data
    x_test, y_test, y_test_cat = validation_data
    test_indices = np.arange(len(x_test))

    # get the training data generator. We are not using validation generator because the
    # data is already loaded in memory and we don't have to perform any extra operation
    # apart from loading the validation images and validation labels.
    ds = DataGenerator(x_train, y_train_cat, batch_size=batch_size)
    enqueuer = OrderedEnqueuer(ds, use_multiprocessing=True)
    enqueuer.start(workers=multiprocessing.cpu_count())
    train_ds = enqueuer.get()

    # get the total number of training and validation steps
    nb_train_steps = int(np.ceil(len(x_train) / batch_size))
    nb_test_steps = int(np.ceil(len(x_test) / batch_size))

    global total_steps, lr_max, lr_min
    total_steps = nb_train_steps
    lr_max = max_lr
    lr_min = min_lr

    # get the optimizer
    optim = optimizers.SGD(learning_rate=get_lr(0))

    # checkpoint prefix
    checkpoint_prefix = os.path.join(save_dir_path, "ckpt")
    checkpoint = tf.train.Checkpoint(optimizer=optim, model=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=save_dir_path,
                                                    max_to_keep=10)

    # check for previous checkpoints, if any
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
        print("Checkpoint restored from {}".format(
            checkpoint_manager.latest_checkpoint))
        starting_epoch = checkpoint.save_counter.numpy()
    else:
        print("Initializing from scratch.")
        starting_epoch = 0

    # sanity check for epoch number. For example, if someone restored a checkpoint
    # from 15th epoch and want to train for two more epochs, then we need to explicitly
    # encode this logic in the for loop
    if nb_epochs <= starting_epoch:
        nb_epochs += starting_epoch

    for epoch in range(starting_epoch, nb_epochs):
        pbar = Progbar(target=nb_train_steps, interval=0.5, width=30)

        # Train for an epoch and keep track of
        # loss and accracy for each batch.
        for bno, (images, labels) in enumerate(train_ds):
            if bno == nb_train_steps:
                break

            # Get the batch data
            clean, aug1, aug2 = images
            loss_value, y_pred_clean = train_step(clean, aug1, aug2, labels,
                                                  optim)

            # Record batch loss and batch accuracy
            train_loss(loss_value)
            train_accuracy(labels, y_pred_clean)
            pbar.update(bno + 1)

        # Validate after each epoch
        for bno in range(nb_test_steps):
            # Get the indices for the current batch
            indices = test_indices[bno * batch_size:(bno + 1) * batch_size]

            # Get the data
            images, labels = x_test[indices], y_test_cat[indices]

            # Get the predicitions and loss for this batch
            loss_value, y_pred = validate_step(images, labels)

            # Record batch loss and accuracy
            test_loss(loss_value)
            test_accuracy(labels, y_pred)

        # get training and validataion stats
        # after one epoch is completed
        loss = train_loss.result()
        acc = train_accuracy.result()
        val_loss = test_loss.result()
        val_acc = test_accuracy.result()

        # record values in the history object
        history.update([loss, acc], [val_loss, val_acc])

        # print loss values and accuracy values for each epoch
        # for both training as well as validation sets
        print(f"""Epoch: {epoch+1} 
                train_loss: {loss:.6f}  train_acc: {acc*100:.2f}%  
                test_loss:  {val_loss:.6f}  test_acc:  {val_acc*100:.2f}%\n""")

        # get the model progress
        improved, stop_training = es.check_progress(val_loss)
        # check if performance of model has imporved or not
        if improved:
            print("Saving model checkpoint.")
            checkpoint.save(checkpoint_prefix)
        if stop_training:
            break

        # plot and save progression
        history.plot_and_save(initial_epoch=starting_epoch)

        # Reset the losses and accuracy
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()
        print("")
        print("*" * 78)
        print("")