def __init__(self): self.strategy = tf.distribute.get_strategy() self.model = create_model() self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1) self.global_step = self.optimizer.iterations self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32) train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn) standard_runner.StandardTrainer.__init__( self, train_dataset, options=standard_runner.StandardTrainerOptions( use_tpu_summary_optimization=True))
def test_trainer_with_tpu_summary_optimization(self): options = standard_runner.StandardTrainerOptions( use_tpu_summary_optimization=True) trainer = TestTrainer(options) self.assertEqual(trainer.train(tf.constant(10)), 10)