示例#1
0
 def run(self):
     while True:
         job = self.queue.get()
         if job is None: return
         read_id, predictions = job
         sequence = decode(predictions, self.alphabet, self.beamsize)
         sys.stdout.write(">%s\n" % read_id)
         sys.stdout.write("%s\n" %
                          os.linesep.join(wrap(sequence, self.wrap)))
         sys.stdout.flush()
示例#2
0
def test(model, device, test_loader):

    model.eval()
    test_loss = 0
    predictions = []
    prediction_lengths = []

    with torch.no_grad():
        for batch_idx, (data, out_lengths, target,
                        lengths) in enumerate(test_loader, start=1):
            data, target = data.to(device), target.to(device)
            log_probs = model(data)
            test_loss += criterion(log_probs.transpose(1, 0), target,
                                   out_lengths / model.stride, lengths)
            predictions.append(torch.exp(log_probs).cpu())
            prediction_lengths.append(out_lengths / model.stride)

    predictions = np.concatenate(predictions)
    lengths = np.concatenate(prediction_lengths)

    references = [
        decode_ref(target, model.alphabet)
        for target in test_loader.dataset.targets
    ]
    sequences = [
        decode(post[:n], model.alphabet)
        for post, n in zip(predictions, lengths)
    ]

    if all(map(len, sequences)):
        accuracies = list(starmap(accuracy, zip(references, sequences)))
    else:
        accuracies = [0]

    mean = np.mean(accuracies)
    median = np.median(accuracies)
    return test_loss.item() / batch_idx, mean, median
示例#3
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(*load_data(
        limit=args.chunks, shuffle=args.shuffle, directory=args.directory))
    dataloader = DataLoader(testdata, batch_size=args.batchsize)

    for w in [int(i) for i in args.weights.split(',')]:

        print("* loading model", w)
        model = load_model(args.model_directory,
                           args.device,
                           weights=w,
                           half=args.half)

        print("* calling")
        predictions = []
        t0 = time.perf_counter()

        with torch.no_grad():
            for data, *_ in dataloader:
                if args.half:
                    data = data.type(torch.float16).to(args.device)
                else:
                    data = data.to(args.device)
                log_probs = model(data)
                predictions.append(log_probs.exp().cpu().numpy().astype(
                    np.float32))

        duration = time.perf_counter() - t0

        references = [
            decode_ref(target, model.alphabet)
            for target in dataloader.dataset.targets
        ]
        sequences = [
            decode(post, model.alphabet, args.beamsize)
            for post in np.concatenate(predictions)
        ]
        accuracies = list(starmap(accuracy, zip(references, sequences)))

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]
        consensuses = poa(poas)
        duration = time.perf_counter() - t0
        accuracies = list(starmap(accuracy, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)