def testWithinBatchMixing(self, use_tf_func):
    # Generate couple files.
    def generate_test_data(tag, cnt):
      tmp = os.path.join(tf.test.get_temp_dir(), tag)
      with tf.python_io.TFRecordWriter(tmp) as w:
        for i in range(cnt):
          w.write(('%s:%08d' % (tag, i)).encode('utf-8'))
      return tmp

    path1 = generate_test_data('input1', 100)
    path2 = generate_test_data('input2', 200)
    path3 = generate_test_data('input3', 10)

    # A record processor written in TF graph.
    def _process(source_id, record):
      return py_utils.NestedMap(source_id=source_id, record=record), 1

    # Samples random records from the data files and processes them
    # to generate batches.
    resource, out_types, output_tmpl = generic_input.GenericInputV2Create(
        file_pattern=','.join(
            ['tfrecord:' + path1, 'tfrecord:' + path2, 'tfrecord:' + path3]),
        input_source_weights=[0.2, 0.3, 0.5],
        file_random_seed=0,
        file_buffer_size=32,
        file_parallelism=4,
        bucket_batch_limit=[8],
        bucket_upper_bound=[1],
        processor=_process)

    def _get_batch():
      return generic_input.GenericInputV2GetNext(resource, out_types,
                                                 output_tmpl)

    if use_tf_func:
      _get_batch = tf.function(autograph=False)(_get_batch)  # pylint: disable=invalid-name

    source_id_count = collections.defaultdict(int)
    tags_count = collections.defaultdict(int)
    total_count = 10000
    for _ in range(total_count):
      ans_input_batch, ans_buckets = _get_batch()
      for s in ans_input_batch.source_id:
        # We use `numpy()` to get Tensor's value
        source_id_count[s.numpy()] += 1
      for s in ans_input_batch.record:
        # We use `numpy()` to get Tensor's value
        tags_count[s.numpy().split(b':')[0]] += 1
      self.assertEqual(ans_input_batch.source_id.shape, (8,))
      self.assertEqual(ans_input_batch.record.shape, (8,))
      self.assertAllEqual(ans_buckets, [1] * 8)
    self.assertEqual(sum(source_id_count.values()), total_count * 8)
    self.assertEqual(sum(tags_count.values()), total_count * 8)
    num_records = 8. * total_count
    self.assertAlmostEqual(tags_count[b'input1'] / num_records, 0.2, delta=0.01)
    self.assertAlmostEqual(tags_count[b'input2'] / num_records, 0.3, delta=0.01)
    self.assertAlmostEqual(tags_count[b'input3'] / num_records, 0.5, delta=0.01)
    self.assertAlmostEqual(source_id_count[0] / num_records, 0.2, delta=0.01)
    self.assertAlmostEqual(source_id_count[1] / num_records, 0.3, delta=0.01)
    self.assertAlmostEqual(source_id_count[2] / num_records, 0.5, delta=0.01)
def get_test_input(path, bucket_batch_limit=8, **kwargs):
  return generic_input.GenericInputV2Create(
      file_pattern='tfrecord:' + path,
      file_random_seed=0,
      file_buffer_size=32,
      file_parallelism=4,
      bucket_batch_limit=[bucket_batch_limit],
      **kwargs)
  def testFatalErrors(self):
    tmp = os.path.join(tf.test.get_temp_dir(), 'fatal')
    with tf.python_io.TFRecordWriter(tmp) as w:
      for i in range(50):
        w.write(str((i % 2) * 2**33))

    def _parse_record(record):
      # tf.strings.to_number raises error on overflow.
      i = tf.strings.to_number(record, tf.int32)
      example = py_utils.NestedMap(record=i)
      bucketing_key = 1
      return example, bucketing_key

    # Without specifying fatal_errors all records not 0 are skipped.
    resource, out_types, output_tmpl = generic_input.GenericInputV2Create(
        _parse_record,
        file_pattern=f'tfrecord:{tmp}',
        bucket_upper_bound=[1],
        bucket_batch_limit=[1])

    for i in range(25):
      ans_input_batch, _ = generic_input.GenericInputV2GetNext(
          resource, out_types, output_tmpl)
      self.assertEqual(ans_input_batch.record[0], 0)

    # With fatal_errors it dies instead.
    resource, out_types, output_tmpl = generic_input.GenericInputV2Create(
        _parse_record,
        file_pattern=f'tfrecord:{tmp}',
        bucket_upper_bound=[1],
        bucket_batch_limit=[1],
        fatal_errors=['StringToNumberOp could not correctly convert string:'])

    # NOTE: There is no way to catch LOG(FATAL) from python side, so running
    # this test will cause a crash.
    for i in range(10):
      ans_input_batch, _ = generic_input.GenericInputV2GetNext(
          resource, out_types, output_tmpl)
  def testNestedGenericInput(self, inner_batch_limit, outer_batch_limit):
    # Generate records using an inner GenericInput, and post-process them using
    # an outer one.
    # Test that the generated records are complete and contain no duplicates

    resource_inner, out_types_inner, output_tmpl_inner = setup_basic(
        use_nested_map=True, bucket_batch_limit=inner_batch_limit)

    def _process(record):
      del record
      # Construct the inner GenericInput.
      batch = run_basic(
          resource_inner,
          use_nested_map=True,
          out_types=out_types_inner,
          output_tmpl=output_tmpl_inner)
      batch.num += 1
      return batch, 1

    # The `AssertionError` error will be raised from `GenericInputV2Create()`
    (resource_outer, out_types_outer,
     output_tmpl_outer) = generic_input.GenericInputV2Create(
         file_pattern='iota:',
         processor=_process,
         bucket_upper_bound=[1],
         bucket_batch_limit=[outer_batch_limit])

    global_batch = inner_batch_limit * outer_batch_limit
    record_seen = set()
    # Iterate the inputs for exactly one epoch.
    for i in range(100 // global_batch):
      ans_input_batch, _ = generic_input.GenericInputV2GetNext(
          resource_outer, out_types_outer, output_tmpl_outer)
      for record_array in ans_input_batch.record:
        for s in record_array:
          # There should not be duplicates since GenericInput is stateful.
          assert s.numpy() not in record_seen
          record_seen.add(s.numpy())
      self.assertEqual(ans_input_batch.source_id.shape,
                       (outer_batch_limit, inner_batch_limit))
      self.assertEqual(ans_input_batch.record.shape,
                       (outer_batch_limit, inner_batch_limit))
      self.assertEqual(ans_input_batch.num.shape,
                       (outer_batch_limit, inner_batch_limit, 2))
      ans_vals = ans_input_batch.num
      self.assertAllEqual(
          np.square(ans_vals[:, :, 0] - 1), ans_vals[:, :, 1] - 1)
    for i in range(100):
      self.assertIn(('%08d' % i).encode('utf-8'), record_seen)
  def testExtraArgs(self):

    def _parse_record_stateful(record):
      del record
      extra = tf.Variable(0)
      example = py_utils.NestedMap(t=extra.value())
      bucketing_key = 1
      return example, bucketing_key

    with self.assertRaisesRegex(AssertionError, 'is not pure: extra_args='):
      generic_input.GenericInputV2Create(
          _parse_record_stateful,
          file_pattern='',
          bucket_upper_bound=[1],
          bucket_batch_limit=[1])