Пример #1
0
        def set_pyguppy_model_attributes():
            init_client = self.pyguppy_GuppyBasecallerClient(
                self.params.pyguppy.config, host=GUPPY_HOST,
                port=self.params.pyguppy.port,
                timeout=PYGUPPY_PER_TRY_TIMEOUT,
                retries=self.pyguppy_retries)
            try:
                init_client.connect()
                init_read = init_client.basecall(
                    ReadData(np.zeros(init_sig_len, dtype=np.int16), 'a'),
                    state=True, trace=True)
            except (TimeoutError, self.zmqAgainError):
                raise mh.MegaError(
                    'Failed to run test read with guppy. See guppy logs in ' +
                    '--output-directory.')
            init_client.disconnect()
            if init_read.model_type not in COMPAT_GUPPY_MODEL_TYPES:
                raise mh.MegaError((
                    'Megalodon is not compatible with guppy model type: ' +
                    '{}').format(init_read.model_type))

            self.stride = init_read.model_stride
            self.ordered_mod_long_names = init_read.mod_long_names
            self.output_alphabet = init_read.mod_alphabet
            self.output_size = init_read.state_size
            if self.ordered_mod_long_names is None:
                self.ordered_mod_long_names = []
            if self.output_alphabet is None:
                self.output_alphabet = mh.ALPHABET
Пример #2
0
    def iter_overlapping_snps(self, r_ref_pos, edge_buffer):
        """Iterator over SNPs overlapping the read mapped position.

        SNPs within edge buffer of the end of the mapping will be ignored.
        """
        if r_ref_pos.end - r_ref_pos.start <= 2 * edge_buffer:
            raise mh.MegaError('Mapped region too short for SNP calling.')
        try:
            fetch_res = self.variants_idx.fetch(r_ref_pos.chrm,
                                                r_ref_pos.start + edge_buffer,
                                                r_ref_pos.end - edge_buffer)
        except ValueError:
            raise mh.MegaError('Mapped location not valid for variants file.')
        for variant in fetch_res:
            snp_ref_seq = variant.ref
            snp_alt_seqs = variant.alts
            # skip SNPs larger than specified limit
            if self.max_indel_size is not None and max(
                    np.abs(len(snp_ref_seq) - len(snp_alt_seq))
                    for snp_alt_seq in snp_alt_seqs) > self.max_indel_size:
                continue
            # convert to 0-based coordinates
            yield snp_ref_seq, snp_alt_seqs, variant.id, variant.pos - 1

        return
Пример #3
0
def check_matching_attrs(ground_truth_bed,
                         strand_offset,
                         mod_db_fn,
                         target_mod_bases,
                         limit=10000):
    mods_db = mods.ModsDb(mod_db_fn)
    db_strands = (1, -1) if strand_offset is None else (None, )
    db_chrms = set((chrm, strand) for _, chrm, _ in mods_db.iter_chrms()
                   for strand in db_strands)
    cov, mod_cov = mh.parse_bed_methyls([
        ground_truth_bed,
    ],
                                        strand_offset,
                                        show_prog_bar=False,
                                        limit=limit)
    if len(db_chrms.intersection(cov.keys())) == 0:
        LOGGER.error(('Using first {} sites from {}, found zero overlapping ' +
                      'contig/chromosome names with the mod database.').format(
                          limit, ground_truth_bed))
        LOGGER.info('Database contigs/chromosomes: {}'.format(', '.join(
            map(str, db_chrms))))
        LOGGER.info('BED methyl contigs/chromosomes: {}'.format(', '.join(
            map(str, list(cov.keys())))))
        raise mh.MegaError('No overlapping contigs found.')
    db_mods = set(mod_base for mod_base, _ in mods_db.get_mod_long_names())
    for tmb in target_mod_bases:
        if tmb not in db_mods:
            raise mh.MegaError(
                ('Target modified base, {}, not found in mods database ' +
                 '({}).').format(tmb, ', '.join(db_mods)))
    mods_db.check_data_covering_index_exists()
    mods_db.close()
Пример #4
0
def check_map_sig_alphabet(model_info, ms_fn):
    # read filename queue filler
    msf = mapped_signal_files.HDF5Reader(ms_fn)
    tai_alph_info = msf.get_alphabet_information()
    msf.close()
    if model_info.output_alphabet != tai_alph_info.alphabet:
        raise mh.MegaError(
            (
                "Different alphabets specified in model ({}) and mapped "
                + "signal file ({})"
            ).format(model_info.output_alphabet, tai_alph_info.alphabet)
        )
    if set(model_info.can_alphabet) != set(tai_alph_info.collapse_alphabet):
        raise mh.MegaError(
            (
                "Different canonical alphabets specified in model ({}) and "
                + "mapped signal file ({})"
            ).format(model_info.can_alphabet, tai_alph_info.collapse_alphabet)
        )
    if model_info.ordered_mod_long_names != tai_alph_info.mod_long_names:
        raise mh.MegaError(
            (
                "Different modified base long names specified in model ({}) and "
                + "mapped signal file ({})"
            ).format(
                ", ".join(model_info.ordered_mod_long_names),
                ", ".join(tai_alph_info.mod_long_names),
            )
        )
Пример #5
0
    def run_model(self, raw_sig, n_can_state=None):
        if self.model_type == TAI_NAME:
            if any(arg is None for arg in (self.chunk_size, self.chunk_overlap,
                                           self.max_concur_chunks)):
                logger = logging.get_logger()
                logger.error('Must provide chunk_size, chunk_overlap, ' +
                             'max_concur_chunks in order to run the taiyaki ' +
                             'base calling backend.')
            try:
                trans_weights = self.tai_run_model(raw_sig, self.model,
                                                   self.chunk_size,
                                                   self.chunk_overlap,
                                                   self.max_concur_chunks)
            except AttributeError:
                raise mh.MegaError('Out of date or incompatible model')
            except RuntimeError:
                raise mh.MegaError('Likely out of memory error.')
            if self.device != self.torch.device('cpu'):
                self.torch.cuda.empty_cache()
            if n_can_state is not None:
                trans_weights = (np.ascontiguousarray(
                    trans_weights[:, :n_can_state]),
                                 np.ascontiguousarray(
                                     trans_weights[:, n_can_state:]))
        else:
            raise mh.MegaError('Invalid model type.')

        return trans_weights
Пример #6
0
def map_read(q_seq, read_id, caller_conn, signal_reversed=False):
    """Map read (query) sequence and return:
    1) reference sequence (endcoded as int labels)
    2) mapping from reference to read positions (after trimming)
    3) reference mapping position (including read trimming positions)
    4) cigar as produced by mappy
    """
    # send seq to _map_read_worker and receive mapped seq and pos
    if signal_reversed:
        q_seq = q_seq[::-1]
    caller_conn.send((q_seq, read_id))
    r_ref_seq, r_algn = caller_conn.recv()
    if r_ref_seq is None:
        raise mh.MegaError('No alignment')
    chrm, strand, r_st, r_en, q_st, q_en, r_cigar = r_algn
    if signal_reversed:
        q_st, q_en = len(q_seq) - q_en, len(q_seq) - q_st
        r_ref_seq = r_ref_seq[::-1]
        r_cigar = r_cigar[::-1]

    try:
        r_to_q_poss = parse_cigar(r_cigar, strand, r_en - r_st)
    except mh.MegaError as e:
        LOGGER.debug('Read {} ::: '.format(read_id) + str(e))
        raise mh.MegaError('Invalid cigar string encountered.')
    r_pos = MAP_POS(chrm=chrm,
                    strand=strand,
                    start=r_st,
                    end=r_en,
                    q_trim_start=q_st,
                    q_trim_end=q_en)

    return r_ref_seq, r_to_q_poss, r_pos, r_cigar
def call_alt_true_indel(indel_size, r_snp_pos, true_ref_seq, r_seq,
                        map_thr_buf, context_bases, r_post, rl_cumsum,
                        all_paths):
    def run_aligner():
        return next(
            mappy.Aligner(seq=false_ref_seq, preset=str('map-ont'),
                          best_n=1).map(str(r_seq), buf=map_thr_buf))

    if indel_size == 0:
        false_base = choice(
            list(set(CAN_BASES).difference(true_ref_seq[r_snp_pos])))
        false_ref_seq = (true_ref_seq[:r_snp_pos] + false_base +
                         true_ref_seq[r_snp_pos + 1:])
        snp_ref_seq = false_base
        snp_alt_seq = true_ref_seq[r_snp_pos]
    elif indel_size > 0:
        # test alt truth reference insertion
        false_ref_seq = (true_ref_seq[:r_snp_pos + 1] +
                         true_ref_seq[r_snp_pos + indel_size + 1:])
        snp_ref_seq = true_ref_seq[r_snp_pos]
        snp_alt_seq = true_ref_seq[r_snp_pos:r_snp_pos + indel_size + 1]
    else:
        # test alt truth reference deletion
        deleted_seq = ''.join(choice(CAN_BASES) for _ in range(-indel_size))
        false_ref_seq = (true_ref_seq[:r_snp_pos + 1] + deleted_seq +
                         true_ref_seq[r_snp_pos + 1:])
        snp_ref_seq = true_ref_seq[r_snp_pos] + deleted_seq
        snp_alt_seq = true_ref_seq[r_snp_pos]

    try:
        r_algn = run_aligner()
    except StopIteration:
        raise mh.MegaError('No alignment')

    r_ref_seq = false_ref_seq[r_algn.r_st:r_algn.r_en]
    if r_algn.strand == -1:
        raise mh.MegaError('Indel mapped read mapped to reverse strand.')

    r_to_q_poss = mapping.parse_cigar(r_algn.cigar, r_algn.strand)
    if (r_algn.r_st > r_snp_pos - context_bases[1]
            or r_algn.r_en < r_snp_pos + context_bases[1]):
        raise mh.MegaError('Indel mapped read clipped snp position.')

    post_mapped_start = rl_cumsum[r_algn.q_st]
    mapped_rl_cumsum = rl_cumsum[r_algn.q_st:r_algn.q_en +
                                 1] - post_mapped_start

    score = call_snp(r_post,
                     post_mapped_start,
                     r_snp_pos,
                     rl_cumsum,
                     r_to_q_poss,
                     snp_ref_seq,
                     snp_alt_seq,
                     context_bases,
                     all_paths,
                     ref_seq=r_ref_seq)

    return score, snp_ref_seq, snp_alt_seq
Пример #8
0
    def __init__(self,
                 variant_fn,
                 max_indel_size,
                 all_paths,
                 write_snps_txt,
                 context_bases,
                 snps_calib_fn=None,
                 call_mode=DIPLOID_MODE,
                 do_pr_ref_snps=False,
                 aligner=None,
                 keep_snp_fp_open=False):
        logger = logging.get_logger('snps')
        self.max_indel_size = max_indel_size
        self.all_paths = all_paths
        self.write_snps_txt = write_snps_txt
        self.snps_calib_fn = snps_calib_fn
        self.calib_table = calibration.SnpCalibrator(self.snps_calib_fn)
        self.context_bases = context_bases
        if len(self.context_bases) != 2:
            raise mh.MegaError(
                'Must provide 2 context bases values (for single base SNPs ' +
                'and indels).')
        self.call_mode = call_mode
        self.do_pr_ref_snps = do_pr_ref_snps
        self.variant_fn = variant_fn
        self.variants_idx = None
        if self.variant_fn is None:
            return

        logger.info('Loading variants.')
        vars_idx = pysam.VariantFile(self.variant_fn)
        try:
            contigs = list(vars_idx.header.contigs.keys())
            vars_idx.fetch(next(iter(contigs)), 0, 0)
        except ValueError:
            logger.warn(
                'Variants file must be indexed. Performing indexing now.')
            vars_idx.close()
            self.variant_fn = index_variants(self.variant_fn)
            vars_idx = pysam.VariantFile(self.variant_fn)
        if keep_snp_fp_open:
            self.variants_idx = vars_idx
        else:
            vars_idx.close()
            self.variants_idx = None

        if aligner is None:
            raise mh.MegaError(
                'Must provide aligner if SNP filename is provided')
        if len(set(aligner.ref_names_and_lens[0]).intersection(contigs)) == 0:
            raise mh.MegaError((
                'Reference and variant files contain no chromosomes/contigs ' +
                'in common.\n\t\tFirst 3 reference contigs:\t{}\n\t\tFirst 3 '
                + 'variant file contigs:\t{}').format(
                    ', '.join(aligner.ref_names_and_lens[0][:3]),
                    ', '.join(contigs[:3])))

        return
Пример #9
0
    def compute_snp_stats(self,
                          snp_loc,
                          het_factors,
                          call_mode=DIPLOID_MODE,
                          valid_read_ids=None):
        assert call_mode in (HAPLIOD_MODE, DIPLOID_MODE), (
            'Invalid SNP aggregation ploidy call mode: {}.'.format(call_mode))

        pr_snp_stats = self.get_per_read_snp_stats(snp_loc)
        alt_seqs = sorted(set(r_stats.alt_seq for r_stats in pr_snp_stats))
        pr_alt_lps = defaultdict(dict)
        for r_stats in pr_snp_stats:
            if (valid_read_ids is not None
                    and r_stats.read_id not in valid_read_ids):
                continue
            pr_alt_lps[r_stats.read_id][r_stats.alt_seq] = r_stats.score
        if len(pr_alt_lps) == 0:
            raise mh.MegaError('No valid reads cover SNP')

        alt_seq_lps = [[] for _ in range(len(alt_seqs))]
        for read_lps in pr_alt_lps.values():
            for i, alt_seq in enumerate(alt_seqs):
                try:
                    alt_seq_lps[i].append(read_lps[alt_seq])
                except KeyError:
                    raise mh.MegaError(
                        'Alternative SNP seqence must exist for all reads.')
        alts_lps = np.stack(alt_seq_lps, axis=0)
        with np.errstate(all='ignore'):
            ref_lps = np.log1p(-np.exp(alts_lps).sum(axis=0))

        r0_stats = pr_snp_stats[0]
        snp_var = Variant(chrom=r0_stats.chrm,
                          pos=r0_stats.pos,
                          ref=r0_stats.ref_seq,
                          alts=alt_seqs,
                          id=r0_stats.snp_id)
        snp_var.add_tag('DP', '{}'.format(ref_lps.shape[0]))
        snp_var.add_sample_field('DP', '{}'.format(ref_lps.shape[0]))

        if self.write_vcf_log_probs:
            snp_var.add_sample_field(
                'LOG_PROBS', ','.join(';'.join('{:.2f}'.format(lp)
                                               for lp in alt_i_lps)
                                      for alt_i_lps in alts_lps))

        if call_mode == DIPLOID_MODE:
            het_factor = (het_factors[0] if len(r0_stats.ref_seq) == 1
                          and len(r0_stats.alt_seq) == 1 else het_factors[1])
            diploid_probs, gts = self.compute_diploid_probs(
                ref_lps, alts_lps, het_factor)
            snp_var.add_diploid_probs(diploid_probs, gts)
        elif call_mode == HAPLIOD_MODE:
            haploid_probs, gts = self.compute_haploid_probs(ref_lps, alts_lps)
            snp_var.add_haploid_probs(haploid_probs, gts)

        return snp_var
Пример #10
0
def map_read(
        caller_conn, called_read, sig_info, mo_q=None, signal_reversed=False,
        rl_cumsum=None):
    """ Map read (query) sequence

    Returns:
        Tuple containing
            1) reference sequence (endcoded as int labels)
            2) mapping from reference to read positions (after trimming)
            3) reference mapping position (including read trimming positions)
            4) cigar as produced by mappy
    """
    # send seq to _map_read_worker and receive mapped seq and pos
    q_seq = called_read.seq[::-1] if signal_reversed else called_read.seq
    caller_conn.send((q_seq, sig_info.read_id))
    map_res = caller_conn.recv()
    if map_res is None:
        raise mh.MegaError('No alignment')
    map_res = MAP_RES(*map_res)
    # add signal coordinates to mapping output if run-length cumsum provided
    if rl_cumsum is not None:
        # convert query start and end to signal-anchored locations
        # Note that for signal_reversed reads, the start will be larger than
        # the end
        q_st = len(map_res.q_seq) - map_res.q_st if signal_reversed else \
            map_res.q_st
        q_en = len(map_res.q_seq) - map_res.q_en if signal_reversed else \
            map_res.q_en
        map_res = map_res._replace(
            map_sig_start=called_read.trimmed_samples +
            rl_cumsum[q_st] * sig_info.stride,
            map_sig_end=called_read.trimmed_samples +
            rl_cumsum[q_en] * sig_info.stride,
            sig_len=called_read.trimmed_samples +
            rl_cumsum[-1] * sig_info.stride)
    if mo_q is not None:
        mo_q.put(tuple(map_res))
    if signal_reversed:
        # if signal is reversed compared to mapping, reverse coordinates so
        # they are relative to signal/state_data
        map_res = map_res._replace(
            q_st=len(map_res.q_seq) - map_res.q_en,
            q_en=len(map_res.q_seq) - map_res.q_st,
            ref_seq=map_res.ref_seq[::-1],
            cigar=map_res.cigar[::-1])

    try:
        r_to_q_poss = parse_cigar(
            map_res.cigar, map_res.strand, map_res.r_en - map_res.r_st)
    except mh.MegaError as e:
        LOGGER.debug('{} CigarParsingError'.format(sig_info.read_id) + str(e))
        raise mh.MegaError('Invalid cigar string encountered.')
    map_pos = get_map_pos_from_res(map_res)

    return map_res.ref_seq, r_to_q_poss, map_pos, map_res.cigar
Пример #11
0
    def run_pyguppy_model(
            self, sig_info, return_post_w_mods, return_mod_scores,
            update_sig_info):
        if self.model_type != PYGUPPY_NAME:
            raise mh.MegaError(
                'Attempted to run pyguppy model with non-pyguppy ' +
                'initialization.')

        post_w_mods = mods_scores = None
        try:
            called_read = self.client.basecall(
                self.pyguppy_ReadData(sig_info.dacs, sig_info.read_id),
                state=True, trace=True)
        except TimeoutError:
            raise mh.MegaError(
                'Pyguppy server timeout. See --guppy-timeout option')

        # compute rl_cumsum from move table
        rl_cumsum = np.where(called_read.move)[0]
        rl_cumsum = np.insert(rl_cumsum, rl_cumsum.shape[0],
                              called_read.move.shape[0])

        if self.is_cat_mod:
            # split canonical posteriors and mod transition weights
            can_post = np.ascontiguousarray(
                called_read.state[:, :self.n_can_state])
            if return_mod_scores or return_post_w_mods:
                mods_weights = self._softmax_mod_weights(
                    called_read.state[:, self.n_can_state:])
                if return_post_w_mods:
                    post_w_mods = np.concatenate(
                        [can_post, mods_weights], axis=1)
                if return_mod_scores:
                    # TODO apply np.NAN mask to scores not applicable to
                    # canonical basecalls
                    mods_scores = np.ascontiguousarray(
                        mods_weights[rl_cumsum[:-1]])
        else:
            can_post = called_read.state

        if update_sig_info:
            # add scale_params and trimmed dacs to sig_info
            trimmed_dacs = sig_info.dacs[called_read.trimmed_samples:]
            # guppy does not apply the med norm factor
            scale_params = (
                called_read.scaling['median'],
                called_read.scaling['med_abs_dev'] * mh.MED_NORM_FACTOR)
            sig_info = sig_info._replace(
                raw_len=trimmed_dacs.shape[0], dacs=trimmed_dacs,
                raw_signal=((trimmed_dacs - scale_params[0]) /
                            scale_params[1]).astype(np.float32),
                scale_params=scale_params)

        return (called_read.seq, called_read.qual, rl_cumsum, can_post,
                sig_info, post_w_mods, mods_scores)
Пример #12
0
        def start_guppy_server():
            def get_server_port():
                next_line = guppy_out_read_fp.readline()
                if next_line is None:
                    return None
                try:
                    return int(GUPPY_PORT_PAT.search(next_line).groups()[0])
                except AttributeError:
                    return None

            # set guppy logs output locations
            self.guppy_log = os.path.join(
                self.params.pyguppy.out_dir, GUPPY_LOG_BASE)
            self.guppy_out_fp = open(self.guppy_log + '.out', 'w')
            guppy_out_read_fp = open(self.guppy_log + '.out', 'r')
            self.guppy_err_fp = open(self.guppy_log + '.err', 'w')
            # prepare args to start guppy server
            server_args = [
                self.params.pyguppy.bin_path,
                '-p', str(self.params.pyguppy.port),
                '-l', self.guppy_log,
                '-c', self.params.pyguppy.config,
                '--post_out']
            if self.params.pyguppy.devices is not None and \
               len(self.params.pyguppy.devices) > 0 and \
               self.params.pyguppy.devices[0] != 'cpu':
                devices_str = ' '.join(
                    parse_device(device) for device in
                    self.params.pyguppy.devices)
                server_args.extend(('-x', devices_str))
            if self.params.pyguppy.server_params is not None:
                server_args.extend(self.params.pyguppy.server_params.split())
            try:
                self.guppy_server_proc = subprocess.Popen(
                    server_args, shell=False,
                    stdout=self.guppy_out_fp, stderr=self.guppy_err_fp)
            except FileNotFoundError:
                raise mh.MegaError(
                    'Guppy server executable not found. Please specify path ' +
                    'via `--guppy-server-path` argument.')
            # wait until server is successfully started or fails
            while True:
                used_port = get_server_port()
                if used_port is not None:
                    break
                if self.guppy_server_proc.poll() is not None:
                    raise mh.MegaError(
                        'Guppy server initialization failed. See guppy logs ' +
                        'in --output-directory for more details.')
                sleep(0.01)
            guppy_out_read_fp.close()
            self.params = self.params._replace(
                pyguppy=self.params.pyguppy._replace(port=used_port))
Пример #13
0
def extract_mods(in_mod_db_fns):
    LOGGER.info('Merging mod tables')
    # collect modified base definitions from input databases
    mod_base_to_can = dict()
    for in_mod_db_fn in tqdm(in_mod_db_fns,
                             desc='Databases',
                             unit='DBs',
                             smoothing=0,
                             dynamic_ncols=True):
        mods_db = mods.ModsDb(in_mod_db_fn)
        for mod_base, can_base, mln in mods_db.get_full_mod_data():
            if mod_base in mod_base_to_can and \
               (can_base, mln) != mod_base_to_can[mod_base]:
                raise mh.MegaError(
                    'Modified base associated with mutliple canonical bases ' +
                    'or long names in different databases. {} != {}'.format(
                        str((can_base, mln)), str(mod_base_to_can[mod_base])))
            mod_base_to_can[mod_base] = (can_base, mln)
    # check that mod long names are unique
    mlns = [mln for _, mln in mod_base_to_can.values()]
    if len(mlns) != len(set(mlns)):
        raise mh.MegaError(
            'Modified base long name assigned to more than one modified ' +
            'base single letter code.')

    # extract canonical bases associated with modified base
    can_bases = set(can_base for can_base, _ in mod_base_to_can.values())
    # determine first valid canonical alphabet compatible with database
    can_alphabet = None
    for v_alphabet in mh.VALID_ALPHABETS:
        if len(can_bases.difference(v_alphabet)) == 0:
            can_alphabet = v_alphabet
            break
    if can_alphabet is None:
        LOGGER.error(
            'Mods database does not contain valid canonical bases ({})'.format(
                ''.join(sorted(can_bases))))
        raise mh.MegaError('Invalid alphabet.')

    # compute full output alphabet and ordered modified base long names
    can_base_to_mods = dict(
        (can_base, [(mod_base, mln)
                    for mod_base, (mcan_base, mln) in mod_base_to_can.items()
                    if mcan_base == can_base]) for can_base in can_alphabet)
    alphabet = ''
    mod_long_names = []
    for can_base in can_alphabet:
        alphabet += can_base
        for mod_base, mln in can_base_to_mods[can_base]:
            alphabet += mod_base
            mod_long_names.append(mln)

    return alphabet, mod_long_names
Пример #14
0
def get_remapping(sig_fn, dacs, scale_params, ref_seq, stride, read_id,
                  r_to_q_poss, rl_cumsum, r_ref_pos, ref_out_info):
    read = fast5_interface.get_fast5_file(sig_fn, 'r').get_read(read_id)
    channel_info = dict(fast5utils.get_channel_info(read).items())
    rd_factor = channel_info['range'] / channel_info['digitisation']
    shift_from_pA = (scale_params[0] + channel_info['offset']) * rd_factor
    scale_from_pA = scale_params[1] * rd_factor
    read_attrs = dict(fast5utils.get_read_attributes(read).items())

    # prepare taiyaki signal object
    sig = tai_signal.Signal(dacs=dacs)
    sig.channel_info = channel_info
    sig.read_attributes = read_attrs
    sig.offset = channel_info['offset']
    sig.range = channel_info['range']
    sig.digitisation = channel_info['digitisation']

    path = np.full((dacs.shape[0] // stride) + 1, -1)
    # skip last value since this is where the two seqs end
    for ref_pos, q_pos in enumerate(r_to_q_poss[:-1]):
        # if the query position maps to the end of the mapping skip it
        if rl_cumsum[q_pos + r_ref_pos.q_trim_start] >= path.shape[0]:
            continue
        path[rl_cumsum[q_pos + r_ref_pos.q_trim_start]] = ref_pos
    remapping = tai_mapping.Mapping.from_remapping_path(
        sig, path, ref_seq, stride)
    try:
        remapping.add_integer_reference(ref_out_info.alphabet)
    except Exception:
        raise mh.MegaError('Invalid reference sequence encountered')

    return (remapping.get_read_dictionary(shift_from_pA, scale_from_pA,
                                          read_id),
            prepare_mapping_funcs.RemapResult.SUCCESS)
Пример #15
0
 def get_read_id(self, uuid):
     try:
         read_id = self.cur.execute('SELECT read_id FROM read WHERE uuid=?',
                                    (uuid, )).fetchone()[0]
     except TypeError:
         raise mh.MegaError('Read ID not found in mods data base.')
     return read_id
Пример #16
0
    def prep_model_worker(self, device):
        """ Load model onto a newly spawned process
        """
        if self.model_type == TAI_NAME:
            # setup for taiyaki model
            self.model = self.load_taiyaki_model(
                self.params.taiyaki.taiyaki_model_fn)
            if device is None or device == 'cpu':
                self.device = self.torch.device('cpu')
            else:
                sleep(np.random.uniform(0, MAX_DEVICE_WAIT))
                try:
                    self.device = self.torch.device(device)
                    self.torch.cuda.set_device(self.device)
                    self.model = self.model.to(self.device)
                except RuntimeError:
                    LOGGER.error('Invalid CUDA device: {}'.format(device))
                    raise mh.MegaError('Error setting CUDA GPU device.')
            self.model = self.model.eval()
        elif self.model_type == PYGUPPY_NAME:
            # open guppy client interface (None indicates using config
            # from server)
            self.client = self.pyguppy_GuppyBasecallerClient(
                self.params.pyguppy.config, host=GUPPY_HOST,
                port=self.params.pyguppy.port, timeout=PYGUPPY_PER_TRY_TIMEOUT,
                retries=self.pyguppy_retries)
            self.client.connect()

        return
Пример #17
0
def main():
    args = get_parser().parse_args()

    vars0_idx = pysam.VariantFile(args.diploid_called_variants)
    vars1_idx = pysam.VariantFile(args.haplotype1_variants)
    vars2_idx = pysam.VariantFile(args.haplotype2_variants)
    try:
        contigs0 = list(vars0_idx.header.contigs.keys())
        vars0_idx.fetch(next(iter(contigs0)), 0, 0)
        contigs1 = list(vars1_idx.header.contigs.keys())
        vars1_idx.fetch(next(iter(contigs1)), 0, 0)
        contigs2 = list(vars2_idx.header.contigs.keys())
        vars2_idx.fetch(next(iter(contigs2)), 0, 0)
    except ValueError:
        raise mh.MegaError(
            'Variants file must be indexed. Use bgzip and tabix.')

    out_vars = open(args.out_vcf, 'w')
    out_vars.write(
        HEADER.format('\n'.join(
            (CONTIG_HEADER_LINE.format(ctg.name, ctg.length)
             for ctg in vars0_idx.header.contigs.values()))))
    for contig in set(contigs0).intersection(contigs1).intersection(contigs2):
        for curr_v0_rec, curr_v1_rec, curr_v2_rec in iter_contig_vars(
                iter(vars0_idx.fetch(contig)), iter(vars1_idx.fetch(contig)),
                iter(vars2_idx.fetch(contig))):
            if curr_v0_rec is None:
                continue
            write_var(curr_v0_rec, curr_v1_rec, curr_v2_rec, out_vars, contig)

    out_vars.close()

    variants.index_variants(args.out_vcf)
Пример #18
0
def extract_llrs(llr_fn, max_indel_len=None):
    snp_ref_llrs, ins_ref_llrs, del_ref_llrs = (defaultdict(list)
                                                for _ in range(3))
    with open(llr_fn) as llr_fp:
        for line in llr_fp:
            is_ref_correct, llr, ref_seq, alt_seq = line.split()
            llr = float(llr)
            if is_ref_correct != 'True':
                continue
            if np.isnan(llr):
                continue
            if max_indel_len is not None and \
               np.abs(len(ref_seq) - len(alt_seq)) > max_indel_len:
                continue
            if len(ref_seq) == 1 and len(alt_seq) == 1:
                snp_ref_llrs[(ref_seq, alt_seq)].append(llr)
            else:
                if len(ref_seq) > len(alt_seq):
                    del_ref_llrs[len(ref_seq) - len(alt_seq)].append(llr)
                else:
                    ins_ref_llrs[len(alt_seq) - len(ref_seq)].append(llr)

    if min(len(snp_ref_llrs), len(ins_ref_llrs), len(del_ref_llrs)) == 0:
        raise mh.MegaError(
            'Variant statistics file does not contain sufficient data for ' +
            'calibration.')

    return snp_ref_llrs, ins_ref_llrs, del_ref_llrs
Пример #19
0
def parse_cigar(r_cigar, strand, ref_len):
    fill_invalid = -1
    # get each base calls genomic position
    r_to_q_poss = np.full(ref_len + 1, fill_invalid, dtype=np.int32)
    # process cigar ops in read direction
    curr_r_pos, curr_q_pos = 0, 0
    cigar_ops = r_cigar if strand == 1 else r_cigar[::-1]
    for op_len, op in cigar_ops:
        if op == 1:
            # inserted bases into ref
            curr_q_pos += op_len
        elif op in (2, 3):
            # deleted ref bases
            for r_pos in range(curr_r_pos, curr_r_pos + op_len):
                r_to_q_poss[r_pos] = curr_q_pos
            curr_r_pos += op_len
        elif op in (0, 7, 8):
            # aligned bases
            for op_offset in range(op_len):
                r_to_q_poss[curr_r_pos + op_offset] = curr_q_pos + op_offset
            curr_q_pos += op_len
            curr_r_pos += op_len
        elif op == 6:
            # padding (shouldn't happen in mappy)
            pass
    r_to_q_poss[curr_r_pos] = curr_q_pos
    if r_to_q_poss[-1] == fill_invalid:
        raise mh.MegaError(
            ('Invalid cigar string encountered. Reference length: {}  Cigar ' +
             'implied reference length: {}').format(ref_len, curr_r_pos))

    return r_to_q_poss
Пример #20
0
def map_read(caller_conn,
             called_read,
             sig_info,
             mo_q=None,
             signal_reversed=False,
             rl_cumsum=None):
    """ Map read (query) sequence

    Returns:
        Tuple containing
            1) reference sequence (endcoded as int labels)
            2) mapping from reference to read positions (after trimming)
            3) reference mapping position (including read trimming positions)
            4) cigar as produced by mappy
    """
    # send seq to _map_read_worker and receive mapped seq and pos
    q_seq = called_read.seq[::-1] if signal_reversed else called_read.seq
    caller_conn.send((q_seq, sig_info.read_id))
    map_ress = caller_conn.recv()
    if map_ress is None:
        raise mh.MegaError('No alignment')

    return [
        process_mapping(map_res, called_read, sig_info, mo_q, signal_reversed,
                        rl_cumsum) for map_res in map_ress
    ]
Пример #21
0
    def _load_fast5_post_out(self):
        def get_model_info_from_fast5(read):
            try:
                stride = fast5_io.get_stride(read)
                mod_long_names, out_alphabet = fast5_io.get_mod_base_info(read)
                out_size = fast5_io.get_posteriors(read).shape[1]
                mod_long_names = mod_long_names.split()
            except KeyError:
                LOGGER.error(
                    'Fast5 read does not contain required attributes.')
                raise mh.MegaError(
                    'Fast5 read does not contain required attributes.')
            return stride, mod_long_names, out_alphabet, out_size

        LOGGER.info('Loading FAST5 basecalling backend.')
        self.model_type = FAST5_NAME
        self.process_devices = [None, ] * self.num_proc

        read_iter = fast5_io.iterate_fast5_reads(self.params.fast5.fast5s_dir)
        nreads = 0
        try:
            fast5_fn, read_id = next(read_iter)
            read = fast5_io.get_read(fast5_fn, read_id)
            (self.stride, self.ordered_mod_long_names, self.output_alphabet,
             self.output_size) = get_model_info_from_fast5(read)
        except StopIteration:
            LOGGER.error('No reads found.')
            raise mh.MegaError('No reads found.')

        for fast5_fn, read_id in read_iter:
            read = fast5_io.get_read(fast5_fn, read_id)
            r_s, r_omln, r_oa, r_os = get_model_info_from_fast5(read)
            if (
                    self.stride != r_s or
                    self.ordered_mod_long_names != r_omln or
                    self.output_alphabet != r_oa or
                    self.output_size != r_os):
                LOGGER.error(
                    'Model information from FAST5 files is inconsistent. ' +
                    'Assure all reads were called with the same model.')
                raise mh.MegaError(
                    'Model information from FAST5 files is inconsistent.')
            nreads += 1
            if nreads >= self.params.fast5.num_startup_reads:
                break

        self._parse_minimal_alphabet_info()
Пример #22
0
 def get_chrm(self, chrm_id):
     try:
         chrm = self.cur.execute('SELECT chrm FROM chrm WHERE chrm_id=?',
                                 (chrm_id, )).fetchone()[0]
     except TypeError:
         raise mh.MegaError('Reference record (chromosome) not found in ' +
                            'mods database.')
     return chrm
Пример #23
0
def get_mapping_mode(map_fmt):
    if map_fmt == 'bam':
        return 'wb'
    elif map_fmt == 'cram':
        return 'wc'
    elif map_fmt == 'sam':
        return 'w'
    raise mh.MegaError('Invalid mapping output format: {}'.format(map_fmt))
Пример #24
0
def get_mapping_mode(map_fmt):
    if map_fmt == "bam":
        return "wb"
    elif map_fmt == "cram":
        return "wc"
    elif map_fmt == "sam":
        return "w"
    raise mh.MegaError("Invalid mapping output format: {}".format(map_fmt))
 def __init__(
         self, conn, size, max_size=mh._MAX_QUEUE_SIZE, name='ConnWithSize',
         full_sleep_time=_FULL_SLEEP_TIME):
     if not isinstance(conn, mp.connection.Connection):
         raise mh.MegaError((
             'ConnWithSize initialized with non-connection object. ' +
             'Object type: {}').format(type(conn)))
     if not isinstance(size, mp.sharedctypes.Synchronized) and \
        isinstance(size.value, int):
         raise mh.MegaError((
             'ConnWithSize initialized with non-synchronized size ' +
             'object. Object type: {}').format(type(size)))
     self._conn = conn
     self._size = size
     self.max_size = max_size
     self.full_sleep_time = full_sleep_time
     self.name = name
Пример #26
0
    def basecall_read(
            self, sig_info, return_post_w_mods=True, return_mod_scores=False,
            update_sig_info=False):
        if self.model_type not in (TAI_NAME, FAST5_NAME, PYGUPPY_NAME):
            raise mh.MegaError('Invalid model backend')

        # decoding is performed within pyguppy server, so shortcurcuit return
        # here as other methods require megalodon decoding.
        if self.model_type == PYGUPPY_NAME:
            return self.run_pyguppy_model(
                sig_info, return_post_w_mods, return_mod_scores,
                update_sig_info)

        post_w_mods = mod_weights = None
        if self.model_type == TAI_NAME:
            # run neural network with taiyaki
            if self.is_cat_mod:
                bc_weights, mod_weights = self.run_taiyaki_model(
                    sig_info.raw_signal, self.n_can_state)
            else:
                bc_weights = self.run_taiyaki_model(sig_info.raw_signal)
            # perform forward-backward algorithm on neural net output
            can_post = decode.crf_flipflop_trans_post(bc_weights, log=True)
            if return_post_w_mods and self.is_cat_mod:
                post_w_mods = np.concatenate([can_post, mod_weights], axis=1)
            # set mod_weights to None if mod_scores not requested to
            # avoid extra computation
            if not return_mod_scores:
                mod_weights = None
        else:
            # FAST5 stored posteriors backend
            if self.is_cat_mod:
                # split canonical posteriors and mod transition weights
                # producing desired return arrays
                can_post = np.ascontiguousarray(
                    sig_info.posteriors[:, :self.n_can_state])
                if return_mod_scores or return_post_w_mods:
                    # convert raw neural net mod weights to softmax weights
                    mod_weights = self._softmax_mod_weights(
                        sig_info.posteriors[:, self.n_can_state:])
                    if return_post_w_mods:
                        post_w_mods = np.concatenate(
                            [can_post, mod_weights], axis=1)
                    if not return_mod_scores:
                        mod_weights = None
            else:
                can_post = sig_info.posteriors

        # decode posteriors to sequence and per-base mod scores
        r_seq, _, rl_cumsum, mods_scores = decode.decode_post(
            can_post, self.can_alphabet, mod_weights, self.can_nmods)
        # TODO implement quality extraction for taiyaki and fast5 modes
        r_qual = None

        return (r_seq, r_qual, rl_cumsum, can_post, sig_info, post_w_mods,
                mods_scores)
Пример #27
0
def compute_mod_sites_stats(
        mod_stats, ctrl_stats, balance_classes, mod_base, samp_lab, vs_lab,
        out_fp):
    if balance_classes:
        # randomly downsample sample with more observations
        if mod_stats.shape[0] > ctrl_stats.shape[0]:
            mod_stats = np.random.choice(
                mod_stats, ctrl_stats.shape[0], replace=False)
        elif mod_stats.shape[0] < ctrl_stats.shape[0]:
            ctrl_stats = np.random.choice(
                ctrl_stats, mod_stats.shape[0], replace=False)

    is_can = np.repeat([0, 1], [mod_stats.shape[0], ctrl_stats.shape[0]])
    all_stats = np.concatenate([mod_stats, ctrl_stats])
    if any(np.isnan(all_stats)):
        LOGGER.warning((
            'Encountered {} NaN modified base scores.').format(
                sum(np.isnan(all_stats))))
        all_stats = all_stats[~np.isnan(all_stats)]
        if all_stats.shape[0] == 0:
            raise mh.MegaError('All modified base scores are NaN')
    inf_idx = np.isinf(all_stats)
    if any(inf_idx):
        all_stats[inf_idx] = np.max(all_stats[~inf_idx])
    neginf_idx = np.isinf(all_stats)
    if any(neginf_idx):
        all_stats[neginf_idx] = np.min(all_stats[~neginf_idx])
    LOGGER.info(
        'Computing PR/ROC for {} from {} at {}'.format(
            mod_base, samp_lab, vs_lab))
    # compute roc and presicion recall
    precision, recall, thresh = precision_recall_curve(is_can, all_stats)
    prec_recall_sum = precision + recall
    valid_idx = np.where(prec_recall_sum > 0)
    all_f1 = (2 * precision[valid_idx] * recall[valid_idx] /
              prec_recall_sum[valid_idx])
    optim_f1_idx = np.argmax(all_f1)
    optim_f1 = all_f1[optim_f1_idx]
    optim_thresh = thresh[optim_f1_idx]
    avg_prcn = average_precision_score(is_can, all_stats)

    fpr, tpr, _ = roc_curve(is_can, all_stats)
    roc_auc = auc(fpr, tpr)

    out_fp.write(
        MOD_VAL_METRICS_TMPLT.format(
            optim_f1, optim_thresh, avg_prcn, roc_auc, mod_stats.shape[0],
            ctrl_stats.shape[0], mod_base, samp_lab, vs_lab))
    pr_data = ('{} at {} mAP={:0.2f}'.format(
        samp_lab, vs_lab, avg_prcn), precision, recall)
    roc_data = ('{} at {} AUC={:0.2f}'.format(
        samp_lab, vs_lab, roc_auc), fpr, tpr)
    kde_data = ('{} from {} at {}'.format(
        mod_base, samp_lab, vs_lab), mod_stats, ctrl_stats)

    return pr_data, roc_data, kde_data
Пример #28
0
def test_open_alignment_out_file(out_dir, map_fmt, ref_names_and_lens, ref_fn):
    try:
        map_fp = open_alignment_out_file(out_dir, map_fmt, ref_names_and_lens,
                                         ref_fn)
    except ValueError:
        raise mh.MegaError(
            'Failed to open alignment file for writing. Check that ' +
            'reference file is compressed with bgzip for CRAM output.')
    map_fp.close()
    os.remove(map_fp.filename)
Пример #29
0
def compute_sig_band(bps, levels, bhw=mh.DEFAULT_CONSTRAINED_HALF_BW):
    """Compute band over which to explore possible paths. Band is represented
    in sequence/level coordinates at each signal position.

    Args:
        bps (np.ndarray): Integer array containing breakpoints
        levels (np.ndarray): float array containing expected signal levels. May
            contain np.NAN values. Band will be constructed to maintain path
            through NAN regions.
        bhw (int): Band half width. If None, full matrix is used.

    Returns:
        int32 np.ndarray with shape (2, sig_len = bps[-1] - bps[0]). The first
        row contains the lower band boundaries in sequence coordinates and the
        second row contains the upper boundaries in sequence coordinates.
    """
    seq_len = levels.shape[0]
    if bps.shape[0] - 1 != seq_len:
        raise mh.MegaError("Breakpoints must be one longer than levels.")
    sig_len = bps[-1] - bps[0]
    seq_indices = np.repeat(np.arange(seq_len), np.diff(bps))

    # Calculate bands
    # The 1st row consists of the start indices (inc) and the 2nd row
    # consists of the end indices (exc) of the valid rows for each col.
    band = np.empty((2, sig_len), dtype=np.int32)
    if bhw is None:
        # specify entire input matrix
        band[0, :] = 0
        band[1, :] = seq_len
    else:
        # use specific band defined by bhw
        band[0, :] = np.maximum(seq_indices - bhw, 0)
        band[1, :] = np.minimum(seq_indices + bhw + 1, seq_len)

    # Modify bands based on invalid levels
    nan_mask = np.isin(seq_indices, np.nonzero(np.isnan(levels)))
    nan_sig_indices = np.where(nan_mask)[0]
    nan_seq_indices = seq_indices[nan_mask]
    band[0, nan_sig_indices] = nan_seq_indices
    band[1, nan_sig_indices] = nan_seq_indices + 1
    # Modify bands close to invalid levels so monotonically increasing
    band[0, :] = np.maximum.accumulate(band[0, :])
    band[1, :] = np.minimum.accumulate(band[1, ::-1])[::-1]

    # expand band around large deletions to ensure valid paths
    invalid_indices = np.where(band[0, 1:] >= band[1, :-1])[0]
    while invalid_indices.shape[0] > 0:
        band[0, invalid_indices + 1] = np.maximum(
            band[0, invalid_indices + 1] - 1, 0)
        band[1, invalid_indices] = np.minimum(band[1, invalid_indices] + 1,
                                              seq_len)
        invalid_indices = np.where(band[0, 1:] >= band[1, :-1])[0]

    return band
Пример #30
0
def open_alignment_out_file(out_dir, map_fmt, ref_names_and_lens, ref_fn):
    map_fn = mh.get_megalodon_fn(out_dir, mh.MAP_NAME) + '.' + map_fmt
    if map_fmt == 'bam': w_mode = 'wb'
    elif map_fmt == 'cram': w_mode = 'wc'
    elif map_fmt == 'sam': w_mode = 'w'
    else:
        raise mh.MegaError('Invalid mapping output format')
    return pysam.AlignmentFile(map_fn,
                               w_mode,
                               reference_names=ref_names_and_lens[0],
                               reference_lengths=ref_names_and_lens[1],
                               reference_filename=ref_fn)