예제 #1
0
파일: test_base.py 프로젝트: Brett-A/spark
    def test_unary_transformer_validate_input_type(self):
        shiftVal = 3
        transformer = MockUnaryTransformer(shiftVal=shiftVal) \
            .setInputCol("input").setOutputCol("output")

        # should not raise any errors
        transformer.validateInputType(DoubleType())

        with self.assertRaises(TypeError):
            # passing the wrong input type should raise an error
            transformer.validateInputType(IntegerType())
예제 #2
0
    def test_unary_transformer_transform(self):
        shiftVal = 3
        transformer = MockUnaryTransformer(shiftVal=shiftVal) \
            .setInputCol("input").setOutputCol("output")

        df = self.spark.range(0, 10).toDF('input')
        df = df.withColumn("input", df.input.cast(dataType="double"))

        transformed_df = transformer.transform(df)
        results = transformed_df.select("input", "output").collect()

        for res in results:
            self.assertEqual(res.input + shiftVal, res.output)
예제 #3
0
파일: test_base.py 프로젝트: Brett-A/spark
    def test_unary_transformer_transform(self):
        shiftVal = 3
        transformer = MockUnaryTransformer(shiftVal=shiftVal) \
            .setInputCol("input").setOutputCol("output")

        df = self.spark.range(0, 10).toDF('input')
        df = df.withColumn("input", df.input.cast(dataType="double"))

        transformed_df = transformer.transform(df)
        results = transformed_df.select("input", "output").collect()

        for res in results:
            self.assertEqual(res.input + shiftVal, res.output)
예제 #4
0
    def test_python_transformer_pipeline_persistence(self):
        """
        Pipeline[MockUnaryTransformer, Binarizer]
        """
        temp_path = tempfile.mkdtemp()

        try:
            df = self.spark.range(0, 10).toDF("input")
            tf = MockUnaryTransformer(
                shiftVal=2).setInputCol("input").setOutputCol("shiftedInput")
            tf2 = Binarizer(threshold=6,
                            inputCol="shiftedInput",
                            outputCol="binarized")
            pl = Pipeline(stages=[tf, tf2])
            model = pl.fit(df)

            pipeline_path = temp_path + "/pipeline"
            pl.save(pipeline_path)
            loaded_pipeline = Pipeline.load(pipeline_path)
            self._compare_pipelines(pl, loaded_pipeline)

            model_path = temp_path + "/pipeline-model"
            model.save(model_path)
            loaded_model = PipelineModel.load(model_path)
            self._compare_pipelines(model, loaded_model)
        finally:
            try:
                rmtree(temp_path)
            except OSError:
                pass
예제 #5
0
파일: test_base.py 프로젝트: zoelin7/spark
    def test_unary_transformer_validate_input_type(self):
        shiftVal = 3
        transformer = (
            MockUnaryTransformer(shiftVal=shiftVal).setInputCol("input").setOutputCol("output")
        )

        # should not raise any errors
        transformer.validateInputType(DoubleType())

        with self.assertRaises(TypeError):
            # passing the wrong input type should raise an error
            transformer.validateInputType(IntegerType())