def test_train_val_split_ratio(self): with spark_session('test_train_val_split_ratio') as spark: data = [[1.0], [1.0], [1.0], [1.0], [1.0]] schema = StructType([StructField('data', FloatType())]) df = create_test_data_from_schema(spark, data, schema) validation = 0.2 train_df, val_df, validation_ratio = util._train_val_split( df, validation) # Only check validation ratio, as we can't rely on random splitting to produce an exact # result of 4 training and 1 validation samples. assert validation_ratio == validation
def test_train_val_split_col_boolean(self): with spark_session('test_train_val_split_col_boolean') as spark: data = [ [1.0, False], [1.0, False], [1.0, False], [1.0, False], [1.0, True] ] schema = StructType([StructField('data', FloatType()), StructField('val', BooleanType())]) df = create_test_data_from_schema(spark, data, schema) validation = 'val' train_df, val_df, validation_ratio = util._train_val_split(df, validation) # Only check counts as validation ratio cannot be guaranteed due to approx calculation assert train_df.count() == 4 assert val_df.count() == 1