Example #1
0
 def test_map(self):
     result = StringMap(
         labels={'a': 1.0},
         inputCol='key_col',
         outputCol='value_col',
     ).transform(self.input)
     expected = StringMapTest.spark.createDataFrame([['a', 'b', 1.0]], OUTPUT_SCHEMA)
     assert_df(expected, result)
Example #2
0
 def test_map_missing_value_error(self):
     with self.assertRaises(Py4JJavaError) as error:
         StringMap(
             labels={'z': 1.0},
             inputCol='key_col',
             outputCol='value_col'
         ).transform(self.input).collect()
     self.assertIn('java.util.NoSuchElementException: Missing label: a', str(error.exception))
Example #3
0
 def test_map_from_dataframe(self):
     labels_df = StringMapTest.spark.createDataFrame([['a', 1.0]], 'key_col: string, value_col: double')
     result = StringMap.from_dataframe(
         labels_df=labels_df,
         inputCol='key_col',
         outputCol='value_col'
     ).transform(self.input)
     expected = StringMapTest.spark.createDataFrame([['a', 'b', 1.0]], OUTPUT_SCHEMA)
     assert_df(expected, result)
Example #4
0
 def test_serialize_to_bundle(self):
     string_map = StringMap({'a': 1.0}, 'key_col', 'value_col')
     pipeline = Pipeline(stages=[string_map]).fit(self.input)
     pipeline_file = os.path.join(os.path.dirname(__file__), '..', '..',
                                  'target', 'test_serialize_to_bundle-pipeline.zip')
     _serialize_to_file(pipeline_file, self.input, pipeline)
     deserialized_pipeline = _deserialize_from_file(pipeline_file)
     result = deserialized_pipeline.transform(self.input)
     expected = self.spark.createDataFrame([['a', 'b', 1.0]], OUTPUT_SCHEMA)
     assert_df(expected, result)
Example #5
0
 def test_map_custom_default_value(self):
     result = StringMap(
         labels={'z': 1.0},
         inputCol='key_col',
         outputCol='value_col',
         handleInvalid='keep',
         defaultValue=-1.0
     ).transform(self.input)
     expected = StringMapTest.spark.createDataFrame([['a', 'b', -1.0]], OUTPUT_SCHEMA)
     assert_df(expected, result)
Example #6
0
    def test_serialize_to_bundle(self):
        string_map = StringMap(
            labels={'a': 1.0},
            inputCol='key_col',
            outputCol='value_col',
        )
        pipeline = Pipeline(stages=[string_map]).fit(self.input)
        serialization_dataset = pipeline.transform(self.input)

        jar_file_path = _serialize_to_file(pipeline, serialization_dataset)
        deserialized_pipeline = _deserialize_from_file(jar_file_path)

        result = deserialized_pipeline.transform(self.input)
        expected = StringMapTest.spark.createDataFrame([['a', 'b', 1.0]], OUTPUT_SCHEMA)
        assert_df(expected, result)
Example #7
0
 def test_validate_labels_type_fails(self):
     with self.assertRaises(AssertionError):
         StringMap(None, set())
Example #8
0
 def test_validate_handleInvalid_bad(self):
     with self.assertRaises(AssertionError):
         StringMap(None, dict(), handleInvalid='invalid')
Example #9
0
 def test_validate_handleInvalid_ok():
     StringMap({}, handleInvalid='error')
Example #10
0
 def test_map_default_value(self):
     result = StringMap({'z': 1.0}, 'key_col', 'value_col', handleInvalid='keep').transform(self.input)
     expected = self.spark.createDataFrame([['a', 'b', 0.0]], OUTPUT_SCHEMA)
     assert_df(expected, result)
Example #11
0
 def test_map(self):
     result = StringMap({'a': 1.0}, 'key_col', 'value_col').transform(self.input)
     expected = self.spark.createDataFrame([['a', 'b', 1.0]], OUTPUT_SCHEMA)
     assert_df(expected, result)
Example #12
0
 def test_validate_labels_type_fails(self):
     with self.assertRaises(AssertionError):
         StringMap(labels=None, inputCol=set(), outputCol=None)
Example #13
0
 def test_validate_labels_value_fails(self):
     with self.assertRaises(AssertionError):
         StringMap(labels=None, inputCol={'valid_key_type': 'invalid_value_type'}, outputCol=None)
Example #14
0
 def test_validate_labels_key_fails(self):
     with self.assertRaises(AssertionError):
         StringMap(labels=None, inputCol={False: 0.0}, outputCol=None)
Example #15
0
 def test_validate_labels_key_fails(self):
     with self.assertRaises(AssertionError):
         StringMap(None, {False: 0.0})
Example #16
0
 def test_validate_labels_value_fails(self):
     with self.assertRaises(AssertionError):
         StringMap(None, {'valid_key_type': 'invalid_value_type'})
Example #17
0
 def test_validate_handleInvalid_bad(self):
     with self.assertRaises(AssertionError):
         StringMap(labels=None, inputCol=dict(), outputCol=None, handleInvalid='invalid')