def testExtraArgs(self): def _parse_record(record): del record example = py_utils.NestedMap(t=tf.convert_to_tensor(0)) bucketing_key = 1 return example, bucketing_key def _parse_record_stateful(record): del record extra = tf.Variable(0) example = py_utils.NestedMap(t=extra.value()) bucketing_key = 1 return example, bucketing_key generic_input.GenericInput(_parse_record, file_pattern='', bucket_upper_bound=[1], bucket_batch_limit=[1]) with self.assertRaisesRegex(AssertionError, 'is not pure: extra_args='): generic_input.GenericInput(_parse_record_stateful, file_pattern='', bucket_upper_bound=[1], bucket_batch_limit=[1])
def _DataSourceFromFilePattern(self, file_pattern): def Proc(record): """Parses a serialized tf.Example record.""" # There we go! string, string, float32. I hope frames is allowed # to be a waveform directly... features = [ ('uttid', tf.io.VarLenFeature(tf.int64)), # Would like to change this to tf.int16 in the future, if that is possible (would have to read from ('frames', tf.io.VarLenFeature(tf.float32)), ] example = tf.io.parse_single_example(record, dict(features)) fval = {k: v.values for k, v in example.items()} # Reshape the flattened vector into its original time-major # representation. fval['frames'] = tf.reshape(fval['frames'], shape=[-1, self.params.frame_size]) # Input duration determines the bucket. bucket_key = tf.cast(tf.shape(fval['frames'])[0], tf.int32) if self.params.append_eos_frame: bucket_key += 1 src_paddings = tf.zeros([tf.shape(fval['frames'])[0]], dtype=tf.float32) return [fval['uttid'], fval['frames'], src_paddings], bucket_key return generic_input.GenericInput(file_pattern=file_pattern, processor=Proc, dynamic_padding_dimensions=[0] * 3, dynamic_padding_constants=[0] * 2 + [1], **self.CommonInputOpArgs())
def _DataSourceFromFilePattern(self, file_pattern): """Create the input processing op. Args: file_pattern: The file pattern to use as input. Returns: an operation that when executed, calls `_ProcessLine` on a line read from `file_pattern`. """ ret = py_utils.NestedMap() (src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights), ret.bucket_keys = generic_input.GenericInput( file_pattern=file_pattern, processor=self._ProcessLine, # Pad dimension 0 to the same length. dynamic_padding_dimensions=[0] * 6, # The constant values to use for padding each of the outputs. dynamic_padding_constants=[0, 1, 0, 1, 0, 0], **self.CommonInputOpArgs()) ret.src = py_utils.NestedMap() ret.src.ids = tf.cast(src_ids, dtype=tf.int32) ret.src.paddings = src_paddings ret.tgt = py_utils.NestedMap() ret.tgt.ids = tgt_ids ret.tgt.labels = tf.cast(tgt_labels, dtype=tf.int32) ret.tgt.weights = tgt_weights ret.tgt.paddings = tgt_paddings return ret
def _DataSourceFromFilePattern(self, file_pattern): def Proc(record): """Parses a serialized tf.Example record.""" features = [ ('uttid', tf.io.VarLenFeature(tf.string)), ('transcript', tf.io.VarLenFeature(tf.string)), ('frames', tf.io.VarLenFeature(tf.float32)), ] example = tf.io.parse_single_example(record, dict(features)) fval = {k: v.values for k, v in six.iteritems(example)} # Reshape the flattened vector into its original time-major # representation. fval['frames'] = tf.reshape(fval['frames'], shape=[-1, self.params.frame_size]) # Input duration determines the bucket. bucket_key = tf.cast(tf.shape(fval['frames'])[0], tf.int32) if self.params.append_eos_frame: bucket_key += 1 tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds( fval['transcript']) src_paddings = tf.zeros([tf.shape(fval['frames'])[0]], dtype=tf.float32) return [ fval['uttid'], tgt_ids, tgt_labels, tgt_paddings, fval['frames'], src_paddings ], bucket_key return generic_input.GenericInput(file_pattern=file_pattern, processor=Proc, dynamic_padding_dimensions=[0] * 6, dynamic_padding_constants=[0] * 5 + [1], **self.CommonInputOpArgs())
def _DataSourceFromFilePattern(self, file_pattern): def Proc(record): """Parses a serialized tf.Example record.""" outputs = [ ('source_id', tf.VarLenFeature(tf.int64)), ('source_padding', tf.VarLenFeature(tf.float32)), ('target_id', tf.VarLenFeature(tf.int64)), ('target_padding', tf.VarLenFeature(tf.float32)), ('target_label', tf.VarLenFeature(tf.int64)), ('target_weight', tf.VarLenFeature(tf.float32)), ] features = tf.parse_single_example(record, dict(outputs)) for k, v in six.iteritems(features): features[k] = v.values bucket_key = tf.cast( tf.maximum( tf.reduce_sum(1.0 - features['source_padding']), tf.reduce_sum(1.0 - features['target_padding'])), tf.int32) return [features[k] for k, _ in outputs], bucket_key return generic_input.GenericInput( file_pattern=file_pattern, processor=Proc, dynamic_padding_dimensions=[0] * 6, dynamic_padding_constants=[0, 1, 0, 1, 0, 0], **self.CommonInputOpArgs())
def get_test_input(self, path, **kwargs): return generic_input.GenericInput(file_pattern='tfrecord:' + path, file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[8], **kwargs)
def _DataSourceFromFilePattern(self, file_pattern): def ReadInput(line): word_count = tf.size(tf.strings.split([line])) strlen = tf.size(tf.strings.split([line], '')) return [line, word_count], strlen return generic_input.GenericInput(file_pattern=file_pattern, processor=ReadInput, **self.CommonInputOpArgs())
def testWithinBatchMixing(self): # Generate couple files. def generate_test_data(tag, cnt): tmp = os.path.join(tf.test.get_temp_dir(), tag) with tf.python_io.TFRecordWriter(tmp) as w: for i in range(cnt): w.write(('%s:%08d' % (tag, i)).encode('utf-8')) return tmp path1 = generate_test_data('input1', 100) path2 = generate_test_data('input2', 200) path3 = generate_test_data('input3', 10) g = tf.Graph() with g.as_default(): # A record processor written in TF graph. def _process(source_id, record): return py_utils.NestedMap(source_id=source_id, record=record), 1 # Samples random records from the data files and processes them # to generate batches. input_batch, buckets = generic_input.GenericInput( file_pattern=','.join( ['tfrecord:' + path1, 'tfrecord:' + path2, 'tfrecord:' + path3]), input_source_weights=[0.2, 0.3, 0.5], file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[8], bucket_upper_bound=[1], processor=_process) with self.session(graph=g): source_id_count = collections.defaultdict(int) tags_count = collections.defaultdict(int) total_count = 10000 for _ in range(total_count): ans_input_batch, ans_buckets = self.evaluate([input_batch, buckets]) for s in ans_input_batch.source_id: source_id_count[s] += 1 for s in ans_input_batch.record: tags_count[s.split(b':')[0]] += 1 self.assertEqual(ans_input_batch.source_id.shape, (8,)) self.assertEqual(ans_input_batch.record.shape, (8,)) self.assertAllEqual(ans_buckets, [1] * 8) self.assertEqual(sum(source_id_count.values()), total_count * 8) self.assertEqual(sum(tags_count.values()), total_count * 8) num_records = 8. * total_count self.assertAlmostEqual( tags_count[b'input1'] / num_records, 0.2, delta=0.01) self.assertAlmostEqual( tags_count[b'input2'] / num_records, 0.3, delta=0.01) self.assertAlmostEqual( tags_count[b'input3'] / num_records, 0.5, delta=0.01) self.assertAlmostEqual(source_id_count[0] / num_records, 0.2, delta=0.01) self.assertAlmostEqual(source_id_count[1] / num_records, 0.3, delta=0.01) self.assertAlmostEqual(source_id_count[2] / num_records, 0.5, delta=0.01)
def get_test_input(self, path, **kwargs): return generic_input.GenericInput( file_pattern=','.join(['tfrecord:' + path, 'tfrecord:' + path]), input_source_weights=[0.3, 0.7], file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[8], **kwargs)
def _DataSourceFromFilePattern(self, file_pattern): def Proc(record): """Parses a serialized tf.Example record.""" bucket, outputs = self.ExtractUsingExtractors(record) return outputs.Flatten(), bucket # Ensure buckets above BUCKET_UPPER_BOUND are dropped. args = self.CommonInputOpArgs() args['bucket_upper_bound'] = [BUCKET_UPPER_BOUND - 1] return generic_input.GenericInput(processor=Proc, file_pattern=file_pattern, **args)
def testMix(self): # Generate couple files. def generate_test_data(tag, cnt): tmp = os.path.join(tf.test.get_temp_dir(), tag) with tf.python_io.TFRecordWriter(tmp) as w: for i in range(cnt): w.write(('%s:%08d' % (tag, i)).encode('utf-8')) return tmp path1 = generate_test_data('input1', 100) path2 = generate_test_data('input2', 200) path3 = generate_test_data('input3', 10) g = tf.Graph() with g.as_default(): # A record processor written in TF graph. def _process(record): return [record, record], 1 # Samples random records from the data files and processes them # to generate batches. (strs, vals), buckets = generic_input.GenericInput( file_pattern=','.join([ 'tfrecord:' + path1, 'tfrecord:' + path2, 'tfrecord:' + path3 ]), input_source_weights=[0.2, 0.3, 0.5], file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[8], bucket_upper_bound=[1], processor=_process) with self.session(graph=g) as sess: tags_count = collections.defaultdict(int) total_count = 10000 for _ in range(total_count): ans_strs, ans_vals, ans_buckets = sess.run( [strs, vals, buckets]) for s in ans_strs: tags_count[s.split(b':')[0]] += 1 self.assertEqual(ans_strs.shape, (8, )) self.assertEqual(ans_vals.shape, (8, )) self.assertAllEqual(ans_buckets, [1] * 8) self.assertEqual(sum(tags_count.values()), total_count * 8) mix_ratios = {} for k, v in six.iteritems(tags_count): mix_ratios[k] = float(v) / total_count / 8 self.assertAlmostEqual(mix_ratios[b'input1'], 0.2, delta=0.01) self.assertAlmostEqual(mix_ratios[b'input2'], 0.3, delta=0.01) self.assertAlmostEqual(mix_ratios[b'input3'], 0.5, delta=0.01)
def testFatalErrors(self): tmp = os.path.join(tf.test.get_temp_dir(), 'fatal') with tf.python_io.TFRecordWriter(tmp) as w: for i in range(50): w.write(str((i % 2) * 2**33)) def _parse_record(record): # tf.strings.to_number raises error on overflow. i = tf.strings.to_number(record, tf.int32) example = py_utils.NestedMap(record=i) bucketing_key = 1 return example, bucketing_key with self.session(): # Without specifying fatal_errors all records not 0 are skipped. input_batch, _ = generic_input.GenericInput( _parse_record, file_pattern=f'tfrecord:{tmp}', bucket_upper_bound=[1], bucket_batch_limit=[1]) for i in range(25): ans_input_batch = self.evaluate(input_batch) self.assertEqual(ans_input_batch.record[0], 0) # With fatal_errors it dies instead. input_batch, _ = generic_input.GenericInput( _parse_record, file_pattern=f'tfrecord:{tmp}', bucket_upper_bound=[1], bucket_batch_limit=[1], fatal_errors=[ 'StringToNumberOp could not correctly convert string:' ]) # NOTE: There is no way to catch LOG(FATAL) from python side, so running # this test will cause a crash. for i in range(10): self.evaluate(input_batch)
def _DataSourceFromFilePattern(self, file_pattern): def ReadInput(line): word_count = tf.size(tf.strings.split([line])) strlen = tf.size(tf.strings.split([line], '')) return [line, word_count], strlen features, bucket_keys = generic_input.GenericInput( file_pattern=file_pattern, processor=ReadInput, **self.CommonInputOpArgs()) return self.BuildInputBatch(batch_size=self.InfeedBatchSize(), features_list=features, bucket_keys=bucket_keys)
def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None): def Processor(source_id, record): """Parses a record, which is a line of text.""" if self.params.input_file_type == 'tsv': def _ApplyMass(source_id): if self.params.file_pattern_task_ids: file_task_ids = tf.constant( self.params.file_pattern_task_ids, dtype=tf.int32) task_id = tf.gather(file_task_ids, source_id) else: task_id = source_id mass_task_ids = tf.constant(self.params.mass_task_ids, dtype=tf.int32) return tf.reduce_any(tf.equal(task_id, mass_task_ids)) def _MASSInput(): src, filtered = self._ReadRecordTsvSingleColumn(record) return self._ProcessMASSInput(source_id, src), filtered def _SingleInput(): src, tgt, filtered = self._ReadRecordTsv(record) return self._ProcessSingleInput(source_id, src, tgt), filtered if self.params.single_column_input: if self.params.mass_task_ids is not None: cond = _ApplyMass(source_id) features, filtered = tf.cond(cond, _MASSInput, _SingleInput) else: features, filtered = _MASSInput() else: features, filtered = _SingleInput() else: src, tgt = self._ReadRecordSentencePairProto(record) filtered = tf.constant(False, dtype=tf.bool) features = self._ProcessSingleInput(source_id, src, tgt) return features, self._GetBucketKey(features, filtered) return generic_input.GenericInput( processor=Processor, file_pattern=file_pattern, input_source_weights=input_source_weights, **self.CommonInputOpArgs())
def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None): def Processor(source_id, record): """Parses a record, which is a line of text.""" task_id = self._GetTaskIds(source_id) if self.params.input_file_type == 'tsv': def _ApplyMass(task_id): mass_task_ids = tf.constant(self.params.mass_task_ids, dtype=tf.int32) return tf.reduce_any(tf.equal(task_id, mass_task_ids)) def _MASSInput(): src, filtered = self._ReadRecordTsvSingleColumn(record) return self._ProcessMASSInput(source_id, src), filtered def _SingleInput(): src, tgt, filtered = self._ReadRecordTsv(record) return self._ProcessSingleInput(source_id, src, tgt), filtered if self.params.single_column_input: # For monolingual input, MASS is applied by default. # If mass_task_ids is specified, only apply MASS to specified tasks. if self.params.mass_task_ids is not None: cond = _ApplyMass(task_id) features, filtered = tf.cond(cond, _MASSInput, _SingleInput) else: features, filtered = _MASSInput() else: features, filtered = _SingleInput() else: src, tgt = self._ReadRecordSentencePairProto(record) filtered = tf.constant(False, dtype=tf.bool) features = self._ProcessSingleInput(source_id, src, tgt) return features, self._GetBucketKey(features, filtered) batch, _ = generic_input.GenericInput( processor=Processor, file_pattern=file_pattern, input_source_weights=input_source_weights, **self.CommonInputOpArgs()) return self._Pack(batch)
def _DataSourceFromFilePattern(self, file_pattern): """Create the input processing op. Args: file_pattern: The file pattern to use as input. Returns: an operation that when executed, calls `_ProcessLine` on a line read from `file_pattern`. """ return generic_input.GenericInput( file_pattern=file_pattern, processor=self._ProcessLine, # Pad dimension 0 to the same length. dynamic_padding_dimensions=[0] * 6, # The constant values to use for padding each of the outputs. dynamic_padding_constants=[0, 1, 0, 1, 0, 0], **self.CommonInputOpArgs())
def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None): def Proc(record): """Parses a serialized tf.Example record.""" bucket, outputs = self.ExtractUsingExtractors(record) return outputs.Flatten(), bucket # Ensure buckets [BUCKET_UPPER_BOUND, inf) are dropped. args = self.CommonInputOpArgs() args['bucket_upper_bound'] = [BUCKET_UPPER_BOUND - 1] batched_outputs, bucket_keys = generic_input.GenericInput( processor=Proc, file_pattern=file_pattern, input_source_weights=input_source_weights, **args) ret = self._NestedMapFromBatchedOutputs(batched_outputs) ret.bucket_keys = bucket_keys return ret
def _DataSourceFromFilePattern(self, file_pattern): def Proc(record): """Parses a serialized tf.Example record.""" # There we go! string, string, float32. I hope frames is allowed # to be a waveform directly... features = [ ('int64_uttid', tf.io.VarLenFeature(tf.int64)), ('int64_audio_document_id', tf.io.VarLenFeature(tf.int64)), ('num_utterances_in_audio_document', tf.io.VarLenFeature(tf.int64)), ('transcript', tf.io.VarLenFeature(tf.string)), ('frames', tf.io.FixedLenFeature((), tf.string)), ] example = tf.io.parse_single_example(record, dict(features)) fval = {} for k, v in example.items(): if k == 'frames': fval[k] = tf.cast(tf.io.decode_raw(v, tf.int16), tf.float32) else: assert isinstance(v, tf.SparseTensor) fval[k] = v.values # Reshape the flattened vector into its original time-major # representation. fval['frames'] = tf.reshape( fval['frames'], shape=[-1, self.params.frame_size]) # Input duration determines the bucket. bucket_key = tf.cast(tf.shape(fval['frames'])[0], tf.int32) if self.params.append_eos_frame: bucket_key += 1 tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(fval['transcript']) src_paddings = tf.zeros([tf.shape(fval['frames'])[0]], dtype=tf.float32) return [ fval['int64_uttid'], fval['int64_audio_document_id'], fval['num_utterances_in_audio_document'], tgt_ids, tgt_labels, tgt_paddings, fval['frames'], src_paddings ], bucket_key return generic_input.GenericInput( file_pattern=file_pattern, processor=Proc, dynamic_padding_dimensions=[0] * 8, dynamic_padding_constants=[0] * 7 + [1], **self.CommonInputOpArgs())
def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None): """Read and return input batch from a string file_pattern.""" del input_source_weights # Unused. def Process(source_id, record): del source_id # Unused. [num] = tf.py_func(int, [record], [tf.int64]) return py_utils.NestedMap(data=num), 1 # Samples random records from the data files and processes them # to generate batches. inputs, _ = generic_input.GenericInput( processor=Process, file_pattern=file_pattern, file_random_seed=123, file_buffer_size=1, file_parallelism=1, bucket_batch_limit=[1], bucket_upper_bound=[1]) return inputs
def testNestedGenericInput(self, inner_batch_limit, outer_batch_limit): # Generate records using an inner GenericInput, and post-process them using # an outer one. # Test that the generated records are complete and contain no duplicates def _process(record): del record # Construct the inner GenericInput. batch = run_basic_graph(use_nested_map=True, bucket_batch_limit=inner_batch_limit) batch.num += 1 return batch, 1 input_batch, _ = generic_input.GenericInput( file_pattern='iota:', processor=_process, bucket_upper_bound=[1], bucket_batch_limit=[outer_batch_limit]) with self.session(): global_batch = inner_batch_limit * outer_batch_limit record_seen = set() # Iterate the inputs for exactly one epoch. for i in range(100 // global_batch): ans_input_batch = self.evaluate(input_batch) for record_array in ans_input_batch.record: for s in record_array: # There should not be duplicates since GenericInput is stateful. assert s not in record_seen record_seen.add(s) self.assertEqual(ans_input_batch.source_id.shape, (outer_batch_limit, inner_batch_limit)) self.assertEqual(ans_input_batch.record.shape, (outer_batch_limit, inner_batch_limit)) self.assertEqual(ans_input_batch.num.shape, (outer_batch_limit, inner_batch_limit, 2)) ans_vals = ans_input_batch.num self.assertAllEqual(np.square(ans_vals[:, :, 0] - 1), ans_vals[:, :, 1] - 1) for i in range(100): self.assertIn(('%08d' % i).encode('utf-8'), record_seen)
def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None): def Processor(source_id, record): """Parses a record, which is a line of text.""" if self.params.input_file_type == 'tsv': if self.params.single_column_input: src, filtered = self._ReadRecordTsvSingleColumn(record) features = self._ProcessMASSInput(source_id, src) else: src, tgt, filtered = self._ReadRecordTsv(record) features = self._ProcessSingleInput(source_id, src, tgt) else: src, tgt = self._ReadRecordSentencePairProto(record) filtered = tf.constant(False, dtype=tf.bool) features = self._ProcessSingleInput(source_id, src, tgt) return features, self._GetBucketKey(features, filtered) return generic_input.GenericInput( processor=Processor, file_pattern=file_pattern, input_source_weights=input_source_weights, **self.CommonInputOpArgs())
def testV2OpsErrorRaised(self, use_tf_func, set_allow_eager): # Generate a test file w/ 100 records. tmp = os.path.join(tf.test.get_temp_dir(), 'basic') with tf.python_io.TFRecordWriter(tmp) as w: for i in range(100): w.write(('%08d' % i).encode('utf-8')) # A simple string parsing routine. Just convert a string to a # number. def str_to_num(s): return np.array(float(s), dtype=np.float32) bucket_fn = lambda x: 1 # A record processor written in TF graph. def _process(source_id, record): num, = tf.py_func(str_to_num, [record], [tf.float32]) num = tf.stack([num, tf.square(num)]) return py_utils.NestedMap( source_id=source_id, record=record, num=num), bucket_fn(num) if set_allow_eager: # Test unique keys must be provided to distinguish GenericInputV2 ops generic_input.SetAllowGenericInputV2InEager(True) err_regex = 'op requires a unique key' else: # Test flags must be set to enable GenericInputV2 ops in Eager mode generic_input.SetAllowGenericInputV2InEager(False) err_regex = 'please add keyword arg' with self.assertRaisesRegex(RuntimeError, err_regex): _ = generic_input.GenericInput( file_pattern='tfrecord:' + tmp, file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[8], bucket_upper_bound=[1], processor=_process)
def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None): assert not tf.compat.v1.executing_eagerly() assert tf.compat.v1.executing_eagerly_outside_functions() def _process(source_id, record): del source_id num = tf.strings.to_number(record, tf.int32) if not tf_py_utils.use_tpu(): num = num * num return py_utils.NestedMap(num=num), 1 inputs, _ = generic_input.GenericInput(processor=_process, file_pattern=file_pattern, file_random_seed=0, require_sequential_order=True, repeat_count=1, file_buffer_size=32, file_parallelism=1, bucket_upper_bound=[10], bucket_batch_limit=[2]) return inputs
def _DataSourceFromFilePattern(self, file_pattern): p = self._params def Proc(record): """Parses a serialized tf.Example record.""" outputs = [ ('inputs', tf.VarLenFeature(tf.int64)), ('targets', tf.VarLenFeature(tf.int64)), ] features = tf.parse_single_example(record, dict(outputs)) for k, v in six.iteritems(features): features[k] = v.values src_ids = features['inputs'] tgt_labels = features['targets'] # Derive src_paddings, tgt_ids, tgt_paddings. # tgt_ids is tgt_labels shifted right by one, with a SOS ID prepended. tgt_ids = tf.concat([[p.sos_id], tgt_labels[:-1]], axis=0) src_paddings = tf.zeros(tf.shape(src_ids), dtype=tf.float32) tgt_paddings = tf.zeros(tf.shape(tgt_ids), dtype=tf.float32) tgt_weights = tf.ones(tf.shape(tgt_ids), dtype=tf.float32) bucket_key = tf.cast( tf.maximum( tf.reduce_sum(1.0 - src_paddings), tf.reduce_sum(1.0 - tgt_paddings)), tf.int32) return [ src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights ], bucket_key return generic_input.GenericInput( file_pattern=file_pattern, processor=Proc, dynamic_padding_dimensions=[0] * 6, dynamic_padding_constants=[0, 1, 0, 1, 0, 0], **self.CommonInputOpArgs())
def testV2OpsGetCalledInEager(self, use_tf_func, mock_op_key): # Generate a test file w/ 100 records. tmp = os.path.join(tf.test.get_temp_dir(), 'basic') with tf.python_io.TFRecordWriter(tmp) as w: for i in range(100): w.write(('%08d' % i).encode('utf-8')) # A simple string parsing routine. Just convert a string to a # number. def str_to_num(s): return np.array(float(s), dtype=np.float32) bucket_fn = lambda x: 1 # A record processor written in TF graph. def _process(source_id, record): num, = tf.py_func(str_to_num, [record], [tf.float32]) num = tf.stack([num, tf.square(num)]) return py_utils.NestedMap( source_id=source_id, record=record, num=num), bucket_fn(num) # pylint: disable=protected-access len_before = len(generic_input._GENERIC_CACHE_V2) _ = generic_input.GenericInput( file_pattern='tfrecord:' + tmp, file_random_seed=0, file_buffer_size=32, file_parallelism=4, bucket_batch_limit=[8], bucket_upper_bound=[1], processor=_process, generic_input_v2_key=mock_op_key) # pylint: disable=protected-access len_after = len(generic_input._GENERIC_CACHE_V2) self.assertEqual(len_after, len_before + 1)
def _DataSourceFromFilePattern(self, file_pattern): p = self._params def _DerivePaddingsAndIds(src_ids, tgt_labels): """tgt_ids is tgt_labels shifted right by one, with a SOS ID prepended.""" tgt_ids = tf.concat([[p.sos_id], tgt_labels[:-1]], axis=0) src_paddings = tf.zeros(tf.shape(src_ids), dtype=tf.float32) tgt_paddings = tf.zeros(tf.shape(tgt_ids), dtype=tf.float32) tgt_weights = tf.ones(tf.shape(tgt_ids), dtype=tf.float32) bucket_key = tf.cast( tf.maximum( tf.reduce_sum(1.0 - src_paddings), tf.reduce_sum(1.0 - tgt_paddings)), tf.int32) return src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key def _ProcPacked(record): """TFExample -> Tensors for PackedInput.""" outputs = [ ('inputs', tf.VarLenFeature(tf.int64)), ('targets', tf.VarLenFeature(tf.int64)), ('inputs_segmentation', tf.VarLenFeature(tf.int64)), ('inputs_position', tf.VarLenFeature(tf.int64)), ('targets_segmentation', tf.VarLenFeature(tf.int64)), ('targets_position', tf.VarLenFeature(tf.int64)), ] features = tf.parse_single_example(record, dict(outputs)) for k, v in six.iteritems(features): features[k] = v.values src_ids = features['inputs'] tgt_labels = features['targets'] src_pos = features['inputs_position'] src_seg = features['inputs_segmentation'] tgt_pos = features['targets_position'] tgt_seg = features['targets_segmentation'] src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key = _DerivePaddingsAndIds( src_ids, tgt_labels) return [ src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights, src_pos, src_seg, tgt_pos, tgt_seg, ], bucket_key def _Proc(record): """Parses a serialized tf.Example record.""" outputs = [ ('inputs', tf.VarLenFeature(tf.int64)), ('targets', tf.VarLenFeature(tf.int64)), ] features = tf.parse_single_example(record, dict(outputs)) for k, v in six.iteritems(features): features[k] = v.values src_ids = features['inputs'] tgt_labels = features['targets'] # Derive trivial segmentation for unpacked input. src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key = _DerivePaddingsAndIds( src_ids, tgt_labels) src_len = tf.shape(src_ids)[0] tgt_len = tf.shape(tgt_ids)[0] src_pos = tf.range(src_len, dtype=tf.int32) src_seg = tf.zeros_like(src_paddings) tgt_pos = tf.range(tgt_len, dtype=tf.int32) tgt_seg = tf.zeros_like(tgt_paddings) return [ src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights, src_pos, src_seg, tgt_pos, tgt_seg, ], bucket_key if not p.packed_input: processor_fn = _Proc else: processor_fn = _ProcPacked return generic_input.GenericInput( file_pattern=file_pattern, processor=processor_fn, dynamic_padding_dimensions=[0] * 10, dynamic_padding_constants=[0, 1, 0, 1, 0, 0, 0, 0, 0, 0], **self.CommonInputOpArgs())