def __init__(self):
   self.strategy = tf.distribute.get_strategy()
   self.model = create_model()
   self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
   self.global_step = self.optimizer.iterations
   self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
   train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
   standard_runner.StandardTrainer.__init__(
       self,
       train_dataset,
       options=standard_runner.StandardTrainerOptions(
           use_tpu_summary_optimization=True))
Ejemplo n.º 2
0
 def test_trainer_with_tpu_summary_optimization(self):
     options = standard_runner.StandardTrainerOptions(
         use_tpu_summary_optimization=True)
     trainer = TestTrainer(options)
     self.assertEqual(trainer.train(tf.constant(10)), 10)