def _perform_combiner_packing_optimization(saved_model_future, cache_value_nodes): """Optimizes the graph by packing possible combine nodes.""" # Inspect the graph to identify all the packable combines. inspect_combine_visitor = _InspectCombineVisitor() inspect_combine_traverser = nodes.Traverser(inspect_combine_visitor) _ = inspect_combine_traverser.visit_value_node(saved_model_future) if cache_value_nodes: for value_node in cache_value_nodes.values(): _ = inspect_combine_traverser.visit_value_node(value_node) packable_combines = inspect_combine_visitor.packable_combines # Do not pack if we have only a single combine in the group. packable_combines = { label: group for label, group in packable_combines.items() if len(group) > 1 } # Do another pass over the graph and pack the grouped combines. pack_combine_visitor = _PackCombineVisitor(packable_combines) pack_combine_traverser = nodes.Traverser(pack_combine_visitor) saved_model_future = pack_combine_traverser.visit_value_node( saved_model_future) # Replace cache nodes to point to the corresponding new nodes. if cache_value_nodes: cache_value_nodes = { key: pack_combine_traverser.visit_value_node(value_node) for key, value_node in cache_value_nodes.items() } return (saved_model_future, cache_value_nodes)
def get_analysis_dataset_keys(preprocessing_fn, specs, dataset_keys, input_cache): """Computes the dataset keys that are required in order to perform analysis. Args: preprocessing_fn: A tf.transform preprocessing_fn. specs: A dict of feature name to feature specification or tf.TypeSpecs. dataset_keys: A set of strings which are dataset keys, they uniquely identify these datasets across analysis runs. input_cache: A cache dictionary. Returns: A pair of: - A set of dataset keys that are required for analysis. - A boolean indicating whether or not a flattened version of the entire dataset is required. See the `flat_data` input to `AnalyzeDatasetWithCache`. """ transform_fn_future, _ = _build_analysis_graph_for_inspection( preprocessing_fn, specs, dataset_keys, input_cache) required_dataset_keys_result = set() inspect_visitor = _InspectVisitor(required_dataset_keys_result) inspect_traverser = nodes.Traverser(inspect_visitor) _ = inspect_traverser.visit_value_node(transform_fn_future) # If None is present this means that a flattened version of the entire dataset # is required, therefore this will be returning all of the given dataset_keys. flat_data_required = None in required_dataset_keys_result if flat_data_required: required_dataset_keys_result = dataset_keys return required_dataset_keys_result, flat_data_required
def get_analyze_input_columns(preprocessing_fn, specs, force_tf_compat_v1=False): """Return columns that are required inputs of `AnalyzeDataset`. Args: preprocessing_fn: A tf.transform preprocessing_fn. specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is True, this can also be feature specifications. force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode. Defaults to `False`. Returns: A list of columns that are required inputs of analyzers. """ use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) if not use_tf_compat_v1: assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs graph, structured_inputs, _ = (impl_helper.trace_preprocessing_function( preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1)) tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) visitor = _SourcedTensorsVisitor() for tensor_sink in tensor_sinks: nodes.Traverser(visitor).visit_value_node(tensor_sink.future) analyze_input_tensors = graph_tools.get_dependent_inputs( graph, structured_inputs, visitor.sourced_tensors) return list(analyze_input_tensors.keys())
def testTraverserComplexGraphMultipleCalls(self): a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') b = nodes.apply_operation(_Constant, value='b', label='Constant[b]') c = nodes.apply_operation(_Constant, value='c', label='Constant[c]') b_copy, a_copy = nodes.apply_multi_output_operation(_Swap, a, b, label='Swap') b_a = nodes.apply_operation(_Concat, b_copy, a_copy, label='Concat[0]') b_a_c = nodes.apply_operation(_Concat, b_a, c, label='Concat[1]') mock_visitor = mock.MagicMock() mock_visitor.visit.side_effect = [('a', ), ('b', ), ('b', 'a'), ('ba', ), ('c', ), ('bac', )] traverser = nodes.Traverser(mock_visitor) traverser.visit_value_node(b_a) traverser.visit_value_node(b_a_c) mock_visitor.assert_has_calls([ mock.call.visit(_Constant('a', 'Constant[a]'), ()), mock.call.validate_value('a'), mock.call.visit(_Constant('b', 'Constant[b]'), ()), mock.call.validate_value('b'), mock.call.visit(_Swap('Swap'), ('a', 'b')), mock.call.validate_value('b'), mock.call.validate_value('a'), mock.call.visit(_Concat('Concat[0]'), ('b', 'a')), mock.call.validate_value('ba'), mock.call.visit(_Constant('c', 'Constant[c]'), ()), mock.call.validate_value('c'), mock.call.visit(_Concat('Concat[1]'), ('ba', 'c')), mock.call.validate_value('bac'), ])
def get_analysis_dataset_keys(preprocessing_fn, specs, dataset_keys, input_cache, force_tf_compat_v1): """Computes the dataset keys that are required in order to perform analysis. Args: preprocessing_fn: A tf.transform preprocessing_fn. specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is True, this can also be feature specifications. dataset_keys: A set of strings which are dataset keys, they uniquely identify these datasets across analysis runs. input_cache: A cache dictionary. force_tf_compat_v1: If `True`, use Tensorflow in compat.v1 mode. Returns: A set of dataset keys that are required for analysis. """ transform_fn_future, _ = _build_analysis_graph_for_inspection( preprocessing_fn, specs, dataset_keys, input_cache, force_tf_compat_v1) result = set() inspect_visitor = _InspectVisitor(result) inspect_traverser = nodes.Traverser(inspect_visitor) _ = inspect_traverser.visit_value_node(transform_fn_future) # If None is present this means that a flattened version of the entire dataset # is required, therefore this will be returning all of the given dataset_keys. if any(k.is_flattened_dataset_key() for k in result): result = dataset_keys return result
def testTraverserComplexGraph(self): a = nodes.apply_operation(_Constant, value='a') b = nodes.apply_operation(_Constant, value='b') c = nodes.apply_operation(_Constant, value='c') b_copy, a_copy = nodes.apply_multi_output_operation(_Swap, a, b) b_a = nodes.apply_operation(_Concat, b_copy, a_copy) b_a_c = nodes.apply_operation(_Concat, b_a, c) mock_visitor = mock.MagicMock() mock_visitor.visit.side_effect = [('a', ), ('b', ), ('b', 'a'), ('ba', ), ('c', ), ('bac', )] nodes.Traverser(mock_visitor).visit_value_node(b_a_c) mock_visitor.assert_has_calls([ mock.call.visit(_Constant('a'), ()), mock.call.validate_value('a'), mock.call.visit(_Constant('b'), ()), mock.call.validate_value('b'), mock.call.visit(_Swap(), ('a', 'b')), mock.call.validate_value('b'), mock.call.validate_value('a'), mock.call.visit(_Concat(), ('b', 'a')), mock.call.validate_value('ba'), mock.call.visit(_Constant('c'), ()), mock.call.validate_value('c'), mock.call.visit(_Concat(), ('ba', 'c')), mock.call.validate_value('bac'), ])
def testTraverserBadNumOutputs(self): a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') mock_visitor = mock.MagicMock() mock_visitor.visit.side_effect = [('a', 'b')] with self.assertRaisesRegexp( ValueError, 'has 1 outputs but visitor returned 2 values: '): nodes.Traverser(mock_visitor).visit_value_node(a)
def testTraverserSimpleGraph(self): a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') mock_visitor = mock.MagicMock() mock_visitor.visit.side_effect = [('a', )] nodes.Traverser(mock_visitor).visit_value_node(a) mock_visitor.assert_has_calls([ mock.call.visit(_Constant('a', 'Constant[a]'), ()), mock.call.validate_value('a'), ])
def testTraverserOutputsNotATuple(self): a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') mock_visitor = mock.MagicMock() mock_visitor.visit.side_effect = ['not a tuple'] with self.assertRaisesRegexp( ValueError, r'expected visitor to return a tuple, got'): nodes.Traverser(mock_visitor).visit_value_node(a)
def _perform_cache_optimization(saved_model_future, dataset_keys, cache_dict): """Performs cache optimization on the given graph.""" cache_output_nodes = {} optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_dict, cache_output_nodes) optimize_traverser = nodes.Traverser(optimize_visitor) optimized = optimize_traverser.visit_value_node( saved_model_future).flattened_view if cache_dict is None: assert not cache_output_nodes cache_output_nodes = None return optimized, cache_output_nodes
def testTraverserCycle(self): a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') x_0 = nodes.apply_operation(_Identity, a, label='Identity[0]') x_1 = nodes.apply_operation(_Identity, x_0, label='Identity[1]') x_2 = nodes.apply_operation(_Identity, x_1, label='Identity[2]') x_0.parent_operation._inputs = (x_2, ) mock_visitor = mock.MagicMock() mock_visitor.visit.return_value = ('x', ) with self.assertRaisesWithLiteralMatch( AssertionError, 'Cycle detected: [Identity[2], Identity[1], Identity[0], Identity[2]]' ): nodes.Traverser(mock_visitor).visit_value_node(x_2)
def testTraverserCycle(self): x = nodes.apply_operation(_Constant, value='x') x_0 = nodes.apply_operation(_Identity, x, name='x_0') x_1 = nodes.apply_operation(_Identity, x_0, name='x_1') x_2 = nodes.apply_operation(_Identity, x_1, name='x_2') x_0.parent_operation._inputs = (x_2, ) mock_visitor = mock.MagicMock() mock_visitor.visit.return_value = ('x', ) with self.assertRaisesWithLiteralMatch( AssertionError, 'Cycle detected: [_Identity(name=\'x_2\'), _Identity(name=\'x_1\'), ' '_Identity(name=\'x_0\'), _Identity(name=\'x_2\')]'): nodes.Traverser(mock_visitor).visit_value_node(x_2)
def _perform_cache_optimization(saved_model_future, dataset_keys, tensor_keys_to_paths, cache_dict, num_phases): """Performs cache optimization on the given graph.""" cache_output_nodes = {} optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_dict, tensor_keys_to_paths, cache_output_nodes, num_phases) optimize_traverser = nodes.Traverser(optimize_visitor) optimized = optimize_traverser.visit_value_node( saved_model_future).flattened_view if cache_dict is None: assert not cache_output_nodes cache_output_nodes = None return (optimized, cache_output_nodes, optimize_visitor.get_detached_sideeffect_leafs())
def get_analyze_input_columns( preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], Mapping[str, common_types.TensorType]], specs: Mapping[str, Union[common_types.FeatureSpecType, tf.TypeSpec]], force_tf_compat_v1: bool = False) -> List[str]: """Return columns that are required inputs of `AnalyzeDataset`. Args: preprocessing_fn: A tf.transform preprocessing_fn. specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is True, this can also be feature specifications. force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode. Defaults to `False`. Returns: A list of columns that are required inputs of analyzers. """ use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) if not use_tf_compat_v1: assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs graph, structured_inputs, structured_outputs = ( impl_helper.trace_preprocessing_function( preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1)) tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) visitor = graph_tools.SourcedTensorsVisitor() for tensor_sink in tensor_sinks: nodes.Traverser(visitor).visit_value_node(tensor_sink.future) if use_tf_compat_v1: control_dependency_ops = [] else: # If traced in TF2 as a tf.function, inputs that end up in control # dependencies are required for the function to execute. Return such inputs # as required inputs of analyzers as well. _, control_dependency_ops = ( tf2_utils.strip_and_get_tensors_and_control_dependencies( tf.nest.flatten(structured_outputs, expand_composites=True))) output_tensors = list( itertools.chain(visitor.sourced_tensors, control_dependency_ops)) analyze_input_tensors = graph_tools.get_dependent_inputs( graph, structured_inputs, output_tensors) return list(analyze_input_tensors.keys())
def get_analysis_dataset_keys(preprocessing_fn, feature_spec, dataset_keys, input_cache): """Computes the dataset keys that are required in order to perform analysis. Args: preprocessing_fn: A tf.transform preprocessing_fn. feature_spec: A dict of feature name to feature specification. dataset_keys: A set of strings which are dataset keys, they uniquely identify these datasets across analysis runs. input_cache: A cache dictionary. Returns: A pair of: - A set of dataset keys that are required for analysis. - A boolean indicating whether or not a flattened version of the entire dataset is required. See the `flat_data` input to `AnalyzeDatasetWithCache`. """ with tf.Graph().as_default() as graph: with tf.compat.v1.name_scope('inputs'): input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) # TODO(b/34288791): This needs to be exactly the same as in impl.py copied_inputs = impl_helper.copy_tensors(input_signature) output_signature = preprocessing_fn(copied_inputs) transform_fn_future, _ = build( graph, input_signature, output_signature, dataset_keys=dataset_keys, cache_dict=input_cache) required_dataset_keys_result = set() inspect_visitor = _InspectVisitor(required_dataset_keys_result) inspect_traverser = nodes.Traverser(inspect_visitor) _ = inspect_traverser.visit_value_node(transform_fn_future) # If None is present this means that a flattened version of the entire dataset # is required, therefore this will be returning all of the given dataset_keys. flat_data_required = None in required_dataset_keys_result if flat_data_required: required_dataset_keys_result = dataset_keys return required_dataset_keys_result, flat_data_required
def get_analyze_input_columns(preprocessing_fn, feature_spec): """Return columns that are required inputs of `AnalyzeDataset`. Args: preprocessing_fn: A tf.transform preprocessing_fn. feature_spec: A dict of feature name to feature specification. Returns: A list of columns that are required inputs of analyzers. """ with tf.compat.v1.Graph().as_default() as graph: input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) _ = preprocessing_fn(input_signature.copy()) tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) visitor = _SourcedTensorsVisitor() for tensor_sink in tensor_sinks: nodes.Traverser(visitor).visit_value_node(tensor_sink.future) analyze_input_tensors = graph_tools.get_dependent_inputs( graph, input_signature, visitor.sourced_tensors) return analyze_input_tensors.keys()
def get_analyzers_fingerprint( graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType] ) -> Mapping[str, AnalyzersFingerprint]: """Computes fingerprints for all analyzers in `graph`. Args: graph: a TF Graph. structured_inputs: a dict from keys to batches of placeholder graph tensors. Returns: A mapping from analyzer name to a set of paths that define its fingerprint. """ result = {} tensor_sinks = graph.get_collection(analyzer_nodes.ALL_REPLACEMENTS) # The value for the keys in this dictionary are unused and can be arbitrary. sink_tensors_ready = { tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False for tensor_sink in tensor_sinks } graph_analyzer = InitializableGraphAnalyzer( graph, structured_inputs, list(sink_tensors_ready.items()), describe_path_as_analyzer_cache_hash) for tensor_sink in tensor_sinks: # Retrieve tensors that are inputs to the analyzer's value node. visitor = SourcedTensorsVisitor() nodes.Traverser(visitor).visit_value_node(tensor_sink.future) source_keys = _retrieve_source_keys(visitor.sourced_tensors, structured_inputs) paths = set() for tensor in visitor.sourced_tensors: # Obtain fingerprint for each tensor that is an input to the value node. path = graph_analyzer.get_unique_path(tensor) if path is not None: paths.add(path) result[str(tensor_sink.tensor.name)] = AnalyzersFingerprint( source_keys, paths) return result
def _perform_cache_optimization(saved_model_future, cache_location, dataset_keys): optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_location) optimize_traverser = nodes.Traverser(optimize_visitor) return optimize_traverser.visit_value_node(saved_model_future).flattened_view
def expand(self, dataset): """Analyze the dataset. Args: dataset: A dataset. Returns: A TransformFn containing the deferred transform function. Raises: ValueError: If preprocessing_fn has no outputs. """ flattened_pcoll, input_values_pcoll_dict, input_metadata = dataset input_schema = input_metadata.schema input_values_pcoll_dict = input_values_pcoll_dict or dict() analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys()) with tf.Graph().as_default() as graph: with tf.name_scope('inputs'): feature_spec = input_schema.as_feature_spec() input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) # In order to avoid a bug where import_graph_def fails when the # input_map and return_elements of an imported graph are the same # (b/34288791), we avoid using the placeholder of an input column as an # output of a graph. We do this by applying tf.identity to all inputs of # the preprocessing_fn. Note this applies at the level of raw tensors. # TODO(b/34288791): Remove this workaround and use a shallow copy of # inputs instead. A shallow copy is needed in case # self._preprocessing_fn mutates its input. copied_inputs = impl_helper.copy_tensors(input_signature) output_signature = self._preprocessing_fn(copied_inputs) # At this point we check that the preprocessing_fn has at least one # output. This is because if we allowed the output of preprocessing_fn to # be empty, we wouldn't be able to determine how many instances to # "unbatch" the output into. if not output_signature: raise ValueError('The preprocessing function returned an empty dict') if graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): raise ValueError( 'The preprocessing function contained trainable variables ' '{}'.format( graph.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES))) pipeline = flattened_pcoll.pipeline serialized_tf_config = common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get( # pylint: disable=protected-access pipeline.runner) extra_args = common.ConstructBeamPipelineVisitor.ExtraArgs( base_temp_dir=Context.create_base_temp_dir(), serialized_tf_config=serialized_tf_config, pipeline=pipeline, flat_pcollection=flattened_pcoll, pcollection_dict=input_values_pcoll_dict, graph=graph, input_signature=input_signature, input_schema=input_schema, cache_location=self._cache_location) transform_fn_future = analysis_graph_builder.build( graph, input_signature, output_signature, input_values_pcoll_dict.keys(), self._cache_location) transform_fn_pcoll = nodes.Traverser( common.ConstructBeamPipelineVisitor(extra_args)).visit_value_node( transform_fn_future) # Infer metadata. We take the inferred metadata and apply overrides that # refer to values of tensors in the graph. The override tensors must # be "constant" in that they don't depend on input data. The tensors can # depend on analyzer outputs though. This allows us to set metadata that # depends on analyzer outputs. _augment_metadata will use the analyzer # outputs stored in `transform_fn` to compute the metadata in a # deferred manner, once the analyzer outputs are known. metadata = dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema(output_signature, graph)) deferred_metadata = ( transform_fn_pcoll | 'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model)) full_metadata = beam_metadata_io.BeamDatasetMetadata( metadata, deferred_metadata) _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll) return transform_fn_pcoll, full_metadata
def build(graph, input_signature, output_signature, dataset_keys=None, cache_dict=None): """Returns a list of `Phase`s describing how to execute the pipeline. The default graph is assumed to contain some `Analyzer`s which must be executed by doing a full pass over the dataset, and passing the inputs for that analyzer into some implementation, then taking the results and replacing the `Analyzer`s outputs with constants in the graph containing these results. The execution plan is described by a list of `Phase`s. Each phase contains a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in that phase, together with a list of ops, which are the table initializers that are ready to run in that phase. An `Analyzer` or op is ready to run when all its dependencies in the graph have been computed. Thus if the graph is constructed by def preprocessing_fn(input) x = inputs['x'] scaled_0 = x - tft.min(x) scaled_0_1 = scaled_0 / tft.max(scaled_0) Then the first phase will contain the analyzer corresponding to the call to `min`, because `x` is an input and so is ready to compute in the first phase, while the second phase will contain the analyzer corresponding to the call to `max` since `scaled_1` depends on the result of the call to `tft.min` which is computed in the first phase. More generally, we define a level for each op and each `Analyzer` by walking the graph, assigning to each operation the max level of its inputs, to each `Tensor` the level of its operation, unless it's the output of an `Analyzer` in which case we assign the level of its `Analyzer` plus one. Args: graph: A `tf.Graph`. input_signature: A dict whose keys are strings and values are `Tensor`s or `SparseTensor`s. output_signature: A dict whose keys are strings and values are `Tensor`s or `SparseTensor`s. dataset_keys: (Optional) A set of strings which are dataset keys, they uniquely identify these datasets across analysis runs. cache_dict: (Optional): A cache dictionary. Returns: A pair of: * list of `Phase`s * A dictionary of output cache `ValueNode`s. Raises: ValueError: if the graph cannot be analyzed. """ tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) graph.clear_collection(analyzer_nodes.TENSOR_REPLACEMENTS) phase = 0 tensor_bindings = [] sink_tensors_ready = { tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False for tensor_sink in tensor_sinks } translate_visitor = _TranslateVisitor() translate_traverser = nodes.Traverser(translate_visitor) analyzers_input_signature = {} graph_analyzer = None extracted_input_node = nodes.apply_operation( beam_nodes.ExtractInputForSavedModel, dataset_key=analyzer_cache._make_flattened_dataset_key(), # pylint: disable=protected-access label='ExtractInputForSavedModel[FlattenedDataset]') while not all(sink_tensors_ready.values()): infix = 'Phase{}'.format(phase) # Determine which table init ops are ready to run in this phase # Determine which keys of pending_tensor_replacements are ready to run # in this phase, based in whether their dependencies are ready. graph_analyzer = graph_tools.InitializableGraphAnalyzer( graph, input_signature, list(sink_tensors_ready.items()), graph_tools.describe_path_as_analyzer_cache_hash) ready_traverser = nodes.Traverser(_ReadyVisitor(graph_analyzer)) # Now create and apply a SavedModel with all tensors in tensor_bindings # bound, which outputs all the tensors in the required tensor tuples. intermediate_output_signature = collections.OrderedDict() saved_model_future = nodes.apply_operation( beam_nodes.CreateSavedModel, *tensor_bindings, table_initializers=tuple(graph_analyzer.ready_table_initializers), output_signature=intermediate_output_signature, label='CreateSavedModelForAnalyzerInputs[{}]'.format(infix)) extracted_values_dict = nodes.apply_operation( beam_nodes.ApplySavedModel, saved_model_future, extracted_input_node, phase=phase, label='ApplySavedModel[{}]'.format(infix)) translate_visitor.phase = phase translate_visitor.intermediate_output_signature = ( intermediate_output_signature) translate_visitor.extracted_values_dict = extracted_values_dict for tensor, value_node, is_asset_filepath in tensor_sinks: hashable_tensor = tf_utils.hashable_tensor_or_op(tensor) # Don't compute a binding/sink/replacement that's already been computed if sink_tensors_ready[hashable_tensor]: continue if not ready_traverser.visit_value_node(value_node): continue translated_value_node = translate_traverser.visit_value_node( value_node) name = _tensor_name(tensor) tensor_bindings.append( nodes.apply_operation( beam_nodes.CreateTensorBinding, translated_value_node, tensor_name=str(tensor.name), dtype_enum=tensor.dtype.as_datatype_enum, is_asset_filepath=is_asset_filepath, label=analyzer_nodes.sanitize_label( 'CreateTensorBinding[{}]'.format(name)))) sink_tensors_ready[hashable_tensor] = True analyzers_input_signature.update(intermediate_output_signature) phase += 1 # We need to make sure that the representation of this output_signature is # deterministic. output_signature = collections.OrderedDict( sorted(output_signature.items(), key=lambda t: t[0])) # TODO(KesterTong): check all table initializers are ready, check all output # tensors are ready. saved_model_future = nodes.apply_operation( beam_nodes.CreateSavedModel, *tensor_bindings, table_initializers=tuple( graph.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)), output_signature=output_signature, label='CreateSavedModel') tensor_keys_to_paths = { tensor_key: graph_analyzer.get_unique_path(analyzers_input_signature[tensor_key]) for tensor_key in analyzers_input_signature } (optimized_saved_model_future, output_cache_value_nodes, detached_sideeffect_leafs) = _perform_cache_optimization( saved_model_future, dataset_keys, tensor_keys_to_paths, cache_dict, phase) (optimized_saved_model_future, output_cache_value_nodes) = ( combiner_packing_util.perform_combiner_packing_optimization( optimized_saved_model_future, output_cache_value_nodes, phase)) global _ANALYSIS_GRAPH _ANALYSIS_GRAPH = optimized_saved_model_future return (optimized_saved_model_future, output_cache_value_nodes, detached_sideeffect_leafs)
def perform_combiner_packing_optimization(saved_model_future, cache_value_nodes, num_phases): """Optimizes the graph by packing possible combine nodes.""" # Inspect the graph to identify all the packable combines. inspect_acc_combine_visitor = _InspectAccumulateCombineVisitor() inspect_acc_combine_traverser = nodes.Traverser(inspect_acc_combine_visitor) _ = inspect_acc_combine_traverser.visit_value_node(saved_model_future) packable_combines = inspect_acc_combine_visitor.packable_combines # Do not pack if we have only a single combine in the group. packable_combines = { label: group for label, group in packable_combines.items() if len(group) > 1 } pack_acc_combine_visitor = _PackAccumulateCombineVisitor(packable_combines) pack_acc_combine_traverser = nodes.Traverser(pack_acc_combine_visitor) saved_model_future = pack_acc_combine_traverser.visit_value_node( saved_model_future) # Replace cache nodes to point to the corresponding new nodes. cache_value_nodes = _update_cache_value_node_references( cache_value_nodes, pack_acc_combine_traverser) # TODO(b/134414978): Consider also packing the merges even when we have # multiple phases. if num_phases > 1: return (saved_model_future, cache_value_nodes) # Identify the merge combines that can be packed together. inspect_merge_combine_visitor = _InspectMergeCombineVisitor() inspect_merge_combine_traverser = nodes.Traverser( inspect_merge_combine_visitor) _ = inspect_merge_combine_traverser.visit_value_node(saved_model_future) # Only pack if we have more than one merge combines. if len(inspect_merge_combine_visitor.packable_combine_extract_outputs) <= 1: return (saved_model_future, cache_value_nodes) # Add flatten and packed merge nodes. pack_merge_combine_visitor = _PackMergeCombineVisitor( packable_combine_extract_outputs= inspect_merge_combine_visitor.packable_combine_extract_outputs) pack_merge_combine_traverser = nodes.Traverser(pack_merge_combine_visitor) saved_model_future = pack_merge_combine_traverser.visit_value_node( saved_model_future) # Replace cache nodes to point to the corresponding new nodes. cache_value_nodes = _update_cache_value_node_references( cache_value_nodes, pack_merge_combine_traverser) # Remove redundant flatten and packed merge nodes. remove_redundant_visitor = _RemoveRedundantPackedMergeCombineVisitor( final_packed_merge_combine_label= pack_merge_combine_visitor.final_packed_merge_combine_label) remove_redundant_traverser = nodes.Traverser(remove_redundant_visitor) saved_model_future = remove_redundant_traverser.visit_value_node( saved_model_future) # Replace cache nodes to point to the corresponding new nodes. cache_value_nodes = _update_cache_value_node_references( cache_value_nodes, remove_redundant_traverser) return (saved_model_future, cache_value_nodes)
def build(graph, input_signature, output_signature): """Returns a list of `Phase`s describing how to execute the pipeline. The default graph is assumed to contain some `Analyzer`s which must be executed by doing a full pass over the dataset, and passing the inputs for that analyzer into some implementation, then taking the results and replacing the `Analyzer`s outputs with constants in the graph containing these results. The execution plan is described by a list of `Phase`s. Each phase contains a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in that phase, together with a list of ops, which are the table initializers that are ready to run in that phase. An `Analyzer` or op is ready to run when all its dependencies in the graph have been computed. Thus if the graph is constructed by def preprocessing_fn(input) x = inputs['x'] scaled_0 = x - tft.min(x) scaled_0_1 = scaled_0 / tft.max(scaled_0) Then the first phase will contain the analyzer corresponding to the call to `min`, because `x` is an input and so is ready to compute in the first phase, while the second phase will contain the analyzer corresponding to the call to `max` since `scaled_1` depends on the result of the call to `tft.min` which is computed in the first phase. More generally, we define a level for each op and each `Analyzer` by walking the graph, assigning to each operation the max level of its inputs, to each `Tensor` the level of its operation, unless it's the output of an `Analyzer` in which case we assign the level of its `Analyzer` plus one. Args: graph: A `tf.Graph`. input_signature: A dict whose keys are strings and values are `Tensor`s or `SparseTensor`s. output_signature: A dict whose keys are strings and values are `Tensor`s or `SparseTensor`s. Returns: A list of `Phase`s. Raises: ValueError: if the graph cannot be analyzed. """ tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) graph.clear_collection(analyzer_nodes.TENSOR_REPLACEMENTS) phase = 0 tensor_bindings = [] sink_tensors_ready = { tensor_sink.tensor: False for tensor_sink in tensor_sinks } translate_visitor = _TranslateVisitor() translate_traverser = nodes.Traverser(translate_visitor) while not all(sink_tensors_ready.values()): # Determine which table init ops are ready to run in this phase # Determine which keys of pending_tensor_replacements are ready to run # in this phase, based in whether their dependencies are ready. graph_analyzer = graph_tools.InitializableGraphAnalyzer( graph, input_signature.values(), sink_tensors_ready) ready_traverser = nodes.Traverser(_ReadyVisitor(graph_analyzer)) # Now create and apply a SavedModel with all tensors in tensor_bindings # bound, which outputs all the tensors in the required tensor tuples. intermediate_output_signature = collections.OrderedDict() saved_model_future = nodes.apply_operation( beam_nodes.CreateSavedModel, *tensor_bindings, table_initializers=tuple(graph_analyzer.ready_table_initializers), output_signature=intermediate_output_signature, label='CreateSavedModelForAnalyzerInputs[{}]'.format(phase)) extracted_values_dict = nodes.apply_operation( beam_nodes.ApplySavedModel, saved_model_future, phase=phase, label='ApplySavedModel[{}]'.format(phase)) translate_visitor.phase = phase translate_visitor.intermediate_output_signature = ( intermediate_output_signature) translate_visitor.extracted_values_dict = extracted_values_dict for tensor, value_node, is_asset_filepath in tensor_sinks: # Don't compute a binding/sink/replacement that's already been computed if sink_tensors_ready[tensor]: continue if not ready_traverser.visit_value_node(value_node): continue translated_value_node = translate_traverser.visit_value_node( value_node) name = _tensor_name(tensor) tensor_bindings.append( nodes.apply_operation( beam_nodes.CreateTensorBinding, translated_value_node, tensor=str(tensor.name), is_asset_filepath=is_asset_filepath, label='CreateTensorBinding[{}]'.format(name))) sink_tensors_ready[tensor] = True phase += 1 # We need to make sure that the representation of this output_signature is # deterministic. output_signature = collections.OrderedDict( sorted(output_signature.items(), key=lambda t: t[0])) return nodes.apply_operation(beam_nodes.CreateSavedModel, *tensor_bindings, table_initializers=tuple( graph.get_collection( tf.GraphKeys.TABLE_INITIALIZERS)), output_signature=output_signature, label='CreateSavedModel')
def expand(self, dataset): """Analyze the dataset. Args: dataset: A dataset. Returns: A TransformFn containing the deferred transform function. Raises: ValueError: If preprocessing_fn has no outputs. """ (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict, input_metadata) = dataset if self._use_tfxio: input_schema = None input_tensor_adapter_config = input_metadata else: input_schema = input_metadata.schema input_tensor_adapter_config = None input_values_pcoll_dict = input_values_pcoll_dict or dict() with tf.compat.v1.Graph().as_default() as graph: with tf.compat.v1.name_scope('inputs'): if self._use_tfxio: specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs() else: specs = schema_utils.schema_as_feature_spec(input_schema).feature_spec input_signature = impl_helper.batched_placeholders_from_specs(specs) # In order to avoid a bug where import_graph_def fails when the # input_map and return_elements of an imported graph are the same # (b/34288791), we avoid using the placeholder of an input column as an # output of a graph. We do this by applying tf.identity to all inputs of # the preprocessing_fn. Note this applies at the level of raw tensors. # TODO(b/34288791): Remove this workaround and use a shallow copy of # inputs instead. A shallow copy is needed in case # self._preprocessing_fn mutates its input. copied_inputs = impl_helper.copy_tensors(input_signature) output_signature = self._preprocessing_fn(copied_inputs) # At this point we check that the preprocessing_fn has at least one # output. This is because if we allowed the output of preprocessing_fn to # be empty, we wouldn't be able to determine how many instances to # "unbatch" the output into. if not output_signature: raise ValueError('The preprocessing function returned an empty dict') if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES): raise ValueError( 'The preprocessing function contained trainable variables ' '{}'.format( graph.get_collection_ref( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))) pipeline = self.pipeline or (flattened_pcoll or next( v for v in input_values_pcoll_dict.values() if v is not None)).pipeline # Add a stage that inspects graph collections for API use counts and logs # them as a beam metric. _ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI(graph)) tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get( type(pipeline.runner)) extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs( base_temp_dir=Context.create_base_temp_dir(), tf_config=tf_config, pipeline=pipeline, flat_pcollection=flattened_pcoll, pcollection_dict=input_values_pcoll_dict, graph=graph, input_signature=input_signature, input_schema=input_schema, input_tensor_adapter_config=input_tensor_adapter_config, use_tfxio=self._use_tfxio, cache_pcoll_dict=dataset_cache_dict) transform_fn_future, cache_value_nodes = analysis_graph_builder.build( graph, input_signature, output_signature, input_values_pcoll_dict.keys(), cache_dict=dataset_cache_dict) traverser = nodes.Traverser( beam_common.ConstructBeamPipelineVisitor(extra_args)) transform_fn_pcoll = traverser.visit_value_node(transform_fn_future) if cache_value_nodes is not None: output_cache_pcoll_dict = {} for (dataset_key, cache_key), value_node in six.iteritems(cache_value_nodes): if dataset_key not in output_cache_pcoll_dict: output_cache_pcoll_dict[dataset_key] = {} output_cache_pcoll_dict[dataset_key][cache_key] = ( traverser.visit_value_node(value_node)) else: output_cache_pcoll_dict = None # Infer metadata. We take the inferred metadata and apply overrides that # refer to values of tensors in the graph. The override tensors must # be "constant" in that they don't depend on input data. The tensors can # depend on analyzer outputs though. This allows us to set metadata that # depends on analyzer outputs. _infer_metadata_from_saved_model will use the # analyzer outputs stored in `transform_fn` to compute the metadata in a # deferred manner, once the analyzer outputs are known. metadata = dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema(output_signature, graph)) deferred_metadata = ( transform_fn_pcoll | 'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model)) full_metadata = beam_metadata_io.BeamDatasetMetadata( metadata, deferred_metadata) _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll) return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict