def test_exploded_array_is_added(self, input_df, default_params): transformer = Exploder(**default_params) expected_columns = set(input_df.columns + [default_params["exploded_elem_name"]]) actual_columns = set(transformer.transform(input_df).columns) assert expected_columns == actual_columns
def test_records_with_empty_arrays_are_dropped_by_default( self, spark_session): input_df = spark_session.createDataFrame([ Row(id=1, array_to_explode=[]), Row(id=2, array_to_explode=[ Row(elem_id="a"), Row(elem_id="b"), Row(elem_id="c") ]), Row(id=3, array_to_explode=[]), ]) transformed_df = Exploder( path_to_array="array_to_explode", exploded_elem_name="elem").transform(input_df) assert transformed_df.count() == 3
def test_records_with_empty_arrays_are_kept_via_setting( self, spark_session): input_df = spark_session.createDataFrame([ Row(id=1, array_to_explode=[]), Row(id=2, array_to_explode=[ Row(elem_id="a"), Row(elem_id="b"), Row(elem_id="c") ]), Row(id=3, array_to_explode=[]), ]) transformed_df = Exploder( path_to_array="array_to_explode", exploded_elem_name="elem", drop_rows_with_empty_array=False).transform(input_df) assert transformed_df.count() == 5
def _get_preliminary_mapping(self, input_df, json_schema, mapping, current_path, exploded_arrays): for field in json_schema["fields"]: self.logger.debug(json.dumps(field, indent=2)) if self._field_is_original_columns_struct(field) and self.keep_original_columns: continue elif self._field_is_atomic(field): self.logger.debug(f"Atomic Field found: {field['name']}") mapping = self._add_field_to_mapping(mapping, current_path, field) elif self._field_is_struct(field): self.logger.debug(f"Struct Field found: {field['name']}") struct_name = field["name"] new_path = current_path + [struct_name] input_df, mapping = self._get_preliminary_mapping(input_df=input_df, json_schema=field["type"], mapping=mapping, current_path=new_path, exploded_arrays=exploded_arrays) elif self._field_is_array(field): self.logger.debug(f"Array Field found: {field['name']}") pretty_field_name = field["name"] field_name = "_".join(current_path + [pretty_field_name]) array_path = ".".join(current_path + [pretty_field_name]) if array_path in exploded_arrays: self.logger.debug(f"Skipping explosion of {field_name}, as it was already exploded") continue else: if self.pretty_names: try: input_df[f"{pretty_field_name}_exploded"] # If no exception is thrown, then the name already taken and the full path will be used exploded_elem_name = f"{field_name}_exploded" except AnalysisException: exploded_elem_name = f"{pretty_field_name}_exploded" else: exploded_elem_name = f"{field_name}_exploded" self.logger.debug(f"Exploding {array_path} into {exploded_elem_name}") exploded_df = Exploder(path_to_array=array_path, exploded_elem_name=exploded_elem_name).transform(input_df) self._script_add_explode_transformation(path_to_array=array_path, exploded_elem_name=exploded_elem_name) exploded_df = Exploder(path_to_array=array_path, exploded_elem_name=exploded_elem_name, drop_rows_with_empty_array=False).transform(input_df) exploded_arrays.append(array_path) return self._get_preliminary_mapping(input_df=exploded_df, json_schema=exploded_df.schema.jsonValue(), mapping=[], current_path=[], exploded_arrays=exploded_arrays) return (input_df, mapping)
def test_array_is_converted_to_struct(self, input_df, default_params): def get_data_type_of_column(df, path=["attributes"]): record = df.first().asDict(recursive=True) for p in path: record = record[p] return type(record) current_data_type_friend = get_data_type_of_column( input_df, path=["attributes", "friends"]) assert issubclass(current_data_type_friend, list) transformed_df = Exploder(**default_params).transform(input_df) transformed_data_type = get_data_type_of_column(transformed_df, path=["friend"]) assert issubclass(transformed_data_type, dict)
def test_count(self, input_df, default_params): expected_count = input_df.select( sql_funcs.explode( input_df[default_params["path_to_array"]])).count() actual_count = Exploder(**default_params).transform(input_df).count() assert expected_count == actual_count
def test_str_representation_is_correct(self): assert str(Exploder()) == "Transformer Object of Class Exploder"
def test_name_is_set(self): assert Exploder().name == "Exploder"
def test_logger_should_be_accessible(self): assert hasattr(Exploder(), "logger")