def basecall(model, reads, chunksize=4000, overlap=500, batchsize=32, reverse=False): reads = (read_chunk for read in reads for read_chunk in split_read(read, chunksize * batchsize)[::-1 if reverse else 1]) chunks = (((read, start, end), chunk(torch.from_numpy(read.signal[start:end]), chunksize, overlap)) for (read, start, end) in reads) batches = ((k, compute_scores(model, batch, reverse=reverse)) for k, batch in batchify(chunks, batchsize=batchsize)) stitched = ((read, stitch(x, chunksize, overlap, end - start, model.stride, reverse=reverse)) for ((read, start, end), x) in unbatchify(batches)) transferred = thread_map(transfer, stitched, n_thread=1) return ((read, concat([part for k, part in parts])) for read, parts in groupby(transferred, lambda x: x[0]))
def stitch(chunks, start, end): """ Stitch chunks together with a given overlap """ if isinstance(chunks, dict): return {k: stitch(v, start, end) for k, v in chunks.items()} if chunks.shape[0] == 1: return chunks.squeeze(0) return concat([chunks[0, :end], *chunks[1:-1, start:end], chunks[-1, start:]])