def prepare_random_batches(device, read_data, batch_chunk_len, sub_batch_size, target_sub_batches, alphabet_info, filter_params, network, network_is_catmod, log): total_sub_batches = 0 while total_sub_batches < target_sub_batches: # Chunk_batch is a list of dicts chunk_batch, batch_rejections = \ chunk_selection.assemble_batch(read_data, sub_batch_size, batch_chunk_len, filter_params) if len(chunk_batch) < sub_batch_size: log.write( '* Warning: only {} chunks passed filters (asked for {}).\n'. format(len(chunk_batch), sub_batch_size)) if not all(len(d['sequence']) > 0.0 for d in chunk_batch): raise Exception('Error: zero length sequence') # Shape of input tensor must be: # (timesteps) x (batch size) x (input channels) # in this case: # batch_chunk_len x sub_batch_size x 1 stacked_current = np.vstack([d['current'] for d in chunk_batch]).T indata = torch.tensor(stacked_current, device=device, dtype=torch.float32).unsqueeze(2) # Prepare seqs, seqlens and (if necessary) mod_cats seqs, seqlens = [], [] mod_cats = [] if network_is_catmod else None for chunk in chunk_batch: chunk_labels = chunk['sequence'] seqlens.append(len(chunk_labels)) if network_is_catmod: chunk_mod_cats = np.ascontiguousarray( network.sublayers[-1].mod_labels[chunk_labels]) mod_cats.append(chunk_mod_cats) # convert chunk_labels to canonical base labels chunk_labels = np.ascontiguousarray( network.sublayers[-1].can_labels[chunk_labels]) chunk_seq = flipflopfings.flipflop_code(chunk_labels, alphabet_info.ncan_base) seqs.append(chunk_seq) seqs = torch.tensor(np.concatenate(seqs), dtype=torch.float32, device=device) seqlens = torch.tensor(seqlens, dtype=torch.long, device=device) if network_is_catmod: mod_cats = torch.tensor(np.concatenate(mod_cats), dtype=torch.long, device=device) total_sub_batches += 1 yield indata, seqs, seqlens, mod_cats, sub_batch_size, batch_rejections
def main(): args = parser.parse_args() np.random.seed(args.seed) device = torch.device(args.device) if device.type == 'cuda': try: torch.cuda.set_device(device) except AttributeError: sys.stderr.write('ERROR: Torch not compiled with CUDA enabled ' + 'and GPU device set.') sys.exit(1) if not os.path.exists(args.output): os.mkdir(args.output) elif not args.overwrite: sys.stderr.write('Error: Output directory {} exists but --overwrite ' + 'is false\n'.format(args.output)) exit(1) if not os.path.isdir(args.output): sys.stderr.write('Error: Output location {} is not directory\n'.format( args.output)) exit(1) copyfile(args.model, os.path.join(args.output, 'model.py')) # Create a logging file to save details of chunks. # If args.chunk_logging_threshold is set to 0 then we log all chunks # including those rejected. chunk_log = chunk_selection.ChunkLog(args.output) log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet) log.write('* Taiyaki version {}\n'.format(__version__)) log.write('* Command line\n') log.write(' '.join(sys.argv) + '\n') log.write('* Loading data from {}\n'.format(args.input)) log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input))) if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write(('* Will train from a subset of {} strands, determined ' + 'by read_ids in input strand list\n').format(len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet, _, _ = per_read_file.get_alphabet_information() read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in # mapped_signal_files.py, based on dict) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell # Choose a chunk length in the middle of the range for this sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2 filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, sampling_chunk_len, args, log, chunk_log=chunk_log) medmd, madmd = filter_parameters log.write( "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n" .format(args.sample_nreads_before_filtering, medmd, madmd)) log.write('* Reading network from {}\n'.format(args.model)) nbase = len(alphabet) model_kwargs = { 'stride': args.stride, 'winlen': args.winlen, # Number of input features to model e.g. was >1 for event-based # models (level, std, dwell) 'insize': 1, 'size': args.size, 'outsize': flipflopfings.nstate_flipflop(nbase) } network = helpers.load_model(args.model, **model_kwargs).to(device) log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay) lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min, args.lr_cosine_iters) score_smoothed = helpers.WindowedExpSmoother() log.write('* Dumping initial model\n') helpers.save_model(network, args.output, 0) total_bases = 0 total_samples = 0 total_chunks = 0 # To count the numbers of different sorts of chunk rejection rejection_dict = defaultdict(int) t0 = time.time() log.write('* Training\n') for i in range(args.niteration): lr_scheduler.step() # Chunk length is chosen randomly in the range given but forced to # be a multiple of the stride batch_chunk_len = ( np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) // args.stride) * args.stride # We choose the batch size so that the size of the data in the batch # is about the same as args.min_batch_size chunks of length # args.chunk_len_max target_batch_size = int(args.min_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) # ...but it can't be more than the number of reads. batch_size = min(target_batch_size, len(read_data)) # If the logging threshold is 0 then we log all chunks, including those # rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # Chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, batch_size, batch_chunk_len, filter_parameters, args, log, chunk_log=log_rejected_chunks) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input tensor must be: # (timesteps) x (batch size) x (input channels) # in this case: # batch_chunk_len x batch_size x 1 stacked_current = np.vstack([d['current'] for d in chunk_batch]).T indata = torch.tensor(stacked_current, device=device, dtype=torch.float32).unsqueeze(2) # Sequence input tensor is just a 1D vector, and so is seqlens seqs = torch.tensor(np.concatenate([ flipflopfings.flipflop_code(d['sequence'], nbase) for d in chunk_batch ]), device=device, dtype=torch.long) seqlens = torch.tensor([len(d['sequence']) for d in chunk_batch], dtype=torch.long, device=device) optimizer.zero_grad() outputs = network(indata) lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens, args.sharpen) loss = lossvector.sum() / (seqlens > 0.0).float().sum() loss.backward() optimizer.step() fval = float(loss) score_smoothed.update(fval) # Check for poison chunk and save losses and chunk locations if we're # poisoned If args.chunk_logging_threshold set to zero then we log # everything if fval / score_smoothed.value >= args.chunk_logging_threshold: chunk_log.write_batch(i, chunk_batch, lossvector) total_bases += int(seqlens.sum()) total_samples += int(indata.nelement()) # Doing this deletion leads to less CUDA memory usage. del indata, seqs, seqlens, outputs, loss, lossvector if device.type == 'cuda': torch.cuda.empty_cache() if (i + 1) % args.save_every == 0: helpers.save_model(network, args.output, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: # In case of super batching, additional functionality must be # added here learning_rate = lr_scheduler.get_lr()[0] tn = time.time() dt = tn - t0 t = ( ' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) ' + 'lr={:.2e}') log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt, total_samples / 1000.0 / dt, total_bases / 1000.0 / dt, learning_rate)) # Write summary of chunk rejection reasons for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) log.write("\n") total_bases = 0 total_samples = 0 t0 = tn helpers.save_model(network, args.output)
def main(): args = parser.parse_args() np.random.seed(args.seed) if not os.path.exists(args.output): os.mkdir(args.output) elif not args.overwrite: sys.stderr.write( 'Error: Output directory {} exists but --overwrite is false\n'. format(args.output)) exit(1) if not os.path.isdir(args.output): sys.stderr.write('Error: Output location {} is not directory\n'.format( args.output)) exit(1) log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet) log.write('# Taiyaki version {}\n'.format(__version__)) log.write('# Command line\n') log.write(' '.join(sys.argv) + '\n') if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write('* Will train from a subset of {} strands\n'.format( len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet, _, _ = per_read_file.get_alphabet_information() assert len(alphabet) == 4, ( 'Squiggle prediction with modified base training data is ' + 'not currenly supported.') read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in mapped_signal_files.py, based on dict) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Create a logging file to save details of chunks. # If args.chunk_logging_threshold is set to 0 then we log all chunks including those rejected. chunk_log = chunk_selection.ChunkLog(args.output) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, args.target_len, args, log, chunk_log=chunk_log) medmd, madmd = filter_parameters log.write( "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n" .format(args.sample_nreads_before_filtering, medmd, madmd)) conv_net = create_convolution(args.size, args.depth, args.winlen) nparam = sum([p.data.detach().numpy().size for p in conv_net.parameters()]) log.write('# Created network. {} parameters\n'.format(nparam)) log.write('# Depth {} layers ({} residual layers)\n'.format( args.depth + 2, args.depth)) log.write('# Window width {}\n'.format(args.winlen)) log.write('# Context +/- {} bases\n'.format( (args.depth + 2) * (args.winlen // 2))) device = torch.device(args.device) conv_net = conv_net.to(device) optimizer = torch.optim.Adam(conv_net.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay) lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_decay) rejection_dict = defaultdict( lambda: 0 ) # To count the numbers of different sorts of chunk rejection t0 = time.time() score_smoothed = helpers.WindowedExpSmoother() total_chunks = 0 for i in range(args.niteration): lr_scheduler.step() # If the logging threshold is 0 then we log all chunks, including those rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, args.batch_size, args.target_len, filter_parameters, args, log, chunk_log=log_rejected_chunks, chunk_len_means_sequence_len=True) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input needs to be seqlen x batchsize x embedding_dimension embedded_matrix = [ embed_sequence(d['sequence'], alphabet=None) for d in chunk_batch ] seq_embed = torch.tensor(embedded_matrix).permute(1, 0, 2).to(device) # Shape of labels is a flat vector batch_signal = torch.tensor( np.concatenate([d['current'] for d in chunk_batch])).to(device) # Shape of lens is also a flat vector batch_siglen = torch.tensor([len(d['current']) for d in chunk_batch]).to(device) #print("First 10 elements of first sequence in batch",seq_embed[:10,0,:]) #print("First 10 elements of signal batch",batch_signal[:10]) #print("First 10 lengths",batch_siglen[:10]) optimizer.zero_grad() predicted_squiggle = conv_net(seq_embed) batch_loss = squiggle_match_loss(predicted_squiggle, batch_signal, batch_siglen, args.back_prob) fval = batch_loss.sum() / float(batch_siglen.sum()) fval.backward() optimizer.step() score_smoothed.update(float(fval)) # Check for poison chunk and save losses and chunk locations if we're poisoned # If args.chunk_logging_threshold set to zero then we log everything if fval / score_smoothed.value >= args.chunk_logging_threshold: chunk_log.write_batch(i, chunk_batch, batch_loss) if (i + 1) % args.save_every == 0: helpers.save_model(conv_net, args.output, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: tn = time.time() dt = tn - t0 t = ' {:5d} {:5.3f} {:5.2f}s' log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt)) t0 = tn # Write summary of chunk rejection reasons for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) log.write("\n") helpers.save_model(conv_net, args.output)
target_batch_size = int(args.min_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) # ...but it can't be more than the number of reads. batch_size = min(target_batch_size, len(read_data)) # If the logging threshold is 0 then we log all chunks, including those rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # Chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, batch_size, batch_chunk_len, filter_parameters, args, log, chunk_log=log_rejected_chunks) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input tensor must be (timesteps) x (batch size) x (input channels) # in this case batch_chunk_len x batch_size x 1 stacked_current = np.vstack([d['current'] for d in chunk_batch]).T indata = torch.tensor(stacked_current, device=device, dtype=torch.float32).unsqueeze(2)
def main(): args = parser.parse_args() log, loss_log, chunk_log, device = _setup_and_logs(args) read_data, alphabet_info = _load_data(args, log) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell # Choose a chunk length in the middle of the range for this filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, (args.chunk_len_min + args.chunk_len_max) // 2, args, log, chunk_log=chunk_log) log.write(("* Sampled {} chunks: median(mean_dwell)={:.2f}, " + "mad(mean_dwell)={:.2f}\n").format( args.sample_nreads_before_filtering, *filter_parameters)) log.write('* Reading network from {}\n'.format(args.model)) model_kwargs = { 'insize': 1, 'winlen': args.winlen, 'stride': args.stride, 'size': args.size, 'alphabet_info': alphabet_info } network = helpers.load_model(args.model, **model_kwargs).to(device) if not isinstance(network.sublayers[-1], layers.GlobalNormFlipFlopCatMod): log.write( 'ERROR: Model must end with GlobalNormCatModFlipFlop layer, ' + 'not {}.\n'.format(str(network.sublayers[-1]))) sys.exit(1) can_mods_offsets = network.sublayers[-1].can_mods_offsets flipflop_can_labels = network.sublayers[-1].can_labels flipflop_mod_labels = network.sublayers[-1].mod_labels flipflop_ncan_base = network.sublayers[-1].ncan_base log.write('* Loaded categorical modifications flip-flop model.\n') log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay) lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min, args.lr_cosine_iters) if args.scale_mod_loss: try: mod_cat_weights = alphabet_info.compute_mod_inv_freq_weights( read_data, args.num_inv_freq_reads) log.write('* Modified base weights: {}\n'.format( str(mod_cat_weights))) except NotImplementedError: log.write( '* WARNING: Some mods not found when computing inverse ' + 'frequency weights. Consider raising ' + '[--num_inv_freq_reads].\n') mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32) else: mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32) log.write('* Dumping initial model\n') save_model(network, args.outdir, 0) total_bases = 0 total_chunks = 0 total_samples = 0 # To count the numbers of different sorts of chunk rejection rejection_dict = defaultdict(int) score_smoothed = helpers.WindowedExpSmoother() t0 = time.time() log.write('* Training\n') for i in range(args.niteration): lr_scheduler.step() mod_factor_t = torch.tensor(args.mod_factor, dtype=torch.float32) # Chunk length is chosen randomly in the range given but forced to # be a multiple of the stride batch_chunk_len = ( np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) // args.stride) * args.stride # We choose the batch size so that the size of the data in the batch # is about the same as args.min_batch_size chunks of length # args.chunk_len_max target_batch_size = int(args.min_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) # ...but it can't be more than the number of reads. batch_size = min(target_batch_size, len(read_data)) # If the logging threshold is 0 then we log all chunks, including those # rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, batch_size, batch_chunk_len, filter_parameters, args, log, chunk_log=log_rejected_chunks) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input tensor must be: # (timesteps) x (batch size) x (input channels) # in this case: # batch_chunk_len x batch_size x 1 stacked_current = np.vstack([d['current'] for d in chunk_batch]).T indata = torch.tensor(stacked_current, device=device, dtype=torch.float32).unsqueeze(2) seqs, mod_cats, seqlens = [], [], [] for chunk in chunk_batch: chunk_labels = chunk['sequence'] seqlens.append(len(chunk_labels)) chunk_seq = flipflop_code( np.ascontiguousarray(flipflop_can_labels[chunk_labels]), flipflop_ncan_base) chunk_mod_cats = np.ascontiguousarray( flipflop_mod_labels[chunk_labels]) seqs.append(chunk_seq) mod_cats.append(chunk_mod_cats) seqs, mod_cats = np.concatenate(seqs), np.concatenate(mod_cats) seqs = torch.tensor(seqs, dtype=torch.float32, device=device) seqlens = torch.tensor(seqlens, dtype=torch.long, device=device) mod_cats = torch.tensor(mod_cats, dtype=torch.long, device=device) optimizer.zero_grad() outputs = network(indata) lossvector = ctc.cat_mod_flipflop_loss(outputs, seqs, seqlens, mod_cats, can_mods_offsets, mod_cat_weights, mod_factor_t, args.sharpen) loss = lossvector.sum() / (seqlens > 0.0).float().sum() loss.backward() optimizer.step() fval = float(loss) score_smoothed.update(fval) # Check for poison chunk and save losses and chunk locations if we're # poisoned If args.chunk_logging_threshold set to zero then we log # everything if fval / score_smoothed.value >= args.chunk_logging_threshold: chunk_log.write_batch(i, chunk_batch, lossvector) total_bases += int(seqlens.sum()) total_samples += int(indata.nelement()) del indata, seqs, mod_cats, seqlens, outputs, loss, lossvector if device.type == 'cuda': torch.cuda.empty_cache() loss_log.write('{}\t{:.10f}\t{:.10f}\n'.format(i, fval, score_smoothed.value)) if (i + 1) % args.save_every == 0: save_model(network, args.outdir, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % 50 == 0: # In case of super batching, additional functionality must be # added here learning_rate = lr_scheduler.get_lr()[0] tn = time.time() dt = tn - t0 log.write((' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s ' + '{:.2f} kbase/s) lr={:.2e}').format( (i + 1) // 50, score_smoothed.value, dt, total_samples / 1000 / dt, total_bases / 1000 / dt, learning_rate)) # Write summary of chunk rejection reasons for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) log.write("\n") total_bases = 0 total_samples = 0 t0 = tn save_model(network, args.outdir) return
def main(): args = parser.parse_args() np.random.seed(args.seed) helpers.prepare_outdir(args.outdir, args.overwrite) device = helpers.set_torch_device(args.device) log = helpers.Logger(os.path.join(args.outdir, 'model.log'), args.quiet) log.write(helpers.formatted_env_info(device)) if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write('* Will train from a subset of {} strands\n'.format( len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet_info = per_read_file.get_alphabet_information() assert alphabet_info.nbase == 4, ( 'Squiggle prediction with modified base training data is ' + 'not currenly supported.') read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in mapped_signal_files.py, based on dict) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, args.target_len, args.filter_mean_dwell, args.filter_max_dwell) log.write( "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n" .format(args.sample_nreads_before_filtering, filter_parameters.median_meandwell, filter_parameters.mad_meandwell)) conv_net = create_convolution(args.size, args.depth, args.winlen) nparam = sum([p.data.detach().numpy().size for p in conv_net.parameters()]) log.write('* Created network. {} parameters\n'.format(nparam)) log.write('* Depth {} layers ({} residual layers)\n'.format( args.depth + 2, args.depth)) log.write('* Window width {}\n'.format(args.winlen)) log.write('* Context +/- {} bases\n'.format( (args.depth + 2) * (args.winlen // 2))) conv_net = conv_net.to(device) optimizer = torch.optim.Adam(conv_net.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay, eps=args.eps) lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_decay) rejection_dict = defaultdict( lambda: 0 ) # To count the numbers of different sorts of chunk rejection t0 = time.time() score_smoothed = helpers.WindowedExpSmoother() total_chunks = 0 for i in range(args.niteration): # If the logging threshold is 0 then we log all chunks, including those rejected, so pass the log # object into assemble_batch # chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, args.batch_size, args.target_len, filter_parameters, chunk_len_means_sequence_len=True) if len(chunk_batch) < args.batch_size: log.write('* Warning: only {} chunks passed filters.\n'.format( len(chunk_batch))) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input needs to be seqlen x batchsize x embedding_dimension embedded_matrix = [ embed_sequence(d['sequence'], alphabet=None) for d in chunk_batch ] seq_embed = torch.tensor(embedded_matrix).permute(1, 0, 2).to(device) # Shape of labels is a flat vector batch_signal = torch.tensor( np.concatenate([d['current'] for d in chunk_batch])).to(device) # Shape of lens is also a flat vector batch_siglen = torch.tensor([len(d['current']) for d in chunk_batch]).to(device) #print("First 10 elements of first sequence in batch",seq_embed[:10,0,:]) #print("First 10 elements of signal batch",batch_signal[:10]) #print("First 10 lengths",batch_siglen[:10]) optimizer.zero_grad() predicted_squiggle = conv_net(seq_embed) batch_loss = squiggle_match_loss(predicted_squiggle, batch_signal, batch_siglen, args.back_prob) fval = batch_loss.sum() / float(batch_siglen.sum()) fval.backward() optimizer.step() score_smoothed.update(float(fval)) if (i + 1) % args.save_every == 0: helpers.save_model(conv_net, args.outdir, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: tn = time.time() dt = tn - t0 t = ' {:5d} {:7.5f} {:5.2f}s' log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt)) t0 = tn # Write summary of chunk rejection reasons if args.full_filter_status: for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) else: n_tot = n_fail = 0 for k, v in rejection_dict.items(): n_tot += v if k != 'pass': n_fail += v log.write(" {:.1%} chunks filtered".format(n_fail / n_tot)) log.write("\n") lr_scheduler.step() helpers.save_model(conv_net, args.outdir)
rejection_dict = defaultdict(lambda : 0) # To count the numbers of different sorts of chunk rejection t0 = time.time() score_smoothed = helpers.ExponentialSmoother(args.smooth) total_chunks = 0 for i in range(args.niteration): learning_rate = args.adam.rate / (1.0 + (i**1.25) / args.lrdecay) # If the logging threshold is 0 then we log all chunks, including those rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch(read_data, args.batch_size, args.target_len, filter_parameters, args, log, chunk_log=log_rejected_chunks, chunk_len_means_sequence_len=True) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input needs to be seqlen x batchsize x embedding_dimension embedded_matrix = [embed_sequence(d['sequence'], alphabet=None) for d in chunk_batch] seq_embed = torch.tensor(embedded_matrix).permute(1,0,2).to(device) # Shape of labels is a flat vector batch_signal = torch.tensor(np.concatenate([d['current'] for d in chunk_batch])).to(device) # Shape of lens is also a flat vector batch_siglen = torch.tensor([len(d['current']) for d in chunk_batch]).to(device)