def test_hail_matrix_table_subset_wrong_sample_id_correct(self): # Tests if subsetting with an incorrect sample ID will raise the MatrixTableSampleSetError and return the appropriate wrong id mt = hl.import_vcf(TEST_DATA_MT_1KG) with self.assertRaises(MatrixTableSampleSetError) as e: HailMatrixTableTask.subset_samples_and_variants( mt, self._create_temp_sample_subset_file(mt, 1, True)) self.assertEqual(e.missing_samples, ['wrong_sample'])
def test_mt_sample_type_stats_threshold(self): threshold = 0.5 self._set_validation_configs() # Tested to get under threshold 0.5 of coding variants. coding_under_threshold_ht = hl.read_table( GlobalConfig().validation_37_coding_ht).sample(threshold - 0.2, 0) # Tested to get over threshold 0.5 of non-coding variants. noncoding_over_threshold_ht = hl.read_table( GlobalConfig().validation_37_noncoding_ht).sample( threshold + 0.2, 0) combined_mt = hl.MatrixTable.from_rows_table( coding_under_threshold_ht.union( noncoding_over_threshold_ht).distinct()) # stats should match with noncoding (over threshold) and not coding (under threshold) stats = HailMatrixTableTask.sample_type_stats(combined_mt, '37', threshold) self.assertEqual( stats, { 'noncoding': { 'matched_count': 1545, 'total_count': 2243, 'match': True }, 'coding': { 'matched_count': 118, 'total_count': 359, 'match': False } })
def read_vcf_write_mt(self, schema_cls=SeqrVariantsAndGenotypesSchema): logger.info("Args:") pprint.pprint(self.__dict__) mt = self.import_vcf() mt = self.annotate_old_and_split_multi_hts(mt) if not self.dont_validate: self.validate_mt(mt, self.genome_version, self.sample_type) if self.remap_path: mt = self.remap_sample_ids(mt, self.remap_path) if self.subset_path: mt = self.subset_samples_and_variants(mt, self.subset_path) if self.genome_version == '38': mt = self.add_37_coordinates(mt) mt = HailMatrixTableTask.run_vep(mt, self.genome_version, self.vep_runner, vep_config_json_path=self.vep_config_json_path) ref_data = hl.read_table(self.reference_ht_path) clinvar = hl.read_table(self.clinvar_ht_path) # hgmd is optional. hgmd = hl.read_table(self.hgmd_ht_path) if self.hgmd_ht_path else None mt = schema_cls(mt, ref_data=ref_data, clinvar_data=clinvar, hgmd_data=hgmd).annotate_all( overwrite=True).select_annotated_mt() mt = mt.annotate_globals(sourceFilePath=','.join(self.source_paths), genomeVersion=self.genome_version, sampleType=self.sample_type, hail_version=pkg_resources.get_distribution('hail').version) mt.describe() mt.write(self.output().path, stage_locally=True, overwrite=True)
def validate_mt(mt, genome_version, sample_type): """ Validate the mt by checking against a list of common coding and non-coding variants given its genome version. This validates genome_version, variants, and the reported sample type. :param mt: mt to validate :param genome_version: reference genome version :param sample_type: WGS or WES :return: True or Exception """ sample_type_stats = HailMatrixTableTask.sample_type_stats( mt, genome_version) for name, stat in sample_type_stats.items(): logger.info('Table contains %i out of %i common %s variants.' % (stat['matched_count'], stat['total_count'], name)) has_coding = sample_type_stats['coding']['match'] has_noncoding = sample_type_stats['noncoding']['match'] if not has_coding and not has_noncoding: # No common variants detected. raise SeqrValidationError( 'Genome version validation error: dataset specified as GRCh{genome_version} but doesn\'t contain ' 'the expected number of common GRCh{genome_version} variants'. format(genome_version=genome_version)) elif has_noncoding and not has_coding: # Non coding only. raise SeqrValidationError( 'Sample type validation error: Dataset contains noncoding variants but is missing common coding ' 'variants for GRCh{}. Please verify that the dataset contains coding variants.' .format(genome_version)) elif has_coding and not has_noncoding: # Only coding should be WES. if sample_type != 'WES': raise SeqrValidationError( 'Sample type validation error: dataset sample-type is specified as {} but appears to be ' 'WGS because it contains many common coding variants'. format(sample_type)) elif has_noncoding and has_coding: # Both should be WGS. if sample_type != 'WGS': raise SeqrValidationError( 'Sample type validation error: dataset sample-type is specified as {} but appears to be ' 'WES because it contains many common non-coding variants'. format(sample_type)) return True
def run(self): mt = self.import_vcf() mt = hl.split_multi_hts(mt) if self.validate: self.validate_mt(mt, self.genome_version, self.sample_type) mt = HailMatrixTableTask.run_vep(mt, self.genome_version, self.vep_runner) # We're now adding ref data. ref_data = hl.read_table(self.reference_ht_path) clinvar = hl.read_table(self.clinvar_ht_path) hgmd = hl.read_table(self.hgmd_ht_path) mt = SeqrVariantSchema( mt, ref_data=ref_data, clinvar_data=clinvar, hgmd_data=hgmd).annotate_all(overwrite=True).select_annotated_mt() mt.write(self.output().path, stage_locally=True)
def test_mt_sample_type_stats_1kg_30(self): self._set_validation_configs() mt = hl.import_vcf(TEST_DATA_MT_1KG) stats = HailMatrixTableTask.sample_type_stats(mt, '37', 0.3) self.assertEqual( stats, { 'noncoding': { 'matched_count': 1, 'total_count': 2243, 'match': False }, 'coding': { 'matched_count': 4, 'total_count': 359, 'match': False } })
def _hail_matrix_table_task(self): temp_dest_path = self._temp_dest_path() return HailMatrixTableTask(source_paths=[TEST_DATA_MT_1KG], dest_path=temp_dest_path, genome_version='37')
def test_hail_matrix_table_subset_14(self): # Tests the subset_samples_and_variants function using 14 samples which should leave 29 variants mt = hl.import_vcf(TEST_DATA_MT_1KG) subset_mt = HailMatrixTableTask.subset_samples_and_variants( mt, self._create_temp_sample_subset_file(mt, 14)) self.assertEqual(subset_mt.count(), (29, 14))
def test_hail_matrix_table_remap_1(self): # Tests the remap_sample_id function when a single sample needs to be remapped mt = hl.import_vcf(TEST_DATA_MT_1KG) remap_mt = HailMatrixTableTask.remap_sample_ids( mt, self._create_temp_sample_remap_file(mt, 1)) self.assertEqual(remap_mt.anti_join_cols(mt.cols()).count_cols(), 1)