コード例 #1
0
    def testMixedModeNonOverlappingKey(self):
        cluster_spec_1 = server_lib.ClusterSpec(
            {"worker": ["worker4:2222", "worker5:2222"]})
        cluster_spec_2 = server_lib.ClusterSpec({
            "worker": {
                3: "worker0:2222",
                6: "worker1:2222",
                7: "worker2:2222"
            }
        })
        cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
        cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)

        union_cluster = UnionClusterResolver(cluster_resolver_1,
                                             cluster_resolver_2)
        cluster_spec = union_cluster.cluster_spec()

        expected_proto = """
    job { name: 'worker' tasks { key: 0 value: 'worker4:2222' }
                         tasks { key: 1 value: 'worker5:2222' }
                         tasks { key: 3 value: 'worker0:2222' }
                         tasks { key: 6 value: 'worker1:2222' }
                         tasks { key: 7 value: 'worker2:2222' }}
    """
        self._verifyClusterSpecEquality(cluster_spec, expected_proto)
コード例 #2
0
    def testSimpleOverrideMaster(self):
        base_cluster_spec = server_lib.ClusterSpec({
            "ps": ["ps0:2222", "ps1:2222"],
            "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
        })

        simple_resolver = SimpleClusterResolver(base_cluster_spec)
        actual_master = simple_resolver.master("worker", 2)
        self.assertEqual(actual_master, "worker2:2222")
コード例 #3
0
    def testSimpleOverrideMasterWithTaskIndexZero(self):
        base_cluster_spec = server_lib.ClusterSpec({
            "ps": ["ps0:2222", "ps1:2222"],
            "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
        })

        simple_resolver = SimpleClusterResolver(base_cluster_spec)
        actual_master = simple_resolver.master("worker", 0, rpc_layer="grpc")
        self.assertEqual(actual_master, "grpc://worker0:2222")
コード例 #4
0
    def testOverlappingDictAndListThrowError(self):
        cluster_spec_1 = server_lib.ClusterSpec(
            {"worker": ["worker4:2222", "worker5:2222"]})
        cluster_spec_2 = server_lib.ClusterSpec({
            "worker": {
                1: "worker0:2222",
                2: "worker1:2222",
                3: "worker2:2222"
            }
        })
        cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
        cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)

        union_cluster = UnionClusterResolver(cluster_resolver_1,
                                             cluster_resolver_2)
        self.assertRaises(KeyError, union_cluster.cluster_spec)
コード例 #5
0
    def testInitSimpleClusterResolver(self):
        base_cluster_spec = server_lib.ClusterSpec({
            "ps": ["ps0:2222", "ps1:2222"],
            "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
        })

        simple_resolver = SimpleClusterResolver(base_cluster_spec,
                                                task_type="ps",
                                                task_id=1,
                                                environment="cloud",
                                                num_accelerators={"GPU": 8},
                                                rpc_layer="grpc")

        self.assertEqual(simple_resolver.task_type, "ps")
        self.assertEqual(simple_resolver.task_id, 1)
        self.assertEqual(simple_resolver.environment, "cloud")
        self.assertEqual(simple_resolver.num_accelerators(), {"GPU": 8})
        self.assertEqual(simple_resolver.rpc_layer, "grpc")
コード例 #6
0
    def testOverlappingSparseJobMergedClusterResolverThrowError(self):
        cluster_spec_1 = server_lib.ClusterSpec(
            {"worker": {
                7: "worker4:2222",
                9: "worker5:2222"
            }})
        cluster_spec_2 = server_lib.ClusterSpec({
            "worker": {
                3: "worker0:2222",
                6: "worker1:2222",
                7: "worker2:2222"
            }
        })
        cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
        cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)

        union_cluster = UnionClusterResolver(cluster_resolver_1,
                                             cluster_resolver_2)
        self.assertRaises(KeyError, union_cluster.cluster_spec)
コード例 #7
0
    def testMergedClusterResolverMaster(self):
        cluster_spec_1 = server_lib.ClusterSpec(
            {"ps": ["ps0:2222", "ps1:2222"]})
        cluster_spec_2 = server_lib.ClusterSpec(
            {"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]})
        cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
        cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)

        union_cluster = UnionClusterResolver(cluster_resolver_1,
                                             cluster_resolver_2)

        unspecified_master = union_cluster.master()
        self.assertEqual(unspecified_master, "")

        specified_master = union_cluster.master("worker", 1)
        self.assertEqual(specified_master, "worker1:2222")

        rpc_master = union_cluster.master("worker", 1, rpc_layer="grpc")
        self.assertEqual(rpc_master, "grpc://worker1:2222")
コード例 #8
0
  def setUp(self):
    super(MultiJobsTest, self).setUp()

    workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2)
    cluster = {
        'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers],
        'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps],
    }
    self._cluster = server_lib.ClusterSpec(cluster)
    self._cluster_resolver = SimpleClusterResolver(
        cluster_spec=self._cluster, master=ps[0].target)
コード例 #9
0
    def testOverlappingJobMergedClusterResolver(self):
        cluster_spec_1 = server_lib.ClusterSpec(
            {"worker": ["worker4:2222", "worker5:2222"]})
        cluster_spec_2 = server_lib.ClusterSpec(
            {"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]})
        cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
        cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)

        union_cluster = UnionClusterResolver(cluster_resolver_1,
                                             cluster_resolver_2)
        cluster_spec = union_cluster.cluster_spec()

        expected_proto = """
    job { name: 'worker' tasks { key: 0 value: 'worker4:2222' }
                         tasks { key: 1 value: 'worker5:2222' }
                         tasks { key: 2 value: 'worker0:2222' }
                         tasks { key: 3 value: 'worker1:2222' }
                         tasks { key: 4 value: 'worker2:2222' } }
    """
        self._verifyClusterSpecEquality(cluster_spec, expected_proto)
コード例 #10
0
    def testUnionClusterResolverGetProperties(self):
        cluster_spec_1 = server_lib.ClusterSpec({
            "ps": ["ps0:2222", "ps1:2222"],
            "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
        })
        resolver1 = SimpleClusterResolver(cluster_spec_1,
                                          task_type="ps",
                                          task_id=1,
                                          environment="cloud",
                                          num_accelerators={"GPU": 8},
                                          rpc_layer="grpc")

        cluster_spec_2 = server_lib.ClusterSpec({
            "ps": ["ps2:2222", "ps3:2222"],
            "worker": ["worker3:2222", "worker4:2222", "worker5:2222"]
        })
        resolver2 = SimpleClusterResolver(cluster_spec_2,
                                          task_type="worker",
                                          task_id=2,
                                          environment="local",
                                          num_accelerators={"GPU": 16},
                                          rpc_layer="http")

        union_resolver = UnionClusterResolver(resolver1, resolver2)

        self.assertEqual(union_resolver.task_type, "ps")
        self.assertEqual(union_resolver.task_id, 1)
        self.assertEqual(union_resolver.environment, "cloud")
        self.assertEqual(union_resolver.num_accelerators(), {"GPU": 8})
        self.assertEqual(union_resolver.rpc_layer, "grpc")

        union_resolver.task_type = "worker"
        union_resolver.task_id = 2
        union_resolver.rpc_layer = "http"

        self.assertEqual(union_resolver.task_type, "worker")
        self.assertEqual(union_resolver.task_id, 2)
        self.assertEqual(union_resolver.rpc_layer, "http")
コード例 #11
0
    def testSingleClusterResolver(self):
        base_cluster_spec = server_lib.ClusterSpec({
            "ps": ["ps0:2222", "ps1:2222"],
            "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
        })
        simple_resolver = SimpleClusterResolver(base_cluster_spec)
        union_resolver = UnionClusterResolver(simple_resolver)

        expected_proto = """
    job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
                     tasks { key: 1 value: 'ps1:2222' } }
    job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
                         tasks { key: 1 value: 'worker1:2222' }
                         tasks { key: 2 value: 'worker2:2222' } }
    """
        actual_cluster_spec = union_resolver.cluster_spec()
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
コード例 #12
0
    def testRetainSparseJobWithNoMerging(self):
        base_cluster_spec = server_lib.ClusterSpec({
            "worker": {
                1: "worker0:2222",
                3: "worker1:2222",
                5: "worker2:2222"
            }
        })

        base_cluster_resolver = SimpleClusterResolver(base_cluster_spec)
        union_cluster = UnionClusterResolver(base_cluster_resolver)
        cluster_spec = union_cluster.cluster_spec()

        expected_proto = """
    job { name: 'worker' tasks { key: 1 value: 'worker0:2222' }
                         tasks { key: 3 value: 'worker1:2222' }
                         tasks { key: 5 value: 'worker2:2222' } }
    """
        self._verifyClusterSpecEquality(cluster_spec, expected_proto)
コード例 #13
0
ファイル: remote_test.py プロジェクト: zyx5256/tensorflow
 def testConnectToClusterWithLocalMaster(self):
     local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
     remote.connect_to_cluster(local_resolver)