def worker_fn(input_tensor):
                    def replica_fn(input_tensor):
                        return input_tensor + v

                    run_result = strategy.run(replica_fn,
                                              args=(input_tensor, ))
                    check_ops.assert_equal_v2(run_result, 4)
                    return run_result
      def worker_fn(input_tensor):

        def replica_fn(input_tensor):
          # Within `replica_fn`, it has to be in a replica context.
          self.assertFalse(
              distribution_strategy_context.in_cross_replica_context())
          return input_tensor + v, input_tensor - v

        run_result = self.strategy.run(replica_fn, args=(input_tensor,))
        reduced_result = self.strategy.reduce('SUM', run_result, axis=None)
        check_ops.assert_equal_v2(reduced_result, expected_result)
        return reduced_result
    def testSaveNonDistributed(self, distribution, synchronization,
                               aggregation):
        # This test verifies that the DistributedVariable behave like the primary
        # variable when saving a non-distributed version of the model (the default).
        # The test asserts that the function traced under SaveContext has no device
        # annotations and only reference the primary component of the variable. Note
        # that please avoid capturing other eager tensors in this test to make the
        # assertion easy.

        if isinstance(
                distribution.extended,
                parameter_server_strategy.ParameterServerStrategyExtended):
            self.skipTest("b/148689177: AggregatingVariable doesn't "
                          "conform to Variable interface well")

        # tf.function requires the return value to be Tensors, which is not always
        # case for properties and methods of Variable, so we simply discard the
        # return values.
        def _discard_return(f):
            f()
            return

        def _test(f, v):
            # This verifies that the function under SaveContext:
            #   - contains no device annotations.
            #   - only references the primary component of the variable.
            g = def_function.function(lambda: _discard_return(f))
            options = save_options.SaveOptions(
                experimental_variable_policy=save_options.VariablePolicy.NONE)
            with save_context.save_context(options):
                # The graph should contain no device.
                graph = g.get_concrete_function().graph
            for op in graph.get_operations():
                self.assertEqual(op.device, "", msg=str(op))
            # The function should only capture the primary variable. Note that it
            # may not have captures, e.g. v.aggregation.
            captures = list(graph.captures)
            self.assertLessEqual(len(captures), 1)
            if graph.captures:
                self.assertIs(captures[0][0], v._primary.handle)

        def _assert(cond):
            return control_flow_ops.Assert(cond, [cond])

        with distribution.scope():
            # We use four variables for convenience reasons. They have no special
            # meaning.
            # - v is used whenever possible.
            # - w is used for scatter and gather, which require the variable to be
            # non-scalar.
            # - y is used when the dtype needs to be integer. Note that aggregation
            # cannot be MEAN for integers.
            v = variables_lib.Variable(0.,
                                       synchronization=synchronization,
                                       aggregation=aggregation,
                                       trainable=True)
            w = variables_lib.Variable([0., 0., 0.],
                                       synchronization=synchronization,
                                       aggregation=aggregation,
                                       trainable=True)
            if aggregation != variables_lib.VariableAggregation.MEAN:
                y = variables_lib.Variable(0,
                                           synchronization=synchronization,
                                           aggregation=aggregation)

        # pylint: disable=g-long-lambda

        # tf.Variable properties.
        _test(lambda: self.assertEqual(v.aggregation, aggregation), v)
        _test(lambda: self.assertIs(v.constraint, None), v)
        # TODO(crccw): should we raise an error instead?
        _test(lambda: self.assertEqual(v.device, v._primary.device), v)
        _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v)
        if not context.executing_eagerly():
            _test(lambda: self.assertIs(v.graph, v._primary.graph), v)
        if not context.executing_eagerly():
            _test(lambda: _assert(v.initial_value == 0), v)
        _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v)
        _test(lambda: self.assertEqual(v.name, "Variable:0"), v)
        if not context.executing_eagerly():
            _test(lambda: self.assertIs(v.op, v._primary.op), v)
        _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())),
              v)
        _test(lambda: self.assertEqual(v.synchronization, synchronization), v)
        _test(lambda: self.assertEqual(v.trainable, True), v)

        # tf.Variable methods.
        _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v)
        _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v)
        _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v)
        # TODO(b/148689177): Implement batch_scatter_update.
        # count_up_to() is skipped since it's deprecated.
        # eval() is skipped since it shouldn't called in a tf.function.
        # experimental_ref() is skipped since it's deprecated.
        # from_proto() is skipped since it shouldn't called in a tf.function.
        # TODO(b/148689177): Implement gather_nd.
        _test(
            lambda: check_ops.assert_equal_v2(v.get_shape(),
                                              tensor_shape.TensorShape(())), v)
        # initialized_value() is skipped since it shouldn't called in a tf.function.
        # load() is skipped since it shouldn't called in a tf.function.
        _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v)
        # ref() is skipped since it shouldn't called in a tf.function.
        _test(
            lambda: check_ops.assert_equal_v2(
                w.scatter_add(
                    _make_index_slices(values=[1., 2.], indices=[0, 2])),
                [1., 0., 2.]), w)
        _test(
            lambda: check_ops.assert_equal_v2(
                w.scatter_div(
                    _make_index_slices(values=[4., 2.], indices=[0, 2])),
                [0.25, 0., 1.]), w)
        _test(
            lambda: check_ops.assert_equal_v2(
                w.scatter_max(
                    _make_index_slices(values=[1., 0.5], indices=[1, 2])),
                [0.25, 1., 1.]), w)
        _test(
            lambda: check_ops.assert_equal_v2(
                w.scatter_min(
                    _make_index_slices(values=[1., 0.5], indices=[0, 1])),
                [0.25, 0.5, 1.]), w)
        _test(
            lambda: check_ops.assert_equal_v2(
                w.scatter_mul(
                    _make_index_slices(values=[2., 0.5], indices=[0, 1])),
                [0.5, 0.25, 1.]), w)
        # TODO(b/148689177): Implement scatter_nd_*
        _test(
            lambda: check_ops.assert_equal_v2(
                w.scatter_sub(
                    _make_index_slices(values=[2., 0.5], indices=[0, 1])),
                [-1.5, -0.25, 1.]), w)
        _test(
            lambda: check_ops.assert_equal_v2(
                w.scatter_update(
                    _make_index_slices(values=[2., 0.5], indices=[0, 1])),
                [2., 0.5, 1.]), w)
        # set_shape() is skipped since ResourceVariable doesn't implement it.
        # to_proto() is skipped since it shouldn't called in a tf.function.
        _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v)

        # DistributedVariable should be treated as ResourceVariable, so it needs to
        # conform to ResourceVariable interface as well.
        _test(lambda: self.assertIs(v.handle, v._primary.handle), v)

        # Convert to tensor.
        _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.),
              v)

        # Control dependency.
        def _with_control_dep():
            with ops.control_dependencies([v.assign(1.)]):
                return array_ops.identity(1)

        _test(_with_control_dep, v)

        # Operator overloads.
        _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v)
        _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v)
        _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v)
        _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v)
        _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v)
        _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v)
        _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v)
        _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v)
        _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v)
        _test(
            lambda: check_ops.assert_equal_v2(
                math_ops.cast(v / 2., dtypes.float32), 3.5), v)
        _test(
            lambda: check_ops.assert_equal_v2(
                math_ops.cast(14. / v, dtypes.float32), 2.), v)
        _test(lambda: _assert(v < 12.), v)
        _test(lambda: _assert(v <= 12.), v)
        _test(lambda: _assert(not v > 12.), v)
        _test(lambda: _assert(not v >= 12.), v)
        _test(lambda: _assert(not 12. < v), v)
        _test(lambda: _assert(not 12. <= v), v)
        _test(lambda: _assert(12. > v), v)
        _test(lambda: _assert(12. >= v), v)
        _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v)
        _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v)
        _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v)

        # Operator overloads that only works for integers.
        if aggregation != variables_lib.VariableAggregation.MEAN:
            _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y)
            _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y)
            _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y)
            _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y)
            _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y)
            _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y)
            _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y)
            _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y)
            _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y)
            _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y)
            _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y)
            _test(lambda: check_ops.assert_equal_v2(-y, -7), y)
            _test(lambda: check_ops.assert_equal_v2(~y, ~7), y)

        # Index.
        if isinstance(distribution.extended, tpu_strategy.TPUExtended):
            # TODO(b/161572567): slice assignment doesn't work for TPU.
            _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w)
        else:
            _test(
                lambda: check_ops.assert_equal_v2(w[0].assign(1.),
                                                  [1., 0.5, 1.]), w)
            _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w)
Esempio n. 4
0
def split(value: ragged_tensor.Ragged,
          num_or_size_splits,
          axis=0,
          num=None,
          name=None):
  """Splits a RaggedTensor `value` into a list of sub RaggedTensors.

  If `num_or_size_splits` is an `int`,  then it splits `value` along the
  dimension `axis` into `num_or_size_splits` smaller RaggedTensors. This
  requires that `value.shape[axis]` is divisible by `num_or_size_splits`.

  If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
  `len(num_or_size_splits)` elements. The shape of the `i`-th element has the
  same size as the `value` except along dimension `axis` where the size is
  `num_or_size_splits[i]`.

  Splits along a ragged dimension is not allowed.

  For example:

  >>> rt = tf.RaggedTensor.from_row_lengths(
  ...      np.arange(6 * 3).reshape(6, 3), row_lengths=[1, 2, 2, 1])
  >>> rt.shape
  TensorShape([4, None, 3])
  >>>
  >>> rt1, rt2 = tf.split(rt, 2)  # uniform splits
  >>> rt1.shape
  TensorShape([2, None, 3])
  >>> rt2.shape
  TensorShape([2, None, 3])
  >>>
  >>> rt3, rt4, rt5 = tf.split(rt, [1, 2, 1])  # ragged splits
  >>> rt3.shape
  TensorShape([1, None, 3])
  >>> rt4.shape
  TensorShape([2, None, 3])
  >>> rt5.shape
  TensorShape([1, None, 3])
  >>>
  >>> rt6, rt7 = tf.split(rt, [1, 2], axis=2)  # splits along axis 2
  >>> rt6.shape
  TensorShape([4, None, 1])
  >>> rt7.shape
  TensorShape([4, None, 2])

  Args:
    value: The `RaggedTensor` to split.
    num_or_size_splits: Either an `int` indicating the number of splits
      along `axis` or a 1-D integer `Tensor` or Python list containing the sizes
      of each output tensor along `axis`. If a Python int, then it must evenly
      divide `value.shape[axis]`; otherwise the sum of sizes along the split
      axis must match that of the `value`.
    axis: An `int` or scalar `int32` `Tensor`. The dimension along which
      to split. Must be in the range `[-rank(value), rank(value))`. Defaults to
      0.
    num: An `int` used to specify the number of outputs when
      `num_or_size_splits` is a 1-D list or `Tensor` and its length is
      statically unknown, e.g., specifying `tf.TensorSepc(None)` with
      the `input_signature` argument of `tf.function` (optional).
    name: A name for the operation (optional).

  Returns:
    if `num_or_size_splits` is an `int` returns a list of `num_or_size_splits`
    `RaggedTensor` objects; if `num_or_size_splits` is a 1-D Tensor returns
    `num_or_size_splits.get_shape[0]` `RaggedTensor` objects resulting from
    splitting `value`.

  Raises:
    ValueError: If the dimension `axis` of `value` is a ragged dimension.
    ValueError: If `num` is unspecified and cannot be inferred.
    ValueError: If `num` is specified but doesn't match the length of
      `num_or_size_splits`.
    ValueError: If `num_or_size_splits` is an `int` and less than 1.
    TypeError: If `num_or_size_splits` is not an `int` or 1-D
      list or 1-D `Tensor`.
    InvalidArgumentError: If the `axis` of `value` cannot be exactly splitted
      by `num_or_size_splits`.
    InvalidArgumentError: If `num_or_size_splits` is contains negative integers.
    InvalidArgumentError: If `num_or_size_splits`'s static shape is unknown and
      its dynamic shape is inconsistent `num`.
    InvalidArgumentError: If `num_or_size_splits`'s static rank is unknown and
      `axis` is a negative integer.
  """
  with ops.name_scope(name, 'RaggedSplit'):
    value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        value, name='value')
    if isinstance(num_or_size_splits, int) and num_or_size_splits == 1:
      return [value]

    # static assert
    check_ops.assert_integer_v2(
        num_or_size_splits,
        message=('`num_or_size_splits` must be an `int` or 1-D list or '
                 '`Tensor` of integers.'))
    value_shape = ragged_shape.RaggedShape.from_tensor(value)
    axis = array_ops.get_positive_axis(axis, value_shape.rank)
    try:
      dim_size = value_shape[axis]
    except ValueError:
      raise ValueError('Cannot split a ragged dimension. Got `value` with '
                       f'shape {value_shape} and `axis` {axis}.')
    if isinstance(num_or_size_splits, int):
      # Uniform split
      num_splits = num_or_size_splits
      if num_splits < 1:
        raise ValueError('`num_or_size_splits` must be >=1 if it is an `int`.'
                         f'Received {num_or_size_splits}.')
      split_length = math_ops.floordiv(dim_size, num_splits)
      split_lengths = array_ops.repeat(split_length, num_splits)
    else:
      # Ragged split
      num_splits = None
      split_lengths = ops.convert_to_tensor(num_or_size_splits)
      if split_lengths.shape.ndims is not None:
        if split_lengths.shape.ndims != 1:
          raise TypeError('`num_or_size_splits` must be an `int` or 1-D list '
                          f'or `Tensor`. Received {num_or_size_splits}.')
        num_splits = tensor_shape.dimension_value(split_lengths.shape[0])

      if num_splits is None:
        if num is None:
          raise ValueError('`num` must be specified as an `int` when the '
                           'size of `num_or_size_split` is statically '
                           f'unknown. Received `num`: {num} and '
                           f'`num_or_size_split`: {num_or_size_splits}.')
        num_splits = num
      else:
        if num is not None and num != num_splits:
          raise ValueError('`num` does not match the size of '
                           f'`num_or_size_split`. Received `num`: {num} and '
                           f'size of `num_or_size_split`: {num_splits}.')

    splits = array_ops.concat([[0], math_ops.cumsum(split_lengths)], axis=0)
    checks = []
    checks.append(
        check_ops.assert_non_negative_v2(
            num_or_size_splits,
            message='`num_or_size_splits` must be non-negative.'))
    checks.append(
        check_ops.assert_equal_v2(
            num_splits,
            array_ops.shape(split_lengths)[0],
            message='`num` is inconsistent with `num_or_size_split.shape[0]`.'))
    checks.append(
        check_ops.assert_equal_v2(
            math_ops.cast(dim_size, splits.dtype),
            splits[-1],
            message=('Cannot exactly split the `axis` dimension of `value` '
                     'with the given `num_or_size_split`.')))
    splits = control_flow_ops.with_dependencies(checks, splits)
    splited_rts = []
    slices = [slice(None)] * (axis + 1)
    for i in range(num_splits):
      slices[-1] = slice(splits[i], splits[i + 1])
      splited_rts.append(value[tuple(slices)])
    return splited_rts