def test_make_feed_list_error(self, feature_spec, instances, error_msg, error_type=ValueError): tensors = tf.parse_example(tf.placeholder(tf.string, [None]), feature_spec) schema = dataset_schema.from_feature_spec(feature_spec) with self.assertRaisesRegexp(error_type, error_msg): impl_helper.make_feed_list(tensors, schema, instances)
def test_make_feed_list_error(self, feature_spec, instances, error_msg, error_type=ValueError): tensors = tf.io.parse_example(serialized=tf.compat.v1.placeholder( tf.string, [None]), features=feature_spec) schema = schema_utils.schema_from_feature_spec(feature_spec) with self.assertRaisesRegexp(error_type, error_msg): impl_helper.make_feed_list(tensors, schema, instances)
def test_make_feed_list(self, feature_spec, instances, feed_dict): schema = dataset_schema.from_feature_spec(feature_spec) feature_names = list(feature_spec.keys()) expected_feed_list = [feed_dict[key] for key in feature_names] np.testing.assert_equal( impl_helper.make_feed_list(feature_names, schema, instances), expected_feed_list)
def _handle_batch(self, batch): self._batch_size_distribution.update(len(batch)) self._num_instances.inc(len(batch)) # Making a copy of batch because mutating PCollection elements is not # allowed. if self._passthrough_keys: batch = [copy.copy(x) for x in batch] # Extract passthrough data. passthrough_data = { key: [instance.pop(key) for instance in batch ] for key in self._passthrough_keys } feed_list = impl_helper.make_feed_list(self._graph_state.inputs_tensor_keys, self._input_schema, batch) try: outputs_list = self._graph_state.callable_get_outputs(*feed_list) except Exception as e: tf.logging.error('%s while applying transform function for tensors %s', e, self._graph_state.outputs_tensor_keys) raise ValueError('bad inputs: {}'.format(feed_list)) assert len(self._graph_state.outputs_tensor_keys) == len(outputs_list) result = { key: value for key, value in zip(self._graph_state.outputs_tensor_keys, outputs_list) } for key, value in six.iteritems(passthrough_data): result[key] = value return result
def benchmarkRunMetagraphDoFnAtTFLevel(self): """Benchmark RunMetaGraphDoFn at the TF level. Benchmarks the parts of RunMetaGraphDoFn that involve feeding and fetching from the TFT SavedModel. Records the wall time taken. Note that this benchmark necessarily duplicates code directly from TFT since it's benchmarking the low-level internals of TFT, which are not exposed for use in this way. """ common_variables = _get_common_variables(self._dataset) tf_config = tft_beam_impl._FIXED_PARALLELISM_TF_CONFIG # pylint: disable=protected-access input_schema = common_variables.transform_input_dataset_metadata.schema # This block copied from _GraphState.__init__ with tf.compat.v1.Graph().as_default() as graph: session = tf.compat.v1.Session(graph=graph, config=tf_config) with session.as_default(): # TODO(b/148082271): Revert back to unpacking the result directly once # TFX depends on TFT 0.22. apply_saved_model_result = ( saved_transform_io. partially_apply_saved_transform_internal( self._dataset.tft_saved_model_path(), {})) inputs, outputs = apply_saved_model_result[:2] session.run(tf.compat.v1.global_variables_initializer()) session.run(tf.compat.v1.tables_initializer()) graph.finalize() # We ignore the schema, and assume there are no excluded outputs. outputs_tensor_keys = sorted(set(outputs.keys())) fetches = [outputs[key] for key in outputs_tensor_keys] tensor_inputs = graph_tools.get_dependent_inputs( graph, inputs, fetches) input_tensor_keys = sorted(tensor_inputs.keys()) feed_list = [inputs[key] for key in input_tensor_keys] callable_get_outputs = session.make_callable(fetches, feed_list=feed_list) batch_size, batched_records = _get_batched_records(self._dataset) # This block copied from _RunMetaGraphDoFn._handle_batch start = time.time() for batch in batched_records: feed_list = impl_helper.make_feed_list(input_tensor_keys, input_schema, batch) outputs_list = callable_get_outputs(*feed_list) _ = { key: value for key, value in zip(outputs_tensor_keys, outputs_list) } end = time.time() delta = end - start self.report_benchmark(iters=1, wall_time=delta, extras={ "batch_size": batch_size, "num_examples": self._dataset.num_examples() })
def _make_feed_list(self, batch): if self._use_tfxio: feed_by_name = self._tensor_adapter.ToBatchTensors( batch, produce_eager_tensors=False) return [ feed_by_name[name] for name in self._graph_state.inputs_tensor_keys] return impl_helper.make_feed_list( self._graph_state.inputs_tensor_keys, self._input_schema, batch)
def test_make_feed_list(self, feature_spec, instances, feed_dict, produce_eager_tensors): if produce_eager_tensors: test_case.skip_if_not_tf2('Tensorflow 2.x required') schema = schema_utils.schema_from_feature_spec(feature_spec) feature_names = list(feature_spec.keys()) expected_feed_list = [feed_dict[key] for key in feature_names] evaluated_feed_list = impl_helper.make_feed_list( feature_names, schema, instances, produce_eager_tensors=produce_eager_tensors) np.testing.assert_equal( evaluated_feed_list if not produce_eager_tensors else _get_value_from_eager_tensors(evaluated_feed_list), expected_feed_list)
def _handle_batch(self, batch): self._batch_size_distribution.update(len(batch)) self._num_instances.inc(len(batch)) # Making a copy of batch because mutating PCollection elements is not # allowed. if self._passthrough_keys: batch = [copy.copy(x) for x in batch] # Extract passthrough data. passthrough_data = { key: [instance.pop(key) for instance in batch] for key in self._passthrough_keys } feed_list = impl_helper.make_feed_list( self._graph_state.inputs_tensor_keys, self._input_schema, batch) try: outputs_list = self._graph_state.callable_get_outputs(*feed_list) except Exception as e: raise ValueError( """An error occured while trying to apply the transformation: "{}". Batch instances: {}, Fetching the values for the following Tensor keys: {}.""".format( str(e), batch, self._graph_state.outputs_tensor_keys)) assert len(self._graph_state.outputs_tensor_keys) == len(outputs_list) result = { key: value for key, value in zip(self._graph_state.outputs_tensor_keys, outputs_list) } for key, value in six.iteritems(passthrough_data): result[key] = value return result