def testCheckpoint(self, delayed, restore_shards): if test_util.is_xla_enabled() and not delayed and restore_shards == 4: self.skipTest( "TODO(b/202760274): Would raise an error that is to be " "investigated.") def make_variable(name, shape, dtype, initializer): initial_value = functools.partial(initializer, shape, dtype=dtype) return variables.Variable(name=name, initial_value=initial_value, shape=shape, dtype=dtype) class Model(tracking.AutoTrackable): def build(self): self.w = self._add_variable_with_custom_getter( "w", shape=(4, ), initializer=init_ops_v2.Ones(), getter=make_variable) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint") with strategy.scope(): model1 = Model() model1.build() self.assertIsInstance(model1.w, sharded_variable.ShardedVariable) self.assertLen(model1.w.variables, 2) model1.w.assign([1., 2., 3., 4.]) cp1 = tracking_util.Checkpoint(model=model1) cp1.write(ckpt_dir) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(restore_shards)) with strategy.scope(): model2 = Model() cp2 = tracking_util.Checkpoint(model=model2) if delayed: cp2.restore(ckpt_dir) model2.build() else: model2.build() cp2.restore(ckpt_dir) self.assertIsInstance(model2.w, sharded_variable.ShardedVariable) self.assertLen(model2.w.variables, restore_shards) if restore_shards == 2: self.assertAllEqual(model2.w.variables[0], [1., 2.]) self.assertAllEqual(model2.w.variables[1], [3., 4.]) elif restore_shards == 4: self.assertAllEqual(model2.w.variables[0], [1.]) self.assertAllEqual(model2.w.variables[1], [2.]) self.assertAllEqual(model2.w.variables[2], [3.]) self.assertAllEqual(model2.w.variables[3], [4.])
def testBasic(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) v1 = variables.Variable( initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64), shape=(5, 2), dtype=dtypes.int64) init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5]) v2 = variables.Variable( initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64), shape=(6, 1), dtype=dtypes.int64) self.assertIsInstance(v1, sharded_variable.ShardedVariable) self.assertLen(v1.variables, 2) self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v1.variables[0], [[0, 1], [2, 3], [4, 5]]) self.assertAllEqual(v1.variables[1], [[6, 7], [8, 9]]) self.assertIsInstance(v2, sharded_variable.ShardedVariable) self.assertLen(v2.variables, 2) self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v2.variables[0], [[0], [1], [2]]) self.assertAllEqual(v2.variables[1], [[3], [4], [5]])
def testCreateInsideTFFunction(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) collection = [] @def_function.function def create_vars(): if not collection: identity = init_ops_v2.Identity() v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32) v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32)) v3 = variables.Variable( lambda: identity((2, 2), dtypes.float32), dtype=dtypes.float32, shape=(2, 2)) collection.extend([v1, v2, v3]) with strategy.scope(): create_vars() for v in collection: self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 2) self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v.variables[0], [[1., 0.]]) self.assertAllEqual(v.variables[1], [[0., 1.]])
def _model_compile(self, steps_per_execution=1, run_eagerly=False, with_normalization_layer=False): class ResultAssertingCallback(callbacks_lib.Callback): def __init__(self): self._prev_epoch = -1 def on_epoch_end(self, epoch, logs=None): logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs) if epoch <= self._prev_epoch: raise RuntimeError("Epoch is supposed to be larger than previous.") self._prev_epoch = epoch if (logs.get("loss", None) is None or not isinstance(logs["loss"], np.floating)): raise RuntimeError("loss is supposed to be in the logs and float.") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( make_cluster(3, 2), variable_partitioner=sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): model = sequential.Sequential([core_layers.Dense(10)]) if with_normalization_layer: norm = keras.layers.BatchNormalization( axis=-1, input_shape=(4, 4, 3), momentum=0.8) model.add(norm) model.compile( gradient_descent.SGD(), loss="mse", steps_per_execution=steps_per_execution, run_eagerly=run_eagerly) return model, [ResultAssertingCallback()]
def testBasicShardedVariableWithAggregation(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) strategy.extended._allow_run_without_coordinator = True with strategy.scope(): v = variables.Variable( initial_value=[0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32, aggregation=variable_scope.VariableAggregation.SUM) self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 2) if strategy.num_replicas_in_sync > 1: self.assertIsInstance(v.variables[0], ps_values.AggregatingVariable) else: self.assertIsInstance(v.variables[0], variables.Variable) def replica_fn(): replica_id = distribution_strategy_context.get_replica_context( ).replica_id_in_sync_group val = array_ops.reshape( math_ops.cast(replica_id + 10, dtype=v.dtype), [1]) v.assign( array_ops.concat( [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0)) strategy.run(replica_fn) expected_result = np.arange(8.) * strategy.num_replicas_in_sync for i in range(strategy.num_replicas_in_sync): expected_result[0] = expected_result[0] + i + 10 expected_result = np.array_split(expected_result, 2) self.assertAllEqual(expected_result[0], v.variables[0]) self.assertAllEqual(expected_result[1], v.variables[1])
def _model_compile(self, strategy, steps_per_execution=1, run_eagerly=False, with_normalization_layer=False): class ResultAssertingCallback(callbacks_lib.Callback): def __init__(self): self._prev_epoch = -1 self._loss_to_compare_against = 2 # Empirical initial value def on_epoch_end(self, epoch, logs=None): logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs) if epoch <= self._prev_epoch: raise RuntimeError("Epoch is supposed to be larger than previous.") self._prev_epoch = epoch is_loss_float = ( logs.get("loss", None) is not None and isinstance(logs["loss"], (float, np.floating))) if not is_loss_float: raise RuntimeError("loss is supposed to be in the logs and float.") if epoch == 0 or epoch == 9: # Making sure the loss of first epoch is below 1, and that of last # epoch is smaller than the first epoch. if logs["loss"] > self._loss_to_compare_against: raise RuntimeError( "loss at epoch {} is larger than previous.".format(epoch)) self._loss_to_compare_against = logs["loss"] def on_train_end(self, logs=None): if self._prev_epoch != 9: raise RuntimeError("Unexpected last epoch: {}".format( self._prev_epoch)) # TODO(b/182193218): Use ParameterServerStrategy as a proper strategy # combination. if strategy == "ParameterServerStrategy": gpu_devices = config.list_physical_devices("GPU") if len(gpu_devices) > 1: self.skipTest("b/178452835: Multi-GPUs not supported in " "ParameterServerStrategy.") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( multi_worker_testing_utils.make_parameter_server_cluster(3, 2), variable_partitioner=sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): model = sequential.Sequential([core_layers.Dense(10)]) if with_normalization_layer: norm = keras.layers.BatchNormalization( axis=-1, input_shape=(4, 4, 3), momentum=0.8) model.add(norm) model.compile( gradient_descent.SGD(), loss="mse", steps_per_execution=steps_per_execution, run_eagerly=run_eagerly) return model, [ResultAssertingCallback()]
def _get_ps_strategy_creator( num_workers, num_ps, required_gpus=0, variable_partitioner=sharded_variable.FixedShardsPartitioner(2)): def _create_ps_strategy(resolver, variable_partitioner): return parameter_server_strategy_v2.ParameterServerStrategyV2( resolver, variable_partitioner=variable_partitioner) def _create_parameter_server(): if framework_test_util.is_xla_enabled(): # To address test failures resulting in XLA with MultiProcessRunner, # continue to use in-process cluster for XLA tests. cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") resolver = cluster_resolver.SimpleClusterResolver( server_lib.ClusterSpec(cluster_def), num_accelerators={"GPU": required_gpus}, rpc_layer="grpc") return _create_ps_strategy(resolver, variable_partitioner) else: tf_config = cluster_resolver.TFConfigClusterResolver() cluster_def = tf_config.cluster_spec().as_dict() if not cluster_def: # When MultiProcessRunner cluster is used, the cluster is not created # initially when the decorator is called. When the test runs, initially # this method is invoked via decorator before setting up the # MultiProcessRunner with worker and ps in the combinations.py. After # setup is done, the subprocess invokes this method again to get # strategy object. We return None strategy when the main thread invokes # this method before setting up cluster. # Returning None is fine here, since this thread will proceed to create # MultiProcessRunner and invoke tests with decorator inside # subprocesses. return None # MultiProcessRunner is already setup and this method is invoked from a # subprocess running the actual test. resolver = cluster_resolver.SimpleClusterResolver( server_lib.ClusterSpec(cluster_def), num_accelerators={"GPU": required_gpus}, task_type=tf_config.task_type, task_id=tf_config.task_id, environment=tf_config.environment, rpc_layer=tf_config.rpc_layer or "grpc") if tf_config.task_type in ("worker", "ps"): worker_config = config_pb2.ConfigProto() worker_config.inter_op_parallelism_threads = 4 # max num_workers + 1 server = server_lib.Server(cluster_def, job_name=tf_config.task_type, task_index=tf_config.task_id, protocol="grpc", config=worker_config) # Blocking the process that starts a server from exiting. server.join() return _create_ps_strategy(resolver, variable_partitioner) return _create_parameter_server
def _create_strategy(self, num_shards): if num_shards > 1: strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, variable_partitioner=sharded_variable.FixedShardsPartitioner( num_shards)) else: strategy = ds_context._get_default_strategy() return strategy
def _create_parameter_server(): cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") resolver = cluster_resolver.SimpleClusterResolver( ClusterSpec(cluster_def), num_accelerators={"GPU": required_gpus}, rpc_layer="grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( resolver, variable_partitioner=sharded_variable.FixedShardsPartitioner(2)) return strategy
def parameter_server_strategy_fn( name, num_workers, num_ps, required_gpus=0, variable_partitioner=sharded_variable.FixedShardsPartitioner(2)): return combinations.NamedDistribution( name, _get_ps_strategy_creator( num_workers=num_workers, num_ps=num_ps, required_gpus=required_gpus, variable_partitioner=variable_partitioner), required_gpus=required_gpus, num_workers=num_workers, has_chief=True, num_ps=num_ps)
def testNumPartitionsLargerThanSize(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4)) with strategy.scope(): v = variables.Variable([0, 1, 2]) self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 3) self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0") self.assertAllEqual(v.variables[0], [0]) self.assertAllEqual(v.variables[1], [1]) self.assertAllEqual(v.variables[2], [2])
def testColocateWith(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): v1 = variables.Variable([0, 1, 2, 3]) with strategy.extended.colocate_vars_with(v1.variables[0]): v2 = variables.Variable([4, 5]) self.assertIsInstance(v1, sharded_variable.ShardedVariable) self.assertIsInstance(v2, variables.Variable) self.assertNotIsInstance(v2, sharded_variable.ShardedVariable) self.assertEqual(v2.device, v1.variables[0].device) self.assertAllEqual(v2, [4, 5])
def testCustomPartitionAwareInitializer(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): initializer = PartitionAwareIdentity() initial_value = functools.partial( initializer, shape=(4, 4), dtype=dtypes.int64) v = variables.Variable( initial_value=initial_value, shape=(4, 4), dtype=dtypes.int64) self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 2) self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v.variables[0], [[1, 0, 0, 0], [0, 1, 0, 0]]) self.assertAllEqual(v.variables[1], [[0, 0, 1, 0], [0, 0, 0, 1]])
def testNonCallableInitialValue(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4)) with strategy.scope(): v = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 4) self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0") self.assertRegex(v.variables[3].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v.variables[0], [0, 1, 2]) self.assertAllEqual(v.variables[1], [3, 4, 5]) self.assertAllEqual(v.variables[2], [6, 7]) self.assertAllEqual(v.variables[3], [8, 9])
def testPartitionWhenLackOfInfo(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): initializer = init_ops_v2.Constant([0, 1, 2, 3]) # Shape is not explicitly specified. v1 = variables.Variable( initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64), dtype=dtypes.int64) # Dtype is not explicitly specified. v2 = variables.Variable( initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64), shape=(4,)) # Neither shape nor dtype is explicitly specified. v3 = variables.Variable( initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64)) for v in [v1, v2, v3]: self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 2) self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v.variables[0], [0, 1]) self.assertAllEqual(v.variables[1], [2, 3])
def test_fixed_shards_partitioner(self): partitioner = sharded_variable.FixedShardsPartitioner(num_shards=2) got = partitioner(tensor_shape.TensorShape([10, 3]), dtypes.float32) self.assertAllEqual(got, [2, 1])
class GeneratorTest(test.TestCase, parameterized.TestCase): def setUp(self): super(GeneratorTest, self).setUp() v2_compat.enable_v2_behavior() def assertAllDifferent(self, tensors): """Checks that there are no duplicate elements anywhere among the tensors. Args: tensors: a list of tensors. They can have different shapes. """ values = [array_ops.reshape(t, shape=[-1]) for t in tensors] values = array_ops.concat(values, axis=0) values = self.evaluate(values) values = values.tolist() self.assertAllEqual(len(values), len(set(values))) @test_util.run_v2_only def testCreateOutsideMirroredStrat(self): """Tests RNG/MirrorStrategy interaction #1. If an RNG is created outside a DS scope, all replicas will access the same RNG object, and accesses are serialized. """ shape = [3, 4] dtype = dtypes.int32 gen = rng.Generator.from_seed(1234) strat = MirroredStrategy(devices=["cpu:0", "cpu:1"]) with strat.scope(): def f(): t1 = gen.uniform_full_int(shape=shape, dtype=dtype) t2 = gen.uniform_full_int(shape=shape, dtype=dtype) t = array_ops.stack([t1, t2]) return t results = strat.extended.call_for_each_replica(fn=f) values = results.values self.assertAllEqual(2, len(values)) self.assertAllDifferent(values) @test_util.run_v2_only def testMirroredStratParaAsync(self): """Tests RNG/MirrorStrategy interaction #2. The user can create n independent RNGs outside strategy.scope(), where n is the number of replicas, and give one to each replica. The replicas can thus get different random-number streams. """ shape = [3, 4] dtype = dtypes.int32 gens = rng.get_global_generator().split(count=2) devices = ["cpu:0", "cpu:1"] strat = MirroredStrategy(devices=devices) # Use `PerReplica` to specify which `gen` is sent to which replica gens = dist_values.PerReplica([[g] for g in gens]) with strat.scope(): def f(gen): t1 = gen.uniform_full_int(shape=shape, dtype=dtype) t2 = gen.uniform_full_int(shape=shape, dtype=dtype) t = array_ops.stack([t1, t2]) return t results = strat.extended.call_for_each_replica(fn=f, args=gens) local_results = strat.experimental_local_results(results) self.assertAllEqual(2, len(local_results)) self.assertAllDifferent(local_results) @ds_combinations.generate( combinations.combine( strat=all_strategies, mode=["eager"])) def testCrossReplica(self, strat): """Tests that RNG can be properly advanced in cross-replica context.""" def read_values(dv): return [v.read_value() for v in strat.experimental_local_results(dv)] with strat.scope(): g = rng.Generator.from_seed(1) s1 = read_values(g.state) g.normal([3]) g.skip(4) s2 = read_values(g.state) self.assertNotAllEqual(s1[0], s2[0]) self.assertEqual(len(s1), len(s2)) for i in range(1, len(s1)): self.assertAllEqual(s1[0], s1[i]) self.assertAllEqual(s2[0], s2[i]) @ds_combinations.generate( combinations.combine( strat=all_strategies, mode=["eager"], jit_replica_fn=[False, True], seeded=[True, False],)) def testDistStrat(self, strat, jit_replica_fn, seeded): """Tests RNG with distribution strategies.""" strat_name = type(strat).__name__ if "TPU" in strat_name and not jit_replica_fn: self.skipTest( "TPUStrategy requires the replica function (the function passed to " "strategy.run) to be decorated with tf.function") coord = None if "ParameterServer" in strat_name: coord = coordinator_lib.ClusterCoordinator(strat) creators = { True: functools.partial(rng.Generator.from_seed, 1234), False: rng.Generator.from_non_deterministic_state, } shape = [3, 4] dtype = dtypes.int32 creator = creators[seeded] with strat.scope(): gen = creator() def f(): t1 = gen.uniform_full_int(shape=shape, dtype=dtype) t2 = gen.uniform_full_int(shape=shape, dtype=dtype) t = array_ops.stack([t1, t2]) return t replica_fn = def_function.function(f) if jit_replica_fn else f results = run_on_strategy(replica_fn, strat, coord) values = strat.experimental_local_results(results) n = get_num_local_replicas(strat, values) self.assertAllEqual(n, len(values)) self.assertAllDifferent(values) @ds_combinations.generate( combinations.combine( strat=[ strategy_combinations.parameter_server_strategy_fn( "ParameterServer1Worker2PSCPUFixedShards", num_workers=1, num_ps=2, variable_partitioner=( sharded_variable.FixedShardsPartitioner(2))) ], mode=["eager"])) def testShardedError(self, strat): """Tests error about sharding is raised.""" with strat.scope(): with self.assertRaisesRegex( ValueError, "state is sharded, which is not allowed"): rng.Generator.from_seed(1234) @ds_combinations.generate( combinations.combine( strat=all_strategies, mode=["eager"], jit_replica_fn=[False, True])) def testDistVarAsTFFunArg(self, strat, jit_replica_fn): """Tests that RNG with dist variables can be used as tf.function's arg.""" strat_name = type(strat).__name__ if "CentralStorage" in strat_name: self.skipTest( "CentralStorageStrategy wraps variable updates in merge_call which " "can't be called inside a tf.function that doesn't cover the entire " "replica function (the function passed to strategy.run).") if "TPU" in strat_name and not jit_replica_fn: self.skipTest( "TPUStrategy requires the replica function (the function passed to " "strategy.run) to be decorated with tf.function") coord = None if "ParameterServer" in strat_name: coord = coordinator_lib.ClusterCoordinator(strat) shape = [3, 4] dtype = dtypes.int32 with strat.scope(): gen = rng.Generator.from_seed(1234) @def_function.function def f(gen): # the main focus t1 = gen.uniform_full_int(shape=shape, dtype=dtype) t2 = gen.uniform_full_int(shape=shape, dtype=dtype) t = array_ops.stack([t1, t2]) return t def g(): return f(gen) replica_fn = def_function.function(g) if jit_replica_fn else g for _ in range(2): results = run_on_strategy(replica_fn, strat, coord) values = strat.experimental_local_results(results) n = get_num_local_replicas(strat, values) self.assertAllEqual(n, len(values)) self.assertAllDifferent(values) @ds_combinations.generate( combinations.combine( strat1=strategy_combinations.all_strategies, strat2=strategy_combinations.all_strategies, jit_replica_fn=[False, True], mode=["eager"]) + combinations.combine( strat1=strategy_combinations.multiworker_strategies + ps_strategies, strat2=[None], jit_replica_fn=[False, True], mode=["eager"])) def testDistStratRestore(self, strat1, strat2, jit_replica_fn): """Tests checkpointing and restoring (to possibly different #replicas).""" if strat2 is None: strat2 = strat1 strat1_name = type(strat1).__name__ strat2_name = type(strat2).__name__ if "Default" in strat1_name or "Default" in strat2_name: self.skipTest( "We don't guarantee consistency between strategy and no-strategy.") if ("TPU" in strat1_name or "TPU" in strat2_name) and not jit_replica_fn: self.skipTest( "TPUStrategy requires the replica function (the function passed to " "strategy.run) to be decorated with tf.function") coord1 = None if "ParameterServer" in strat1_name: coord1 = coordinator_lib.ClusterCoordinator(strat1) coord2 = None if "ParameterServer" in strat2_name: coord2 = coordinator_lib.ClusterCoordinator(strat2) fname = os.path.join(self.get_temp_dir(), "checkpoint") def uniform(strat, coord, g): def f(): return g.uniform_full_int([3], dtype=dtypes.int32) replica_fn = def_function.function(f) if jit_replica_fn else f result = run_on_strategy(replica_fn, strat, coord) return strat.experimental_local_results(result) with strat1.scope(): g1 = rng.Generator.from_seed(1) with strat2.scope(): g2 = rng.Generator.from_seed(10) cp1 = tracking_util.Checkpoint(g=g1) cp2 = tracking_util.Checkpoint(g=g2) def write_restore_compare(): cp1.write(fname) r1 = uniform(strat1, coord1, g1) cp2.restore(fname) r2 = uniform(strat2, coord2, g2) # Tests that overlapping replicas are properly restored. n1 = get_num_local_replicas(strat1) n2 = get_num_local_replicas(strat2) n = min(n1, n2) self.assertAllEqual(r1[:n], r2[:n]) # Run multiple times so that cp1.write is called in various RNG states for _ in range(2): write_restore_compare() @ds_combinations.generate( combinations.combine( strat=all_strategies, mode=["eager"], is_save_in_scope=[True, False])) def testSavedModel(self, strat, is_save_in_scope): class CustomModule(module.Module): def __init__(self): super(CustomModule, self).__init__() self.g = rng.Generator.from_seed(0) @def_function.function def __call__(self): return self.g.state @def_function.function def mutate(self): self.g.normal([]) with strat.scope(): m = CustomModule() m.mutate() state_before = m() path = os.path.join(self.get_temp_dir(), "saved_model") if is_save_in_scope: with strat.scope(): save.save(m, path) else: save.save(m, path) with strat.scope(): m.mutate() state_before_2 = m() imported = load.load(path) state_after = imported() self.assertAllEqual(state_before, state_after) imported.mutate() state_after_2 = imported() self.assertAllEqual(state_before_2, state_after_2)
def setUpClass(cls): super().setUpClass() cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( make_cluster(3, 2), variable_partitioner=sharded_variable.FixedShardsPartitioner(2))
def _create_ps_strategy(resolver): return parameter_server_strategy_v2.ParameterServerStrategyV2( resolver, variable_partitioner=sharded_variable.FixedShardsPartitioner(2))