def testWithinBatchMixing(self, use_tf_func): # Generate couple files. def generate_test_data(tag, cnt): tmp = os.path.join(tf.test.get_temp_dir(), tag) with tf.python_io.TFRecordWriter(tmp) as w: for i in range(cnt): w.write(('%s:%08d' % (tag, i)).encode('utf-8')) return tmp path1 = generate_test_data('input1', 100) path2 = generate_test_data('input2', 200) path3 = generate_test_data('input3', 10) # A record processor written in TF graph. def _process(source_id, record): return py_utils.NestedMap(source_id=source_id, record=record), 1 # Samples random records from the data files and processes them # to generate batches. resource, out_types, output_tmpl = generic_input.GenericInputV2Create( file_pattern=','.join( ['tfrecord:' + path1, 'tfrecord:' + path2, 'tfrecord:' + path3]), input_source_weights=[0.2, 0.3, 0.5], file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[8], bucket_upper_bound=[1], processor=_process) def _get_batch(): return generic_input.GenericInputV2GetNext(resource, out_types, output_tmpl) if use_tf_func: _get_batch = tf.function(autograph=False)(_get_batch) # pylint: disable=invalid-name source_id_count = collections.defaultdict(int) tags_count = collections.defaultdict(int) total_count = 10000 for _ in range(total_count): ans_input_batch, ans_buckets = _get_batch() for s in ans_input_batch.source_id: # We use `numpy()` to get Tensor's value source_id_count[s.numpy()] += 1 for s in ans_input_batch.record: # We use `numpy()` to get Tensor's value tags_count[s.numpy().split(b':')[0]] += 1 self.assertEqual(ans_input_batch.source_id.shape, (8,)) self.assertEqual(ans_input_batch.record.shape, (8,)) self.assertAllEqual(ans_buckets, [1] * 8) self.assertEqual(sum(source_id_count.values()), total_count * 8) self.assertEqual(sum(tags_count.values()), total_count * 8) num_records = 8. * total_count self.assertAlmostEqual(tags_count[b'input1'] / num_records, 0.2, delta=0.01) self.assertAlmostEqual(tags_count[b'input2'] / num_records, 0.3, delta=0.01) self.assertAlmostEqual(tags_count[b'input3'] / num_records, 0.5, delta=0.01) self.assertAlmostEqual(source_id_count[0] / num_records, 0.2, delta=0.01) self.assertAlmostEqual(source_id_count[1] / num_records, 0.3, delta=0.01) self.assertAlmostEqual(source_id_count[2] / num_records, 0.5, delta=0.01)
def get_test_input(path, bucket_batch_limit=8, **kwargs): return generic_input.GenericInputV2Create( file_pattern='tfrecord:' + path, file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[bucket_batch_limit], **kwargs)
def testFatalErrors(self): tmp = os.path.join(tf.test.get_temp_dir(), 'fatal') with tf.python_io.TFRecordWriter(tmp) as w: for i in range(50): w.write(str((i % 2) * 2**33)) def _parse_record(record): # tf.strings.to_number raises error on overflow. i = tf.strings.to_number(record, tf.int32) example = py_utils.NestedMap(record=i) bucketing_key = 1 return example, bucketing_key # Without specifying fatal_errors all records not 0 are skipped. resource, out_types, output_tmpl = generic_input.GenericInputV2Create( _parse_record, file_pattern=f'tfrecord:{tmp}', bucket_upper_bound=[1], bucket_batch_limit=[1]) for i in range(25): ans_input_batch, _ = generic_input.GenericInputV2GetNext( resource, out_types, output_tmpl) self.assertEqual(ans_input_batch.record[0], 0) # With fatal_errors it dies instead. resource, out_types, output_tmpl = generic_input.GenericInputV2Create( _parse_record, file_pattern=f'tfrecord:{tmp}', bucket_upper_bound=[1], bucket_batch_limit=[1], fatal_errors=['StringToNumberOp could not correctly convert string:']) # NOTE: There is no way to catch LOG(FATAL) from python side, so running # this test will cause a crash. for i in range(10): ans_input_batch, _ = generic_input.GenericInputV2GetNext( resource, out_types, output_tmpl)
def testNestedGenericInput(self, inner_batch_limit, outer_batch_limit): # Generate records using an inner GenericInput, and post-process them using # an outer one. # Test that the generated records are complete and contain no duplicates resource_inner, out_types_inner, output_tmpl_inner = setup_basic( use_nested_map=True, bucket_batch_limit=inner_batch_limit) def _process(record): del record # Construct the inner GenericInput. batch = run_basic( resource_inner, use_nested_map=True, out_types=out_types_inner, output_tmpl=output_tmpl_inner) batch.num += 1 return batch, 1 # The `AssertionError` error will be raised from `GenericInputV2Create()` (resource_outer, out_types_outer, output_tmpl_outer) = generic_input.GenericInputV2Create( file_pattern='iota:', processor=_process, bucket_upper_bound=[1], bucket_batch_limit=[outer_batch_limit]) global_batch = inner_batch_limit * outer_batch_limit record_seen = set() # Iterate the inputs for exactly one epoch. for i in range(100 // global_batch): ans_input_batch, _ = generic_input.GenericInputV2GetNext( resource_outer, out_types_outer, output_tmpl_outer) for record_array in ans_input_batch.record: for s in record_array: # There should not be duplicates since GenericInput is stateful. assert s.numpy() not in record_seen record_seen.add(s.numpy()) self.assertEqual(ans_input_batch.source_id.shape, (outer_batch_limit, inner_batch_limit)) self.assertEqual(ans_input_batch.record.shape, (outer_batch_limit, inner_batch_limit)) self.assertEqual(ans_input_batch.num.shape, (outer_batch_limit, inner_batch_limit, 2)) ans_vals = ans_input_batch.num self.assertAllEqual( np.square(ans_vals[:, :, 0] - 1), ans_vals[:, :, 1] - 1) for i in range(100): self.assertIn(('%08d' % i).encode('utf-8'), record_seen)
def testExtraArgs(self): def _parse_record_stateful(record): del record extra = tf.Variable(0) example = py_utils.NestedMap(t=extra.value()) bucketing_key = 1 return example, bucketing_key with self.assertRaisesRegex(AssertionError, 'is not pure: extra_args='): generic_input.GenericInputV2Create( _parse_record_stateful, file_pattern='', bucket_upper_bound=[1], bucket_batch_limit=[1])