def test_bad_input_type_raises(self): with self.assertRaisesRegex(AssertionError, 'Unsupported serialization type'): with tempfile.NamedTemporaryFile() as file: model_loader = SklearnModelLoader(model_uri=file.name, model_file_type=None) model_loader.load_model()
def test_bad_file_raises(self): with self.assertRaises(RuntimeError): with TestPipeline() as pipeline: examples = [numpy.array([0, 0])] pcoll = pipeline | 'start' >> beam.Create(examples) # TODO(BEAM-14305) Test against the public API. _ = pcoll | base.RunInference( SklearnModelLoader(model_uri='/var/bad_file_name')) pipeline.run()
def test_pipeline_pickled(self): temp_file_name = self.tmpdir + os.sep + 'pickled_file' with open(temp_file_name, 'wb') as file: pickle.dump(build_model(), file) with TestPipeline() as pipeline: examples = [numpy.array([0, 0]), numpy.array([1, 1])] pcoll = pipeline | 'start' >> beam.Create(examples) #TODO(BEAM-14305) Test against the public API. actual = pcoll | base.RunInference( SklearnModelLoader(model_uri=temp_file_name)) expected = [ api.PredictionResult(numpy.array([0, 0]), 0), api.PredictionResult(numpy.array([1, 1]), 1) ] assert_that( actual, equal_to(expected, equals_fn=_compare_prediction_result))