示例#1
0
def test(model, device, test_loader, min_coverage=0.5, criterion=None):

    if criterion is None:
        C = len(model.alphabet)
        weights = torch.cat([torch.tensor([0.4]), (0.1 / (C - 1)) * torch.ones(C - 1)]).to(device)
        criterion = partial(ctc_label_smoothing_loss, weights=weights)

    seqs = []
    model.eval()
    test_loss = 0
    accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=min_coverage)

    with torch.no_grad():
        for batch_idx, (data, target, lengths) in enumerate(test_loader, start=1):
            log_probs = model(data.to(device))
            loss = criterion(log_probs, target.to(device), lengths.to(device))
            test_loss += loss['ctc_loss'] if isinstance(loss, dict) else loss
            if hasattr(model, 'decode_batch'):
                seqs.extend(model.decode_batch(log_probs))
            else:
                seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')])

    refs = [
        decode_ref(target, model.alphabet) for target in test_loader.dataset.targets
    ]
    accuracies = [
        accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)
    ]

    mean = np.mean(accuracies)
    median = np.median(accuracies)
    return test_loss.item() / batch_idx, mean, median
示例#2
0
    def run(self):
        while True:

            job = self.queue.get()
            if job is None: return

            read_id_1, logits_1, read_id_2, logits_2 = job

            # revcomp decode the second read
            logits_2 = logits_2[::-1, [0, 4, 3, 2, 1]]

            # fast-ctc-decode expects probs (not logprobs)
            probs_1 = np.exp(logits_1)
            probs_2 = np.exp(logits_2)

            temp_seq, temp_path = beam_search(
                probs_1,
                self.alphabet,
                beam_size=16,
                beam_cut_threshold=self.threshold)
            comp_seq, comp_path = beam_search(
                probs_2,
                self.alphabet,
                beam_size=16,
                beam_cut_threshold=self.threshold)

            # catch any bad reads before attempt to align (parasail will segfault)
            if len(temp_seq) < self.minseqlen or len(
                    comp_seq) < self.minseqlen:
                continue

            # check template/complement agreement
            if accuracy(temp_seq, comp_seq) < self.match:
                continue

            env = build_envelope(probs_1.shape[0],
                                 temp_seq,
                                 temp_path,
                                 probs_2.shape[0],
                                 comp_seq,
                                 comp_path,
                                 padding=self.padding)

            consensus = beam_search_2d(probs_1,
                                       probs_2,
                                       self.alphabet,
                                       envelope=env,
                                       beam_size=self.beamsize,
                                       beam_cut_threshold=self.threshold)

            with self.lock:
                sys.stdout.write(">%s;%s;\n" % (read_id_1, read_id_2))
                sys.stdout.write("%s\n" %
                                 os.linesep.join(wrap(consensus, 100)))
                sys.stdout.flush()
示例#3
0
def decode(res,
           beamsize_1=5,
           pad_1=40,
           cut_1=0.01,
           beamsize_2=5,
           pad_2=40,
           cut_2=0.01,
           match=80,
           alphabet="NACGT"):

    temp_probs, init1 = res[0]['trans'].astype(
        np.float32), res[0]['init'][0].astype(np.float32)
    comp_probs, init2 = res[1]['trans'].astype(
        np.float32), res[1]['init'][0].astype(np.float32)

    simplex1, path1 = crf_beam_search(temp_probs,
                                      init1,
                                      alphabet,
                                      beam_size=5,
                                      beam_cut_threshold=0.01)
    simplex2, path2 = crf_beam_search(comp_probs,
                                      init2,
                                      alphabet,
                                      beam_size=5,
                                      beam_cut_threshold=0.01)

    if len(simplex1) < 10 or len(simplex2) < 10:
        return [simplex1, simplex2]

    if accuracy(simplex1, simplex2) < match:
        return [simplex1, simplex2]

    duplex1 = beam_search_duplex(simplex1,
                                 path1,
                                 temp_probs,
                                 init1,
                                 simplex2,
                                 path2,
                                 comp_probs,
                                 init2,
                                 pad=pad_1,
                                 beamsize=5,
                                 T=cut_1)
    duplex2 = beam_search_duplex(simplex2,
                                 path2,
                                 comp_probs,
                                 init2,
                                 simplex1,
                                 path1,
                                 temp_probs,
                                 init1,
                                 pad=pad_2,
                                 beamsize=5,
                                 T=cut_2)
    return [duplex1, duplex2, simplex1, simplex2]
示例#4
0
    def validate_one_step(self, batch):
        data, targets, lengths = batch

        scores = self.model(data.to(self.device))
        losses = self.criterion(scores, targets.to(self.device),
                                lengths.to(self.device))
        losses = {k: v.item()
                  for k, v in losses.items()} if isinstance(
                      losses, dict) else losses.item()
        if hasattr(self.model, 'decode_batch'):
            seqs = self.model.decode_batch(scores)
        else:
            seqs = [
                self.model.decode(x) for x in permute(scores, 'TNC', 'NTC')
            ]
        refs = [decode_ref(target, self.model.alphabet) for target in targets]
        accs = [
            accuracy(ref, seq, min_coverage=0.5) if len(seq) else 0.
            for ref, seq in zip(refs, seqs)
        ]
        return seqs, refs, accs, losses
示例#5
0
def test(model, device, test_loader, min_coverage=0.5):

    model.eval()
    test_loss = 0
    predictions = []
    prediction_lengths = []
    accuracy_with_coverage_filter = lambda ref, seq: accuracy(
        ref, seq, min_coverage=min_coverage)

    with torch.no_grad():
        for batch_idx, (data, out_lengths, target,
                        lengths) in enumerate(test_loader, start=1):
            data, target = data.to(device), target.to(device)
            log_probs = model(data)
            test_loss += ctc_loss(log_probs.transpose(1, 0), target,
                                  out_lengths // model.stride, lengths)
            predictions.append(torch.exp(log_probs).cpu())
            prediction_lengths.append(out_lengths // model.stride)

    predictions = np.concatenate(predictions)
    lengths = np.concatenate(prediction_lengths)

    references = [
        decode_ref(target, model.alphabet)
        for target in test_loader.dataset.targets
    ]
    sequences = [
        model.decode(post[:n]) for post, n in zip(predictions, lengths)
    ]

    if all(map(len, sequences)):
        accuracies = list(
            starmap(accuracy_with_coverage_filter, zip(references, sequences)))
    else:
        accuracies = [0]

    mean = np.mean(accuracies)
    median = np.median(accuracies)
    return test_loss.item() / batch_idx, mean, median
示例#6
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(
        *load_data(
            limit=args.chunks, shuffle=args.shuffle,
            directory=args.directory, validation=True
        )
    )
    dataloader = DataLoader(testdata, batch_size=args.batchsize)
    accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=args.min_coverage)

    for w in [int(i) for i in args.weights.split(',')]:

        seqs = []

        print("* loading model", w)
        model = load_model(args.model_directory, args.device, weights=w)

        print("* calling")
        t0 = time.perf_counter()

        with torch.no_grad():
            for data, *_ in dataloader:
                if half_supported():
                    data = data.type(torch.float16).to(args.device)
                else:
                    data = data.to(args.device)

                log_probs = model(data)

                if hasattr(model, 'decode_batch'):
                    seqs.extend(model.decode_batch(log_probs))
                else:
                    seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')])

        duration = time.perf_counter() - t0

        refs = [decode_ref(target, model.alphabet) for target in dataloader.dataset.targets]
        accuracies = [accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)]

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]
        consensuses = poa(poas)
        duration = time.perf_counter() - t0
        accuracies = list(starmap(accuracy_with_coverage_filter, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
示例#7
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(*load_data(limit=args.chunks,
                                       shuffle=args.shuffle,
                                       directory=args.directory,
                                       validation=True))
    dataloader = DataLoader(testdata, batch_size=args.batchsize)
    accuracy_with_coverage_filter = lambda ref, seq: accuracy(
        ref, seq, min_coverage=args.min_coverage)

    for w in [int(i) for i in args.weights.split(',')]:

        print("* loading model", w)
        model = load_model(args.model_directory,
                           args.device,
                           weights=w,
                           half=args.half)

        print("* calling")
        predictions = []
        t0 = time.perf_counter()

        with torch.no_grad():
            for data, *_ in dataloader:
                if args.half:
                    data = data.type(torch.float16).to(args.device)
                else:
                    data = data.to(args.device)
                log_probs = model(data)
                predictions.append(log_probs.exp().cpu().numpy().astype(
                    np.float32))

        duration = time.perf_counter() - t0

        references = [
            decode_ref(target, model.alphabet)
            for target in dataloader.dataset.targets
        ]
        sequences = [
            model.decode(post, beamsize=args.beamsize)
            for post in np.concatenate(predictions)
        ]
        accuracies = list(
            starmap(accuracy_with_coverage_filter, zip(references, sequences)))

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]
        consensuses = poa(poas)
        duration = time.perf_counter() - t0
        accuracies = list(
            starmap(accuracy_with_coverage_filter, zip(references,
                                                       consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)