Exemplo n.º 1
0
        def __init__(self, model_dir, skip_preprocessing):
            self.model_dir = model_dir

            session, signature = mlprediction.load_model(model_dir)
            client = mlprediction.SessionClient(session, signature)
            self.model = mlprediction.create_model(
                client, model_dir, skip_preprocessing=skip_preprocessing)
Exemplo n.º 2
0
    def testLoadCustomSignature(self):
        model_dir = os.path.join(FLAGS.test_tmpdir, "identity_model")
        model_test_util.create_identity_model(model_dir=model_dir,
                                              signature_name="mysignature",
                                              tags=("tag1", "tag2"))
        _, signature_map = mlprediction.load_model(model_dir,
                                                   tags=("tag1", "tag2"))
        signature = signature_map["mysignature"]
        self.assertEqual([i for i in signature.inputs], ["in"])
        self.assertEqual("Print:0", signature.inputs["in"].name)
        self.assertEqual([i for i in signature.outputs], ["out"])
        self.assertEqual("Print_1:0", signature.outputs["out"].name)

        with self.assertRaises(mlprediction.PredictionError) as error:
            _, _ = mlprediction.load_model(model_dir, tags=("tag1", ))
        self.assertEqual(
            error.exception.error_detail,
            "Failed to load the model due to bad model data. "
            "tags: ['tag1']\nMetaGraphDef associated with tags 'tag1' "
            "could not be found in SavedModel. To inspect available "
            "tag-sets in the SavedModel, please use the SavedModel "
            "CLI: `saved_model_cli`")
Exemplo n.º 3
0
    def testConfigIsSet(self):
        # Arrange
        test_config = tf.ConfigProto(inter_op_parallelism_threads=3)

        # Act
        model_path = os.path.join(FLAGS.test_srcdir, INPUT_MODEL)
        session, _ = mlprediction.load_model(
            model_path,
            tags=(tf.saved_model.tag_constants.SERVING, ),
            config=test_config)

        # Assert
        self.assertEqual(session._config, test_config)