def test_export_tpu_savedmodel_e2e(self, export_tpu_tensor, export_cpu_tensor, use_export_mode_v2): tmpdir = tempfile.mkdtemp() def _input_fn(params): return dummy_input_fn(params['batch_size']) model_fn = get_model_fn(export_tpu_tensor, export_cpu_tensor) run_config = create_run_config(iterations_per_loop=4) if use_export_mode_v2: export_api_version = tpu_estimator.ExportSavedModelApiVersion.V2 batch_config = tpu_estimator.BatchConfig( num_batch_threads=1, max_batch_size=1, batch_timeout_micros=100, allowed_batch_sizes=[1]) def tpu_model_fn(features, labels, mode, params): if mode == _PREDICT and params['use_tpu']: return tpu_estimator.model_fn_inference_on_tpu( model_fn, features, labels, mode, params, batch_config) else: return model_fn(features, labels, mode, params) est_model_fn = tpu_model_fn else: export_api_version = tpu_estimator.ExportSavedModelApiVersion.V1 est_model_fn = model_fn est = tpu_estimator.TPUEstimator( model_fn=est_model_fn, config=run_config, train_batch_size=16, export_to_tpu=True, export_saved_model_api_version=export_api_version) est.train(_input_fn, steps=1) # Perform the export. export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) export_dir = est.export_saved_model(export_dir_base, self._serving_input_receiver_fn) self._validate_export(export_dir_base, export_dir, export_tpu_tensor, export_cpu_tensor) # Clean up. gfile.DeleteRecursively(tmpdir)
def model_fn_for_export(features, labels, mode, params, config): """The model_fn to use during export for TPU.""" assert mode == tf.estimator.ModeKeys.PREDICT batch_config = tpu_estimator.BatchConfig( # Set num_batch_threads to the number of TPU cores on Servomatic. num_batch_threads=2, max_batch_size=self._predict_batch_size, # TODO: Magic number. Investigate whether there is a better # way to set this, or have the user pass it in. batch_timeout_micros=60 * 1000, allowed_batch_sizes=[self._predict_batch_size]) return tpu_estimator.model_fn_inference_on_tpu( functools.partial(self._adanet_model_fn, hooks=hooks), features=features, labels=labels, config=config, params=params, batch_config=batch_config)
def _model_fn(features, labels, mode, params, config): """The model_fn to return which supports exporting on TPU.""" if (is_export and params["use_tpu"] and mode == tf.estimator.ModeKeys.PREDICT): batch_config = tpu_estimator.BatchConfig( # Set num_batch_threads to the number of TPU cores on Servomatic. num_batch_threads=2, max_batch_size=self._predict_batch_size, # TODO: Magic number. Investigate whether there is a better # way to set this, or have the user pass it in. batch_timeout_micros=60 * 1000, allowed_batch_sizes=[self._predict_batch_size]) return tpu_estimator.model_fn_inference_on_tpu( adanet_model_fn, features=features, labels=labels, config=config, params=params, batch_config=batch_config) return adanet_model_fn(features, labels, mode, params, config)