def test_build_tf_config_with_one_host(trainer): trainer.hosts = ['algo-1'] trainer.current_host = 'algo-1' tf_config = trainer.build_tf_config() expected_tf_config = { 'environment': 'cloud', 'cluster': { 'master': ['algo-1:2222'] }, 'task': {'index': 0, 'type': 'master'} } assert tf_config == expected_tf_config assert trainer.task_type == 'master'
def test_build_tf_config_with_multiple_hosts(trainer): trainer.hosts = ['algo-1', 'algo-2', 'algo-3', 'algo-4'] trainer.current_host = 'algo-3' tf_config = trainer.build_tf_config() expected_tf_config = { 'environment': 'cloud', 'cluster': { 'master': ['algo-1:2222'], 'ps': ['algo-1:2223', 'algo-2:2223', 'algo-3:2223', 'algo-4:2223'], 'worker': ['algo-2:2222', 'algo-3:2222', 'algo-4:2222'] }, 'task': {'index': 1, 'type': 'worker'} } assert tf_config == expected_tf_config assert trainer.task_type == 'worker'