def test_tensorflow_cluster_config(self): config_dict = { "worker": [ "worker0.example.com:2222", "worker1.example.com:2222", "worker2.example.com:2222" ], "ps": ["ps0.example.com:2222", "ps1.example.com:2222"] } config = TensorflowClusterConfig.from_dict(config_dict) assert_equal_dict(config_dict, config.to_dict())
def test_run_config(self): config_dict = { 'tf_random_seed': 100, 'save_summary_steps': 100, 'save_checkpoints_secs': 600, 'save_checkpoints_steps': None, 'keep_checkpoint_max': 5, 'keep_checkpoint_every_n_hours': 10000, } config = RunConfig.from_dict(config_dict) assert_equal_dict(config.to_dict(), config_dict) # Add session config config_dict['session'] = SessionConfig().to_dict() config = RunConfig.from_dict(config_dict) assert_equal_dict(config.to_dict(), config_dict) # Add cluster config config_dict['cluster'] = TensorflowClusterConfig( worker=[TaskType.WORKER], ps=[TaskType.PS]).to_dict() config = RunConfig.from_dict(config_dict) assert_equal_dict(config.to_dict(), config_dict)
def get_cluster(self): cluster_def, is_distributed = self.spec.cluster_def job_name = self.pod_manager.get_job_name(task_type=TaskType.MASTER, task_idx=0) cluster_config = { TaskType.MASTER: [self._get_pod_address(job_name)] } workers = [] for i in range(cluster_def.get(TaskType.WORKER, 0)): job_name = self.pod_manager.get_job_name(task_type=TaskType.WORKER, task_idx=i) workers.append(self._get_pod_address(job_name)) cluster_config[TaskType.WORKER] = workers servers = [] for i in range(cluster_def.get(TaskType.PS, 0)): job_name = self.pod_manager.get_job_name(task_type=TaskType.PS, task_idx=i) servers.append(self._get_pod_address(job_name)) cluster_config[TaskType.PS] = servers return TensorflowClusterConfig.from_dict(cluster_config).to_dict()