コード例 #1
0
  def test_export_tpu_savedmodel_e2e(self, export_tpu_tensor, export_cpu_tensor,
                                     use_export_mode_v2):
    tmpdir = tempfile.mkdtemp()

    def _input_fn(params):
      return dummy_input_fn(params['batch_size'])

    model_fn = get_model_fn(export_tpu_tensor, export_cpu_tensor)
    run_config = create_run_config(iterations_per_loop=4)
    if use_export_mode_v2:
      export_api_version = tpu_estimator.ExportSavedModelApiVersion.V2

      batch_config = tpu_estimator.BatchConfig(
          num_batch_threads=1,
          max_batch_size=1,
          batch_timeout_micros=100,
          allowed_batch_sizes=[1])

      def tpu_model_fn(features, labels, mode, params):
        if mode == _PREDICT and params['use_tpu']:
          return tpu_estimator.model_fn_inference_on_tpu(
              model_fn, features, labels, mode, params, batch_config)
        else:
          return model_fn(features, labels, mode, params)

      est_model_fn = tpu_model_fn
    else:
      export_api_version = tpu_estimator.ExportSavedModelApiVersion.V1
      est_model_fn = model_fn
    est = tpu_estimator.TPUEstimator(
        model_fn=est_model_fn,
        config=run_config,
        train_batch_size=16,
        export_to_tpu=True,
        export_saved_model_api_version=export_api_version)
    est.train(_input_fn, steps=1)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_saved_model(export_dir_base,
                                        self._serving_input_receiver_fn)

    self._validate_export(export_dir_base, export_dir, export_tpu_tensor,
                          export_cpu_tensor)

    # Clean up.
    gfile.DeleteRecursively(tmpdir)
コード例 #2
0
    def model_fn_for_export(features, labels, mode, params, config):
      """The model_fn to use during export for TPU."""

      assert mode == tf.estimator.ModeKeys.PREDICT
      batch_config = tpu_estimator.BatchConfig(
          # Set num_batch_threads to the number of TPU cores on Servomatic.
          num_batch_threads=2,
          max_batch_size=self._predict_batch_size,
          # TODO: Magic number. Investigate whether there is a better
          # way to set this, or have the user pass it in.
          batch_timeout_micros=60 * 1000,
          allowed_batch_sizes=[self._predict_batch_size])
      return tpu_estimator.model_fn_inference_on_tpu(
          functools.partial(self._adanet_model_fn, hooks=hooks),
          features=features,
          labels=labels,
          config=config,
          params=params,
          batch_config=batch_config)
コード例 #3
0
ファイル: tpu_estimator.py プロジェクト: xuyuewei/adanet
        def _model_fn(features, labels, mode, params, config):
            """The model_fn to return which supports exporting on TPU."""

            if (is_export and params["use_tpu"]
                    and mode == tf.estimator.ModeKeys.PREDICT):
                batch_config = tpu_estimator.BatchConfig(
                    # Set num_batch_threads to the number of TPU cores on Servomatic.
                    num_batch_threads=2,
                    max_batch_size=self._predict_batch_size,
                    # TODO: Magic number. Investigate whether there is a better
                    # way to set this, or have the user pass it in.
                    batch_timeout_micros=60 * 1000,
                    allowed_batch_sizes=[self._predict_batch_size])
                return tpu_estimator.model_fn_inference_on_tpu(
                    adanet_model_fn,
                    features=features,
                    labels=labels,
                    config=config,
                    params=params,
                    batch_config=batch_config)

            return adanet_model_fn(features, labels, mode, params, config)