コード例 #1
0
def _ParseProcessor(processor):
    """Parses python callable `processor` into a TF concrete function."""
    output_tmpl = py_utils.NestedMap()

    @tf.function(autograph=False)
    def _FlatOutputProcessor(source_id, record):
        """Returns a flattened list of 'processor(inputs)'."""
        processor_spec = tf_inspect.getargspec(processor)
        tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
        processor_args = set(processor_spec.args) - set(['self'])
        if len(processor_args) == 1:
            output, bucketing_key = processor(record)
        elif processor_args == set(['source_id', 'record']):
            output, bucketing_key = processor(source_id=source_id,
                                              record=record)
        else:
            raise ValueError(
                'GenericInput: processor should take either a single arg '
                'or two args named as "source_id" and "record". '
                'Actual: %s' % processor_args)
        if isinstance(output, list):
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output), '{}'.format(output)
        else:
            assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output.Flatten()), '{}'.format(
                           output.DebugString())
        bucketing_key = tf.cast(bucketing_key, tf.int32)
        tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                         bucketing_key)
        output_tmpl.out_values = output
        flat_output_tmpl = output_tmpl.Flatten()
        tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
        tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                         py_utils.GetExtraInputs(), py_utils.GetExtraArgs(),
                         py_utils.GetExtraVars())
        assert not py_utils.GetExtraArgs(), (
            'fns {} is not pure: extra_args={}'.format(
                processor, py_utils.GetExtraArgs()))
        return flat_output_tmpl + [bucketing_key]

    with py_utils.GlobalStepContext(None):
        # Hide global_step tensor from being captured by _FlatOutputProcessor.
        proc_fn = _FlatOutputProcessor.get_concrete_function(
            tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.string))

    out_types = [
        tf.DType(a.type) for a in proc_fn.function_def.signature.output_arg
    ]
    assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
    return proc_fn, out_types, output_tmpl
コード例 #2
0
class FakeTF2ImageModule(tf.train.Checkpoint):
    def __init__(self, output_dim=768):
        # Counts the number of times the layer has been run with training=True.
        self.counter = tf.Variable(initial_value=0,
                                   dtype=tf.int32,
                                   name='counter',
                                   use_resource=True)
        self.output_dim = output_dim

        # "Reusable" SavedModel metadata expected by KerasLayer.
        self.variables = [self.counter]
        self.trainable_variables = []
        self.regularization_losses = []

    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None, None, None, None], dtype=tf.float32),
        tf.TensorSpec(shape=[], dtype=tf.bool)
    ])
    def __call__(self, images, training=False):
        if training:
            self.counter.assign_add(1)
        return tf.ones([tf.shape(images)[0], self.output_dim])
コード例 #3
0
ファイル: dataset_spec_test.py プロジェクト: vcj-huy/lingvo
  def testPopulatesFeaturesInMetadata(self):
    spec = dataset_spec.TFRecordDatasetSpec(
        # Create a file so Dataset.list_files() won't complain if the test is
        # ever switched to eager.
        split_paths={'train': self.create_tempfile().full_path},
        schema={
            'quizybuck':
                tf.io.FixedLenFeature([42], tf.int64),
            'x':
                tf.io.FixedLenSequenceFeature([3],
                                              tf.float32,
                                              allow_missing=True)
        },
        label_fn=None)
    expected_features = {
        'quizybuck': tf.TensorSpec([42], tf.int64),
        'x': tf.TensorSpec([None, 3], tf.float32)
    }
    self.assertSameStructure(spec.meta.features, expected_features)

    # Check that these match the features in the actual `tf.data.Dataset`.
    self.assertSameStructure(spec.Read('train').element_spec, expected_features)
コード例 #4
0
def GenericInput(processor, **kwargs):
    """Builds a generic input pipeline.

  Example usage::

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.io.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is a scalar convertible to tf.int32.
      # Use 1 if all examples are of the same size.
      bucketing_key = 1
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  ParseRecord can also take both 'source_id' and 'record' as inputs (the arg
  names must be exactly 'source_id' and 'record'):

    def ParseRecord(source_id, record):
      # Given a tf.int32 source_id and a tf.string record, return a (NestedMap,
      # bucketing key) pair.
      example = py_utils.NestedMap(source_id=source_id, ...)
      ...
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)

  Args:
    processor: a function that takes either a tf.string record or a
      (source_id: tf.int32, record: tf.string) pair as input and returns a tuple
      (output, bucketing_key). `output` must be a NestedMap or a list of tensors
      representing an example. `bucketing_key` must be a scalar convertible to
      a tf.int32 tensor that represents the bucketing key (e.g., sequence
      length for sequence inputs). If `bucketing_key` is a negative number,
      the record is dropped.
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):

    - outputs: a NestedMap or a list of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension.
    - bucket_keys: a tf.int32 vector.
  """
    output_tmpl = py_utils.NestedMap()

    @tf.function(autograph=False)
    def _FlatOutputProcessor(source_id, record):
        """Returns a flattened list of 'processor(inputs)'."""
        processor_spec = tf_inspect.getargspec(processor)
        tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
        processor_args = set(processor_spec.args) - set(['self'])
        if len(processor_args) == 1:
            output, bucketing_key = processor(record)
        elif processor_args == set(['source_id', 'record']):
            output, bucketing_key = processor(source_id=source_id,
                                              record=record)
        else:
            raise ValueError(
                'GenericInput: processor should take either a single arg '
                'or two args named as "source_id" and "record". '
                'Actual: %s' % processor_args)
        if isinstance(output, list):
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output), '{}'.format(output)
        else:
            assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output.Flatten()), '{}'.format(
                           output.DebugString())
        bucketing_key = tf.cast(bucketing_key, tf.int32)
        tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                         bucketing_key)
        output_tmpl.out_values = output
        flat_output_tmpl = output_tmpl.Flatten()
        tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
        tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                         py_utils.GetExtraInputs(), py_utils.GetExtraArgs(),
                         py_utils.GetExtraVars())
        assert not py_utils.GetExtraArgs(), (
            'fns {} is not pure: extra_args={}'.format(
                processor, py_utils.GetExtraArgs()))
        return flat_output_tmpl + [bucketing_key]

    proc_fn = _FlatOutputProcessor.get_concrete_function(
        tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.string))

    out_types = [
        tf.DType(a.type) for a in proc_fn.function_def.signature.output_arg
    ]
    assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
    flat_outputs, bucket_keys = ops.gen_x_ops.generic_input(
        processor=proc_fn, out_types=out_types[:-1], **kwargs)
    tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs)
    # Pack flat_outputs to outputs.
    outputs = output_tmpl.Pack(flat_outputs).out_values
    tf.logging.debug('x_ops.generic_input outputs=%s', outputs)
    return outputs, bucket_keys
コード例 #5
0
def dump_spellings():
    words = []
    with open(FLAGS.in_words_txt, "r") as words_fh:
        words = words_fh.read().lower().splitlines()
    # if "<unk>" not in words:
    #   words.append("<unk>")
    # We add 2 to account for <s> and (optional) </s> tokens.
    longest_word_length = max(len(word) for word in words) + 2

    print("GALV:", longest_word_length)

    with open(FLAGS.in_units_txt, "r") as units_fh:
        vocab_tokens = [line.rstrip("\n") for line in units_fh.readlines()]

    print("GALV:", vocab_tokens)

    @tf.function(
        input_signature=[tf.TensorSpec(shape=[len(words)], dtype=tf.string)])
    def tokenize_words(words_t):
        padded_tokenized_t, _, paddings_t = str_to_vocab_tokens(
            labels=words_t,
            maxlen=longest_word_length,
            append_eos=True,
            pad_to_maxlen=True,
            vocab_filepath=FLAGS.in_units_txt,
            load_token_ids_from_vocab=False,
            delimiter="",
        )
        # Either lengths or paddings are incorrect.
        lengths_t = py_utils.LengthsFromPaddings(paddings_t)
        ragged_tokenized_t = tf.RaggedTensor.from_tensor(padded_tokenized_t,
                                                         lengths=lengths_t)
        # Drop start-of-sentence-token
        ragged_tokenized_t = ragged_tokenized_t[:, 1:]
        lengths_t -= 1
        letters_t = vocab_id_to_token(
            id=ragged_tokenized_t.flat_values,
            vocab=vocab_tokens,
            load_token_ids_from_vocab=False,
        )
        ragged_letters_t = tf.RaggedTensor.from_row_lengths(
            letters_t, lengths_t)
        # Is capatilizationt he problem?
        return ragged_tokenized_t, ragged_letters_t

    with tf.Session() as session:
        spelling_numbers, spelling_letters = session.run(tokenize_words(words))
    spelling_numbers = spelling_numbers.to_list()
    spelling_letters = spelling_letters.to_list()

    with open(FLAGS.out_spelling_txt,
              "w") as spelling_fh, open(FLAGS.out_spelling_numbers_txt,
                                        "w") as spelling_numbers_fh:
        for word, numbers, letters in zip(words, spelling_numbers,
                                          spelling_letters):
            if isinstance(letters, list):
                letters_str = " ".join([str(letter) for letter in word])
            else:
                letters_str = letters
            numbers_str = " ".join([str(number) for number in numbers])
            spelling_fh.write(f"{word} {letters_str}\n")
            spelling_numbers_fh.write(f"{word} {numbers_str}\n")
        spelling_fh.write("<unk> <unk>\n")
        spelling_numbers_fh.write(f"<unk> {UNK_NUMBER}\n")