def setUp(self):
    super(FaultToleranceTest, self).setUp()

    # Set the environment variable to prevent hanging upon job failure and
    # restart. Note that it defaults to 'use_caller' at Google, but defaults
    # to False in OSS.
    os.environ["GRPC_FAIL_FAST"] = "use_caller"

    self._cluster = multi_worker_test_base.create_multi_process_cluster(
        num_workers=FaultToleranceTest.NUM_WORKERS,
        num_ps=FaultToleranceTest.NUM_PS,
        rpc_layer="grpc")
    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
    self._cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(
        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")

    # The strategy's constructor would connect to the cluster.
    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        cluster_resolver)
    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)

    self.thread_coord = thread_coordinator.Coordinator(
        clean_stop_exception_types=[])
Exemple #2
0
 def setUp(self):
     super(MultiProcessClusterTest, self).setUp()
     self._cluster = multi_worker_test_base.create_multi_process_cluster(
         num_workers=2, num_ps=1, has_chief=True, rpc_layer="grpc")
     remote.connect_to_cluster(
         self._cluster.cluster_resolver.cluster_spec(), protocol="grpc")
     context.ensure_initialized()
Exemple #3
0
  def setUpClass(cls):
    super(EvaluationTest, cls).setUpClass()
    cls._cluster = multi_worker_test_base.create_multi_process_cluster(
        num_workers=3, num_ps=2, rpc_layer="grpc")
    cls._cluster_def = cls._cluster.cluster_resolver.cluster_spec().as_dict()
    cluster_resolver = SimpleClusterResolver(
        tf.train.ClusterSpec(cls._cluster_def), rpc_layer="grpc")

    cls.strategy = tf.distribute.experimental.ParameterServerStrategy(
        cluster_resolver)
    cls.cluster_coord = tf.distribute.experimental.coordinator.ClusterCoordinator(cls.strategy)
    def setUpClass(cls):
        super(EvaluationTest, cls).setUpClass()
        cls._cluster = multi_worker_test_base.create_multi_process_cluster(
            num_workers=3, num_ps=2, rpc_layer="grpc")
        cls._cluster_def = cls._cluster.cluster_resolver.cluster_spec(
        ).as_dict()
        cluster_resolver = SimpleClusterResolver(server_lib.ClusterSpec(
            cls._cluster_def),
                                                 rpc_layer="grpc")

        cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver)
        cls.cluster_coord = coordinator_lib.ClusterCoordinator(cls.strategy)
  def setUp(self, num_workers, num_ps):
    super(BaseFaultToleranceTest, self).setUp()

    self._cluster = multi_worker_test_base.create_multi_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
    self._cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(
        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")

    # The strategy's constructor would connect to the cluster.
    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        cluster_resolver)
    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)

    self.thread_coord = thread_coordinator.Coordinator(
        clean_stop_exception_types=[])
    self.num_workers = num_workers
    self.num_ps = num_ps
 def setUpClass(cls):
     super(ParameterServerStrategyV2Test, cls).setUpClass()
     cls.cluster = multi_worker_test_base.create_multi_process_cluster(
         num_workers=2, num_ps=3, rpc_layer="grpc")
     cls.cluster_resolver = cls.cluster.cluster_resolver
 def setUpClass(cls):
     super(VariablePartitioningTest, cls).setUpClass()
     cls.cluster = multi_worker_test_base.create_multi_process_cluster(
         num_workers=2, num_ps=2, rpc_layer="grpc")
     cls.cluster_resolver = cls.cluster.cluster_resolver
Exemple #8
0
 def setUpClass(cls):
   super(DistributedTableTest, cls).setUpClass()
   cls.cluster = multi_worker_test_base.create_multi_process_cluster(
       num_workers=2, num_ps=3, rpc_layer="grpc")
   cls.cluster_resolver = cls.cluster.cluster_resolver