def testOne(self, strategy): @def_function.function def f(): return array_ops.ones((), dtypes.float32) results = test_util.gather(strategy, strategy.run(f)) self.assertAllEqual(self.evaluate(results), [1.] * strategy.num_replicas_in_sync)
def testMakeDistributedValueExtractFromArray(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") multiple_values = range(distribution.num_replicas_in_sync) def value_fn(ctx): return multiple_values[ctx.replica_id_in_sync_group] distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) distributed_values = ds_test_util.gather(distribution, distributed_values) expected = range(distribution.num_replicas_in_sync) self.assertAllEqual(distributed_values, expected)
def test_replica_id_in_sync_group(self, strategy): def replica_fn(): replica_ctx = distribution_strategy_context.get_replica_context() return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id results = test_util.gather(strategy, strategy.run(replica_fn)) self.assertAllEqual( list(range(strategy.extended._num_replicas_in_sync)), results[0].numpy()) self.assertAllEqual( list(range(len(strategy.extended.worker_devices))) * strategy.extended._num_workers, results[1].numpy())
def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") array_value = np.array([1., 2., 3.]) def value_fn(ctx): del ctx return array_value distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) self.assertAllEqual( ds_test_util.gather(distribution, distributed_values).numpy(), [[1., 2., 3.]] * distribution.num_replicas_in_sync)
def testMakeDistributedValueFromTensor(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") single_value = constant_op.constant(1) def value_fn(ctx): del ctx return single_value distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) self.assertAllEqual( ds_test_util.gather(distribution, distributed_values), constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
def optimize(): grads = ops.convert_to_tensor([[1., 1.], [2., 2.]]) grads = distribution.experimental_distribute_values_from_function( lambda ctx: grads[ctx.replica_id_in_sync_group]) def step_fn(grads): optimizer.apply_gradients([(grads, v)], experimental_aggregate_gradients= experimental_aggregate_gradients) return v.read_value() return test_util.gather(distribution, distribution.run(step_fn, args=(grads, )))
def run(): multiple_values = range(distribution.num_replicas_in_sync) def value_fn(ctx): return multiple_values[ctx.replica_id_in_sync_group] distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) def computation(x): return math_ops.square(x) outputs = ds_test_util.gather( distribution, distribution.run(computation, args=(distributed_values,))) return outputs
def testMakeDistributedValueTupleConstant(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") tuple_value = (1., 2., 3.) def value_fn(ctx): del ctx return tuple_value distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) distributed_values = ds_test_util.gather(distribution, distributed_values) # Expected output for 2 replicas: # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0]) expected = tuple([v for i in range(distribution.num_replicas_in_sync)] for v in tuple_value) self.assertAllEqual(distributed_values, expected)
def testMakeDistributedValueNestedStructurePerReplica(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") tuple_value = (1., 2., 3.) def value_fn(ctx): per_replica = [] for val in tuple_value: per_replica.append(val * ctx.replica_id_in_sync_group) return tuple(per_replica) distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) distributed_values = ds_test_util.gather(distribution, distributed_values) # Expected output for 2 replicas: # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0]) expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)] for v in tuple_value) self.assertAllEqual(distributed_values, expected)
def testStrategyRun(self, strategy, enable_packed_handle, tf_function): if (test_util.is_tpu_strategy(strategy) and tf_function is combinations.no_tf_function): self.skipTest("tpu doesn't support eager") v = self.create_variable(strategy, 0., enable_packed_handle) @tf_function def update(per_replica): v.assign(per_replica) @tf_function def read(): return v.read_value() strategy.run(update, args=(test_util.create_per_replica(strategy, [1., 2.]), )) self.assertReplica(v, [1., 2.]) self.assertAllEqual(test_util.gather(strategy, strategy.run(read)), [1., 2.])
def testNest(self, strategy): @def_function.function def f(): return { 'foo': array_ops.ones((), dtypes.float32), 'bar': [ array_ops.zeros((), dtypes.float32), array_ops.ones((), dtypes.float32), ] } results = test_util.gather(strategy, strategy.run(f)) self.assertAllEqual(self.evaluate(results['foo']), [1.] * strategy.num_replicas_in_sync) self.assertAllEqual(self.evaluate(results['bar'][0]), [0.] * strategy.num_replicas_in_sync) self.assertAllEqual(self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync)