def test_load_file_h5(self): with tempfile.NamedTemporaryFile(suffix='.h5') as tmp: tf.keras.models.save_model(self.model, tmp.name) loaded_model = loading.load_file(tmp.name) actual_predictions = loaded_model.predict(self.X) expected_predictions = self.predictions self.assertTrue(np.allclose(actual_predictions, expected_predictions))
def test_load_file_pkl(self): with tempfile.NamedTemporaryFile(suffix='.pkl') as tmp: joblib.dump(self.model, tmp.name) loaded_model = loading.load_file(tmp.name) actual_predictions = loaded_model.predict(self.X) expected_predictions = self.predictions self.assertTrue(np.allclose(actual_predictions, expected_predictions))
def from_file(cls, path, *args, s3_access_key_id=None, s3_secret_access_key=None, **kwargs): model = load_file(path, s3_access_key_id, s3_secret_access_key) return cls(model, *args, **kwargs)
def from_file(cls, path, *args, s3_access_key_id=None, s3_secret_access_key=None, **kwargs): transformer = load_file(path, s3_access_key_id, s3_secret_access_key) return cls(transformer, *args, **kwargs)
def test_load_file_s3(self): key = 'data-science/porter/tests/sklearn_model.pkl' s3_path = 's3://%s/%s' % (self.bucket, key) with tempfile.NamedTemporaryFile(suffix='.pkl') as tmp: joblib.dump(self.model, tmp.name) self.write_to_s3(tmp.name, self.bucket, key) loaded_model = loading.load_file( s3_path, s3_access_key_id=self.s3_access_key_id, s3_secret_access_key=self.s3_secret_access_key) actual_predictions = loaded_model.predict(self.X) expected_predictions = self.model.predict(self.X) self.assertTrue(np.allclose(actual_predictions, expected_predictions))
def test_load_file_s3_fail_missing_bucket(self): with self.assertRaisesRegex(Exception, r'An error occurred \(40[34]\)'): loading.load_file('s3://invalid-bucket/this/does/not/exist', s3_access_key_id=self.s3_access_key_id, s3_secret_access_key=self.s3_secret_access_key)
def test_load_file_s3_fail_missing_key(self): self.bucket = os.environ['PORTER_S3_BUCKET_TEST'] with self.assertRaisesRegex(Exception, r'An error occurred \(404\)'): loading.load_file('s3://%s/this/does/not/exist' % self.bucket, s3_access_key_id=self.s3_access_key_id, s3_secret_access_key=self.s3_secret_access_key)