def _dataset_for_stmt_with_extra_test(ds, extra_test, body, init_state): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" def scan_body(state, iterate): extra_cond = extra_test(*state) new_state = control_flow_ops.cond( extra_cond, lambda: body(iterate, *state), lambda: state) aug_state = new_state, extra_cond # Note: new_state is the actual state of scan; aug_state is its output # (hence the redundancy). return new_state, aug_state def take_while_predicate(new_state, extra_cond): del new_state return extra_cond def reduce_body(old_state, aug_state): del old_state new_state, extra_cond = aug_state del extra_cond return new_state ds = ds.apply(scan_ops.scan(init_state, scan_body)) ds = ds.apply(take_while_ops.take_while(take_while_predicate)) return ds.reduce(init_state, reduce_body)
def CounterV2(start=0, step=1, dtype=dtypes.int64): """Creates a `Dataset` that counts from `start` in steps of size `step`. For example: ```python Dataset.count() == [0, 1, 2, ...) Dataset.count(2) == [2, 3, ...) Dataset.count(2, 5) == [2, 7, 12, ...) Dataset.count(0, -1) == [0, -1, -2, ...) Dataset.count(10, -1) == [10, 9, ...) ``` Args: start: (Optional.) The starting value for the counter. Defaults to 0. step: (Optional.) The step size for the counter. Defaults to 1. dtype: (Optional.) The data type for counter elements. Defaults to `tf.int64`. Returns: A `Dataset` of scalar `dtype` elements. """ with ops.name_scope("counter"): start = ops.convert_to_tensor(start, dtype=dtype, name="start") step = ops.convert_to_tensor(step, dtype=dtype, name="step") return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( scan_ops.scan(start, lambda state, _: (state + step, state)))
def testChangingStateShape(self): # Test the fixed-point shape invariant calculations: start with # initial values with known shapes, and use a scan function that # changes the size of the state on each element. def _scan_fn(state, input_value): # Statically known rank, but dynamic length. ret_longer_vector = array_ops.concat([state[0], state[0]], 0) # Statically unknown rank. ret_larger_rank = array_ops.expand_dims(state[1], 0) return (ret_longer_vector, ret_larger_rank), (state, input_value) dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply( scan_ops.scan(([0], 1), _scan_fn)) self.assertEqual([None], dataset.output_shapes[0][0].as_list()) self.assertIs(None, dataset.output_shapes[0][1].ndims) self.assertEqual([], dataset.output_shapes[1].as_list()) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: for i in range(5): (longer_vector_val, larger_rank_val), _ = self.evaluate(next_element) self.assertAllEqual([0] * (2**i), longer_vector_val) self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testTensorArrayWithCondResetByExternalCaptureBreaks(self): if control_flow_v2_toggles.control_flow_v2_enabled(): self.skipTest("v1 only test") empty_ta = tensor_array_ops.TensorArray(size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) def scan_fn(ta, x): updated = ta.write(ta.size(), x) # Here, capture empty_ta from outside the function. However, it may be # either a TF1-style TensorArray or an Eager-style TensorArray. next_iter = control_flow_ops.cond(math_ops.equal(x % 3, 0), lambda: empty_ta, lambda: updated) return (next_iter, updated.stack()) start = empty_ta start = start.write(0, -1) with self.assertRaisesRegexp( NotImplementedError, r"construct a new TensorArray inside the function"): dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
def Counter(start=0, step=1, dtype=dtypes.int64): """Creates a `Dataset` that counts from `start` in steps of size `step`. For example: ```python Dataset.count() == [0, 1, 2, ...) Dataset.count(2) == [2, 3, ...) Dataset.count(2, 5) == [2, 7, 12, ...) Dataset.count(0, -1) == [0, -1, -2, ...) Dataset.count(10, -1) == [10, 9, ...) ``` Args: start: (Optional.) The starting value for the counter. Defaults to 0. step: (Optional.) The step size for the counter. Defaults to 1. dtype: (Optional.) The data type for counter elements. Defaults to `tf.int64`. Returns: A `Dataset` of scalar `dtype` elements. """ with ops.name_scope("counter"): start = ops.convert_to_tensor(start, dtype=dtype, name="start") step = ops.convert_to_tensor(step, dtype=dtype, name="step") return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( scan_ops.scan(start, lambda state, _: (state + step, state)))
def testTensorArrayWithCondReset(self): def empty(): return tensor_array_ops.TensorArray( size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) def scan_fn(ta, x): updated = ta.write(ta.size(), x) next_iter = control_flow_ops.cond( math_ops.equal(x % 3, 0), empty, lambda: updated) return (next_iter, updated.stack()) start = empty() start = start.write(0, -1) ds = dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn)) self.assertDatasetProduces( ds, expected_output=[ [-1, 0], [1], [1, 2], [1, 2, 3], [4], [4, 5], ], requires_initialization=True, num_test_iterations=2)
def testTensorArraySimple(self): def scan_fn(ta, x): return (ta.write(ta.size(), x), ta.stack()) start = tensor_array_ops.TensorArray( size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) start = start.write(0, -1) ds = dataset_ops.Dataset.range(5).apply(scan_ops.scan(start, scan_fn)) self.assertDatasetProduces( ds, expected_output=[ [-1], [-1, 0], [-1, 0, 1], [-1, 0, 1, 2], [-1, 0, 1, 2, 3], ], requires_initialization=True, num_test_iterations=2)
def testTensorArraySimple(self): def scan_fn(ta, x): return (ta.write(ta.size(), x), ta.stack()) start = tensor_array_ops.TensorArray( size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) start = start.write(0, -1) ds = dataset_ops.Dataset.range(5).apply(scan_ops.scan(start, scan_fn)) self.assertDatasetProduces( ds, expected_output=[ [-1], [-1, 0], [-1, 0, 1], [-1, 0, 1, 2], [-1, 0, 1, 2, 3], ], requires_initialization=True, num_test_iterations=2)
def testTensorArrayWithCondReset(self): def empty(): return tensor_array_ops.TensorArray( size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) def scan_fn(ta, x): updated = ta.write(ta.size(), x) next_iter = control_flow_ops.cond( math_ops.equal(x % 3, 0), empty, lambda: updated) return (next_iter, updated.stack()) start = empty() start = start.write(0, -1) ds = dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn)) self.assertDatasetProduces( ds, expected_output=[ [-1, 0], [1], [1, 2], [1, 2, 3], [4], [4, 5], ], requires_initialization=True, num_test_iterations=2)
def testChangingStateShape(self): # Test the fixed-point shape invariant calculations: start with # initial values with known shapes, and use a scan function that # changes the size of the state on each element. def _scan_fn(state, input_value): # Statically known rank, but dynamic length. ret_longer_vector = array_ops.concat([state[0], state[0]], 0) # Statically unknown rank. ret_larger_rank = array_ops.expand_dims(state[1], 0) return (ret_longer_vector, ret_larger_rank), (state, input_value) dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply( scan_ops.scan(([0], 1), _scan_fn)) self.assertEqual( [None], dataset_ops.get_legacy_output_shapes(dataset)[0][0].as_list()) self.assertIs( None, dataset_ops.get_legacy_output_shapes(dataset)[0][1].ndims) self.assertEqual( [], dataset_ops.get_legacy_output_shapes(dataset)[1].as_list()) next_element = self.getNext(dataset) for i in range(5): (longer_vector_val, larger_rank_val), _ = self.evaluate(next_element()) self.assertAllEqual([0] * (2**i), longer_vector_val) self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element())
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, init_state): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" def scan_body(state, iterate): extra_cond = extra_test(*state) new_state = control_flow_ops.cond( extra_cond, lambda: body(iterate, *state), lambda: state) aug_state = new_state, extra_cond # Note: new_state is the actual state of scan; aug_state is its output # (hence the redundancy). return new_state, aug_state def take_while_predicate(new_state, extra_cond): del new_state return extra_cond def reduce_body(old_state, aug_state): del old_state new_state, extra_cond = aug_state del extra_cond return new_state ds = ds.apply(scan_ops.scan(init_state, scan_body)) ds = ds.apply(take_while_ops.take_while(take_while_predicate)) return ds.reduce(init_state, reduce_body)
def testStatefulScanNotCheckpointable(self): dataset = dataset_ops.Dataset.range(10) def stateful_scan(state, element): return state, self._statefulBoolFunc(element) dataset = dataset.apply(scan_ops.scan(0, stateful_scan)) self._assertNotCheckpointable(dataset)
def testUnsupportedTransformError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder).apply( scan_ops.scan([0], lambda _, a: ([0], a))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testUnsupportedTransformError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder).apply( scan_ops.scan([0], lambda _, a: ([0], a))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testScanAfterBatch(self): dataset = dataset_ops.Dataset.range(40).batch(10).apply( scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output)
def testFibonacci(self): data = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) next_element = self.getNext(data) self.assertEqual(1, self.evaluate(next_element())) self.assertEqual(1, self.evaluate(next_element())) self.assertEqual(2, self.evaluate(next_element())) self.assertEqual(3, self.evaluate(next_element())) self.assertEqual(5, self.evaluate(next_element())) self.assertEqual(8, self.evaluate(next_element()))
def testFibonacci(self): data = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) next_element = self.getNext(data) self.assertEqual(1, self.evaluate(next_element())) self.assertEqual(1, self.evaluate(next_element())) self.assertEqual(2, self.evaluate(next_element())) self.assertEqual(3, self.evaluate(next_element())) self.assertEqual(5, self.evaluate(next_element())) self.assertEqual(8, self.evaluate(next_element()))
def testIncorrectStateType(self): def _scan_fn(state, _): return constant_op.constant(1, dtype=dtypes.int64), state dataset = dataset_ops.Dataset.range(10) with self.assertRaisesRegexp( TypeError, "The element types for the new state must match the initial state."): dataset.apply( scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
def testIncorrectStateType(self): def _scan_fn(state, _): return constant_op.constant(1, dtype=dtypes.int64), state dataset = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( TypeError, "The element types for the new state must match the initial state."): dataset.apply( scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
def testIncorrectReturnType(self): def _scan_fn(unused_state, unused_input_value): return constant_op.constant(1, dtype=dtypes.int64) dataset = dataset_ops.Dataset.range(10) with self.assertRaisesRegexp( TypeError, "The scan function must return a pair comprising the new state and the " "output value."): dataset.apply( scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
def testIncorrectReturnType(self): def _scan_fn(unused_state, unused_input_value): return constant_op.constant(1, dtype=dtypes.int64) dataset = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( TypeError, "The scan function must return a pair comprising the new state and the " "output value."): dataset.apply( scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
def testPreserveCardinality(self): def scan_fn(state, val): def py_fn(_): raise StopIteration() return state, script_ops.py_func(py_fn, [val], dtypes.int64) dataset = dataset_ops.Dataset.from_tensors(0).apply( scan_ops.scan(constant_op.constant(1), scan_fn)) get_next = self.getNext(dataset) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next())
def testPreserveCardinality(self): def scan_fn(state, val): def py_fn(_): raise StopIteration() return state, script_ops.py_func(py_fn, [val], dtypes.int64) dataset = dataset_ops.Dataset.from_tensors(0).apply( scan_ops.scan(constant_op.constant(1), scan_fn)) get_next = self.getNext(dataset) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next())
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" # TODO(mdan): Simplify this - following it is extremely difficult. def scan_body(aug_vars, iterate): """The main loop body wrapper. Only calculates the stop condition.""" loop_vars, state = aug_vars def true_fn(): set_state(state) outputs = body(iterate, *loop_vars) _verify_tf_loop_vars(loop_vars + state, outputs + state, basic_symbol_names, composite_symbol_names, include_shapes=False) return outputs, get_state() extra_cond = extra_test(*loop_vars) new_vars, new_state = control_flow_ops.cond(extra_cond, true_fn, lambda: (loop_vars, state)) scan_outputs = new_vars, new_state, extra_cond # Note: new_aug_vars is the actual state of scan; scan_outputs is its output # (hence the redundancy). # get_state will pull any mutations that body may have made. new_aug_vars = new_vars, new_state return new_aug_vars, scan_outputs def take_while_predicate(unused_loop_vars, unused_state, extra_cond): return extra_cond def reduce_body(unused_aug_vars, scan_outputs): output_aug_vars, output_state, extra_cond = scan_outputs del extra_cond return output_aug_vars, output_state init_state = get_state() aug_vars = init_vars, init_state ds = ds.apply(scan_ops.scan(aug_vars, scan_body)) ds = ds.apply(take_while_ops.take_while(take_while_predicate)) final_aug_vars = ds.reduce(aug_vars, reduce_body) final_vars, final_state = final_aug_vars set_state(final_state) return final_vars
def testFibonacci(self): iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))) if context.executing_eagerly(): next_element = iterator.get_next else: get_next = iterator.get_next() next_element = lambda: get_next self.assertEqual(1, self.evaluate(next_element())) self.assertEqual(1, self.evaluate(next_element())) self.assertEqual(2, self.evaluate(next_element())) self.assertEqual(3, self.evaluate(next_element())) self.assertEqual(5, self.evaluate(next_element())) self.assertEqual(8, self.evaluate(next_element()))
def testFibonacci(self): iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) ).make_one_shot_iterator() if context.executing_eagerly(): next_element = iterator.get_next else: get_next = iterator.get_next() next_element = lambda: get_next self.assertEqual(1, self.evaluate(next_element())) self.assertEqual(1, self.evaluate(next_element())) self.assertEqual(2, self.evaluate(next_element())) self.assertEqual(3, self.evaluate(next_element())) self.assertEqual(5, self.evaluate(next_element())) self.assertEqual(8, self.evaluate(next_element()))
def _estimate_initial_dist_ds(target_dist_t, class_values_ds, dist_estimation_batch_size=32, smoothing_constant=10): num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0]) initial_examples_per_class_seen = array_ops.fill( [num_classes], np.int64(smoothing_constant)) def update_estimate_and_tile(num_examples_per_class_seen, c): updated_examples_per_class_seen, dist = _estimate_data_distribution( c, num_examples_per_class_seen) tiled_dist = array_ops.tile(array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) return updated_examples_per_class_seen, tiled_dist initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size).apply( scan_ops.scan(initial_examples_per_class_seen, update_estimate_and_tile)).unbatch()) return initial_dist_ds
def testTensorArrayWithCondResetByExternalCaptureBreaks(self): empty_ta = tensor_array_ops.TensorArray( size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) def scan_fn(ta, x): updated = ta.write(ta.size(), x) # Here, capture empty_ta from outside the function. However, it may be # either a TF1-style TensorArray or an Eager-style TensorArray. next_iter = control_flow_ops.cond( math_ops.equal(x % 3, 0), lambda: empty_ta, lambda: updated) return (next_iter, updated.stack()) start = empty_ta start = start.write(0, -1) with self.assertRaisesRegexp( NotImplementedError, r"construct a new TensorArray inside the function"): dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
def scan(initial_state, scan_func): """A transformation that scans a function across an input dataset. This transformation is a stateful relative of `tf.data.Dataset.map`. In addition to mapping `scan_func` across the elements of the input dataset, `scan()` accumulates one or more state tensors, whose initial values are `initial_state`. Args: initial_state: A nested structure of tensors, representing the initial state of the accumulator. scan_func: A function that maps `(old_state, input_element)` to `(new_state, output_element). It must take two arguments and return a pair of nested structures of tensors. The `new_state` must match the structure of `initial_state`. Returns: A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ return scan_ops.scan(initial_state, scan_func)
def _estimate_initial_dist_ds( target_dist_t, class_values_ds, dist_estimation_batch_size=32, smoothing_constant=10): num_classes = (target_dist_t.shape[0].value or array_ops.shape(target_dist_t)[0]) initial_examples_per_class_seen = array_ops.fill( [num_classes], np.int64(smoothing_constant)) def update_estimate_and_tile(num_examples_per_class_seen, c): updated_examples_per_class_seen, dist = _estimate_data_distribution( c, num_examples_per_class_seen) tiled_dist = array_ops.tile( array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) return updated_examples_per_class_seen, tiled_dist initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) .apply(scan_ops.scan(initial_examples_per_class_seen, update_estimate_and_tile)) .apply(batching.unbatch())) return initial_dist_ds
def scan(initial_state, scan_func): """A transformation that scans a function across an input dataset. This transformation is a stateful relative of `tf.data.Dataset.map`. In addition to mapping `scan_func` across the elements of the input dataset, `scan()` accumulates one or more state tensors, whose initial values are `initial_state`. Args: initial_state: A nested structure of tensors, representing the initial state of the accumulator. scan_func: A function that maps `(old_state, input_element)` to `(new_state, output_element). It must take two arguments and return a pair of nested structures of tensors. The `new_state` must match the structure of `initial_state`. Returns: A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ return scan_ops.scan(initial_state, scan_func)
def CounterV2(start=0, step=1, dtype=dtypes.int64): """Creates a `Dataset` that counts from `start` in steps of size `step`. Unlike `tf.data.Dataset.range` which will stop at some ending number, `Counter` will produce elements indefinitely. >>> dataset = tf.data.experimental.Counter().take(5) >>> list(dataset.as_numpy_iterator()) [0, 1, 2, 3, 4] >>> dataset.element_spec TensorSpec(shape=(), dtype=tf.int64, name=None) >>> dataset = tf.data.experimental.Counter(dtype=tf.int32) >>> dataset.element_spec TensorSpec(shape=(), dtype=tf.int32, name=None) >>> dataset = tf.data.experimental.Counter(start=2).take(5) >>> list(dataset.as_numpy_iterator()) [2, 3, 4, 5, 6] >>> dataset = tf.data.experimental.Counter(start=2, step=5).take(5) >>> list(dataset.as_numpy_iterator()) [2, 7, 12, 17, 22] >>> dataset = tf.data.experimental.Counter(start=10, step=-1).take(5) >>> list(dataset.as_numpy_iterator()) [10, 9, 8, 7, 6] Args: start: (Optional.) The starting value for the counter. Defaults to 0. step: (Optional.) The step size for the counter. Defaults to 1. dtype: (Optional.) The data type for counter elements. Defaults to `tf.int64`. Returns: A `Dataset` of scalar `dtype` elements. """ with ops.name_scope("counter"): start = ops.convert_to_tensor(start, dtype=dtype, name="start") step = ops.convert_to_tensor(step, dtype=dtype, name="step") return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( scan_ops.scan(start, lambda state, _: (state + step, state)))
def _counting_dataset(self, start, scan_fn): return dataset_ops.Dataset.from_tensors(0).repeat().apply( scan_ops.scan(start, scan_fn))
def _build_dataset(self, num_elements): return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
def _build_dataset(self, num_elements): return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
def make_scan_dataset(var): return dataset_ops.Dataset.from_tensors(0).apply( scan_ops.scan( 0, lambda old_state, elem: (old_state + 1, elem + old_state + var)))
def make_scan_dataset(var): return dataset_ops.Dataset.from_tensors(0).apply( scan_ops.scan( 0, lambda old_state, elem: (old_state + 1, elem + old_state + var)))
def _counting_dataset(self, start, scan_fn): return dataset_ops.Dataset.from_tensors(0).repeat().apply( scan_ops.scan(start, scan_fn))