def set_pyguppy_model_attributes(): init_client = self.pyguppy_GuppyBasecallerClient( self.params.pyguppy.config, host=GUPPY_HOST, port=self.params.pyguppy.port, timeout=PYGUPPY_PER_TRY_TIMEOUT, retries=self.pyguppy_retries) try: init_client.connect() init_read = init_client.basecall( ReadData(np.zeros(init_sig_len, dtype=np.int16), 'a'), state=True, trace=True) except (TimeoutError, self.zmqAgainError): raise mh.MegaError( 'Failed to run test read with guppy. See guppy logs in ' + '--output-directory.') init_client.disconnect() if init_read.model_type not in COMPAT_GUPPY_MODEL_TYPES: raise mh.MegaError(( 'Megalodon is not compatible with guppy model type: ' + '{}').format(init_read.model_type)) self.stride = init_read.model_stride self.ordered_mod_long_names = init_read.mod_long_names self.output_alphabet = init_read.mod_alphabet self.output_size = init_read.state_size if self.ordered_mod_long_names is None: self.ordered_mod_long_names = [] if self.output_alphabet is None: self.output_alphabet = mh.ALPHABET
def iter_overlapping_snps(self, r_ref_pos, edge_buffer): """Iterator over SNPs overlapping the read mapped position. SNPs within edge buffer of the end of the mapping will be ignored. """ if r_ref_pos.end - r_ref_pos.start <= 2 * edge_buffer: raise mh.MegaError('Mapped region too short for SNP calling.') try: fetch_res = self.variants_idx.fetch(r_ref_pos.chrm, r_ref_pos.start + edge_buffer, r_ref_pos.end - edge_buffer) except ValueError: raise mh.MegaError('Mapped location not valid for variants file.') for variant in fetch_res: snp_ref_seq = variant.ref snp_alt_seqs = variant.alts # skip SNPs larger than specified limit if self.max_indel_size is not None and max( np.abs(len(snp_ref_seq) - len(snp_alt_seq)) for snp_alt_seq in snp_alt_seqs) > self.max_indel_size: continue # convert to 0-based coordinates yield snp_ref_seq, snp_alt_seqs, variant.id, variant.pos - 1 return
def check_matching_attrs(ground_truth_bed, strand_offset, mod_db_fn, target_mod_bases, limit=10000): mods_db = mods.ModsDb(mod_db_fn) db_strands = (1, -1) if strand_offset is None else (None, ) db_chrms = set((chrm, strand) for _, chrm, _ in mods_db.iter_chrms() for strand in db_strands) cov, mod_cov = mh.parse_bed_methyls([ ground_truth_bed, ], strand_offset, show_prog_bar=False, limit=limit) if len(db_chrms.intersection(cov.keys())) == 0: LOGGER.error(('Using first {} sites from {}, found zero overlapping ' + 'contig/chromosome names with the mod database.').format( limit, ground_truth_bed)) LOGGER.info('Database contigs/chromosomes: {}'.format(', '.join( map(str, db_chrms)))) LOGGER.info('BED methyl contigs/chromosomes: {}'.format(', '.join( map(str, list(cov.keys()))))) raise mh.MegaError('No overlapping contigs found.') db_mods = set(mod_base for mod_base, _ in mods_db.get_mod_long_names()) for tmb in target_mod_bases: if tmb not in db_mods: raise mh.MegaError( ('Target modified base, {}, not found in mods database ' + '({}).').format(tmb, ', '.join(db_mods))) mods_db.check_data_covering_index_exists() mods_db.close()
def check_map_sig_alphabet(model_info, ms_fn): # read filename queue filler msf = mapped_signal_files.HDF5Reader(ms_fn) tai_alph_info = msf.get_alphabet_information() msf.close() if model_info.output_alphabet != tai_alph_info.alphabet: raise mh.MegaError( ( "Different alphabets specified in model ({}) and mapped " + "signal file ({})" ).format(model_info.output_alphabet, tai_alph_info.alphabet) ) if set(model_info.can_alphabet) != set(tai_alph_info.collapse_alphabet): raise mh.MegaError( ( "Different canonical alphabets specified in model ({}) and " + "mapped signal file ({})" ).format(model_info.can_alphabet, tai_alph_info.collapse_alphabet) ) if model_info.ordered_mod_long_names != tai_alph_info.mod_long_names: raise mh.MegaError( ( "Different modified base long names specified in model ({}) and " + "mapped signal file ({})" ).format( ", ".join(model_info.ordered_mod_long_names), ", ".join(tai_alph_info.mod_long_names), ) )
def run_model(self, raw_sig, n_can_state=None): if self.model_type == TAI_NAME: if any(arg is None for arg in (self.chunk_size, self.chunk_overlap, self.max_concur_chunks)): logger = logging.get_logger() logger.error('Must provide chunk_size, chunk_overlap, ' + 'max_concur_chunks in order to run the taiyaki ' + 'base calling backend.') try: trans_weights = self.tai_run_model(raw_sig, self.model, self.chunk_size, self.chunk_overlap, self.max_concur_chunks) except AttributeError: raise mh.MegaError('Out of date or incompatible model') except RuntimeError: raise mh.MegaError('Likely out of memory error.') if self.device != self.torch.device('cpu'): self.torch.cuda.empty_cache() if n_can_state is not None: trans_weights = (np.ascontiguousarray( trans_weights[:, :n_can_state]), np.ascontiguousarray( trans_weights[:, n_can_state:])) else: raise mh.MegaError('Invalid model type.') return trans_weights
def map_read(q_seq, read_id, caller_conn, signal_reversed=False): """Map read (query) sequence and return: 1) reference sequence (endcoded as int labels) 2) mapping from reference to read positions (after trimming) 3) reference mapping position (including read trimming positions) 4) cigar as produced by mappy """ # send seq to _map_read_worker and receive mapped seq and pos if signal_reversed: q_seq = q_seq[::-1] caller_conn.send((q_seq, read_id)) r_ref_seq, r_algn = caller_conn.recv() if r_ref_seq is None: raise mh.MegaError('No alignment') chrm, strand, r_st, r_en, q_st, q_en, r_cigar = r_algn if signal_reversed: q_st, q_en = len(q_seq) - q_en, len(q_seq) - q_st r_ref_seq = r_ref_seq[::-1] r_cigar = r_cigar[::-1] try: r_to_q_poss = parse_cigar(r_cigar, strand, r_en - r_st) except mh.MegaError as e: LOGGER.debug('Read {} ::: '.format(read_id) + str(e)) raise mh.MegaError('Invalid cigar string encountered.') r_pos = MAP_POS(chrm=chrm, strand=strand, start=r_st, end=r_en, q_trim_start=q_st, q_trim_end=q_en) return r_ref_seq, r_to_q_poss, r_pos, r_cigar
def call_alt_true_indel(indel_size, r_snp_pos, true_ref_seq, r_seq, map_thr_buf, context_bases, r_post, rl_cumsum, all_paths): def run_aligner(): return next( mappy.Aligner(seq=false_ref_seq, preset=str('map-ont'), best_n=1).map(str(r_seq), buf=map_thr_buf)) if indel_size == 0: false_base = choice( list(set(CAN_BASES).difference(true_ref_seq[r_snp_pos]))) false_ref_seq = (true_ref_seq[:r_snp_pos] + false_base + true_ref_seq[r_snp_pos + 1:]) snp_ref_seq = false_base snp_alt_seq = true_ref_seq[r_snp_pos] elif indel_size > 0: # test alt truth reference insertion false_ref_seq = (true_ref_seq[:r_snp_pos + 1] + true_ref_seq[r_snp_pos + indel_size + 1:]) snp_ref_seq = true_ref_seq[r_snp_pos] snp_alt_seq = true_ref_seq[r_snp_pos:r_snp_pos + indel_size + 1] else: # test alt truth reference deletion deleted_seq = ''.join(choice(CAN_BASES) for _ in range(-indel_size)) false_ref_seq = (true_ref_seq[:r_snp_pos + 1] + deleted_seq + true_ref_seq[r_snp_pos + 1:]) snp_ref_seq = true_ref_seq[r_snp_pos] + deleted_seq snp_alt_seq = true_ref_seq[r_snp_pos] try: r_algn = run_aligner() except StopIteration: raise mh.MegaError('No alignment') r_ref_seq = false_ref_seq[r_algn.r_st:r_algn.r_en] if r_algn.strand == -1: raise mh.MegaError('Indel mapped read mapped to reverse strand.') r_to_q_poss = mapping.parse_cigar(r_algn.cigar, r_algn.strand) if (r_algn.r_st > r_snp_pos - context_bases[1] or r_algn.r_en < r_snp_pos + context_bases[1]): raise mh.MegaError('Indel mapped read clipped snp position.') post_mapped_start = rl_cumsum[r_algn.q_st] mapped_rl_cumsum = rl_cumsum[r_algn.q_st:r_algn.q_en + 1] - post_mapped_start score = call_snp(r_post, post_mapped_start, r_snp_pos, rl_cumsum, r_to_q_poss, snp_ref_seq, snp_alt_seq, context_bases, all_paths, ref_seq=r_ref_seq) return score, snp_ref_seq, snp_alt_seq
def __init__(self, variant_fn, max_indel_size, all_paths, write_snps_txt, context_bases, snps_calib_fn=None, call_mode=DIPLOID_MODE, do_pr_ref_snps=False, aligner=None, keep_snp_fp_open=False): logger = logging.get_logger('snps') self.max_indel_size = max_indel_size self.all_paths = all_paths self.write_snps_txt = write_snps_txt self.snps_calib_fn = snps_calib_fn self.calib_table = calibration.SnpCalibrator(self.snps_calib_fn) self.context_bases = context_bases if len(self.context_bases) != 2: raise mh.MegaError( 'Must provide 2 context bases values (for single base SNPs ' + 'and indels).') self.call_mode = call_mode self.do_pr_ref_snps = do_pr_ref_snps self.variant_fn = variant_fn self.variants_idx = None if self.variant_fn is None: return logger.info('Loading variants.') vars_idx = pysam.VariantFile(self.variant_fn) try: contigs = list(vars_idx.header.contigs.keys()) vars_idx.fetch(next(iter(contigs)), 0, 0) except ValueError: logger.warn( 'Variants file must be indexed. Performing indexing now.') vars_idx.close() self.variant_fn = index_variants(self.variant_fn) vars_idx = pysam.VariantFile(self.variant_fn) if keep_snp_fp_open: self.variants_idx = vars_idx else: vars_idx.close() self.variants_idx = None if aligner is None: raise mh.MegaError( 'Must provide aligner if SNP filename is provided') if len(set(aligner.ref_names_and_lens[0]).intersection(contigs)) == 0: raise mh.MegaError(( 'Reference and variant files contain no chromosomes/contigs ' + 'in common.\n\t\tFirst 3 reference contigs:\t{}\n\t\tFirst 3 ' + 'variant file contigs:\t{}').format( ', '.join(aligner.ref_names_and_lens[0][:3]), ', '.join(contigs[:3]))) return
def compute_snp_stats(self, snp_loc, het_factors, call_mode=DIPLOID_MODE, valid_read_ids=None): assert call_mode in (HAPLIOD_MODE, DIPLOID_MODE), ( 'Invalid SNP aggregation ploidy call mode: {}.'.format(call_mode)) pr_snp_stats = self.get_per_read_snp_stats(snp_loc) alt_seqs = sorted(set(r_stats.alt_seq for r_stats in pr_snp_stats)) pr_alt_lps = defaultdict(dict) for r_stats in pr_snp_stats: if (valid_read_ids is not None and r_stats.read_id not in valid_read_ids): continue pr_alt_lps[r_stats.read_id][r_stats.alt_seq] = r_stats.score if len(pr_alt_lps) == 0: raise mh.MegaError('No valid reads cover SNP') alt_seq_lps = [[] for _ in range(len(alt_seqs))] for read_lps in pr_alt_lps.values(): for i, alt_seq in enumerate(alt_seqs): try: alt_seq_lps[i].append(read_lps[alt_seq]) except KeyError: raise mh.MegaError( 'Alternative SNP seqence must exist for all reads.') alts_lps = np.stack(alt_seq_lps, axis=0) with np.errstate(all='ignore'): ref_lps = np.log1p(-np.exp(alts_lps).sum(axis=0)) r0_stats = pr_snp_stats[0] snp_var = Variant(chrom=r0_stats.chrm, pos=r0_stats.pos, ref=r0_stats.ref_seq, alts=alt_seqs, id=r0_stats.snp_id) snp_var.add_tag('DP', '{}'.format(ref_lps.shape[0])) snp_var.add_sample_field('DP', '{}'.format(ref_lps.shape[0])) if self.write_vcf_log_probs: snp_var.add_sample_field( 'LOG_PROBS', ','.join(';'.join('{:.2f}'.format(lp) for lp in alt_i_lps) for alt_i_lps in alts_lps)) if call_mode == DIPLOID_MODE: het_factor = (het_factors[0] if len(r0_stats.ref_seq) == 1 and len(r0_stats.alt_seq) == 1 else het_factors[1]) diploid_probs, gts = self.compute_diploid_probs( ref_lps, alts_lps, het_factor) snp_var.add_diploid_probs(diploid_probs, gts) elif call_mode == HAPLIOD_MODE: haploid_probs, gts = self.compute_haploid_probs(ref_lps, alts_lps) snp_var.add_haploid_probs(haploid_probs, gts) return snp_var
def map_read( caller_conn, called_read, sig_info, mo_q=None, signal_reversed=False, rl_cumsum=None): """ Map read (query) sequence Returns: Tuple containing 1) reference sequence (endcoded as int labels) 2) mapping from reference to read positions (after trimming) 3) reference mapping position (including read trimming positions) 4) cigar as produced by mappy """ # send seq to _map_read_worker and receive mapped seq and pos q_seq = called_read.seq[::-1] if signal_reversed else called_read.seq caller_conn.send((q_seq, sig_info.read_id)) map_res = caller_conn.recv() if map_res is None: raise mh.MegaError('No alignment') map_res = MAP_RES(*map_res) # add signal coordinates to mapping output if run-length cumsum provided if rl_cumsum is not None: # convert query start and end to signal-anchored locations # Note that for signal_reversed reads, the start will be larger than # the end q_st = len(map_res.q_seq) - map_res.q_st if signal_reversed else \ map_res.q_st q_en = len(map_res.q_seq) - map_res.q_en if signal_reversed else \ map_res.q_en map_res = map_res._replace( map_sig_start=called_read.trimmed_samples + rl_cumsum[q_st] * sig_info.stride, map_sig_end=called_read.trimmed_samples + rl_cumsum[q_en] * sig_info.stride, sig_len=called_read.trimmed_samples + rl_cumsum[-1] * sig_info.stride) if mo_q is not None: mo_q.put(tuple(map_res)) if signal_reversed: # if signal is reversed compared to mapping, reverse coordinates so # they are relative to signal/state_data map_res = map_res._replace( q_st=len(map_res.q_seq) - map_res.q_en, q_en=len(map_res.q_seq) - map_res.q_st, ref_seq=map_res.ref_seq[::-1], cigar=map_res.cigar[::-1]) try: r_to_q_poss = parse_cigar( map_res.cigar, map_res.strand, map_res.r_en - map_res.r_st) except mh.MegaError as e: LOGGER.debug('{} CigarParsingError'.format(sig_info.read_id) + str(e)) raise mh.MegaError('Invalid cigar string encountered.') map_pos = get_map_pos_from_res(map_res) return map_res.ref_seq, r_to_q_poss, map_pos, map_res.cigar
def run_pyguppy_model( self, sig_info, return_post_w_mods, return_mod_scores, update_sig_info): if self.model_type != PYGUPPY_NAME: raise mh.MegaError( 'Attempted to run pyguppy model with non-pyguppy ' + 'initialization.') post_w_mods = mods_scores = None try: called_read = self.client.basecall( self.pyguppy_ReadData(sig_info.dacs, sig_info.read_id), state=True, trace=True) except TimeoutError: raise mh.MegaError( 'Pyguppy server timeout. See --guppy-timeout option') # compute rl_cumsum from move table rl_cumsum = np.where(called_read.move)[0] rl_cumsum = np.insert(rl_cumsum, rl_cumsum.shape[0], called_read.move.shape[0]) if self.is_cat_mod: # split canonical posteriors and mod transition weights can_post = np.ascontiguousarray( called_read.state[:, :self.n_can_state]) if return_mod_scores or return_post_w_mods: mods_weights = self._softmax_mod_weights( called_read.state[:, self.n_can_state:]) if return_post_w_mods: post_w_mods = np.concatenate( [can_post, mods_weights], axis=1) if return_mod_scores: # TODO apply np.NAN mask to scores not applicable to # canonical basecalls mods_scores = np.ascontiguousarray( mods_weights[rl_cumsum[:-1]]) else: can_post = called_read.state if update_sig_info: # add scale_params and trimmed dacs to sig_info trimmed_dacs = sig_info.dacs[called_read.trimmed_samples:] # guppy does not apply the med norm factor scale_params = ( called_read.scaling['median'], called_read.scaling['med_abs_dev'] * mh.MED_NORM_FACTOR) sig_info = sig_info._replace( raw_len=trimmed_dacs.shape[0], dacs=trimmed_dacs, raw_signal=((trimmed_dacs - scale_params[0]) / scale_params[1]).astype(np.float32), scale_params=scale_params) return (called_read.seq, called_read.qual, rl_cumsum, can_post, sig_info, post_w_mods, mods_scores)
def start_guppy_server(): def get_server_port(): next_line = guppy_out_read_fp.readline() if next_line is None: return None try: return int(GUPPY_PORT_PAT.search(next_line).groups()[0]) except AttributeError: return None # set guppy logs output locations self.guppy_log = os.path.join( self.params.pyguppy.out_dir, GUPPY_LOG_BASE) self.guppy_out_fp = open(self.guppy_log + '.out', 'w') guppy_out_read_fp = open(self.guppy_log + '.out', 'r') self.guppy_err_fp = open(self.guppy_log + '.err', 'w') # prepare args to start guppy server server_args = [ self.params.pyguppy.bin_path, '-p', str(self.params.pyguppy.port), '-l', self.guppy_log, '-c', self.params.pyguppy.config, '--post_out'] if self.params.pyguppy.devices is not None and \ len(self.params.pyguppy.devices) > 0 and \ self.params.pyguppy.devices[0] != 'cpu': devices_str = ' '.join( parse_device(device) for device in self.params.pyguppy.devices) server_args.extend(('-x', devices_str)) if self.params.pyguppy.server_params is not None: server_args.extend(self.params.pyguppy.server_params.split()) try: self.guppy_server_proc = subprocess.Popen( server_args, shell=False, stdout=self.guppy_out_fp, stderr=self.guppy_err_fp) except FileNotFoundError: raise mh.MegaError( 'Guppy server executable not found. Please specify path ' + 'via `--guppy-server-path` argument.') # wait until server is successfully started or fails while True: used_port = get_server_port() if used_port is not None: break if self.guppy_server_proc.poll() is not None: raise mh.MegaError( 'Guppy server initialization failed. See guppy logs ' + 'in --output-directory for more details.') sleep(0.01) guppy_out_read_fp.close() self.params = self.params._replace( pyguppy=self.params.pyguppy._replace(port=used_port))
def extract_mods(in_mod_db_fns): LOGGER.info('Merging mod tables') # collect modified base definitions from input databases mod_base_to_can = dict() for in_mod_db_fn in tqdm(in_mod_db_fns, desc='Databases', unit='DBs', smoothing=0, dynamic_ncols=True): mods_db = mods.ModsDb(in_mod_db_fn) for mod_base, can_base, mln in mods_db.get_full_mod_data(): if mod_base in mod_base_to_can and \ (can_base, mln) != mod_base_to_can[mod_base]: raise mh.MegaError( 'Modified base associated with mutliple canonical bases ' + 'or long names in different databases. {} != {}'.format( str((can_base, mln)), str(mod_base_to_can[mod_base]))) mod_base_to_can[mod_base] = (can_base, mln) # check that mod long names are unique mlns = [mln for _, mln in mod_base_to_can.values()] if len(mlns) != len(set(mlns)): raise mh.MegaError( 'Modified base long name assigned to more than one modified ' + 'base single letter code.') # extract canonical bases associated with modified base can_bases = set(can_base for can_base, _ in mod_base_to_can.values()) # determine first valid canonical alphabet compatible with database can_alphabet = None for v_alphabet in mh.VALID_ALPHABETS: if len(can_bases.difference(v_alphabet)) == 0: can_alphabet = v_alphabet break if can_alphabet is None: LOGGER.error( 'Mods database does not contain valid canonical bases ({})'.format( ''.join(sorted(can_bases)))) raise mh.MegaError('Invalid alphabet.') # compute full output alphabet and ordered modified base long names can_base_to_mods = dict( (can_base, [(mod_base, mln) for mod_base, (mcan_base, mln) in mod_base_to_can.items() if mcan_base == can_base]) for can_base in can_alphabet) alphabet = '' mod_long_names = [] for can_base in can_alphabet: alphabet += can_base for mod_base, mln in can_base_to_mods[can_base]: alphabet += mod_base mod_long_names.append(mln) return alphabet, mod_long_names
def get_remapping(sig_fn, dacs, scale_params, ref_seq, stride, read_id, r_to_q_poss, rl_cumsum, r_ref_pos, ref_out_info): read = fast5_interface.get_fast5_file(sig_fn, 'r').get_read(read_id) channel_info = dict(fast5utils.get_channel_info(read).items()) rd_factor = channel_info['range'] / channel_info['digitisation'] shift_from_pA = (scale_params[0] + channel_info['offset']) * rd_factor scale_from_pA = scale_params[1] * rd_factor read_attrs = dict(fast5utils.get_read_attributes(read).items()) # prepare taiyaki signal object sig = tai_signal.Signal(dacs=dacs) sig.channel_info = channel_info sig.read_attributes = read_attrs sig.offset = channel_info['offset'] sig.range = channel_info['range'] sig.digitisation = channel_info['digitisation'] path = np.full((dacs.shape[0] // stride) + 1, -1) # skip last value since this is where the two seqs end for ref_pos, q_pos in enumerate(r_to_q_poss[:-1]): # if the query position maps to the end of the mapping skip it if rl_cumsum[q_pos + r_ref_pos.q_trim_start] >= path.shape[0]: continue path[rl_cumsum[q_pos + r_ref_pos.q_trim_start]] = ref_pos remapping = tai_mapping.Mapping.from_remapping_path( sig, path, ref_seq, stride) try: remapping.add_integer_reference(ref_out_info.alphabet) except Exception: raise mh.MegaError('Invalid reference sequence encountered') return (remapping.get_read_dictionary(shift_from_pA, scale_from_pA, read_id), prepare_mapping_funcs.RemapResult.SUCCESS)
def get_read_id(self, uuid): try: read_id = self.cur.execute('SELECT read_id FROM read WHERE uuid=?', (uuid, )).fetchone()[0] except TypeError: raise mh.MegaError('Read ID not found in mods data base.') return read_id
def prep_model_worker(self, device): """ Load model onto a newly spawned process """ if self.model_type == TAI_NAME: # setup for taiyaki model self.model = self.load_taiyaki_model( self.params.taiyaki.taiyaki_model_fn) if device is None or device == 'cpu': self.device = self.torch.device('cpu') else: sleep(np.random.uniform(0, MAX_DEVICE_WAIT)) try: self.device = self.torch.device(device) self.torch.cuda.set_device(self.device) self.model = self.model.to(self.device) except RuntimeError: LOGGER.error('Invalid CUDA device: {}'.format(device)) raise mh.MegaError('Error setting CUDA GPU device.') self.model = self.model.eval() elif self.model_type == PYGUPPY_NAME: # open guppy client interface (None indicates using config # from server) self.client = self.pyguppy_GuppyBasecallerClient( self.params.pyguppy.config, host=GUPPY_HOST, port=self.params.pyguppy.port, timeout=PYGUPPY_PER_TRY_TIMEOUT, retries=self.pyguppy_retries) self.client.connect() return
def main(): args = get_parser().parse_args() vars0_idx = pysam.VariantFile(args.diploid_called_variants) vars1_idx = pysam.VariantFile(args.haplotype1_variants) vars2_idx = pysam.VariantFile(args.haplotype2_variants) try: contigs0 = list(vars0_idx.header.contigs.keys()) vars0_idx.fetch(next(iter(contigs0)), 0, 0) contigs1 = list(vars1_idx.header.contigs.keys()) vars1_idx.fetch(next(iter(contigs1)), 0, 0) contigs2 = list(vars2_idx.header.contigs.keys()) vars2_idx.fetch(next(iter(contigs2)), 0, 0) except ValueError: raise mh.MegaError( 'Variants file must be indexed. Use bgzip and tabix.') out_vars = open(args.out_vcf, 'w') out_vars.write( HEADER.format('\n'.join( (CONTIG_HEADER_LINE.format(ctg.name, ctg.length) for ctg in vars0_idx.header.contigs.values())))) for contig in set(contigs0).intersection(contigs1).intersection(contigs2): for curr_v0_rec, curr_v1_rec, curr_v2_rec in iter_contig_vars( iter(vars0_idx.fetch(contig)), iter(vars1_idx.fetch(contig)), iter(vars2_idx.fetch(contig))): if curr_v0_rec is None: continue write_var(curr_v0_rec, curr_v1_rec, curr_v2_rec, out_vars, contig) out_vars.close() variants.index_variants(args.out_vcf)
def extract_llrs(llr_fn, max_indel_len=None): snp_ref_llrs, ins_ref_llrs, del_ref_llrs = (defaultdict(list) for _ in range(3)) with open(llr_fn) as llr_fp: for line in llr_fp: is_ref_correct, llr, ref_seq, alt_seq = line.split() llr = float(llr) if is_ref_correct != 'True': continue if np.isnan(llr): continue if max_indel_len is not None and \ np.abs(len(ref_seq) - len(alt_seq)) > max_indel_len: continue if len(ref_seq) == 1 and len(alt_seq) == 1: snp_ref_llrs[(ref_seq, alt_seq)].append(llr) else: if len(ref_seq) > len(alt_seq): del_ref_llrs[len(ref_seq) - len(alt_seq)].append(llr) else: ins_ref_llrs[len(alt_seq) - len(ref_seq)].append(llr) if min(len(snp_ref_llrs), len(ins_ref_llrs), len(del_ref_llrs)) == 0: raise mh.MegaError( 'Variant statistics file does not contain sufficient data for ' + 'calibration.') return snp_ref_llrs, ins_ref_llrs, del_ref_llrs
def parse_cigar(r_cigar, strand, ref_len): fill_invalid = -1 # get each base calls genomic position r_to_q_poss = np.full(ref_len + 1, fill_invalid, dtype=np.int32) # process cigar ops in read direction curr_r_pos, curr_q_pos = 0, 0 cigar_ops = r_cigar if strand == 1 else r_cigar[::-1] for op_len, op in cigar_ops: if op == 1: # inserted bases into ref curr_q_pos += op_len elif op in (2, 3): # deleted ref bases for r_pos in range(curr_r_pos, curr_r_pos + op_len): r_to_q_poss[r_pos] = curr_q_pos curr_r_pos += op_len elif op in (0, 7, 8): # aligned bases for op_offset in range(op_len): r_to_q_poss[curr_r_pos + op_offset] = curr_q_pos + op_offset curr_q_pos += op_len curr_r_pos += op_len elif op == 6: # padding (shouldn't happen in mappy) pass r_to_q_poss[curr_r_pos] = curr_q_pos if r_to_q_poss[-1] == fill_invalid: raise mh.MegaError( ('Invalid cigar string encountered. Reference length: {} Cigar ' + 'implied reference length: {}').format(ref_len, curr_r_pos)) return r_to_q_poss
def map_read(caller_conn, called_read, sig_info, mo_q=None, signal_reversed=False, rl_cumsum=None): """ Map read (query) sequence Returns: Tuple containing 1) reference sequence (endcoded as int labels) 2) mapping from reference to read positions (after trimming) 3) reference mapping position (including read trimming positions) 4) cigar as produced by mappy """ # send seq to _map_read_worker and receive mapped seq and pos q_seq = called_read.seq[::-1] if signal_reversed else called_read.seq caller_conn.send((q_seq, sig_info.read_id)) map_ress = caller_conn.recv() if map_ress is None: raise mh.MegaError('No alignment') return [ process_mapping(map_res, called_read, sig_info, mo_q, signal_reversed, rl_cumsum) for map_res in map_ress ]
def _load_fast5_post_out(self): def get_model_info_from_fast5(read): try: stride = fast5_io.get_stride(read) mod_long_names, out_alphabet = fast5_io.get_mod_base_info(read) out_size = fast5_io.get_posteriors(read).shape[1] mod_long_names = mod_long_names.split() except KeyError: LOGGER.error( 'Fast5 read does not contain required attributes.') raise mh.MegaError( 'Fast5 read does not contain required attributes.') return stride, mod_long_names, out_alphabet, out_size LOGGER.info('Loading FAST5 basecalling backend.') self.model_type = FAST5_NAME self.process_devices = [None, ] * self.num_proc read_iter = fast5_io.iterate_fast5_reads(self.params.fast5.fast5s_dir) nreads = 0 try: fast5_fn, read_id = next(read_iter) read = fast5_io.get_read(fast5_fn, read_id) (self.stride, self.ordered_mod_long_names, self.output_alphabet, self.output_size) = get_model_info_from_fast5(read) except StopIteration: LOGGER.error('No reads found.') raise mh.MegaError('No reads found.') for fast5_fn, read_id in read_iter: read = fast5_io.get_read(fast5_fn, read_id) r_s, r_omln, r_oa, r_os = get_model_info_from_fast5(read) if ( self.stride != r_s or self.ordered_mod_long_names != r_omln or self.output_alphabet != r_oa or self.output_size != r_os): LOGGER.error( 'Model information from FAST5 files is inconsistent. ' + 'Assure all reads were called with the same model.') raise mh.MegaError( 'Model information from FAST5 files is inconsistent.') nreads += 1 if nreads >= self.params.fast5.num_startup_reads: break self._parse_minimal_alphabet_info()
def get_chrm(self, chrm_id): try: chrm = self.cur.execute('SELECT chrm FROM chrm WHERE chrm_id=?', (chrm_id, )).fetchone()[0] except TypeError: raise mh.MegaError('Reference record (chromosome) not found in ' + 'mods database.') return chrm
def get_mapping_mode(map_fmt): if map_fmt == 'bam': return 'wb' elif map_fmt == 'cram': return 'wc' elif map_fmt == 'sam': return 'w' raise mh.MegaError('Invalid mapping output format: {}'.format(map_fmt))
def get_mapping_mode(map_fmt): if map_fmt == "bam": return "wb" elif map_fmt == "cram": return "wc" elif map_fmt == "sam": return "w" raise mh.MegaError("Invalid mapping output format: {}".format(map_fmt))
def __init__( self, conn, size, max_size=mh._MAX_QUEUE_SIZE, name='ConnWithSize', full_sleep_time=_FULL_SLEEP_TIME): if not isinstance(conn, mp.connection.Connection): raise mh.MegaError(( 'ConnWithSize initialized with non-connection object. ' + 'Object type: {}').format(type(conn))) if not isinstance(size, mp.sharedctypes.Synchronized) and \ isinstance(size.value, int): raise mh.MegaError(( 'ConnWithSize initialized with non-synchronized size ' + 'object. Object type: {}').format(type(size))) self._conn = conn self._size = size self.max_size = max_size self.full_sleep_time = full_sleep_time self.name = name
def basecall_read( self, sig_info, return_post_w_mods=True, return_mod_scores=False, update_sig_info=False): if self.model_type not in (TAI_NAME, FAST5_NAME, PYGUPPY_NAME): raise mh.MegaError('Invalid model backend') # decoding is performed within pyguppy server, so shortcurcuit return # here as other methods require megalodon decoding. if self.model_type == PYGUPPY_NAME: return self.run_pyguppy_model( sig_info, return_post_w_mods, return_mod_scores, update_sig_info) post_w_mods = mod_weights = None if self.model_type == TAI_NAME: # run neural network with taiyaki if self.is_cat_mod: bc_weights, mod_weights = self.run_taiyaki_model( sig_info.raw_signal, self.n_can_state) else: bc_weights = self.run_taiyaki_model(sig_info.raw_signal) # perform forward-backward algorithm on neural net output can_post = decode.crf_flipflop_trans_post(bc_weights, log=True) if return_post_w_mods and self.is_cat_mod: post_w_mods = np.concatenate([can_post, mod_weights], axis=1) # set mod_weights to None if mod_scores not requested to # avoid extra computation if not return_mod_scores: mod_weights = None else: # FAST5 stored posteriors backend if self.is_cat_mod: # split canonical posteriors and mod transition weights # producing desired return arrays can_post = np.ascontiguousarray( sig_info.posteriors[:, :self.n_can_state]) if return_mod_scores or return_post_w_mods: # convert raw neural net mod weights to softmax weights mod_weights = self._softmax_mod_weights( sig_info.posteriors[:, self.n_can_state:]) if return_post_w_mods: post_w_mods = np.concatenate( [can_post, mod_weights], axis=1) if not return_mod_scores: mod_weights = None else: can_post = sig_info.posteriors # decode posteriors to sequence and per-base mod scores r_seq, _, rl_cumsum, mods_scores = decode.decode_post( can_post, self.can_alphabet, mod_weights, self.can_nmods) # TODO implement quality extraction for taiyaki and fast5 modes r_qual = None return (r_seq, r_qual, rl_cumsum, can_post, sig_info, post_w_mods, mods_scores)
def compute_mod_sites_stats( mod_stats, ctrl_stats, balance_classes, mod_base, samp_lab, vs_lab, out_fp): if balance_classes: # randomly downsample sample with more observations if mod_stats.shape[0] > ctrl_stats.shape[0]: mod_stats = np.random.choice( mod_stats, ctrl_stats.shape[0], replace=False) elif mod_stats.shape[0] < ctrl_stats.shape[0]: ctrl_stats = np.random.choice( ctrl_stats, mod_stats.shape[0], replace=False) is_can = np.repeat([0, 1], [mod_stats.shape[0], ctrl_stats.shape[0]]) all_stats = np.concatenate([mod_stats, ctrl_stats]) if any(np.isnan(all_stats)): LOGGER.warning(( 'Encountered {} NaN modified base scores.').format( sum(np.isnan(all_stats)))) all_stats = all_stats[~np.isnan(all_stats)] if all_stats.shape[0] == 0: raise mh.MegaError('All modified base scores are NaN') inf_idx = np.isinf(all_stats) if any(inf_idx): all_stats[inf_idx] = np.max(all_stats[~inf_idx]) neginf_idx = np.isinf(all_stats) if any(neginf_idx): all_stats[neginf_idx] = np.min(all_stats[~neginf_idx]) LOGGER.info( 'Computing PR/ROC for {} from {} at {}'.format( mod_base, samp_lab, vs_lab)) # compute roc and presicion recall precision, recall, thresh = precision_recall_curve(is_can, all_stats) prec_recall_sum = precision + recall valid_idx = np.where(prec_recall_sum > 0) all_f1 = (2 * precision[valid_idx] * recall[valid_idx] / prec_recall_sum[valid_idx]) optim_f1_idx = np.argmax(all_f1) optim_f1 = all_f1[optim_f1_idx] optim_thresh = thresh[optim_f1_idx] avg_prcn = average_precision_score(is_can, all_stats) fpr, tpr, _ = roc_curve(is_can, all_stats) roc_auc = auc(fpr, tpr) out_fp.write( MOD_VAL_METRICS_TMPLT.format( optim_f1, optim_thresh, avg_prcn, roc_auc, mod_stats.shape[0], ctrl_stats.shape[0], mod_base, samp_lab, vs_lab)) pr_data = ('{} at {} mAP={:0.2f}'.format( samp_lab, vs_lab, avg_prcn), precision, recall) roc_data = ('{} at {} AUC={:0.2f}'.format( samp_lab, vs_lab, roc_auc), fpr, tpr) kde_data = ('{} from {} at {}'.format( mod_base, samp_lab, vs_lab), mod_stats, ctrl_stats) return pr_data, roc_data, kde_data
def test_open_alignment_out_file(out_dir, map_fmt, ref_names_and_lens, ref_fn): try: map_fp = open_alignment_out_file(out_dir, map_fmt, ref_names_and_lens, ref_fn) except ValueError: raise mh.MegaError( 'Failed to open alignment file for writing. Check that ' + 'reference file is compressed with bgzip for CRAM output.') map_fp.close() os.remove(map_fp.filename)
def compute_sig_band(bps, levels, bhw=mh.DEFAULT_CONSTRAINED_HALF_BW): """Compute band over which to explore possible paths. Band is represented in sequence/level coordinates at each signal position. Args: bps (np.ndarray): Integer array containing breakpoints levels (np.ndarray): float array containing expected signal levels. May contain np.NAN values. Band will be constructed to maintain path through NAN regions. bhw (int): Band half width. If None, full matrix is used. Returns: int32 np.ndarray with shape (2, sig_len = bps[-1] - bps[0]). The first row contains the lower band boundaries in sequence coordinates and the second row contains the upper boundaries in sequence coordinates. """ seq_len = levels.shape[0] if bps.shape[0] - 1 != seq_len: raise mh.MegaError("Breakpoints must be one longer than levels.") sig_len = bps[-1] - bps[0] seq_indices = np.repeat(np.arange(seq_len), np.diff(bps)) # Calculate bands # The 1st row consists of the start indices (inc) and the 2nd row # consists of the end indices (exc) of the valid rows for each col. band = np.empty((2, sig_len), dtype=np.int32) if bhw is None: # specify entire input matrix band[0, :] = 0 band[1, :] = seq_len else: # use specific band defined by bhw band[0, :] = np.maximum(seq_indices - bhw, 0) band[1, :] = np.minimum(seq_indices + bhw + 1, seq_len) # Modify bands based on invalid levels nan_mask = np.isin(seq_indices, np.nonzero(np.isnan(levels))) nan_sig_indices = np.where(nan_mask)[0] nan_seq_indices = seq_indices[nan_mask] band[0, nan_sig_indices] = nan_seq_indices band[1, nan_sig_indices] = nan_seq_indices + 1 # Modify bands close to invalid levels so monotonically increasing band[0, :] = np.maximum.accumulate(band[0, :]) band[1, :] = np.minimum.accumulate(band[1, ::-1])[::-1] # expand band around large deletions to ensure valid paths invalid_indices = np.where(band[0, 1:] >= band[1, :-1])[0] while invalid_indices.shape[0] > 0: band[0, invalid_indices + 1] = np.maximum( band[0, invalid_indices + 1] - 1, 0) band[1, invalid_indices] = np.minimum(band[1, invalid_indices] + 1, seq_len) invalid_indices = np.where(band[0, 1:] >= band[1, :-1])[0] return band
def open_alignment_out_file(out_dir, map_fmt, ref_names_and_lens, ref_fn): map_fn = mh.get_megalodon_fn(out_dir, mh.MAP_NAME) + '.' + map_fmt if map_fmt == 'bam': w_mode = 'wb' elif map_fmt == 'cram': w_mode = 'wc' elif map_fmt == 'sam': w_mode = 'w' else: raise mh.MegaError('Invalid mapping output format') return pysam.AlignmentFile(map_fn, w_mode, reference_names=ref_names_and_lens[0], reference_lengths=ref_names_and_lens[1], reference_filename=ref_fn)