コード例 #1
0
ファイル: pair.py プロジェクト: sirelkhatim/bonito
    def run(self):
        while True:

            job = self.queue.get()
            if job is None: return

            read_id_1, logits_1, read_id_2, logits_2 = job

            # revcomp decode the second read
            logits_2 = logits_2[::-1, [0, 4, 3, 2, 1]]

            # fast-ctc-decode expects probs (not logprobs)
            probs_1 = np.exp(logits_1)
            probs_2 = np.exp(logits_2)

            temp_seq, temp_path = beam_search(
                probs_1,
                self.alphabet,
                beam_size=16,
                beam_cut_threshold=self.threshold)
            comp_seq, comp_path = beam_search(
                probs_2,
                self.alphabet,
                beam_size=16,
                beam_cut_threshold=self.threshold)

            # catch any bad reads before attempt to align (parasail will segfault)
            if len(temp_seq) < self.minseqlen or len(
                    comp_seq) < self.minseqlen:
                continue

            # check template/complement agreement
            if accuracy(temp_seq, comp_seq) < self.match:
                continue

            env = build_envelope(probs_1.shape[0],
                                 temp_seq,
                                 temp_path,
                                 probs_2.shape[0],
                                 comp_seq,
                                 comp_path,
                                 padding=self.padding)

            consensus = beam_search_2d(probs_1,
                                       probs_2,
                                       self.alphabet,
                                       envelope=env,
                                       beam_size=self.beamsize,
                                       beam_cut_threshold=self.threshold)

            with self.lock:
                sys.stdout.write(">%s;%s;\n" % (read_id_1, read_id_2))
                sys.stdout.write("%s\n" %
                                 os.linesep.join(wrap(consensus, 100)))
                sys.stdout.flush()
コード例 #2
0
 def decode(self, x, beamsize=5, threshold=1e-3, qscores=False, return_path=False):
     if beamsize == 1 or qscores:
         seq, path  = viterbi_search(x, self.alphabet, qscores, self.qscale, self.qbias)
     else:
         seq, path = beam_search(x, self.alphabet, beamsize, threshold)
     if return_path: return seq, path
     return seq
コード例 #3
0
def decode_revised(predictions,
                   alphabet,
                   signal_data,
                   kmer_length=5,
                   beam_size=5,
                   threshold=0.1):
    """
	Decode model posteriors to sequence
	"""
    alphabet = ''.join(alphabet)
    if beam_size == 1:
        return greedy_ctc_decode(predictions, alphabet)
    seq, path = beam_search(predictions.astype(np.float32), alphabet,
                            beam_size, threshold)
    means = []
    if len(path) > 0:
        if path[0] != 0: path = [0] + path
        if path[:-1] != len(signal_data): path.append(len(signal_data))
        if kmer_length < len(seq):
            for i in range(len(seq) - kmer_length + 1):
                start_idx, end_idx = path[i], path[i + kmer_length]
                mean = np.mean(signal_data[start_idx:end_idx])
                means.append(mean)
            min_v, max_v = np.min(means), np.max(means)
            for j in range(len(means)):
                means[j] -= min_v
                means[j] /= (max_v - min_v)
                means[j] *= 255
                means[j] = means[j].astype('uint8')
        else:
            means.append(0)
    return seq, np.asarray(means)
コード例 #4
0
ファイル: util.py プロジェクト: iiSeymour/fast-bonito
def decode(x, beamsize=5, threshold=1e-3, qscores=False, return_path=False):
    """decode the output of CNN to ACTG

    Args:
        x:
        beamsize:
        threshold:
        qscores:
        return_path:

    Returns:

    """
    alphabet = ["N", "A", "C", "G", "T"]
    if beamsize == 1 or qscores:
        qbias = 2.0
        qscale = 0.7
        seq, path = viterbi_search(x, alphabet, qscores, qscale, qbias)
    else:
        qbias = 0.0
        qscale = 1.0
        seq, path = beam_search(x, alphabet, beamsize, threshold)
    if return_path:
        return seq, path
    return seq
コード例 #5
0
 def decode(self, x, beamsize=5, threshold=1e-3, qscores=False, return_path=False):
     x = x.exp().cpu().numpy().astype(np.float32)
     if beamsize == 1 or qscores:
         seq, path  = viterbi_search(x, self.alphabet, qscores, self.qscale, self.qbias)
     else:
         seq, path = beam_search(x, self.alphabet, beamsize, threshold)
     if return_path: return seq, path
     return seq
コード例 #6
0
 def test_beam_search_short_alphabet(self):
     """ simple beam search test with short alphabet"""
     self.alphabet = "NAG"
     self.probs = self.get_random_data()
     seq, path = beam_search(self.probs, self.alphabet, self.beam_size,
                             self.beam_cut_threshold)
     self.assertEqual(len(seq), len(path))
     self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
コード例 #7
0
 def test_beam_search_named_args(self):
     """ simple beam search test with named arguments"""
     seq, path = beam_search(network_output=self.probs,
                             alphabet=self.alphabet,
                             beam_size=self.beam_size,
                             beam_cut_threshold=self.beam_cut_threshold)
     self.assertEqual(len(seq), len(path))
     self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
コード例 #8
0
def decode(predictions, alphabet, beam_size=5, threshold=0.1):
    """
	Decode model posteriors to sequence
	"""
    alphabet = ''.join(alphabet)
    if beam_size == 1:
        return greedy_ctc_decode(predictions, alphabet)
    return beam_search(predictions.astype(np.float32), alphabet, beam_size,
                       threshold)
コード例 #9
0
 def test_beam_search_long_alphabet(self):
     """ simple beam search test with long alphabet"""
     self.alphabet = "NABCDEFGHIJK"
     self.probs = self.get_random_data(10000)
     seq, path = beam_search(self.probs,
                             self.alphabet,
                             self.beam_size,
                             beam_cut_threshold=0.0)
     self.assertEqual(len(seq), len(path))
     self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
コード例 #10
0
 def process(self, raw, identifiers, frame_meta):
     if not self.label_map:
         raise ConfigError('Beam Search Decoder requires dataset label map for correct decoding.')
     alphabet = list(self.label_map.values())
     raw_outputs = self._extract_predictions(raw, frame_meta)
     self.select_output_blob(raw_outputs)
     result = []
     for identifier, out in zip(identifiers, np.exp(raw_outputs[self.output_blob])):
         if self.beam_size == 1:
             seq, _ = viterbi_search(np.squeeze(out), alphabet, False, 1, 0)
         else:
             seq, _ = beam_search(np.squeeze(out.astype(np.float32)), alphabet, self.beam_size, self.threshold)
         result.append(DNASequencePrediction(identifier, seq))
     return result
コード例 #11
0
    def test_beam_search_path(self):
        """ simple beam search with path"""
        w = 5000
        x = np.zeros((w, len(self.alphabet)), np.float32)
        x[:, 0] = 0.5  # set stay prob

        # emit a base evenly spaced along w
        emit = np.arange(0, w, len(self.alphabet) - 1)
        for base, pos in enumerate(emit):
            x[pos, base % 4 + 1] = 1.0

        seq, path = beam_search(x, self.alphabet, self.beam_size,
                                self.beam_cut_threshold)
        np.testing.assert_array_equal(emit, path)
        self.assertEqual(len(seq), len(path))
コード例 #12
0
    def test_repeat_sequence_path_with_spread_probs(self):
        """ simple beam search path test with a repeated sequence with probabilities spread"""
        w = 20
        x = np.zeros((w, len(self.alphabet)), np.float32)
        x[:, 0] = 0.5  # set stay prob

        expected_path = [6, 13, 18]
        for idx in expected_path:
            x[idx - 1:idx + 1, 0] = 0.0
            x[idx - 1:idx + 1, 1] = 1.0

        seq, path = beam_search(x, self.alphabet, self.beam_size,
                                self.beam_cut_threshold)

        self.assertEqual(seq, 'AAA')
        self.assertEqual(len(seq), len(path))
        self.assertEqual(path, expected_path)
コード例 #13
0
    def test_repeat_sequence_path_with_multi_char_alpha(self):
        """ simple beam search path test with a repeated sequence and multi-char alphabet """
        w = 20
        self.alphabet = ["N", "AAA", "CCC", "GGG", "TTTT"]
        x = np.zeros((w, len(self.alphabet)), np.float32)
        x[:, 0] = 0.5  # set stay prob

        alphabet_idx = 1
        expected_path = [6, 13, 18]
        for idx in expected_path:
            x[idx, 0] = 0.0
            x[idx, alphabet_idx] = 1.0
            alphabet_idx += 1

        seq, path = beam_search(x, self.alphabet, self.beam_size,
                                self.beam_cut_threshold)

        self.assertEqual(seq, 'AAACCCGGG')
        self.assertEqual(path, expected_path)
コード例 #14
0
def ctcdecoder(logits,
               label,
               blank=False,
               beam_size=5,
               alphabet="NACGT",
               pre=None):
    ret = np.zeros((label.shape[0], label.shape[1] + 50))
    retstr = []
    for i in range(logits.shape[0]):
        if pre is not None:
            beamcur = beam_search(torch.softmax(torch.tensor(pre[:, i, :]),
                                                dim=-1).cpu().detach().numpy(),
                                  alphabet=alphabet,
                                  beam_size=beam_size)[0]
        prev = None
        cur = []
        pos = 0
        for j in range(logits.shape[1]):
            if not blank:
                if logits[i, j] != prev:
                    prev = logits[i, j]
                    try:
                        if prev != 0:
                            ret[i, pos] = prev
                            pos += 1
                            cur.append(vocab[prev])
                    except:
                        sys.stderr.write("ctcdecoder: fail on i:", i, "pos:",
                                         pos)
            else:
                if logits[i, j] == 0: break
                ret[i, pos] = logits[i, j]  # is this right?
                cur.append(vocab[logits[i, pos]])
                pos += 1
        if pre is not None:
            retstr.append(beamcur)
        else:
            retstr.append("".join(cur))
    return ret, retstr
コード例 #15
0
    def getResult(self, img, mode='path'):
        self.net.eval()

        img = img.unsqueeze(0)  #.transpose(1, 3).transpose(2, 3)
        try:
            with torch.no_grad():
                tst_o = self.net(Variable(img.cuda()))

            if mode == 'beam':
                estimated_word, _ = beam_search(
                    tst_o.softmax(2).cpu().numpy().squeeze(),
                    classes,
                    beam_size=5,
                    beam_cut_threshold=0.012195121)
            else:
                tdec = tst_o.log_softmax(2).argmax(2).cpu().numpy().squeeze()
                t_beam = beam_search_decoder(
                    tst_o.softmax(2).cpu().numpy().squeeze(), 5)

                # todo: create a better way than to just ignore output with size [1, 1, 80] (first 1 has to be >1
                tt = [
                    v for j, v in enumerate(tdec) if j == 0 or v != tdec[j - 1]
                ]
                tb = [
                    v for j, v in enumerate(t_beam)
                    if j == 0 or v != t_beam[j - 1]
                ]
                if tb != tt:
                    print('Unequal')

                estimated_word = ''.join([icdict[t] for t in tt
                                          ]).replace('_',
                                                     '').replace("v", "u")
        except:
            estimated_word = 'error'
        self.net.train()

        return estimated_word
コード例 #16
0
 def test_beam_search_mismatched_alphabet_short(self):
     """ simple beam search test with too few alphabet chars"""
     alphabet = "NAGC"
     with self.assertRaises(ValueError):
         beam_search(self.probs, alphabet, self.beam_size,
                     self.beam_cut_threshold)
コード例 #17
0
 def test_high_beam_cut_threshold(self):
     """ simple beam search test with very high beam cut threshold"""
     with self.assertRaises(ValueError):
         beam_search(self.probs, self.alphabet, self.beam_size, 1.1)
コード例 #18
0
 def test_beam_cut_threshold_boundary(self):
     """ simple beam search test with beam cut threshold of 1/len(alphabet)"""
     with self.assertRaises(ValueError):
         beam_search(self.probs, self.alphabet, self.beam_size,
                     1.0 / len(self.alphabet))
コード例 #19
0
 def test_negative_beam_cut_threshold(self):
     """ simple beam search test with beam cut threshold below 0.0"""
     with self.assertRaises(ValueError):
         beam_search(self.probs, self.alphabet, self.beam_size, -0.1)
コード例 #20
0
 def test_zero_beam_cut_threshold(self):
     """ simple beam search test with beam cut threshold of 0.0"""
     seq, path = beam_search(self.probs, self.alphabet, self.beam_size, 0.0)
     self.assertEqual(len(seq), len(path))
     self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
コード例 #21
0
 def test_zero_beam_size(self):
     """ simple beam search test with zero beam size"""
     with self.assertRaises(ValueError):
         beam_search(self.probs, self.alphabet, 0, self.beam_cut_threshold)
コード例 #22
0
 def test_beam_search_alphabet(self):
     """ simple beam search test with different alphabet"""
     seq, path = beam_search(self.probs, "NRUST", self.beam_size,
                             self.beam_cut_threshold)
     self.assertEqual(len(seq), len(path))
     self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
コード例 #23
0
 def test_beam_search_defaults(self):
     """ simple beam search test using argument defaults"""
     seq, path = beam_search(self.probs, self.alphabet)
     self.assertEqual(len(seq), len(path))
     self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
コード例 #24
0
 def test_beam_search_tuple(self):
     """ simple beam search test with the canonical alphabet as a tuple"""
     seq, path = beam_search(self.probs, tuple(self.alphabet),
                             self.beam_size, self.beam_cut_threshold)
     self.assertEqual(len(seq), len(path))
     self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
コード例 #25
0
 def test_nans(self):
     """beam_search is passed NaN values"""
     self.probs.fill(np.NaN)
     with self.assertRaisesRegexp(RuntimeError, "Failed to compare values"):
         beam_search(self.probs, self.alphabet)
コード例 #26
0
 def test_beam_search_not_enough_args(self):
     """ simple beam search test with not enough arguments"""
     with self.assertRaises(TypeError):
         beam_search(self.probs)
コード例 #27
0
 def decode_beamsearch(self, preds, length) :
     seq, path = beam_search(preds, self.alphabet, beam_size=5, beam_cut_threshold=0.1)