def _load_data(args, log): if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write('* Will train from a subset of {} strands, determined ' + 'by read_ids in input strand list\n'.format(len(read_ids))) else: log.write('* Will train from all strands\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: (bases_alphabet, collapse_alphabet, mod_long_names) = per_read_file.get_alphabet_information() read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in # mapped_signal_files.py, based on dict) log.write('* Loaded {} reads.\n'.format(len(read_data))) alphabet_info = alphabet.AlphabetInfo(bases_alphabet, collapse_alphabet, mod_long_names, do_reorder=False) log.write('* Using alphabet definition: {}\n'.format(str(alphabet_info))) return read_data, alphabet_info
def test_mod_prepare_remap(self): print("Current directory is", os.getcwd()) print("Taiyaki dir is", self.taiyakidir) print("Data dir is ", self.datadir) cmd = [ self.script, self.read_dir, self.per_read_params, self.output_mapped_signal_file, self.remapping_model, self.mod_per_read_refs, "--mod", "Z", "C", "5mC", "--mod", "Y", "A", "6mA", "--overwrite" ] r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) print("Result of running make command in shell:") print("Stdout=", r.stdout.decode('utf-8')) print("Stderr=", r.stderr.decode('utf-8')) # Open mapped read file and run checks to see if it complies with file format # Also get a chunk and check that speed is within reasonable bounds with mapped_signal_files.HDF5Reader( self.output_mapped_signal_file) as f: testreport = f.check() print("Test report from checking mapped read file:") print(testreport) self.assertEqual(testreport, "pass") read0 = f.get_multiple_reads("all")[0] chunk = read0.get_chunk_with_sample_length(1000, start_sample=10) # Defined start_sample to make it reproducible - otherwise randomly # located chunk is returned. chunk_meandwell = len( chunk['current']) / (len(chunk['sequence']) + 0.0001) print("chunk mean dwell time in samples = ", chunk_meandwell) assert 7 < chunk_meandwell < 13, "Chunk mean dwell time outside allowed range 7 to 13" return
def check_map_sig_alphabet(model_info, ms_fn): # read filename queue filler msf = mapped_signal_files.HDF5Reader(ms_fn) tai_alph_info = msf.get_alphabet_information() msf.close() if model_info.output_alphabet != tai_alph_info.alphabet: raise mh.MegaError( ( "Different alphabets specified in model ({}) and mapped " + "signal file ({})" ).format(model_info.output_alphabet, tai_alph_info.alphabet) ) if set(model_info.can_alphabet) != set(tai_alph_info.collapse_alphabet): raise mh.MegaError( ( "Different canonical alphabets specified in model ({}) and " + "mapped signal file ({})" ).format(model_info.can_alphabet, tai_alph_info.collapse_alphabet) ) if model_info.ordered_mod_long_names != tai_alph_info.mod_long_names: raise mh.MegaError( ( "Different modified base long names specified in model ({}) and " + "mapped signal file ({})" ).format( ", ".join(model_info.ordered_mod_long_names), ", ".join(tai_alph_info.mod_long_names), ) )
def test_check_HDF5_mapped_read_file(self): """Check that constructing a read object which doesn't conform leads to errors. """ print("Creating flawed Read object from test data") read_dict = construct_mapped_read() read_dict['Reference'] = "I'm not a numpy array!" # Wrong type! read_object = mapped_signal_files.Read(read_dict) print("Checking contents") check_text = read_object.check() print("Check result on read object: should fail") print(check_text) self.assertNotEqual(check_text, "pass") print("Writing to file") alphabet_info = alphabet.AlphabetInfo(DEFAULT_ALPHABET, DEFAULT_ALPHABET) with mapped_signal_files.HDF5Writer(self.testfilepath, alphabet_info) as f: f.write_read(read_object) print("Current dir = ", os.getcwd()) print("File written to ", self.testfilepath) print("\nOpening file for reading") with mapped_signal_files.HDF5Reader(self.testfilepath) as f: ids = f.get_read_ids() print("Read ids=", ids[0]) print("Version number = ", f.version) self.assertEqual(ids[0], read_dict['read_id']) file_test_report = f.check() print("Test report (should fail):", file_test_report) self.assertNotEqual(file_test_report, "pass")
def count_reads(self, mapped_signal_file, print_readlist=True): """Count the number of reads in a mapped signal file.""" with mapped_signal_files.HDF5Reader(mapped_signal_file) as f: read_ids = f.get_read_ids() if print_readlist: print("Read list:") print('\n'.join(read_ids)) return len(read_ids)
def fill_reads_queue(read_q, read_filler_conn, ms_fn, num_reads_limit, num_proc): msf = mapped_signal_files.HDF5Reader(ms_fn) num_reads = 0 for read in msf: read_q.put(read) num_reads += 1 if num_reads_limit is not None and num_reads >= num_reads_limit: break read_filler_conn.send(num_reads) msf.close() for _ in num_proc: read_q.put(None)
def main(): args = parser.parse_args() if args.output is not None: plt.figure(figsize=(12, 10)) reads_sofar = 0 for nfile, mapped_read_file in enumerate(args.mapped_read_files): with mapped_signal_files.HDF5Reader(mapped_read_file) as h5: all_read_ids = h5.get_read_ids() if len(args.read_ids) > 0: read_ids = args.read_ids else: read_ids = all_read_ids[:args.nreads] sys.stderr.write( "Reading first {} read ids in file {}\n".format( args.nreads, mapped_read_file)) for nread, read_id in enumerate(read_ids): r = h5.get_read(read_id) mapping = r['Ref_to_signal'] f = mapping >= 0 maplen = len(mapping) read_info_text = ( 'file {} read {}:{} reflen:{}, daclen:{}').format( nfile, nread, read_id, maplen - 1, len(r['Dacs'])) sys.stdout.write(read_info_text + '\n') if args.output is not None: label = (read_info_text if reads_sofar <= args.maxlegendsize else None) x, y = np.arange(maplen)[f], mapping[f] if args.xmin is not None: xf = x >= args.xmin x, y = x[xf], y[xf] if args.xmax is not None: xf = x <= args.xmax x, y = x[xf], y[xf] plt.plot(x, y, label=label, linestyle='dashed' if nfile == 1 else 'solid') if args.output is not None: plt.grid() plt.xlabel('Reference location') plt.ylabel('Signal location') plt.legend(loc='upper left', framealpha=0.3) plt.tight_layout() sys.stderr.write("Saving plot to {}\n".format(args.output)) plt.savefig(args.output)
def main(): args = parser.parse_args() plt.figure(figsize=(12, 10)) for nfile, mapped_read_file in enumerate(args.mapped_read_files): sys.stderr.write("Opening {}\n".format(mapped_read_file)) with mapped_signal_files.HDF5Reader(mapped_read_file) as h5: all_read_ids = h5.get_read_ids() sys.stderr.write("First ten read_ids in file:\n") for read_id in all_read_ids[:10]: sys.stderr.write(" {}\n".format(read_id)) if len(args.read_ids) > 0: read_ids = args.read_ids else: read_ids = all_read_ids[:args.nreads] sys.stderr.write("Plotting first {} read ids in file\n".format( args.nreads)) for nread, read_id in enumerate(read_ids): sys.stderr.write("Opening read id {}\n".format(read_id)) r = h5.get_read(read_id) mapping = r['Ref_to_signal'] f = mapping >= 0 maplen = len(mapping) label = 'file ' + str(nfile) + ' read ' + str( nread) + ":" + read_id + " reflen:" + str( maplen - 1) + ", daclen:" + str(len(r['Dacs'])) x, y = np.arange(maplen)[f], mapping[f] if args.xmin is not None: xf = (x >= args.xmin) x, y = x[xf], y[xf] if args.xmax is not None: xf = (x <= args.xmax) x, y = x[xf], y[xf] plt.plot(x, y, label=label, linestyle='dashed' if nfile == 1 else 'solid') plt.grid() plt.xlabel('Reference location') plt.ylabel('Signal location') if len(read_ids) < 15: plt.legend(loc='upper left', framealpha=0.3) plt.tight_layout() sys.stderr.write("Saving plot to {}\n".format(args.output)) plt.savefig(args.output)
def main(): args = parser.parse_args() np.random.seed(args.seed) device = torch.device(args.device) if device.type == 'cuda': try: torch.cuda.set_device(device) except AttributeError: sys.stderr.write('ERROR: Torch not compiled with CUDA enabled ' + 'and GPU device set.') sys.exit(1) if not os.path.exists(args.output): os.mkdir(args.output) elif not args.overwrite: sys.stderr.write('Error: Output directory {} exists but --overwrite ' + 'is false\n'.format(args.output)) exit(1) if not os.path.isdir(args.output): sys.stderr.write('Error: Output location {} is not directory\n'.format( args.output)) exit(1) copyfile(args.model, os.path.join(args.output, 'model.py')) # Create a logging file to save details of chunks. # If args.chunk_logging_threshold is set to 0 then we log all chunks # including those rejected. chunk_log = chunk_selection.ChunkLog(args.output) log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet) log.write('* Taiyaki version {}\n'.format(__version__)) log.write('* Command line\n') log.write(' '.join(sys.argv) + '\n') log.write('* Loading data from {}\n'.format(args.input)) log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input))) if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write(('* Will train from a subset of {} strands, determined ' + 'by read_ids in input strand list\n').format(len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet, _, _ = per_read_file.get_alphabet_information() read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in # mapped_signal_files.py, based on dict) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell # Choose a chunk length in the middle of the range for this sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2 filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, sampling_chunk_len, args, log, chunk_log=chunk_log) medmd, madmd = filter_parameters log.write( "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n" .format(args.sample_nreads_before_filtering, medmd, madmd)) log.write('* Reading network from {}\n'.format(args.model)) nbase = len(alphabet) model_kwargs = { 'stride': args.stride, 'winlen': args.winlen, # Number of input features to model e.g. was >1 for event-based # models (level, std, dwell) 'insize': 1, 'size': args.size, 'outsize': flipflopfings.nstate_flipflop(nbase) } network = helpers.load_model(args.model, **model_kwargs).to(device) log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay) lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min, args.lr_cosine_iters) score_smoothed = helpers.WindowedExpSmoother() log.write('* Dumping initial model\n') helpers.save_model(network, args.output, 0) total_bases = 0 total_samples = 0 total_chunks = 0 # To count the numbers of different sorts of chunk rejection rejection_dict = defaultdict(int) t0 = time.time() log.write('* Training\n') for i in range(args.niteration): lr_scheduler.step() # Chunk length is chosen randomly in the range given but forced to # be a multiple of the stride batch_chunk_len = ( np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) // args.stride) * args.stride # We choose the batch size so that the size of the data in the batch # is about the same as args.min_batch_size chunks of length # args.chunk_len_max target_batch_size = int(args.min_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) # ...but it can't be more than the number of reads. batch_size = min(target_batch_size, len(read_data)) # If the logging threshold is 0 then we log all chunks, including those # rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # Chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, batch_size, batch_chunk_len, filter_parameters, args, log, chunk_log=log_rejected_chunks) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input tensor must be: # (timesteps) x (batch size) x (input channels) # in this case: # batch_chunk_len x batch_size x 1 stacked_current = np.vstack([d['current'] for d in chunk_batch]).T indata = torch.tensor(stacked_current, device=device, dtype=torch.float32).unsqueeze(2) # Sequence input tensor is just a 1D vector, and so is seqlens seqs = torch.tensor(np.concatenate([ flipflopfings.flipflop_code(d['sequence'], nbase) for d in chunk_batch ]), device=device, dtype=torch.long) seqlens = torch.tensor([len(d['sequence']) for d in chunk_batch], dtype=torch.long, device=device) optimizer.zero_grad() outputs = network(indata) lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens, args.sharpen) loss = lossvector.sum() / (seqlens > 0.0).float().sum() loss.backward() optimizer.step() fval = float(loss) score_smoothed.update(fval) # Check for poison chunk and save losses and chunk locations if we're # poisoned If args.chunk_logging_threshold set to zero then we log # everything if fval / score_smoothed.value >= args.chunk_logging_threshold: chunk_log.write_batch(i, chunk_batch, lossvector) total_bases += int(seqlens.sum()) total_samples += int(indata.nelement()) # Doing this deletion leads to less CUDA memory usage. del indata, seqs, seqlens, outputs, loss, lossvector if device.type == 'cuda': torch.cuda.empty_cache() if (i + 1) % args.save_every == 0: helpers.save_model(network, args.output, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: # In case of super batching, additional functionality must be # added here learning_rate = lr_scheduler.get_lr()[0] tn = time.time() dt = tn - t0 t = ( ' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) ' + 'lr={:.2e}') log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt, total_samples / 1000.0 / dt, total_bases / 1000.0 / dt, learning_rate)) # Write summary of chunk rejection reasons for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) log.write("\n") total_bases = 0 total_samples = 0 t0 = tn helpers.save_model(network, args.output)
def main(): args = parser.parse_args() np.random.seed(args.seed) if not os.path.exists(args.output): os.mkdir(args.output) elif not args.overwrite: sys.stderr.write( 'Error: Output directory {} exists but --overwrite is false\n'. format(args.output)) exit(1) if not os.path.isdir(args.output): sys.stderr.write('Error: Output location {} is not directory\n'.format( args.output)) exit(1) log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet) log.write('# Taiyaki version {}\n'.format(__version__)) log.write('# Command line\n') log.write(' '.join(sys.argv) + '\n') if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write('* Will train from a subset of {} strands\n'.format( len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet, _, _ = per_read_file.get_alphabet_information() assert len(alphabet) == 4, ( 'Squiggle prediction with modified base training data is ' + 'not currenly supported.') read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in mapped_signal_files.py, based on dict) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Create a logging file to save details of chunks. # If args.chunk_logging_threshold is set to 0 then we log all chunks including those rejected. chunk_log = chunk_selection.ChunkLog(args.output) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, args.target_len, args, log, chunk_log=chunk_log) medmd, madmd = filter_parameters log.write( "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n" .format(args.sample_nreads_before_filtering, medmd, madmd)) conv_net = create_convolution(args.size, args.depth, args.winlen) nparam = sum([p.data.detach().numpy().size for p in conv_net.parameters()]) log.write('# Created network. {} parameters\n'.format(nparam)) log.write('# Depth {} layers ({} residual layers)\n'.format( args.depth + 2, args.depth)) log.write('# Window width {}\n'.format(args.winlen)) log.write('# Context +/- {} bases\n'.format( (args.depth + 2) * (args.winlen // 2))) device = torch.device(args.device) conv_net = conv_net.to(device) optimizer = torch.optim.Adam(conv_net.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay) lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_decay) rejection_dict = defaultdict( lambda: 0 ) # To count the numbers of different sorts of chunk rejection t0 = time.time() score_smoothed = helpers.WindowedExpSmoother() total_chunks = 0 for i in range(args.niteration): lr_scheduler.step() # If the logging threshold is 0 then we log all chunks, including those rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, args.batch_size, args.target_len, filter_parameters, args, log, chunk_log=log_rejected_chunks, chunk_len_means_sequence_len=True) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input needs to be seqlen x batchsize x embedding_dimension embedded_matrix = [ embed_sequence(d['sequence'], alphabet=None) for d in chunk_batch ] seq_embed = torch.tensor(embedded_matrix).permute(1, 0, 2).to(device) # Shape of labels is a flat vector batch_signal = torch.tensor( np.concatenate([d['current'] for d in chunk_batch])).to(device) # Shape of lens is also a flat vector batch_siglen = torch.tensor([len(d['current']) for d in chunk_batch]).to(device) #print("First 10 elements of first sequence in batch",seq_embed[:10,0,:]) #print("First 10 elements of signal batch",batch_signal[:10]) #print("First 10 lengths",batch_siglen[:10]) optimizer.zero_grad() predicted_squiggle = conv_net(seq_embed) batch_loss = squiggle_match_loss(predicted_squiggle, batch_signal, batch_siglen, args.back_prob) fval = batch_loss.sum() / float(batch_siglen.sum()) fval.backward() optimizer.step() score_smoothed.update(float(fval)) # Check for poison chunk and save losses and chunk locations if we're poisoned # If args.chunk_logging_threshold set to zero then we log everything if fval / score_smoothed.value >= args.chunk_logging_threshold: chunk_log.write_batch(i, chunk_batch, batch_loss) if (i + 1) % args.save_every == 0: helpers.save_model(conv_net, args.output, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: tn = time.time() dt = tn - t0 t = ' {:5d} {:5.3f} {:5.2f}s' log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt)) t0 = tn # Write summary of chunk rejection reasons for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) log.write("\n") helpers.save_model(conv_net, args.output)
def main(): args = parser.parse_args() is_multi_gpu = (args.local_rank is not None) is_lead_process = (not is_multi_gpu) or args.local_rank == 0 if is_multi_gpu: #Use distributed parallel processing to run one process per GPU try: torch.distributed.init_process_group(backend='nccl') except: raise Exception( "Unable to start multiprocessing group. " + "The most likely reason is that the script is running with " + "local_rank set but without the set-up for distributed " + "operation. local_rank should be used " + "only by torch.distributed.launch. See the README.") device = helpers.set_torch_device(args.local_rank) if args.seed is not None: #Make sure processes get different random picks of training data np.random.seed(args.seed + args.local_rank) else: device = helpers.set_torch_device(args.device) np.random.seed(args.seed) if is_lead_process: helpers.prepare_outdir(args.outdir, args.overwrite) if args.model.endswith('.py'): copyfile(args.model, os.path.join(args.outdir, 'model.py')) batchlog = helpers.BatchLog(args.outdir) logfile = os.path.join(args.outdir, 'model.log') else: logfile = None log = helpers.Logger(logfile, args.quiet) log.write(helpers.formatted_env_info(device)) log.write('* Loading data from {}\n'.format(args.input)) log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input))) if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write(('* Will train from a subset of {} strands, determined ' + 'by read_ids in input strand list\n').format(len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet_info = per_read_file.get_alphabet_information() read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in # mapped_signal_files.py, based on dict) log.write('* Using alphabet definition: {}\n'.format(str(alphabet_info))) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell # Choose a chunk length in the middle of the range for this sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2 filter_params = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, sampling_chunk_len, args.filter_mean_dwell, args.filter_max_dwell) log.write("* Sampled {} chunks".format( args.sample_nreads_before_filtering)) log.write(": median(mean_dwell)={:.2f}".format( filter_params.median_meandwell)) log.write(", mad(mean_dwell)={:.2f}\n".format(filter_params.mad_meandwell)) log.write('* Reading network from {}\n'.format(args.model)) model_kwargs = { 'stride': args.stride, 'winlen': args.winlen, # Number of input features to model e.g. was >1 for event-based # models (level, std, dwell) 'insize': 1, 'size': args.size, 'alphabet_info': alphabet_info } if is_lead_process: # Under pytorch's DistributedDataParallel scheme, we # need a clone of the start network to use as a template for saving # checkpoints. Necessary because DistributedParallel makes the class # structure different. network_save_skeleton = helpers.load_model(args.model, **model_kwargs) log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network_save_skeleton.parameters()]))) if not alphabet_info.is_compatible_model(network_save_skeleton): sys.stderr.write( '* ERROR: Model and mapped signal files contain incompatible ' + 'alphabet definitions (including modified bases).') sys.exit(1) if is_cat_mod_model(network_save_skeleton): log.write('* Loaded categorical modified base model.\n') if not alphabet_info.contains_modified_bases(): sys.stderr.write( '* ERROR: Modified bases model specified, but mapped ' + 'signal file does not contain modified bases.') sys.exit(1) else: log.write('* Loaded standard (canonical bases-only) model.\n') if alphabet_info.contains_modified_bases(): sys.stderr.write( '* ERROR: Standard (canonical bases only) model ' + 'specified, but mapped signal file does contains ' + 'modified bases.') sys.exit(1) log.write('* Dumping initial model\n') helpers.save_model(network_save_skeleton, args.outdir, 0) if is_multi_gpu: #so that processes 1,2,3.. don't try to load before process 0 has saved torch.distributed.barrier() log.write('* MultiGPU process {}'.format(args.local_rank)) log.write(': loading initial model saved by process 0\n') saved_startmodel_path = os.path.join( args.outdir, 'model_checkpoint_00000.checkpoint') network = helpers.load_model(saved_startmodel_path).to(device) # Wrap network for training in the DistributedDataParallel structure network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[args.local_rank], output_device=args.local_rank) else: network = network_save_skeleton.to(device) network_save_skeleton = None optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay, eps=args.eps) if args.lr_warmup is None: lr_warmup = args.lr_min else: lr_warmup = args.lr_warmup if args.lr_frac_decay is not None: lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_frac_decay, args.warmup_batches, lr_warmup) log.write('* Learning rate schedule lr_max*k/(k+t)') log.write(', k={}, t=iterations.\n'.format(args.lr_frac_decay)) else: lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min, args.lr_cosine_iters, args.warmup_batches, lr_warmup) log.write('* Learning rate goes like cosine from lr_max to lr_min ') log.write('over {} iterations.\n'.format(args.lr_cosine_iters)) log.write('* At start, train for {} '.format(args.warmup_batches)) log.write('batches at warm-up learning rate {:3.2}\n'.format(lr_warmup)) score_smoothed = helpers.WindowedExpSmoother() # prepare modified base paramter tensors network_is_catmod = is_cat_mod_model(network) mod_factor_t = torch.tensor(args.mod_factor, dtype=torch.float32).to(device) can_mods_offsets = (network.sublayers[-1].can_mods_offsets if network_is_catmod else None) # mod cat inv freq weighting is currently disabled. Compute and set this # value to enable mod cat weighting mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32) #Generating list of batches for standard loss reporting reporting_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2 reporting_batch_list = list( prepare_random_batches(device, read_data, reporting_chunk_len, args.min_sub_batch_size, args.reporting_sub_batches, alphabet_info, filter_params, network, network_is_catmod, log)) log.write( ('* Standard loss report: chunk length = {} & sub-batch size ' + '= {} for {} sub-batches. \n').format(reporting_chunk_len, args.min_sub_batch_size, args.reporting_sub_batches)) #Set cap at very large value (before we have any gradient stats). gradient_cap = constants.LARGE_VAL if args.gradient_cap_fraction is None: log.write('* No gradient capping\n') else: rolling_quantile = maths.RollingQuantile(args.gradient_cap_fraction) log.write('* Gradient L2 norm cap will be upper' + ' {:3.2f} quantile of the last {} norms.\n'.format( args.gradient_cap_fraction, rolling_quantile.window)) total_bases = 0 total_samples = 0 total_chunks = 0 # To count the numbers of different sorts of chunk rejection rejection_dict = defaultdict(int) t0 = time.time() log.write('* Training\n') for i in range(args.niteration): # Chunk length is chosen randomly in the range given but forced to # be a multiple of the stride batch_chunk_len = ( np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) // args.stride) * args.stride # We choose the size of a sub-batch so that the size of the data in # the sub-batch is about the same as args.min_sub_batch_size chunks of # length args.chunk_len_max sub_batch_size = int(args.min_sub_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) optimizer.zero_grad() main_batch_gen = prepare_random_batches( device, read_data, batch_chunk_len, sub_batch_size, args.sub_batches, alphabet_info, filter_params, network, network_is_catmod, log) chunk_count, fval, chunk_samples, chunk_bases, batch_rejections = \ calculate_loss( network, network_is_catmod, main_batch_gen, args.sharpen, can_mods_offsets, mod_cat_weights, mod_factor_t, calc_grads = True ) gradnorm_uncapped = torch.nn.utils.clip_grad_norm_( network.parameters(), gradient_cap) if args.gradient_cap_fraction is not None: gradient_cap = rolling_quantile.update(gradnorm_uncapped) optimizer.step() if is_lead_process: batchlog.record( fval, gradnorm_uncapped, None if args.gradient_cap_fraction is None else gradient_cap) total_chunks += chunk_count total_samples += chunk_samples total_bases += chunk_bases # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v score_smoothed.update(fval) if (i + 1) % args.save_every == 0 and is_lead_process: helpers.save_model(network, args.outdir, (i + 1) // args.save_every, network_save_skeleton) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: _, rloss, _, _, _ = calculate_loss(network, network_is_catmod, reporting_batch_list, args.sharpen, can_mods_offsets, mod_cat_weights, mod_factor_t) # In case of super batching, additional functionality must be # added here learning_rate = lr_scheduler.get_lr()[0] tn = time.time() dt = tn - t0 t = (' {:5d} {:7.5f} {:7.5f} {:5.2f}s ({:.2f} ksample/s {:.2f} ' + 'kbase/s) lr={:.2e}') log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, rloss, dt, total_samples / 1000.0 / dt, total_bases / 1000.0 / dt, learning_rate)) # Write summary of chunk rejection reasons if args.full_filter_status: for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) else: n_tot = n_fail = 0 for k, v in rejection_dict.items(): n_tot += v if k != 'pass': n_fail += v log.write(" {:.1%} chunks filtered".format(n_fail / n_tot)) log.write("\n") total_bases = 0 total_samples = 0 t0 = tn # Uncomment the lines below to check synchronisation of models # between processes in multi-GPU operation #for p in network.parameters(): # v = p.data.reshape(-1)[:5].to('cpu') # u = p.data.reshape(-1)[-5:].to('cpu') # break #if args.local_rank is not None: # log.write("* GPU{} params:".format(args.local_rank)) #log.write("{}...{}\n".format(v,u)) lr_scheduler.step() if is_lead_process: helpers.save_model(network, args.outdir, model_skeleton=network_save_skeleton)
def main(): args = parser.parse_args() np.random.seed(args.seed) helpers.prepare_outdir(args.outdir, args.overwrite) device = helpers.set_torch_device(args.device) log = helpers.Logger(os.path.join(args.outdir, 'model.log'), args.quiet) log.write(helpers.formatted_env_info(device)) if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write('* Will train from a subset of {} strands\n'.format( len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet_info = per_read_file.get_alphabet_information() assert alphabet_info.nbase == 4, ( 'Squiggle prediction with modified base training data is ' + 'not currenly supported.') read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in mapped_signal_files.py, based on dict) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, args.target_len, args.filter_mean_dwell, args.filter_max_dwell) log.write( "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n" .format(args.sample_nreads_before_filtering, filter_parameters.median_meandwell, filter_parameters.mad_meandwell)) conv_net = create_convolution(args.size, args.depth, args.winlen) nparam = sum([p.data.detach().numpy().size for p in conv_net.parameters()]) log.write('* Created network. {} parameters\n'.format(nparam)) log.write('* Depth {} layers ({} residual layers)\n'.format( args.depth + 2, args.depth)) log.write('* Window width {}\n'.format(args.winlen)) log.write('* Context +/- {} bases\n'.format( (args.depth + 2) * (args.winlen // 2))) conv_net = conv_net.to(device) optimizer = torch.optim.Adam(conv_net.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay, eps=args.eps) lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_decay) rejection_dict = defaultdict( lambda: 0 ) # To count the numbers of different sorts of chunk rejection t0 = time.time() score_smoothed = helpers.WindowedExpSmoother() total_chunks = 0 for i in range(args.niteration): # If the logging threshold is 0 then we log all chunks, including those rejected, so pass the log # object into assemble_batch # chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, args.batch_size, args.target_len, filter_parameters, chunk_len_means_sequence_len=True) if len(chunk_batch) < args.batch_size: log.write('* Warning: only {} chunks passed filters.\n'.format( len(chunk_batch))) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input needs to be seqlen x batchsize x embedding_dimension embedded_matrix = [ embed_sequence(d['sequence'], alphabet=None) for d in chunk_batch ] seq_embed = torch.tensor(embedded_matrix).permute(1, 0, 2).to(device) # Shape of labels is a flat vector batch_signal = torch.tensor( np.concatenate([d['current'] for d in chunk_batch])).to(device) # Shape of lens is also a flat vector batch_siglen = torch.tensor([len(d['current']) for d in chunk_batch]).to(device) #print("First 10 elements of first sequence in batch",seq_embed[:10,0,:]) #print("First 10 elements of signal batch",batch_signal[:10]) #print("First 10 lengths",batch_siglen[:10]) optimizer.zero_grad() predicted_squiggle = conv_net(seq_embed) batch_loss = squiggle_match_loss(predicted_squiggle, batch_signal, batch_siglen, args.back_prob) fval = batch_loss.sum() / float(batch_siglen.sum()) fval.backward() optimizer.step() score_smoothed.update(float(fval)) if (i + 1) % args.save_every == 0: helpers.save_model(conv_net, args.outdir, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: tn = time.time() dt = tn - t0 t = ' {:5d} {:7.5f} {:5.2f}s' log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt)) t0 = tn # Write summary of chunk rejection reasons if args.full_filter_status: for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) else: n_tot = n_fail = 0 for k, v in rejection_dict.items(): n_tot += v if k != 'pass': n_fail += v log.write(" {:.1%} chunks filtered".format(n_fail / n_tot)) log.write("\n") lr_scheduler.step() helpers.save_model(conv_net, args.outdir)
def test_HDF5_mapped_read_file(self): """Test that we can save a mapped read file, open it again and use some methods to get data from it. Plot a picture for diagnostics. """ print("Creating Read object from test data") read_dict = construct_mapped_read() read_object = mapped_signal_files.Read(read_dict) print("Checking contents") check_text = read_object.check() print("Check result on read object:") print(check_text) self.assertEqual(check_text, "pass") print("Writing to file") alphabet_info = alphabet.AlphabetInfo(DEFAULT_ALPHABET, DEFAULT_ALPHABET) with mapped_signal_files.HDF5Writer(self.testfilepath, alphabet_info) as f: f.write_read(read_object) print("Current dir = ", os.getcwd()) print("File written to ", self.testfilepath) print("\nOpening file for reading") with mapped_signal_files.HDF5Reader(self.testfilepath) as f: ids = f.get_read_ids() print("Read ids=", ids[0]) print("Version number = ", f.version) self.assertEqual(ids[0], read_dict['read_id']) file_test_report = f.check() print("Test report:", file_test_report) self.assertEqual(file_test_report, "pass") read_list = f.get_multiple_reads("all") recovered_read = read_list[0] reflen = len(recovered_read['Reference']) siglen = len(recovered_read['Dacs']) # Get a chunk - note that chunkstart is relative to the start of the mapped # region, not relative to the start of the signal chunklen, chunkstart = 5, 3 chunkdict = recovered_read.get_chunk_with_sample_length(chunklen, chunkstart) # Check that the extracted chunk is the right length self.assertEqual(len(chunkdict['current']), chunklen) # Check that the mapping data agrees with what we put in self.assertTrue(np.all(recovered_read['Ref_to_signal']==read_dict['Ref_to_signal'])) # Plot a picture showing ref_to_sig from the read object, def setup(): # and the result of searches to find the inverse if False: plt.figure() plt.xlabel('Signal coord') plt.ylabel('Ref coord') ix = np.array([0, -1]) plt.scatter(chunkdict['current'][ix], chunkdict['sequence'][ix], s=50, label='chunk limits', marker='s', color='black') plt.scatter(recovered_read['Ref_to_signal'], np.arange(reflen + 1), label='reftosig (source data)', color='none', edgecolor='blue', s=60) siglocs = np.arange(siglen, dtype=np.int32) sigtoref_fromsearch = recovered_read.get_reference_locations(siglocs) plt.scatter(siglocs, sigtoref_fromsearch, label='from search', color='red', marker='x', s=50) plt.legend() plt.grid() plt.savefig(self.plotfilepath) print("Saved plot to", self.plotfilepath)