Esempio n. 1
0
    def testDistStratRestore(self, strat1, strat2, jit_replica_fn):
        """Tests checkpointing and restoring (to possibly different #replicas)."""
        if strat2 is None:
            strat2 = strat1
        strat1_name = type(strat1).__name__
        strat2_name = type(strat2).__name__
        if "Default" in strat1_name or "Default" in strat2_name:
            self.skipTest(
                "We don't guarantee consistency between strategy and no-strategy."
            )
        if ("TPU" in strat1_name
                or "TPU" in strat2_name) and not jit_replica_fn:
            self.skipTest(
                "TPUStrategy requires the replica function (the function passed to "
                "strategy.run) to be decorated with tf.function")
        coord1 = None
        if "ParameterServer" in strat1_name:
            coord1 = coordinator_lib.ClusterCoordinator(strat1)
        coord2 = None
        if "ParameterServer" in strat2_name:
            coord2 = coordinator_lib.ClusterCoordinator(strat2)
        fname = os.path.join(self.get_temp_dir(), "checkpoint")

        def uniform(strat, coord, g):
            def f():
                return g.uniform_full_int([3], dtype=dtypes.int32)

            replica_fn = def_function.function(f) if jit_replica_fn else f
            result = run_on_strategy(replica_fn, strat, coord)
            return strat.experimental_local_results(result)

        with strat1.scope():
            g1 = rng.Generator.from_seed(1)
        with strat2.scope():
            g2 = rng.Generator.from_seed(10)
        cp1 = tracking_util.Checkpoint(g=g1)
        cp2 = tracking_util.Checkpoint(g=g2)

        def write_restore_compare():
            cp1.write(fname)
            r1 = uniform(strat1, coord1, g1)
            cp2.restore(fname)
            r2 = uniform(strat2, coord2, g2)
            # Tests that overlapping replicas are properly restored.
            n1 = get_num_local_replicas(strat1)
            n2 = get_num_local_replicas(strat2)
            n = min(n1, n2)
            self.assertAllEqual(r1[:n], r2[:n])

        # Run multiple times so that cp1.write is called in various RNG states
        for _ in range(2):
            write_restore_compare()
Esempio n. 2
0
 def testDistStrat(self, strat, jit_replica_fn, seeded):
   """Tests RNG with distribution strategies."""
   strat_name = type(strat).__name__
   if "TPU" in strat_name and not jit_replica_fn:
     self.skipTest(
         "TPUStrategy requires the replica function (the function passed to "
         "strategy.run) to be decorated with tf.function")
   coord = None
   if "ParameterServer" in strat_name:
     coord = coordinator_lib.ClusterCoordinator(strat)
   creators = {
       True: functools.partial(rng.Generator.from_seed, 1234),
       False: rng.Generator.from_non_deterministic_state,
   }
   shape = [3, 4]
   dtype = dtypes.int32
   creator = creators[seeded]
   with strat.scope():
     gen = creator()
     def f():
       t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
       t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
       t = array_ops.stack([t1, t2])
       return t
     replica_fn = def_function.function(f) if jit_replica_fn else f
     results = run_on_strategy(replica_fn, strat, coord)
     values = strat.experimental_local_results(results)
     n = get_num_local_replicas(strat, values)
     self.assertAllEqual(n, len(values))
     self.assertAllDifferent(values)
Esempio n. 3
0
  def setUp(self):
    super(FaultToleranceTest, self).setUp()

    # Set the environment variable to prevent hanging upon job failure and
    # restart. Note that it defaults to 'use_caller' at Google, but defaults
    # to False in OSS.
    os.environ["GRPC_FAIL_FAST"] = "use_caller"

    self._cluster = multi_worker_test_base.create_multi_process_cluster(
        num_workers=FaultToleranceTest.NUM_WORKERS,
        num_ps=FaultToleranceTest.NUM_PS,
        rpc_layer="grpc")
    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
    self._cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(
        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")

    # The strategy's constructor would connect to the cluster.
    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        cluster_resolver)
    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)

    self.thread_coord = thread_coordinator.Coordinator(
        clean_stop_exception_types=[])
Esempio n. 4
0
 def testDistVarAsTFFunArg(self, strat, jit_replica_fn):
   """Tests that RNG with dist variables can be used as tf.function's arg."""
   strat_name = type(strat).__name__
   if "CentralStorage" in strat_name:
     self.skipTest(
         "CentralStorageStrategy wraps variable updates in merge_call which "
         "can't be called inside a tf.function that doesn't cover the entire "
         "replica function (the function passed to strategy.run).")
   if "TPU" in strat_name and not jit_replica_fn:
     self.skipTest(
         "TPUStrategy requires the replica function (the function passed to "
         "strategy.run) to be decorated with tf.function")
   coord = None
   if "ParameterServer" in strat_name:
     coord = coordinator_lib.ClusterCoordinator(strat)
   shape = [3, 4]
   dtype = dtypes.int32
   with strat.scope():
     gen = rng.Generator.from_seed(1234)
     @def_function.function
     def f(gen):  # the main focus
       t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
       t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
       t = array_ops.stack([t1, t2])
       return t
     def g():
       return f(gen)
     replica_fn = def_function.function(g) if jit_replica_fn else g
     for _ in range(2):
       results = run_on_strategy(replica_fn, strat, coord)
       values = strat.experimental_local_results(results)
       n = get_num_local_replicas(strat, values)
       self.assertAllEqual(n, len(values))
       self.assertAllDifferent(values)
  def testInModelAndCapture(self, source):

    file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")

    model = self.Model(source, file_path)
    func_captures = model.use_table.get_concrete_function(
    ).graph.external_captures
    self.assertLen(func_captures, 2)
    self.assertTrue(
        any(model.table.resource_handle is t for t in func_captures))
    deferred_captures = model.use_table.get_concrete_function(
    ).graph.deferred_external_captures
    self.assertEmpty(deferred_captures)

    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)
    coordinator = coordinator_lib.ClusterCoordinator(strategy)
    with strategy.scope():
      distributed_model = self.Model("value", file_path)
    func_captures = distributed_model.use_table.get_concrete_function(
    ).graph.external_captures
    # One less external_capture, since the table handle becomes a closure in the
    # deferred_external_capture
    self.assertLen(func_captures, 1)
    self.assertFalse(
        any(model.table.resource_handle is t for t in func_captures))
    deferred_captures = distributed_model.use_table.get_concrete_function(
    ).graph.deferred_external_captures
    self.assertNotEmpty(deferred_captures)

    self.verifyWorkerLocalInstance(coordinator, distributed_model)
        def fn(functions_scheduled_event, test_finished_event):
            # TODO(b/170664373): This is needed for TF2 parameter server training in
            # OSS. Remove this when resolved.
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                # An ever-running function.
                for _ in math_ops.range(100000):
                    v.assign_add(1)

            # Keep the two workers occupied.
            ps_coordinator.schedule(worker_fn)
            ps_coordinator.schedule(worker_fn)
            # Now the main process can terminate.
            functions_scheduled_event.set()

            # Verified that join and schedule indeed raise UnavailableError.
            try:
                if test_join:
                    ps_coordinator.join()
                if test_schedule:
                    while ps_coordinator.cluster._closure_queue._error is None:
                        time.sleep(1)
                    ps_coordinator.schedule(worker_fn)
            except errors.UnavailableError:
                # The following verifies that after PS fails, continue executing
                # functions on workers should fail and indicate it's PS failure.
                for worker_id in range(3):
                    with ops.device(
                            "/job:worker/replica:0/task:{}".format(worker_id)):
                        try:
                            # Executing a function after PS fails should result in a PS
                            # failure.
                            worker_fn()
                        except Exception as e:  # pylint: disable=broad-except
                            if coordinator_lib._is_ps_failure(e):
                                if worker_id < 2:
                                    continue
                                logging.info(
                                    "_test_translate_ps_failure_error ends properly."
                                )
                                # Now we can safely exit the test.
                                test_finished_event.set()
                                return
                        raise RuntimeError(
                            "Executing a function after PS fails, should "
                            "result in a PS failure.")

            raise RuntimeError("UnavailableError supposed to be raised.")
    def test_dataset_creator_input_options_with_cluster_coordinator(self):
        dataset_fn = lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 1])
        input_options = distribute_lib.InputOptions(
            experimental_fetch_to_device=True,
            experimental_per_replica_buffer_size=2)
        x = dataset_creator.DatasetCreator(dataset_fn,
                                           input_options=input_options)
        strategy = self._get_parameter_server_strategy()
        with strategy.scope():
            model = sequential.Sequential([core_layers.Dense(10)])
            model._cluster_coordinator = cluster_coordinator.ClusterCoordinator(
                strategy)
            data_handler = data_adapter.get_data_handler(x,
                                                         steps_per_epoch=2,
                                                         model=model)

        iter_rv = iter(data_handler._dataset)._values[0]
        iter_rv._rebuild_on(model._cluster_coordinator._cluster.workers[0])
        distributed_iterator = iter_rv._get_values()

        # Ensuring the resulting `DistributedIterator` has the right options.
        self.assertTrue(
            distributed_iterator._options.experimental_fetch_to_device)
        self.assertEqual(
            distributed_iterator._options.experimental_per_replica_buffer_size,
            2)
def make_coordinator(num_workers, num_ps):
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return coordinator_lib.ClusterCoordinator(
        parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver))
def make_coordinator(num_workers, num_ps):
  # TODO(rchao): Test the internal rpc_layer version.
  cluster_def = multi_worker_test_base.create_in_process_cluster(
      num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
  cluster_def['chief'] = [
      'localhost:%d' % multi_worker_test_base.pick_unused_port()
  ]
  cluster_resolver = SimpleClusterResolver(
      ClusterSpec(cluster_def), rpc_layer='grpc')
  strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
      cluster_resolver)
  return coordinator_lib.ClusterCoordinator(strategy)
    def setUpClass(cls):
        super(EvaluationTest, cls).setUpClass()
        cls._cluster = multi_worker_test_base.create_multi_process_cluster(
            num_workers=3, num_ps=2, rpc_layer="grpc")
        cls._cluster_def = cls._cluster.cluster_resolver.cluster_spec(
        ).as_dict()
        cluster_resolver = SimpleClusterResolver(server_lib.ClusterSpec(
            cls._cluster_def),
                                                 rpc_layer="grpc")

        cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver)
        cls.cluster_coord = coordinator_lib.ClusterCoordinator(cls.strategy)
  def testCopyDistributedTable(self, source):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)

    coordinator_lib.ClusterCoordinator(strategy=strategy)

    with strategy.scope():
      lookuptable = self.createStaticHashTable(
          init_source=source, vals=[0, 1, 2], default_value=-2)

    new_table = copy.copy(lookuptable)
    # No new coordinator instance or distributed tables are created.
    self.assertDictEqual(lookuptable.__dict__, new_table.__dict__)
    def testCreateLookupInDatasetFnUnderScope(self, source):
        # TODO(wxinyi): Warn the user of the inefficiency of this workflow (i.e.
        # creating `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to
        # be distributed with `distribute_datasets_from_function` and
        # `create_per_worker_dataset`.
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver)

        coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)

        with strategy.scope():

            def dataset_fn(input_context):
                some_out_of_range_tensor = constant_op.constant(
                    10, dtype=dtypes.int64)
                lookuptable = self.createStaticHashTable(init_source=source,
                                                         vals=[0, 1, 2],
                                                         default_value=-2)

                self.assertNotIsInstance(lookuptable,
                                         ps_values.DistributedTable)

                generation_tensor = lookuptable.lookup(
                    some_out_of_range_tensor)
                dataset = self.makeDatasetFromTensorWithoutUsingResource(
                    input_context, generation_tensor)
                return dataset

            @def_function.function
            def per_worker_dataset_fn():
                return strategy.distribute_datasets_from_function(dataset_fn)

            per_worker_dataset = coordinator.create_per_worker_dataset(
                per_worker_dataset_fn)
            per_worker_iterator = iter(per_worker_dataset)

            @def_function.function
            def worker_fn(iterator):
                return math_ops.reduce_sum(next(iterator))

            result = []
            for _ in range(10):
                result.append(
                    coordinator.schedule(worker_fn,
                                         args=(per_worker_iterator, )))

            for r in result:
                returned_input = r.fetch()
                self.assertAllClose(-48, returned_input)
  def testCreateDistributedTableInScope(self, source):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)

    coordinator_lib.ClusterCoordinator(strategy=strategy)

    with strategy.scope():
      lookuptable = self.createStaticHashTable(
          init_source=source, vals=[0, 1, 2], default_value=-2)

    self.assertIsInstance(lookuptable, ps_values.DistributedTable)
    self.assertEqual(self.evaluate(lookuptable.size()), 3)

    # Lookup on the coordinator.
    output = lookuptable.lookup(
        constant_op.constant([0, 1, -1], dtype=dtypes.int64))
    self.assertAllEqual([0, 1, -2], output)
    self.assertEqual(lookuptable.size(), 3)
    def testAccessingResourceHandleInDatasetFnWithMapFnDefinedOutside(
            self, source):
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver)

        coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)

        with strategy.scope():
            lookuptable = self.createStaticHashTable(init_source=source,
                                                     vals=[0, 1, 2],
                                                     default_value=-2)

        def map_fn(vals):
            return lookuptable.lookup(vals)

        def dataset_fn(input_context):
            generation_tensor = constant_op.constant([0, 1, 3],
                                                     dtype=dtypes.int64)
            dataset = self.makeDatasetFromTensorWithoutUsingResource(
                input_context, generation_tensor)
            dataset = dataset.map(map_fn)
            return dataset

        @def_function.function
        def per_worker_dataset_fn():
            return strategy.distribute_datasets_from_function(dataset_fn)

        per_worker_dataset = coordinator.create_per_worker_dataset(
            per_worker_dataset_fn)
        per_worker_iterator = iter(per_worker_dataset)

        @def_function.function
        def worker_fn(iterator):
            return math_ops.reduce_sum(next(iterator))

        result = []
        for _ in range(10):
            # batch_size == 24 and each input is [0, 1, -2]
            result.append(
                coordinator.schedule(worker_fn, args=(per_worker_iterator, )))

        for r in result:
            returned_input = r.fetch()
            self.assertAllClose(-24, returned_input)
    def testInModelAndCapture(self, source):

        file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")

        model = self.Model(source, file_path)
        func_captures = model.use_table.get_concrete_function(
        ).graph.external_captures
        self.assertLen(func_captures, 2)
        self.assertTrue(
            any(model.table.resource_handle is t for t in func_captures))
        deferred_captures = model.use_table.get_concrete_function(
        ).graph.deferred_external_captures
        self.assertEmpty(deferred_captures)

        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver)
        coordinator = coordinator_lib.ClusterCoordinator(strategy)
        with strategy.scope():
            distributed_model = self.Model("value", file_path)
        func_captures = distributed_model.use_table.get_concrete_function(
        ).graph.external_captures
        # One less external_capture, since the table handle becomes a closure in the
        # deferred_external_capture
        self.assertLen(func_captures, 1)
        self.assertFalse(
            any(model.table.resource_handle is t for t in func_captures))
        deferred_captures = distributed_model.use_table.get_concrete_function(
        ).graph.deferred_external_captures
        self.assertNotEmpty(deferred_captures)

        # assert capturing a worker-local resource on each worker
        for worker in coordinator._cluster.workers:
            with coordinator_context.with_dispatch_context(worker):
                for capture in [
                        t for t in distributed_model.use_table.
                        get_concrete_function().captured_inputs
                        if t.dtype == dtypes.resource
                ]:
                    if capture.dtype == dtypes.resource:
                        self.assertEqual(
                            capture.device,
                            device_util.canonicalize(
                                "/CPU:0", default=worker.device_name))
    def testAccessingResourceHandleInDatasetFnWithoutMap(self, source):

        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver)

        coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)

        with strategy.scope():
            lookuptable = self.createStaticHashTable(init_source=source,
                                                     vals=[0, 1, 2],
                                                     default_value=-2)

        def dataset_fn(input_context):
            some_out_of_range_tensor = constant_op.constant(10,
                                                            dtype=dtypes.int64)

            self.assertIsInstance(lookuptable, ps_values.DistributedTable)

            generation_tensor = lookuptable.lookup(some_out_of_range_tensor)
            dataset = self.makeDatasetFromTensorWithoutUsingResource(
                input_context, generation_tensor)
            return dataset

        @def_function.function
        def per_worker_dataset_fn():
            return strategy.distribute_datasets_from_function(dataset_fn)

        per_worker_dataset = coordinator.create_per_worker_dataset(
            per_worker_dataset_fn)
        per_worker_iterator = iter(per_worker_dataset)

        @def_function.function
        def worker_fn(iterator):
            return math_ops.reduce_sum(next(iterator))

        result = []
        for _ in range(10):
            result.append(
                coordinator.schedule(worker_fn, args=(per_worker_iterator, )))

        for r in result:
            returned_input = r.fetch()
            self.assertAllClose(-48, returned_input)
Esempio n. 17
0
    def testBasicShardedVariableWithAggregation(self):
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
        coordinator = coordinator_lib.ClusterCoordinator(strategy)
        with strategy.scope():
            v = variables.Variable(
                initial_value=[0, 0, 0, 0, 0, 0, 0, 0],
                dtype=dtypes.float32,
                aggregation=variable_scope.VariableAggregation.SUM)

        self.assertIsInstance(v, sharded_variable.ShardedVariable)
        self.assertLen(v.variables, 2)
        if strategy.num_replicas_in_sync > 1:
            self.assertIsInstance(v.variables[0],
                                  ps_values.AggregatingVariable)
        else:
            self.assertIsInstance(v.variables[0], variables.Variable)

        @def_function.function
        def worker_fn():
            def replica_fn():
                replica_id = distribution_strategy_context.get_replica_context(
                ).replica_id_in_sync_group
                val = array_ops.reshape(
                    math_ops.cast(replica_id + 10, dtype=v.dtype), [1])
                v.assign(
                    array_ops.concat([
                        val,
                        constant_op.constant([1., 2., 3., 4., 5., 6., 7.])
                    ], 0))

            strategy.run(replica_fn)

        coordinator.schedule(worker_fn)
        coordinator.join()
        expected_result = np.arange(8.) * strategy.num_replicas_in_sync
        for i in range(strategy.num_replicas_in_sync):
            expected_result[0] = expected_result[0] + i + 10
        expected_result = np.array_split(expected_result, 2)
        self.assertAllEqual(expected_result[0], v.variables[0])
        self.assertAllEqual(expected_result[1], v.variables[1])
Esempio n. 18
0
  def setUp(self, num_workers, num_ps):
    super(BaseFaultToleranceTest, self).setUp()

    self._cluster = multi_worker_test_base.create_multi_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
    self._cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(
        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")

    # The strategy's constructor would connect to the cluster.
    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        cluster_resolver)
    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)

    self.thread_coord = thread_coordinator.Coordinator(
        clean_stop_exception_types=[])
    self.num_workers = num_workers
    self.num_ps = num_ps
        def fn(first_fetch_occurred_event, worker_terminated_event):
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                return v + 1, v - 1

            remote_value = ps_coordinator.schedule(worker_fn)
            logging.info("result (1st fetch): %r", remote_value.fetch())
            first_fetch_occurred_event.set()
            worker_terminated_event.wait()
            logging.info("result (2nd fetch): %r", remote_value.fetch())
        def fn(functions_scheduled_event):
            # TODO(b/170664373): This is needed for TF2 parameter server training in
            # OSS. Remove this when resolved.
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_client = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=1)

                @def_function.function
                def worker_fn(input_tensor):
                    def replica_fn(input_tensor):
                        return input_tensor + v

                    run_result = strategy.run(replica_fn,
                                              args=(input_tensor, ))
                    check_ops.assert_equal_v2(run_result, 4)
                    return run_result

            for i in range(5000):
                if i % 500 == 0:
                    logging.info("Scheduling function-{}...".format(i))
                result = ps_client.schedule(worker_fn,
                                            args=(constant_op.constant(3), ))
            functions_scheduled_event.set()
            logging.info("Joining...")
            ps_client.join()
            logging.info("Finished joining.")
            if result.fetch() != 4:
                raise AssertionError(
                    "Unexpected RemoteValue result: {}".format(result.fetch()))
            logging.info("testStrategyRun succeeded")
  def testDistributeTableSaveAndLoadUnderStrategy(self, load, source):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver)
    coordinator = coordinator_lib.ClusterCoordinator(strategy)
    file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
    with strategy.scope():
      model = self.Model(source, file_path)
    model_dir = self.get_temp_dir()
    tf_save.save(model, model_dir)

    if load == "tf_load":
      load_fn = tf_load.load
    else:
      load_fn = keras_save.load_model

    with strategy.scope():
      loaded = load_fn(model_dir)

    loaded_func_captures = (
        loaded.use_table.get_concrete_function().graph.external_captures)
    loaded_func_deferred_captures = (
        loaded.use_table.get_concrete_function().graph
        .deferred_external_captures)
    # Compared with loading without strategy, there is one less
    # external_capture, since the captured table handle has been swapped to a
    # closure in the deferred_external_capture
    self.assertLen(loaded_func_captures, 1)
    self.assertNotEmpty(loaded_func_deferred_captures)

    self.assertIsInstance(loaded.table, ps_values.DistributedTable)

    self.assertLen([
        t for t in loaded.use_table.get_concrete_function().captured_inputs
        if t.dtype == dtypes.resource
    ], 1)

    self.verifyWorkerLocalInstance(coordinator, loaded)
Esempio n. 22
0
 def testClusterCoordinatorSingleInstance(self, strategy):
     model = self._model_fit(strategy)
     strategy = model.distribute_strategy
     self.assertIs(strategy._cluster_coordinator,
                   coordinator_lib.ClusterCoordinator(strategy))
def make_coordinator(num_workers, num_ps):
    return coordinator_lib.ClusterCoordinator(
        parameter_server_strategy_v2.ParameterServerStrategyV2(
            make_cluster(num_workers, num_ps)))
Esempio n. 24
0
def make_coordinator(num_workers, num_ps, variable_partitioner=None):
  return coordinator_lib.ClusterCoordinator(
      parameter_server_strategy_v2.ParameterServerStrategyV2(
          make_cluster(num_workers, num_ps),
          variable_partitioner=variable_partitioner))
 def testClusterCoordinatorOnlyInitOnce(self):
   cluster = self.coordinator._cluster
   same_coordinator = coordinator_lib.ClusterCoordinator(self.strategy)
   self.assertIs(self.coordinator, same_coordinator)
   self.assertIs(cluster, same_coordinator._cluster)
Esempio n. 26
0
 def setUpClass(cls):
     super(KPLCreatedInDatasetsFromFunctionTest, cls).setUpClass()
     cls.coordinator = coordinator_lib.ClusterCoordinator(
         parameter_server_strategy_v2.ParameterServerStrategyV2(
             multi_worker_testing_utils.make_parameter_server_cluster(3,
                                                                      2)))