def test_block_variants_and_samples(spark): variant_df = spark.read.format("vcf") \ .load("test-data/combined.chr20_18210071_18210093.g.vcf") \ .withColumn("values", expr("genotype_states(genotypes)")) sample_ids = ["HG00096", "HG00268", "NA19625"] block_gt, index_map = functions.block_variants_and_samples( variant_df, sample_ids, variants_per_block=10, sample_block_count=2) expected_block_gt = glow.transform("block_variants_and_samples", variant_df, variants_per_block=10, sample_block_count=2) assert block_gt.collect() == expected_block_gt.collect() assert index_map == {"1": ["HG00096", "HG00268"], "2": ["NA19625"]}
def block_variants_and_samples( variant_df: DataFrame, sample_ids: List[str], variants_per_block: int, sample_block_count: int) -> (DataFrame, Dict[str, List[str]]): """ Creates a blocked GT matrix and index mapping from sample blocks to a list of corresponding sample IDs. Uses the same sample-blocking logic as the blocked GT matrix transformer. Requires that: - Each variant row has the same number of values - The number of values per row matches the number of sample IDs Args: variant_df : The variant DataFrame sample_ids : The list of sample ID strings variants_per_block : The number of variants per block sample_block_count : The number of sample blocks Returns: tuple of (blocked GT matrix, index mapping) """ assert check_argument_types() distinct_num_values = variant_df.selectExpr( "size(values) as numValues").distinct() distinct_num_values_count = distinct_num_values.count() if distinct_num_values_count == 0: raise Exception("DataFrame has no values.") if distinct_num_values_count > 1: raise Exception("Each row must have the same number of values.") num_values = distinct_num_values.head().numValues if num_values != len(sample_ids): raise Exception( "Number of values does not match between DataFrame and sample ID list." ) __validate_sample_ids(sample_ids) blocked_gt = glow.transform("block_variants_and_samples", variant_df, variants_per_block=variants_per_block, sample_block_count=sample_block_count) index_map = __get_index_map(sample_ids, sample_block_count, variant_df.sql_ctx) output = blocked_gt, index_map assert check_return_type(output) return output