def test_get_col_info_error_bad_size(self): with spark_session('test_get_col_info_error_bad_size') as spark: data_bad_size = [[DenseVector([1.0, 1.0])], [DenseVector([1.0])]] schema = StructType([StructField('data', VectorUDT())]) df = create_test_data_from_schema(spark, data_bad_size, schema) with pytest.raises(ValueError): util._get_col_info(df)
def test_get_col_info_error_bad_shape(self): with spark_session('test_get_col_info_error_bad_shape') as spark: data_bad_shape = [[SparseVector(2, {0: 1.0})], [SparseVector(1, {0: 1.0})]] schema = StructType([StructField('data', VectorUDT())]) df = create_test_data_from_schema(spark, data_bad_shape, schema) with pytest.raises(ValueError): util._get_col_info(df)
def test_get_col_info(self): with spark_session('test_get_col_info') as spark: data = [[ 0, 0.0, None, [1, 1], DenseVector([1.0, 1.0]), SparseVector(2, {1: 1.0}), DenseVector([1.0, 1.0]) ], [ 1, None, None, [1, 1], DenseVector([1.0, 1.0]), SparseVector(2, {1: 1.0}), SparseVector(2, {1: 1.0}) ]] schema = StructType([ StructField('int', IntegerType()), StructField('float', FloatType()), StructField('null', NullType()), StructField('array', ArrayType(IntegerType())), StructField('dense', VectorUDT()), StructField('sparse', VectorUDT()), StructField('mixed', VectorUDT()) ]) df = create_test_data_from_schema(spark, data, schema) all_col_types, col_shapes, col_max_sizes = util._get_col_info(df) expected = [ ('int', {int}, 1, 1), ('float', {float, NullType}, 1, 1), ('null', {NullType}, 1, 1), ('array', {list}, 2, 2), ('dense', {DenseVector}, 2, 2), ('sparse', {SparseVector}, 2, 1), ('mixed', {DenseVector, SparseVector}, 2, 2) ] for expected_col_info in expected: col_name, col_types, col_shape, col_size = expected_col_info assert all_col_types[col_name] == col_types, col_name assert col_shapes[col_name] == col_shape, col_name assert col_max_sizes[col_name] == col_size, col_name