Exemplo n.º 1
0
 def is_range_tensor(self):
     self.assertTrue(tensors.is_range_tensor(math_ops.range(1)))
     self.assertTrue(tensors.is_range_tensor(math_ops.range(1, 2)))
     self.assertTrue(tensors.is_range_tensor(math_ops.range(1, 2, 3)))
     self.assertFalse(tensors.is_range_tensor(None))
     self.assertFalse(
         tensors.is_range_tensor(constant_op.constant(range(1))))
Exemplo n.º 2
0
def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names,
             opts):
    """Functional form of a for statement.

  The loop operates on a state, which includes all symbols that are
  variant across loop iterations, excluding the variables local to the loop.

  For example, given the loop below that calculates the geometric and
  arithmetic means or some numbers:

  ```
    geo_mean = 1
    arith_mean = 0
    for i in range(n):
      a = numbers[i]
      geo_mean *= a
      arith_mean += a
  ```

  The state is represented by the variables geo_mean and arith_mean. The
  `extra_test`, `body`, `get_state` and `set_state` functions must bind to the
  original `geo_mean` and `arith_mean` symbols, using `nonlocal`.

  The inputs and outputs of the callables representing the loop blocks are not
  explicit - instead, these functions must use nonlocal/global for side effects.
  The inputs and outputs are instead controlled by the set_state/get_state
  functions.

  Args:
    iter_: The entity being iterated over.
    extra_test: Callable with boolean return type.
      An additional loop condition.
    body: Callable representing the actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    symbol_names: Tuple containing names of the loop variables returned by
      get_state.
    opts: Optional dict of extra loop parameters.
  """
    if tensor_util.is_tensor(iter_):
        if tensors.is_range_tensor(iter_):
            _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                               symbol_names, opts)
        elif isinstance(iter_, ragged_tensor.RaggedTensor):
            _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
                                symbol_names, opts)
        else:
            _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
                                   set_state, symbol_names, opts)

    elif isinstance(iter_, dataset_ops.DatasetV2):
        _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
                             symbol_names, opts)

    elif isinstance(iter_, iterator_ops.OwnedIterator):
        _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
                              symbol_names, opts)

    elif isinstance(iter_, ragged_tensor.RaggedTensor):
        _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
                            symbol_names, opts)

    elif isinstance(iter_, distribute.Iterator):
        _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
                              symbol_names, opts)

    elif isinstance(iter_, distribute.Iterable):
        # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)...
        _tf_distributed_iterable_for_stmt(iter_, extra_test, body, get_state,
                                          set_state, symbol_names, opts)

    else:
        _py_for_stmt(iter_, extra_test, body, None, None)
def for_stmt(iter_, extra_test, body, get_state, set_state, init_vars,
             basic_symbol_names, composite_symbol_names, opts):
    """Functional form of a for statement.

  The loop operates on a state, which includes all symbols that are
  variant across loop iterations, excluding the iterate as well as the
  variables local to the loop.

  For example, given the loop below that calculates the geometric and
  arithmetic means or some numbers:

    geo_mean = 1
    arith_mean = 0
    for i in range(n):
      a = numbers[i]
      geo_mean *= a
      arith_mean += a

  The state is represented by the variables geo_mean and arith_mean. The
  argument for initial_state may contain the tuple (1, 0), the body will
  include the arguments geo_mean and arith_mean and will return a tuple
  representing the new values for geo_mean and respectively arith_mean.

  Args:
    iter_: The entity being iterated over.
    extra_test: Callable with the state as arguments, and boolean return type.
      An additional loop condition.
    body: Callable with the iterate and the state as arguments, and state as
      return type. The actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    init_vars: Tuple containing the initial state.
    basic_symbol_names: Tuple containing basic loop var names.
    composite_symbol_names: Tuple containing composite loop var names.
    opts: Optional dict of extra loop parameters.

  Returns:
    Tuple containing the final state.
  """
    if tensor_util.is_tensor(iter_):
        if tensors.is_range_tensor(iter_):
            return _tf_range_for_stmt(iter_, extra_test, body, get_state,
                                      set_state, init_vars, basic_symbol_names,
                                      composite_symbol_names, opts)
        else:
            return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
                                          set_state, init_vars,
                                          basic_symbol_names,
                                          composite_symbol_names, opts)

    if isinstance(iter_, dataset_ops.DatasetV2):
        return _tf_dataset_for_stmt(iter_, extra_test, body, get_state,
                                    set_state, init_vars, basic_symbol_names,
                                    composite_symbol_names, opts)

    if isinstance(iter_, iterator_ops.OwnedIterator):
        return _tf_iterator_for_stmt(iter_, extra_test, body, get_state,
                                     set_state, init_vars, basic_symbol_names,
                                     composite_symbol_names, opts)

    if isinstance(iter_, ragged_tensor.RaggedTensor):
        return _tf_ragged_for_stmt(iter_, extra_test, body, get_state,
                                   set_state, init_vars, basic_symbol_names,
                                   composite_symbol_names, opts)

    if isinstance(iter_, input_lib.DistributedIterator):
        raise NotImplementedError(
            'distributed iterators not supported yet, use the distributed dataset'
            ' directly')

    if isinstance(iter_, input_lib.DistributedDataset):
        return _tf_distributed_dataset_for_stmt(iter_, extra_test, body,
                                                init_vars)

    return _py_for_stmt(iter_, extra_test, body, get_state, set_state,
                        init_vars)
Exemplo n.º 4
0
def for_stmt(iter_, extra_test, body, get_state, set_state, init_vars,
             basic_symbol_names, composite_symbol_names, opts):
    """Functional form of a for statement.

  The loop operates on a state, which includes all symbols that are
  variant across loop iterations, excluding the iterate as well as the
  variables local to the loop.

  For example, given the loop below that calculates the geometric and
  arithmetic means or some numbers:

    geo_mean = 1
    arith_mean = 0
    for i in range(n):
      a = numbers[i]
      geo_mean *= a
      arith_mean += a

  The state is represented by the variables geo_mean and arith_mean. The
  argument for initial_state may contain the tuple (1, 0), the body will
  include the arguments geo_mean and arith_mean and will return a tuple
  representing the new values for geo_mean and respectively arith_mean.

  Args:
    iter_: The entity being iterated over.
    extra_test: Callable with the state as arguments, and boolean return type.
      An additional loop condition.
    body: Callable with the iterate and the state as arguments, and state as
      return type. The actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    init_vars: Tuple containing the initial state.
    basic_symbol_names: Tuple containing basic loop var names.
    composite_symbol_names: Tuple containing composite loop var names.
    opts: Optional dict of extra loop parameters.

  Returns:
    Tuple containing the final state.
  """
    if tensor_util.is_tensor(iter_):
        if tensors.is_range_tensor(iter_):
            return _tf_range_for_stmt(iter_, extra_test, body, get_state,
                                      set_state, init_vars, basic_symbol_names,
                                      composite_symbol_names, opts)
        else:
            return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
                                          set_state, init_vars,
                                          basic_symbol_names,
                                          composite_symbol_names, opts)

    if isinstance(iter_, dataset_ops.DatasetV2):
        return _tf_dataset_for_stmt(iter_, extra_test, body, get_state,
                                    set_state, init_vars, basic_symbol_names,
                                    composite_symbol_names, opts)

    if isinstance(iter_, iterator_ops.OwnedIterator):
        return _tf_iterator_for_stmt(iter_, extra_test, body, get_state,
                                     set_state, init_vars, basic_symbol_names,
                                     composite_symbol_names, opts)

    if isinstance(iter_, ragged_tensor.RaggedTensor):
        return _tf_ragged_for_stmt(iter_, extra_test, body, get_state,
                                   set_state, init_vars, basic_symbol_names,
                                   composite_symbol_names, opts)

    # Note: This experimental interface is subject to change.
    custom_handler = getattr(iter_, '_autograph_for_loop', None)
    if custom_handler is not None:
        # TODO(mdan): TensorFlow-specific verification - handlers should perform it.
        _disallow_undefs_into_loop(*init_vars)
        # TODO(mdan): Enable get_state/set_state separately.
        return custom_handler(extra_test, body, init_vars)

    return _py_for_stmt(iter_, extra_test, body, get_state, set_state,
                        init_vars)
Exemplo n.º 5
0
def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names,
             opts):
    """Functional form of a for statement.

  The loop operates on a state, which includes all symbols that are
  variant across loop iterations, excluding the variables local to the loop.

  For example, given the loop below that calculates the geometric and
  arithmetic means or some numbers:

  ```
    geo_mean = 1
    arith_mean = 0
    for i in range(n):
      a = numbers[i]
      geo_mean *= a
      arith_mean += a
  ```

  The state is represented by the variables geo_mean and arith_mean. The
  `extra_test`, `body`, `get_state` and `set_state` functions must bind to the
  original `geo_mean` and `arith_mean` symbols, using `nonlocal`.

  Args:
    iter_: The entity being iterated over.
    extra_test: Callable with the state as arguments, and boolean return type.
      An additional loop condition.
    body: Callable with the iterate and the state as arguments, and state as
      return type. The actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    symbol_names: Tuple containing names of the loop variables returned by
      get_state.
    opts: Optional dict of extra loop parameters.

  Returns:
    Tuple containing the final state.
  """
    if tensor_util.is_tensor(iter_):
        if tensors.is_range_tensor(iter_):
            _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                               symbol_names, opts)
        else:
            _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
                                   set_state, symbol_names, opts)

    elif isinstance(iter_, dataset_ops.DatasetV2):
        _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
                             symbol_names, opts)

    elif isinstance(iter_, iterator_ops.OwnedIterator):
        _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
                              symbol_names, opts)

    elif isinstance(iter_, ragged_tensor.RaggedTensor):
        _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
                            symbol_names, opts)

    elif isinstance(iter_, input_lib.DistributedIterator):
        raise NotImplementedError(
            'distributed iterators not supported yet, use the distributed dataset'
            ' directly')

    # TODO(mdan): Resolve the private access issue.
    elif isinstance(iter_, input_lib._IterableInput):  # pylint:disable=protected-access
        _tf_distributed_iterable_for_stmt(iter_, extra_test, body, get_state,
                                          set_state, symbol_names, opts)

    else:
        _py_for_stmt(iter_, extra_test, body, None, None)