示例#1
0
 def test_make_feed_list_error(self,
                               feature_spec,
                               instances,
                               error_msg,
                               error_type=ValueError):
     tensors = tf.parse_example(tf.placeholder(tf.string, [None]),
                                feature_spec)
     schema = dataset_schema.from_feature_spec(feature_spec)
     with self.assertRaisesRegexp(error_type, error_msg):
         impl_helper.make_feed_list(tensors, schema, instances)
示例#2
0
 def test_make_feed_list_error(self,
                               feature_spec,
                               instances,
                               error_msg,
                               error_type=ValueError):
     tensors = tf.io.parse_example(serialized=tf.compat.v1.placeholder(
         tf.string, [None]),
                                   features=feature_spec)
     schema = schema_utils.schema_from_feature_spec(feature_spec)
     with self.assertRaisesRegexp(error_type, error_msg):
         impl_helper.make_feed_list(tensors, schema, instances)
 def test_make_feed_list(self, feature_spec, instances, feed_dict):
     schema = dataset_schema.from_feature_spec(feature_spec)
     feature_names = list(feature_spec.keys())
     expected_feed_list = [feed_dict[key] for key in feature_names]
     np.testing.assert_equal(
         impl_helper.make_feed_list(feature_names, schema, instances),
         expected_feed_list)
示例#4
0
  def _handle_batch(self, batch):
    self._batch_size_distribution.update(len(batch))
    self._num_instances.inc(len(batch))

    # Making a copy of batch because mutating PCollection elements is not
    # allowed.
    if self._passthrough_keys:
      batch = [copy.copy(x) for x in batch]
    # Extract passthrough data.
    passthrough_data = {
        key: [instance.pop(key) for instance in batch
             ] for key in self._passthrough_keys
    }

    feed_list = impl_helper.make_feed_list(self._graph_state.inputs_tensor_keys,
                                           self._input_schema, batch)

    try:
      outputs_list = self._graph_state.callable_get_outputs(*feed_list)
    except Exception as e:
      tf.logging.error('%s while applying transform function for tensors %s',
                       e, self._graph_state.outputs_tensor_keys)
      raise ValueError('bad inputs: {}'.format(feed_list))

    assert len(self._graph_state.outputs_tensor_keys) == len(outputs_list)
    result = {
        key: value for key, value in zip(self._graph_state.outputs_tensor_keys,
                                         outputs_list)
    }

    for key, value in six.iteritems(passthrough_data):
      result[key] = value

    return result
示例#5
0
    def benchmarkRunMetagraphDoFnAtTFLevel(self):
        """Benchmark RunMetaGraphDoFn at the TF level.

    Benchmarks the parts of RunMetaGraphDoFn that involve feeding and
    fetching from the TFT SavedModel. Records the wall time taken.

    Note that this benchmark necessarily duplicates code directly from TFT
    since it's benchmarking the low-level internals of TFT, which are not
    exposed for use in this way.
    """
        common_variables = _get_common_variables(self._dataset)
        tf_config = tft_beam_impl._FIXED_PARALLELISM_TF_CONFIG  # pylint: disable=protected-access
        input_schema = common_variables.transform_input_dataset_metadata.schema

        # This block copied from _GraphState.__init__
        with tf.compat.v1.Graph().as_default() as graph:
            session = tf.compat.v1.Session(graph=graph, config=tf_config)
            with session.as_default():
                # TODO(b/148082271): Revert back to unpacking the result directly once
                # TFX depends on TFT 0.22.
                apply_saved_model_result = (
                    saved_transform_io.
                    partially_apply_saved_transform_internal(
                        self._dataset.tft_saved_model_path(), {}))
                inputs, outputs = apply_saved_model_result[:2]
                session.run(tf.compat.v1.global_variables_initializer())
                session.run(tf.compat.v1.tables_initializer())
                graph.finalize()
            # We ignore the schema, and assume there are no excluded outputs.
            outputs_tensor_keys = sorted(set(outputs.keys()))
            fetches = [outputs[key] for key in outputs_tensor_keys]
            tensor_inputs = graph_tools.get_dependent_inputs(
                graph, inputs, fetches)
            input_tensor_keys = sorted(tensor_inputs.keys())
            feed_list = [inputs[key] for key in input_tensor_keys]
            callable_get_outputs = session.make_callable(fetches,
                                                         feed_list=feed_list)

        batch_size, batched_records = _get_batched_records(self._dataset)

        # This block copied from _RunMetaGraphDoFn._handle_batch
        start = time.time()
        for batch in batched_records:
            feed_list = impl_helper.make_feed_list(input_tensor_keys,
                                                   input_schema, batch)
            outputs_list = callable_get_outputs(*feed_list)
            _ = {
                key: value
                for key, value in zip(outputs_tensor_keys, outputs_list)
            }
        end = time.time()
        delta = end - start

        self.report_benchmark(iters=1,
                              wall_time=delta,
                              extras={
                                  "batch_size": batch_size,
                                  "num_examples": self._dataset.num_examples()
                              })
示例#6
0
  def _make_feed_list(self, batch):
    if self._use_tfxio:
      feed_by_name = self._tensor_adapter.ToBatchTensors(
          batch, produce_eager_tensors=False)
      return [
          feed_by_name[name] for name in self._graph_state.inputs_tensor_keys]

    return impl_helper.make_feed_list(
        self._graph_state.inputs_tensor_keys, self._input_schema, batch)
示例#7
0
 def test_make_feed_list(self, feature_spec, instances, feed_dict,
                         produce_eager_tensors):
     if produce_eager_tensors:
         test_case.skip_if_not_tf2('Tensorflow 2.x required')
     schema = schema_utils.schema_from_feature_spec(feature_spec)
     feature_names = list(feature_spec.keys())
     expected_feed_list = [feed_dict[key] for key in feature_names]
     evaluated_feed_list = impl_helper.make_feed_list(
         feature_names,
         schema,
         instances,
         produce_eager_tensors=produce_eager_tensors)
     np.testing.assert_equal(
         evaluated_feed_list if not produce_eager_tensors else
         _get_value_from_eager_tensors(evaluated_feed_list),
         expected_feed_list)
示例#8
0
    def _handle_batch(self, batch):
        self._batch_size_distribution.update(len(batch))
        self._num_instances.inc(len(batch))

        # Making a copy of batch because mutating PCollection elements is not
        # allowed.
        if self._passthrough_keys:
            batch = [copy.copy(x) for x in batch]
        # Extract passthrough data.
        passthrough_data = {
            key: [instance.pop(key) for instance in batch]
            for key in self._passthrough_keys
        }

        feed_list = impl_helper.make_feed_list(
            self._graph_state.inputs_tensor_keys, self._input_schema, batch)

        try:
            outputs_list = self._graph_state.callable_get_outputs(*feed_list)
        except Exception as e:
            raise ValueError(
                """An error occured while trying to apply the transformation: "{}".
          Batch instances: {},
          Fetching the values for the following Tensor keys: {}.""".format(
                    str(e), batch, self._graph_state.outputs_tensor_keys))

        assert len(self._graph_state.outputs_tensor_keys) == len(outputs_list)
        result = {
            key: value
            for key, value in zip(self._graph_state.outputs_tensor_keys,
                                  outputs_list)
        }

        for key, value in six.iteritems(passthrough_data):
            result[key] = value

        return result