Example #1
0
    def testCreateSklearnInvalidModel(self):
        model_path = os.path.join(FLAGS.test_tmpdir)
        # Copying a .joblib model with incorrect suffix (.pkl), so that it cannot be
        # loaded.
        shutil.copy2(
            os.path.join(FLAGS.test_srcdir,
                         SKLEARN_JOBLIB_MODEL, "model.joblib"),
            os.path.join(model_path, "model.pkl"))

        with self.assertRaises(mlprediction.PredictionError) as error:
            mlprediction.create_sklearn_model(model_path, None)
        self.assertEqual(
            error.exception.error_code,
            mlprediction.PredictionError.FAILED_TO_LOAD_MODEL.code)
        self.assertIn("Could not load the model", error.exception.error_detail)
Example #2
0
 def testInvalidPredictionWithSklearn(self):
     model_path = os.path.join(FLAGS.test_srcdir, SKLEARN_JOBLIB_MODEL)
     # model is a Scikit-Learn classifier.
     model = mlprediction.create_sklearn_model(model_path, None)
     # The shape doesn't match the expected shape of: (2,)
     inputs = [[10, 20, 30]]
     with self.assertRaises(mlprediction.PredictionError) as error:
         model.predict(inputs, stats=None)
     self.assertEqual(error.exception.error_code,
                      mlprediction.PredictionError.FAILED_TO_RUN_MODEL.code)
     self.assertIn("Exception during sklearn prediction",
                   error.exception.error_detail)
Example #3
0
 def testCreateSklearnModelFromPickle(self):
     model_path = os.path.join(FLAGS.test_srcdir, SKLEARN_PICKLE_MODEL)
     # model is a Scikit-Learn classifier.
     model = mlprediction.create_sklearn_model(model_path, None)
     inputs = [[10, 20], [1, 2], [5, 6]]
     stats = mlprediction.Stats()
     stats["dummy"] = 1  # So that a new stats object is not created.
     original_inputs, predictions = model.predict(inputs, stats=stats)
     self.assertEqual(predictions, [30, 3, 11])
     self.assertEqual(original_inputs, inputs)
     self.assertEqual(stats[mlprediction.ENGINE],
                      mlprediction.SCIKIT_LEARN_FRAMEWORK_NAME)
Example #4
0
 def testSklearnModelNotFound(self):
     model_path = os.path.join(FLAGS.test_srcdir, "non_existent_path")
     with self.assertRaises(mlprediction.PredictionError) as error:
         mlprediction.create_sklearn_model(model_path, None)
     self.assertIn("Could not find ", error.exception.error_detail)