def _ParseProcessor(processor): """Parses python callable `processor` into a TF concrete function.""" output_tmpl = py_utils.NestedMap() @tf.function(autograph=False) def _FlatOutputProcessor(source_id, record): """Returns a flattened list of 'processor(inputs)'.""" processor_spec = tf_inspect.getargspec(processor) tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec) processor_args = set(processor_spec.args) - set(['self']) if len(processor_args) == 1: output, bucketing_key = processor(record) elif processor_args == set(['source_id', 'record']): output, bucketing_key = processor(source_id=source_id, record=record) else: raise ValueError( 'GenericInput: processor should take either a single arg ' 'or two args named as "source_id" and "record". ' 'Actual: %s' % processor_args) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all(isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.cast(bucketing_key, tf.int32) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.out_values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', py_utils.GetExtraInputs(), py_utils.GetExtraArgs(), py_utils.GetExtraVars()) assert not py_utils.GetExtraArgs(), ( 'fns {} is not pure: extra_args={}'.format( processor, py_utils.GetExtraArgs())) return flat_output_tmpl + [bucketing_key] with py_utils.GlobalStepContext(None): # Hide global_step tensor from being captured by _FlatOutputProcessor. proc_fn = _FlatOutputProcessor.get_concrete_function( tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.string)) out_types = [ tf.DType(a.type) for a in proc_fn.function_def.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) return proc_fn, out_types, output_tmpl
def GenericInput(processor, *args, **kwargs): """Builds a generic input pipeline. Example usage:: def ParseRecord(record): # Given a tf.string record, return a (NestedMap, bucketing key) pair. feature_map = ... features = tf.parse_single_example(record, feature_map) # Each example is represented by a NestedMap of tensors (without a # batch dimension). example = py_utils.NestedMap(field1=..., field2=...) # bucketing_key is a scalar convertible to tf.int32. # Use 1 if all examples are of the same size. bucketing_key = 1 return example, bucketing_key input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...) # input_batch is a NestedMap of tensors, where dim 0 of each tensor # represents the batch dimension. input_batch.field1 = ... Args: processor: a function that takes a string record as input and returns a tuple (output, bucketing_key). `output` must be a NestedMap or a list of tensors representing one example. The `bucketing_key` must be a scalar convertible to a tf.int32 tensor that represents the bucketing key (e.g., sequence length for sequence inputs). *args: additional args for x_ops.generic_input. **kwargs: additional keyword args for x_ops.generic_input. Returns: A tuple of (outputs, bucket_keys): - outputs: a NestedMap or a list of tensors, similar to `processor`'s return, except every tensor will have an additional dimension 0 that represents the batch dimension. - bucket_keys: a tf.int32 vector. """ output_tmpl = py_utils.NestedMap() def _FlatOutputProcessor(inputs): """Returns a flattened list of 'processor(inputs)'.""" output, bucketing_key = processor(inputs) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all(isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.to_int32(bucketing_key) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_output_tmpl + [bucketing_key] proc_fn = function.Defun(tf.string)(_FlatOutputProcessor) out_types = [ tf.DType(a.type) for a in proc_fn.definition.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) flat_outputs, bucket_keys = ops.gen_x_ops.generic_input( processor=proc_fn, out_types=out_types[:-1], *args, **kwargs) tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs) # Pack flat_outputs to outputs. outputs = output_tmpl.Pack(flat_outputs).values tf.logging.debug('x_ops.generic_input outputs=%s', outputs) return outputs, bucket_keys
def GenericInput(processor, **kwargs): """Builds a generic input pipeline. Example usage:: def ParseRecord(record): # Given a tf.string record, return a (NestedMap, bucketing key) pair. feature_map = ... features = tf.parse_single_example(record, feature_map) # Each example is represented by a NestedMap of tensors (without a # batch dimension). example = py_utils.NestedMap(field1=..., field2=...) # bucketing_key is a scalar convertible to tf.int32. # Use 1 if all examples are of the same size. bucketing_key = 1 return example, bucketing_key input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...) # input_batch is a NestedMap of tensors, where dim 0 of each tensor # represents the batch dimension. input_batch.field1 = ... ParseRecord can also take both 'source_id' and 'record' as inputs (the arg names must be exactly 'source_id' and 'record'): def ParseRecord(source_id, record): # Given a tf.int32 source_id and a tf.string record, return a (NestedMap, # bucketing key) pair. example = py_utils.NestedMap(source_id=source_id, ...) ... return example, bucketing_key input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...) Args: processor: a function that takes either a tf.string record or a (source_id: tf.int32, record: tf.string) pair as input and returns a tuple (output, bucketing_key). `output` must be a NestedMap or a list of tensors representing an example. `bucketing_key` must be a scalar convertible to a tf.int32 tensor that represents the bucketing key (e.g., sequence length for sequence inputs). If `bucketing_key` is a negative number, the record is dropped. **kwargs: additional keyword args for x_ops.generic_input. Returns: A tuple of (outputs, bucket_keys): - outputs: a NestedMap or a list of tensors, similar to `processor`'s return, except every tensor will have an additional dimension 0 that represents the batch dimension. - bucket_keys: a tf.int32 vector. """ output_tmpl = py_utils.NestedMap() def _FlatOutputProcessor(source_id, record): """Returns a flattened list of 'processor(inputs)'.""" processor_spec = tf_inspect.getargspec(processor) tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec) processor_args = set(processor_spec.args) - set(['self']) if len(processor_args) == 1: output, bucketing_key = processor(record) elif processor_args == set(['source_id', 'record']): output, bucketing_key = processor(source_id=source_id, record=record) else: raise ValueError( 'GenericInput: processor should take either a single arg ' 'or two args named as "source_id" and "record". ' 'Actual: %s' % processor_args) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all( isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.to_int32(bucketing_key) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.out_values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format(processor, function.get_extra_args())) return flat_output_tmpl + [bucketing_key] proc_fn = tf.Defun(tf.int32, tf.string)(_FlatOutputProcessor) out_types = [ tf.DType(a.type) for a in proc_fn.definition.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) flat_outputs, bucket_keys = ops.gen_x_ops.generic_input( processor=proc_fn, out_types=out_types[:-1], **kwargs) tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs) # Pack flat_outputs to outputs. outputs = output_tmpl.Pack(flat_outputs).out_values tf.logging.debug('x_ops.generic_input outputs=%s', outputs) return outputs, bucket_keys
def GenericInput(processor, *args, **kwargs): """Builds a generic input pipeline. Example usage:: def ParseRecord(record): # Given a tf.string record, return a (NestedMap, bucketing key) pair. feature_map = ... features = tf.parse_single_example(record, feature_map) # Each example is represented by a NestedMap of tensors (without a # batch dimension). example = py_utils.NestedMap(field1=..., field2=...) # bucketing_key is an int scalar tensor. # Use 1 if all examples are of the same size. bucketing_key = tf.to_int32(1) return example, bucketing_key input_batch = GenericInput(ParseRecord, file_pattern=..., ...) # input_batch is a NestedMap of tensors, where dim 0 of each tensor # represents the batch dimension. input_batch.field1 = ... Args: processor: a function that takes a string record as input and returns a list of tensors or NestedMaps representing one example. The last return value of processor must be an int32 scalar tensor that represents the bucketing key (e.g., sequence length for sequence inputs). *args: additional args for x_ops.generic_input. **kwargs: additional keyword args for x_ops.generic_input. Returns: A list of tensors or NestedMaps, similar `processor`'s return, except: * The bucket key is not included in the output. * Every tensor will have an additional dimension 0 that represents the batch dimension. """ output_tmpl = py_utils.NestedMap() def _FlatOutputProcessor(inputs): """Returns a flattened list of 'processor(inputs)'.""" outputs = processor(inputs) tf.logging.debug('Processor outputs=%s', outputs) assert len(outputs) > 1, outputs # Add 'outputs' as a list so that each element will be flattened. output_tmpl.values = list(outputs) flat_outputs = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_outputs) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_outputs proc_fn = function.Defun(tf.string)(_FlatOutputProcessor) out_types = [ tf.DType(a.type) for a in proc_fn.definition.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) flat_outputs = ops.gen_x_ops.generic_input(processor=proc_fn, out_types=out_types[:-1], *args, **kwargs) tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs) if not output_tmpl: return flat_outputs # Pack flat_outputs to outputs. output_tmpl.values.pop(-1) outputs = output_tmpl.Pack(flat_outputs).values tf.logging.debug('x_ops.generic_input outputs=%s', outputs) return outputs