def testOne(self, strategy):
        @def_function.function
        def f():
            return array_ops.ones((), dtypes.float32)

        results = test_util.gather(strategy, strategy.run(f))
        self.assertAllEqual(self.evaluate(results),
                            [1.] * strategy.num_replicas_in_sync)
Exemple #2
0
 def testMakeDistributedValueExtractFromArray(self, distribution):
   if not tf2.enabled():
     self.skipTest("Only V2 is supported.")
   multiple_values = range(distribution.num_replicas_in_sync)
   def value_fn(ctx):
     return multiple_values[ctx.replica_id_in_sync_group]
   distributed_values = (
       distribution.experimental_distribute_values_from_function(value_fn))
   distributed_values = ds_test_util.gather(distribution, distributed_values)
   expected = range(distribution.num_replicas_in_sync)
   self.assertAllEqual(distributed_values, expected)
Exemple #3
0
    def test_replica_id_in_sync_group(self, strategy):
        def replica_fn():
            replica_ctx = distribution_strategy_context.get_replica_context()
            return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id

        results = test_util.gather(strategy, strategy.run(replica_fn))
        self.assertAllEqual(
            list(range(strategy.extended._num_replicas_in_sync)),
            results[0].numpy())
        self.assertAllEqual(
            list(range(len(strategy.extended.worker_devices))) *
            strategy.extended._num_workers, results[1].numpy())
Exemple #4
0
  def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution):
    if not tf2.enabled():
      self.skipTest("Only V2 is supported.")
    array_value = np.array([1., 2., 3.])
    def value_fn(ctx):
      del ctx
      return array_value

    distributed_values = (
        distribution.experimental_distribute_values_from_function(value_fn))
    self.assertAllEqual(
        ds_test_util.gather(distribution, distributed_values).numpy(),
        [[1., 2., 3.]] * distribution.num_replicas_in_sync)
Exemple #5
0
  def testMakeDistributedValueFromTensor(self, distribution):
    if not tf2.enabled():
      self.skipTest("Only V2 is supported.")
    single_value = constant_op.constant(1)
    def value_fn(ctx):
      del ctx
      return single_value

    distributed_values = (
        distribution.experimental_distribute_values_from_function(value_fn))
    self.assertAllEqual(
        ds_test_util.gather(distribution, distributed_values),
        constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
Exemple #6
0
        def optimize():
            grads = ops.convert_to_tensor([[1., 1.], [2., 2.]])
            grads = distribution.experimental_distribute_values_from_function(
                lambda ctx: grads[ctx.replica_id_in_sync_group])

            def step_fn(grads):
                optimizer.apply_gradients([(grads, v)],
                                          experimental_aggregate_gradients=
                                          experimental_aggregate_gradients)
                return v.read_value()

            return test_util.gather(distribution,
                                    distribution.run(step_fn, args=(grads, )))
Exemple #7
0
    def run():
      multiple_values = range(distribution.num_replicas_in_sync)
      def value_fn(ctx):
        return multiple_values[ctx.replica_id_in_sync_group]
      distributed_values = (
          distribution.experimental_distribute_values_from_function(value_fn))

      def computation(x):
        return math_ops.square(x)

      outputs = ds_test_util.gather(
          distribution,
          distribution.run(computation, args=(distributed_values,)))
      return outputs
Exemple #8
0
  def testMakeDistributedValueTupleConstant(self, distribution):
    if not tf2.enabled():
      self.skipTest("Only V2 is supported.")
    tuple_value = (1., 2., 3.)
    def value_fn(ctx):
      del ctx
      return tuple_value
    distributed_values = (
        distribution.experimental_distribute_values_from_function(value_fn))
    distributed_values = ds_test_util.gather(distribution, distributed_values)

    # Expected output for 2 replicas:
    # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0])
    expected = tuple([v for i in range(distribution.num_replicas_in_sync)]
                     for v in tuple_value)
    self.assertAllEqual(distributed_values, expected)
Exemple #9
0
  def testMakeDistributedValueNestedStructurePerReplica(self, distribution):
    if not tf2.enabled():
      self.skipTest("Only V2 is supported.")
    tuple_value = (1., 2., 3.)
    def value_fn(ctx):
      per_replica = []
      for val in tuple_value:
        per_replica.append(val * ctx.replica_id_in_sync_group)
      return tuple(per_replica)
    distributed_values = (
        distribution.experimental_distribute_values_from_function(value_fn))
    distributed_values = ds_test_util.gather(distribution, distributed_values)

    # Expected output for 2 replicas:
    # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0])
    expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)]
                     for v in tuple_value)
    self.assertAllEqual(distributed_values, expected)
Exemple #10
0
    def testStrategyRun(self, strategy, enable_packed_handle, tf_function):
        if (test_util.is_tpu_strategy(strategy)
                and tf_function is combinations.no_tf_function):
            self.skipTest("tpu doesn't support eager")
        v = self.create_variable(strategy, 0., enable_packed_handle)

        @tf_function
        def update(per_replica):
            v.assign(per_replica)

        @tf_function
        def read():
            return v.read_value()

        strategy.run(update,
                     args=(test_util.create_per_replica(strategy, [1., 2.]), ))
        self.assertReplica(v, [1., 2.])
        self.assertAllEqual(test_util.gather(strategy, strategy.run(read)),
                            [1., 2.])
    def testNest(self, strategy):
        @def_function.function
        def f():
            return {
                'foo':
                array_ops.ones((), dtypes.float32),
                'bar': [
                    array_ops.zeros((), dtypes.float32),
                    array_ops.ones((), dtypes.float32),
                ]
            }

        results = test_util.gather(strategy, strategy.run(f))
        self.assertAllEqual(self.evaluate(results['foo']),
                            [1.] * strategy.num_replicas_in_sync)
        self.assertAllEqual(self.evaluate(results['bar'][0]),
                            [0.] * strategy.num_replicas_in_sync)
        self.assertAllEqual(self.evaluate(results['bar'][1]),
                            [1.] * strategy.num_replicas_in_sync)