def testCheckpoint(self, delayed, restore_shards):

        if test_util.is_xla_enabled() and not delayed and restore_shards == 4:
            self.skipTest(
                "TODO(b/202760274): Would raise an error that is to be "
                "investigated.")

        def make_variable(name, shape, dtype, initializer):
            initial_value = functools.partial(initializer, shape, dtype=dtype)
            return variables.Variable(name=name,
                                      initial_value=initial_value,
                                      shape=shape,
                                      dtype=dtype)

        class Model(tracking.AutoTrackable):
            def build(self):
                self.w = self._add_variable_with_custom_getter(
                    "w",
                    shape=(4, ),
                    initializer=init_ops_v2.Ones(),
                    getter=make_variable)

        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
        ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint")

        with strategy.scope():
            model1 = Model()
            model1.build()
            self.assertIsInstance(model1.w, sharded_variable.ShardedVariable)
            self.assertLen(model1.w.variables, 2)
            model1.w.assign([1., 2., 3., 4.])

            cp1 = tracking_util.Checkpoint(model=model1)
            cp1.write(ckpt_dir)

        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver,
            sharded_variable.FixedShardsPartitioner(restore_shards))

        with strategy.scope():
            model2 = Model()
            cp2 = tracking_util.Checkpoint(model=model2)
            if delayed:
                cp2.restore(ckpt_dir)
                model2.build()
            else:
                model2.build()
                cp2.restore(ckpt_dir)
            self.assertIsInstance(model2.w, sharded_variable.ShardedVariable)
            self.assertLen(model2.w.variables, restore_shards)
            if restore_shards == 2:
                self.assertAllEqual(model2.w.variables[0], [1., 2.])
                self.assertAllEqual(model2.w.variables[1], [3., 4.])
            elif restore_shards == 4:
                self.assertAllEqual(model2.w.variables[0], [1.])
                self.assertAllEqual(model2.w.variables[1], [2.])
                self.assertAllEqual(model2.w.variables[2], [3.])
                self.assertAllEqual(model2.w.variables[3], [4.])
예제 #2
0
  def testBasic(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
    with strategy.scope():
      init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
      v1 = variables.Variable(
          initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64),
          shape=(5, 2),
          dtype=dtypes.int64)

      init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5])
      v2 = variables.Variable(
          initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64),
          shape=(6, 1),
          dtype=dtypes.int64)

    self.assertIsInstance(v1, sharded_variable.ShardedVariable)
    self.assertLen(v1.variables, 2)
    self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0")
    self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1")
    self.assertAllEqual(v1.variables[0], [[0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(v1.variables[1], [[6, 7], [8, 9]])

    self.assertIsInstance(v2, sharded_variable.ShardedVariable)
    self.assertLen(v2.variables, 2)
    self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0")
    self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1")
    self.assertAllEqual(v2.variables[0], [[0], [1], [2]])
    self.assertAllEqual(v2.variables[1], [[3], [4], [5]])
예제 #3
0
  def testCreateInsideTFFunction(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))

    collection = []

    @def_function.function
    def create_vars():
      if not collection:
        identity = init_ops_v2.Identity()
        v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32)
        v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32))
        v3 = variables.Variable(
            lambda: identity((2, 2), dtypes.float32),
            dtype=dtypes.float32,
            shape=(2, 2))
        collection.extend([v1, v2, v3])

    with strategy.scope():
      create_vars()
      for v in collection:
        self.assertIsInstance(v, sharded_variable.ShardedVariable)
        self.assertLen(v.variables, 2)
        self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
        self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
        self.assertAllEqual(v.variables[0], [[1., 0.]])
        self.assertAllEqual(v.variables[1], [[0., 1.]])
예제 #4
0
  def _model_compile(self,
                     steps_per_execution=1,
                     run_eagerly=False,
                     with_normalization_layer=False):

    class ResultAssertingCallback(callbacks_lib.Callback):

      def __init__(self):
        self._prev_epoch = -1

      def on_epoch_end(self, epoch, logs=None):
        logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs)
        if epoch <= self._prev_epoch:
          raise RuntimeError("Epoch is supposed to be larger than previous.")
        self._prev_epoch = epoch
        if (logs.get("loss", None) is None or
            not isinstance(logs["loss"], np.floating)):
          raise RuntimeError("loss is supposed to be in the logs and float.")

    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        make_cluster(3, 2),
        variable_partitioner=sharded_variable.FixedShardsPartitioner(2))
    with strategy.scope():
      model = sequential.Sequential([core_layers.Dense(10)])
      if with_normalization_layer:
        norm = keras.layers.BatchNormalization(
            axis=-1, input_shape=(4, 4, 3), momentum=0.8)
        model.add(norm)

    model.compile(
        gradient_descent.SGD(),
        loss="mse",
        steps_per_execution=steps_per_execution,
        run_eagerly=run_eagerly)
    return model, [ResultAssertingCallback()]
예제 #5
0
  def testBasicShardedVariableWithAggregation(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
    strategy.extended._allow_run_without_coordinator = True
    with strategy.scope():
      v = variables.Variable(
          initial_value=[0, 0, 0, 0, 0, 0, 0, 0],
          dtype=dtypes.float32,
          aggregation=variable_scope.VariableAggregation.SUM)

    self.assertIsInstance(v, sharded_variable.ShardedVariable)
    self.assertLen(v.variables, 2)
    if strategy.num_replicas_in_sync > 1:
      self.assertIsInstance(v.variables[0], ps_values.AggregatingVariable)
    else:
      self.assertIsInstance(v.variables[0], variables.Variable)

    def replica_fn():
      replica_id = distribution_strategy_context.get_replica_context(
      ).replica_id_in_sync_group
      val = array_ops.reshape(
          math_ops.cast(replica_id + 10, dtype=v.dtype), [1])
      v.assign(
          array_ops.concat(
              [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0))

    strategy.run(replica_fn)

    expected_result = np.arange(8.) * strategy.num_replicas_in_sync
    for i in range(strategy.num_replicas_in_sync):
      expected_result[0] = expected_result[0] + i + 10
    expected_result = np.array_split(expected_result, 2)
    self.assertAllEqual(expected_result[0], v.variables[0])
    self.assertAllEqual(expected_result[1], v.variables[1])
예제 #6
0
  def _model_compile(self,
                     strategy,
                     steps_per_execution=1,
                     run_eagerly=False,
                     with_normalization_layer=False):

    class ResultAssertingCallback(callbacks_lib.Callback):

      def __init__(self):
        self._prev_epoch = -1
        self._loss_to_compare_against = 2  # Empirical initial value

      def on_epoch_end(self, epoch, logs=None):
        logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs)
        if epoch <= self._prev_epoch:
          raise RuntimeError("Epoch is supposed to be larger than previous.")
        self._prev_epoch = epoch
        is_loss_float = (
            logs.get("loss", None) is not None and
            isinstance(logs["loss"], (float, np.floating)))
        if not is_loss_float:
          raise RuntimeError("loss is supposed to be in the logs and float.")
        if epoch == 0 or epoch == 9:
          # Making sure the loss of first epoch is below 1, and that of last
          # epoch is smaller than the first epoch.
          if logs["loss"] > self._loss_to_compare_against:
            raise RuntimeError(
                "loss at epoch {} is larger than previous.".format(epoch))
          self._loss_to_compare_against = logs["loss"]

      def on_train_end(self, logs=None):
        if self._prev_epoch != 9:
          raise RuntimeError("Unexpected last epoch: {}".format(
              self._prev_epoch))

    # TODO(b/182193218): Use ParameterServerStrategy as a proper strategy
    # combination.
    if strategy == "ParameterServerStrategy":
      gpu_devices = config.list_physical_devices("GPU")
      if len(gpu_devices) > 1:
        self.skipTest("b/178452835: Multi-GPUs not supported in "
                      "ParameterServerStrategy.")
      strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
          multi_worker_testing_utils.make_parameter_server_cluster(3, 2),
          variable_partitioner=sharded_variable.FixedShardsPartitioner(2))

    with strategy.scope():
      model = sequential.Sequential([core_layers.Dense(10)])
      if with_normalization_layer:
        norm = keras.layers.BatchNormalization(
            axis=-1, input_shape=(4, 4, 3), momentum=0.8)
        model.add(norm)

    model.compile(
        gradient_descent.SGD(),
        loss="mse",
        steps_per_execution=steps_per_execution,
        run_eagerly=run_eagerly)
    return model, [ResultAssertingCallback()]
예제 #7
0
def _get_ps_strategy_creator(
    num_workers,
    num_ps,
    required_gpus=0,
    variable_partitioner=sharded_variable.FixedShardsPartitioner(2)):
    def _create_ps_strategy(resolver, variable_partitioner):
        return parameter_server_strategy_v2.ParameterServerStrategyV2(
            resolver, variable_partitioner=variable_partitioner)

    def _create_parameter_server():
        if framework_test_util.is_xla_enabled():
            # To address test failures resulting in XLA with MultiProcessRunner,
            # continue to use in-process cluster for XLA tests.
            cluster_def = multi_worker_test_base.create_in_process_cluster(
                num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
            resolver = cluster_resolver.SimpleClusterResolver(
                server_lib.ClusterSpec(cluster_def),
                num_accelerators={"GPU": required_gpus},
                rpc_layer="grpc")
            return _create_ps_strategy(resolver, variable_partitioner)
        else:
            tf_config = cluster_resolver.TFConfigClusterResolver()
            cluster_def = tf_config.cluster_spec().as_dict()
            if not cluster_def:
                # When MultiProcessRunner cluster is used, the cluster is not created
                # initially when the decorator is called. When the test runs, initially
                # this method is invoked via decorator before setting up the
                # MultiProcessRunner with worker and ps in the combinations.py. After
                # setup is done, the subprocess invokes this method again to get
                # strategy object. We return None strategy when the main thread invokes
                # this method before setting up cluster.
                # Returning None is fine here, since this thread will proceed to create
                # MultiProcessRunner and invoke tests with decorator inside
                # subprocesses.
                return None
            # MultiProcessRunner is already setup and this method is invoked from a
            # subprocess running the actual test.
            resolver = cluster_resolver.SimpleClusterResolver(
                server_lib.ClusterSpec(cluster_def),
                num_accelerators={"GPU": required_gpus},
                task_type=tf_config.task_type,
                task_id=tf_config.task_id,
                environment=tf_config.environment,
                rpc_layer=tf_config.rpc_layer or "grpc")
            if tf_config.task_type in ("worker", "ps"):
                worker_config = config_pb2.ConfigProto()
                worker_config.inter_op_parallelism_threads = 4  # max num_workers + 1
                server = server_lib.Server(cluster_def,
                                           job_name=tf_config.task_type,
                                           task_index=tf_config.task_id,
                                           protocol="grpc",
                                           config=worker_config)

                # Blocking the process that starts a server from exiting.
                server.join()

            return _create_ps_strategy(resolver, variable_partitioner)

    return _create_parameter_server
예제 #8
0
 def _create_strategy(self, num_shards):
     if num_shards > 1:
         strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
             self.cluster_resolver,
             variable_partitioner=sharded_variable.FixedShardsPartitioner(
                 num_shards))
     else:
         strategy = ds_context._get_default_strategy()
     return strategy
예제 #9
0
    def _create_parameter_server():

        cluster_def = multi_worker_test_base.create_in_process_cluster(
            num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
        resolver = cluster_resolver.SimpleClusterResolver(
            ClusterSpec(cluster_def),
            num_accelerators={"GPU": required_gpus},
            rpc_layer="grpc")
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            resolver,
            variable_partitioner=sharded_variable.FixedShardsPartitioner(2))
        return strategy
예제 #10
0
def parameter_server_strategy_fn(
    name, num_workers, num_ps, required_gpus=0,
    variable_partitioner=sharded_variable.FixedShardsPartitioner(2)):
  return combinations.NamedDistribution(
      name,
      _get_ps_strategy_creator(
          num_workers=num_workers, num_ps=num_ps, required_gpus=required_gpus,
          variable_partitioner=variable_partitioner),
      required_gpus=required_gpus,
      num_workers=num_workers,
      has_chief=True,
      num_ps=num_ps)
예제 #11
0
  def testNumPartitionsLargerThanSize(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4))
    with strategy.scope():
      v = variables.Variable([0, 1, 2])

    self.assertIsInstance(v, sharded_variable.ShardedVariable)
    self.assertLen(v.variables, 3)
    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
    self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0")
    self.assertAllEqual(v.variables[0], [0])
    self.assertAllEqual(v.variables[1], [1])
    self.assertAllEqual(v.variables[2], [2])
예제 #12
0
  def testColocateWith(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
    with strategy.scope():
      v1 = variables.Variable([0, 1, 2, 3])

      with strategy.extended.colocate_vars_with(v1.variables[0]):
        v2 = variables.Variable([4, 5])

    self.assertIsInstance(v1, sharded_variable.ShardedVariable)

    self.assertIsInstance(v2, variables.Variable)
    self.assertNotIsInstance(v2, sharded_variable.ShardedVariable)
    self.assertEqual(v2.device, v1.variables[0].device)
    self.assertAllEqual(v2, [4, 5])
예제 #13
0
  def testCustomPartitionAwareInitializer(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
    with strategy.scope():
      initializer = PartitionAwareIdentity()
      initial_value = functools.partial(
          initializer, shape=(4, 4), dtype=dtypes.int64)
      v = variables.Variable(
          initial_value=initial_value, shape=(4, 4), dtype=dtypes.int64)

    self.assertIsInstance(v, sharded_variable.ShardedVariable)
    self.assertLen(v.variables, 2)
    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
    self.assertAllEqual(v.variables[0], [[1, 0, 0, 0], [0, 1, 0, 0]])
    self.assertAllEqual(v.variables[1], [[0, 0, 1, 0], [0, 0, 0, 1]])
예제 #14
0
  def testNonCallableInitialValue(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4))
    with strategy.scope():
      v = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

    self.assertIsInstance(v, sharded_variable.ShardedVariable)
    self.assertLen(v.variables, 4)
    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
    self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0")
    self.assertRegex(v.variables[3].device, "/job:ps/replica:0/task:1")
    self.assertAllEqual(v.variables[0], [0, 1, 2])
    self.assertAllEqual(v.variables[1], [3, 4, 5])
    self.assertAllEqual(v.variables[2], [6, 7])
    self.assertAllEqual(v.variables[3], [8, 9])
예제 #15
0
  def testPartitionWhenLackOfInfo(self):
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
    with strategy.scope():
      initializer = init_ops_v2.Constant([0, 1, 2, 3])
      # Shape is not explicitly specified.
      v1 = variables.Variable(
          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64),
          dtype=dtypes.int64)
      # Dtype is not explicitly specified.
      v2 = variables.Variable(
          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64),
          shape=(4,))
      # Neither shape nor dtype is explicitly specified.
      v3 = variables.Variable(
          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64))

    for v in [v1, v2, v3]:
      self.assertIsInstance(v, sharded_variable.ShardedVariable)
      self.assertLen(v.variables, 2)
      self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
      self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
      self.assertAllEqual(v.variables[0], [0, 1])
      self.assertAllEqual(v.variables[1], [2, 3])
예제 #16
0
 def test_fixed_shards_partitioner(self):
   partitioner = sharded_variable.FixedShardsPartitioner(num_shards=2)
   got = partitioner(tensor_shape.TensorShape([10, 3]), dtypes.float32)
   self.assertAllEqual(got, [2, 1])
예제 #17
0
class GeneratorTest(test.TestCase, parameterized.TestCase):

  def setUp(self):
    super(GeneratorTest, self).setUp()
    v2_compat.enable_v2_behavior()

  def assertAllDifferent(self, tensors):
    """Checks that there are no duplicate elements anywhere among the tensors.

    Args:
      tensors: a list of tensors. They can have different shapes.
    """
    values = [array_ops.reshape(t, shape=[-1]) for t in tensors]
    values = array_ops.concat(values, axis=0)
    values = self.evaluate(values)
    values = values.tolist()
    self.assertAllEqual(len(values), len(set(values)))

  @test_util.run_v2_only
  def testCreateOutsideMirroredStrat(self):
    """Tests RNG/MirrorStrategy interaction #1.

    If an RNG is created outside a DS scope, all replicas will access the
    same RNG object, and accesses are serialized.
    """
    shape = [3, 4]
    dtype = dtypes.int32
    gen = rng.Generator.from_seed(1234)
    strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
    with strat.scope():

      def f():
        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t = array_ops.stack([t1, t2])
        return t

      results = strat.extended.call_for_each_replica(fn=f)
      values = results.values
      self.assertAllEqual(2, len(values))
      self.assertAllDifferent(values)

  @test_util.run_v2_only
  def testMirroredStratParaAsync(self):
    """Tests RNG/MirrorStrategy interaction #2.

    The user can create n independent RNGs outside strategy.scope(), where n
    is the number of replicas, and give one to each replica. The replicas can
    thus get different random-number streams.
    """
    shape = [3, 4]
    dtype = dtypes.int32
    gens = rng.get_global_generator().split(count=2)
    devices = ["cpu:0", "cpu:1"]
    strat = MirroredStrategy(devices=devices)
    # Use `PerReplica` to specify which `gen` is sent to which replica
    gens = dist_values.PerReplica([[g] for g in gens])
    with strat.scope():

      def f(gen):
        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t = array_ops.stack([t1, t2])
        return t

      results = strat.extended.call_for_each_replica(fn=f, args=gens)
      local_results = strat.experimental_local_results(results)
      self.assertAllEqual(2, len(local_results))
      self.assertAllDifferent(local_results)

  @ds_combinations.generate(
      combinations.combine(
          strat=all_strategies,
          mode=["eager"]))
  def testCrossReplica(self, strat):
    """Tests that RNG can be properly advanced in cross-replica context."""
    def read_values(dv):
      return [v.read_value() for v in strat.experimental_local_results(dv)]
    with strat.scope():
      g = rng.Generator.from_seed(1)
      s1 = read_values(g.state)
      g.normal([3])
      g.skip(4)
      s2 = read_values(g.state)
    self.assertNotAllEqual(s1[0], s2[0])
    self.assertEqual(len(s1), len(s2))
    for i in range(1, len(s1)):
      self.assertAllEqual(s1[0], s1[i])
      self.assertAllEqual(s2[0], s2[i])

  @ds_combinations.generate(
      combinations.combine(
          strat=all_strategies,
          mode=["eager"],
          jit_replica_fn=[False, True],
          seeded=[True, False],))
  def testDistStrat(self, strat, jit_replica_fn, seeded):
    """Tests RNG with distribution strategies."""
    strat_name = type(strat).__name__
    if "TPU" in strat_name and not jit_replica_fn:
      self.skipTest(
          "TPUStrategy requires the replica function (the function passed to "
          "strategy.run) to be decorated with tf.function")
    coord = None
    if "ParameterServer" in strat_name:
      coord = coordinator_lib.ClusterCoordinator(strat)
    creators = {
        True: functools.partial(rng.Generator.from_seed, 1234),
        False: rng.Generator.from_non_deterministic_state,
    }
    shape = [3, 4]
    dtype = dtypes.int32
    creator = creators[seeded]
    with strat.scope():
      gen = creator()
      def f():
        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t = array_ops.stack([t1, t2])
        return t
      replica_fn = def_function.function(f) if jit_replica_fn else f
      results = run_on_strategy(replica_fn, strat, coord)
      values = strat.experimental_local_results(results)
      n = get_num_local_replicas(strat, values)
      self.assertAllEqual(n, len(values))
      self.assertAllDifferent(values)

  @ds_combinations.generate(
      combinations.combine(
          strat=[
              strategy_combinations.parameter_server_strategy_fn(
                  "ParameterServer1Worker2PSCPUFixedShards",
                  num_workers=1, num_ps=2,
                  variable_partitioner=(
                      sharded_variable.FixedShardsPartitioner(2)))
          ],
          mode=["eager"]))
  def testShardedError(self, strat):
    """Tests error about sharding is raised."""
    with strat.scope():
      with self.assertRaisesRegex(
          ValueError, "state is sharded, which is not allowed"):
        rng.Generator.from_seed(1234)

  @ds_combinations.generate(
      combinations.combine(
          strat=all_strategies,
          mode=["eager"],
          jit_replica_fn=[False, True]))
  def testDistVarAsTFFunArg(self, strat, jit_replica_fn):
    """Tests that RNG with dist variables can be used as tf.function's arg."""
    strat_name = type(strat).__name__
    if "CentralStorage" in strat_name:
      self.skipTest(
          "CentralStorageStrategy wraps variable updates in merge_call which "
          "can't be called inside a tf.function that doesn't cover the entire "
          "replica function (the function passed to strategy.run).")
    if "TPU" in strat_name and not jit_replica_fn:
      self.skipTest(
          "TPUStrategy requires the replica function (the function passed to "
          "strategy.run) to be decorated with tf.function")
    coord = None
    if "ParameterServer" in strat_name:
      coord = coordinator_lib.ClusterCoordinator(strat)
    shape = [3, 4]
    dtype = dtypes.int32
    with strat.scope():
      gen = rng.Generator.from_seed(1234)
      @def_function.function
      def f(gen):  # the main focus
        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t = array_ops.stack([t1, t2])
        return t
      def g():
        return f(gen)
      replica_fn = def_function.function(g) if jit_replica_fn else g
      for _ in range(2):
        results = run_on_strategy(replica_fn, strat, coord)
        values = strat.experimental_local_results(results)
        n = get_num_local_replicas(strat, values)
        self.assertAllEqual(n, len(values))
        self.assertAllDifferent(values)

  @ds_combinations.generate(
      combinations.combine(
          strat1=strategy_combinations.all_strategies,
          strat2=strategy_combinations.all_strategies,
          jit_replica_fn=[False, True],
          mode=["eager"]) +
      combinations.combine(
          strat1=strategy_combinations.multiworker_strategies + ps_strategies,
          strat2=[None],
          jit_replica_fn=[False, True],
          mode=["eager"]))
  def testDistStratRestore(self, strat1, strat2, jit_replica_fn):
    """Tests checkpointing and restoring (to possibly different #replicas)."""
    if strat2 is None:
      strat2 = strat1
    strat1_name = type(strat1).__name__
    strat2_name = type(strat2).__name__
    if "Default" in strat1_name or "Default" in strat2_name:
      self.skipTest(
          "We don't guarantee consistency between strategy and no-strategy.")
    if ("TPU" in strat1_name or "TPU" in strat2_name) and not jit_replica_fn:
      self.skipTest(
          "TPUStrategy requires the replica function (the function passed to "
          "strategy.run) to be decorated with tf.function")
    coord1 = None
    if "ParameterServer" in strat1_name:
      coord1 = coordinator_lib.ClusterCoordinator(strat1)
    coord2 = None
    if "ParameterServer" in strat2_name:
      coord2 = coordinator_lib.ClusterCoordinator(strat2)
    fname = os.path.join(self.get_temp_dir(), "checkpoint")
    def uniform(strat, coord, g):
      def f():
        return g.uniform_full_int([3], dtype=dtypes.int32)
      replica_fn = def_function.function(f) if jit_replica_fn else f
      result = run_on_strategy(replica_fn, strat, coord)
      return strat.experimental_local_results(result)
    with strat1.scope():
      g1 = rng.Generator.from_seed(1)
    with strat2.scope():
      g2 = rng.Generator.from_seed(10)
    cp1 = tracking_util.Checkpoint(g=g1)
    cp2 = tracking_util.Checkpoint(g=g2)
    def write_restore_compare():
      cp1.write(fname)
      r1 = uniform(strat1, coord1, g1)
      cp2.restore(fname)
      r2 = uniform(strat2, coord2, g2)
      # Tests that overlapping replicas are properly restored.
      n1 = get_num_local_replicas(strat1)
      n2 = get_num_local_replicas(strat2)
      n = min(n1, n2)
      self.assertAllEqual(r1[:n], r2[:n])
    # Run multiple times so that cp1.write is called in various RNG states
    for _ in range(2):
      write_restore_compare()

  @ds_combinations.generate(
      combinations.combine(
          strat=all_strategies,
          mode=["eager"],
          is_save_in_scope=[True, False]))
  def testSavedModel(self, strat, is_save_in_scope):

    class CustomModule(module.Module):

      def __init__(self):
        super(CustomModule, self).__init__()
        self.g = rng.Generator.from_seed(0)

      @def_function.function
      def __call__(self):
        return self.g.state

      @def_function.function
      def mutate(self):
        self.g.normal([])

    with strat.scope():
      m = CustomModule()
      m.mutate()
      state_before = m()
      path = os.path.join(self.get_temp_dir(), "saved_model")
    if is_save_in_scope:
      with strat.scope():
        save.save(m, path)
    else:
      save.save(m, path)
    with strat.scope():
      m.mutate()
      state_before_2 = m()

    imported = load.load(path)
    state_after = imported()
    self.assertAllEqual(state_before, state_after)
    imported.mutate()
    state_after_2 = imported()
    self.assertAllEqual(state_before_2, state_after_2)
예제 #18
0
 def setUpClass(cls):
   super().setUpClass()
   cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
       make_cluster(3, 2),
       variable_partitioner=sharded_variable.FixedShardsPartitioner(2))
예제 #19
0
 def _create_ps_strategy(resolver):
     return parameter_server_strategy_v2.ParameterServerStrategyV2(
         resolver,
         variable_partitioner=sharded_variable.FixedShardsPartitioner(2))