예제 #1
0
 def test_tf_keras_converter(self):
     target_framework, model = self._do_convert(self.keras_model_param,
                                                self.keras_model_meta)
     self.assertTrue(target_framework == "tf_keras")
     self.assertTrue(isinstance(model, tf.keras.Sequential))
     with tempfile.TemporaryDirectory() as d:
         dest = save_converted_model(model, target_framework, d)
         self.assertTrue(os.path.isdir(dest))
예제 #2
0
 def test_pytorch_lightning_converter(self):
     target_framework, model = self._do_convert(self.pl_model_param,
                                                self.pl_model_meta)
     self.assertTrue(target_framework == "pytorch")
     self.assertTrue(isinstance(model, torch.nn.Sequential))
     with tempfile.TemporaryDirectory() as d:
         dest = save_converted_model(model, target_framework, d)
         self.assertTrue(os.path.isfile(dest))
         self.assertTrue(dest.endswith(".pth"))
예제 #3
0
    def test_sklearn_converter(self):
        target_framework, model = model_convert(model_contents={
                                                    'HomoLogisticRegressionParam': self.model_param,
                                                    'HomoLogisticRegressionMeta': self.model_meta
                                                },
                                                module_name='HomoLR',
                                                framework_name='sklearn')
        self.assertTrue(target_framework == 'sklearn')
        self.assertTrue(isinstance(model, LogisticRegression))
        self.assertTrue(model.intercept_[0] == self.model_param.intercept)
        self.assertTrue(model.coef_.shape == (1, len(self.model_param.header)))
        self.assertTrue(model.tol == self.model_meta.tol)

        with tempfile.TemporaryDirectory() as d:
            dest = save_converted_model(model, target_framework, d)
            self.assertTrue(os.path.isfile(dest))
            self.assertTrue(dest.endswith(".joblib"))