Ejemplo n.º 1
0
def process_read(read_filename, read_id, model, chunk_size, overlap,
                 read_params, n_can_state, stride, alphabet, is_cat_mod,
                 mods_fp, max_concurrent_chunks):
    signal = get_signal(read_filename, read_id)
    if signal is None:
        return None, 0

    if read_params is None:
        normed_signal = med_mad_norm(signal)
    else:
        normed_signal = (signal - read_params['shift']) / read_params['scale']

    chunks, chunk_starts, chunk_ends = basecall_helpers.chunk_read(
        normed_signal, chunk_size, overlap)

    with torch.no_grad():
        device = next(model.parameters()).device
        chunks = torch.tensor(chunks, device=device)
        out = []
        for some_chunks in torch.split(chunks, max_concurrent_chunks, 1):
            out.append(model(some_chunks))
        out = torch.cat(out, 1)

        if STITCH_BEFORE_VITERBI:
            out = basecall_helpers.stitch_chunks(out, chunk_starts, chunk_ends,
                                                 stride)
            trans = flipflop_make_trans(out.unsqueeze(1)[:, :, :n_can_state])
            _, _, best_path = flipflop_viterbi(trans)
        else:
            trans = flipflop_make_trans(out[:, :, :n_can_state])
            _, _, chunk_best_paths = flipflop_viterbi(trans)
            best_path = basecall_helpers.stitch_chunks(
                chunk_best_paths,
                chunk_starts,
                chunk_ends,
                stride,
                path_stitching=is_cat_mod)

        if is_cat_mod and mods_fp is not None:
            # output modified base weights for each base call
            if STITCH_BEFORE_VITERBI:
                mod_weights = out[:, n_can_state:]
            else:
                mod_weights = basecall_helpers.stitch_chunks(
                    out[:, :, n_can_state:], chunk_starts, chunk_ends, stride)
            mods_scores = extract_mod_weights(
                mod_weights.detach().cpu().numpy(),
                best_path.detach().cpu().numpy(),
                model.sublayers[-1].can_nmods)
            mods_fp.create_dataset('Reads/' + read_id,
                                   data=mods_scores,
                                   compression="gzip")

        basecall = path_to_str(best_path.cpu().numpy(), alphabet=alphabet)

    return basecall, len(signal)
Ejemplo n.º 2
0
def process_read(
        read_filename, read_id, model, chunk_size, overlap, device,
        n_can_state, stride, alphabet, is_cat_mod, mods_fp):
    signal = get_signal(read_filename, read_id)
    if signal is None:
        return None, 0
    normed_signal = med_mad_norm(signal)
    chunks, chunk_starts, chunk_ends = basecall_helpers.chunk_read(
        normed_signal, chunk_size, overlap)
    with torch.no_grad():
        out = model(torch.tensor(chunks, device=device))
        if is_cat_mod:
            if STITCH_BEFORE_VITERBI:
                out = basecall_helpers.stitch_chunks(
                    out, chunk_starts, chunk_ends, stride)
                trans, _, _ = flipflop_make_trans(
                    out.unsqueeze(1)[:,:,:n_can_state])
                _, _, best_path = flipflop_viterbi(trans)
            else:
                trans, _, _ = flipflop_make_trans(out[:,:,:n_can_state])
                _, _, chunk_best_paths = flipflop_viterbi(trans)
                best_path = basecall_helpers.stitch_chunks(
                    chunk_best_paths, chunk_starts, chunk_ends, stride,
                    path_stitching=True)
            if mods_fp is not None:
                # output modified base weights for each base call
                if STITCH_BEFORE_VITERBI:
                    mod_weights = out[:,n_can_state:]
                else:
                    mod_weights = basecall_helpers.stitch_chunks(
                        out[:,:,n_can_state:], chunk_starts, chunk_ends, stride)
                mods_scores = extract_mod_weights(
                    mod_weights.detach().cpu().numpy(),
                    best_path.detach().cpu().numpy(),
                    model.sublayers[-1].can_nmods)
                mods_fp.create_dataset(
                    'Reads/' + read_id, data=mods_scores,
                    compression="gzip")
        else:
            if STITCH_BEFORE_VITERBI:
                out = basecall_helpers.stitch_chunks(
                    out, chunk_starts, chunk_ends, stride)
                trans, _, _ = flipflop_make_trans(out.unsqueeze(1))
                _, _, best_path = flipflop_viterbi(trans)
            else:
                trans, _, _ = flipflop_make_trans(out)
                _, _, chunk_best_paths = flipflop_viterbi(trans)
                best_path = basecall_helpers.stitch_chunks(
                    chunk_best_paths, chunk_starts, chunk_ends, stride)

        basecall = path_to_str(
            best_path.cpu().numpy(), alphabet=alphabet)

    return basecall, len(signal)
Ejemplo n.º 3
0
def process_read(read_filename,
                 read_id,
                 model,
                 chunk_size,
                 overlap,
                 read_params,
                 n_can_state,
                 stride,
                 alphabet,
                 is_cat_mod,
                 mods_fp,
                 max_concurrent_chunks,
                 fastq=False,
                 qscore_scale=1.0,
                 qscore_offset=0.0):
    """Basecall a read, dividing the samples into chunks before applying the
    basecalling network and then stitching them back together.

    :param read_filename: filename to load data from
    :param read_id: id used in comment line in fasta or fastq output
    :param model: pytorch basecalling network
    :param chunk_size: chunk size, measured in samples
    :param overlap: overlap between chunks, measured in samples
    :param read_params: dict of read params including 'shift' and 'scale'
    :param n_can_state: number of canonical flip-flop transitions (40 for ACGT)
    :param stride: stride of basecalling network (measured in samples)
    :param alphabet: python str containing alphabet (e.g. 'ACGT')
    :param is_cat_mod: bool. True for multi-level categorical mod-base model.
    :param mods_fp: h5py handle to hdf5 file prepared to accept mod base output
                    (not used unless is_cat_mod)
    :param max_concurrent_chunks: max number of chunks to basecall at same time
              (having this limit prevents running out of memory for long reads)
    :param fastq: generate fastq file with q scores if this is True. Otherwise
                  generate fasta.
    :param qscore_scale: qscore <-- qscore * qscore_scale + qscore_offset
                         before coding as fastq
    :param qscore_offset: see qscore_scale above
    :returns: tuple (basecall, qstring, len(signal))
              where basecall and qstring are python strings, except when
              fastq is False: in this case qstring is None.

    :note: fastq output implemented only for the case is_cat_mod=False
    """
    if is_cat_mod and fastq:
        raise Exception("fastq output not implemented for mod bases")

    signal = get_signal(read_filename, read_id)
    if signal is None:
        return None, 0

    if read_params is None:
        normed_signal = med_mad_norm(signal)
    else:
        normed_signal = (signal - read_params['shift']) / read_params['scale']

    chunks, chunk_starts, chunk_ends = basecall_helpers.chunk_read(
        normed_signal, chunk_size, overlap)

    qstring = None
    with torch.no_grad():
        device = next(model.parameters()).device
        chunks = torch.tensor(chunks, device=device)
        out = []
        for some_chunks in torch.split(chunks, max_concurrent_chunks, 1):
            out.append(model(some_chunks))
        out = torch.cat(out, 1)

        if STITCH_BEFORE_VITERBI:
            out = basecall_helpers.stitch_chunks(out, chunk_starts, chunk_ends,
                                                 stride)
            trans = flipflop_make_trans(out.unsqueeze(1)[:, :, :n_can_state])
            _, _, best_path = flipflop_viterbi(trans)
        else:
            trans = flipflop_make_trans(out[:, :, :n_can_state])
            _, _, chunk_best_paths = flipflop_viterbi(trans)
            best_path = basecall_helpers.stitch_chunks(
                chunk_best_paths,
                chunk_starts,
                chunk_ends,
                stride,
                path_stitching=is_cat_mod)
            if fastq:
                chunk_errprobs = qscores.errprobs_from_trans(
                    trans, chunk_best_paths)
                errprobs = basecall_helpers.stitch_chunks(
                    chunk_errprobs,
                    chunk_starts,
                    chunk_ends,
                    stride,
                    path_stitching=is_cat_mod)
                qstring = qscores.path_errprobs_to_qstring(
                    errprobs, best_path, qscore_scale, qscore_offset)

        if is_cat_mod and mods_fp is not None:
            # output modified base weights for each base call
            if STITCH_BEFORE_VITERBI:
                mod_weights = out[:, n_can_state:]
            else:
                mod_weights = basecall_helpers.stitch_chunks(
                    out[:, :, n_can_state:], chunk_starts, chunk_ends, stride)
            mods_scores = extract_mod_weights(
                mod_weights.detach().cpu().numpy(),
                best_path.detach().cpu().numpy(),
                model.sublayers[-1].can_nmods)
            mods_fp.create_dataset('Reads/' + read_id,
                                   data=mods_scores,
                                   compression="gzip")

        #Don't include first source state from the path in the basecall.
        #This makes our basecalls agree with Guppy's, and removes the
        #problem that there is no entry transition for the first path
        #element, so we don't know what the q score is.
        basecall = path_to_str(best_path.cpu().numpy(),
                               alphabet=alphabet,
                               include_first_source=False)

    return basecall, qstring, len(signal)
Ejemplo n.º 4
0
def process_read(
        read_filename, read_id, model, chunk_size, overlap, read_params,
        n_can_state, stride, alphabet, max_concurrent_chunks,
        fastq=False, qscore_scale=1.0, qscore_offset=0.0, beam=None,
        posterior=True, temperature=1.0):
    """Basecall a read, dividing the samples into chunks before applying the
    basecalling network and then stitching them back together.

    Args:
        read_filename (str): filename to load data from.
        read_id (str): id used in comment line in fasta or fastq output.
        model (:class:`nn.Module`): Taiyaki network.
        chunk_size (int): chunk size, measured in samples.
        overlap (int): overlap between chunks, measured in samples.
        read_params (dict str -> T): reads specific scaling parameters,
            including 'shift' and 'scale'.
        n_can_state (int): number of canonical flip-flop transitions (40 for
            ACGT).
        stride (int): stride of basecalling network (measured in samples)
        alphabet (str): Alphabet (e.g. 'ACGT').
        max_concurrent_chunks (int): max number of chunks to basecall at same
            time (having this limit prevents running out of memory for long
            reads).
        fastq (bool): generate fastq file with q scores if this is True,
            otherwise generate fasta.
        qscore_scale (float): Scaling factor for Q score calibration.
        qscore_offset (float): Offset for Q score calibration.
        beam (None or NamedTuple): Use beam search decoding
        posterior (bool): Decode using posterior probability of transitions
        temperature (float): Multiplier for network output

    Returns:
        tuple of str and str and int: strings containing the called bases and
            their associated Phred-encoded quality scores, and the number of
            samples in the read (before chunking).

        When `fastq` is False, `None` is returned instead of a quality string.
    """
    signal = get_signal(read_filename, read_id)
    if signal is None:
        return None, None, 0
    if model.metadata['reverse']:
        signal = signal[::-1]

    if read_params is None:
        normed_signal = med_mad_norm(signal)
    else:
        normed_signal = (signal - read_params['shift']) / read_params['scale']

    chunks, chunk_starts, chunk_ends = basecall_helpers.chunk_read(
        normed_signal, chunk_size, overlap)

    qstring = None
    with torch.no_grad():
        device = next(model.parameters()).device
        chunks = torch.tensor(chunks, device=device)
        trans = []
        for some_chunks in torch.split(chunks, max_concurrent_chunks, 1):
            trans.append(model(some_chunks)[:, :, :n_can_state])
        trans = torch.cat(trans, 1) * temperature

        if posterior:
            trans = (flipflop_make_trans(trans) + 1e-8).log()

        if beam is not None:
            trans = basecall_helpers.stitch_chunks(trans, chunk_starts,
                                                   chunk_ends, stride)
            best_path, score = decodeutil.beamsearch(trans.cpu().numpy(),
                                                     beam_width=beam.width,
                                                     guided=beam.guided)
        else:
            _, _, chunk_best_paths = flipflop_viterbi(trans)
            best_path = basecall_helpers.stitch_chunks(
                chunk_best_paths, chunk_starts, chunk_ends,
                stride).cpu().numpy()

        if fastq:
            chunk_errprobs = qscores.errprobs_from_trans(trans,
                                                         chunk_best_paths)
            errprobs = basecall_helpers.stitch_chunks(
                chunk_errprobs, chunk_starts, chunk_ends, stride)
            qstring = qscores.path_errprobs_to_qstring(errprobs, best_path,
                                                       qscore_scale,
                                                       qscore_offset)

    # This makes our basecalls agree with Guppy's, and removes the
    # problem that there is no entry transition for the first path
    # element, so we don't know what the q score is.
    basecall = path_to_str(best_path, alphabet=alphabet,
                           include_first_source=False)

    return basecall, qstring, len(signal)