def test_switch_as_a_column_cases_as_kwargs(self): df = self.spark.createDataFrame( data=[('one', ), ('two', ), ('three', ), ('hi', )], schema=T.StructType([T.StructField('name', T.StringType())]), ) df = df.withColumn( 'value', SF.switch_case(F.col('name'), one=1, two=2, three=3, default=0), ) self.assertDataFrameEqual( df, [ { 'name': 'one', 'value': 1 }, { 'name': 'two', 'value': 2 }, { 'name': 'three', 'value': 3 }, { 'name': 'hi', 'value': 0 }, ], )
def test_default_as_a_column(self): df = self.spark.createDataFrame( data=[('one', ), ('two', ), ('three', ), ('hi', )], schema=T.StructType([T.StructField('name', T.StringType())]), ) df = df.withColumn('value', SF.switch_case('name', default=F.col('name'))) self.assertDataFrameEqual( df, [ { 'name': 'one', 'value': 'one' }, { 'name': 'two', 'value': 'two' }, { 'name': 'three', 'value': 'three' }, { 'name': 'hi', 'value': 'hi' }, ], )
def test_switch_case_with_custom_operand_lt(self): df = self.spark.createDataFrame( data=[(1, ), (2, ), (3, ), (0, )], schema=T.StructType([T.StructField('value', T.IntegerType())]), ) df = df.withColumn( 'value_2', SF.switch_case( 'value', OrderedDict([ (1, 'worst'), (2, 'bad'), (3, 'good'), (4, 'best'), ]), operand=operator.lt, ), ) self.assertDataFrameEqual( df, [ {'value': 1, 'value_2': 'bad'}, {'value': 2, 'value_2': 'good'}, {'value': 3, 'value_2': 'best'}, {'value': 0, 'value_2': 'worst'}, ], )
def test_switch_case_with_custom_operand_between(self): df = self.spark.createDataFrame( data=[(1, ), (2, ), (3, ), (0, )], schema=T.StructType([T.StructField('value', T.IntegerType())]), ) df = df.withColumn( 'value_2', SF.switch_case( 'value', { (1, 1): 'aloha', (2, 3): 'hi', }, operand=lambda c, v: c.between(*v), ), ) self.assertDataFrameEqual( df, [ {'value': 1, 'value_2': 'aloha'}, {'value': 2, 'value_2': 'hi'}, {'value': 3, 'value_2': 'hi'}, {'value': 0, 'value_2': None}, ], )
def test_cases_values_as_a_column(self): df = self.spark.createDataFrame( data=[(1, ), (2, ), (3, ), (0, )], schema=T.StructType([T.StructField('value', T.IntegerType())]), ) df = df.withColumn( 'value_2', SF.switch_case( 'value', { 1: 11 * F.col('value'), 2: F.col('value') * F.col('value'), 'hi': 5, }, default=F.col('value'), ), ) self.assertDataFrameEqual( df, [ {'value': 1, 'value_2': 11}, {'value': 2, 'value_2': 4}, {'value': 3, 'value_2': 3}, {'value': 0, 'value_2': 0}, ], )
def test_cases_condition_constant_as_an_arbitrary_value(self): df = self.spark.createDataFrame( data=[(1, ), (2, ), (3, ), (0, )], schema=T.StructType([T.StructField('value', T.IntegerType())]), ) df = df.withColumn( 'name', SF.switch_case('value', {1: 'one', 2: 'two', 3: 'three'}, default='hi'), ) self.assertDataFrameEqual( df, [ {'name': 'one', 'value': 1}, {'name': 'two', 'value': 2}, {'name': 'three', 'value': 3}, {'name': 'hi', 'value': 0}, ], )