Esempio n. 1
0
    def call(self, dataset: tf.data.Dataset):
        r"""
        Perform the adversarial training.

        Args:
            dataset (:py:class:`tf.data.Dataset`): The adversarial training dataset.
        """
        current_epoch = self._current_epoch()

        self._update_global_batch_size(
            dataset, [self._d_loss, self._g_loss, self._e_loss]
        )

        dataset = wrap(
            dataset.unbatch().batch(self._global_batch_size, drop_remainder=True)
        )

        samples = next(iter(dataset.take(1)))
        gen_inputs = samples[1]

        with self._train_summary_writer.as_default():
            self._log("real_x", samples[0][0])
            self._log("real_y", samples[0][1])

            for epoch in tf.range(current_epoch, self._epochs):
                distribute_dataset = self._distribute_strategy.experimental_distribute_dataset(
                    dataset
                )

                for example in distribute_dataset:
                    d_loss, g_loss, e_loss, fake, generator_of_encoder = self._train_step(
                        example
                    )
                    self._global_step.assign_add(1)

                    if tf.equal(tf.math.mod(self._global_step, 10), 0):
                        tf.print(
                            f"[{self._global_step.numpy()}] g_loss: {g_loss} - "
                            f"d_loss: {d_loss} - e_loss: {e_loss}"
                        )
                        self._measure_performance(
                            tf.data.Dataset.from_tensor_slices(example).batch(
                                self._global_batch_size
                            )
                        )

                self._epoch_completed(epoch + 1)
                if self._log_eval_mode == LogEvalMode.TEST:
                    self._log("generator", self._generator(gen_inputs, training=False))

                    self._log(
                        "generator_of_encoder",
                        self._generator(
                            self._encoder(samples[0][0], training=False), training=False
                        ),
                    )
                elif self._log_eval_mode == LogEvalMode.TRAIN:
                    self._log("generator", fake)
                    self._log("generator_of_encoder", generator_of_encoder)
Esempio n. 2
0
    def call(self, train_set, validation_set):
        """
        Start the training.

        Args:
            train_set (:py:obj:`tf.data.Dataset`): Training dataset.
            validation_set (:py:obj:`tf.data.Dataset`): Validation dataset.
        """
        current_epoch = self._current_epoch()
        self._update_global_batch_size(train_set, self._loss)
        with self._eval_summary_writer.as_default():
            self._measure_performance(validation_set)

        # need to use the global batch size in the training set
        train_set = wrap(train_set.unbatch().batch(
            self._global_batch_size,
            drop_remainder=tf.distribute.has_strategy()))
        samples = train_set.take(1)

        with self._train_summary_writer.as_default():
            for epoch in tf.range(current_epoch, self._epochs):
                distribute_dataset = self._distribute_strategy.experimental_distribute_dataset(
                    train_set)

                for example in distribute_dataset:
                    loss = self._train_step(example)
                    self._global_step.assign_add(1)
                    if tf.equal(tf.math.mod(self._global_step, 10), 0):
                        tf.print(f"[{self._global_step.numpy()}] loss: {loss}")
                        self._measure_performance(
                            tf.data.Dataset.from_tensor_slices(example).batch(
                                self._global_batch_size))
                        self._log("input_x", example[0])
                        self._log("input_y", example[1])

                self._epoch_completed(epoch + 1)
                with self._eval_summary_writer.as_default():
                    self._measure_performance(validation_set)
Esempio n. 3
0
    def call(
        self,
        training_set: tf.data.Dataset,
        validation_set: tf.data.Dataset,
        log_freq: int = 10,
        measure_performance_freq: int = 10,
    ):
        """
        Start the training.

        Args:
            training_set (:py:obj:`tf.data.Dataset`): Training dataset.
            validation_set (:py:obj:`tf.data.Dataset`): Validation dataset.
            log_freq (int): Specifies how many steps to run before logging the losses,
                e.g. `log_frequency=10` logs every 10 steps of training.
                Pass `log_frequency<=0` in case you don't want to log.
            measure_performance_freq (int): Specifies how many steps to run before
                measuring the performance, e.g. `measure_performance_freq=10`
                measures performance every 10 steps of training.
                Pass `measure_performance_freq<=0` in case you don't want to measure
                performance.

        """
        # set the context properties
        self._context.training_set = training_set
        self._context.validation_set = validation_set

        current_epoch = self._current_epoch()
        self._update_global_batch_size(training_set, self._loss)

        # measure performance on the validation set
        with self._eval_summary_writer.as_default():
            self._context.dataset = validation_set
            self._measure_performance()

        # need to use the global batch size in the training set
        training_set = wrap(training_set.unbatch().batch(
            self._global_batch_size,
            drop_remainder=tf.distribute.has_strategy()))

        with self._train_summary_writer.as_default():

            # notify on train start
            self._on_train_start()

            for _ in tf.range(current_epoch, self._epochs):
                distribute_dataset = self._distribute_strategy.experimental_distribute_dataset(
                    training_set)

                # notify on epoch start
                self._on_epoch_start()

                for example in distribute_dataset:

                    self._context.current_batch = self.local_example(
                        example, (1, 1))

                    # notify on batch start
                    self._on_batch_start()

                    # perform training step
                    loss = self._train_step(example)

                    # increase global step
                    self._global_step.assign_add(1)

                    # log loss if needed
                    if log_freq > 0 and tf.equal(
                            tf.math.mod(self._global_step, log_freq), 0):
                        tf.print(f"[{self._global_step.numpy()}] loss: {loss}")

                    # measure performance
                    # this can also be moved to on_batch_end
                    self._measure_performance_if_needed(
                        example, measure_performance_freq)

                    # notify on batch end
                    self._on_batch_end()

                # notify on epoch end
                self._on_epoch_end()

                with self._eval_summary_writer.as_default():
                    self._context.dataset = validation_set
                    self._measure_performance()

            # final callback
            self._on_train_end()
Esempio n. 4
0
    def call(
        self,
        dataset: tf.data.Dataset,
        log_freq: int = 10,
        measure_performance_freq: int = 10,
    ):
        r"""
        Perform the adversarial training.

        Args:
            dataset (:py:class:`tf.data.Dataset`): The adversarial training dataset.
            log_freq (int): Specifies how many steps to run before logging the losses,
                e.g. `log_frequency=10` logs every 10 steps of training.
                Pass `log_frequency<=0` in case you don't want to log.
            measure_performance_freq (int): Specifies how many steps to run before
                measuring the performance, e.g. `measure_performance_freq=10`
                measures performance every 10 steps of training.
                Pass `measure_performance_freq<=0` in case you don't want to measure
                performance.

        """
        current_epoch = self._current_epoch()

        self._update_global_batch_size(
            dataset,
            [
                self._discriminator_loss, self._generator_loss,
                self._encoder_loss
            ],
        )

        dataset = wrap(dataset.unbatch().batch(self._global_batch_size,
                                               drop_remainder=True))

        samples = next(iter(dataset.take(1)))

        self._context.generator_inputs = samples[1]
        self._context.encoder_inputs = samples[0][0]

        with self._train_summary_writer.as_default():

            # notify on train start event
            self._on_train_start()

            for _ in tf.range(current_epoch, self._epochs):

                distribute_dataset = self._distribute_strategy.experimental_distribute_dataset(
                    dataset)

                # notify on epoch start event
                self._on_epoch_start()

                for example in distribute_dataset:

                    # perform training step
                    d_loss, g_loss, e_loss, fake, generator_of_encoder = self._train_step(
                        example)

                    # increase global step
                    self._global_step.assign_add(1)

                    # setup fake_samples
                    self._context.fake_samples = fake
                    self._context.generator_of_encoder = generator_of_encoder

                    # Log losses
                    if log_freq > 0 and tf.equal(
                            tf.math.mod(self._global_step, log_freq), 0):
                        tf.print(
                            f"[{self._global_step.numpy()}] g_loss: {g_loss} - "
                            f"d_loss: {d_loss} - e_loss: {e_loss}")

                    # measure performance if needed
                    self._measure_performance_if_needed(
                        example, measure_performance_freq)

                    # notify on batch end event
                    self._on_batch_end()

                # notify on epoch end event
                self._on_epoch_end()

            # notify on training end event
            self._on_train_end()