def _trace_preprocessing_fn_v2(preprocessing_fn, specs, base_temp_dir): """Trace TF2 graph for `preprocessing_fn`.""" tf_graph_context = graph_context.TFGraphContext( module_to_export=tf.Module(), temp_dir=base_temp_dir, evaluated_replacements=None) with annotators.object_tracker_scope(annotators.ObjectTracker()): concrete_fn = get_traced_transform_fn( preprocessing_fn, specs, tf_graph_context).get_concrete_function() return (concrete_fn.graph, tf2_utils.get_structured_inputs_from_func_graph(concrete_fn.graph), concrete_fn.structured_outputs)
def _get_schema(self, preprocessing_fn, use_compat_v1, inputs=None, input_signature=None, create_session=False): if inputs is None: inputs = {} if input_signature is None: input_signature = {} if use_compat_v1: with tf.compat.v1.Graph().as_default() as graph: # Convert eager tensors to graph tensors. inputs_copy = { k: tf.constant(v, input_signature[k].dtype) for k, v in inputs.items() } tensors = preprocessing_fn(inputs_copy) if create_session: # Create a session to actually evaluate the annotations and extract # the output schema with annotations applied. with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( tensors, graph, session) else: schema = schema_inference.infer_feature_schema( tensors, graph) else: tf_func = tf.function(preprocessing_fn, input_signature=[input_signature ]).get_concrete_function() tensors = tf.nest.pack_sequence_as( structure=tf_func.structured_outputs, flat_sequence=tf_func.outputs, expand_composites=True) structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( tf_func.graph) tf_graph_context = graph_context.TFGraphContext( module_to_export=tf.Module(), temp_dir=os.path.join(self.get_temp_dir(), self._testMethodName), evaluated_replacements={}) concrete_metadata_fn = schema_inference.get_traced_metadata_fn( preprocessing_fn=preprocessing_fn, structured_inputs=structured_inputs, tf_graph_context=tf_graph_context, evaluate_schema_overrides=create_session) schema = schema_inference.infer_feature_schema_v2( tensors, concrete_metadata_fn, evaluate_schema_overrides=create_session) return schema
def _create_test_saved_model(export_in_tf1, input_specs, preprocessing_fn, export_path_suffix=None, base_dir=None): if not export_path_suffix: export_path = os.path.join(tempfile.mkdtemp(dir=base_dir), 'export') else: export_path = os.path.join(tempfile.mkdtemp(dir=base_dir), export_path_suffix) if export_in_tf1: with tf.compat.v1.Graph().as_default(): with tf.compat.v1.Session().as_default() as session: inputs = {} for key in input_specs: tensor_spec = input_specs[key] if isinstance(tensor_spec, tf.TensorSpec): inputs[key] = tf.compat.v1.placeholder( tensor_spec.dtype, shape=tensor_spec.shape) elif isinstance(tensor_spec, tf.SparseTensorSpec): inputs[key] = tf.compat.v1.sparse_placeholder( tensor_spec.dtype, shape=tensor_spec.shape) elif isinstance(tensor_spec, tf.RaggedTensorSpec): inputs[key] = tf.compat.v1.ragged.placeholder( tensor_spec._dtype, tensor_spec._ragged_rank, []) else: raise ValueError( 'TypeSpecs specified should be one of `tf.TensorSpec`, ' '`tf.SparseTensorSpec`, `tf.RaggedTensorSpec`') outputs = preprocessing_fn(inputs) # show that unrelated & unmapped placeholders do not interfere tf.compat.v1.placeholder(tf.int64) saved_transform_io.write_saved_transform_from_session( session, inputs, outputs, export_path) else: module = tf.Module() tf_graph_context = graph_context.TFGraphContext( module_to_export=module, temp_dir=None, evaluated_replacements=None) transform_fn = impl_helper.get_traced_transform_fn( preprocessing_fn=preprocessing_fn, input_signature=input_specs, tf_graph_context=tf_graph_context, output_keys_to_name_map=None) saved_transform_io_v2.write_v2_saved_model(module, transform_fn, 'transform_fn', export_path) return export_path
def transform_fn(inputs): graph = ops.get_default_graph() # If any analyzers have already been evaluated, pass them using the # `graph_context.TFGraphContext`. This will be used in place of the analyzer # nodes. with graph_context.TFGraphContext( temp_dir=base_temp_dir, evaluated_replacements=tensor_replacement_map): transformed_features = preprocessing_fn(inputs) # An empty `TENSOR_REPLACEMENTS` collection symbolizes that there is no # analyzer left for Transform to evaluate. Either if this collection is # empty or if no specific outputs have been requested, return # the same output as `preprocessing_fn` (i.e, transformed_features). if (output_keys_to_name_map is None or not graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)): return transformed_features else: return { key: graph.get_tensor_by_name(value) for key, value in six.iteritems(output_keys_to_name_map) }
def metadata_fn(inputs): graph = ops.get_default_graph() # The user defined `preprocessing_fn` may directly modify its inputs which # is not allowed in a tf.function. Hence, we make a copy here. inputs_copy = tf_utils.copy_tensors(inputs) with graph_context.TFGraphContext( temp_dir=base_temp_dir, evaluated_replacements=tensor_replacement_map): transformed_features = preprocessing_fn(inputs_copy) # Get a map from tensor value names to feature keys. reversed_features = _get_tensor_value_to_key_map(transformed_features) result = collections.defaultdict(list) if not evaluate_schema_overrides: schema_override_tensors = graph.get_collection( _TF_METADATA_TENSOR_COLLECTION) for tensor in schema_override_tensors: if tensor.name in reversed_features: result[_TF_METADATA_TENSOR_COLLECTION].append( reversed_features[tensor.name]) else: # Obtain schema overrides for feature tensor ranges. result.update( _get_schema_overrides(graph, reversed_features, _TF_METADATA_TENSOR_COLLECTION, [ _TF_METADATA_TENSOR_MIN_COLLECTION, _TF_METADATA_TENSOR_MAX_COLLECTION ])) # Obtain schema overrides for feature protos. If no feature tensor is in # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the # feature name to indicate that this annotation should be added to the # global schema. result.update( _get_schema_overrides( graph, reversed_features, _TF_METADATA_EXTRA_ANNOTATION, [ _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL, _TF_METADATA_EXTRA_ANNOTATION_PROTO ], _TF_METADATA_EXTRA_ANNOTATION_GLOBAL)) return result
def _trace_and_write_transform_fn( saved_model_dir: str, preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], Mapping[str, common_types.TensorType]], input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str], tensor_replacement_map: Optional[Dict[str, tf.Tensor]], output_keys_to_name_map: Optional[Dict[str, str]] ) -> function.ConcreteFunction: """Trace `preprocessing_fn` and serialize to a SavedModel.""" tf_graph_context = graph_context.TFGraphContext( module_to_export=tf.Module(), temp_dir=base_temp_dir, evaluated_replacements=tensor_replacement_map) transform_fn = get_traced_transform_fn( preprocessing_fn, input_signature, tf_graph_context, output_keys_to_name_map=output_keys_to_name_map) return saved_transform_io_v2.write_v2_saved_model( tf_graph_context.module_to_export, transform_fn, 'transform_fn', saved_model_dir)
def metadata_fn(): graph = ops.get_default_graph() inputs = tf2_utils.supply_missing_inputs(structured_inputs, batch_size=1) with graph_context.TFGraphContext( temp_dir=base_temp_dir, evaluated_replacements=tensor_replacement_map): transformed_features = preprocessing_fn(inputs) # Get a map from tensor value names to feature keys. reversed_features = _get_tensor_value_to_key_map(transformed_features) result = collections.defaultdict(list) if not evaluate_schema_overrides: schema_override_tensors = graph.get_collection( _TF_METADATA_TENSOR_COLLECTION) for tensor in schema_override_tensors: if tensor.name in reversed_features: result[_TF_METADATA_TENSOR_COLLECTION].append( reversed_features[tensor.name]) else: # Obtain schema overrides for feature tensor ranges. result.update( _get_schema_overrides(graph, reversed_features, _TF_METADATA_TENSOR_COLLECTION, [ _TF_METADATA_TENSOR_MIN_COLLECTION, _TF_METADATA_TENSOR_MAX_COLLECTION ])) # Obtain schema overrides for feature protos. If no feature tensor is in # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the # feature name to indicate that this annotation should be added to the # global schema. result.update( _get_schema_overrides( graph, reversed_features, _TF_METADATA_EXTRA_ANNOTATION, [ _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL, _TF_METADATA_EXTRA_ANNOTATION_PROTO ], _TF_METADATA_EXTRA_ANNOTATION_GLOBAL)) return result
def _trace_and_get_metadata( concrete_transform_fn: function.ConcreteFunction, structured_inputs: Mapping[str, common_types.TensorType], preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], Mapping[str, common_types.TensorType]], base_temp_dir: Optional[str], tensor_replacement_map: Optional[Dict[str, tf.Tensor]] ) -> dataset_metadata.DatasetMetadata: """Compute and return metadata for the outputs of `concrete_transform_fn`.""" tf_graph_context = graph_context.TFGraphContext( module_to_export=tf.Module(), temp_dir=base_temp_dir, evaluated_replacements=tensor_replacement_map) concrete_metadata_fn = schema_inference.get_traced_metadata_fn( preprocessing_fn, structured_inputs, tf_graph_context, evaluate_schema_overrides=True) return dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema_v2( concrete_transform_fn.structured_outputs, concrete_metadata_fn, evaluate_schema_overrides=True))