def is_range_tensor(self): self.assertTrue(tensors.is_range_tensor(math_ops.range(1))) self.assertTrue(tensors.is_range_tensor(math_ops.range(1, 2))) self.assertTrue(tensors.is_range_tensor(math_ops.range(1, 2, 3))) self.assertFalse(tensors.is_range_tensor(None)) self.assertFalse( tensors.is_range_tensor(constant_op.constant(range(1))))
def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Functional form of a for statement. The loop operates on a state, which includes all symbols that are variant across loop iterations, excluding the variables local to the loop. For example, given the loop below that calculates the geometric and arithmetic means or some numbers: ``` geo_mean = 1 arith_mean = 0 for i in range(n): a = numbers[i] geo_mean *= a arith_mean += a ``` The state is represented by the variables geo_mean and arith_mean. The `extra_test`, `body`, `get_state` and `set_state` functions must bind to the original `geo_mean` and `arith_mean` symbols, using `nonlocal`. The inputs and outputs of the callables representing the loop blocks are not explicit - instead, these functions must use nonlocal/global for side effects. The inputs and outputs are instead controlled by the set_state/get_state functions. Args: iter_: The entity being iterated over. extra_test: Callable with boolean return type. An additional loop condition. body: Callable representing the actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. set_state: Additional callable which save values captured by get_state back into the Python environment. This is only useful when staging the loop. symbol_names: Tuple containing names of the loop variables returned by get_state. opts: Optional dict of extra loop parameters. """ if tensor_util.is_tensor(iter_): if tensors.is_range_tensor(iter_): _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, ragged_tensor.RaggedTensor): _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) else: _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, dataset_ops.DatasetV2): _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, iterator_ops.OwnedIterator): _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, ragged_tensor.RaggedTensor): _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, distribute.Iterator): _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, distribute.Iterable): # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)... _tf_distributed_iterable_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) else: _py_for_stmt(iter_, extra_test, body, None, None)
def for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts): """Functional form of a for statement. The loop operates on a state, which includes all symbols that are variant across loop iterations, excluding the iterate as well as the variables local to the loop. For example, given the loop below that calculates the geometric and arithmetic means or some numbers: geo_mean = 1 arith_mean = 0 for i in range(n): a = numbers[i] geo_mean *= a arith_mean += a The state is represented by the variables geo_mean and arith_mean. The argument for initial_state may contain the tuple (1, 0), the body will include the arguments geo_mean and arith_mean and will return a tuple representing the new values for geo_mean and respectively arith_mean. Args: iter_: The entity being iterated over. extra_test: Callable with the state as arguments, and boolean return type. An additional loop condition. body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. set_state: Additional callable which save values captured by get_state back into the Python environment. This is only useful when staging the loop. init_vars: Tuple containing the initial state. basic_symbol_names: Tuple containing basic loop var names. composite_symbol_names: Tuple containing composite loop var names. opts: Optional dict of extra loop parameters. Returns: Tuple containing the final state. """ if tensor_util.is_tensor(iter_): if tensors.is_range_tensor(iter_): return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) else: return _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) if isinstance(iter_, dataset_ops.DatasetV2): return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) if isinstance(iter_, iterator_ops.OwnedIterator): return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) if isinstance(iter_, ragged_tensor.RaggedTensor): return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) if isinstance(iter_, input_lib.DistributedIterator): raise NotImplementedError( 'distributed iterators not supported yet, use the distributed dataset' ' directly') if isinstance(iter_, input_lib.DistributedDataset): return _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_vars) return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars)
def for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts): """Functional form of a for statement. The loop operates on a state, which includes all symbols that are variant across loop iterations, excluding the iterate as well as the variables local to the loop. For example, given the loop below that calculates the geometric and arithmetic means or some numbers: geo_mean = 1 arith_mean = 0 for i in range(n): a = numbers[i] geo_mean *= a arith_mean += a The state is represented by the variables geo_mean and arith_mean. The argument for initial_state may contain the tuple (1, 0), the body will include the arguments geo_mean and arith_mean and will return a tuple representing the new values for geo_mean and respectively arith_mean. Args: iter_: The entity being iterated over. extra_test: Callable with the state as arguments, and boolean return type. An additional loop condition. body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. set_state: Additional callable which save values captured by get_state back into the Python environment. This is only useful when staging the loop. init_vars: Tuple containing the initial state. basic_symbol_names: Tuple containing basic loop var names. composite_symbol_names: Tuple containing composite loop var names. opts: Optional dict of extra loop parameters. Returns: Tuple containing the final state. """ if tensor_util.is_tensor(iter_): if tensors.is_range_tensor(iter_): return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) else: return _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) if isinstance(iter_, dataset_ops.DatasetV2): return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) if isinstance(iter_, iterator_ops.OwnedIterator): return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) if isinstance(iter_, ragged_tensor.RaggedTensor): return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) # Note: This experimental interface is subject to change. custom_handler = getattr(iter_, '_autograph_for_loop', None) if custom_handler is not None: # TODO(mdan): TensorFlow-specific verification - handlers should perform it. _disallow_undefs_into_loop(*init_vars) # TODO(mdan): Enable get_state/set_state separately. return custom_handler(extra_test, body, init_vars) return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars)
def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Functional form of a for statement. The loop operates on a state, which includes all symbols that are variant across loop iterations, excluding the variables local to the loop. For example, given the loop below that calculates the geometric and arithmetic means or some numbers: ``` geo_mean = 1 arith_mean = 0 for i in range(n): a = numbers[i] geo_mean *= a arith_mean += a ``` The state is represented by the variables geo_mean and arith_mean. The `extra_test`, `body`, `get_state` and `set_state` functions must bind to the original `geo_mean` and `arith_mean` symbols, using `nonlocal`. Args: iter_: The entity being iterated over. extra_test: Callable with the state as arguments, and boolean return type. An additional loop condition. body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. set_state: Additional callable which save values captured by get_state back into the Python environment. This is only useful when staging the loop. symbol_names: Tuple containing names of the loop variables returned by get_state. opts: Optional dict of extra loop parameters. Returns: Tuple containing the final state. """ if tensor_util.is_tensor(iter_): if tensors.is_range_tensor(iter_): _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) else: _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, dataset_ops.DatasetV2): _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, iterator_ops.OwnedIterator): _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, ragged_tensor.RaggedTensor): _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) elif isinstance(iter_, input_lib.DistributedIterator): raise NotImplementedError( 'distributed iterators not supported yet, use the distributed dataset' ' directly') # TODO(mdan): Resolve the private access issue. elif isinstance(iter_, input_lib._IterableInput): # pylint:disable=protected-access _tf_distributed_iterable_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts) else: _py_for_stmt(iter_, extra_test, body, None, None)