def test(model, device, test_loader, min_coverage=0.5, criterion=None): if criterion is None: C = len(model.alphabet) weights = torch.cat([torch.tensor([0.4]), (0.1 / (C - 1)) * torch.ones(C - 1)]).to(device) criterion = partial(ctc_label_smoothing_loss, weights=weights) seqs = [] model.eval() test_loss = 0 accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=min_coverage) with torch.no_grad(): for batch_idx, (data, target, lengths) in enumerate(test_loader, start=1): log_probs = model(data.to(device)) loss = criterion(log_probs, target.to(device), lengths.to(device)) test_loss += loss['ctc_loss'] if isinstance(loss, dict) else loss if hasattr(model, 'decode_batch'): seqs.extend(model.decode_batch(log_probs)) else: seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')]) refs = [ decode_ref(target, model.alphabet) for target in test_loader.dataset.targets ] accuracies = [ accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs) ] mean = np.mean(accuracies) median = np.median(accuracies) return test_loss.item() / batch_idx, mean, median
def run(self): while True: job = self.queue.get() if job is None: return read_id_1, logits_1, read_id_2, logits_2 = job # revcomp decode the second read logits_2 = logits_2[::-1, [0, 4, 3, 2, 1]] # fast-ctc-decode expects probs (not logprobs) probs_1 = np.exp(logits_1) probs_2 = np.exp(logits_2) temp_seq, temp_path = beam_search( probs_1, self.alphabet, beam_size=16, beam_cut_threshold=self.threshold) comp_seq, comp_path = beam_search( probs_2, self.alphabet, beam_size=16, beam_cut_threshold=self.threshold) # catch any bad reads before attempt to align (parasail will segfault) if len(temp_seq) < self.minseqlen or len( comp_seq) < self.minseqlen: continue # check template/complement agreement if accuracy(temp_seq, comp_seq) < self.match: continue env = build_envelope(probs_1.shape[0], temp_seq, temp_path, probs_2.shape[0], comp_seq, comp_path, padding=self.padding) consensus = beam_search_2d(probs_1, probs_2, self.alphabet, envelope=env, beam_size=self.beamsize, beam_cut_threshold=self.threshold) with self.lock: sys.stdout.write(">%s;%s;\n" % (read_id_1, read_id_2)) sys.stdout.write("%s\n" % os.linesep.join(wrap(consensus, 100))) sys.stdout.flush()
def decode(res, beamsize_1=5, pad_1=40, cut_1=0.01, beamsize_2=5, pad_2=40, cut_2=0.01, match=80, alphabet="NACGT"): temp_probs, init1 = res[0]['trans'].astype( np.float32), res[0]['init'][0].astype(np.float32) comp_probs, init2 = res[1]['trans'].astype( np.float32), res[1]['init'][0].astype(np.float32) simplex1, path1 = crf_beam_search(temp_probs, init1, alphabet, beam_size=5, beam_cut_threshold=0.01) simplex2, path2 = crf_beam_search(comp_probs, init2, alphabet, beam_size=5, beam_cut_threshold=0.01) if len(simplex1) < 10 or len(simplex2) < 10: return [simplex1, simplex2] if accuracy(simplex1, simplex2) < match: return [simplex1, simplex2] duplex1 = beam_search_duplex(simplex1, path1, temp_probs, init1, simplex2, path2, comp_probs, init2, pad=pad_1, beamsize=5, T=cut_1) duplex2 = beam_search_duplex(simplex2, path2, comp_probs, init2, simplex1, path1, temp_probs, init1, pad=pad_2, beamsize=5, T=cut_2) return [duplex1, duplex2, simplex1, simplex2]
def validate_one_step(self, batch): data, targets, lengths = batch scores = self.model(data.to(self.device)) losses = self.criterion(scores, targets.to(self.device), lengths.to(self.device)) losses = {k: v.item() for k, v in losses.items()} if isinstance( losses, dict) else losses.item() if hasattr(self.model, 'decode_batch'): seqs = self.model.decode_batch(scores) else: seqs = [ self.model.decode(x) for x in permute(scores, 'TNC', 'NTC') ] refs = [decode_ref(target, self.model.alphabet) for target in targets] accs = [ accuracy(ref, seq, min_coverage=0.5) if len(seq) else 0. for ref, seq in zip(refs, seqs) ] return seqs, refs, accs, losses
def test(model, device, test_loader, min_coverage=0.5): model.eval() test_loss = 0 predictions = [] prediction_lengths = [] accuracy_with_coverage_filter = lambda ref, seq: accuracy( ref, seq, min_coverage=min_coverage) with torch.no_grad(): for batch_idx, (data, out_lengths, target, lengths) in enumerate(test_loader, start=1): data, target = data.to(device), target.to(device) log_probs = model(data) test_loss += ctc_loss(log_probs.transpose(1, 0), target, out_lengths // model.stride, lengths) predictions.append(torch.exp(log_probs).cpu()) prediction_lengths.append(out_lengths // model.stride) predictions = np.concatenate(predictions) lengths = np.concatenate(prediction_lengths) references = [ decode_ref(target, model.alphabet) for target in test_loader.dataset.targets ] sequences = [ model.decode(post[:n]) for post, n in zip(predictions, lengths) ] if all(map(len, sequences)): accuracies = list( starmap(accuracy_with_coverage_filter, zip(references, sequences))) else: accuracies = [0] mean = np.mean(accuracies) median = np.median(accuracies) return test_loss.item() / batch_idx, mean, median
def main(args): poas = [] init(args.seed, args.device) print("* loading data") testdata = ChunkDataSet( *load_data( limit=args.chunks, shuffle=args.shuffle, directory=args.directory, validation=True ) ) dataloader = DataLoader(testdata, batch_size=args.batchsize) accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=args.min_coverage) for w in [int(i) for i in args.weights.split(',')]: seqs = [] print("* loading model", w) model = load_model(args.model_directory, args.device, weights=w) print("* calling") t0 = time.perf_counter() with torch.no_grad(): for data, *_ in dataloader: if half_supported(): data = data.type(torch.float16).to(args.device) else: data = data.to(args.device) log_probs = model(data) if hasattr(model, 'decode_batch'): seqs.extend(model.decode_batch(log_probs)) else: seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')]) duration = time.perf_counter() - t0 refs = [decode_ref(target, model.alphabet) for target in dataloader.dataset.targets] accuracies = [accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)] if args.poa: poas.append(sequences) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration) print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration)) if args.poa: print("* doing poa") t0 = time.perf_counter() # group each sequence prediction per model together poas = [list(seq) for seq in zip(*poas)] consensuses = poa(poas) duration = time.perf_counter() - t0 accuracies = list(starmap(accuracy_with_coverage_filter, zip(references, consensuses))) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration)
def main(args): poas = [] init(args.seed, args.device) print("* loading data") testdata = ChunkDataSet(*load_data(limit=args.chunks, shuffle=args.shuffle, directory=args.directory, validation=True)) dataloader = DataLoader(testdata, batch_size=args.batchsize) accuracy_with_coverage_filter = lambda ref, seq: accuracy( ref, seq, min_coverage=args.min_coverage) for w in [int(i) for i in args.weights.split(',')]: print("* loading model", w) model = load_model(args.model_directory, args.device, weights=w, half=args.half) print("* calling") predictions = [] t0 = time.perf_counter() with torch.no_grad(): for data, *_ in dataloader: if args.half: data = data.type(torch.float16).to(args.device) else: data = data.to(args.device) log_probs = model(data) predictions.append(log_probs.exp().cpu().numpy().astype( np.float32)) duration = time.perf_counter() - t0 references = [ decode_ref(target, model.alphabet) for target in dataloader.dataset.targets ] sequences = [ model.decode(post, beamsize=args.beamsize) for post in np.concatenate(predictions) ] accuracies = list( starmap(accuracy_with_coverage_filter, zip(references, sequences))) if args.poa: poas.append(sequences) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration) print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration)) if args.poa: print("* doing poa") t0 = time.perf_counter() # group each sequence prediction per model together poas = [list(seq) for seq in zip(*poas)] consensuses = poa(poas) duration = time.perf_counter() - t0 accuracies = list( starmap(accuracy_with_coverage_filter, zip(references, consensuses))) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration)