Exemplo n.º 1
0
    def test(self,
        refer_model=None,
        batch_size=4,
        datapath_test='./images/val_dir',
        crops_per_image=1,
        log_test_path="./images/test/"
    ):
        """Trains the generator part of the network with MSE loss"""

        # Create data loaders
        loader = DataLoader(
            datapath_test, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image
        )
        print(">> Ploting test images")
        plot_test_images(self, loader, datapath_test, log_test_path, 0, refer_model=refer_model)
Exemplo n.º 2
0
    def train_generator(self,
                        epochs,
                        batch_size,
                        workers,
                        dataname,
                        datapath_train,
                        datapath_validation=None,
                        datapath_test=None,
                        steps_per_epoch=1000,
                        steps_per_validation=1000,
                        crops_per_image=2,
                        log_weight_path='./data/weights/',
                        log_tensorboard_path='./data/logs/',
                        log_tensorboard_name='SRResNet',
                        log_tensorboard_update_freq=10000,
                        log_test_path="./images/samples/"):
        """Trains the generator part of the network with MSE loss"""

        # Create data loaders
        train_loader = DataLoader(datapath_train, batch_size, self.height_hr,
                                  self.width_hr, self.upscaling_factor,
                                  crops_per_image)
        test_loader = None
        if datapath_validation is not None:
            test_loader = DataLoader(datapath_validation, batch_size,
                                     self.height_hr, self.width_hr,
                                     self.upscaling_factor, crops_per_image)

        # Callback: tensorboard
        callbacks = []
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, log_tensorboard_name),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=False,
                                      write_grads=False,
                                      update_freq=log_tensorboard_update_freq)
            callbacks.append(tensorboard)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Callback: save weights after each epoch
        modelcheckpoint = ModelCheckpoint(os.path.join(
            log_weight_path, dataname + '_{}X'.format(self.upscaling_factor)),
                                          monitor='val_loss',
                                          save_best_only=True,
                                          save_weights_only=True)
        callbacks.append(modelcheckpoint)

        # Callback: test images plotting
        if datapath_test is not None:
            testplotting = LambdaCallback(on_epoch_end=lambda epoch, logs:
                                          plot_test_images(self,
                                                           train_loader,
                                                           datapath_test,
                                                           log_test_path,
                                                           epoch,
                                                           name='SRResNet'))
            callbacks.append(testplotting)

        # Fit the model
        self.generator.fit_generator(train_loader,
                                     steps_per_epoch=steps_per_epoch,
                                     epochs=epochs,
                                     validation_data=test_loader,
                                     validation_steps=steps_per_validation,
                                     callbacks=callbacks,
                                     use_multiprocessing=workers > 1,
                                     workers=workers,
                                     verbose=1)
Exemplo n.º 3
0
    def train_srgan(
        self,
        epochs,
        batch_size,
        dataname,
        datapath_train,
        datapath_validation=None,
        steps_per_validation=1000,
        datapath_test=None,
        workers=4,
        max_queue_size=10,
        first_epoch=0,
        print_frequency=1,
        crops_per_image=2,
        log_weight_frequency=None,
        log_weight_path='./data/weights/',
        log_tensorboard_path='./data/logs/',
        log_tensorboard_name='SRGAN',
        log_tensorboard_update_freq=10000,
        log_test_frequency=500,
        log_test_path="./images/samples/",
    ):
        """Train the SRGAN network

        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved        
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        :param str log_tensorboard_name: what folder should tf logs be saved under        
        """

        # Create train data loader
        loader = DataLoader(datapath_train, batch_size, self.height_hr,
                            self.width_hr, self.upscaling_factor,
                            crops_per_image)

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(datapath_validation, batch_size,
                                           self.height_hr, self.width_hr,
                                           self.upscaling_factor,
                                           crops_per_image)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(loader,
                                   use_multiprocessing=True,
                                   shuffle=True)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, log_tensorboard_name),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=False,
                                      write_grads=False,
                                      update_freq=log_tensorboard_update_freq)
            tensorboard.set_model(self.srgan)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)

        # Loop through epochs / iterations
        for epoch in range(first_epoch, int(epochs) + first_epoch):

            # Start epoch time
            if epoch % (print_frequency + 1) == 0:
                start_epoch = datetime.datetime.now()

            # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)
            real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch(imgs_lr,
                                                       [real, features_hr])

            # Callbacks
            logs = named_logs(self.srgan, generator_loss)
            tensorboard.on_epoch_end(epoch, logs)

            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(
                    "\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}"
                    .format(
                        epoch, epochs + first_epoch,
                        (datetime.datetime.now() - start_epoch).seconds,
                        ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.srgan.metrics_names, g_avg_loss)
                        ]), ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.discriminator.metrics_names, d_avg_loss)
                        ])))
                print_losses = {"G": [], "D": []}

                # Run validation inference if specified
                if datapath_validation:
                    validation_losses = self.generator.evaluate_generator(
                        validation_loader,
                        steps=steps_per_validation,
                        use_multiprocessing=workers > 1,
                        workers=workers)
                    print(">> Validation Losses: {}".format(", ".join([
                        "{}={:.4f}".format(k, v) for k, v in zip(
                            self.generator.metrics_names, validation_losses)
                    ])))

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                plot_test_images(self, loader, datapath_test, log_test_path,
                                 epoch)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:

                # Save the network weights
                self.save_weights(os.path.join(log_weight_path, dataname))
Exemplo n.º 4
0
    def train(self,
              epochs,
              dataname,
              datapath,
              batch_size=1,
              test_images=None,
              test_frequency=50,
              test_path="./images/samples/",
              weight_frequency=None,
              weight_path='./data/weights/',
              print_frequency=1):
        """Train the SRGAN network
        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath: path for the image files to use for training
        :param int batch_size: how large mini-batches to use
        :param list test_images: list of image paths to perform testing on
        :param int test_frequency: how often (in epochs) should testing be performed
        :param str test_path: where should test results be saved
        :param int weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int weight_path: where should network weights be saved
        :param int print_frequency: how often (in epochs) to print progress to terminal
        """

        # Create data loader
        loader = DataLoader(datapath, self.height_hr, self.width_hr,
                            self.height_lr, self.width_lr,
                            self.upscaling_factor)

        # Shape of output from discriminator
        disciminator_output_shape = list(self.discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.ones(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        losses = []
        for epoch in range(epochs):

            # Start epoch time
            if epoch % (print_frequency + 1) == 0:
                start_epoch = datetime.datetime.now()

            # Train discriminator
            imgs_hr, imgs_lr = loader.load_batch(batch_size)
            generated_hr = self.generator.predict(imgs_lr)
            real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            imgs_hr, imgs_lr = loader.load_batch(batch_size)
            features_hr = self.vgg.predict(imgs_hr)
            generator_loss = self.srgan.train_on_batch([imgs_lr, imgs_hr],
                                                       [real, features_hr])

            # Save losses
            losses.append({
                'generator': generator_loss,
                'discriminator': discriminator_loss
            })

            # Plot the progress
            if epoch % print_frequency == 0:
                print(
                    "Epoch {}/{} | Time: {}s\n>> Generator: {}\n>> Discriminator: {}\n"
                    .format(
                        epoch, epochs,
                        (datetime.datetime.now() - start_epoch).seconds,
                        ", ".join([
                            "{}={:.3e}".format(k, v) for k, v in zip(
                                self.srgan.metrics_names, generator_loss)
                        ]), ", ".join([
                            "{}={:.3e}".format(k, v)
                            for k, v in zip(self.discriminator.metrics_names,
                                            discriminator_loss)
                        ])))

            # If test images are supplied, show them to the user
            if test_images and epoch % test_frequency == 0:
                plot_test_images(self, loader, test_images, test_path, epoch)

            # Check if we should save the network weights
            if weight_frequency and epoch % weight_frequency == 0:

                # Save the network weights
                self.save_weights(os.path.join(weight_path, dataname))

                # Save the recorded losses
                pickle.dump(
                    losses,
                    open(os.path.join(weight_path, dataname + '_losses.p'),
                         'wb'))