def testDistStratRestore(self, strat1, strat2, jit_replica_fn): """Tests checkpointing and restoring (to possibly different #replicas).""" if strat2 is None: strat2 = strat1 strat1_name = type(strat1).__name__ strat2_name = type(strat2).__name__ if "Default" in strat1_name or "Default" in strat2_name: self.skipTest( "We don't guarantee consistency between strategy and no-strategy." ) if ("TPU" in strat1_name or "TPU" in strat2_name) and not jit_replica_fn: self.skipTest( "TPUStrategy requires the replica function (the function passed to " "strategy.run) to be decorated with tf.function") coord1 = None if "ParameterServer" in strat1_name: coord1 = coordinator_lib.ClusterCoordinator(strat1) coord2 = None if "ParameterServer" in strat2_name: coord2 = coordinator_lib.ClusterCoordinator(strat2) fname = os.path.join(self.get_temp_dir(), "checkpoint") def uniform(strat, coord, g): def f(): return g.uniform_full_int([3], dtype=dtypes.int32) replica_fn = def_function.function(f) if jit_replica_fn else f result = run_on_strategy(replica_fn, strat, coord) return strat.experimental_local_results(result) with strat1.scope(): g1 = rng.Generator.from_seed(1) with strat2.scope(): g2 = rng.Generator.from_seed(10) cp1 = tracking_util.Checkpoint(g=g1) cp2 = tracking_util.Checkpoint(g=g2) def write_restore_compare(): cp1.write(fname) r1 = uniform(strat1, coord1, g1) cp2.restore(fname) r2 = uniform(strat2, coord2, g2) # Tests that overlapping replicas are properly restored. n1 = get_num_local_replicas(strat1) n2 = get_num_local_replicas(strat2) n = min(n1, n2) self.assertAllEqual(r1[:n], r2[:n]) # Run multiple times so that cp1.write is called in various RNG states for _ in range(2): write_restore_compare()
def testDistStrat(self, strat, jit_replica_fn, seeded): """Tests RNG with distribution strategies.""" strat_name = type(strat).__name__ if "TPU" in strat_name and not jit_replica_fn: self.skipTest( "TPUStrategy requires the replica function (the function passed to " "strategy.run) to be decorated with tf.function") coord = None if "ParameterServer" in strat_name: coord = coordinator_lib.ClusterCoordinator(strat) creators = { True: functools.partial(rng.Generator.from_seed, 1234), False: rng.Generator.from_non_deterministic_state, } shape = [3, 4] dtype = dtypes.int32 creator = creators[seeded] with strat.scope(): gen = creator() def f(): t1 = gen.uniform_full_int(shape=shape, dtype=dtype) t2 = gen.uniform_full_int(shape=shape, dtype=dtype) t = array_ops.stack([t1, t2]) return t replica_fn = def_function.function(f) if jit_replica_fn else f results = run_on_strategy(replica_fn, strat, coord) values = strat.experimental_local_results(results) n = get_num_local_replicas(strat, values) self.assertAllEqual(n, len(values)) self.assertAllDifferent(values)
def setUp(self): super(FaultToleranceTest, self).setUp() # Set the environment variable to prevent hanging upon job failure and # restart. Note that it defaults to 'use_caller' at Google, but defaults # to False in OSS. os.environ["GRPC_FAIL_FAST"] = "use_caller" self._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=FaultToleranceTest.NUM_WORKERS, num_ps=FaultToleranceTest.NUM_PS, rpc_layer="grpc") self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() self._cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc") # The strategy's constructor would connect to the cluster. self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy) self.thread_coord = thread_coordinator.Coordinator( clean_stop_exception_types=[])
def testDistVarAsTFFunArg(self, strat, jit_replica_fn): """Tests that RNG with dist variables can be used as tf.function's arg.""" strat_name = type(strat).__name__ if "CentralStorage" in strat_name: self.skipTest( "CentralStorageStrategy wraps variable updates in merge_call which " "can't be called inside a tf.function that doesn't cover the entire " "replica function (the function passed to strategy.run).") if "TPU" in strat_name and not jit_replica_fn: self.skipTest( "TPUStrategy requires the replica function (the function passed to " "strategy.run) to be decorated with tf.function") coord = None if "ParameterServer" in strat_name: coord = coordinator_lib.ClusterCoordinator(strat) shape = [3, 4] dtype = dtypes.int32 with strat.scope(): gen = rng.Generator.from_seed(1234) @def_function.function def f(gen): # the main focus t1 = gen.uniform_full_int(shape=shape, dtype=dtype) t2 = gen.uniform_full_int(shape=shape, dtype=dtype) t = array_ops.stack([t1, t2]) return t def g(): return f(gen) replica_fn = def_function.function(g) if jit_replica_fn else g for _ in range(2): results = run_on_strategy(replica_fn, strat, coord) values = strat.experimental_local_results(results) n = get_num_local_replicas(strat, values) self.assertAllEqual(n, len(values)) self.assertAllDifferent(values)
def testInModelAndCapture(self, source): file_path = os.path.join(self.get_temp_dir(), "text_file_initializer") model = self.Model(source, file_path) func_captures = model.use_table.get_concrete_function( ).graph.external_captures self.assertLen(func_captures, 2) self.assertTrue( any(model.table.resource_handle is t for t in func_captures)) deferred_captures = model.use_table.get_concrete_function( ).graph.deferred_external_captures self.assertEmpty(deferred_captures) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): distributed_model = self.Model("value", file_path) func_captures = distributed_model.use_table.get_concrete_function( ).graph.external_captures # One less external_capture, since the table handle becomes a closure in the # deferred_external_capture self.assertLen(func_captures, 1) self.assertFalse( any(model.table.resource_handle is t for t in func_captures)) deferred_captures = distributed_model.use_table.get_concrete_function( ).graph.deferred_external_captures self.assertNotEmpty(deferred_captures) self.verifyWorkerLocalInstance(coordinator, distributed_model)
def fn(functions_scheduled_event, test_finished_event): # TODO(b/170664373): This is needed for TF2 parameter server training in # OSS. Remove this when resolved. os.environ["GRPC_FAIL_FAST"] = "use_caller" cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) ps_coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable(initial_value=0, dtype=dtypes.int32) @def_function.function def worker_fn(): # An ever-running function. for _ in math_ops.range(100000): v.assign_add(1) # Keep the two workers occupied. ps_coordinator.schedule(worker_fn) ps_coordinator.schedule(worker_fn) # Now the main process can terminate. functions_scheduled_event.set() # Verified that join and schedule indeed raise UnavailableError. try: if test_join: ps_coordinator.join() if test_schedule: while ps_coordinator.cluster._closure_queue._error is None: time.sleep(1) ps_coordinator.schedule(worker_fn) except errors.UnavailableError: # The following verifies that after PS fails, continue executing # functions on workers should fail and indicate it's PS failure. for worker_id in range(3): with ops.device( "/job:worker/replica:0/task:{}".format(worker_id)): try: # Executing a function after PS fails should result in a PS # failure. worker_fn() except Exception as e: # pylint: disable=broad-except if coordinator_lib._is_ps_failure(e): if worker_id < 2: continue logging.info( "_test_translate_ps_failure_error ends properly." ) # Now we can safely exit the test. test_finished_event.set() return raise RuntimeError( "Executing a function after PS fails, should " "result in a PS failure.") raise RuntimeError("UnavailableError supposed to be raised.")
def test_dataset_creator_input_options_with_cluster_coordinator(self): dataset_fn = lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 1]) input_options = distribute_lib.InputOptions( experimental_fetch_to_device=True, experimental_per_replica_buffer_size=2) x = dataset_creator.DatasetCreator(dataset_fn, input_options=input_options) strategy = self._get_parameter_server_strategy() with strategy.scope(): model = sequential.Sequential([core_layers.Dense(10)]) model._cluster_coordinator = cluster_coordinator.ClusterCoordinator( strategy) data_handler = data_adapter.get_data_handler(x, steps_per_epoch=2, model=model) iter_rv = iter(data_handler._dataset)._values[0] iter_rv._rebuild_on(model._cluster_coordinator._cluster.workers[0]) distributed_iterator = iter_rv._get_values() # Ensuring the resulting `DistributedIterator` has the right options. self.assertTrue( distributed_iterator._options.experimental_fetch_to_device) self.assertEqual( distributed_iterator._options.experimental_per_replica_buffer_size, 2)
def make_coordinator(num_workers, num_ps): cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc") return coordinator_lib.ClusterCoordinator( parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver))
def make_coordinator(num_workers, num_ps): # TODO(rchao): Test the internal rpc_layer version. cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc') cluster_def['chief'] = [ 'localhost:%d' % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer='grpc') strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) return coordinator_lib.ClusterCoordinator(strategy)
def setUpClass(cls): super(EvaluationTest, cls).setUpClass() cls._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=3, num_ps=2, rpc_layer="grpc") cls._cluster_def = cls._cluster.cluster_resolver.cluster_spec( ).as_dict() cluster_resolver = SimpleClusterResolver(server_lib.ClusterSpec( cls._cluster_def), rpc_layer="grpc") cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) cls.cluster_coord = coordinator_lib.ClusterCoordinator(cls.strategy)
def testCopyDistributedTable(self, source): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator_lib.ClusterCoordinator(strategy=strategy) with strategy.scope(): lookuptable = self.createStaticHashTable( init_source=source, vals=[0, 1, 2], default_value=-2) new_table = copy.copy(lookuptable) # No new coordinator instance or distributed tables are created. self.assertDictEqual(lookuptable.__dict__, new_table.__dict__)
def testCreateLookupInDatasetFnUnderScope(self, source): # TODO(wxinyi): Warn the user of the inefficiency of this workflow (i.e. # creating `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to # be distributed with `distribute_datasets_from_function` and # `create_per_worker_dataset`. strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy) with strategy.scope(): def dataset_fn(input_context): some_out_of_range_tensor = constant_op.constant( 10, dtype=dtypes.int64) lookuptable = self.createStaticHashTable(init_source=source, vals=[0, 1, 2], default_value=-2) self.assertNotIsInstance(lookuptable, ps_values.DistributedTable) generation_tensor = lookuptable.lookup( some_out_of_range_tensor) dataset = self.makeDatasetFromTensorWithoutUsingResource( input_context, generation_tensor) return dataset @def_function.function def per_worker_dataset_fn(): return strategy.distribute_datasets_from_function(dataset_fn) per_worker_dataset = coordinator.create_per_worker_dataset( per_worker_dataset_fn) per_worker_iterator = iter(per_worker_dataset) @def_function.function def worker_fn(iterator): return math_ops.reduce_sum(next(iterator)) result = [] for _ in range(10): result.append( coordinator.schedule(worker_fn, args=(per_worker_iterator, ))) for r in result: returned_input = r.fetch() self.assertAllClose(-48, returned_input)
def testCreateDistributedTableInScope(self, source): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator_lib.ClusterCoordinator(strategy=strategy) with strategy.scope(): lookuptable = self.createStaticHashTable( init_source=source, vals=[0, 1, 2], default_value=-2) self.assertIsInstance(lookuptable, ps_values.DistributedTable) self.assertEqual(self.evaluate(lookuptable.size()), 3) # Lookup on the coordinator. output = lookuptable.lookup( constant_op.constant([0, 1, -1], dtype=dtypes.int64)) self.assertAllEqual([0, 1, -2], output) self.assertEqual(lookuptable.size(), 3)
def testAccessingResourceHandleInDatasetFnWithMapFnDefinedOutside( self, source): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy) with strategy.scope(): lookuptable = self.createStaticHashTable(init_source=source, vals=[0, 1, 2], default_value=-2) def map_fn(vals): return lookuptable.lookup(vals) def dataset_fn(input_context): generation_tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64) dataset = self.makeDatasetFromTensorWithoutUsingResource( input_context, generation_tensor) dataset = dataset.map(map_fn) return dataset @def_function.function def per_worker_dataset_fn(): return strategy.distribute_datasets_from_function(dataset_fn) per_worker_dataset = coordinator.create_per_worker_dataset( per_worker_dataset_fn) per_worker_iterator = iter(per_worker_dataset) @def_function.function def worker_fn(iterator): return math_ops.reduce_sum(next(iterator)) result = [] for _ in range(10): # batch_size == 24 and each input is [0, 1, -2] result.append( coordinator.schedule(worker_fn, args=(per_worker_iterator, ))) for r in result: returned_input = r.fetch() self.assertAllClose(-24, returned_input)
def testInModelAndCapture(self, source): file_path = os.path.join(self.get_temp_dir(), "text_file_initializer") model = self.Model(source, file_path) func_captures = model.use_table.get_concrete_function( ).graph.external_captures self.assertLen(func_captures, 2) self.assertTrue( any(model.table.resource_handle is t for t in func_captures)) deferred_captures = model.use_table.get_concrete_function( ).graph.deferred_external_captures self.assertEmpty(deferred_captures) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): distributed_model = self.Model("value", file_path) func_captures = distributed_model.use_table.get_concrete_function( ).graph.external_captures # One less external_capture, since the table handle becomes a closure in the # deferred_external_capture self.assertLen(func_captures, 1) self.assertFalse( any(model.table.resource_handle is t for t in func_captures)) deferred_captures = distributed_model.use_table.get_concrete_function( ).graph.deferred_external_captures self.assertNotEmpty(deferred_captures) # assert capturing a worker-local resource on each worker for worker in coordinator._cluster.workers: with coordinator_context.with_dispatch_context(worker): for capture in [ t for t in distributed_model.use_table. get_concrete_function().captured_inputs if t.dtype == dtypes.resource ]: if capture.dtype == dtypes.resource: self.assertEqual( capture.device, device_util.canonicalize( "/CPU:0", default=worker.device_name))
def testAccessingResourceHandleInDatasetFnWithoutMap(self, source): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy) with strategy.scope(): lookuptable = self.createStaticHashTable(init_source=source, vals=[0, 1, 2], default_value=-2) def dataset_fn(input_context): some_out_of_range_tensor = constant_op.constant(10, dtype=dtypes.int64) self.assertIsInstance(lookuptable, ps_values.DistributedTable) generation_tensor = lookuptable.lookup(some_out_of_range_tensor) dataset = self.makeDatasetFromTensorWithoutUsingResource( input_context, generation_tensor) return dataset @def_function.function def per_worker_dataset_fn(): return strategy.distribute_datasets_from_function(dataset_fn) per_worker_dataset = coordinator.create_per_worker_dataset( per_worker_dataset_fn) per_worker_iterator = iter(per_worker_dataset) @def_function.function def worker_fn(iterator): return math_ops.reduce_sum(next(iterator)) result = [] for _ in range(10): result.append( coordinator.schedule(worker_fn, args=(per_worker_iterator, ))) for r in result: returned_input = r.fetch() self.assertAllClose(-48, returned_input)
def testBasicShardedVariableWithAggregation(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable( initial_value=[0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32, aggregation=variable_scope.VariableAggregation.SUM) self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 2) if strategy.num_replicas_in_sync > 1: self.assertIsInstance(v.variables[0], ps_values.AggregatingVariable) else: self.assertIsInstance(v.variables[0], variables.Variable) @def_function.function def worker_fn(): def replica_fn(): replica_id = distribution_strategy_context.get_replica_context( ).replica_id_in_sync_group val = array_ops.reshape( math_ops.cast(replica_id + 10, dtype=v.dtype), [1]) v.assign( array_ops.concat([ val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.]) ], 0)) strategy.run(replica_fn) coordinator.schedule(worker_fn) coordinator.join() expected_result = np.arange(8.) * strategy.num_replicas_in_sync for i in range(strategy.num_replicas_in_sync): expected_result[0] = expected_result[0] + i + 10 expected_result = np.array_split(expected_result, 2) self.assertAllEqual(expected_result[0], v.variables[0]) self.assertAllEqual(expected_result[1], v.variables[1])
def setUp(self, num_workers, num_ps): super(BaseFaultToleranceTest, self).setUp() self._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() self._cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc") # The strategy's constructor would connect to the cluster. self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy) self.thread_coord = thread_coordinator.Coordinator( clean_stop_exception_types=[]) self.num_workers = num_workers self.num_ps = num_ps
def fn(first_fetch_occurred_event, worker_terminated_event): os.environ["GRPC_FAIL_FAST"] = "use_caller" cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) ps_coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable(initial_value=0, dtype=dtypes.int32) @def_function.function def worker_fn(): return v + 1, v - 1 remote_value = ps_coordinator.schedule(worker_fn) logging.info("result (1st fetch): %r", remote_value.fetch()) first_fetch_occurred_event.set() worker_terminated_event.wait() logging.info("result (2nd fetch): %r", remote_value.fetch())
def fn(functions_scheduled_event): # TODO(b/170664373): This is needed for TF2 parameter server training in # OSS. Remove this when resolved. os.environ["GRPC_FAIL_FAST"] = "use_caller" cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) ps_client = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable(initial_value=1) @def_function.function def worker_fn(input_tensor): def replica_fn(input_tensor): return input_tensor + v run_result = strategy.run(replica_fn, args=(input_tensor, )) check_ops.assert_equal_v2(run_result, 4) return run_result for i in range(5000): if i % 500 == 0: logging.info("Scheduling function-{}...".format(i)) result = ps_client.schedule(worker_fn, args=(constant_op.constant(3), )) functions_scheduled_event.set() logging.info("Joining...") ps_client.join() logging.info("Finished joining.") if result.fetch() != 4: raise AssertionError( "Unexpected RemoteValue result: {}".format(result.fetch())) logging.info("testStrategyRun succeeded")
def testDistributeTableSaveAndLoadUnderStrategy(self, load, source): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator = coordinator_lib.ClusterCoordinator(strategy) file_path = os.path.join(self.get_temp_dir(), "text_file_initializer") with strategy.scope(): model = self.Model(source, file_path) model_dir = self.get_temp_dir() tf_save.save(model, model_dir) if load == "tf_load": load_fn = tf_load.load else: load_fn = keras_save.load_model with strategy.scope(): loaded = load_fn(model_dir) loaded_func_captures = ( loaded.use_table.get_concrete_function().graph.external_captures) loaded_func_deferred_captures = ( loaded.use_table.get_concrete_function().graph .deferred_external_captures) # Compared with loading without strategy, there is one less # external_capture, since the captured table handle has been swapped to a # closure in the deferred_external_capture self.assertLen(loaded_func_captures, 1) self.assertNotEmpty(loaded_func_deferred_captures) self.assertIsInstance(loaded.table, ps_values.DistributedTable) self.assertLen([ t for t in loaded.use_table.get_concrete_function().captured_inputs if t.dtype == dtypes.resource ], 1) self.verifyWorkerLocalInstance(coordinator, loaded)
def testClusterCoordinatorSingleInstance(self, strategy): model = self._model_fit(strategy) strategy = model.distribute_strategy self.assertIs(strategy._cluster_coordinator, coordinator_lib.ClusterCoordinator(strategy))
def make_coordinator(num_workers, num_ps): return coordinator_lib.ClusterCoordinator( parameter_server_strategy_v2.ParameterServerStrategyV2( make_cluster(num_workers, num_ps)))
def make_coordinator(num_workers, num_ps, variable_partitioner=None): return coordinator_lib.ClusterCoordinator( parameter_server_strategy_v2.ParameterServerStrategyV2( make_cluster(num_workers, num_ps), variable_partitioner=variable_partitioner))
def testClusterCoordinatorOnlyInitOnce(self): cluster = self.coordinator._cluster same_coordinator = coordinator_lib.ClusterCoordinator(self.strategy) self.assertIs(self.coordinator, same_coordinator) self.assertIs(cluster, same_coordinator._cluster)
def setUpClass(cls): super(KPLCreatedInDatasetsFromFunctionTest, cls).setUpClass() cls.coordinator = coordinator_lib.ClusterCoordinator( parameter_server_strategy_v2.ParameterServerStrategyV2( multi_worker_testing_utils.make_parameter_server_cluster(3, 2)))