示例#1
0
    def testExtraArgs(self):
        def _parse_record(record):
            del record
            example = py_utils.NestedMap(t=tf.convert_to_tensor(0))
            bucketing_key = 1
            return example, bucketing_key

        def _parse_record_stateful(record):
            del record
            extra = tf.Variable(0)
            example = py_utils.NestedMap(t=extra.value())
            bucketing_key = 1
            return example, bucketing_key

        generic_input.GenericInput(_parse_record,
                                   file_pattern='',
                                   bucket_upper_bound=[1],
                                   bucket_batch_limit=[1])

        with self.assertRaisesRegex(AssertionError,
                                    'is not pure: extra_args='):
            generic_input.GenericInput(_parse_record_stateful,
                                       file_pattern='',
                                       bucket_upper_bound=[1],
                                       bucket_batch_limit=[1])
示例#2
0
    def _DataSourceFromFilePattern(self, file_pattern):
        def Proc(record):
            """Parses a serialized tf.Example record."""
            # There we go! string, string, float32. I hope frames is allowed
            # to be a waveform directly...
            features = [
                ('uttid', tf.io.VarLenFeature(tf.int64)),
                # Would like to change this to tf.int16 in the future, if that is possible (would have to read from
                ('frames', tf.io.VarLenFeature(tf.float32)),
            ]
            example = tf.io.parse_single_example(record, dict(features))
            fval = {k: v.values for k, v in example.items()}
            # Reshape the flattened vector into its original time-major
            # representation.
            fval['frames'] = tf.reshape(fval['frames'],
                                        shape=[-1, self.params.frame_size])
            # Input duration determines the bucket.
            bucket_key = tf.cast(tf.shape(fval['frames'])[0], tf.int32)
            if self.params.append_eos_frame:
                bucket_key += 1
            src_paddings = tf.zeros([tf.shape(fval['frames'])[0]],
                                    dtype=tf.float32)
            return [fval['uttid'], fval['frames'], src_paddings], bucket_key

        return generic_input.GenericInput(file_pattern=file_pattern,
                                          processor=Proc,
                                          dynamic_padding_dimensions=[0] * 3,
                                          dynamic_padding_constants=[0] * 2 +
                                          [1],
                                          **self.CommonInputOpArgs())
示例#3
0
    def _DataSourceFromFilePattern(self, file_pattern):
        """Create the input processing op.

    Args:
      file_pattern: The file pattern to use as input.

    Returns:
      an operation that when executed, calls `_ProcessLine` on a line read
    from `file_pattern`.
    """
        ret = py_utils.NestedMap()

        (src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels,
         tgt_weights), ret.bucket_keys = generic_input.GenericInput(
             file_pattern=file_pattern,
             processor=self._ProcessLine,
             # Pad dimension 0 to the same length.
             dynamic_padding_dimensions=[0] * 6,
             # The constant values to use for padding each of the outputs.
             dynamic_padding_constants=[0, 1, 0, 1, 0, 0],
             **self.CommonInputOpArgs())

        ret.src = py_utils.NestedMap()
        ret.src.ids = tf.cast(src_ids, dtype=tf.int32)
        ret.src.paddings = src_paddings

        ret.tgt = py_utils.NestedMap()
        ret.tgt.ids = tgt_ids
        ret.tgt.labels = tf.cast(tgt_labels, dtype=tf.int32)
        ret.tgt.weights = tgt_weights
        ret.tgt.paddings = tgt_paddings

        return ret
示例#4
0
    def _DataSourceFromFilePattern(self, file_pattern):
        def Proc(record):
            """Parses a serialized tf.Example record."""
            features = [
                ('uttid', tf.io.VarLenFeature(tf.string)),
                ('transcript', tf.io.VarLenFeature(tf.string)),
                ('frames', tf.io.VarLenFeature(tf.float32)),
            ]
            example = tf.io.parse_single_example(record, dict(features))
            fval = {k: v.values for k, v in six.iteritems(example)}
            # Reshape the flattened vector into its original time-major
            # representation.
            fval['frames'] = tf.reshape(fval['frames'],
                                        shape=[-1, self.params.frame_size])
            # Input duration determines the bucket.
            bucket_key = tf.cast(tf.shape(fval['frames'])[0], tf.int32)
            if self.params.append_eos_frame:
                bucket_key += 1
            tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(
                fval['transcript'])
            src_paddings = tf.zeros([tf.shape(fval['frames'])[0]],
                                    dtype=tf.float32)
            return [
                fval['uttid'], tgt_ids, tgt_labels, tgt_paddings,
                fval['frames'], src_paddings
            ], bucket_key

        return generic_input.GenericInput(file_pattern=file_pattern,
                                          processor=Proc,
                                          dynamic_padding_dimensions=[0] * 6,
                                          dynamic_padding_constants=[0] * 5 +
                                          [1],
                                          **self.CommonInputOpArgs())
示例#5
0
  def _DataSourceFromFilePattern(self, file_pattern):

    def Proc(record):
      """Parses a serialized tf.Example record."""
      outputs = [
          ('source_id', tf.VarLenFeature(tf.int64)),
          ('source_padding', tf.VarLenFeature(tf.float32)),
          ('target_id', tf.VarLenFeature(tf.int64)),
          ('target_padding', tf.VarLenFeature(tf.float32)),
          ('target_label', tf.VarLenFeature(tf.int64)),
          ('target_weight', tf.VarLenFeature(tf.float32)),
      ]
      features = tf.parse_single_example(record, dict(outputs))
      for k, v in six.iteritems(features):
        features[k] = v.values
      bucket_key = tf.cast(
          tf.maximum(
              tf.reduce_sum(1.0 - features['source_padding']),
              tf.reduce_sum(1.0 - features['target_padding'])), tf.int32)
      return [features[k] for k, _ in outputs], bucket_key

    return generic_input.GenericInput(
        file_pattern=file_pattern,
        processor=Proc,
        dynamic_padding_dimensions=[0] * 6,
        dynamic_padding_constants=[0, 1, 0, 1, 0, 0],
        **self.CommonInputOpArgs())
示例#6
0
 def get_test_input(self, path, **kwargs):
     return generic_input.GenericInput(file_pattern='tfrecord:' + path,
                                       file_random_seed=0,
                                       file_buffer_size=32,
                                       file_parallelism=4,
                                       bucket_batch_limit=[8],
                                       **kwargs)
示例#7
0
    def _DataSourceFromFilePattern(self, file_pattern):
        def ReadInput(line):
            word_count = tf.size(tf.strings.split([line]))
            strlen = tf.size(tf.strings.split([line], ''))
            return [line, word_count], strlen

        return generic_input.GenericInput(file_pattern=file_pattern,
                                          processor=ReadInput,
                                          **self.CommonInputOpArgs())
示例#8
0
  def testWithinBatchMixing(self):
    # Generate couple files.
    def generate_test_data(tag, cnt):
      tmp = os.path.join(tf.test.get_temp_dir(), tag)
      with tf.python_io.TFRecordWriter(tmp) as w:
        for i in range(cnt):
          w.write(('%s:%08d' % (tag, i)).encode('utf-8'))
      return tmp

    path1 = generate_test_data('input1', 100)
    path2 = generate_test_data('input2', 200)
    path3 = generate_test_data('input3', 10)

    g = tf.Graph()
    with g.as_default():
      # A record processor written in TF graph.
      def _process(source_id, record):
        return py_utils.NestedMap(source_id=source_id, record=record), 1

      # Samples random records from the data files and processes them
      # to generate batches.
      input_batch, buckets = generic_input.GenericInput(
          file_pattern=','.join(
              ['tfrecord:' + path1, 'tfrecord:' + path2, 'tfrecord:' + path3]),
          input_source_weights=[0.2, 0.3, 0.5],
          file_random_seed=0,
          file_buffer_size=32,
          file_parallelism=4,
          bucket_batch_limit=[8],
          bucket_upper_bound=[1],
          processor=_process)

    with self.session(graph=g):
      source_id_count = collections.defaultdict(int)
      tags_count = collections.defaultdict(int)
      total_count = 10000
      for _ in range(total_count):
        ans_input_batch, ans_buckets = self.evaluate([input_batch, buckets])
        for s in ans_input_batch.source_id:
          source_id_count[s] += 1
        for s in ans_input_batch.record:
          tags_count[s.split(b':')[0]] += 1
        self.assertEqual(ans_input_batch.source_id.shape, (8,))
        self.assertEqual(ans_input_batch.record.shape, (8,))
        self.assertAllEqual(ans_buckets, [1] * 8)
      self.assertEqual(sum(source_id_count.values()), total_count * 8)
      self.assertEqual(sum(tags_count.values()), total_count * 8)
      num_records = 8. * total_count
      self.assertAlmostEqual(
          tags_count[b'input1'] / num_records, 0.2, delta=0.01)
      self.assertAlmostEqual(
          tags_count[b'input2'] / num_records, 0.3, delta=0.01)
      self.assertAlmostEqual(
          tags_count[b'input3'] / num_records, 0.5, delta=0.01)
      self.assertAlmostEqual(source_id_count[0] / num_records, 0.2, delta=0.01)
      self.assertAlmostEqual(source_id_count[1] / num_records, 0.3, delta=0.01)
      self.assertAlmostEqual(source_id_count[2] / num_records, 0.5, delta=0.01)
示例#9
0
 def get_test_input(self, path, **kwargs):
   return generic_input.GenericInput(
       file_pattern=','.join(['tfrecord:' + path, 'tfrecord:' + path]),
       input_source_weights=[0.3, 0.7],
       file_random_seed=0,
       file_buffer_size=32,
       file_parallelism=4,
       bucket_batch_limit=[8],
       **kwargs)
示例#10
0
    def _DataSourceFromFilePattern(self, file_pattern):
        def Proc(record):
            """Parses a serialized tf.Example record."""
            bucket, outputs = self.ExtractUsingExtractors(record)
            return outputs.Flatten(), bucket

        # Ensure buckets above BUCKET_UPPER_BOUND are dropped.
        args = self.CommonInputOpArgs()
        args['bucket_upper_bound'] = [BUCKET_UPPER_BOUND - 1]
        return generic_input.GenericInput(processor=Proc,
                                          file_pattern=file_pattern,
                                          **args)
示例#11
0
    def testMix(self):
        # Generate couple files.
        def generate_test_data(tag, cnt):
            tmp = os.path.join(tf.test.get_temp_dir(), tag)
            with tf.python_io.TFRecordWriter(tmp) as w:
                for i in range(cnt):
                    w.write(('%s:%08d' % (tag, i)).encode('utf-8'))
            return tmp

        path1 = generate_test_data('input1', 100)
        path2 = generate_test_data('input2', 200)
        path3 = generate_test_data('input3', 10)

        g = tf.Graph()
        with g.as_default():
            # A record processor written in TF graph.
            def _process(record):
                return [record, record], 1

            # Samples random records from the data files and processes them
            # to generate batches.
            (strs, vals), buckets = generic_input.GenericInput(
                file_pattern=','.join([
                    'tfrecord:' + path1, 'tfrecord:' + path2,
                    'tfrecord:' + path3
                ]),
                input_source_weights=[0.2, 0.3, 0.5],
                file_random_seed=0,
                file_buffer_size=32,
                file_parallelism=4,
                bucket_batch_limit=[8],
                bucket_upper_bound=[1],
                processor=_process)

        with self.session(graph=g) as sess:
            tags_count = collections.defaultdict(int)
            total_count = 10000
            for _ in range(total_count):
                ans_strs, ans_vals, ans_buckets = sess.run(
                    [strs, vals, buckets])
                for s in ans_strs:
                    tags_count[s.split(b':')[0]] += 1
                self.assertEqual(ans_strs.shape, (8, ))
                self.assertEqual(ans_vals.shape, (8, ))
                self.assertAllEqual(ans_buckets, [1] * 8)
            self.assertEqual(sum(tags_count.values()), total_count * 8)
            mix_ratios = {}
            for k, v in six.iteritems(tags_count):
                mix_ratios[k] = float(v) / total_count / 8
            self.assertAlmostEqual(mix_ratios[b'input1'], 0.2, delta=0.01)
            self.assertAlmostEqual(mix_ratios[b'input2'], 0.3, delta=0.01)
            self.assertAlmostEqual(mix_ratios[b'input3'], 0.5, delta=0.01)
示例#12
0
    def testFatalErrors(self):
        tmp = os.path.join(tf.test.get_temp_dir(), 'fatal')
        with tf.python_io.TFRecordWriter(tmp) as w:
            for i in range(50):
                w.write(str((i % 2) * 2**33))

        def _parse_record(record):
            # tf.strings.to_number raises error on overflow.
            i = tf.strings.to_number(record, tf.int32)
            example = py_utils.NestedMap(record=i)
            bucketing_key = 1
            return example, bucketing_key

        with self.session():
            # Without specifying fatal_errors all records not 0 are skipped.
            input_batch, _ = generic_input.GenericInput(
                _parse_record,
                file_pattern=f'tfrecord:{tmp}',
                bucket_upper_bound=[1],
                bucket_batch_limit=[1])

            for i in range(25):
                ans_input_batch = self.evaluate(input_batch)
                self.assertEqual(ans_input_batch.record[0], 0)

            # With fatal_errors it dies instead.
            input_batch, _ = generic_input.GenericInput(
                _parse_record,
                file_pattern=f'tfrecord:{tmp}',
                bucket_upper_bound=[1],
                bucket_batch_limit=[1],
                fatal_errors=[
                    'StringToNumberOp could not correctly convert string:'
                ])

            # NOTE: There is no way to catch LOG(FATAL) from python side, so running
            # this test will cause a crash.
            for i in range(10):
                self.evaluate(input_batch)
示例#13
0
    def _DataSourceFromFilePattern(self, file_pattern):
        def ReadInput(line):
            word_count = tf.size(tf.strings.split([line]))
            strlen = tf.size(tf.strings.split([line], ''))
            return [line, word_count], strlen

        features, bucket_keys = generic_input.GenericInput(
            file_pattern=file_pattern,
            processor=ReadInput,
            **self.CommonInputOpArgs())

        return self.BuildInputBatch(batch_size=self.InfeedBatchSize(),
                                    features_list=features,
                                    bucket_keys=bucket_keys)
示例#14
0
    def _DataSourceFromFilePattern(self,
                                   file_pattern,
                                   input_source_weights=None):
        def Processor(source_id, record):
            """Parses a record, which is a line of text."""

            if self.params.input_file_type == 'tsv':

                def _ApplyMass(source_id):
                    if self.params.file_pattern_task_ids:
                        file_task_ids = tf.constant(
                            self.params.file_pattern_task_ids, dtype=tf.int32)
                        task_id = tf.gather(file_task_ids, source_id)
                    else:
                        task_id = source_id
                    mass_task_ids = tf.constant(self.params.mass_task_ids,
                                                dtype=tf.int32)
                    return tf.reduce_any(tf.equal(task_id, mass_task_ids))

                def _MASSInput():
                    src, filtered = self._ReadRecordTsvSingleColumn(record)
                    return self._ProcessMASSInput(source_id, src), filtered

                def _SingleInput():
                    src, tgt, filtered = self._ReadRecordTsv(record)
                    return self._ProcessSingleInput(source_id, src,
                                                    tgt), filtered

                if self.params.single_column_input:
                    if self.params.mass_task_ids is not None:
                        cond = _ApplyMass(source_id)
                        features, filtered = tf.cond(cond, _MASSInput,
                                                     _SingleInput)
                    else:
                        features, filtered = _MASSInput()
                else:
                    features, filtered = _SingleInput()

            else:
                src, tgt = self._ReadRecordSentencePairProto(record)
                filtered = tf.constant(False, dtype=tf.bool)
                features = self._ProcessSingleInput(source_id, src, tgt)

            return features, self._GetBucketKey(features, filtered)

        return generic_input.GenericInput(
            processor=Processor,
            file_pattern=file_pattern,
            input_source_weights=input_source_weights,
            **self.CommonInputOpArgs())
示例#15
0
    def _DataSourceFromFilePattern(self,
                                   file_pattern,
                                   input_source_weights=None):
        def Processor(source_id, record):
            """Parses a record, which is a line of text."""

            task_id = self._GetTaskIds(source_id)

            if self.params.input_file_type == 'tsv':

                def _ApplyMass(task_id):
                    mass_task_ids = tf.constant(self.params.mass_task_ids,
                                                dtype=tf.int32)
                    return tf.reduce_any(tf.equal(task_id, mass_task_ids))

                def _MASSInput():
                    src, filtered = self._ReadRecordTsvSingleColumn(record)
                    return self._ProcessMASSInput(source_id, src), filtered

                def _SingleInput():
                    src, tgt, filtered = self._ReadRecordTsv(record)
                    return self._ProcessSingleInput(source_id, src,
                                                    tgt), filtered

                if self.params.single_column_input:
                    # For monolingual input, MASS is applied by default.
                    # If mass_task_ids is specified, only apply MASS to specified tasks.
                    if self.params.mass_task_ids is not None:
                        cond = _ApplyMass(task_id)
                        features, filtered = tf.cond(cond, _MASSInput,
                                                     _SingleInput)
                    else:
                        features, filtered = _MASSInput()
                else:
                    features, filtered = _SingleInput()

            else:
                src, tgt = self._ReadRecordSentencePairProto(record)
                filtered = tf.constant(False, dtype=tf.bool)
                features = self._ProcessSingleInput(source_id, src, tgt)

            return features, self._GetBucketKey(features, filtered)

        batch, _ = generic_input.GenericInput(
            processor=Processor,
            file_pattern=file_pattern,
            input_source_weights=input_source_weights,
            **self.CommonInputOpArgs())
        return self._Pack(batch)
示例#16
0
    def _DataSourceFromFilePattern(self, file_pattern):
        """Create the input processing op.

    Args:
      file_pattern: The file pattern to use as input.

    Returns:
      an operation that when executed, calls `_ProcessLine` on a line read
    from `file_pattern`.
    """
        return generic_input.GenericInput(
            file_pattern=file_pattern,
            processor=self._ProcessLine,
            # Pad dimension 0 to the same length.
            dynamic_padding_dimensions=[0] * 6,
            # The constant values to use for padding each of the outputs.
            dynamic_padding_constants=[0, 1, 0, 1, 0, 0],
            **self.CommonInputOpArgs())
示例#17
0
  def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None):

    def Proc(record):
      """Parses a serialized tf.Example record."""
      bucket, outputs = self.ExtractUsingExtractors(record)
      return outputs.Flatten(), bucket

    # Ensure buckets [BUCKET_UPPER_BOUND, inf) are dropped.
    args = self.CommonInputOpArgs()
    args['bucket_upper_bound'] = [BUCKET_UPPER_BOUND - 1]
    batched_outputs, bucket_keys = generic_input.GenericInput(
        processor=Proc,
        file_pattern=file_pattern,
        input_source_weights=input_source_weights,
        **args)
    ret = self._NestedMapFromBatchedOutputs(batched_outputs)
    ret.bucket_keys = bucket_keys
    return ret
示例#18
0
  def _DataSourceFromFilePattern(self, file_pattern):

    def Proc(record):
      """Parses a serialized tf.Example record."""
      # There we go! string, string, float32. I hope frames is allowed
      # to be a waveform directly...
      features = [
          ('int64_uttid', tf.io.VarLenFeature(tf.int64)),
          ('int64_audio_document_id', tf.io.VarLenFeature(tf.int64)),
          ('num_utterances_in_audio_document', tf.io.VarLenFeature(tf.int64)),
          ('transcript', tf.io.VarLenFeature(tf.string)),
          ('frames', tf.io.FixedLenFeature((), tf.string)),
      ]
      example = tf.io.parse_single_example(record, dict(features))
      fval = {}
      for k, v in example.items():
        if k == 'frames':
          fval[k] = tf.cast(tf.io.decode_raw(v, tf.int16), tf.float32)
        else:
          assert isinstance(v, tf.SparseTensor)
          fval[k] = v.values
      # Reshape the flattened vector into its original time-major
      # representation.
      fval['frames'] = tf.reshape(
          fval['frames'], shape=[-1, self.params.frame_size])
      # Input duration determines the bucket.
      bucket_key = tf.cast(tf.shape(fval['frames'])[0], tf.int32)
      if self.params.append_eos_frame:
        bucket_key += 1
      tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(fval['transcript'])
      src_paddings = tf.zeros([tf.shape(fval['frames'])[0]], dtype=tf.float32)
      return [
          fval['int64_uttid'], fval['int64_audio_document_id'],
          fval['num_utterances_in_audio_document'], tgt_ids,
          tgt_labels, tgt_paddings, fval['frames'],
          src_paddings
      ], bucket_key

    return generic_input.GenericInput(
        file_pattern=file_pattern,
        processor=Proc,
        dynamic_padding_dimensions=[0] * 8,
        dynamic_padding_constants=[0] * 7 + [1],
        **self.CommonInputOpArgs())
示例#19
0
  def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None):
    """Read and return input batch from a string file_pattern."""
    del input_source_weights  # Unused.

    def Process(source_id, record):
      del source_id  # Unused.
      [num] = tf.py_func(int, [record], [tf.int64])
      return py_utils.NestedMap(data=num), 1

    # Samples random records from the data files and processes them
    # to generate batches.
    inputs, _ = generic_input.GenericInput(
        processor=Process,
        file_pattern=file_pattern,
        file_random_seed=123,
        file_buffer_size=1,
        file_parallelism=1,
        bucket_batch_limit=[1],
        bucket_upper_bound=[1])
    return inputs
示例#20
0
    def testNestedGenericInput(self, inner_batch_limit, outer_batch_limit):
        # Generate records using an inner GenericInput, and post-process them using
        # an outer one.
        # Test that the generated records are complete and contain no duplicates

        def _process(record):
            del record
            # Construct the inner GenericInput.
            batch = run_basic_graph(use_nested_map=True,
                                    bucket_batch_limit=inner_batch_limit)
            batch.num += 1
            return batch, 1

        input_batch, _ = generic_input.GenericInput(
            file_pattern='iota:',
            processor=_process,
            bucket_upper_bound=[1],
            bucket_batch_limit=[outer_batch_limit])

        with self.session():
            global_batch = inner_batch_limit * outer_batch_limit
            record_seen = set()
            # Iterate the inputs for exactly one epoch.
            for i in range(100 // global_batch):
                ans_input_batch = self.evaluate(input_batch)
                for record_array in ans_input_batch.record:
                    for s in record_array:
                        # There should not be duplicates since GenericInput is stateful.
                        assert s not in record_seen
                        record_seen.add(s)
                self.assertEqual(ans_input_batch.source_id.shape,
                                 (outer_batch_limit, inner_batch_limit))
                self.assertEqual(ans_input_batch.record.shape,
                                 (outer_batch_limit, inner_batch_limit))
                self.assertEqual(ans_input_batch.num.shape,
                                 (outer_batch_limit, inner_batch_limit, 2))
                ans_vals = ans_input_batch.num
                self.assertAllEqual(np.square(ans_vals[:, :, 0] - 1),
                                    ans_vals[:, :, 1] - 1)
            for i in range(100):
                self.assertIn(('%08d' % i).encode('utf-8'), record_seen)
示例#21
0
  def _DataSourceFromFilePattern(self, file_pattern, input_source_weights=None):

    def Processor(source_id, record):
      """Parses a record, which is a line of text."""
      if self.params.input_file_type == 'tsv':
        if self.params.single_column_input:
          src, filtered = self._ReadRecordTsvSingleColumn(record)
          features = self._ProcessMASSInput(source_id, src)
        else:
          src, tgt, filtered = self._ReadRecordTsv(record)
          features = self._ProcessSingleInput(source_id, src, tgt)
      else:
        src, tgt = self._ReadRecordSentencePairProto(record)
        filtered = tf.constant(False, dtype=tf.bool)
        features = self._ProcessSingleInput(source_id, src, tgt)
      return features, self._GetBucketKey(features, filtered)

    return generic_input.GenericInput(
        processor=Processor,
        file_pattern=file_pattern,
        input_source_weights=input_source_weights,
        **self.CommonInputOpArgs())
示例#22
0
  def testV2OpsErrorRaised(self, use_tf_func, set_allow_eager):
    # Generate a test file w/ 100 records.
    tmp = os.path.join(tf.test.get_temp_dir(), 'basic')
    with tf.python_io.TFRecordWriter(tmp) as w:
      for i in range(100):
        w.write(('%08d' % i).encode('utf-8'))

    # A simple string parsing routine. Just convert a string to a
    # number.
    def str_to_num(s):
      return np.array(float(s), dtype=np.float32)

    bucket_fn = lambda x: 1

    # A record processor written in TF graph.
    def _process(source_id, record):
      num, = tf.py_func(str_to_num, [record], [tf.float32])
      num = tf.stack([num, tf.square(num)])
      return py_utils.NestedMap(
          source_id=source_id, record=record, num=num), bucket_fn(num)

    if set_allow_eager:
      # Test unique keys must be provided to distinguish GenericInputV2 ops
      generic_input.SetAllowGenericInputV2InEager(True)
      err_regex = 'op requires a unique key'
    else:
      # Test flags must be set to enable GenericInputV2 ops in Eager mode
      generic_input.SetAllowGenericInputV2InEager(False)
      err_regex = 'please add keyword arg'

    with self.assertRaisesRegex(RuntimeError, err_regex):
      _ = generic_input.GenericInput(
          file_pattern='tfrecord:' + tmp,
          file_random_seed=0,
          file_buffer_size=32,
          file_parallelism=4,
          bucket_batch_limit=[8],
          bucket_upper_bound=[1],
          processor=_process)
示例#23
0
    def _DataSourceFromFilePattern(self,
                                   file_pattern,
                                   input_source_weights=None):
        assert not tf.compat.v1.executing_eagerly()
        assert tf.compat.v1.executing_eagerly_outside_functions()

        def _process(source_id, record):
            del source_id
            num = tf.strings.to_number(record, tf.int32)
            if not tf_py_utils.use_tpu():
                num = num * num
            return py_utils.NestedMap(num=num), 1

        inputs, _ = generic_input.GenericInput(processor=_process,
                                               file_pattern=file_pattern,
                                               file_random_seed=0,
                                               require_sequential_order=True,
                                               repeat_count=1,
                                               file_buffer_size=32,
                                               file_parallelism=1,
                                               bucket_upper_bound=[10],
                                               bucket_batch_limit=[2])
        return inputs
示例#24
0
  def _DataSourceFromFilePattern(self, file_pattern):
    p = self._params

    def Proc(record):
      """Parses a serialized tf.Example record."""
      outputs = [
          ('inputs', tf.VarLenFeature(tf.int64)),
          ('targets', tf.VarLenFeature(tf.int64)),
      ]
      features = tf.parse_single_example(record, dict(outputs))
      for k, v in six.iteritems(features):
        features[k] = v.values

      src_ids = features['inputs']
      tgt_labels = features['targets']

      # Derive src_paddings, tgt_ids, tgt_paddings.
      # tgt_ids is tgt_labels shifted right by one, with a SOS ID prepended.
      tgt_ids = tf.concat([[p.sos_id], tgt_labels[:-1]], axis=0)
      src_paddings = tf.zeros(tf.shape(src_ids), dtype=tf.float32)
      tgt_paddings = tf.zeros(tf.shape(tgt_ids), dtype=tf.float32)
      tgt_weights = tf.ones(tf.shape(tgt_ids), dtype=tf.float32)
      bucket_key = tf.cast(
          tf.maximum(
              tf.reduce_sum(1.0 - src_paddings),
              tf.reduce_sum(1.0 - tgt_paddings)), tf.int32)

      return [
          src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights
      ], bucket_key

    return generic_input.GenericInput(
        file_pattern=file_pattern,
        processor=Proc,
        dynamic_padding_dimensions=[0] * 6,
        dynamic_padding_constants=[0, 1, 0, 1, 0, 0],
        **self.CommonInputOpArgs())
示例#25
0
  def testV2OpsGetCalledInEager(self, use_tf_func, mock_op_key):
    # Generate a test file w/ 100 records.
    tmp = os.path.join(tf.test.get_temp_dir(), 'basic')
    with tf.python_io.TFRecordWriter(tmp) as w:
      for i in range(100):
        w.write(('%08d' % i).encode('utf-8'))

    # A simple string parsing routine. Just convert a string to a
    # number.
    def str_to_num(s):
      return np.array(float(s), dtype=np.float32)

    bucket_fn = lambda x: 1

    # A record processor written in TF graph.
    def _process(source_id, record):
      num, = tf.py_func(str_to_num, [record], [tf.float32])
      num = tf.stack([num, tf.square(num)])
      return py_utils.NestedMap(
          source_id=source_id, record=record, num=num), bucket_fn(num)

    # pylint: disable=protected-access
    len_before = len(generic_input._GENERIC_CACHE_V2)
    _ = generic_input.GenericInput(
        file_pattern='tfrecord:' + tmp,
        file_random_seed=0,
        file_buffer_size=32,
        file_parallelism=4,
        bucket_batch_limit=[8],
        bucket_upper_bound=[1],
        processor=_process,
        generic_input_v2_key=mock_op_key)

    # pylint: disable=protected-access
    len_after = len(generic_input._GENERIC_CACHE_V2)
    self.assertEqual(len_after, len_before + 1)
示例#26
0
  def _DataSourceFromFilePattern(self, file_pattern):
    p = self._params

    def _DerivePaddingsAndIds(src_ids, tgt_labels):
      """tgt_ids is tgt_labels shifted right by one, with a SOS ID prepended."""
      tgt_ids = tf.concat([[p.sos_id], tgt_labels[:-1]], axis=0)
      src_paddings = tf.zeros(tf.shape(src_ids), dtype=tf.float32)
      tgt_paddings = tf.zeros(tf.shape(tgt_ids), dtype=tf.float32)
      tgt_weights = tf.ones(tf.shape(tgt_ids), dtype=tf.float32)

      bucket_key = tf.cast(
          tf.maximum(
              tf.reduce_sum(1.0 - src_paddings),
              tf.reduce_sum(1.0 - tgt_paddings)), tf.int32)

      return src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key

    def _ProcPacked(record):
      """TFExample -> Tensors for PackedInput."""
      outputs = [
          ('inputs', tf.VarLenFeature(tf.int64)),
          ('targets', tf.VarLenFeature(tf.int64)),
          ('inputs_segmentation', tf.VarLenFeature(tf.int64)),
          ('inputs_position', tf.VarLenFeature(tf.int64)),
          ('targets_segmentation', tf.VarLenFeature(tf.int64)),
          ('targets_position', tf.VarLenFeature(tf.int64)),
      ]

      features = tf.parse_single_example(record, dict(outputs))
      for k, v in six.iteritems(features):
        features[k] = v.values

      src_ids = features['inputs']
      tgt_labels = features['targets']

      src_pos = features['inputs_position']
      src_seg = features['inputs_segmentation']

      tgt_pos = features['targets_position']
      tgt_seg = features['targets_segmentation']

      src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key = _DerivePaddingsAndIds(
          src_ids, tgt_labels)
      return [
          src_ids,
          src_paddings,
          tgt_ids,
          tgt_paddings,
          tgt_labels,
          tgt_weights,
          src_pos,
          src_seg,
          tgt_pos,
          tgt_seg,
      ], bucket_key

    def _Proc(record):
      """Parses a serialized tf.Example record."""
      outputs = [
          ('inputs', tf.VarLenFeature(tf.int64)),
          ('targets', tf.VarLenFeature(tf.int64)),
      ]
      features = tf.parse_single_example(record, dict(outputs))
      for k, v in six.iteritems(features):
        features[k] = v.values

      src_ids = features['inputs']
      tgt_labels = features['targets']

      # Derive trivial segmentation for unpacked input.
      src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key = _DerivePaddingsAndIds(
          src_ids, tgt_labels)

      src_len = tf.shape(src_ids)[0]
      tgt_len = tf.shape(tgt_ids)[0]
      src_pos = tf.range(src_len, dtype=tf.int32)
      src_seg = tf.zeros_like(src_paddings)
      tgt_pos = tf.range(tgt_len, dtype=tf.int32)
      tgt_seg = tf.zeros_like(tgt_paddings)

      return [
          src_ids,
          src_paddings,
          tgt_ids,
          tgt_paddings,
          tgt_labels,
          tgt_weights,
          src_pos,
          src_seg,
          tgt_pos,
          tgt_seg,
      ], bucket_key

    if not p.packed_input:
      processor_fn = _Proc
    else:
      processor_fn = _ProcPacked

    return generic_input.GenericInput(
        file_pattern=file_pattern,
        processor=processor_fn,
        dynamic_padding_dimensions=[0] * 10,
        dynamic_padding_constants=[0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
        **self.CommonInputOpArgs())