def run_model( normed_signal, model, chunk_size=_DEFAULT_CHUNK_SIZE, overlap=_DEFAULT_OVERLAP, max_concur_chunks=None, return_numpy=True, return_tensor_on_device=True): """ Hook for megalodon to run network via taiyaki """ device = next(model.parameters()).device stride = guess_model_stride(model) chunk_size *= stride overlap *= stride chunks, chunk_starts, chunk_ends = chunk_read( normed_signal, chunk_size, overlap) device = next(model.parameters()).device chunks = torch.tensor(chunks) with torch.no_grad(): if max_concur_chunks is None: out = model(chunks.to(device)).cpu() else: out = [] for some_chunks in torch.split(chunks, max_concur_chunks, 1): out.append(model(some_chunks.to(device)).cpu()) out = torch.cat(out, 1) stitched_chunks = stitch_chunks( out, chunk_starts, chunk_ends, stride) if return_numpy: return stitched_chunks.numpy() if return_tensor_on_device: return stitched_chunks.to(device) return stitched_chunks
def run_model(normed_signal, model, chunk_size=_DEFAULT_CHUNK_SIZE, overlap=_DEFAULT_OVERLAP, max_concur_chunks=None, return_numpy=True): """ Hook for megalodon to run network via taiyaki """ device = next(model.parameters()).device stride = guess_model_stride(model) chunk_size, overlap = round_chunk_values(chunk_size, overlap, stride) chunks, chunk_starts, chunk_ends = chunk_read(normed_signal, chunk_size, overlap) device = next(model.parameters()).device with torch.no_grad(): if max_concur_chunks is None: out = model(torch.tensor(chunks, device=device)) else: out = [] for super_chunk_i in range( np.ceil(chunks.shape[1] / max_concur_chunks).astype(int)): sc_start, sc_end = (super_chunk_i * max_concur_chunks, (super_chunk_i + 1) * max_concur_chunks) sc_chunks = np.ascontiguousarray(chunks[:, sc_start:sc_end]) out.append(model(torch.tensor(sc_chunks, device=device))) out = torch.cat(out, 1) stitched_chunks = stitch_chunks(out, chunk_starts, chunk_ends, stride) if return_numpy: return stitched_chunks.cpu().numpy() return stitched_chunks
def worker_init(device, modelname, chunk_size, overlap, read_params, alphabet, max_concurrent_chunks, fastq, qscore_scale, qscore_offset, beam, posterior, temperature): global all_read_params global process_read_partial all_read_params = read_params device = helpers.set_torch_device(device) model = load_model(modelname).to(device) stride = guess_model_stride(model) chunk_size = chunk_size * stride overlap = overlap * stride n_can_base = len(alphabet) n_can_state = nstate_flipflop(n_can_base) def process_read_partial(read_filename, read_id, read_params): res = process_read(read_filename, read_id, model, chunk_size, overlap, read_params, n_can_state, stride, alphabet, max_concurrent_chunks, fastq, qscore_scale, qscore_offset, beam, posterior, temperature) return (read_id, *res)
def main(): args = parser.parse_args() device = helpers.set_torch_device(args.device) # TODO convert to logging sys.stderr.write("* Loading model.\n") model = load_model(args.model).to(device) is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod) do_output_mods = args.modified_base_output is not None if do_output_mods and not is_cat_mod: sys.stderr.write( "Cannot output modified bases from canonical base only model.") sys.exit() n_can_states = nstate_flipflop(model.sublayers[-1].nbase) stride = guess_model_stride(model) chunk_size = args.chunk_size * stride chunk_overlap = args.overlap * stride sys.stderr.write("* Initializing reads file search.\n") fast5_reads = list( fast5utils.iterate_fast5_reads(args.input_folder, limit=args.limit, strand_list=args.input_strand_list, recursive=args.recursive)) sys.stderr.write("* Found {} reads.\n".format(len(fast5_reads))) if args.scaling is not None: sys.stderr.write("* Loading read scaling parameters from {}.\n".format( args.scaling)) all_read_params = get_per_read_params_dict_from_tsv(args.scaling) input_read_ids = frozenset(rec[1] for rec in fast5_reads) scaling_read_ids = frozenset(all_read_params.keys()) sys.stderr.write("* {} / {} reads have scaling information.\n".format( len(input_read_ids & scaling_read_ids), len(input_read_ids))) fast5_reads = [ rec for rec in fast5_reads if rec[1] in scaling_read_ids ] else: all_read_params = {} mods_fp = None if do_output_mods: mods_fp = h5py.File(args.modified_base_output) mods_fp.create_group('Reads') mod_long_names = model.sublayers[-1].ordered_mod_long_names sys.stderr.write("* Preparing modified base output: {}.\n".format( ', '.join(map(str, mod_long_names)))) mods_fp.create_dataset('mod_long_names', data=np.array(mod_long_names, dtype='S'), dtype=h5py.special_dtype(vlen=str)) sys.stderr.write("* Calling reads.\n") nbase, ncalled, nread, nsample = 0, 0, 0, 0 t0 = time.time() progress = Progress(quiet=args.quiet) startcharacter = '@' if args.fastq else '>' try: with open_file_or_stdout(args.output) as fh: for read_filename, read_id in fast5_reads: read_params = all_read_params[ read_id] if read_id in all_read_params else None basecall, qstring, read_nsample = process_read( read_filename, read_id, model, chunk_size, chunk_overlap, read_params, n_can_states, stride, args.alphabet, is_cat_mod, mods_fp, args.max_concurrent_chunks, args.fastq, args.qscore_scale, args.qscore_offset) if basecall is not None: fh.write("{}{}\n{}\n".format( startcharacter, read_id, basecall[::-1] if args.reverse else basecall)) nbase += len(basecall) ncalled += 1 if args.fastq: fh.write("+\n{}\n".format( qstring[::-1] if args.reverse else qstring)) nread += 1 nsample += read_nsample progress.step() finally: if mods_fp is not None: mods_fp.close() total_time = time.time() - t0 sys.stderr.write("* Called {} reads in {:.2f}s\n".format( nread, int(total_time))) sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0)) sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0)) sys.stderr.write("* {} reads failed.\n".format(nread - ncalled)) return
def oneread_remap(read_tuple, references, model, device, per_read_params_dict, alphabet_info): """ Worker function for remapping reads using flip-flop model on raw signal :param read_tuple : read, identified by a tuple (filepath, read_id) :param references :dict mapping fast5 filenames to reference strings :param model :pytorch model (the torch data structure, not a filename) :param device :integer specifying which GPU to use for remapping, or 'cpu' to use CPU :param per_read_params_dict :dictionary where keys are UUIDs, values are dicts containing keys trim_start trim_end shift scale :param alphabet_info : AlphabetInfo object for basecalling :returns: tuple of dictionary as specified in mapped_signal_files.Read class and a message string indicating an error if one occured """ filename, read_id = read_tuple try: with fast5_interface.get_fast5_file(filename, 'r') as f5file: read = f5file.get_read(read_id) sig = signal.Signal(read) except Exception: return None, READ_ID_INFO_NOT_FOUND_ERR_TEXT if read_id in references: read_ref = references[read_id] else: return None, NO_REF_FOUND_ERR_TEXT try: read_params_dict = per_read_params_dict[read_id] except KeyError: return None, NO_PARAMS_ERR_TEXT sig.set_trim_absolute(read_params_dict['trim_start'], read_params_dict['trim_end']) try: torch.set_num_threads( 1 ) # Prevents torch doing its own parallelisation on top of our imap_map # Standardise (i.e. shift/scale so that approximately mean =0, std=1) signalArray = (sig.current - read_params_dict['shift']) / read_params_dict['scale'] # Make signal into 3D tensor with shape [siglength,1,1] and move to appropriate device (GPU number or CPU) signalTensor = torch.tensor( signalArray[:, np.newaxis, np.newaxis].astype(taiyaki_dtype), device=device) # The model must live on the same device modelOnDevice = model.to(device) # Apply the network to the signal, generating transition weight matrix, and put it back into a numpy array with torch.no_grad(): transweights = modelOnDevice(signalTensor).cpu().numpy() except Exception: return None, REMAP_ERR_TEXT # Extra dimensions introduced by np.newaxis above removed by np.squeeze can_read_ref = alphabet_info.collapse_sequence(read_ref) remappingscore, path = flipflop_remap.flipflop_remap( np.squeeze(transweights), can_read_ref, alphabet=alphabet_info.can_bases, localpen=0.0) # read_ref comes out as a bytes object, so we need to convert to str # localpen=0.0 does local alignment # flipflop_remap() establishes a mapping between the network outputs and the reference. # What we need is a mapping between the signal and the reference. # To resolve this we need to know the stride of the model (how many samples for each network output) model_stride = helpers.guess_model_stride(model) remapping = mapping.Mapping.from_remapping_path(sig, path, read_ref, model_stride) remapping.add_integer_reference(alphabet_info.alphabet) return remapping.get_read_dictionary(read_params_dict['shift'], read_params_dict['scale'], read_id), REMAP_SUCCESS_TEXT
def load_network(args, alphabet_info, res_info, log): log.write('* Reading network from {}\n'.format(args.model)) if res_info.is_lead_process: # Under pytorch's DistributedDataParallel scheme, we # need a clone of the start network to use as a template for saving # checkpoints. Necessary because DistributedParallel makes the class # structure different. model_kwargs = { 'stride': args.stride, 'winlen': args.winlen, 'insize': 1, 'size': args.size, 'alphabet_info': alphabet_info } model_metadata = { 'reverse': args.reverse, 'standardize': args.standardize } net_clone = helpers.load_model(args.model, model_metadata=model_metadata, **model_kwargs) log.write('* Network has {} parameters.\n'.format( sum(p.nelement() for p in net_clone.parameters()))) if not alphabet_info.is_compatible_model(net_clone): sys.stderr.write( '* ERROR: Model and mapped signal files contain ' + 'incompatible alphabet definitions (including modified ' + 'bases).') sys.exit(1) if layers.is_cat_mod_model(net_clone): log.write('* Loaded categorical modified base model.\n') if not alphabet_info.contains_modified_bases(): sys.stderr.write( '* ERROR: Modified bases model specified, but mapped ' + 'signal file does not contain modified bases.') sys.exit(1) else: log.write('* Loaded standard (canonical bases-only) model.\n') if alphabet_info.contains_modified_bases(): sys.stderr.write( '* ERROR: Standard (canonical bases only) model ' + 'specified, but mapped signal file does contains ' + 'modified bases.') sys.exit(1) if layers.is_delta_model(net_clone) and model_metadata.standardize: log.write('*' * 60 + '\n* WARNING: Delta-scaling models trained ' + 'with --standardize are not compatible with Guppy.\n' + '*' * 60) log.write('* Dumping initial model\n') helpers.save_model(net_clone, args.outdir, 0) else: net_clone = None if res_info.is_multi_gpu: # so that processes 1,2,3.. don't try to load before process 0 has # saved torch.distributed.barrier() log.write('* MultiGPU process {}'.format(args.local_rank)) log.write(': loading initial model saved by process 0\n') saved_startmodel_path = os.path.join( args.outdir, 'model_checkpoint_00000.checkpoint') network = helpers.load_model(saved_startmodel_path).to(res_info.device) network_metadata = parse_network_metadata(network) # Wrap network for training in the DistributedDataParallel structure network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[args.local_rank], output_device=args.local_rank) else: log.write('* Loading model onto device\n') network = net_clone.to(res_info.device) network_metadata = parse_network_metadata(network) net_clone = None log.write('* Estimating filter parameters from training data\n') stride = guess_model_stride(network) optimiser = torch.optim.AdamW(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay, eps=args.eps) lr_warmup = args.lr_min if args.lr_warmup is None else args.lr_warmup adam_beta1, _ = args.adam if args.warmup_batches >= args.niteration: sys.stderr.write('* Error: --warmup_batches must be < --niteration\n') sys.exit(1) warmup_fraction = args.warmup_batches / args.niteration # Pytorch OneCycleLR crashes if pct_start==1 (i.e. warmup_fraction==1) lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimiser, args.lr_max, total_steps=args.niteration, # pct_start is really fractional, not percent pct_start=warmup_fraction, div_factor=args.lr_max / lr_warmup, final_div_factor=lr_warmup / args.lr_min, cycle_momentum=(args.min_momentum is not None), base_momentum=adam_beta1 if args.min_momentum is None \ else args.min_momentum, max_momentum=adam_beta1 ) log.write( ('* Learning rate increases from {:.2e} to {:.2e} over {} ' + 'iterations using cosine schedule.\n').format(lr_warmup, args.lr_max, args.warmup_batches)) log.write(('* Then learning rate decreases from {:.2e} to {:.2e} over ' + '{} iterations using cosine schedule.\n').format( args.lr_max, args.lr_min, args.niteration - args.warmup_batches)) if args.gradient_clip_num_mads is None: log.write('* No gradient clipping\n') rolling_mads = None else: nparams = len([p for p in network.parameters() if p.requires_grad]) if nparams == 0: rolling_mads = None log.write('* No gradient clipping due to missing parameters\n') else: rolling_mads = maths.RollingMAD(nparams, args.gradient_clip_num_mads) log.write(( '* Gradients will be clipped (by value) at {:3.2f} MADs ' + 'above the median of the last {} gradient maximums.\n').format( rolling_mads.n_mads, rolling_mads.window)) net_info = NETWORK_INFO(net=network, net_clone=net_clone, metadata=network_metadata, stride=stride) optim_info = OPTIM_INFO(optimiser=optimiser, lr_warmup=lr_warmup, lr_scheduler=lr_scheduler, rolling_mads=rolling_mads) return net_info, optim_info
def oneread_remap(read_tuple, references, model, device, per_read_params_dict, alphabet, collapse_alphabet): """ Worker function for remapping reads using flip-flop model on raw signal :param read_tuple : read, identified by a tuple (filepath, read_id) :param references :dict mapping fast5 filenames to reference strings :param model :pytorch model (the torch data structure, not a filename) :param device :integer specifying which GPU to use for remapping, or 'cpu' to use CPU :param per_read_params_dict :dictionary where keys are UUIDs, values are dicts containing keys trim_start trim_end shift scale :param alphabet : alphabet for basecalling (passed on to mapped-read file) :param collapse_alphabet : collapsed alphabet for basecalling (passed on to mapped-read file) :returns: dictionary as specified in mapped_signal_files.Read class """ filename, read_id = read_tuple try: with fast5_interface.get_fast5_file(filename, 'r') as f5file: read = f5file.get_read(read_id) sig = signal.Signal(read) except Exception as e: # We want any single failure in the batch of reads to not disrupt other reads being processed. sys.stderr.write( 'No read information on read {} found in file {}.\n{}\n'.format( read_id, filename, repr(e))) return None if read_id in references: read_ref = references[read_id].decode("utf-8") else: sys.stderr.write('No fasta reference found for {}.\n'.format(read_id)) return None if read_id in per_read_params_dict: read_params_dict = per_read_params_dict[read_id] else: return None sig.set_trim_absolute(read_params_dict['trim_start'], read_params_dict['trim_end']) try: torch.set_num_threads( 1 ) # Prevents torch doing its own parallelisation on top of our imap_map # Standardise (i.e. shift/scale so that approximately mean =0, std=1) signalArray = (sig.current - read_params_dict['shift']) / read_params_dict['scale'] # Make signal into 3D tensor with shape [siglength,1,1] and move to appropriate device (GPU number or CPU) signalTensor = torch.tensor( signalArray[:, np.newaxis, np.newaxis].astype(taiyaki_dtype), device=device) # The model must live on the same device modelOnDevice = model.to(device) # Apply the network to the signal, generating transition weight matrix, and put it back into a numpy array with torch.no_grad(): transweights = modelOnDevice(signalTensor).cpu().numpy() except Exception as e: sys.stderr.write( "Failure applying basecall network to remap read {}.\n{}\n".format( sig.read_id, repr(e))) return None # Extra dimensions introduced by np.newaxis above removed by np.squeeze remappingscore, path = flipflop_remap.flipflop_remap( np.squeeze(transweights), read_ref, localpen=0.0) # read_ref comes out as a bytes object, so we need to convert to str # localpen=0.0 does local alignment # flipflop_remap() establishes a mapping between the network outputs and the reference. # What we need is a mapping between the signal and the reference. # To resolve this we need to know the stride of the model (how many samples for each network output) model_stride = helpers.guess_model_stride(model, device=device) remapping = mapping.Mapping.from_remapping_path(sig, path, read_ref, model_stride) return remapping.get_read_dictionary(read_params_dict['shift'], read_params_dict['scale'], read_id, alphabet=alphabet, collapse_alphabet=collapse_alphabet)
def main(): args = parser.parse_args() assert args.device != 'cpu', "Flipflop basecalling in taiyaki requires a GPU and for cupy to be installed" device = torch.device(args.device) # TODO convert to logging sys.stderr.write("* Loading model.\n") model = load_model(args.model).to(device) is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod) do_output_mods = args.modified_base_output is not None if do_output_mods and not is_cat_mod: sys.stderr.write( "Cannot output modified bases from canonical base only model.") sys.exit() n_can_states = nstate_flipflop(model.sublayers[-1].nbase) stride = guess_model_stride(model, device=device) chunk_size, chunk_overlap = basecall_helpers.round_chunk_values( args.chunk_size, args.overlap, stride) sys.stderr.write("* Initializing reads file search.\n") fast5_reads = fast5utils.iterate_fast5_reads( args.input_folder, limit=args.limit, strand_list=args.input_strand_list, recursive=args.recursive) mods_fp = None if do_output_mods: mods_fp = h5py.File(args.modified_base_output) mods_fp.create_group('Reads') mod_long_names = model.sublayers[-1].ordered_mod_long_names sys.stderr.write("* Preparing modified base output: {}.\n".format( ', '.join(map(str, mod_long_names)))) mods_fp.create_dataset( 'mod_long_names', data=np.array(mod_long_names, dtype='S'), dtype=h5py.special_dtype(vlen=str)) sys.stderr.write("* Calling reads.\n") nbase, ncalled, nread, nsample = 0, 0, 0, 0 t0 = time.time() progress = Progress(quiet=args.quiet) try: with open_file_or_stdout(args.output) as fh: for read_filename, read_id in fast5_reads: basecall, read_nsample = process_read( read_filename, read_id, model, chunk_size, chunk_overlap, device, n_can_states, stride, args.alphabet, is_cat_mod, mods_fp) if basecall is not None: fh.write(">{}\n{}\n".format(read_id, basecall)) nbase += len(basecall) ncalled += 1 nread += 1 nsample += read_nsample progress.step() finally: if mods_fp is not None: mods_fp.close() total_time = time.time() - t0 sys.stderr.write("* Called {} reads in {}s\n".format(nread, int(total_time))) sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0)) sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0)) sys.stderr.write("* {} reads failed.\n".format(nread - ncalled)) return
def run_model(normed_signal, model, chunk_size=_DEFAULT_CHUNK_SIZE, overlap=_DEFAULT_OVERLAP, max_concur_chunks=None, return_numpy=True, return_tensor_on_device=True): """ Hook for megalodon to run network via taiyaki Note: The `chunk_size` and `overlap` parameters as multiples of the stride of `model` rather than as number of samples. This behaviour is consistent with the parameterisation in Guppy. Args: normed_signal (:class:`ndarray`): Signal of read, which will be chunked and the result of calling and stitching returned. model (:class:`layers.Serial`): A Taiyaki model, implicitly assumed to to have a :class:`layers.Serial` as its outmost layer and for the first wrapped layer to have parameters. chunk_size (int, optional): Length of chunks into which `signal` will be split. overlap (int, optional): Overlap between one chunk and the next. max_concur_chunks (int, optional): Calculate chunks in batches of size at most `max_concur_chunks`; if None, then all chunks are calculated at once. return_numpy (bool, optional): Return value should be converted :class:`ndarray` (default). return_tensor_on_device (bool, optional): Return value should be moved back onto same device as model (default). Overridden by `return_numpy`. Returns: :class:`Tensor`: Output of basecalling chunks, stitched together. If `return_tensor_on_device` is True, then the stitched chunks are transferred back on to the GPU device; otherwise, they are returned in host memory ("cpu" device). If `return_numpy` is True, the return type is converted to a :class:`ndarray` and remains in host memory. """ device = get_model_device(model) stride = guess_model_stride(model) chunk_size *= stride overlap *= stride chunks, chunk_starts, chunk_ends = chunk_read(normed_signal, chunk_size, overlap) chunks = torch.tensor(chunks) with torch.no_grad(): if max_concur_chunks is None: out = model(chunks.to(device)).cpu() else: out = [] for some_chunks in torch.split(chunks, max_concur_chunks, 1): out.append(model(some_chunks.to(device)).cpu()) out = torch.cat(out, 1) stitched_chunks = stitch_chunks(out, chunk_starts, chunk_ends, stride) if return_numpy: return stitched_chunks.numpy() if return_tensor_on_device: return stitched_chunks.to(device) return stitched_chunks
def _load_taiyaki_model(self): LOGGER.info('Loading taiyaki basecalling backend.') self.model_type = TAI_NAME devices = self.params.taiyaki.devices if devices is None: devices = ['cpu', ] self.process_devices = [ parse_device(devices[i]) for i in np.tile( np.arange(len(devices)), (self.num_proc // len(devices)) + 1)][:self.num_proc] try: # import modules from taiyaki.helpers import ( load_model as load_taiyaki_model, guess_model_stride) from taiyaki.basecall_helpers import run_model as tai_run_model from taiyaki.layers import GlobalNormFlipFlopCatMod except ImportError: LOGGER.error( 'Failed to import taiyaki. Ensure working ' + 'installations to run megalodon') sys.exit(1) try: import torch except ImportError: LOGGER.error( 'Failed to import pytorch. Ensure working ' + 'installations to run megalodon') sys.exit(1) # store modules in object self.load_taiyaki_model = load_taiyaki_model self.tai_run_model = tai_run_model self.torch = torch tmp_model = self.load_taiyaki_model( self.params.taiyaki.taiyaki_model_fn) ff_layer = tmp_model.sublayers[-1] self.is_cat_mod = ( GlobalNormFlipFlopCatMod is not None and isinstance( ff_layer, GlobalNormFlipFlopCatMod)) self.stride = guess_model_stride(tmp_model) self.output_size = ff_layer.size if self.is_cat_mod: # Modified base model is defined by 3 fixed fields in taiyaki # can_nmods, output_alphabet and modified_base_long_names self.output_alphabet = ff_layer.output_alphabet self.can_nmods = ff_layer.can_nmods self.ordered_mod_long_names = ff_layer.ordered_mod_long_names self.compute_mod_alphabet_attrs() else: if mh.nstate_to_nbase(self.output_size) != 4: raise NotImplementedError( 'Naive modified base flip-flop models are not ' + 'supported.') self.output_alphabet = mh.ALPHABET self.can_alphabet = mh.ALPHABET self.mod_long_names = [] self.str_to_int_mod_labels = {} self.can_nmods = None self.n_mods = len(self.mod_long_names)
def oneread_remap(read_tuple, model, per_read_params_dict, alphabet_info, max_read_length, device='cpu', localpen=0.0): """ Worker function for remapping reads using flip-flop model on raw signal Args: read_tuple (tuple) : read, identified by a tuple (filepath, read_id, read reference) model (pytorch Module): pytorch model device (int or float): integer specifying which GPU to use for remapping, or 'cpu' to use CPU per_read_params_dict (dict) : dictionary where keys are UUIDs, values are dicts containing keys trim_start trim_end shift scale alphabet_info (AlphabetInfo object): for basecalling max_read_length (int) : Don't attempt to remap reads with references longer than this localpen (float): Penalty for local mapping Returns: tuple :(dict,str) containing 1. dictionary as specified in signal_mapping.SignalMapping.get_read_dictionary 2. message string indicating an error if one occured """ filename, read_id, read_ref = read_tuple if read_ref is None: return None, RemapResult.NO_REF_FOUND if max_read_length is not None and len(read_ref) > max_read_length: return None, RemapResult.REF_TOO_LONG try: read_params_dict = per_read_params_dict[read_id] except KeyError: return None, RemapResult.NO_PARAMS try: with fast5_interface.get_fast5_file(filename, 'r') as f5file: read = f5file.get_read(read_id) sig = signal.Signal(read, read_params=read_params_dict) except Exception: return None, RemapResult.READ_ID_INFO_NOT_FOUND try: # Prevents torch doing its own parallelisation on top of our imap_map torch.set_num_threads(1) # Make signal into 3D tensor with shape [siglength,1,1] and move to # appropriate device (GPU number or CPU) signalTensor = torch.tensor( sig.standardized_current[:, np.newaxis, np.newaxis].astype(np.float32), device=device) # The model must live on the same device modelOnDevice = model.to(device) # Apply the network to the signal, generating transition weight matrix, # and put it back into a numpy array with torch.no_grad(): transweights = modelOnDevice(signalTensor).cpu().numpy() except Exception: return None, RemapResult.NETWORK_ERROR # Extra dimensions introduced by np.newaxis above removed by np.squeeze can_read_ref = alphabet_info.collapse_sequence(read_ref) remappingscore, path = flipflop_remap.flipflop_remap( np.squeeze(transweights), can_read_ref, alphabet=alphabet_info.can_bases, localpen=localpen) # read_ref comes out as a bytes object, so we need to convert to str # localpen=0.0 does local alignment # flipflop_remap() establishes a mapping between the network outputs and # the reference. # What we need is a mapping between the signal and the reference. # To resolve this we need to know the stride of the model (how many samples # for each network output) model_stride = helpers.guess_model_stride(model) int_ref = signal_mapping.SignalMapping.get_integer_reference( read_ref, alphabet_info.alphabet) sig_mapping = signal_mapping.SignalMapping.from_remapping_path( path, int_ref, model_stride, sig) try: sig_mapping_dict = sig_mapping.get_read_dictionary() except signal_mapping.TaiyakiSigMapError as e: return None, str(e) return sig_mapping_dict, RemapResult.SUCCESS