예제 #1
0
 def test_get_distribution_strategy_tpu(self):
     run_kwargs = dict(
         worker_count=1,
         worker_config=machine_config.COMMON_MACHINE_CONFIGS['TPU'],
     )
     strategy = models.get_distribution_strategy_str(run_kwargs)
     self.assertEqual('tpu', strategy)
예제 #2
0
 def test_get_distribution_strategy_one_device(self):
     run_kwargs = dict()
     strategy = models.get_distribution_strategy_str(run_kwargs)
     self.assertEqual('one_device', strategy)
예제 #3
0
 def test_get_distribution_strategy_mirror(self):
     run_kwargs = dict(
         chief_config=machine_config.COMMON_MACHINE_CONFIGS['K80_4X'])
     strategy = models.get_distribution_strategy_str(run_kwargs)
     self.assertEqual('mirror', strategy)
예제 #4
0
 def test_get_distribution_strategy_multi_mirror(self):
     run_kwargs = dict(worker_count=1)
     strategy = models.get_distribution_strategy_str(run_kwargs)
     self.assertEqual('multi_mirror', strategy)