def test_get_distribution_strategy_tpu(self): run_kwargs = dict( worker_count=1, worker_config=machine_config.COMMON_MACHINE_CONFIGS['TPU'], ) strategy = models.get_distribution_strategy_str(run_kwargs) self.assertEqual('tpu', strategy)
def test_get_distribution_strategy_one_device(self): run_kwargs = dict() strategy = models.get_distribution_strategy_str(run_kwargs) self.assertEqual('one_device', strategy)
def test_get_distribution_strategy_mirror(self): run_kwargs = dict( chief_config=machine_config.COMMON_MACHINE_CONFIGS['K80_4X']) strategy = models.get_distribution_strategy_str(run_kwargs) self.assertEqual('mirror', strategy)
def test_get_distribution_strategy_multi_mirror(self): run_kwargs = dict(worker_count=1) strategy = models.get_distribution_strategy_str(run_kwargs) self.assertEqual('multi_mirror', strategy)