def test_bad_input_type_raises(self):
     with self.assertRaisesRegex(AssertionError,
                                 'Unsupported serialization type'):
         with tempfile.NamedTemporaryFile() as file:
             model_loader = SklearnModelLoader(model_uri=file.name,
                                               model_file_type=None)
             model_loader.load_model()
 def test_bad_file_raises(self):
     with self.assertRaises(RuntimeError):
         with TestPipeline() as pipeline:
             examples = [numpy.array([0, 0])]
             pcoll = pipeline | 'start' >> beam.Create(examples)
             # TODO(BEAM-14305) Test against the public API.
             _ = pcoll | base.RunInference(
                 SklearnModelLoader(model_uri='/var/bad_file_name'))
             pipeline.run()
    def test_pipeline_pickled(self):
        temp_file_name = self.tmpdir + os.sep + 'pickled_file'
        with open(temp_file_name, 'wb') as file:
            pickle.dump(build_model(), file)
        with TestPipeline() as pipeline:
            examples = [numpy.array([0, 0]), numpy.array([1, 1])]

            pcoll = pipeline | 'start' >> beam.Create(examples)
            #TODO(BEAM-14305) Test against the public API.
            actual = pcoll | base.RunInference(
                SklearnModelLoader(model_uri=temp_file_name))
            expected = [
                api.PredictionResult(numpy.array([0, 0]), 0),
                api.PredictionResult(numpy.array([1, 1]), 1)
            ]
            assert_that(
                actual, equal_to(expected,
                                 equals_fn=_compare_prediction_result))