コード例 #1
0
 def test_tensorflow_cluster_config(self):
     config_dict = {
         "worker": [
             "worker0.example.com:2222", "worker1.example.com:2222",
             "worker2.example.com:2222"
         ],
         "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]
     }
     config = TensorflowClusterConfig.from_dict(config_dict)
     assert_equal_dict(config_dict, config.to_dict())
コード例 #2
0
    def test_run_config(self):
        config_dict = {
            'tf_random_seed': 100,
            'save_summary_steps': 100,
            'save_checkpoints_secs': 600,
            'save_checkpoints_steps': None,
            'keep_checkpoint_max': 5,
            'keep_checkpoint_every_n_hours': 10000,
        }
        config = RunConfig.from_dict(config_dict)
        assert_equal_dict(config.to_dict(), config_dict)

        # Add session config
        config_dict['session'] = SessionConfig().to_dict()
        config = RunConfig.from_dict(config_dict)
        assert_equal_dict(config.to_dict(), config_dict)

        # Add cluster config
        config_dict['cluster'] = TensorflowClusterConfig(
            worker=[TaskType.WORKER], ps=[TaskType.PS]).to_dict()
        config = RunConfig.from_dict(config_dict)
        assert_equal_dict(config.to_dict(), config_dict)
コード例 #3
0
    def get_cluster(self):
        cluster_def, is_distributed = self.spec.cluster_def

        job_name = self.pod_manager.get_job_name(task_type=TaskType.MASTER, task_idx=0)
        cluster_config = {
            TaskType.MASTER: [self._get_pod_address(job_name)]
        }

        workers = []
        for i in range(cluster_def.get(TaskType.WORKER, 0)):
            job_name = self.pod_manager.get_job_name(task_type=TaskType.WORKER, task_idx=i)
            workers.append(self._get_pod_address(job_name))

        cluster_config[TaskType.WORKER] = workers

        servers = []
        for i in range(cluster_def.get(TaskType.PS, 0)):
            job_name = self.pod_manager.get_job_name(task_type=TaskType.PS, task_idx=i)
            servers.append(self._get_pod_address(job_name))

        cluster_config[TaskType.PS] = servers

        return TensorflowClusterConfig.from_dict(cluster_config).to_dict()