Exemplo n.º 1
0
        feature_extrator_dir = os.path.join(temp_dir,
                                            "mnist_feature_extractor")
        full_model_dir = os.path.join(temp_dir, "full_model")

        self.assertCommandSucceeded("export_mnist_cnn",
                                    fast_test_mode=fast_test_mode,
                                    export_dir=feature_extrator_dir)

        use_kwargs = dict(fast_test_mode=fast_test_mode,
                          input_saved_model_dir=feature_extrator_dir,
                          retrain=retrain_flag_value,
                          use_keras_save_api=use_keras_save_api)
        if full_model_dir is not None:
            use_kwargs["output_saved_model_dir"] = full_model_dir
        if named_strategy:
            use_kwargs["strategy"] = str(named_strategy)
        if regularization_loss_multiplier is not None:
            use_kwargs[
                "regularization_loss_multiplier"] = regularization_loss_multiplier
        self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs)

        if full_model_dir is not None:
            self.assertCommandSucceeded("deploy_mnist_cnn",
                                        fast_test_mode=fast_test_mode,
                                        saved_model_dir=full_model_dir)


if __name__ == "__main__":
    scripts.MaybeRunScriptInstead()
    tf.test.main()
Exemplo n.º 2
0
        "use_text_embedding_in_dataset", model_dir=export_dir)

  def test_mnist_cnn(self):
    self.skipIfMissingExtraDeps()
    export_dir = self.get_temp_dir()
    self.assertCommandSucceeded(
        "export_mnist_cnn", export_dir=export_dir, fast_test_mode="true")
    self.assertCommandSucceeded(
        "use_mnist_cnn", export_dir=export_dir, fast_test_mode="true")

  def test_mnist_cnn_with_mirrored_strategy(self):
    self.skipIfMissingExtraDeps()
    self.skipTest(
        "b/129134185 - saved model and distribution strategy integration")
    export_dir = self.get_temp_dir()
    self.assertCommandSucceeded(
        "export_mnist_cnn",
        export_dir=export_dir,
        fast_test_mode="true")
    self.assertCommandSucceeded(
        "use_mnist_cnn",
        export_dir=export_dir,
        fast_test_mode="true",
        use_mirrored_strategy=True,
    )


if __name__ == "__main__":
  integration_scripts.MaybeRunScriptInstead()
  tf.test.main()