예제 #1
0
    def testDenseUpdate(self, strategy, tf_function, update_fn):
        if strategy_test_lib.is_tpu_strategy(strategy) and (not tf_function):
            self.skipTest('Skip TPUStrategy + eager combination.')
        with strategy.scope():
            distributed_variable1 = variables.Variable(5.0)

        def replica_fn():
            value = array_ops.constant(2.)
            python_literal = 1.
            replica_context = ds_context.get_replica_context()
            fn_sets = {
                'assign': lambda var, value: var.assign(value),
                'assign_add': lambda var, value: var.assign_add(value),
                'assign_sub': lambda var, value: var.assign_sub(value),
            }
            replica_context._update(distributed_variable1,
                                    fn_sets[update_fn],
                                    args=(value, ))
            replica_context._update(distributed_variable1,
                                    fn_sets[update_fn],
                                    args=(python_literal, ))

        if tf_function:
            replica_fn = def_function.function(replica_fn)
        strategy.run(replica_fn)

        expected_result = {'assign': 1., 'assign_add': 8., 'assign_sub': 2.}
        self.assertAllEqual(
            strategy.experimental_local_results(distributed_variable1),
            [expected_result[update_fn]] *
            _get_num_replicas_per_client(strategy))
예제 #2
0
    def testSparse(self, strategy, tf_function):
        if tf_function is combinations.no_tf_function:
            self.skipTest('Skip IndexedSlices + eager combination.')

        @tf_function
        def fn():
            def replica_fn():
                value = indexed_slices.IndexedSlices(
                    values=array_ops.identity([[1.0]]),
                    indices=array_ops.identity([0]),
                    dense_shape=array_ops.identity([5, 1]))
                rep_ctx = ds_context.get_replica_context()
                reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)
                return reduced

            return strategy.experimental_local_results(
                strategy.run(replica_fn))

        got = fn()[0]

        if not strategy_test_lib.is_tpu_strategy(strategy):
            self.assertIsInstance(got, indexed_slices.IndexedSlices)
        expect = indexed_slices.IndexedSlices(
            values=array_ops.identity([[1.0]]),
            indices=array_ops.identity([0]),
            dense_shape=array_ops.identity([5, 1]))
        self.assertAllEqual(ops.convert_to_tensor(got),
                            ops.convert_to_tensor(expect))
    def testReplicaContextEager(self, distribution, use_function):
        if not use_function and strategy_test_lib.is_tpu_strategy(
                distribution):
            self.skipTest("TPUStrategy doesn't support pure eager execution.")
        if isinstance(
                distribution,
                collective_all_reduce_strategy.CollectiveAllReduceStrategy):
            self.skipTest(
                "b/160194267: Cannot do variable.assign([0.5]) in replica "
                "context with MultiWorkerMirroredStrategy.")
        with distribution.scope():
            w = variables.Variable(
                [1.0],
                name="w",
                aggregation=variables.VariableAggregation.MEAN)
            ema = moving_averages.ExponentialMovingAverage(0.8)

            def fn():
                def _ema_replica_fn_eager():
                    ema.apply([w])
                    w.assign_sub([0.5])
                    ema.apply([w])
                    return ema.average(w)

                return distribution.run(_ema_replica_fn_eager)

            if use_function:
                fn = def_function.function(fn)
            ema_w = fn()
        self.assertAllClose(
            self.evaluate(distribution.experimental_local_results(ema_w))[0],
            [0.89999998])
예제 #4
0
    def testSyncOnReadVariableInput(self, strategy, tf_function):
        if (not strategy_test_lib.is_mirrored_strategy(strategy)
                and not strategy_test_lib.is_multi_worker_mirrored_strategy(
                    strategy)
                and not strategy_test_lib.is_tpu_strategy(strategy)):
            self.skipTest('Skip strategies not using SyncOnReadVariables.')
        if (strategy_test_lib.is_tpu_strategy(strategy)
                and tf_function is combinations.no_tf_function):
            self.skipTest('Skip TPUStrategy + eager combination.')
        if (strategy_test_lib.is_multi_worker_mirrored_strategy(strategy)
                and tf_function is combinations.tf_function):
            self.skipTest(
                'Skip MWMS + graph combination until b/228512201 is fixed.')

        with strategy.scope():
            var = variables.Variable(
                0.0,
                synchronization=variables.VariableSynchronization.ON_READ,
                aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)

        @tf_function
        def replica_fn():
            replica_context = ds_context.get_replica_context()
            replica_id = replica_context.replica_id_in_sync_group
            var.assign(math_ops.cast(replica_id, dtype=float) * 3.0)

            return replica_context.all_reduce(reduce_util.ReduceOp.SUM, var)

        if strategy_test_lib.is_multi_worker_mirrored_strategy(strategy):
            client_local_replica_num = strategy.extended._num_devices_per_worker
        else:
            client_local_replica_num = strategy.num_replicas_in_sync

        workers_num = strategy.num_replicas_in_sync
        expected_sum = sum(range(workers_num)) * 3.0

        # Expand the values on each replica if multiple devices are used; otherwise
        # simple read the value of the Tensor.
        result = strategy.run(replica_fn)
        if hasattr(result, 'values'):
            result = result.values
        result = nest.flatten(result)

        # Iterate through all replicas and verify the reduce sum result.
        for i in range(client_local_replica_num):
            self.assertEqual(result[i].numpy(), expected_sum)
예제 #5
0
    def testDense(self, strategy, tf_function):
        if (strategy_test_lib.is_tpu_strategy(strategy)
                and tf_function is combinations.no_tf_function):
            self.skipTest('Skip TPUStrategy + eager combination.')

        @tf_function
        def fn():
            def replica_fn():
                value = array_ops.identity(1.0)
                rep_ctx = ds_context.get_replica_context()
                reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
                return reduced

            return strategy.experimental_local_results(
                strategy.run(replica_fn))

        got = fn()[0]
        self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync)
def _make_mirrored(distribution=None):
  v = []
  if distribution:
    devices = distribution.extended.worker_devices
  else:
    devices = ["/device:GPU:0", "/device:CPU:0"]
  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
    with ops.device(d):
      v.append(
          variable_scope.get_variable(
              name=n, initializer=init, use_resource=True))

  if (distribution
      is not None) and strategy_test_lib.is_tpu_strategy(distribution):
    var_cls = tpu_values.TPUMirroredVariable
  else:
    var_cls = values_lib.MirroredVariable
  mirrored = var_cls(distribution, v, variable_scope.VariableAggregation.SUM)
  return mirrored
예제 #7
0
    def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
        with ops.Graph().as_default(), distribution.scope():
            iterator = distribution.make_input_fn_iterator(
                lambda _: dataset_fn())
            if strategy_test_lib.is_tpu_strategy(distribution):

                def step_fn(ctx, inputs):
                    value, update = distribution.extended.call_for_each_replica(
                        metric_fn, args=(inputs, ))
                    ctx.set_non_tensor_output(name="value", output=value)
                    return distribution.group(update)

                ctx = distribution.extended.experimental_run_steps_on_iterator(
                    step_fn,
                    iterator,
                    iterations=distribution.extended.steps_per_run)
                update = ctx.run_op
                value = ctx.non_tensor_outputs["value"]
                # In each run, we run multiple steps, and each steps consumes as many
                # batches as number of replicas.
                batches_per_update = (distribution.num_replicas_in_sync *
                                      distribution.extended.steps_per_run)
            else:
                value, update = distribution.extended.call_for_each_replica(
                    metric_fn, args=(iterator.get_next(), ))
                update = distribution.group(update)
                # TODO(josh11b): Once we switch to using a global batch size for input,
                # replace "distribution.num_replicas_in_sync" with "1".
                batches_per_update = distribution.num_replicas_in_sync

            self.evaluate(iterator.initializer)
            self.evaluate(variables.local_variables_initializer())

            batches_consumed = 0
            for i in range(4):
                self.evaluate(update)
                batches_consumed += batches_per_update
                self.assertAllClose(expected_fn(batches_consumed),
                                    self.evaluate(value),
                                    0.001,
                                    msg="After update #" + str(i + 1))
                if batches_consumed >= 4:  # Consume 4 input batches in total.
                    break
예제 #8
0
    def testSparseTuple(self, strategy, tf_function):
        if tf_function is combinations.no_tf_function:
            self.skipTest('Skip IndexedSlices + eager combination.')

        @tf_function
        def fn():
            def replica_fn():
                value1 = indexed_slices.IndexedSlices(
                    values=array_ops.identity([[1.0]]),
                    indices=array_ops.identity([0]),
                    dense_shape=array_ops.identity([5, 1]))
                value2 = indexed_slices.IndexedSlices(
                    values=array_ops.identity([[2.0]]),
                    indices=array_ops.identity([0]),
                    dense_shape=array_ops.identity([5, 1]))
                rep_ctx = ds_context.get_replica_context()
                reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM,
                                             [value1, value2])
                return reduced

            return strategy.experimental_local_results(
                strategy.run(replica_fn))

        got = fn()[0]

        if not strategy_test_lib.is_tpu_strategy(strategy):
            for g in got:
                self.assertIsInstance(g, indexed_slices.IndexedSlices)
        expect = [
            indexed_slices.IndexedSlices(values=array_ops.identity(
                [[1.0 * strategy.num_replicas_in_sync]]),
                                         indices=array_ops.identity([0]),
                                         dense_shape=array_ops.identity([5,
                                                                         1])),
            indexed_slices.IndexedSlices(values=array_ops.identity(
                [[2.0 * strategy.num_replicas_in_sync]]),
                                         indices=array_ops.identity([0]),
                                         dense_shape=array_ops.identity([5,
                                                                         1]))
        ]
        self.assertAllEqual(nest.map_structure(ops.convert_to_tensor, got),
                            nest.map_structure(ops.convert_to_tensor, expect))
예제 #9
0
    def testClusterResolverProperty(self, strategy):
        # CollectiveAllReduceStrategy and TPUStrategy must have a cluster resolver.
        # `None` otherwise.
        resolver = strategy.cluster_resolver
        if (not isinstance(strategy, CollectiveAllReduceStrategy)
                and not strategy_test_lib.is_tpu_strategy(strategy)):
            self.assertIsNone(resolver)
            return

        with strategy.scope():
            self.assertIs(strategy.cluster_resolver, resolver)

        self.assertTrue(hasattr(resolver, 'cluster_spec'))
        self.assertTrue(hasattr(resolver, 'master'))
        self.assertTrue(hasattr(resolver, 'num_accelerators'))
        self.assertTrue(hasattr(resolver, 'task_id'))
        self.assertTrue(hasattr(resolver, 'task_type'))
        if isinstance(strategy, CollectiveAllReduceStrategy):
            self.assertEqual(resolver.task_id, 0)
            self.assertAllInSet(resolver.task_type, ['chief', 'worker'])