def call(self, dataset: tf.data.Dataset): r""" Perform the adversarial training. Args: dataset (:py:class:`tf.data.Dataset`): The adversarial training dataset. """ current_epoch = self._current_epoch() self._update_global_batch_size( dataset, [self._d_loss, self._g_loss, self._e_loss] ) dataset = wrap( dataset.unbatch().batch(self._global_batch_size, drop_remainder=True) ) samples = next(iter(dataset.take(1))) gen_inputs = samples[1] with self._train_summary_writer.as_default(): self._log("real_x", samples[0][0]) self._log("real_y", samples[0][1]) for epoch in tf.range(current_epoch, self._epochs): distribute_dataset = self._distribute_strategy.experimental_distribute_dataset( dataset ) for example in distribute_dataset: d_loss, g_loss, e_loss, fake, generator_of_encoder = self._train_step( example ) self._global_step.assign_add(1) if tf.equal(tf.math.mod(self._global_step, 10), 0): tf.print( f"[{self._global_step.numpy()}] g_loss: {g_loss} - " f"d_loss: {d_loss} - e_loss: {e_loss}" ) self._measure_performance( tf.data.Dataset.from_tensor_slices(example).batch( self._global_batch_size ) ) self._epoch_completed(epoch + 1) if self._log_eval_mode == LogEvalMode.TEST: self._log("generator", self._generator(gen_inputs, training=False)) self._log( "generator_of_encoder", self._generator( self._encoder(samples[0][0], training=False), training=False ), ) elif self._log_eval_mode == LogEvalMode.TRAIN: self._log("generator", fake) self._log("generator_of_encoder", generator_of_encoder)
def call(self, train_set, validation_set): """ Start the training. Args: train_set (:py:obj:`tf.data.Dataset`): Training dataset. validation_set (:py:obj:`tf.data.Dataset`): Validation dataset. """ current_epoch = self._current_epoch() self._update_global_batch_size(train_set, self._loss) with self._eval_summary_writer.as_default(): self._measure_performance(validation_set) # need to use the global batch size in the training set train_set = wrap(train_set.unbatch().batch( self._global_batch_size, drop_remainder=tf.distribute.has_strategy())) samples = train_set.take(1) with self._train_summary_writer.as_default(): for epoch in tf.range(current_epoch, self._epochs): distribute_dataset = self._distribute_strategy.experimental_distribute_dataset( train_set) for example in distribute_dataset: loss = self._train_step(example) self._global_step.assign_add(1) if tf.equal(tf.math.mod(self._global_step, 10), 0): tf.print(f"[{self._global_step.numpy()}] loss: {loss}") self._measure_performance( tf.data.Dataset.from_tensor_slices(example).batch( self._global_batch_size)) self._log("input_x", example[0]) self._log("input_y", example[1]) self._epoch_completed(epoch + 1) with self._eval_summary_writer.as_default(): self._measure_performance(validation_set)
def call( self, training_set: tf.data.Dataset, validation_set: tf.data.Dataset, log_freq: int = 10, measure_performance_freq: int = 10, ): """ Start the training. Args: training_set (:py:obj:`tf.data.Dataset`): Training dataset. validation_set (:py:obj:`tf.data.Dataset`): Validation dataset. log_freq (int): Specifies how many steps to run before logging the losses, e.g. `log_frequency=10` logs every 10 steps of training. Pass `log_frequency<=0` in case you don't want to log. measure_performance_freq (int): Specifies how many steps to run before measuring the performance, e.g. `measure_performance_freq=10` measures performance every 10 steps of training. Pass `measure_performance_freq<=0` in case you don't want to measure performance. """ # set the context properties self._context.training_set = training_set self._context.validation_set = validation_set current_epoch = self._current_epoch() self._update_global_batch_size(training_set, self._loss) # measure performance on the validation set with self._eval_summary_writer.as_default(): self._context.dataset = validation_set self._measure_performance() # need to use the global batch size in the training set training_set = wrap(training_set.unbatch().batch( self._global_batch_size, drop_remainder=tf.distribute.has_strategy())) with self._train_summary_writer.as_default(): # notify on train start self._on_train_start() for _ in tf.range(current_epoch, self._epochs): distribute_dataset = self._distribute_strategy.experimental_distribute_dataset( training_set) # notify on epoch start self._on_epoch_start() for example in distribute_dataset: self._context.current_batch = self.local_example( example, (1, 1)) # notify on batch start self._on_batch_start() # perform training step loss = self._train_step(example) # increase global step self._global_step.assign_add(1) # log loss if needed if log_freq > 0 and tf.equal( tf.math.mod(self._global_step, log_freq), 0): tf.print(f"[{self._global_step.numpy()}] loss: {loss}") # measure performance # this can also be moved to on_batch_end self._measure_performance_if_needed( example, measure_performance_freq) # notify on batch end self._on_batch_end() # notify on epoch end self._on_epoch_end() with self._eval_summary_writer.as_default(): self._context.dataset = validation_set self._measure_performance() # final callback self._on_train_end()
def call( self, dataset: tf.data.Dataset, log_freq: int = 10, measure_performance_freq: int = 10, ): r""" Perform the adversarial training. Args: dataset (:py:class:`tf.data.Dataset`): The adversarial training dataset. log_freq (int): Specifies how many steps to run before logging the losses, e.g. `log_frequency=10` logs every 10 steps of training. Pass `log_frequency<=0` in case you don't want to log. measure_performance_freq (int): Specifies how many steps to run before measuring the performance, e.g. `measure_performance_freq=10` measures performance every 10 steps of training. Pass `measure_performance_freq<=0` in case you don't want to measure performance. """ current_epoch = self._current_epoch() self._update_global_batch_size( dataset, [ self._discriminator_loss, self._generator_loss, self._encoder_loss ], ) dataset = wrap(dataset.unbatch().batch(self._global_batch_size, drop_remainder=True)) samples = next(iter(dataset.take(1))) self._context.generator_inputs = samples[1] self._context.encoder_inputs = samples[0][0] with self._train_summary_writer.as_default(): # notify on train start event self._on_train_start() for _ in tf.range(current_epoch, self._epochs): distribute_dataset = self._distribute_strategy.experimental_distribute_dataset( dataset) # notify on epoch start event self._on_epoch_start() for example in distribute_dataset: # perform training step d_loss, g_loss, e_loss, fake, generator_of_encoder = self._train_step( example) # increase global step self._global_step.assign_add(1) # setup fake_samples self._context.fake_samples = fake self._context.generator_of_encoder = generator_of_encoder # Log losses if log_freq > 0 and tf.equal( tf.math.mod(self._global_step, log_freq), 0): tf.print( f"[{self._global_step.numpy()}] g_loss: {g_loss} - " f"d_loss: {d_loss} - e_loss: {e_loss}") # measure performance if needed self._measure_performance_if_needed( example, measure_performance_freq) # notify on batch end event self._on_batch_end() # notify on epoch end event self._on_epoch_end() # notify on training end event self._on_train_end()