def testBoolDType(self, use_tf_func): tmp = os.path.join(tf.test.get_temp_dir(), 'bool') with tf.python_io.TFRecordWriter(tmp) as w: for i in range(50): w.write(pickle.dumps(True if i % 2 == 0 else False)) # A record processor written in TF graph. def _process(record): bucket_key = 1 num, = tf.py_func(pickle.loads, [record], [tf.bool]) return [num], bucket_key # Samples random records from the data files and processes them # to generate batches. resource, out_types, output_tmpl = get_test_input( tmp, bucket_upper_bound=[1], processor=_process) def _get_batch(): return generic_input.GenericInputV2GetNext(resource, out_types, output_tmpl) if use_tf_func: _get_batch = tf.function(autograph=False)(_get_batch) # pylint: disable=invalid-name for _ in range(10): inputs_vals, _ = _get_batch() self.assertEqual(inputs_vals[0].dtype, bool)
def testBasic(self, use_nested_map, bucket_batch_limit, use_tf_func): resource, out_types, output_tmpl = setup_basic( use_nested_map=use_nested_map, bucket_batch_limit=bucket_batch_limit) def _get_batch(): return run_basic( resource, use_nested_map=use_nested_map, out_types=out_types, output_tmpl=output_tmpl) if use_tf_func: _get_batch = tf.function(autograph=False)(_get_batch) # pylint: disable=invalid-name record_seen = set() # Iterate for 1 epoch for _ in range(100 // bucket_batch_limit + 1): ans_input_batch = _get_batch() for s in ans_input_batch.record: record_seen.add(s.numpy()) self.assertEqual(ans_input_batch.source_id.shape, (bucket_batch_limit,)) self.assertEqual(ans_input_batch.record.shape, (bucket_batch_limit,)) self.assertEqual(ans_input_batch.num.shape, (bucket_batch_limit, 2)) ans_vals = ans_input_batch.num self.assertAllEqual(np.square(ans_vals[:, 0]), ans_vals[:, 1]) for i in range(100): self.assertIn(('%08d' % i).encode('utf-8'), record_seen)
def testWithinBatchMixing(self, use_tf_func): # Generate couple files. def generate_test_data(tag, cnt): tmp = os.path.join(tf.test.get_temp_dir(), tag) with tf.python_io.TFRecordWriter(tmp) as w: for i in range(cnt): w.write(('%s:%08d' % (tag, i)).encode('utf-8')) return tmp path1 = generate_test_data('input1', 100) path2 = generate_test_data('input2', 200) path3 = generate_test_data('input3', 10) # A record processor written in TF graph. def _process(source_id, record): return py_utils.NestedMap(source_id=source_id, record=record), 1 # Samples random records from the data files and processes them # to generate batches. resource, out_types, output_tmpl = generic_input.GenericInputV2Create( file_pattern=','.join( ['tfrecord:' + path1, 'tfrecord:' + path2, 'tfrecord:' + path3]), input_source_weights=[0.2, 0.3, 0.5], file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[8], bucket_upper_bound=[1], processor=_process) def _get_batch(): return generic_input.GenericInputV2GetNext(resource, out_types, output_tmpl) if use_tf_func: _get_batch = tf.function(autograph=False)(_get_batch) # pylint: disable=invalid-name source_id_count = collections.defaultdict(int) tags_count = collections.defaultdict(int) total_count = 10000 for _ in range(total_count): ans_input_batch, ans_buckets = _get_batch() for s in ans_input_batch.source_id: # We use `numpy()` to get Tensor's value source_id_count[s.numpy()] += 1 for s in ans_input_batch.record: # We use `numpy()` to get Tensor's value tags_count[s.numpy().split(b':')[0]] += 1 self.assertEqual(ans_input_batch.source_id.shape, (8,)) self.assertEqual(ans_input_batch.record.shape, (8,)) self.assertAllEqual(ans_buckets, [1] * 8) self.assertEqual(sum(source_id_count.values()), total_count * 8) self.assertEqual(sum(tags_count.values()), total_count * 8) num_records = 8. * total_count self.assertAlmostEqual(tags_count[b'input1'] / num_records, 0.2, delta=0.01) self.assertAlmostEqual(tags_count[b'input2'] / num_records, 0.3, delta=0.01) self.assertAlmostEqual(tags_count[b'input3'] / num_records, 0.5, delta=0.01) self.assertAlmostEqual(source_id_count[0] / num_records, 0.2, delta=0.01) self.assertAlmostEqual(source_id_count[1] / num_records, 0.3, delta=0.01) self.assertAlmostEqual(source_id_count[2] / num_records, 0.5, delta=0.01)
def testDropRecordIfNegativeBucketKey(self, use_tf_func): def bucket_fn(num): # Drops record if num[0] is odd. return tf.cond( tf.equal(tf.math.floormod(num[0], 2), 0), lambda: 1, lambda: -tf.cast(num[0], tf.int32)) resource, out_types, output_tmpl = setup_basic( use_nested_map=False, bucket_fn=bucket_fn) def _get_batch(): return run_basic( resource, use_nested_map=False, out_types=out_types, output_tmpl=output_tmpl) if use_tf_func: _get_batch = tf.function(autograph=False)(_get_batch) # pylint: disable=invalid-name record_seen = set() for i in range(100): ans_input_batch = _get_batch() for s in ans_input_batch.record: record_seen.add(s.numpy()) for i in range(100): if i % 2 == 0: self.assertIn(('%08d' % i).encode('utf-8'), record_seen) else: self.assertNotIn(('%08d' % i).encode('utf-8'), record_seen)
def testPadding(self, use_tf_func): # Generate a test file w/ 50 records of different lengths. tmp = os.path.join(tf.test.get_temp_dir(), 'basic') with tf.python_io.TFRecordWriter(tmp) as w: for n in range(1, 50): w.write(pickle.dumps(np.full([n, 3, 3], n, np.int32))) # A record processor written in TF graph. def _process(record): num = tf.py_func(pickle.loads, [record], tf.int32) bucket_key = tf.shape(num)[0] return [num, tf.transpose(num, [1, 0, 2])], bucket_key # Samples random records from the data files and processes them # to generate batches. resource, out_types, output_tmpl = get_test_input( tmp, bucket_upper_bound=[10], processor=_process, dynamic_padding_dimensions=[0, 1], dynamic_padding_constants=[0] * 2) def _get_batch(): return generic_input.GenericInputV2GetNext(resource, out_types, output_tmpl) if use_tf_func: _get_batch = tf.function(autograph=False)(_get_batch) # pylint: disable=invalid-name for _ in range(10): (vals, transposed_vals), _ = _get_batch() self.assertEqual(vals.shape[0], 8) self.assertEqual(vals.shape[2], 3) self.assertEqual(vals.shape[3], 3) largest = np.amax(vals) self.assertLessEqual(largest, 10) self.assertEqual(vals.shape[1], largest) for j in range(8): n = vals[j, 0, 0, 0] self.assertTrue(np.all(vals[j, :n] == n)) self.assertTrue(np.all(vals[j, n:] == 0)) self.assertAllEqual(vals, np.transpose(transposed_vals, [0, 2, 1, 3]))
def __init__(self, params): super().__init__(params) p = self.params for required_param_name in ('name', 'module_path'): if not p.Get(required_param_name): raise ValueError(f'Must set {required_param_name} param.') with tf.variable_scope(p.name): # NB: `trainable` merely controls whether the model *can* be run in # training mode. self._module = hub.KerasLayer(p.module_path, trainable=True) _WrapNonLingvoVars( self, variables=self._module.variables, trainable_variables=self._module.trainable_variables) # Functionalize the module's __call__ so train-mode update ops run eagerly. self._hub_module_fn = tf.function( lambda images, training: self._module(images, training=training))
def testTfData(self, use_tf_func): """Checks that GenericInput can be invoked from a tf.data.Dataset.""" resource, out_types, output_tmpl = setup_basic(use_nested_map=True) def _get_batch(): return run_basic( resource, use_nested_map=False, out_types=out_types, output_tmpl=output_tmpl) if use_tf_func: _get_batch = tf.function(autograph=False)(_get_batch) # pylint: disable=invalid-name # Trick to create dataset from tensor coming from custom op. dummy_dataset = tf.data.Dataset.from_tensors(0).repeat() dataset = dummy_dataset.map(lambda _: _get_batch()) it = iter(dataset) for _ in range(10): # Read 10 batches. print(it.get_next())