예제 #1
0
 def tpu_model_fn(features, labels, mode, params):
     """Wrapper model_fn with tpu.rewrite / TPUPartitionedCall."""
     if mode == tf.estimator.ModeKeys.PREDICT and params["use_tpu"]:
         return tpu_estimator.model_fn_inference_on_tpu(
             maybe_use_guarantee_const_getter_model_fn,
             features=features,
             labels=labels,
             config=None,
             params=params,
             batch_config=None)
     else:
         return model_fn(features, labels, mode, params)
예제 #2
0
 def tpu_model_fn(features, labels, mode, params):
     """Wrapper model_fn with tpu.rewrite / TPUPartitionedCall."""
     if mode == tf.estimator.ModeKeys.PREDICT and params["use_tpu"]:
         batch_config = tpu_estimator.BatchConfig(
             num_batch_threads=2,
             max_batch_size=predict_batch_size,
             batch_timeout_micros=60 * 1000,
             allowed_batch_sizes=[predict_batch_size])
         return tpu_estimator.model_fn_inference_on_tpu(
             maybe_use_guarantee_const_getter_model_fn,
             features=features,
             labels=labels,
             config=None,
             params=params,
             batch_config=batch_config)
     else:
         return model_fn(features, labels, mode, params)