Example #1
0
def _ParseProcessor(processor):
    """Parses python callable `processor` into a TF concrete function."""
    output_tmpl = py_utils.NestedMap()

    @tf.function(autograph=False)
    def _FlatOutputProcessor(source_id, record):
        """Returns a flattened list of 'processor(inputs)'."""
        processor_spec = tf_inspect.getargspec(processor)
        tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
        processor_args = set(processor_spec.args) - set(['self'])
        if len(processor_args) == 1:
            output, bucketing_key = processor(record)
        elif processor_args == set(['source_id', 'record']):
            output, bucketing_key = processor(source_id=source_id,
                                              record=record)
        else:
            raise ValueError(
                'GenericInput: processor should take either a single arg '
                'or two args named as "source_id" and "record". '
                'Actual: %s' % processor_args)
        if isinstance(output, list):
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output), '{}'.format(output)
        else:
            assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output.Flatten()), '{}'.format(
                           output.DebugString())
        bucketing_key = tf.cast(bucketing_key, tf.int32)
        tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                         bucketing_key)
        output_tmpl.out_values = output
        flat_output_tmpl = output_tmpl.Flatten()
        tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
        tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                         py_utils.GetExtraInputs(), py_utils.GetExtraArgs(),
                         py_utils.GetExtraVars())
        assert not py_utils.GetExtraArgs(), (
            'fns {} is not pure: extra_args={}'.format(
                processor, py_utils.GetExtraArgs()))
        return flat_output_tmpl + [bucketing_key]

    with py_utils.GlobalStepContext(None):
        # Hide global_step tensor from being captured by _FlatOutputProcessor.
        proc_fn = _FlatOutputProcessor.get_concrete_function(
            tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.string))

    out_types = [
        tf.DType(a.type) for a in proc_fn.function_def.signature.output_arg
    ]
    assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
    return proc_fn, out_types, output_tmpl
Example #2
0
def GenericInput(processor, *args, **kwargs):
    """Builds a generic input pipeline.

  Example usage::

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is a scalar convertible to tf.int32.
      # Use 1 if all examples are of the same size.
      bucketing_key = 1
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  Args:
    processor: a function that takes a string record as input and returns a
      tuple (output, bucketing_key). `output` must be a NestedMap or a list of
      tensors representing one example. The `bucketing_key` must be a scalar
      convertible to a tf.int32 tensor that represents the bucketing key (e.g.,
      sequence length for sequence inputs).
    *args: additional args for x_ops.generic_input.
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):
    - outputs: a NestedMap or a list of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension.
    - bucket_keys: a tf.int32 vector.
  """
    output_tmpl = py_utils.NestedMap()

    def _FlatOutputProcessor(inputs):
        """Returns a flattened list of 'processor(inputs)'."""
        output, bucketing_key = processor(inputs)
        if isinstance(output, list):
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output), '{}'.format(output)
        else:
            assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output.Flatten()), '{}'.format(
                           output.DebugString())
        bucketing_key = tf.to_int32(bucketing_key)
        tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                         bucketing_key)
        output_tmpl.values = output
        flat_output_tmpl = output_tmpl.Flatten()
        tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
        tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                         function.get_extra_inputs(),
                         function.get_extra_args(), function.get_extra_vars())
        assert not function.get_extra_args(), (
            'fns {} is not pure: extra_args={}'.format(
                processor, function.get_extra_args()))
        return flat_output_tmpl + [bucketing_key]

    proc_fn = function.Defun(tf.string)(_FlatOutputProcessor)

    out_types = [
        tf.DType(a.type) for a in proc_fn.definition.signature.output_arg
    ]
    assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
    flat_outputs, bucket_keys = ops.gen_x_ops.generic_input(
        processor=proc_fn, out_types=out_types[:-1], *args, **kwargs)
    tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs)
    # Pack flat_outputs to outputs.
    outputs = output_tmpl.Pack(flat_outputs).values
    tf.logging.debug('x_ops.generic_input outputs=%s', outputs)
    return outputs, bucket_keys
Example #3
0
def GenericInput(processor, **kwargs):
  """Builds a generic input pipeline.

  Example usage::

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is a scalar convertible to tf.int32.
      # Use 1 if all examples are of the same size.
      bucketing_key = 1
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  ParseRecord can also take both 'source_id' and 'record' as inputs (the arg
  names must be exactly 'source_id' and 'record'):

    def ParseRecord(source_id, record):
      # Given a tf.int32 source_id and a tf.string record, return a (NestedMap,
      # bucketing key) pair.
      example = py_utils.NestedMap(source_id=source_id, ...)
      ...
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)

  Args:
    processor: a function that takes either a tf.string record or a
      (source_id: tf.int32, record: tf.string) pair as input and returns a
      tuple (output, bucketing_key).
      `output` must be a NestedMap or a list of tensors representing an example.
      `bucketing_key` must be a scalar convertible to a tf.int32 tensor that
      represents the bucketing key (e.g., sequence length for sequence inputs).
      If `bucketing_key` is a negative number, the record is dropped.
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):

    - outputs: a NestedMap or a list of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension.
    - bucket_keys: a tf.int32 vector.
  """
  output_tmpl = py_utils.NestedMap()

  def _FlatOutputProcessor(source_id, record):
    """Returns a flattened list of 'processor(inputs)'."""
    processor_spec = tf_inspect.getargspec(processor)
    tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
    processor_args = set(processor_spec.args) - set(['self'])
    if len(processor_args) == 1:
      output, bucketing_key = processor(record)
    elif processor_args == set(['source_id', 'record']):
      output, bucketing_key = processor(source_id=source_id, record=record)
    else:
      raise ValueError(
          'GenericInput: processor should take either a single arg '
          'or two args named as "source_id" and "record". '
          'Actual: %s' % processor_args)
    if isinstance(output, list):
      assert output
      assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output)
    else:
      assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
      assert output
      assert all(
          isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format(
              output.DebugString())
    bucketing_key = tf.to_int32(bucketing_key)
    tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                     bucketing_key)
    output_tmpl.out_values = output
    flat_output_tmpl = output_tmpl.Flatten()
    tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
    tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                     function.get_extra_inputs(), function.get_extra_args(),
                     function.get_extra_vars())
    assert not function.get_extra_args(), (
        'fns {} is not pure: extra_args={}'.format(processor,
                                                   function.get_extra_args()))
    return flat_output_tmpl + [bucketing_key]

  proc_fn = tf.Defun(tf.int32, tf.string)(_FlatOutputProcessor)

  out_types = [
      tf.DType(a.type) for a in proc_fn.definition.signature.output_arg
  ]
  assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
  flat_outputs, bucket_keys = ops.gen_x_ops.generic_input(
      processor=proc_fn, out_types=out_types[:-1], **kwargs)
  tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs)
  # Pack flat_outputs to outputs.
  outputs = output_tmpl.Pack(flat_outputs).out_values
  tf.logging.debug('x_ops.generic_input outputs=%s', outputs)
  return outputs, bucket_keys
Example #4
0
def GenericInput(processor, *args, **kwargs):
    """Builds a generic input pipeline.

  Example usage::

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is an int scalar tensor.
      # Use 1 if all examples are of the same size.
      bucketing_key = tf.to_int32(1)
      return example, bucketing_key

    input_batch = GenericInput(ParseRecord, file_pattern=..., ...)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  Args:
    processor: a function that takes a string record as input and returns a list
      of tensors or NestedMaps representing one example. The last return value
      of processor must be an int32 scalar tensor that represents the bucketing
      key (e.g., sequence length for sequence inputs).
    *args: additional args for x_ops.generic_input.
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A list of tensors or NestedMaps, similar `processor`'s return, except:
      * The bucket key is not included in the output.
      * Every tensor will have an additional dimension 0 that represents the
        batch dimension.
  """
    output_tmpl = py_utils.NestedMap()

    def _FlatOutputProcessor(inputs):
        """Returns a flattened list of 'processor(inputs)'."""
        outputs = processor(inputs)
        tf.logging.debug('Processor outputs=%s', outputs)
        assert len(outputs) > 1, outputs
        # Add 'outputs' as a list so that each element will be flattened.
        output_tmpl.values = list(outputs)
        flat_outputs = output_tmpl.Flatten()
        tf.logging.debug('Processor flat outputs=%s', flat_outputs)
        tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                         function.get_extra_inputs(),
                         function.get_extra_args(), function.get_extra_vars())
        assert not function.get_extra_args(), (
            'fns {} is not pure: extra_args={}'.format(
                processor, function.get_extra_args()))
        return flat_outputs

    proc_fn = function.Defun(tf.string)(_FlatOutputProcessor)

    out_types = [
        tf.DType(a.type) for a in proc_fn.definition.signature.output_arg
    ]
    assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
    flat_outputs = ops.gen_x_ops.generic_input(processor=proc_fn,
                                               out_types=out_types[:-1],
                                               *args,
                                               **kwargs)
    tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs)
    if not output_tmpl:
        return flat_outputs
    # Pack flat_outputs to outputs.
    output_tmpl.values.pop(-1)
    outputs = output_tmpl.Pack(flat_outputs).values
    tf.logging.debug('x_ops.generic_input outputs=%s', outputs)
    return outputs