def test_load_model_from_file_with_wrong_metadata(self, filename): metadata = {'reverse': True, 'standardize': False} filepath = os.path.join(MODELS_DIR, filename) with self.assertWarns(RuntimeWarning): net = helpers.load_model(filepath, model_metadata=metadata) self.assertEqual(metadata['reverse'], net.metadata['reverse']) self.assertEqual(metadata['standardize'], net.metadata['standardize'])
def main(): args = get_parser().parse_args() worker_kwarg_names = ['back_prob', 'localpen', 'minscore', 'trim'] model = helpers.load_model(args.model) fast5_reads = fast5utils.iterate_fast5_reads( args.read_dir, limit=args.limit, strand_list=args.input_strand_list, recursive=args.recursive) with helpers.open_file_or_stdout(args.output) as fh: for res in imap_mp( squiggle_match.worker, fast5_reads, threads=args.jobs, fix_kwargs=helpers.get_kwargs(args, worker_kwarg_names), unordered=True, init=squiggle_match.init_worker, initargs=[model, args.references]): if res is None: continue read_id, sig, score, path, squiggle, bases = res bases = bases.decode('ascii') fh.write('#{} {}\n'.format(read_id, score)) for i, (s, p) in enumerate(zip(sig, path)): fh.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format( read_id, i, s, p, bases[p], squiggle[p, 0], squiggle[p, 1], squiggle[p, 2]))
def main(argv): """Main function to process mapping for each read using functions in prepare_mapping_funcs""" args = parser.parse_args() print("Running prepare_mapping using flip-flop remapping") if not args.overwrite: if os.path.exists(args.output): print("Cowardly refusing to overwrite {}".format(args.output)) sys.exit(1) # Make an iterator that yields all the reads we're interested in. fast5_reads = fast5utils.iterate_fast5_reads( args.input_folder, limit=args.limit, strand_list=args.input_strand_list) # Set up arguments (kwargs) for the worker function for each read kwargs = helpers.get_kwargs(args, ['alphabet', 'collapse_alphabet', 'device']) kwargs[ 'per_read_params_dict'] = prepare_mapping_funcs.get_per_read_params_dict_from_tsv( args.input_per_read_params) kwargs['references'] = helpers.fasta_file_to_dict(args.references) kwargs['model'] = helpers.load_model(args.model) workerFunction = prepare_mapping_funcs.oneread_remap # remaps a single read using flip-flip network results = imap_mp(workerFunction, fast5_reads, threads=args.jobs, fix_kwargs=kwargs, unordered=True) # results is an iterable of dicts # each dict is a set of return values from a single read prepare_mapping_funcs.generate_output_from_results(results, args)
def main(): args = parser.parse_args() model = load_model(args.model) json_out = model.json(args.params) with open_file_or_stdout(args.output) as fh: json.dump(json_out, fh, indent=4, cls=JsonEncoder)
def main(): args = get_parser().parse_args() model_md5 = file_md5(args.model) model = load_model(args.model) json_out = model.json() json_out['md5sum'] = model_md5 with open_file_or_stdout(args.output) as fh: json.dump(json_out, fh, indent=4, cls=JsonEncoder)
def main(): args = parser.parse_args() #model_md5 = file_md5(args.model) model = load_model(args.model) print("Previous alphabet",model.sublayers[-1].output_alphabet) for attr in ["mod_bases","mod_labels","mod_name_conv","ordered_mod_long_names"]: if "mod" in attr: print(attr,getattr(model.sublayers[-1],attr)) #print("Previous alphabet",dir(model.sublayers[-1]))#).mod_long_names) model.sublayers[-1].output_alphabet = args.alphabet save_model(model,args.output)
def main(): args = parser.parse_args() predict_squiggle = helpers.load_model(args.model) with helpers.open_file_or_stdout(args.output) as fh: for seq in SeqIO.parse(args.input, 'fasta'): seqstr = str(seq.seq) embedded_seq_numpy = np.expand_dims( squiggle_match.embed_sequence(seqstr), axis=1) embedded_seq_torch = torch.tensor(embedded_seq_numpy, dtype=torch.float32) with torch.no_grad(): squiggle = np.squeeze( predict_squiggle(embedded_seq_torch).cpu().numpy(), axis=1) fh.write('base\tcurrent\tsd\tdwell\n') for base, (mean, logsd, dwell) in zip(seq.seq, squiggle): fh.write('{}\t{}\t{}\t{}\n'.format(base, mean, np.exp(logsd), np.exp(-dwell)))
def worker_init(device, modelname, chunk_size, overlap, read_params, alphabet, max_concurrent_chunks, fastq, qscore_scale, qscore_offset, beam, posterior, temperature): global all_read_params global process_read_partial all_read_params = read_params device = helpers.set_torch_device(device) model = load_model(modelname).to(device) stride = guess_model_stride(model) chunk_size = chunk_size * stride overlap = overlap * stride n_can_base = len(alphabet) n_can_state = nstate_flipflop(n_can_base) def process_read_partial(read_filename, read_id, read_params): res = process_read(read_filename, read_id, model, chunk_size, overlap, read_params, n_can_state, stride, alphabet, max_concurrent_chunks, fastq, qscore_scale, qscore_offset, beam, posterior, temperature) return (read_id, *res)
def main(): """Main function to process mapping for each read using functions in prepare_mapping_funcs""" args = parser.parse_args() print("Running prepare_mapping using flip-flop remapping") if not args.overwrite: if os.path.exists(args.output): print("Cowardly refusing to overwrite {}".format(args.output)) sys.exit(1) # Create alphabet and check for consistency modified_bases = [elt[0] for elt in args.mod] canonical_bases = [elt[1] for elt in args.mod] for b in modified_bases: assert len( b ) == 1, "Modified bases must be a single character, got {}".format(b) assert b not in args.alphabet, "Modified base must not be a canonical base, got {}".format( b) for b in canonical_bases: assert len( b ) == 1, "Canonical coding for modified bases must be a single character, got {}".format( b) assert b in args.alphabet, "Canonical coding for modified base must be a canonical base, got {}".format( b) full_alphabet = args.alphabet + ''.join(modified_bases) flat_alphabet = args.alphabet + ''.join(canonical_bases) modification_names = [elt[2] for elt in args.mod] alphabet_info = alphabet.AlphabetInfo(full_alphabet, flat_alphabet, modification_names, do_reorder=True) print("Converting references to labels using {}".format( str(alphabet_info))) # Make an iterator that yields all the reads we're interested in. fast5_reads = fast5utils.iterate_fast5_reads( args.input_folder, limit=args.limit, strand_list=args.input_strand_list, recursive=args.recursive) # Set up arguments (kwargs) for the worker function for each read kwargs = {} kwargs[ 'per_read_params_dict'] = prepare_mapping_funcs.get_per_read_params_dict_from_tsv( args.input_per_read_params) kwargs['model'] = helpers.load_model(args.model) kwargs['alphabet_info'] = alphabet_info kwargs['max_read_length'] = args.max_read_length kwargs['localpen'] = args.localpen # remaps a single read using flip-flip network workerFunction = prepare_mapping_funcs.oneread_remap def iter_jobs(): references = bio.fasta_file_to_dict(args.references, alphabet=full_alphabet) for fn, read_id in fast5_reads: yield fn, read_id, references.get(read_id, None) if args.limit is not None: chunksize = args.limit // (2 * args.jobs) chunksize = int(np.clip(chunksize, 1, 50)) else: chunksize = 50 results = imap_mp(workerFunction, iter_jobs(), threads=args.jobs, fix_kwargs=kwargs, unordered=True, chunksize=chunksize) # results is an iterable of dicts # each dict is a set of return values from a single read prepare_mapping_funcs.generate_output_from_results(results, args.output, alphabet_info)
def test_load_model_from_file_with_metadata(self, filename): metadata = {'reverse': False, 'standardize': False} filepath = os.path.join(MODELS_DIR, filename) net = helpers.load_model(filepath, model_metadata=metadata) self.assertEqual(metadata['reverse'], net.metadata['reverse']) self.assertEqual(metadata['standardize'], net.metadata['standardize'])
def main(): args = parser.parse_args() np.random.seed(args.seed) device = torch.device(args.device) if device.type == 'cuda': try: torch.cuda.set_device(device) except AttributeError: sys.stderr.write('ERROR: Torch not compiled with CUDA enabled ' + 'and GPU device set.') sys.exit(1) if not os.path.exists(args.output): os.mkdir(args.output) elif not args.overwrite: sys.stderr.write('Error: Output directory {} exists but --overwrite ' + 'is false\n'.format(args.output)) exit(1) if not os.path.isdir(args.output): sys.stderr.write('Error: Output location {} is not directory\n'.format( args.output)) exit(1) copyfile(args.model, os.path.join(args.output, 'model.py')) # Create a logging file to save details of chunks. # If args.chunk_logging_threshold is set to 0 then we log all chunks # including those rejected. chunk_log = chunk_selection.ChunkLog(args.output) log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet) log.write('* Taiyaki version {}\n'.format(__version__)) log.write('* Command line\n') log.write(' '.join(sys.argv) + '\n') log.write('* Loading data from {}\n'.format(args.input)) log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input))) if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write(('* Will train from a subset of {} strands, determined ' + 'by read_ids in input strand list\n').format(len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet, _, _ = per_read_file.get_alphabet_information() read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in # mapped_signal_files.py, based on dict) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell # Choose a chunk length in the middle of the range for this sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2 filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, sampling_chunk_len, args, log, chunk_log=chunk_log) medmd, madmd = filter_parameters log.write( "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n" .format(args.sample_nreads_before_filtering, medmd, madmd)) log.write('* Reading network from {}\n'.format(args.model)) nbase = len(alphabet) model_kwargs = { 'stride': args.stride, 'winlen': args.winlen, # Number of input features to model e.g. was >1 for event-based # models (level, std, dwell) 'insize': 1, 'size': args.size, 'outsize': flipflopfings.nstate_flipflop(nbase) } network = helpers.load_model(args.model, **model_kwargs).to(device) log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay) lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min, args.lr_cosine_iters) score_smoothed = helpers.WindowedExpSmoother() log.write('* Dumping initial model\n') helpers.save_model(network, args.output, 0) total_bases = 0 total_samples = 0 total_chunks = 0 # To count the numbers of different sorts of chunk rejection rejection_dict = defaultdict(int) t0 = time.time() log.write('* Training\n') for i in range(args.niteration): lr_scheduler.step() # Chunk length is chosen randomly in the range given but forced to # be a multiple of the stride batch_chunk_len = ( np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) // args.stride) * args.stride # We choose the batch size so that the size of the data in the batch # is about the same as args.min_batch_size chunks of length # args.chunk_len_max target_batch_size = int(args.min_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) # ...but it can't be more than the number of reads. batch_size = min(target_batch_size, len(read_data)) # If the logging threshold is 0 then we log all chunks, including those # rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # Chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, batch_size, batch_chunk_len, filter_parameters, args, log, chunk_log=log_rejected_chunks) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input tensor must be: # (timesteps) x (batch size) x (input channels) # in this case: # batch_chunk_len x batch_size x 1 stacked_current = np.vstack([d['current'] for d in chunk_batch]).T indata = torch.tensor(stacked_current, device=device, dtype=torch.float32).unsqueeze(2) # Sequence input tensor is just a 1D vector, and so is seqlens seqs = torch.tensor(np.concatenate([ flipflopfings.flipflop_code(d['sequence'], nbase) for d in chunk_batch ]), device=device, dtype=torch.long) seqlens = torch.tensor([len(d['sequence']) for d in chunk_batch], dtype=torch.long, device=device) optimizer.zero_grad() outputs = network(indata) lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens, args.sharpen) loss = lossvector.sum() / (seqlens > 0.0).float().sum() loss.backward() optimizer.step() fval = float(loss) score_smoothed.update(fval) # Check for poison chunk and save losses and chunk locations if we're # poisoned If args.chunk_logging_threshold set to zero then we log # everything if fval / score_smoothed.value >= args.chunk_logging_threshold: chunk_log.write_batch(i, chunk_batch, lossvector) total_bases += int(seqlens.sum()) total_samples += int(indata.nelement()) # Doing this deletion leads to less CUDA memory usage. del indata, seqs, seqlens, outputs, loss, lossvector if device.type == 'cuda': torch.cuda.empty_cache() if (i + 1) % args.save_every == 0: helpers.save_model(network, args.output, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: # In case of super batching, additional functionality must be # added here learning_rate = lr_scheduler.get_lr()[0] tn = time.time() dt = tn - t0 t = ( ' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) ' + 'lr={:.2e}') log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt, total_samples / 1000.0 / dt, total_bases / 1000.0 / dt, learning_rate)) # Write summary of chunk rejection reasons for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) log.write("\n") total_bases = 0 total_samples = 0 t0 = tn helpers.save_model(network, args.output)
def main(): args = parser.parse_args() device = helpers.set_torch_device(args.device) # TODO convert to logging sys.stderr.write("* Loading model.\n") model = load_model(args.model).to(device) is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod) do_output_mods = args.modified_base_output is not None if do_output_mods and not is_cat_mod: sys.stderr.write( "Cannot output modified bases from canonical base only model.") sys.exit() n_can_states = nstate_flipflop(model.sublayers[-1].nbase) stride = guess_model_stride(model) chunk_size = args.chunk_size * stride chunk_overlap = args.overlap * stride sys.stderr.write("* Initializing reads file search.\n") fast5_reads = list( fast5utils.iterate_fast5_reads(args.input_folder, limit=args.limit, strand_list=args.input_strand_list, recursive=args.recursive)) sys.stderr.write("* Found {} reads.\n".format(len(fast5_reads))) if args.scaling is not None: sys.stderr.write("* Loading read scaling parameters from {}.\n".format( args.scaling)) all_read_params = get_per_read_params_dict_from_tsv(args.scaling) input_read_ids = frozenset(rec[1] for rec in fast5_reads) scaling_read_ids = frozenset(all_read_params.keys()) sys.stderr.write("* {} / {} reads have scaling information.\n".format( len(input_read_ids & scaling_read_ids), len(input_read_ids))) fast5_reads = [ rec for rec in fast5_reads if rec[1] in scaling_read_ids ] else: all_read_params = {} mods_fp = None if do_output_mods: mods_fp = h5py.File(args.modified_base_output) mods_fp.create_group('Reads') mod_long_names = model.sublayers[-1].ordered_mod_long_names sys.stderr.write("* Preparing modified base output: {}.\n".format( ', '.join(map(str, mod_long_names)))) mods_fp.create_dataset('mod_long_names', data=np.array(mod_long_names, dtype='S'), dtype=h5py.special_dtype(vlen=str)) sys.stderr.write("* Calling reads.\n") nbase, ncalled, nread, nsample = 0, 0, 0, 0 t0 = time.time() progress = Progress(quiet=args.quiet) startcharacter = '@' if args.fastq else '>' try: with open_file_or_stdout(args.output) as fh: for read_filename, read_id in fast5_reads: read_params = all_read_params[ read_id] if read_id in all_read_params else None basecall, qstring, read_nsample = process_read( read_filename, read_id, model, chunk_size, chunk_overlap, read_params, n_can_states, stride, args.alphabet, is_cat_mod, mods_fp, args.max_concurrent_chunks, args.fastq, args.qscore_scale, args.qscore_offset) if basecall is not None: fh.write("{}{}\n{}\n".format( startcharacter, read_id, basecall[::-1] if args.reverse else basecall)) nbase += len(basecall) ncalled += 1 if args.fastq: fh.write("+\n{}\n".format( qstring[::-1] if args.reverse else qstring)) nread += 1 nsample += read_nsample progress.step() finally: if mods_fp is not None: mods_fp.close() total_time = time.time() - t0 sys.stderr.write("* Called {} reads in {:.2f}s\n".format( nread, int(total_time))) sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0)) sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0)) sys.stderr.write("* {} reads failed.\n".format(nread - ncalled)) return
pickle_name = os.path.splitext(args.reference)[0] + '.pkl' with open(pickle_name, 'wb') as fh: pickle.dump(seq_dict, fh) log.write('* Written pickle of processed references to {} for future use.\n'.format(pickle_name)) log.write('* Reading network from {}\n'.format(args.model)) nbase = len(args.alphabet) model_kwargs = { 'size' : args.size, 'stride': args.stride, 'winlen': args.winlen, 'insize': 1, # Number of input features to model e.g. was >1 for event-based models (level, std, dwell) 'outsize': flipflopfings.nstate_flipflop(nbase) } network = helpers.load_model(args.model, **model_kwargs).to(device) log.write('* Network has {} parameters.\n'.format(sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, eps=args.eps) lr_scheduler = CosineAnnealingLR(optimizer, args.niteration) score_smoothed = helpers.WindowedExpSmoother() log.write('* Dumping initial model\n') save_model(network, args.outdir, 0) total_bases = 0 total_samples = 0 total_chunks = 0
'{}W'.format(name), filterW.permute(0, 2, 1).reshape(-1, nf), nr=nf2 * winlen - nf2 + nf, nc=nfilter) cformatV(sys.stdout, '{}b'.format(name), convwrapper.conv.bias.reshape(-1)) sys.stdout.write("#define {}stride {}\n".format(name, convwrapper.stride)) sys.stdout.write("""#define {}nfilter {} #define {}winlen {} """.format(name, nfilter, name, winlen)) if __name__ == '__main__': args = parser.parse_args() modelid = args.id + '_' network = helpers.load_model(args.model) if isinstance(network.sublayers[0], DeltaSample): sys.stderr.write('* Removing initial DeltaSample layer\n') network.sublayers = network.sublayers[1:] sys.stdout.write("""#pragma once #ifndef FLIPFLOP_{}MODEL_H #define FLIPFLOP_{}MODEL_H #include "../util.h" """.format(modelid.upper(), modelid.upper())) """ Convolution layers """ conv1 = network.sublayers[0] print_convolution(conv1, 'conv1_rnnrf_flipflop5_{}'.format(modelid),
def load_network(args, alphabet_info, res_info, log): log.write('* Reading network from {}\n'.format(args.model)) if res_info.is_lead_process: # Under pytorch's DistributedDataParallel scheme, we # need a clone of the start network to use as a template for saving # checkpoints. Necessary because DistributedParallel makes the class # structure different. model_kwargs = { 'stride': args.stride, 'winlen': args.winlen, 'insize': 1, 'size': args.size, 'alphabet_info': alphabet_info } model_metadata = { 'reverse': args.reverse, 'standardize': args.standardize } net_clone = helpers.load_model(args.model, model_metadata=model_metadata, **model_kwargs) log.write('* Network has {} parameters.\n'.format( sum(p.nelement() for p in net_clone.parameters()))) if not alphabet_info.is_compatible_model(net_clone): sys.stderr.write( '* ERROR: Model and mapped signal files contain ' + 'incompatible alphabet definitions (including modified ' + 'bases).') sys.exit(1) if layers.is_cat_mod_model(net_clone): log.write('* Loaded categorical modified base model.\n') if not alphabet_info.contains_modified_bases(): sys.stderr.write( '* ERROR: Modified bases model specified, but mapped ' + 'signal file does not contain modified bases.') sys.exit(1) else: log.write('* Loaded standard (canonical bases-only) model.\n') if alphabet_info.contains_modified_bases(): sys.stderr.write( '* ERROR: Standard (canonical bases only) model ' + 'specified, but mapped signal file does contains ' + 'modified bases.') sys.exit(1) if layers.is_delta_model(net_clone) and model_metadata.standardize: log.write('*' * 60 + '\n* WARNING: Delta-scaling models trained ' + 'with --standardize are not compatible with Guppy.\n' + '*' * 60) log.write('* Dumping initial model\n') helpers.save_model(net_clone, args.outdir, 0) else: net_clone = None if res_info.is_multi_gpu: # so that processes 1,2,3.. don't try to load before process 0 has # saved torch.distributed.barrier() log.write('* MultiGPU process {}'.format(args.local_rank)) log.write(': loading initial model saved by process 0\n') saved_startmodel_path = os.path.join( args.outdir, 'model_checkpoint_00000.checkpoint') network = helpers.load_model(saved_startmodel_path).to(res_info.device) network_metadata = parse_network_metadata(network) # Wrap network for training in the DistributedDataParallel structure network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[args.local_rank], output_device=args.local_rank) else: log.write('* Loading model onto device\n') network = net_clone.to(res_info.device) network_metadata = parse_network_metadata(network) net_clone = None log.write('* Estimating filter parameters from training data\n') stride = guess_model_stride(network) optimiser = torch.optim.AdamW(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay, eps=args.eps) lr_warmup = args.lr_min if args.lr_warmup is None else args.lr_warmup adam_beta1, _ = args.adam if args.warmup_batches >= args.niteration: sys.stderr.write('* Error: --warmup_batches must be < --niteration\n') sys.exit(1) warmup_fraction = args.warmup_batches / args.niteration # Pytorch OneCycleLR crashes if pct_start==1 (i.e. warmup_fraction==1) lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimiser, args.lr_max, total_steps=args.niteration, # pct_start is really fractional, not percent pct_start=warmup_fraction, div_factor=args.lr_max / lr_warmup, final_div_factor=lr_warmup / args.lr_min, cycle_momentum=(args.min_momentum is not None), base_momentum=adam_beta1 if args.min_momentum is None \ else args.min_momentum, max_momentum=adam_beta1 ) log.write( ('* Learning rate increases from {:.2e} to {:.2e} over {} ' + 'iterations using cosine schedule.\n').format(lr_warmup, args.lr_max, args.warmup_batches)) log.write(('* Then learning rate decreases from {:.2e} to {:.2e} over ' + '{} iterations using cosine schedule.\n').format( args.lr_max, args.lr_min, args.niteration - args.warmup_batches)) if args.gradient_clip_num_mads is None: log.write('* No gradient clipping\n') rolling_mads = None else: nparams = len([p for p in network.parameters() if p.requires_grad]) if nparams == 0: rolling_mads = None log.write('* No gradient clipping due to missing parameters\n') else: rolling_mads = maths.RollingMAD(nparams, args.gradient_clip_num_mads) log.write(( '* Gradients will be clipped (by value) at {:3.2f} MADs ' + 'above the median of the last {} gradient maximums.\n').format( rolling_mads.n_mads, rolling_mads.window)) net_info = NETWORK_INFO(net=network, net_clone=net_clone, metadata=network_metadata, stride=stride) optim_info = OPTIM_INFO(optimiser=optimiser, lr_warmup=lr_warmup, lr_scheduler=lr_scheduler, rolling_mads=rolling_mads) return net_info, optim_info
def main(): args = parser.parse_args() is_multi_gpu = (args.local_rank is not None) is_lead_process = (not is_multi_gpu) or args.local_rank == 0 if is_multi_gpu: #Use distributed parallel processing to run one process per GPU try: torch.distributed.init_process_group(backend='nccl') except: raise Exception( "Unable to start multiprocessing group. " + "The most likely reason is that the script is running with " + "local_rank set but without the set-up for distributed " + "operation. local_rank should be used " + "only by torch.distributed.launch. See the README.") device = helpers.set_torch_device(args.local_rank) if args.seed is not None: #Make sure processes get different random picks of training data np.random.seed(args.seed + args.local_rank) else: device = helpers.set_torch_device(args.device) np.random.seed(args.seed) if is_lead_process: helpers.prepare_outdir(args.outdir, args.overwrite) if args.model.endswith('.py'): copyfile(args.model, os.path.join(args.outdir, 'model.py')) batchlog = helpers.BatchLog(args.outdir) logfile = os.path.join(args.outdir, 'model.log') else: logfile = None log = helpers.Logger(logfile, args.quiet) log.write(helpers.formatted_env_info(device)) log.write('* Loading data from {}\n'.format(args.input)) log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input))) if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write(('* Will train from a subset of {} strands, determined ' + 'by read_ids in input strand list\n').format(len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet_info = per_read_file.get_alphabet_information() read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in # mapped_signal_files.py, based on dict) log.write('* Using alphabet definition: {}\n'.format(str(alphabet_info))) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell # Choose a chunk length in the middle of the range for this sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2 filter_params = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, sampling_chunk_len, args.filter_mean_dwell, args.filter_max_dwell) log.write("* Sampled {} chunks".format( args.sample_nreads_before_filtering)) log.write(": median(mean_dwell)={:.2f}".format( filter_params.median_meandwell)) log.write(", mad(mean_dwell)={:.2f}\n".format(filter_params.mad_meandwell)) log.write('* Reading network from {}\n'.format(args.model)) model_kwargs = { 'stride': args.stride, 'winlen': args.winlen, # Number of input features to model e.g. was >1 for event-based # models (level, std, dwell) 'insize': 1, 'size': args.size, 'alphabet_info': alphabet_info } if is_lead_process: # Under pytorch's DistributedDataParallel scheme, we # need a clone of the start network to use as a template for saving # checkpoints. Necessary because DistributedParallel makes the class # structure different. network_save_skeleton = helpers.load_model(args.model, **model_kwargs) log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network_save_skeleton.parameters()]))) if not alphabet_info.is_compatible_model(network_save_skeleton): sys.stderr.write( '* ERROR: Model and mapped signal files contain incompatible ' + 'alphabet definitions (including modified bases).') sys.exit(1) if is_cat_mod_model(network_save_skeleton): log.write('* Loaded categorical modified base model.\n') if not alphabet_info.contains_modified_bases(): sys.stderr.write( '* ERROR: Modified bases model specified, but mapped ' + 'signal file does not contain modified bases.') sys.exit(1) else: log.write('* Loaded standard (canonical bases-only) model.\n') if alphabet_info.contains_modified_bases(): sys.stderr.write( '* ERROR: Standard (canonical bases only) model ' + 'specified, but mapped signal file does contains ' + 'modified bases.') sys.exit(1) log.write('* Dumping initial model\n') helpers.save_model(network_save_skeleton, args.outdir, 0) if is_multi_gpu: #so that processes 1,2,3.. don't try to load before process 0 has saved torch.distributed.barrier() log.write('* MultiGPU process {}'.format(args.local_rank)) log.write(': loading initial model saved by process 0\n') saved_startmodel_path = os.path.join( args.outdir, 'model_checkpoint_00000.checkpoint') network = helpers.load_model(saved_startmodel_path).to(device) # Wrap network for training in the DistributedDataParallel structure network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[args.local_rank], output_device=args.local_rank) else: network = network_save_skeleton.to(device) network_save_skeleton = None optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay, eps=args.eps) if args.lr_warmup is None: lr_warmup = args.lr_min else: lr_warmup = args.lr_warmup if args.lr_frac_decay is not None: lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_frac_decay, args.warmup_batches, lr_warmup) log.write('* Learning rate schedule lr_max*k/(k+t)') log.write(', k={}, t=iterations.\n'.format(args.lr_frac_decay)) else: lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min, args.lr_cosine_iters, args.warmup_batches, lr_warmup) log.write('* Learning rate goes like cosine from lr_max to lr_min ') log.write('over {} iterations.\n'.format(args.lr_cosine_iters)) log.write('* At start, train for {} '.format(args.warmup_batches)) log.write('batches at warm-up learning rate {:3.2}\n'.format(lr_warmup)) score_smoothed = helpers.WindowedExpSmoother() # prepare modified base paramter tensors network_is_catmod = is_cat_mod_model(network) mod_factor_t = torch.tensor(args.mod_factor, dtype=torch.float32).to(device) can_mods_offsets = (network.sublayers[-1].can_mods_offsets if network_is_catmod else None) # mod cat inv freq weighting is currently disabled. Compute and set this # value to enable mod cat weighting mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32) #Generating list of batches for standard loss reporting reporting_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2 reporting_batch_list = list( prepare_random_batches(device, read_data, reporting_chunk_len, args.min_sub_batch_size, args.reporting_sub_batches, alphabet_info, filter_params, network, network_is_catmod, log)) log.write( ('* Standard loss report: chunk length = {} & sub-batch size ' + '= {} for {} sub-batches. \n').format(reporting_chunk_len, args.min_sub_batch_size, args.reporting_sub_batches)) #Set cap at very large value (before we have any gradient stats). gradient_cap = constants.LARGE_VAL if args.gradient_cap_fraction is None: log.write('* No gradient capping\n') else: rolling_quantile = maths.RollingQuantile(args.gradient_cap_fraction) log.write('* Gradient L2 norm cap will be upper' + ' {:3.2f} quantile of the last {} norms.\n'.format( args.gradient_cap_fraction, rolling_quantile.window)) total_bases = 0 total_samples = 0 total_chunks = 0 # To count the numbers of different sorts of chunk rejection rejection_dict = defaultdict(int) t0 = time.time() log.write('* Training\n') for i in range(args.niteration): # Chunk length is chosen randomly in the range given but forced to # be a multiple of the stride batch_chunk_len = ( np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) // args.stride) * args.stride # We choose the size of a sub-batch so that the size of the data in # the sub-batch is about the same as args.min_sub_batch_size chunks of # length args.chunk_len_max sub_batch_size = int(args.min_sub_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) optimizer.zero_grad() main_batch_gen = prepare_random_batches( device, read_data, batch_chunk_len, sub_batch_size, args.sub_batches, alphabet_info, filter_params, network, network_is_catmod, log) chunk_count, fval, chunk_samples, chunk_bases, batch_rejections = \ calculate_loss( network, network_is_catmod, main_batch_gen, args.sharpen, can_mods_offsets, mod_cat_weights, mod_factor_t, calc_grads = True ) gradnorm_uncapped = torch.nn.utils.clip_grad_norm_( network.parameters(), gradient_cap) if args.gradient_cap_fraction is not None: gradient_cap = rolling_quantile.update(gradnorm_uncapped) optimizer.step() if is_lead_process: batchlog.record( fval, gradnorm_uncapped, None if args.gradient_cap_fraction is None else gradient_cap) total_chunks += chunk_count total_samples += chunk_samples total_bases += chunk_bases # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v score_smoothed.update(fval) if (i + 1) % args.save_every == 0 and is_lead_process: helpers.save_model(network, args.outdir, (i + 1) // args.save_every, network_save_skeleton) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: _, rloss, _, _, _ = calculate_loss(network, network_is_catmod, reporting_batch_list, args.sharpen, can_mods_offsets, mod_cat_weights, mod_factor_t) # In case of super batching, additional functionality must be # added here learning_rate = lr_scheduler.get_lr()[0] tn = time.time() dt = tn - t0 t = (' {:5d} {:7.5f} {:7.5f} {:5.2f}s ({:.2f} ksample/s {:.2f} ' + 'kbase/s) lr={:.2e}') log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, rloss, dt, total_samples / 1000.0 / dt, total_bases / 1000.0 / dt, learning_rate)) # Write summary of chunk rejection reasons if args.full_filter_status: for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) else: n_tot = n_fail = 0 for k, v in rejection_dict.items(): n_tot += v if k != 'pass': n_fail += v log.write(" {:.1%} chunks filtered".format(n_fail / n_tot)) log.write("\n") total_bases = 0 total_samples = 0 t0 = tn # Uncomment the lines below to check synchronisation of models # between processes in multi-GPU operation #for p in network.parameters(): # v = p.data.reshape(-1)[:5].to('cpu') # u = p.data.reshape(-1)[-5:].to('cpu') # break #if args.local_rank is not None: # log.write("* GPU{} params:".format(args.local_rank)) #log.write("{}...{}\n".format(v,u)) lr_scheduler.step() if is_lead_process: helpers.save_model(network, args.outdir, model_skeleton=network_save_skeleton)
def main(): args = parser.parse_args() assert args.device != 'cpu', "Flipflop basecalling in taiyaki requires a GPU and for cupy to be installed" device = torch.device(args.device) # TODO convert to logging sys.stderr.write("* Loading model.\n") model = load_model(args.model).to(device) is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod) do_output_mods = args.modified_base_output is not None if do_output_mods and not is_cat_mod: sys.stderr.write( "Cannot output modified bases from canonical base only model.") sys.exit() n_can_states = nstate_flipflop(model.sublayers[-1].nbase) stride = guess_model_stride(model, device=device) chunk_size, chunk_overlap = basecall_helpers.round_chunk_values( args.chunk_size, args.overlap, stride) sys.stderr.write("* Initializing reads file search.\n") fast5_reads = fast5utils.iterate_fast5_reads( args.input_folder, limit=args.limit, strand_list=args.input_strand_list, recursive=args.recursive) mods_fp = None if do_output_mods: mods_fp = h5py.File(args.modified_base_output) mods_fp.create_group('Reads') mod_long_names = model.sublayers[-1].ordered_mod_long_names sys.stderr.write("* Preparing modified base output: {}.\n".format( ', '.join(map(str, mod_long_names)))) mods_fp.create_dataset( 'mod_long_names', data=np.array(mod_long_names, dtype='S'), dtype=h5py.special_dtype(vlen=str)) sys.stderr.write("* Calling reads.\n") nbase, ncalled, nread, nsample = 0, 0, 0, 0 t0 = time.time() progress = Progress(quiet=args.quiet) try: with open_file_or_stdout(args.output) as fh: for read_filename, read_id in fast5_reads: basecall, read_nsample = process_read( read_filename, read_id, model, chunk_size, chunk_overlap, device, n_can_states, stride, args.alphabet, is_cat_mod, mods_fp) if basecall is not None: fh.write(">{}\n{}\n".format(read_id, basecall)) nbase += len(basecall) ncalled += 1 nread += 1 nsample += read_nsample progress.step() finally: if mods_fp is not None: mods_fp.close() total_time = time.time() - t0 sys.stderr.write("* Called {} reads in {}s\n".format(nread, int(total_time))) sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0)) sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0)) sys.stderr.write("* {} reads failed.\n".format(nread - ncalled)) return
def test_load_model_from_file_with_no_metadata(self, filename): helpers.load_model(os.path.join(MODELS_DIR, filename))
#!/usr/bin/env python3 import argparse import json from taiyaki.cmdargs import AutoBool, FileExists, FileAbsent from taiyaki.helpers import load_model from taiyaki.json import JsonEncoder parser = argparse.ArgumentParser(description='Dump JSON representation of model', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--out_file', default=None, action=FileAbsent, help='Output JSON file to this file location') parser.add_argument('--params', default=True, action=AutoBool, help='Output parameters as well as model structure') parser.add_argument('model', action=FileExists, help='Model file to read from') if __name__ == "__main__": args = parser.parse_args() model = load_model(args.model) json_out = model.json(args.params) if args.out_file is not None: with open(args.out_file, 'w') as f: print("Writing to file: ", args.out_file) json.dump(json_out, f, indent=4, cls=JsonEncoder) else: print(json.dumps(json_out, indent=4, cls=JsonEncoder))
log.write('* Reading network from {}\n'.format(args.model)) alphabet_info = alphabet.AlphabetInfo(args.alphabet, args.alphabet) model_kwargs = { 'size': args.size, 'stride': args.stride, 'winlen': args.winlen, # Number of input features to model e.g. was >1 for event-based models # (level, std, dwell) 'insize': 1, 'alphabet_info': alphabet_info } model_metadata = {'reverse': False, 'standardize': True} network = helpers.load_model(args.model, model_metadata=model_metadata, **model_kwargs).to(device) log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.AdamW(network.parameters(), lr=args.lr_max, betas=args.adam, eps=args.eps, weight_decay=args.weight_decay) lr_scheduler = CosineAnnealingLR(optimizer, args.niteration) score_smoothed = helpers.WindowedExpSmoother() log.write('* Dumping initial model\n') save_model(network, args.outdir, 0)
def main(): args = parser.parse_args() log, loss_log, chunk_log, device = _setup_and_logs(args) read_data, alphabet_info = _load_data(args, log) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell # Choose a chunk length in the middle of the range for this filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, (args.chunk_len_min + args.chunk_len_max) // 2, args, log, chunk_log=chunk_log) log.write(("* Sampled {} chunks: median(mean_dwell)={:.2f}, " + "mad(mean_dwell)={:.2f}\n").format( args.sample_nreads_before_filtering, *filter_parameters)) log.write('* Reading network from {}\n'.format(args.model)) model_kwargs = { 'insize': 1, 'winlen': args.winlen, 'stride': args.stride, 'size': args.size, 'alphabet_info': alphabet_info } network = helpers.load_model(args.model, **model_kwargs).to(device) if not isinstance(network.sublayers[-1], layers.GlobalNormFlipFlopCatMod): log.write( 'ERROR: Model must end with GlobalNormCatModFlipFlop layer, ' + 'not {}.\n'.format(str(network.sublayers[-1]))) sys.exit(1) can_mods_offsets = network.sublayers[-1].can_mods_offsets flipflop_can_labels = network.sublayers[-1].can_labels flipflop_mod_labels = network.sublayers[-1].mod_labels flipflop_ncan_base = network.sublayers[-1].ncan_base log.write('* Loaded categorical modifications flip-flop model.\n') log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay) lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min, args.lr_cosine_iters) if args.scale_mod_loss: try: mod_cat_weights = alphabet_info.compute_mod_inv_freq_weights( read_data, args.num_inv_freq_reads) log.write('* Modified base weights: {}\n'.format( str(mod_cat_weights))) except NotImplementedError: log.write( '* WARNING: Some mods not found when computing inverse ' + 'frequency weights. Consider raising ' + '[--num_inv_freq_reads].\n') mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32) else: mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32) log.write('* Dumping initial model\n') save_model(network, args.outdir, 0) total_bases = 0 total_chunks = 0 total_samples = 0 # To count the numbers of different sorts of chunk rejection rejection_dict = defaultdict(int) score_smoothed = helpers.WindowedExpSmoother() t0 = time.time() log.write('* Training\n') for i in range(args.niteration): lr_scheduler.step() mod_factor_t = torch.tensor(args.mod_factor, dtype=torch.float32) # Chunk length is chosen randomly in the range given but forced to # be a multiple of the stride batch_chunk_len = ( np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) // args.stride) * args.stride # We choose the batch size so that the size of the data in the batch # is about the same as args.min_batch_size chunks of length # args.chunk_len_max target_batch_size = int(args.min_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) # ...but it can't be more than the number of reads. batch_size = min(target_batch_size, len(read_data)) # If the logging threshold is 0 then we log all chunks, including those # rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, batch_size, batch_chunk_len, filter_parameters, args, log, chunk_log=log_rejected_chunks) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input tensor must be: # (timesteps) x (batch size) x (input channels) # in this case: # batch_chunk_len x batch_size x 1 stacked_current = np.vstack([d['current'] for d in chunk_batch]).T indata = torch.tensor(stacked_current, device=device, dtype=torch.float32).unsqueeze(2) seqs, mod_cats, seqlens = [], [], [] for chunk in chunk_batch: chunk_labels = chunk['sequence'] seqlens.append(len(chunk_labels)) chunk_seq = flipflop_code( np.ascontiguousarray(flipflop_can_labels[chunk_labels]), flipflop_ncan_base) chunk_mod_cats = np.ascontiguousarray( flipflop_mod_labels[chunk_labels]) seqs.append(chunk_seq) mod_cats.append(chunk_mod_cats) seqs, mod_cats = np.concatenate(seqs), np.concatenate(mod_cats) seqs = torch.tensor(seqs, dtype=torch.float32, device=device) seqlens = torch.tensor(seqlens, dtype=torch.long, device=device) mod_cats = torch.tensor(mod_cats, dtype=torch.long, device=device) optimizer.zero_grad() outputs = network(indata) lossvector = ctc.cat_mod_flipflop_loss(outputs, seqs, seqlens, mod_cats, can_mods_offsets, mod_cat_weights, mod_factor_t, args.sharpen) loss = lossvector.sum() / (seqlens > 0.0).float().sum() loss.backward() optimizer.step() fval = float(loss) score_smoothed.update(fval) # Check for poison chunk and save losses and chunk locations if we're # poisoned If args.chunk_logging_threshold set to zero then we log # everything if fval / score_smoothed.value >= args.chunk_logging_threshold: chunk_log.write_batch(i, chunk_batch, lossvector) total_bases += int(seqlens.sum()) total_samples += int(indata.nelement()) del indata, seqs, mod_cats, seqlens, outputs, loss, lossvector if device.type == 'cuda': torch.cuda.empty_cache() loss_log.write('{}\t{:.10f}\t{:.10f}\n'.format(i, fval, score_smoothed.value)) if (i + 1) % args.save_every == 0: save_model(network, args.outdir, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % 50 == 0: # In case of super batching, additional functionality must be # added here learning_rate = lr_scheduler.get_lr()[0] tn = time.time() dt = tn - t0 log.write((' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s ' + '{:.2f} kbase/s) lr={:.2e}').format( (i + 1) // 50, score_smoothed.value, dt, total_samples / 1000 / dt, total_bases / 1000 / dt, learning_rate)) # Write summary of chunk rejection reasons for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) log.write("\n") total_bases = 0 total_samples = 0 t0 = tn save_model(network, args.outdir) return
parser = argparse.ArgumentParser( description='Predict squiggle from sequence', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--version', nargs=0, action=display_version_and_exit, metavar=__version__, help='Display version information') parser.add_argument('model', action=FileExists, help='Model file') parser.add_argument('input', action=FileExists, help='Fasta file') if __name__ == '__main__': args = parser.parse_args() predict_squiggle = helpers.load_model(args.model) for seq in SeqIO.parse(args.input, 'fasta'): seqstr = str(seq.seq).encode('ascii') embedded_seq_numpy = np.expand_dims( squiggle_match.embed_sequence(seqstr), axis=1) embedded_seq_torch = torch.tensor(embedded_seq_numpy, dtype=torch.float32) with torch.no_grad(): squiggle = np.squeeze( predict_squiggle(embedded_seq_torch).cpu().numpy(), axis=1) print('base', 'current', 'sd', 'dwell', sep='\t') for base, (mean, logsd, dwell) in zip(seq.seq, squiggle): print(base, mean, np.exp(logsd), np.exp(-dwell), sep='\t')
nargs=2, type=NonNegative(int), metavar=('beginning', 'end'), help='Number of samples to trim off start and end') parser.add_argument('model', action=FileExists, help='Model file') parser.add_argument('references', action=FileExists, help='Fasta file') parser.add_argument('read_dir', action=FileExists, help='Directory for fast5 reads') if __name__ == '__main__': args = parser.parse_args() worker_kwarg_names = ['back_prob', 'localpen', 'minscore', 'trim'] model = helpers.load_model(args.model) fast5_reads = fast5utils.iterate_fast5_reads( args.read_dir, limit=args.limit, strand_list=args.input_strand_list) for res in imap_mp(squiggle_match.worker, fast5_reads, threads=args.jobs, fix_kwargs=helpers.get_kwargs(args, worker_kwarg_names), unordered=True, init=squiggle_match.init_worker, initargs=[model, args.references]): if res is None: continue read_id, sig, score, path, squiggle, bases = res bases = bases.decode('ascii')