Exemple #1
0
def train():
    device = 'gpu:0' if tf.test.is_gpu_available() else 'cpu'
    args = parse_train_arguments()

    with tf.device(device):

        losses = {
            'train_loss': Mean(name='train_loss'),
            'train_mse': Mean(name='train_mse'),
            'train_psnr': Mean(name='train_psnr'),
            'train_ssim': Mean(name='train_ssim')
        }

        train_dataset_path = glob(os.path.join(args.train_dataset_base_path, '**/**.png'), recursive=True) + \
                             glob(os.path.join(args.train_dataset_base_path, '**/**.jpg'), recursive=True) + \
                             glob(os.path.join(args.train_dataset_base_path, '**/**.bmp'), recursive=True)

        dataset = load_simulation_data(train_dataset_path, args.batch_size * args.batches, args.patch_size,
                                       args.radious, args.epsilon)

        model, optimizer, initial_epoch, clip_norms = load_model(args.checkpoint_directory, args.restore_model,
                                                                 args.learning_rate)

        for epoch in range(initial_epoch, args.epochs):
            total_clip_norms = [tf.cast(0, dtype=tf.float32), tf.cast(0, dtype=tf.float32)]
            batched_dataset = dataset.batch(args.batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
            progress_bar = tqdm(batched_dataset, total=args.batches)

            for index, data_batch in enumerate(progress_bar):
                dnet_new_norm, snet_new_norm = train_step(model, optimizer, data_batch, losses, clip_norms,
                                                          args.radious)
                on_batch_end(epoch, index, dnet_new_norm, snet_new_norm, total_clip_norms, losses, progress_bar)

            on_epoch_end(model, optimizer, epoch, losses, clip_norms, total_clip_norms, args.checkpoint_directory, args.checkpoint_frequency)
Exemple #2
0
 def __init__(self):
     # loss function for training
     self.loss_object = SparseCategoricalCrossentropy()
     # optimizer for training
     self.optimizer = Adam()
     # metrics to measure the loss and the accuracy of the model
     self.train_loss = Mean(name='train_loss')
     self.train_accuracy = SparseCategoricalAccuracy(name='train_accuracy')
     self.test_loss = Mean(name='test_loss')
     self.test_accuracy = SparseCategoricalAccuracy(name='test_accuracy')
Exemple #3
0
 def build(self,input_shape):
     neg_log_perplexity=functools.partial(
         padded_neg_log_perplexity,vocab_size=self.vocab_size
     )
     self.metric_mean_fns=[
         (Mean('accuracy'),padded_accuracy),
         (Mean('accuracy_top5'),padded_accuracy_top5),
         (Mean('accuracy_per_sequence'),padded_sequence_accuracy),
         (Mean('neg_log_perplexity'),neg_log_perplexity)
     ]
     super().build(input_shape)
def test():
    args = parse_test_arguments()

    model = load_test_model(args.checkpoint_directory)

    validation_dataset = load_validation_dataset(
        validation_clean_dataset_path=args.validation_clean_dataset_path,
        validation_noisy_dataset_path=args.validation_noisy_dataset_path)

    losses = {
        'validation_mse': Mean(name='validation_mse'),
        'validation_psnr': Mean(name='validation_psnr'),
        'validation_ssim': Mean(name='validation_ssim')
    }

    validation_progress_bar = tqdm(validation_dataset.batch(args.batch_size))

    for validation_data_batch in validation_progress_bar:
        validate(model, validation_data_batch, losses)

    for loss in losses.keys():
        print("{}:{}".format(loss, float(losses[loss].result())))
def train():
    device = 'gpu:0' if tf.test.is_gpu_available() else 'cpu'

    args = parse_train_arguments()

    with tf.device(device):

        best_losses = {'validation_psnr': 0, 'validation_ssim': 0}

        losses = {
            'train_loss': Mean(name='train_loss'),
            'train_mse': Mean(name='train_mse'),
            'validation_mse': Mean(name='validation_mse'),
            'validation_psnr': Mean(name='validation_psnr'),
            'validation_ssim': Mean(name='validation_ssim')
        }

        train_dataset_path = glob(os.path.join(args.train_dataset_base_path,
                                               '**/*GT*.PNG'),
                                  recursive=True)
        train_noisy_dataset_path = [
            path_image.replace('GT', 'NOISY')
            for path_image in train_dataset_path
        ]

        dataset = load_train_data(
            train_dataset_path=train_dataset_path,
            train_noisy_dataset_path=train_noisy_dataset_path,
            data_length=args.batch_size * args.batches,
            patch_size=args.patch_size,
            radious=args.radious,
            epsilon=args.epsilon)

        validation_dataset = load_validation_dataset(
            validation_clean_dataset_path=args.validation_clean_dataset_path,
            validation_noisy_dataset_path=args.validation_noisy_dataset_path)

        model, optimizer, initial_epoch, clip_norms = load_model(
            checkpoint_directory=args.checkpoint_directory,
            restore_model=args.restore_model,
            learning_rate=args.learning_rate)

        for epoch in range(initial_epoch, args.epochs + 1):

            total_clip_norms = [
                tf.cast(0, dtype=tf.float32),
                tf.cast(0, dtype=tf.float32)
            ]
            batched_dataset = dataset.batch(args.batch_size).prefetch(
                buffer_size=tf.data.experimental.AUTOTUNE)
            progress_bar = tqdm(batched_dataset, total=args.batches)

            for index, data_batch in enumerate(progress_bar):
                dnet_new_norm, snet_new_norm = train_step(
                    model, optimizer, data_batch, losses, clip_norms,
                    args.radious)
                on_batch_end(epoch, index, dnet_new_norm, snet_new_norm,
                             total_clip_norms, losses, progress_bar,
                             args.batches)

            validation_progress_bar = tqdm(
                validation_dataset.batch(args.batch_size))

            for validation_data_batch in validation_progress_bar:
                validation_step(model, validation_data_batch, losses)

            on_epoch_end(model, optimizer, epoch, losses, best_losses,
                         clip_norms, total_clip_norms,
                         args.checkpoint_directory)
Exemple #6
0
class CustomTrainer(Trainer):
    """
    This trainer follows https://www.tensorflow.org/beta/tutorials/quickstart/advanced .
    """
    def __init__(self):
        # loss function for training
        self.loss_object = SparseCategoricalCrossentropy()
        # optimizer for training
        self.optimizer = Adam()
        # metrics to measure the loss and the accuracy of the model
        self.train_loss = Mean(name='train_loss')
        self.train_accuracy = SparseCategoricalAccuracy(name='train_accuracy')
        self.test_loss = Mean(name='test_loss')
        self.test_accuracy = SparseCategoricalAccuracy(name='test_accuracy')

    def __str__(self):
        return "CustomTrainer"

    @tf.function
    def train_step(self, inputs, classes, model):
        with tf.GradientTape() as tape:
            predictions = model(inputs)
            loss = self.loss_object(classes, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, model.trainable_variables))

        self.train_loss(loss)
        self.train_accuracy(classes, predictions)

    @tf.function
    def test_step(self, inputs, classes, model):
        predictions = model(inputs)
        loss = self.loss_object(classes, predictions)

        self.test_loss(loss)
        self.test_accuracy(classes, predictions)

    def print_summary(self, epoch):
        template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
        print(
            template.format(epoch + 1, self.train_loss.result(),
                            self.train_accuracy.result() * 100,
                            self.test_loss.result(),
                            self.test_accuracy.result() * 100))

    def train_epoch(self, model, train_ds, test_ds, epoch,
                    data_train: DataSpec, data_test: DataSpec):
        for train_inputs, train_classes in train_ds:
            self.train_step(train_inputs, train_classes, model)

        for test_inputs, test_classes in test_ds:
            self.test_step(test_inputs, test_classes, model)

        if VERBOSE_MODEL_TRAINING:
            self.print_summary(epoch)

        # never terminate earlier
        return False

    def train(self, model, data_train: DataSpec, data_test: DataSpec, epochs,
              batch_size):
        x_train = data_train.x()
        y_train = data_train.y()
        x_test = data_test.x()
        y_test = data_test.y()

        y_train = to_classes(y_train)
        y_test = to_classes(y_test)

        train_ds = to_dataset(x_train, y_train, batch_size)
        test_ds = to_dataset(x_test, y_test, batch_size)

        for epoch in range(epochs):
            stop = self.train_epoch(model, train_ds, test_ds, epoch,
                                    data_train, data_test)
            if stop:
                break

        # TODO return the history
        return None