Beispiel #1
0
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, init_state):
  """Overload of _dataset_for_stmt with early stopping. See for_stmt."""

  def scan_body(state, iterate):
    extra_cond = extra_test(*state)
    new_state = control_flow_ops.cond(
        extra_cond, lambda: body(iterate, *state), lambda: state)
    aug_state = new_state, extra_cond
    # Note: new_state is the actual state of scan; aug_state is its output
    # (hence the redundancy).
    return new_state, aug_state

  def take_while_predicate(new_state, extra_cond):
    del new_state
    return extra_cond

  def reduce_body(old_state, aug_state):
    del old_state
    new_state, extra_cond = aug_state
    del extra_cond
    return new_state

  ds = ds.apply(scan_ops.scan(init_state, scan_body))
  ds = ds.apply(take_while_ops.take_while(take_while_predicate))
  return ds.reduce(init_state, reduce_body)
Beispiel #2
0
def CounterV2(start=0, step=1, dtype=dtypes.int64):
    """Creates a `Dataset` that counts from `start` in steps of size `step`.

  For example:

  ```python
  Dataset.count() == [0, 1, 2, ...)
  Dataset.count(2) == [2, 3, ...)
  Dataset.count(2, 5) == [2, 7, 12, ...)
  Dataset.count(0, -1) == [0, -1, -2, ...)
  Dataset.count(10, -1) == [10, 9, ...)
  ```

  Args:
    start: (Optional.) The starting value for the counter. Defaults to 0.
    step: (Optional.) The step size for the counter. Defaults to 1.
    dtype: (Optional.) The data type for counter elements. Defaults to
      `tf.int64`.

  Returns:
    A `Dataset` of scalar `dtype` elements.
  """
    with ops.name_scope("counter"):
        start = ops.convert_to_tensor(start, dtype=dtype, name="start")
        step = ops.convert_to_tensor(step, dtype=dtype, name="step")
        return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
            scan_ops.scan(start, lambda state, _: (state + step, state)))
Beispiel #3
0
  def testChangingStateShape(self):
    # Test the fixed-point shape invariant calculations: start with
    # initial values with known shapes, and use a scan function that
    # changes the size of the state on each element.
    def _scan_fn(state, input_value):
      # Statically known rank, but dynamic length.
      ret_longer_vector = array_ops.concat([state[0], state[0]], 0)
      # Statically unknown rank.
      ret_larger_rank = array_ops.expand_dims(state[1], 0)
      return (ret_longer_vector, ret_larger_rank), (state, input_value)

    dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply(
        scan_ops.scan(([0], 1), _scan_fn))
    self.assertEqual([None], dataset.output_shapes[0][0].as_list())
    self.assertIs(None, dataset.output_shapes[0][1].ndims)
    self.assertEqual([], dataset.output_shapes[1].as_list())

    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      for i in range(5):
        (longer_vector_val, larger_rank_val), _ = self.evaluate(next_element)
        self.assertAllEqual([0] * (2**i), longer_vector_val)
        self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)
Beispiel #4
0
    def testTensorArrayWithCondResetByExternalCaptureBreaks(self):

        if control_flow_v2_toggles.control_flow_v2_enabled():
            self.skipTest("v1 only test")

        empty_ta = tensor_array_ops.TensorArray(size=0,
                                                element_shape=[],
                                                dtype=dtypes.int64,
                                                dynamic_size=True)

        def scan_fn(ta, x):
            updated = ta.write(ta.size(), x)
            # Here, capture empty_ta from outside the function.  However, it may be
            # either a TF1-style TensorArray or an Eager-style TensorArray.
            next_iter = control_flow_ops.cond(math_ops.equal(x % 3, 0),
                                              lambda: empty_ta,
                                              lambda: updated)
            return (next_iter, updated.stack())

        start = empty_ta
        start = start.write(0, -1)

        with self.assertRaisesRegexp(
                NotImplementedError,
                r"construct a new TensorArray inside the function"):
            dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
Beispiel #5
0
def Counter(start=0, step=1, dtype=dtypes.int64):
  """Creates a `Dataset` that counts from `start` in steps of size `step`.

  For example:

  ```python
  Dataset.count() == [0, 1, 2, ...)
  Dataset.count(2) == [2, 3, ...)
  Dataset.count(2, 5) == [2, 7, 12, ...)
  Dataset.count(0, -1) == [0, -1, -2, ...)
  Dataset.count(10, -1) == [10, 9, ...)
  ```

  Args:
    start: (Optional.) The starting value for the counter. Defaults to 0.
    step: (Optional.) The step size for the counter. Defaults to 1.
    dtype: (Optional.) The data type for counter elements. Defaults to
      `tf.int64`.

  Returns:
    A `Dataset` of scalar `dtype` elements.
  """
  with ops.name_scope("counter"):
    start = ops.convert_to_tensor(start, dtype=dtype, name="start")
    step = ops.convert_to_tensor(step, dtype=dtype, name="step")
    return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
        scan_ops.scan(start, lambda state, _: (state + step, state)))
Beispiel #6
0
  def testTensorArrayWithCondReset(self):

    def empty():
      return tensor_array_ops.TensorArray(
          size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)

    def scan_fn(ta, x):
      updated = ta.write(ta.size(), x)
      next_iter = control_flow_ops.cond(
          math_ops.equal(x % 3, 0), empty, lambda: updated)
      return (next_iter, updated.stack())

    start = empty()
    start = start.write(0, -1)

    ds = dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))

    self.assertDatasetProduces(
        ds,
        expected_output=[
            [-1, 0],
            [1],
            [1, 2],
            [1, 2, 3],
            [4],
            [4, 5],
        ],
        requires_initialization=True,
        num_test_iterations=2)
Beispiel #7
0
  def testTensorArraySimple(self):

    def scan_fn(ta, x):
      return (ta.write(ta.size(), x), ta.stack())

    start = tensor_array_ops.TensorArray(
        size=0,
        element_shape=[],
        dtype=dtypes.int64,
        dynamic_size=True)
    start = start.write(0, -1)

    ds = dataset_ops.Dataset.range(5).apply(scan_ops.scan(start, scan_fn))

    self.assertDatasetProduces(
        ds,
        expected_output=[
            [-1],
            [-1, 0],
            [-1, 0, 1],
            [-1, 0, 1, 2],
            [-1, 0, 1, 2, 3],
        ],
        requires_initialization=True,
        num_test_iterations=2)
Beispiel #8
0
  def testTensorArraySimple(self):

    def scan_fn(ta, x):
      return (ta.write(ta.size(), x), ta.stack())

    start = tensor_array_ops.TensorArray(
        size=0,
        element_shape=[],
        dtype=dtypes.int64,
        dynamic_size=True)
    start = start.write(0, -1)

    ds = dataset_ops.Dataset.range(5).apply(scan_ops.scan(start, scan_fn))

    self.assertDatasetProduces(
        ds,
        expected_output=[
            [-1],
            [-1, 0],
            [-1, 0, 1],
            [-1, 0, 1, 2],
            [-1, 0, 1, 2, 3],
        ],
        requires_initialization=True,
        num_test_iterations=2)
Beispiel #9
0
  def testTensorArrayWithCondReset(self):

    def empty():
      return tensor_array_ops.TensorArray(
          size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)

    def scan_fn(ta, x):
      updated = ta.write(ta.size(), x)
      next_iter = control_flow_ops.cond(
          math_ops.equal(x % 3, 0), empty, lambda: updated)
      return (next_iter, updated.stack())

    start = empty()
    start = start.write(0, -1)

    ds = dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))

    self.assertDatasetProduces(
        ds,
        expected_output=[
            [-1, 0],
            [1],
            [1, 2],
            [1, 2, 3],
            [4],
            [4, 5],
        ],
        requires_initialization=True,
        num_test_iterations=2)
Beispiel #10
0
  def testChangingStateShape(self):
    # Test the fixed-point shape invariant calculations: start with
    # initial values with known shapes, and use a scan function that
    # changes the size of the state on each element.
    def _scan_fn(state, input_value):
      # Statically known rank, but dynamic length.
      ret_longer_vector = array_ops.concat([state[0], state[0]], 0)
      # Statically unknown rank.
      ret_larger_rank = array_ops.expand_dims(state[1], 0)
      return (ret_longer_vector, ret_larger_rank), (state, input_value)

    dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply(
        scan_ops.scan(([0], 1), _scan_fn))
    self.assertEqual(
        [None], dataset_ops.get_legacy_output_shapes(dataset)[0][0].as_list())
    self.assertIs(
        None, dataset_ops.get_legacy_output_shapes(dataset)[0][1].ndims)
    self.assertEqual(
        [], dataset_ops.get_legacy_output_shapes(dataset)[1].as_list())

    next_element = self.getNext(dataset)

    for i in range(5):
      (longer_vector_val, larger_rank_val), _ = self.evaluate(next_element())
      self.assertAllEqual([0] * (2**i), longer_vector_val)
      self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element())
Beispiel #11
0
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, init_state):
  """Overload of _dataset_for_stmt with early stopping. See for_stmt."""

  def scan_body(state, iterate):
    extra_cond = extra_test(*state)
    new_state = control_flow_ops.cond(
        extra_cond, lambda: body(iterate, *state), lambda: state)
    aug_state = new_state, extra_cond
    # Note: new_state is the actual state of scan; aug_state is its output
    # (hence the redundancy).
    return new_state, aug_state

  def take_while_predicate(new_state, extra_cond):
    del new_state
    return extra_cond

  def reduce_body(old_state, aug_state):
    del old_state
    new_state, extra_cond = aug_state
    del extra_cond
    return new_state

  ds = ds.apply(scan_ops.scan(init_state, scan_body))
  ds = ds.apply(take_while_ops.take_while(take_while_predicate))
  return ds.reduce(init_state, reduce_body)
Beispiel #12
0
  def testStatefulScanNotCheckpointable(self):
    dataset = dataset_ops.Dataset.range(10)

    def stateful_scan(state, element):
      return state, self._statefulBoolFunc(element)

    dataset = dataset.apply(scan_ops.scan(0, stateful_scan))
    self._assertNotCheckpointable(dataset)
Beispiel #13
0
 def testUnsupportedTransformError(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).batch(
       32, drop_remainder=drop_remainder).apply(
           scan_ops.scan([0], lambda _, a: ([0], a)))
   with self.assertRaises(errors.InvalidArgumentError):
     rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
 def testUnsupportedTransformError(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).batch(
       32, drop_remainder=drop_remainder).apply(
           scan_ops.scan([0], lambda _, a: ([0], a)))
   with self.assertRaises(errors.InvalidArgumentError):
     rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
  def testScanAfterBatch(self):
    dataset = dataset_ops.Dataset.range(40).batch(10).apply(
        scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
    dataset = distribute._RebatchDataset(dataset, num_replicas=2)

    self.assertEqual([[None]],
                     [ts.as_list() for ts in _flat_shapes(dataset)])
    expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(dataset, expected_output)
Beispiel #16
0
  def testFibonacci(self):
    data = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
        scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
    next_element = self.getNext(data)

    self.assertEqual(1, self.evaluate(next_element()))
    self.assertEqual(1, self.evaluate(next_element()))
    self.assertEqual(2, self.evaluate(next_element()))
    self.assertEqual(3, self.evaluate(next_element()))
    self.assertEqual(5, self.evaluate(next_element()))
    self.assertEqual(8, self.evaluate(next_element()))
Beispiel #17
0
  def testFibonacci(self):
    data = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
        scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
    next_element = self.getNext(data)

    self.assertEqual(1, self.evaluate(next_element()))
    self.assertEqual(1, self.evaluate(next_element()))
    self.assertEqual(2, self.evaluate(next_element()))
    self.assertEqual(3, self.evaluate(next_element()))
    self.assertEqual(5, self.evaluate(next_element()))
    self.assertEqual(8, self.evaluate(next_element()))
Beispiel #18
0
  def testIncorrectStateType(self):

    def _scan_fn(state, _):
      return constant_op.constant(1, dtype=dtypes.int64), state

    dataset = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegexp(
        TypeError,
        "The element types for the new state must match the initial state."):
      dataset.apply(
          scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
Beispiel #19
0
  def testIncorrectStateType(self):

    def _scan_fn(state, _):
      return constant_op.constant(1, dtype=dtypes.int64), state

    dataset = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegex(
        TypeError,
        "The element types for the new state must match the initial state."):
      dataset.apply(
          scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
Beispiel #20
0
  def testIncorrectReturnType(self):

    def _scan_fn(unused_state, unused_input_value):
      return constant_op.constant(1, dtype=dtypes.int64)

    dataset = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegexp(
        TypeError,
        "The scan function must return a pair comprising the new state and the "
        "output value."):
      dataset.apply(
          scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
Beispiel #21
0
  def testIncorrectReturnType(self):

    def _scan_fn(unused_state, unused_input_value):
      return constant_op.constant(1, dtype=dtypes.int64)

    dataset = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegex(
        TypeError,
        "The scan function must return a pair comprising the new state and the "
        "output value."):
      dataset.apply(
          scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
Beispiel #22
0
    def testPreserveCardinality(self):
        def scan_fn(state, val):
            def py_fn(_):
                raise StopIteration()

            return state, script_ops.py_func(py_fn, [val], dtypes.int64)

        dataset = dataset_ops.Dataset.from_tensors(0).apply(
            scan_ops.scan(constant_op.constant(1), scan_fn))
        get_next = self.getNext(dataset)
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(get_next())
Beispiel #23
0
  def testPreserveCardinality(self):

    def scan_fn(state, val):

      def py_fn(_):
        raise StopIteration()

      return state, script_ops.py_func(py_fn, [val], dtypes.int64)

    dataset = dataset_ops.Dataset.from_tensors(0).apply(
        scan_ops.scan(constant_op.constant(1), scan_fn))
    get_next = self.getNext(dataset)
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(get_next())
Beispiel #24
0
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
                                      set_state, init_vars, basic_symbol_names,
                                      composite_symbol_names):
    """Overload of _dataset_for_stmt with early stopping. See for_stmt."""

    # TODO(mdan): Simplify this - following it is extremely difficult.

    def scan_body(aug_vars, iterate):
        """The main loop body wrapper. Only calculates the stop condition."""
        loop_vars, state = aug_vars

        def true_fn():
            set_state(state)
            outputs = body(iterate, *loop_vars)
            _verify_tf_loop_vars(loop_vars + state,
                                 outputs + state,
                                 basic_symbol_names,
                                 composite_symbol_names,
                                 include_shapes=False)
            return outputs, get_state()

        extra_cond = extra_test(*loop_vars)
        new_vars, new_state = control_flow_ops.cond(extra_cond, true_fn,
                                                    lambda: (loop_vars, state))

        scan_outputs = new_vars, new_state, extra_cond
        # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
        # (hence the redundancy).
        # get_state will pull any mutations that body may have made.
        new_aug_vars = new_vars, new_state
        return new_aug_vars, scan_outputs

    def take_while_predicate(unused_loop_vars, unused_state, extra_cond):
        return extra_cond

    def reduce_body(unused_aug_vars, scan_outputs):
        output_aug_vars, output_state, extra_cond = scan_outputs
        del extra_cond
        return output_aug_vars, output_state

    init_state = get_state()
    aug_vars = init_vars, init_state
    ds = ds.apply(scan_ops.scan(aug_vars, scan_body))
    ds = ds.apply(take_while_ops.take_while(take_while_predicate))
    final_aug_vars = ds.reduce(aug_vars, reduce_body)
    final_vars, final_state = final_aug_vars
    set_state(final_state)
    return final_vars
Beispiel #25
0
  def testFibonacci(self):
    iterator = dataset_ops.make_one_shot_iterator(
        dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
            scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))))

    if context.executing_eagerly():
      next_element = iterator.get_next
    else:
      get_next = iterator.get_next()
      next_element = lambda: get_next

    self.assertEqual(1, self.evaluate(next_element()))
    self.assertEqual(1, self.evaluate(next_element()))
    self.assertEqual(2, self.evaluate(next_element()))
    self.assertEqual(3, self.evaluate(next_element()))
    self.assertEqual(5, self.evaluate(next_element()))
    self.assertEqual(8, self.evaluate(next_element()))
Beispiel #26
0
  def testFibonacci(self):
    iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
        scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))
    ).make_one_shot_iterator()

    if context.executing_eagerly():
      next_element = iterator.get_next
    else:
      get_next = iterator.get_next()
      next_element = lambda: get_next

    self.assertEqual(1, self.evaluate(next_element()))
    self.assertEqual(1, self.evaluate(next_element()))
    self.assertEqual(2, self.evaluate(next_element()))
    self.assertEqual(3, self.evaluate(next_element()))
    self.assertEqual(5, self.evaluate(next_element()))
    self.assertEqual(8, self.evaluate(next_element()))
Beispiel #27
0
def _estimate_initial_dist_ds(target_dist_t,
                              class_values_ds,
                              dist_estimation_batch_size=32,
                              smoothing_constant=10):
    num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0])
    initial_examples_per_class_seen = array_ops.fill(
        [num_classes], np.int64(smoothing_constant))

    def update_estimate_and_tile(num_examples_per_class_seen, c):
        updated_examples_per_class_seen, dist = _estimate_data_distribution(
            c, num_examples_per_class_seen)
        tiled_dist = array_ops.tile(array_ops.expand_dims(dist, 0),
                                    [dist_estimation_batch_size, 1])
        return updated_examples_per_class_seen, tiled_dist

    initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size).apply(
        scan_ops.scan(initial_examples_per_class_seen,
                      update_estimate_and_tile)).unbatch())

    return initial_dist_ds
Beispiel #28
0
  def testTensorArrayWithCondResetByExternalCaptureBreaks(self):

    empty_ta = tensor_array_ops.TensorArray(
        size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)

    def scan_fn(ta, x):
      updated = ta.write(ta.size(), x)
      # Here, capture empty_ta from outside the function.  However, it may be
      # either a TF1-style TensorArray or an Eager-style TensorArray.
      next_iter = control_flow_ops.cond(
          math_ops.equal(x % 3, 0), lambda: empty_ta, lambda: updated)
      return (next_iter, updated.stack())

    start = empty_ta
    start = start.write(0, -1)

    with self.assertRaisesRegexp(
        NotImplementedError,
        r"construct a new TensorArray inside the function"):
      dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
Beispiel #29
0
def scan(initial_state, scan_func):
  """A transformation that scans a function across an input dataset.

  This transformation is a stateful relative of `tf.data.Dataset.map`.
  In addition to mapping `scan_func` across the elements of the input dataset,
  `scan()` accumulates one or more state tensors, whose initial values are
  `initial_state`.

  Args:
    initial_state: A nested structure of tensors, representing the initial state
      of the accumulator.
    scan_func: A function that maps `(old_state, input_element)` to
      `(new_state, output_element). It must take two arguments and return a
      pair of nested structures of tensors. The `new_state` must match the
      structure of `initial_state`.

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.
  """
  return scan_ops.scan(initial_state, scan_func)
Beispiel #30
0
def _estimate_initial_dist_ds(
    target_dist_t, class_values_ds, dist_estimation_batch_size=32,
    smoothing_constant=10):
  num_classes = (target_dist_t.shape[0].value or
                 array_ops.shape(target_dist_t)[0])
  initial_examples_per_class_seen = array_ops.fill(
      [num_classes], np.int64(smoothing_constant))

  def update_estimate_and_tile(num_examples_per_class_seen, c):
    updated_examples_per_class_seen, dist = _estimate_data_distribution(
        c, num_examples_per_class_seen)
    tiled_dist = array_ops.tile(
        array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
    return updated_examples_per_class_seen, tiled_dist

  initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
                     .apply(scan_ops.scan(initial_examples_per_class_seen,
                                          update_estimate_and_tile))
                     .apply(batching.unbatch()))

  return initial_dist_ds
Beispiel #31
0
def scan(initial_state, scan_func):
    """A transformation that scans a function across an input dataset.

  This transformation is a stateful relative of `tf.data.Dataset.map`.
  In addition to mapping `scan_func` across the elements of the input dataset,
  `scan()` accumulates one or more state tensors, whose initial values are
  `initial_state`.

  Args:
    initial_state: A nested structure of tensors, representing the initial state
      of the accumulator.
    scan_func: A function that maps `(old_state, input_element)` to
      `(new_state, output_element). It must take two arguments and return a
      pair of nested structures of tensors. The `new_state` must match the
      structure of `initial_state`.

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.
  """
    return scan_ops.scan(initial_state, scan_func)
Beispiel #32
0
def CounterV2(start=0, step=1, dtype=dtypes.int64):
    """Creates a `Dataset` that counts from `start` in steps of size `step`.

  Unlike `tf.data.Dataset.range` which will stop at some ending number,
  `Counter` will produce elements indefinitely.

  >>> dataset = tf.data.experimental.Counter().take(5)
  >>> list(dataset.as_numpy_iterator())
  [0, 1, 2, 3, 4]
  >>> dataset.element_spec
  TensorSpec(shape=(), dtype=tf.int64, name=None)
  >>> dataset = tf.data.experimental.Counter(dtype=tf.int32)
  >>> dataset.element_spec
  TensorSpec(shape=(), dtype=tf.int32, name=None)
  >>> dataset = tf.data.experimental.Counter(start=2).take(5)
  >>> list(dataset.as_numpy_iterator())
  [2, 3, 4, 5, 6]
  >>> dataset = tf.data.experimental.Counter(start=2, step=5).take(5)
  >>> list(dataset.as_numpy_iterator())
  [2, 7, 12, 17, 22]
  >>> dataset = tf.data.experimental.Counter(start=10, step=-1).take(5)
  >>> list(dataset.as_numpy_iterator())
  [10, 9, 8, 7, 6]

  Args:
    start: (Optional.) The starting value for the counter. Defaults to 0.
    step: (Optional.) The step size for the counter. Defaults to 1.
    dtype: (Optional.) The data type for counter elements. Defaults to
      `tf.int64`.

  Returns:
    A `Dataset` of scalar `dtype` elements.
  """
    with ops.name_scope("counter"):
        start = ops.convert_to_tensor(start, dtype=dtype, name="start")
        step = ops.convert_to_tensor(step, dtype=dtype, name="step")
        return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
            scan_ops.scan(start, lambda state, _: (state + step, state)))
Beispiel #33
0
 def _counting_dataset(self, start, scan_fn):
   return dataset_ops.Dataset.from_tensors(0).repeat().apply(
       scan_ops.scan(start, scan_fn))
Beispiel #34
0
 def _build_dataset(self, num_elements):
   return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
       scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
 def _build_dataset(self, num_elements):
   return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
       scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
 def make_scan_dataset(var):
   return dataset_ops.Dataset.from_tensors(0).apply(
       scan_ops.scan(
           0, lambda old_state, elem: (old_state + 1, elem + old_state + var)))
 def make_scan_dataset(var):
     return dataset_ops.Dataset.from_tensors(0).apply(
         scan_ops.scan(
             0, lambda old_state, elem:
             (old_state + 1, elem + old_state + var)))
Beispiel #38
0
 def _counting_dataset(self, start, scan_fn):
   return dataset_ops.Dataset.from_tensors(0).repeat().apply(
       scan_ops.scan(start, scan_fn))