def testArbitraryCurrentTaskType(self): cluster_def = multi_worker_test_base._create_cluster( num_workers=1, num_ps=1) cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar") with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def testArbitraryPsName(self): cluster_def = multi_worker_test_base._create_cluster( num_workers=1, num_ps=1, ps_name="some_arbitrary_name") cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc") with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def testLessThanOneWorker(self): cluster_def = multi_worker_test_base._create_cluster( num_workers=0, num_ps=1) cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0) with self.assertRaisesRegexp(ValueError, "There must be at least one worker."): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def testMoreThanOneChief(self): cluster_def = multi_worker_test_base._create_cluster( num_workers=1, num_ps=1) chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)] cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc", task_type="chief", task_id=1) with self.assertRaisesRegexp(ValueError, "There must be at most one 'chief' job."): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)