def setUp(self): super(FaultToleranceTest, self).setUp() # Set the environment variable to prevent hanging upon job failure and # restart. Note that it defaults to 'use_caller' at Google, but defaults # to False in OSS. os.environ["GRPC_FAIL_FAST"] = "use_caller" self._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=FaultToleranceTest.NUM_WORKERS, num_ps=FaultToleranceTest.NUM_PS, rpc_layer="grpc") self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() self._cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc") # The strategy's constructor would connect to the cluster. self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy) self.thread_coord = thread_coordinator.Coordinator( clean_stop_exception_types=[])
def setUp(self): super(MultiProcessClusterTest, self).setUp() self._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=2, num_ps=1, has_chief=True, rpc_layer="grpc") remote.connect_to_cluster( self._cluster.cluster_resolver.cluster_spec(), protocol="grpc") context.ensure_initialized()
def setUpClass(cls): super(EvaluationTest, cls).setUpClass() cls._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=3, num_ps=2, rpc_layer="grpc") cls._cluster_def = cls._cluster.cluster_resolver.cluster_spec().as_dict() cluster_resolver = SimpleClusterResolver( tf.train.ClusterSpec(cls._cluster_def), rpc_layer="grpc") cls.strategy = tf.distribute.experimental.ParameterServerStrategy( cluster_resolver) cls.cluster_coord = tf.distribute.experimental.coordinator.ClusterCoordinator(cls.strategy)
def setUpClass(cls): super(EvaluationTest, cls).setUpClass() cls._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=3, num_ps=2, rpc_layer="grpc") cls._cluster_def = cls._cluster.cluster_resolver.cluster_spec( ).as_dict() cluster_resolver = SimpleClusterResolver(server_lib.ClusterSpec( cls._cluster_def), rpc_layer="grpc") cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) cls.cluster_coord = coordinator_lib.ClusterCoordinator(cls.strategy)
def setUp(self, num_workers, num_ps): super(BaseFaultToleranceTest, self).setUp() self._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() self._cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc") # The strategy's constructor would connect to the cluster. self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy) self.thread_coord = thread_coordinator.Coordinator( clean_stop_exception_types=[]) self.num_workers = num_workers self.num_ps = num_ps
def setUpClass(cls): super(ParameterServerStrategyV2Test, cls).setUpClass() cls.cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=2, num_ps=3, rpc_layer="grpc") cls.cluster_resolver = cls.cluster.cluster_resolver
def setUpClass(cls): super(VariablePartitioningTest, cls).setUpClass() cls.cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=2, num_ps=2, rpc_layer="grpc") cls.cluster_resolver = cls.cluster.cluster_resolver
def setUpClass(cls): super(DistributedTableTest, cls).setUpClass() cls.cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=2, num_ps=3, rpc_layer="grpc") cls.cluster_resolver = cls.cluster.cluster_resolver