コード例 #1
0
 def testDistributeDatasetNotUsedWithClusterCoordinator(self):
   strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
       self.cluster_resolver)
   dataset = dataset_ops.DatasetV2.range(3)
   with self._assertRaisesUsageError():
     def_function.function(
         lambda: strategy.experimental_distribute_dataset(dataset))()
コード例 #2
0
  def testInModelAndCapture(self, source):

    file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")

    model = self.Model(source, file_path)
    func_captures = model.use_table.get_concrete_function(
    ).graph.external_captures
    self.assertLen(func_captures, 2)
    self.assertTrue(
        any(model.table.resource_handle is t for t in func_captures))
    deferred_captures = model.use_table.get_concrete_function(
    ).graph.deferred_external_captures
    self.assertEmpty(deferred_captures)

    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)
    coordinator = coordinator_lib.ClusterCoordinator(strategy)
    with strategy.scope():
      distributed_model = self.Model("value", file_path)
    func_captures = distributed_model.use_table.get_concrete_function(
    ).graph.external_captures
    # One less external_capture, since the table handle becomes a closure in the
    # deferred_external_capture
    self.assertLen(func_captures, 1)
    self.assertFalse(
        any(model.table.resource_handle is t for t in func_captures))
    deferred_captures = distributed_model.use_table.get_concrete_function(
    ).graph.deferred_external_captures
    self.assertNotEmpty(deferred_captures)

    self.verifyWorkerLocalInstance(coordinator, distributed_model)
コード例 #3
0
  def testDistributeTableSaveAndServe(self, load, source):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)
    file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
    with strategy.scope():
      model = self.Model(source, file_path)

    model_dir = self.get_temp_dir()
    tf_save.save(model, model_dir)

    if load == "tf_load":
      load_fn = tf_load.load
    else:
      load_fn = keras_save.load_model

    loaded_without_strategy = load_fn(model_dir)
    loaded_func_captures_without_strategy = (
        loaded_without_strategy.use_table.get_concrete_function().graph
        .external_captures)
    loaded_func_deferred_captures_without_strategy = (
        loaded_without_strategy.use_table.get_concrete_function().graph
        .deferred_external_captures)
    self.assertLen(loaded_func_captures_without_strategy, 2)
    self.assertEmpty(loaded_func_deferred_captures_without_strategy)

    self.assertAllEqual(
        loaded_without_strategy.use_table(
            constant_op.constant([0, 1, 3], dtype=dtypes.int64)), [0, 1, -2])
コード例 #4
0
  def _model_compile(self, steps_per_execution=1, run_eagerly=False):

    class ResultAssertingCallback(callbacks_lib.Callback):

      def __init__(self):
        self._prev_epoch = -1

      def on_epoch_end(self, epoch, logs=None):
        logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs)
        if epoch <= self._prev_epoch:
          raise RuntimeError("Epoch is supposed to be larger than previous.")
        self._prev_epoch = epoch
        if (logs.get("loss", None) is None or
            not isinstance(logs["loss"], np.floating)):
          raise RuntimeError("loss is supposed to be in the logs and float.")

    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        make_cluster(3, 2))
    with strategy.scope():
      model = sequential.Sequential([core_layers.Dense(10)])
    model.compile(
        gradient_descent.SGD(),
        loss="mse",
        steps_per_execution=steps_per_execution,
        run_eagerly=run_eagerly)
    return model, [ResultAssertingCallback()]
コード例 #5
0
    def setUp(self, num_workers, num_ps):
        super(BaseFaultToleranceTest, self).setUp()

        # Set the environment variable to prevent hanging upon job failure and
        # restart. Note that it defaults to 'use_caller' at Google, but defaults
        # to False in OSS.
        os.environ["GRPC_FAIL_FAST"] = "use_caller"

        self._cluster = multi_worker_test_base.create_multi_process_cluster(
            num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
        self._cluster_def = self._cluster.cluster_resolver.cluster_spec(
        ).as_dict()
        self._cluster_def["chief"] = [
            "localhost:%d" % multi_worker_test_base.pick_unused_port()
        ]
        cluster_resolver = SimpleClusterResolver(server_lib.ClusterSpec(
            self._cluster_def),
                                                 rpc_layer="grpc")

        # The strategy's constructor would connect to the cluster.
        self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver)
        self.cluster_coord = cluster_coordinator.ClusterCoordinator(
            self.strategy)

        self.thread_coord = thread_coordinator.Coordinator(
            clean_stop_exception_types=[])
        self.num_workers = num_workers
コード例 #6
0
    def testBasicVariableWithAggregation(self):
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver)
        strategy.extended._allow_run_without_coordinator = True
        with strategy.scope():
            v = variables.Variable(
                initial_value=[0, 0, 0, 0, 0, 0, 0, 0],
                dtype=dtypes.float32,
                aggregation=variable_scope.VariableAggregation.SUM)

        if strategy.num_replicas_in_sync > 1:
            self.assertIsInstance(v, ps_values.AggregatingVariable)
        else:
            self.assertIsInstance(v, variables.Variable)

        def replica_fn():
            replica_id = distribution_strategy_context.get_replica_context(
            ).replica_id_in_sync_group
            val = array_ops.reshape(
                math_ops.cast(replica_id + 10, dtype=v.dtype), [1])
            v.assign(
                array_ops.concat(
                    [val,
                     constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0))

        strategy.run(replica_fn)

        expected_result = np.arange(8.) * strategy.num_replicas_in_sync
        for i in range(strategy.num_replicas_in_sync):
            expected_result[0] = expected_result[0] + i + 10
        self.assertAllEqual(v, expected_result)
コード例 #7
0
    def testInteractionWithDeviceScope(self):
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver)

        # The strategy scope always wins.
        with strategy.scope():
            with ops.device("/job:ps/replica:0/task:1"):
                v0 = variables.Variable(initial_value=0.0)
            self.assertEqual(v0.device,
                             "/job:ps/replica:0/task:0/device:CPU:0")

            with ops.device("/job:ps/replica:0/task:0"):
                v1 = variables.Variable(initial_value=0.0)
            self.assertEqual(v1.device,
                             "/job:ps/replica:0/task:1/device:CPU:0")

        with ops.device("/job:ps/replica:0/task:1"):
            with strategy.scope():
                v2 = variables.Variable(initial_value=0.0)
                self.assertEqual(v2.device,
                                 "/job:ps/replica:0/task:2/device:CPU:0")

                v3 = variables.Variable(initial_value=0.0)
                self.assertEqual(v3.device,
                                 "/job:ps/replica:0/task:0/device:CPU:0")
コード例 #8
0
    def testDefaultNoPartition(self):
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver)
        with strategy.scope():
            v = variables.Variable([0, 1, 2, 3])

        self.assertIsInstance(v, variables.Variable)
コード例 #9
0
    def testBasic(self):
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
        with strategy.scope():
            init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
            v1 = variables.Variable(
                initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64),
                shape=(5, 2),
                dtype=dtypes.int64)

            init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5])
            v2 = variables.Variable(
                initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64),
                shape=(6, 1),
                dtype=dtypes.int64)

        self.assertIsInstance(v1, sharded_variable.ShardedVariable)
        self.assertLen(v1.variables, 2)
        self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0")
        self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1")
        self.assertAllEqual(v1.variables[0], [[0, 1], [2, 3], [4, 5]])
        self.assertAllEqual(v1.variables[1], [[6, 7], [8, 9]])

        self.assertIsInstance(v2, sharded_variable.ShardedVariable)
        self.assertLen(v2.variables, 2)
        self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0")
        self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1")
        self.assertAllEqual(v2.variables[0], [[0], [1], [2]])
        self.assertAllEqual(v2.variables[1], [[3], [4], [5]])
コード例 #10
0
  def testCreateInsideTFFunction(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))

    collection = []

    @def_function.function
    def create_vars():
      if not collection:
        identity = init_ops_v2.Identity()
        v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32)
        v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32))
        v3 = variables.Variable(
            lambda: identity((2, 2), dtypes.float32),
            dtype=dtypes.float32,
            shape=(2, 2))
        collection.extend([v1, v2, v3])

    with strategy.scope():
      create_vars()
      for v in collection:
        self.assertIsInstance(v, sharded_variable.ShardedVariable)
        self.assertLen(v.variables, 2)
        self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
        self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
        self.assertAllEqual(v.variables[0], [[1., 0.]])
        self.assertAllEqual(v.variables[1], [[0., 1.]])
コード例 #11
0
 def testArbitraryCurrentTaskType(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=1, num_ps=1, has_chief=True)
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
   with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
コード例 #12
0
        def fn(functions_scheduled_event, test_finished_event):
            # TODO(b/170664373): This is needed for TF2 parameter server training in
            # OSS. Remove this when resolved.
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                # An ever-running function.
                for _ in math_ops.range(100000):
                    v.assign_add(1)

            # Keep the two workers occupied.
            ps_coordinator.schedule(worker_fn)
            ps_coordinator.schedule(worker_fn)
            # Now the main process can terminate.
            functions_scheduled_event.set()

            # Verified that join and schedule indeed raise UnavailableError.
            try:
                if test_join:
                    ps_coordinator.join()
                if test_schedule:
                    while ps_coordinator.cluster._closure_queue._error is None:
                        time.sleep(1)
                    ps_coordinator.schedule(worker_fn)
            except errors.UnavailableError:
                # The following verifies that after PS fails, continue executing
                # functions on workers should fail and indicate it's PS failure.
                for worker_id in range(3):
                    with ops.device(
                            "/job:worker/replica:0/task:{}".format(worker_id)):
                        try:
                            # Executing a function after PS fails should result in a PS
                            # failure.
                            worker_fn()
                        except Exception as e:  # pylint: disable=broad-except
                            if coordinator_lib._is_ps_failure(e):
                                if worker_id < 2:
                                    continue
                                logging.info(
                                    "_test_translate_ps_failure_error ends properly."
                                )
                                # Now we can safely exit the test.
                                test_finished_event.set()
                                return
                        raise RuntimeError(
                            "Executing a function after PS fails, should "
                            "result in a PS failure.")

            raise RuntimeError("UnavailableError supposed to be raised.")
コード例 #13
0
  def _model_compile(self,
                     steps_per_execution=1,
                     run_eagerly=False,
                     with_normalization_layer=False):

    class ResultAssertingCallback(callbacks_lib.Callback):

      def __init__(self):
        self._prev_epoch = -1

      def on_epoch_end(self, epoch, logs=None):
        logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs)
        if epoch <= self._prev_epoch:
          raise RuntimeError("Epoch is supposed to be larger than previous.")
        self._prev_epoch = epoch
        if (logs.get("loss", None) is None or
            not isinstance(logs["loss"], np.floating)):
          raise RuntimeError("loss is supposed to be in the logs and float.")

    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        make_cluster(3, 2),
        variable_partitioner=sharded_variable.FixedShardsPartitioner(2))
    with strategy.scope():
      model = sequential.Sequential([core_layers.Dense(10)])
      if with_normalization_layer:
        norm = keras.layers.BatchNormalization(
            axis=-1, input_shape=(4, 4, 3), momentum=0.8)
        model.add(norm)

    model.compile(
        gradient_descent.SGD(),
        loss="mse",
        steps_per_execution=steps_per_execution,
        run_eagerly=run_eagerly)
    return model, [ResultAssertingCallback()]
    def testPartitionWhenLackOfInfo(self):
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver,
            partitioned_variables.fixed_size_partitioner(2))
        with strategy.scope():
            initializer = init_ops_v2.Constant([0, 1, 2, 3])
            # Shape is not explicitly specified.
            v1 = variables.Variable(initial_value=lambda: initializer(
                shape=(4, ), dtype=dtypes.int64),
                                    dtype=dtypes.int64)
            # Dtype is not explicitly specified.
            v2 = variables.Variable(initial_value=lambda: initializer(
                shape=(4, ), dtype=dtypes.int64),
                                    shape=(4, ))
            # Neither shape nor dtype is explicitly specified.
            v3 = variables.Variable(initial_value=lambda: initializer(
                shape=(4, ), dtype=dtypes.int64))

        for v in [v1, v2, v3]:
            self.assertIsInstance(v, sharded_variable.ShardedVariable)
            self.assertLen(v.variables, 2)
            self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
            self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
            self.assertAllEqual(v.variables[0], [0, 1])
            self.assertAllEqual(v.variables[1], [2, 3])
コード例 #15
0
  def _model_compile(self,
                     strategy,
                     steps_per_execution=1,
                     run_eagerly=False,
                     with_normalization_layer=False):

    class ResultAssertingCallback(callbacks_lib.Callback):

      def __init__(self):
        self._prev_epoch = -1
        self._loss_to_compare_against = 2  # Empirical initial value

      def on_epoch_end(self, epoch, logs=None):
        logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs)
        if epoch <= self._prev_epoch:
          raise RuntimeError("Epoch is supposed to be larger than previous.")
        self._prev_epoch = epoch
        is_loss_float = (
            logs.get("loss", None) is not None and
            isinstance(logs["loss"], (float, np.floating)))
        if not is_loss_float:
          raise RuntimeError("loss is supposed to be in the logs and float.")
        if epoch == 0 or epoch == 9:
          # Making sure the loss of first epoch is below 1, and that of last
          # epoch is smaller than the first epoch.
          if logs["loss"] > self._loss_to_compare_against:
            raise RuntimeError(
                "loss at epoch {} is larger than previous.".format(epoch))
          self._loss_to_compare_against = logs["loss"]

      def on_train_end(self, logs=None):
        if self._prev_epoch != 9:
          raise RuntimeError("Unexpected last epoch: {}".format(
              self._prev_epoch))

    # TODO(b/182193218): Use ParameterServerStrategy as a proper strategy
    # combination.
    if strategy == "ParameterServerStrategy":
      gpu_devices = config.list_physical_devices("GPU")
      if len(gpu_devices) > 1:
        self.skipTest("b/178452835: Multi-GPUs not supported in "
                      "ParameterServerStrategy.")
      strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
          multi_worker_testing_utils.make_parameter_server_cluster(3, 2),
          variable_partitioner=sharded_variable.FixedShardsPartitioner(2))

    with strategy.scope():
      model = sequential.Sequential([core_layers.Dense(10)])
      if with_normalization_layer:
        norm = keras.layers.BatchNormalization(
            axis=-1, input_shape=(4, 4, 3), momentum=0.8)
        model.add(norm)

    model.compile(
        gradient_descent.SGD(),
        loss="mse",
        steps_per_execution=steps_per_execution,
        run_eagerly=run_eagerly)
    return model, [ResultAssertingCallback()]
コード例 #16
0
    def testSparselyReadForEmbeddingLookup(self):
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver)

        class FakeModel(module.Module):
            def __init__(self):
                self._var0 = variables.Variable([1.0, 2.0, 3.0, 4.0])
                self._var1 = variables.Variable([5.0, 6.0, 7.0, 8.0])

            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(shape=[2],
                                       dtype=dtypes.int32,
                                       name="inputs")
            ])
            def func(self, x):
                return embedding_ops.embedding_lookup([self._var0, self._var1],
                                                      x)

        with strategy.scope():
            model = FakeModel()

        # Assert that ResourceGather op exists instead of Gather in training
        # function.
        found_resource_gather = False
        found_gather = False

        for n in model.func.get_concrete_function().graph.as_graph_def().node:
            if n.op == "ResourceGather":
                found_resource_gather = True
            elif n.op == "Gather":
                found_gather = True
        self.assertTrue(found_resource_gather)
        self.assertFalse(found_gather)

        # Assert that ResourceGather op exists instead of Gather in saved_model.
        found_resource_gather = False
        found_gather = False

        tmp_dir = self.get_temp_dir()
        save.save(model, tmp_dir, signatures=model.func)

        with gfile.Open("%s/saved_model.pb" % tmp_dir, "rb") as f:
            saved_model_proto = saved_model_pb2.SavedModel().FromString(
                f.read())

        for function in saved_model_proto.meta_graphs[
                0].graph_def.library.function:
            for n in function.node_def:
                if n.op == "ResourceGather":
                    found_resource_gather = True
                    resource_gather_device = n.device
                elif n.op == "Gather":
                    found_gather = True
        self.assertTrue(found_resource_gather)
        self.assertFalse(found_gather)

        # We also assert that the colocate_with in embedding_ops will not result in
        # a hard-coded device string.
        self.assertEmpty(resource_gather_device)
コード例 #17
0
 def testLessThanOneWorker(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=0, num_ps=1, has_chief=True)
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
   with self.assertRaisesRegexp(ValueError,
                                "There must be at least one worker."):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
コード例 #18
0
 def _create_strategy(self, num_shards):
     if num_shards > 1:
         strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
             self.cluster_resolver,
             variable_partitioner=sharded_variable.FixedShardsPartitioner(
                 num_shards))
     else:
         strategy = ds_context._get_default_strategy()
     return strategy
コード例 #19
0
 def testArbitraryCurrentTaskType(self):
   cluster_def = multi_worker_test_base._create_cluster(
       num_workers=1, num_ps=1)
   cluster_def["chief"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
   with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
コード例 #20
0
 def testArbitraryJobName(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=1, num_ps=1, has_chief=True)
   cluster_def["some_arbitrary_name"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc")
   with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
コード例 #21
0
ファイル: saved_model_test.py プロジェクト: mrax714/nearme
    def test_load_with_partitioner_raises_error(self):
        model = self.Model()
        model_dir = self.get_temp_dir()
        tf.saved_model.save(model, model_dir)

        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver, tf1.fixed_size_partitioner(2))
        with self.assertRaisesRegex(ValueError, "`variable_partitioner`"):
            with strategy.scope():
                tf.saved_model.load(model_dir)
コード例 #22
0
  def test_load_with_partitioner_raises_error(self):
    model = self.Model()
    model_dir = self.get_temp_dir()
    tf.saved_model.save(model, model_dir)

    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, tf1.fixed_size_partitioner(2))
    with self.assertRaises(errors_impl.InvalidArgumentError):
      with strategy.scope():
        tf.saved_model.load(model_dir)
コード例 #23
0
  def testDistributeDatasetFromFunctionNotUsedWithClusterCoordinator(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)

    def dataset_fn(_):
      return dataset_ops.DatasetV2.range(3)

    with self._assertRaisesUsageError():
      def_function.function(
          lambda: strategy.distribute_datasets_from_function(dataset_fn))()
コード例 #24
0
def make_client(num_workers, num_ps):
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return client_lib.Client(
        parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver))
コード例 #25
0
  def test_sharded_variable(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, tf1.fixed_size_partitioner(2))
    model_dir = self.get_temp_dir()
    with strategy.scope():
      m = self.Model()
      self.assertIsInstance(m.v1, sharded_variable.ShardedVariable)
    m.train()
    tf.saved_model.save(m, model_dir)

    self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6])
コード例 #26
0
 def testLessThanOneWorker(self):
   cluster_def = multi_worker_test_base._create_cluster(
       num_workers=0, num_ps=1)
   cluster_def["chief"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
   with self.assertRaisesRegexp(ValueError,
                                "There must be at least one worker."):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
コード例 #27
0
def make_coordinator(num_workers, num_ps):
  # TODO(rchao): Test the internal rpc_layer version.
  cluster_def = multi_worker_test_base.create_in_process_cluster(
      num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
  cluster_def['chief'] = [
      'localhost:%d' % multi_worker_test_base.pick_unused_port()
  ]
  cluster_resolver = SimpleClusterResolver(
      ClusterSpec(cluster_def), rpc_layer='grpc')
  strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
      cluster_resolver)
  return coordinator_lib.ClusterCoordinator(strategy)
コード例 #28
0
  def testDistributeDatasetUsedDirectly(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)
    dataset = dataset_ops.DatasetV2.range(3)
    distributed_dataset = strategy.experimental_distribute_dataset(dataset)
    with self.assertRaises(ValueError):
      iter(distributed_dataset)

    distributed_dataset = strategy.distribute_datasets_from_function(
        lambda: dataset)
    with self.assertRaises(ValueError):
      iter(distributed_dataset)
コード例 #29
0
  def testRunUsedWithTestOnlyMode(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)
    strategy.extended._allow_run_without_coordinator = True
    dataset = dataset_ops.DatasetV2.range(15)
    with strategy.scope():
      v = variables.Variable(1, dtype=dtypes.int64)

    def step_fn(iterator):
      return next(iterator) + v

    strategy.run(step_fn, args=(iter(dataset),))
コード例 #30
0
  def testRunNotUsedWithClusterCoordinator(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)
    dataset = dataset_ops.DatasetV2.range(8)
    with strategy.scope():
      v = variables.Variable(1, dtype=dtypes.int64)

    def step_fn(iterator):
      return next(iterator) + v

    with self._assertRaisesUsageWarningWithSchedule():
      strategy.run(step_fn, args=(iter(dataset),))