def model_fn_for_export(features, labels, mode, params, config): """The model_fn to use during export for TPU.""" assert mode == tf.estimator.ModeKeys.PREDICT batch_config = tpu_estimator.BatchConfig( # Set num_batch_threads to the number of TPU cores on Servomatic. num_batch_threads=2, max_batch_size=self._predict_batch_size, # TODO: Magic number. Investigate whether there is a better # way to set this, or have the user pass it in. batch_timeout_micros=60 * 1000, allowed_batch_sizes=[self._predict_batch_size]) return tpu_estimator.model_fn_inference_on_tpu( functools.partial(self._adanet_model_fn, hooks=hooks), features=features, labels=labels, config=config, params=params, batch_config=batch_config)
def _model_fn(features, labels, mode, params, config): """The model_fn to return which supports exporting on TPU.""" if (is_export and params["use_tpu"] and mode == tf.estimator.ModeKeys.PREDICT): batch_config = tpu_estimator.BatchConfig( # Set num_batch_threads to the number of TPU cores on Servomatic. num_batch_threads=2, max_batch_size=self._predict_batch_size, # TODO: Magic number. Investigate whether there is a better # way to set this, or have the user pass it in. batch_timeout_micros=60 * 1000, allowed_batch_sizes=[self._predict_batch_size]) return tpu_estimator.model_fn_inference_on_tpu( adanet_model_fn, features=features, labels=labels, config=config, params=params, batch_config=batch_config) return adanet_model_fn(features, labels, mode, params, config)
def tpu_model_fn(features, labels, mode, params): if mode == _PREDICT and params['use_tpu']: return tpu_estimator.model_fn_inference_on_tpu( model_fn, features, labels, mode, params, batch_config) else: return model_fn(features, labels, mode, params)