feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor") full_model_dir = os.path.join(temp_dir, "full_model") self.assertCommandSucceeded("export_mnist_cnn", fast_test_mode=fast_test_mode, export_dir=feature_extrator_dir) use_kwargs = dict(fast_test_mode=fast_test_mode, input_saved_model_dir=feature_extrator_dir, retrain=retrain_flag_value, use_keras_save_api=use_keras_save_api) if full_model_dir is not None: use_kwargs["output_saved_model_dir"] = full_model_dir if named_strategy: use_kwargs["strategy"] = str(named_strategy) if regularization_loss_multiplier is not None: use_kwargs[ "regularization_loss_multiplier"] = regularization_loss_multiplier self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs) if full_model_dir is not None: self.assertCommandSucceeded("deploy_mnist_cnn", fast_test_mode=fast_test_mode, saved_model_dir=full_model_dir) if __name__ == "__main__": scripts.MaybeRunScriptInstead() tf.test.main()
"use_text_embedding_in_dataset", model_dir=export_dir) def test_mnist_cnn(self): self.skipIfMissingExtraDeps() export_dir = self.get_temp_dir() self.assertCommandSucceeded( "export_mnist_cnn", export_dir=export_dir, fast_test_mode="true") self.assertCommandSucceeded( "use_mnist_cnn", export_dir=export_dir, fast_test_mode="true") def test_mnist_cnn_with_mirrored_strategy(self): self.skipIfMissingExtraDeps() self.skipTest( "b/129134185 - saved model and distribution strategy integration") export_dir = self.get_temp_dir() self.assertCommandSucceeded( "export_mnist_cnn", export_dir=export_dir, fast_test_mode="true") self.assertCommandSucceeded( "use_mnist_cnn", export_dir=export_dir, fast_test_mode="true", use_mirrored_strategy=True, ) if __name__ == "__main__": integration_scripts.MaybeRunScriptInstead() tf.test.main()