コード例 #1
0
    def model_fn_for_export(features, labels, mode, params, config):
      """The model_fn to use during export for TPU."""

      assert mode == tf.estimator.ModeKeys.PREDICT
      batch_config = tpu_estimator.BatchConfig(
          # Set num_batch_threads to the number of TPU cores on Servomatic.
          num_batch_threads=2,
          max_batch_size=self._predict_batch_size,
          # TODO: Magic number. Investigate whether there is a better
          # way to set this, or have the user pass it in.
          batch_timeout_micros=60 * 1000,
          allowed_batch_sizes=[self._predict_batch_size])
      return tpu_estimator.model_fn_inference_on_tpu(
          functools.partial(self._adanet_model_fn, hooks=hooks),
          features=features,
          labels=labels,
          config=config,
          params=params,
          batch_config=batch_config)
コード例 #2
0
ファイル: tpu_estimator.py プロジェクト: xuyuewei/adanet
        def _model_fn(features, labels, mode, params, config):
            """The model_fn to return which supports exporting on TPU."""

            if (is_export and params["use_tpu"]
                    and mode == tf.estimator.ModeKeys.PREDICT):
                batch_config = tpu_estimator.BatchConfig(
                    # Set num_batch_threads to the number of TPU cores on Servomatic.
                    num_batch_threads=2,
                    max_batch_size=self._predict_batch_size,
                    # TODO: Magic number. Investigate whether there is a better
                    # way to set this, or have the user pass it in.
                    batch_timeout_micros=60 * 1000,
                    allowed_batch_sizes=[self._predict_batch_size])
                return tpu_estimator.model_fn_inference_on_tpu(
                    adanet_model_fn,
                    features=features,
                    labels=labels,
                    config=config,
                    params=params,
                    batch_config=batch_config)

            return adanet_model_fn(features, labels, mode, params, config)
コード例 #3
0
 def tpu_model_fn(features, labels, mode, params):
     if mode == _PREDICT and params['use_tpu']:
         return tpu_estimator.model_fn_inference_on_tpu(
             model_fn, features, labels, mode, params, batch_config)
     else:
         return model_fn(features, labels, mode, params)