def test_optimize_traversal(self, feature_spec, preprocessing_fn, dataset_input_cache_dict, expected_dot_graph_str): span_0_key, span_1_key = 'span-0', 'span-1' if dataset_input_cache_dict is not None: cache = {span_0_key: dataset_input_cache_dict} else: cache = {} with tf.compat.v1.name_scope('inputs'): input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) output_signature = preprocessing_fn(input_signature) transform_fn_future, cache_output_dict = analysis_graph_builder.build( tf.compat.v1.get_default_graph(), input_signature, output_signature, {span_0_key, span_1_key}, cache) leaf_nodes = [transform_fn_future] + sorted(cache_output_dict.values(), key=str) dot_string = nodes.get_dot_graph(leaf_nodes).to_string() self.WriteRenderedDotFile(dot_string) self.assertSameElements( dot_string.split('\n'), expected_dot_graph_str.split('\n'), msg='Result dot graph is:\n{}'.format(dot_string))
def test_perform_combiner_packing_optimization( self, feature_spec, preprocessing_fn, num_phases, expected_dot_graph_str_before_packing, expected_dot_graph_str_after_packing): graph, structured_inputs, structured_outputs = ( impl_helper.trace_preprocessing_function( preprocessing_fn, feature_spec, use_tf_compat_v1=True)) def _side_effect_fn(saved_model_future, cache_value_nodes, unused_num_phases): return (saved_model_future, cache_value_nodes) with mock.patch.object( combiner_packing_util, 'perform_combiner_packing_optimization', side_effect=_side_effect_fn): transform_fn_future_before, unused_cache = analysis_graph_builder.build( graph, structured_inputs, structured_outputs) transform_fn_future_after, unused_cache = ( combiner_packing_util.perform_combiner_packing_optimization( transform_fn_future_before, unused_cache, num_phases)) dot_string_before = nodes.get_dot_graph( [transform_fn_future_before]).to_string() self.assertMultiLineEqual( msg='Result dot graph is:\n{}'.format(dot_string_before), first=dot_string_before, second=expected_dot_graph_str_before_packing) dot_string_after = nodes.get_dot_graph( [transform_fn_future_after]).to_string() self.WriteRenderedDotFile(dot_string_after) self.assertMultiLineEqual( msg='Result dot graph is:\n{}'.format(dot_string_after), first=dot_string_after, second=expected_dot_graph_str_after_packing)
def test_perform_combiner_packing_optimization( self, feature_spec, preprocessing_fn, num_phases, expected_dot_graph_str_before_packing, expected_dot_graph_str_after_packing): with tf.compat.v1.Graph().as_default() as graph: with tf.compat.v1.name_scope('inputs'): input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) output_signature = preprocessing_fn(input_signature) def _side_effect_fn(saved_model_future, cache_value_nodes, unused_num_phases): return (saved_model_future, cache_value_nodes) with mock.patch.object(combiner_packing_util, 'perform_combiner_packing_optimization', side_effect=_side_effect_fn): transform_fn_future_before, unused_cache = analysis_graph_builder.build( graph, input_signature, output_signature) transform_fn_future_after, unused_cache = ( combiner_packing_util.perform_combiner_packing_optimization( transform_fn_future_before, unused_cache, num_phases)) dot_string_before = nodes.get_dot_graph([transform_fn_future_before ]).to_string() self.assertMultiLineEqual( msg='Result dot graph is:\n{}'.format(dot_string_before), first=dot_string_before, second=expected_dot_graph_str_before_packing) dot_string_after = nodes.get_dot_graph([transform_fn_future_after ]).to_string() self.WriteRenderedDotFile(dot_string_after) self.assertMultiLineEqual( msg='Result dot graph is:\n{}'.format(dot_string_after), first=dot_string_after, second=expected_dot_graph_str_after_packing)
def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str): graph, structured_inputs, structured_outputs = ( impl_helper.trace_preprocessing_function( preprocessing_fn, feature_spec, use_tf_compat_v1=True)) transform_fn_future, unused_cache = analysis_graph_builder.build( graph, structured_inputs, structured_outputs) dot_string = nodes.get_dot_graph([transform_fn_future]).to_string() self.WriteRenderedDotFile(dot_string) self.assertMultiLineEqual( msg='Result dot graph is:\n{}'.format(dot_string), first=dot_string, second=expected_dot_graph_str)
def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str): with tf.name_scope('inputs'): input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) output_signature = preprocessing_fn(input_signature) transform_fn_future = analysis_graph_builder.build( tf.get_default_graph(), input_signature, output_signature) dot_string = nodes.get_dot_graph([transform_fn_future]).to_string() self.WriteRenderedDotFile(dot_string) self.assertMultiLineEqual( msg='Result dot graph is:\n{}'.format(dot_string), first=dot_string, second=expected_dot_graph_str)
def test_optimize_traversal(self, feature_spec, preprocessing_fn, write_cache_fn, expected_dot_graph_str): cache_location = self._make_cache_location() span_0_key, span_1_key = 'span-0', 'span-1' if write_cache_fn is not None: write_cache_fn(cache_location.input_cache_dir, [span_0_key, span_1_key]) with tf.name_scope('inputs'): input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) output_signature = preprocessing_fn(input_signature) transform_fn_future = analysis_graph_builder.build( tf.get_default_graph(), input_signature, output_signature, {span_0_key, span_1_key}, cache_location) dot_string = nodes.get_dot_graph([transform_fn_future]).to_string() self.WriteRenderedDotFile(dot_string) self.assertSameElements( dot_string.split('\n'), expected_dot_graph_str.split('\n'), msg='Result dot graph is:\n{}'.format(dot_string))
def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str, expected_dot_graph_str_tf2, use_tf_compat_v1): if not use_tf_compat_v1: test_case.skip_if_not_tf2('Tensorflow 2.x required') specs = (feature_spec if use_tf_compat_v1 else impl_helper.get_type_specs_from_feature_specs(feature_spec)) graph, structured_inputs, structured_outputs = ( impl_helper.trace_preprocessing_function( preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1, base_temp_dir=os.path.join(self.get_temp_dir(), self._testMethodName))) transform_fn_future, unused_cache = analysis_graph_builder.build( graph, structured_inputs, structured_outputs) dot_string = nodes.get_dot_graph([transform_fn_future]).to_string() self.WriteRenderedDotFile(dot_string) self.assertMultiLineEqual( msg='Result dot graph is:\n{}'.format(dot_string), first=dot_string, second=(expected_dot_graph_str if use_tf_compat_v1 else expected_dot_graph_str_tf2))
def expand(self, dataset): """Analyze the dataset. Args: dataset: A dataset. Returns: A TransformFn containing the deferred transform function. Raises: ValueError: If preprocessing_fn has no outputs. """ flattened_pcoll, input_values_pcoll_dict, input_metadata = dataset input_schema = input_metadata.schema input_values_pcoll_dict = input_values_pcoll_dict or dict() analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys()) with tf.Graph().as_default() as graph: with tf.name_scope('inputs'): feature_spec = input_schema.as_feature_spec() input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) # In order to avoid a bug where import_graph_def fails when the # input_map and return_elements of an imported graph are the same # (b/34288791), we avoid using the placeholder of an input column as an # output of a graph. We do this by applying tf.identity to all inputs of # the preprocessing_fn. Note this applies at the level of raw tensors. # TODO(b/34288791): Remove this workaround and use a shallow copy of # inputs instead. A shallow copy is needed in case # self._preprocessing_fn mutates its input. copied_inputs = impl_helper.copy_tensors(input_signature) output_signature = self._preprocessing_fn(copied_inputs) # At this point we check that the preprocessing_fn has at least one # output. This is because if we allowed the output of preprocessing_fn to # be empty, we wouldn't be able to determine how many instances to # "unbatch" the output into. if not output_signature: raise ValueError('The preprocessing function returned an empty dict') if graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): raise ValueError( 'The preprocessing function contained trainable variables ' '{}'.format( graph.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES))) pipeline = flattened_pcoll.pipeline serialized_tf_config = common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get( # pylint: disable=protected-access pipeline.runner) extra_args = common.ConstructBeamPipelineVisitor.ExtraArgs( base_temp_dir=Context.create_base_temp_dir(), serialized_tf_config=serialized_tf_config, pipeline=pipeline, flat_pcollection=flattened_pcoll, pcollection_dict=input_values_pcoll_dict, graph=graph, input_signature=input_signature, input_schema=input_schema, cache_location=self._cache_location) transform_fn_future = analysis_graph_builder.build( graph, input_signature, output_signature, input_values_pcoll_dict.keys(), self._cache_location) transform_fn_pcoll = nodes.Traverser( common.ConstructBeamPipelineVisitor(extra_args)).visit_value_node( transform_fn_future) # Infer metadata. We take the inferred metadata and apply overrides that # refer to values of tensors in the graph. The override tensors must # be "constant" in that they don't depend on input data. The tensors can # depend on analyzer outputs though. This allows us to set metadata that # depends on analyzer outputs. _augment_metadata will use the analyzer # outputs stored in `transform_fn` to compute the metadata in a # deferred manner, once the analyzer outputs are known. metadata = dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema(output_signature, graph)) deferred_metadata = ( transform_fn_pcoll | 'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model)) full_metadata = beam_metadata_io.BeamDatasetMetadata( metadata, deferred_metadata) _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll) return transform_fn_pcoll, full_metadata
def expand(self, dataset): """Analyze the dataset. Args: dataset: A dataset. Returns: A TransformFn containing the deferred transform function. Raises: ValueError: If preprocessing_fn has no outputs. """ (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict, input_metadata) = dataset if self._use_tfxio: input_schema = None input_tensor_adapter_config = input_metadata else: input_schema = input_metadata.schema input_tensor_adapter_config = None input_values_pcoll_dict = input_values_pcoll_dict or dict() with tf.compat.v1.Graph().as_default() as graph: with tf.compat.v1.name_scope('inputs'): if self._use_tfxio: specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs() else: specs = schema_utils.schema_as_feature_spec(input_schema).feature_spec input_signature = impl_helper.batched_placeholders_from_specs(specs) # In order to avoid a bug where import_graph_def fails when the # input_map and return_elements of an imported graph are the same # (b/34288791), we avoid using the placeholder of an input column as an # output of a graph. We do this by applying tf.identity to all inputs of # the preprocessing_fn. Note this applies at the level of raw tensors. # TODO(b/34288791): Remove this workaround and use a shallow copy of # inputs instead. A shallow copy is needed in case # self._preprocessing_fn mutates its input. copied_inputs = impl_helper.copy_tensors(input_signature) output_signature = self._preprocessing_fn(copied_inputs) # At this point we check that the preprocessing_fn has at least one # output. This is because if we allowed the output of preprocessing_fn to # be empty, we wouldn't be able to determine how many instances to # "unbatch" the output into. if not output_signature: raise ValueError('The preprocessing function returned an empty dict') if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES): raise ValueError( 'The preprocessing function contained trainable variables ' '{}'.format( graph.get_collection_ref( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))) pipeline = self.pipeline or (flattened_pcoll or next( v for v in input_values_pcoll_dict.values() if v is not None)).pipeline # Add a stage that inspects graph collections for API use counts and logs # them as a beam metric. _ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI(graph)) tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get( type(pipeline.runner)) extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs( base_temp_dir=Context.create_base_temp_dir(), tf_config=tf_config, pipeline=pipeline, flat_pcollection=flattened_pcoll, pcollection_dict=input_values_pcoll_dict, graph=graph, input_signature=input_signature, input_schema=input_schema, input_tensor_adapter_config=input_tensor_adapter_config, use_tfxio=self._use_tfxio, cache_pcoll_dict=dataset_cache_dict) transform_fn_future, cache_value_nodes = analysis_graph_builder.build( graph, input_signature, output_signature, input_values_pcoll_dict.keys(), cache_dict=dataset_cache_dict) traverser = nodes.Traverser( beam_common.ConstructBeamPipelineVisitor(extra_args)) transform_fn_pcoll = traverser.visit_value_node(transform_fn_future) if cache_value_nodes is not None: output_cache_pcoll_dict = {} for (dataset_key, cache_key), value_node in six.iteritems(cache_value_nodes): if dataset_key not in output_cache_pcoll_dict: output_cache_pcoll_dict[dataset_key] = {} output_cache_pcoll_dict[dataset_key][cache_key] = ( traverser.visit_value_node(value_node)) else: output_cache_pcoll_dict = None # Infer metadata. We take the inferred metadata and apply overrides that # refer to values of tensors in the graph. The override tensors must # be "constant" in that they don't depend on input data. The tensors can # depend on analyzer outputs though. This allows us to set metadata that # depends on analyzer outputs. _infer_metadata_from_saved_model will use the # analyzer outputs stored in `transform_fn` to compute the metadata in a # deferred manner, once the analyzer outputs are known. metadata = dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema(output_signature, graph)) deferred_metadata = ( transform_fn_pcoll | 'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model)) full_metadata = beam_metadata_io.BeamDatasetMetadata( metadata, deferred_metadata) _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll) return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict