コード例 #1
0
 def test_cupy_equals_torch_make_trans(self):
     trans_torch = decode.flipflop_make_trans(torch.tensor(self.scores,
                                                           device=0),
                                              _never_use_cupy=True)
     trans_cupy = decode.flipflop_make_trans(
         torch.tensor(self.scores, device=0))
     self.assertArrayEqual(trans_torch.cpu().numpy(),
                           trans_cupy.cpu().numpy())
コード例 #2
0
ファイル: basecall.py プロジェクト: xiaoying201355/taiyaki
def process_read(read_filename, read_id, model, chunk_size, overlap,
                 read_params, n_can_state, stride, alphabet, is_cat_mod,
                 mods_fp, max_concurrent_chunks):
    signal = get_signal(read_filename, read_id)
    if signal is None:
        return None, 0

    if read_params is None:
        normed_signal = med_mad_norm(signal)
    else:
        normed_signal = (signal - read_params['shift']) / read_params['scale']

    chunks, chunk_starts, chunk_ends = basecall_helpers.chunk_read(
        normed_signal, chunk_size, overlap)

    with torch.no_grad():
        device = next(model.parameters()).device
        chunks = torch.tensor(chunks, device=device)
        out = []
        for some_chunks in torch.split(chunks, max_concurrent_chunks, 1):
            out.append(model(some_chunks))
        out = torch.cat(out, 1)

        if STITCH_BEFORE_VITERBI:
            out = basecall_helpers.stitch_chunks(out, chunk_starts, chunk_ends,
                                                 stride)
            trans = flipflop_make_trans(out.unsqueeze(1)[:, :, :n_can_state])
            _, _, best_path = flipflop_viterbi(trans)
        else:
            trans = flipflop_make_trans(out[:, :, :n_can_state])
            _, _, chunk_best_paths = flipflop_viterbi(trans)
            best_path = basecall_helpers.stitch_chunks(
                chunk_best_paths,
                chunk_starts,
                chunk_ends,
                stride,
                path_stitching=is_cat_mod)

        if is_cat_mod and mods_fp is not None:
            # output modified base weights for each base call
            if STITCH_BEFORE_VITERBI:
                mod_weights = out[:, n_can_state:]
            else:
                mod_weights = basecall_helpers.stitch_chunks(
                    out[:, :, n_can_state:], chunk_starts, chunk_ends, stride)
            mods_scores = extract_mod_weights(
                mod_weights.detach().cpu().numpy(),
                best_path.detach().cpu().numpy(),
                model.sublayers[-1].can_nmods)
            mods_fp.create_dataset('Reads/' + read_id,
                                   data=mods_scores,
                                   compression="gzip")

        basecall = path_to_str(best_path.cpu().numpy(), alphabet=alphabet)

    return basecall, len(signal)
コード例 #3
0
ファイル: test_decode.py プロジェクト: udishadc/taiyaki
 def test_cupy_equals_torch_make_trans(self):
     """ Test that cupy and torch routines to calculate transition scores
     agree.
     """
     trans_torch = decode.flipflop_make_trans(torch.tensor(
         self.scores, device=0), _never_use_cupy=True)
     trans_cupy = decode.flipflop_make_trans(
         torch.tensor(self.scores, device=0))
     self.assertArrayEqual(trans_torch.cpu().numpy(),
                           trans_cupy.cpu().numpy())
コード例 #4
0
ファイル: basecall.py プロジェクト: nodrogluap/taiyaki
def process_read(read_filename,
                 read_id,
                 model,
                 chunk_size,
                 overlap,
                 read_params,
                 n_can_state,
                 stride,
                 alphabet,
                 is_cat_mod,
                 mods_fp,
                 max_concurrent_chunks,
                 fastq=False,
                 qscore_scale=1.0,
                 qscore_offset=0.0):
    """Basecall a read, dividing the samples into chunks before applying the
    basecalling network and then stitching them back together.

    :param read_filename: filename to load data from
    :param read_id: id used in comment line in fasta or fastq output
    :param model: pytorch basecalling network
    :param chunk_size: chunk size, measured in samples
    :param overlap: overlap between chunks, measured in samples
    :param read_params: dict of read params including 'shift' and 'scale'
    :param n_can_state: number of canonical flip-flop transitions (40 for ACGT)
    :param stride: stride of basecalling network (measured in samples)
    :param alphabet: python str containing alphabet (e.g. 'ACGT')
    :param is_cat_mod: bool. True for multi-level categorical mod-base model.
    :param mods_fp: h5py handle to hdf5 file prepared to accept mod base output
                    (not used unless is_cat_mod)
    :param max_concurrent_chunks: max number of chunks to basecall at same time
              (having this limit prevents running out of memory for long reads)
    :param fastq: generate fastq file with q scores if this is True. Otherwise
                  generate fasta.
    :param qscore_scale: qscore <-- qscore * qscore_scale + qscore_offset
                         before coding as fastq
    :param qscore_offset: see qscore_scale above
    :returns: tuple (basecall, qstring, len(signal))
              where basecall and qstring are python strings, except when
              fastq is False: in this case qstring is None.

    :note: fastq output implemented only for the case is_cat_mod=False
    """
    if is_cat_mod and fastq:
        raise Exception("fastq output not implemented for mod bases")

    signal = get_signal(read_filename, read_id)
    if signal is None:
        return None, 0

    if read_params is None:
        normed_signal = med_mad_norm(signal)
    else:
        normed_signal = (signal - read_params['shift']) / read_params['scale']

    chunks, chunk_starts, chunk_ends = basecall_helpers.chunk_read(
        normed_signal, chunk_size, overlap)

    qstring = None
    with torch.no_grad():
        device = next(model.parameters()).device
        chunks = torch.tensor(chunks, device=device)
        out = []
        for some_chunks in torch.split(chunks, max_concurrent_chunks, 1):
            out.append(model(some_chunks))
        out = torch.cat(out, 1)

        if STITCH_BEFORE_VITERBI:
            out = basecall_helpers.stitch_chunks(out, chunk_starts, chunk_ends,
                                                 stride)
            trans = flipflop_make_trans(out.unsqueeze(1)[:, :, :n_can_state])
            _, _, best_path = flipflop_viterbi(trans)
        else:
            trans = flipflop_make_trans(out[:, :, :n_can_state])
            _, _, chunk_best_paths = flipflop_viterbi(trans)
            best_path = basecall_helpers.stitch_chunks(
                chunk_best_paths,
                chunk_starts,
                chunk_ends,
                stride,
                path_stitching=is_cat_mod)
            if fastq:
                chunk_errprobs = qscores.errprobs_from_trans(
                    trans, chunk_best_paths)
                errprobs = basecall_helpers.stitch_chunks(
                    chunk_errprobs,
                    chunk_starts,
                    chunk_ends,
                    stride,
                    path_stitching=is_cat_mod)
                qstring = qscores.path_errprobs_to_qstring(
                    errprobs, best_path, qscore_scale, qscore_offset)

        if is_cat_mod and mods_fp is not None:
            # output modified base weights for each base call
            if STITCH_BEFORE_VITERBI:
                mod_weights = out[:, n_can_state:]
            else:
                mod_weights = basecall_helpers.stitch_chunks(
                    out[:, :, n_can_state:], chunk_starts, chunk_ends, stride)
            mods_scores = extract_mod_weights(
                mod_weights.detach().cpu().numpy(),
                best_path.detach().cpu().numpy(),
                model.sublayers[-1].can_nmods)
            mods_fp.create_dataset('Reads/' + read_id,
                                   data=mods_scores,
                                   compression="gzip")

        #Don't include first source state from the path in the basecall.
        #This makes our basecalls agree with Guppy's, and removes the
        #problem that there is no entry transition for the first path
        #element, so we don't know what the q score is.
        basecall = path_to_str(best_path.cpu().numpy(),
                               alphabet=alphabet,
                               include_first_source=False)

    return basecall, qstring, len(signal)
コード例 #5
0
ファイル: test_decode.py プロジェクト: udishadc/taiyaki
 def test_cpu_make_trans_with_grad_non_leaf_no_grad(self):
     """ Test making transition scores, complex case
     """
     scores = torch.tensor(self.scores, requires_grad=True)
     with torch.no_grad():
         decode.flipflop_make_trans(1.0 * scores)
コード例 #6
0
ファイル: test_decode.py プロジェクト: udishadc/taiyaki
 def test_cpu_make_trans_with_grad_non_leaf(self):
     """ Test making transition scores when input does require gradients
     """
     scores = torch.tensor(self.scores, requires_grad=True)
     decode.flipflop_make_trans(1.0 * scores)
コード例 #7
0
ファイル: test_decode.py プロジェクト: udishadc/taiyaki
 def test_cpu_make_trans_no_grad(self):
     """ Test making transition scores when input does not require gradients
     """
     scores = torch.tensor(self.scores, requires_grad=False)
     decode.flipflop_make_trans(scores)
コード例 #8
0
ファイル: basecall.py プロジェクト: udishadc/taiyaki
def process_read(
        read_filename, read_id, model, chunk_size, overlap, read_params,
        n_can_state, stride, alphabet, max_concurrent_chunks,
        fastq=False, qscore_scale=1.0, qscore_offset=0.0, beam=None,
        posterior=True, temperature=1.0):
    """Basecall a read, dividing the samples into chunks before applying the
    basecalling network and then stitching them back together.

    Args:
        read_filename (str): filename to load data from.
        read_id (str): id used in comment line in fasta or fastq output.
        model (:class:`nn.Module`): Taiyaki network.
        chunk_size (int): chunk size, measured in samples.
        overlap (int): overlap between chunks, measured in samples.
        read_params (dict str -> T): reads specific scaling parameters,
            including 'shift' and 'scale'.
        n_can_state (int): number of canonical flip-flop transitions (40 for
            ACGT).
        stride (int): stride of basecalling network (measured in samples)
        alphabet (str): Alphabet (e.g. 'ACGT').
        max_concurrent_chunks (int): max number of chunks to basecall at same
            time (having this limit prevents running out of memory for long
            reads).
        fastq (bool): generate fastq file with q scores if this is True,
            otherwise generate fasta.
        qscore_scale (float): Scaling factor for Q score calibration.
        qscore_offset (float): Offset for Q score calibration.
        beam (None or NamedTuple): Use beam search decoding
        posterior (bool): Decode using posterior probability of transitions
        temperature (float): Multiplier for network output

    Returns:
        tuple of str and str and int: strings containing the called bases and
            their associated Phred-encoded quality scores, and the number of
            samples in the read (before chunking).

        When `fastq` is False, `None` is returned instead of a quality string.
    """
    signal = get_signal(read_filename, read_id)
    if signal is None:
        return None, None, 0
    if model.metadata['reverse']:
        signal = signal[::-1]

    if read_params is None:
        normed_signal = med_mad_norm(signal)
    else:
        normed_signal = (signal - read_params['shift']) / read_params['scale']

    chunks, chunk_starts, chunk_ends = basecall_helpers.chunk_read(
        normed_signal, chunk_size, overlap)

    qstring = None
    with torch.no_grad():
        device = next(model.parameters()).device
        chunks = torch.tensor(chunks, device=device)
        trans = []
        for some_chunks in torch.split(chunks, max_concurrent_chunks, 1):
            trans.append(model(some_chunks)[:, :, :n_can_state])
        trans = torch.cat(trans, 1) * temperature

        if posterior:
            trans = (flipflop_make_trans(trans) + 1e-8).log()

        if beam is not None:
            trans = basecall_helpers.stitch_chunks(trans, chunk_starts,
                                                   chunk_ends, stride)
            best_path, score = decodeutil.beamsearch(trans.cpu().numpy(),
                                                     beam_width=beam.width,
                                                     guided=beam.guided)
        else:
            _, _, chunk_best_paths = flipflop_viterbi(trans)
            best_path = basecall_helpers.stitch_chunks(
                chunk_best_paths, chunk_starts, chunk_ends,
                stride).cpu().numpy()

        if fastq:
            chunk_errprobs = qscores.errprobs_from_trans(trans,
                                                         chunk_best_paths)
            errprobs = basecall_helpers.stitch_chunks(
                chunk_errprobs, chunk_starts, chunk_ends, stride)
            qstring = qscores.path_errprobs_to_qstring(errprobs, best_path,
                                                       qscore_scale,
                                                       qscore_offset)

    # This makes our basecalls agree with Guppy's, and removes the
    # problem that there is no entry transition for the first path
    # element, so we don't know what the q score is.
    basecall = path_to_str(best_path, alphabet=alphabet,
                           include_first_source=False)

    return basecall, qstring, len(signal)