def testSimpleOverrideMasterWithTaskIndexZero(self):
    base_cluster_spec = server_lib.ClusterSpec({
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
    })

    simple_resolver = SimpleClusterResolver(base_cluster_spec)
    actual_master = simple_resolver.master("worker", 0, rpc_layer="grpc")
    self.assertEqual(actual_master, "grpc://worker0:2222")
  def testSimpleOverrideMaster(self):
    base_cluster_spec = server_lib.ClusterSpec({
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
    })

    simple_resolver = SimpleClusterResolver(base_cluster_spec)
    actual_master = simple_resolver.master("worker", 2)
    self.assertEqual(actual_master, "worker2:2222")
  def testInitSimpleClusterResolver(self):
    base_cluster_spec = server_lib.ClusterSpec({
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
    })

    simple_resolver = SimpleClusterResolver(base_cluster_spec, task_type="ps",
                                            task_index=1, environment="cloud",
                                            num_accelerators=8,
                                            rpc_layer="grpc")

    self.assertEqual(simple_resolver.task_type, "ps")
    self.assertEqual(simple_resolver.task_index, 1)
    self.assertEqual(simple_resolver.environment, "cloud")
    self.assertEqual(simple_resolver.num_accelerators(), 8)
    self.assertEqual(simple_resolver.rpc_layer, "grpc")
  def testOverrideSimpleClusterResolver(self):
    base_cluster_spec = server_lib.ClusterSpec({
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
    })

    simple_resolver = SimpleClusterResolver(base_cluster_spec, task_type="ps",
                                            task_index=1, environment="cloud",
                                            num_accelerators_per_worker=8,
                                            rpc_layer="grpc")

    simple_resolver.task_type = "worker"
    simple_resolver.task_index = 2
    simple_resolver.rpc_layer = "http"

    self.assertEqual(simple_resolver.task_type, "worker")
    self.assertEqual(simple_resolver.task_index, 2)
    self.assertEqual(simple_resolver.rpc_layer, "http")
示例#5
0
 def setUp(self):
   super().setUp()
   cluster_def = get_cluster_def(test_cluster_params, num_workers=2, num_ps=3)
   self.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
示例#6
0
 def testConnectToClusterWithLocalMaster(self):
     local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
     remote.connect_to_cluster(local_resolver)
示例#7
0
 def setUpClass(cls):
     super(ParameterServerStrategyV2Test, cls).setUpClass()
     cluster_def = multi_worker_test_base.create_in_process_cluster(
         num_workers=2, num_ps=3)
     cls.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
示例#8
0
 def setUpClass(cls):
     super(VariablePartitioningTest, cls).setUpClass()
     cluster_def = multi_worker_test_base.create_in_process_cluster(
         num_workers=2, num_ps=2)
     cls.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
示例#9
0
def make_parameter_server_cluster(num_workers, num_ps):
  cluster_def = multi_worker_test_base.create_in_process_cluster(
      num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
  return SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc")
示例#10
0
 def _get_parameter_server_strategy(self):
   cluster_def = multi_worker_test_base.create_in_process_cluster(
       num_workers=2, num_ps=1, rpc_layer="grpc")
   return parameter_server_strategy_v2.ParameterServerStrategyV2(
       SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))
示例#11
0
 def _get_parameter_server_strategy(self):
     cluster_def = multi_worker_testing_utils.create_in_process_cluster(
         num_workers=2, num_ps=1, rpc_layer="grpc")
     return tf.distribute.experimental.ParameterServerStrategy(
         SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))