def setUpClass(self): self.nbases = 4 self.batchsize = 1 self.seqlen = 3 self.nblocks = 4 self.sharpen = 1.0 self.nflipflop_transitions = flipflopfings.nstate_flipflop( self.nbases) # 40 for ACGT self.dx_size = 0.001 # Size of small changes for gradient check self.grad_dp = 5 # Number of decimal places for gradient check # Sequence ACC. Flip-flop coded ACc or 015 # 'sequences' rather than 'sequence' because # we could have more than one sequence packed together self.sequences = { '015': torch.tensor([0, 1, 5]), '237': torch.tensor([2, 3, 7]), '510': torch.tensor([5, 1, 0]) } # prob 0 according to outputs self.seqlens = torch.tensor([self.seqlen]) # Network outputs are weights for flip-flop transitions # Define some example paths and assign weights to them. # These will be used to define the example output matrix # paths = {} weights = {} paths['015'] = [0, 0, 1, 5, 5] weights['015'] = [1.0, 1.0, 0.5, 1.0] paths['237'] = [2, 2, 3, 7, 7] weights['237'] = [1.0, 0.5, 1.0, 1.0] weights['510'] = [0.0] # No weight for this sequence/path self.path_probabilities = {k: np.prod(v) for k, v in weights.items()} # Normalise path probabilities psum = sum(self.path_probabilities.values()) self.path_probabilities = { k: v / psum for k, v in self.path_probabilities.items() } # Make output (transition weight) matrix with these path probs self.outputs = torch.zeros(self.nblocks, self.batchsize, self.nflipflop_transitions, dtype=torch.float) for k in paths.keys(): for block in range(self.nblocks): transcode = flipflop_transitioncode(paths[k][block], paths[k][block + 1], self.nbases) self.outputs[block, 0, transcode] = weights[k][block] # Log and normalise output (transition weight) matrix self.outputs = torch.log(self.outputs + SMALL_VAL) self.outputs = layers.global_norm_flipflop(self.outputs)
def __init__(self, insize, nbase, has_bias=True, _never_use_cupy=False): super().__init__() self.insize = insize self.nbase = nbase self.size = flipflopfings.nstate_flipflop(nbase) self.has_bias = has_bias self.linear = nn.Linear(insize, self.size, bias=has_bias) self.reset_parameters() self._never_use_cupy = _never_use_cupy
def worker_init(device, modelname, chunk_size, overlap, read_params, alphabet, max_concurrent_chunks, fastq, qscore_scale, qscore_offset, beam, posterior, temperature): global all_read_params global process_read_partial all_read_params = read_params device = helpers.set_torch_device(device) model = load_model(modelname).to(device) stride = guess_model_stride(model) chunk_size = chunk_size * stride overlap = overlap * stride n_can_base = len(alphabet) n_can_state = nstate_flipflop(n_can_base) def process_read_partial(read_filename, read_id, read_params): res = process_read(read_filename, read_id, model, chunk_size, overlap, read_params, n_can_state, stride, alphabet, max_concurrent_chunks, fastq, qscore_scale, qscore_offset, beam, posterior, temperature) return (read_id, *res)
def main(): args = parser.parse_args() device = helpers.set_torch_device(args.device) # TODO convert to logging sys.stderr.write("* Loading model.\n") model = load_model(args.model).to(device) is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod) do_output_mods = args.modified_base_output is not None if do_output_mods and not is_cat_mod: sys.stderr.write( "Cannot output modified bases from canonical base only model.") sys.exit() n_can_states = nstate_flipflop(model.sublayers[-1].nbase) stride = guess_model_stride(model) chunk_size = args.chunk_size * stride chunk_overlap = args.overlap * stride sys.stderr.write("* Initializing reads file search.\n") fast5_reads = list( fast5utils.iterate_fast5_reads(args.input_folder, limit=args.limit, strand_list=args.input_strand_list, recursive=args.recursive)) sys.stderr.write("* Found {} reads.\n".format(len(fast5_reads))) if args.scaling is not None: sys.stderr.write("* Loading read scaling parameters from {}.\n".format( args.scaling)) all_read_params = get_per_read_params_dict_from_tsv(args.scaling) input_read_ids = frozenset(rec[1] for rec in fast5_reads) scaling_read_ids = frozenset(all_read_params.keys()) sys.stderr.write("* {} / {} reads have scaling information.\n".format( len(input_read_ids & scaling_read_ids), len(input_read_ids))) fast5_reads = [ rec for rec in fast5_reads if rec[1] in scaling_read_ids ] else: all_read_params = {} mods_fp = None if do_output_mods: mods_fp = h5py.File(args.modified_base_output) mods_fp.create_group('Reads') mod_long_names = model.sublayers[-1].ordered_mod_long_names sys.stderr.write("* Preparing modified base output: {}.\n".format( ', '.join(map(str, mod_long_names)))) mods_fp.create_dataset('mod_long_names', data=np.array(mod_long_names, dtype='S'), dtype=h5py.special_dtype(vlen=str)) sys.stderr.write("* Calling reads.\n") nbase, ncalled, nread, nsample = 0, 0, 0, 0 t0 = time.time() progress = Progress(quiet=args.quiet) startcharacter = '@' if args.fastq else '>' try: with open_file_or_stdout(args.output) as fh: for read_filename, read_id in fast5_reads: read_params = all_read_params[ read_id] if read_id in all_read_params else None basecall, qstring, read_nsample = process_read( read_filename, read_id, model, chunk_size, chunk_overlap, read_params, n_can_states, stride, args.alphabet, is_cat_mod, mods_fp, args.max_concurrent_chunks, args.fastq, args.qscore_scale, args.qscore_offset) if basecall is not None: fh.write("{}{}\n{}\n".format( startcharacter, read_id, basecall[::-1] if args.reverse else basecall)) nbase += len(basecall) ncalled += 1 if args.fastq: fh.write("+\n{}\n".format( qstring[::-1] if args.reverse else qstring)) nread += 1 nsample += read_nsample progress.step() finally: if mods_fp is not None: mods_fp.close() total_time = time.time() - t0 sys.stderr.write("* Called {} reads in {:.2f}s\n".format( nread, int(total_time))) sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0)) sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0)) sys.stderr.write("* {} reads failed.\n".format(nread - ncalled)) return
def main(): args = parser.parse_args() np.random.seed(args.seed) device = torch.device(args.device) if device.type == 'cuda': try: torch.cuda.set_device(device) except AttributeError: sys.stderr.write('ERROR: Torch not compiled with CUDA enabled ' + 'and GPU device set.') sys.exit(1) if not os.path.exists(args.output): os.mkdir(args.output) elif not args.overwrite: sys.stderr.write('Error: Output directory {} exists but --overwrite ' + 'is false\n'.format(args.output)) exit(1) if not os.path.isdir(args.output): sys.stderr.write('Error: Output location {} is not directory\n'.format( args.output)) exit(1) copyfile(args.model, os.path.join(args.output, 'model.py')) # Create a logging file to save details of chunks. # If args.chunk_logging_threshold is set to 0 then we log all chunks # including those rejected. chunk_log = chunk_selection.ChunkLog(args.output) log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet) log.write('* Taiyaki version {}\n'.format(__version__)) log.write('* Command line\n') log.write(' '.join(sys.argv) + '\n') log.write('* Loading data from {}\n'.format(args.input)) log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input))) if args.input_strand_list is not None: read_ids = list(set(helpers.get_read_ids(args.input_strand_list))) log.write(('* Will train from a subset of {} strands, determined ' + 'by read_ids in input strand list\n').format(len(read_ids))) else: log.write('* Reads not filtered by id\n') read_ids = 'all' if args.limit is not None: log.write('* Limiting number of strands to {}\n'.format(args.limit)) with mapped_signal_files.HDF5Reader(args.input) as per_read_file: alphabet, _, _ = per_read_file.get_alphabet_information() read_data = per_read_file.get_multiple_reads(read_ids, max_reads=args.limit) # read_data now contains a list of reads # (each an instance of the Read class defined in # mapped_signal_files.py, based on dict) if len(read_data) == 0: log.write('* No reads remaining for training, exiting.\n') exit(1) log.write('* Loaded {} reads.\n'.format(len(read_data))) # Get parameters for filtering by sampling a subset of the reads # Result is a tuple median mean_dwell, mad mean_dwell # Choose a chunk length in the middle of the range for this sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2 filter_parameters = chunk_selection.sample_filter_parameters( read_data, args.sample_nreads_before_filtering, sampling_chunk_len, args, log, chunk_log=chunk_log) medmd, madmd = filter_parameters log.write( "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n" .format(args.sample_nreads_before_filtering, medmd, madmd)) log.write('* Reading network from {}\n'.format(args.model)) nbase = len(alphabet) model_kwargs = { 'stride': args.stride, 'winlen': args.winlen, # Number of input features to model e.g. was >1 for event-based # models (level, std, dwell) 'insize': 1, 'size': args.size, 'outsize': flipflopfings.nstate_flipflop(nbase) } network = helpers.load_model(args.model, **model_kwargs).to(device) log.write('* Network has {} parameters.\n'.format( sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, weight_decay=args.weight_decay) lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min, args.lr_cosine_iters) score_smoothed = helpers.WindowedExpSmoother() log.write('* Dumping initial model\n') helpers.save_model(network, args.output, 0) total_bases = 0 total_samples = 0 total_chunks = 0 # To count the numbers of different sorts of chunk rejection rejection_dict = defaultdict(int) t0 = time.time() log.write('* Training\n') for i in range(args.niteration): lr_scheduler.step() # Chunk length is chosen randomly in the range given but forced to # be a multiple of the stride batch_chunk_len = ( np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) // args.stride) * args.stride # We choose the batch size so that the size of the data in the batch # is about the same as args.min_batch_size chunks of length # args.chunk_len_max target_batch_size = int(args.min_batch_size * args.chunk_len_max / batch_chunk_len + 0.5) # ...but it can't be more than the number of reads. batch_size = min(target_batch_size, len(read_data)) # If the logging threshold is 0 then we log all chunks, including those # rejected, so pass the log # object into assemble_batch if args.chunk_logging_threshold == 0: log_rejected_chunks = chunk_log else: log_rejected_chunks = None # Chunk_batch is a list of dicts. chunk_batch, batch_rejections = chunk_selection.assemble_batch( read_data, batch_size, batch_chunk_len, filter_parameters, args, log, chunk_log=log_rejected_chunks) total_chunks += len(chunk_batch) # Update counts of reasons for rejection for k, v in batch_rejections.items(): rejection_dict[k] += v # Shape of input tensor must be: # (timesteps) x (batch size) x (input channels) # in this case: # batch_chunk_len x batch_size x 1 stacked_current = np.vstack([d['current'] for d in chunk_batch]).T indata = torch.tensor(stacked_current, device=device, dtype=torch.float32).unsqueeze(2) # Sequence input tensor is just a 1D vector, and so is seqlens seqs = torch.tensor(np.concatenate([ flipflopfings.flipflop_code(d['sequence'], nbase) for d in chunk_batch ]), device=device, dtype=torch.long) seqlens = torch.tensor([len(d['sequence']) for d in chunk_batch], dtype=torch.long, device=device) optimizer.zero_grad() outputs = network(indata) lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens, args.sharpen) loss = lossvector.sum() / (seqlens > 0.0).float().sum() loss.backward() optimizer.step() fval = float(loss) score_smoothed.update(fval) # Check for poison chunk and save losses and chunk locations if we're # poisoned If args.chunk_logging_threshold set to zero then we log # everything if fval / score_smoothed.value >= args.chunk_logging_threshold: chunk_log.write_batch(i, chunk_batch, lossvector) total_bases += int(seqlens.sum()) total_samples += int(indata.nelement()) # Doing this deletion leads to less CUDA memory usage. del indata, seqs, seqlens, outputs, loss, lossvector if device.type == 'cuda': torch.cuda.empty_cache() if (i + 1) % args.save_every == 0: helpers.save_model(network, args.output, (i + 1) // args.save_every) log.write('C') else: log.write('.') if (i + 1) % DOTROWLENGTH == 0: # In case of super batching, additional functionality must be # added here learning_rate = lr_scheduler.get_lr()[0] tn = time.time() dt = tn - t0 t = ( ' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) ' + 'lr={:.2e}') log.write( t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt, total_samples / 1000.0 / dt, total_bases / 1000.0 / dt, learning_rate)) # Write summary of chunk rejection reasons for k, v in rejection_dict.items(): log.write(" {}:{} ".format(k, v)) log.write("\n") total_bases = 0 total_samples = 0 t0 = tn helpers.save_model(network, args.output)
log.write('* Loaded references from {}.\n'.format(args.reference)) # Write pickle for future pickle_name = os.path.splitext(args.reference)[0] + '.pkl' with open(pickle_name, 'wb') as fh: pickle.dump(seq_dict, fh) log.write('* Written pickle of processed references to {} for future use.\n'.format(pickle_name)) log.write('* Reading network from {}\n'.format(args.model)) nbase = len(args.alphabet) model_kwargs = { 'size' : args.size, 'stride': args.stride, 'winlen': args.winlen, 'insize': 1, # Number of input features to model e.g. was >1 for event-based models (level, std, dwell) 'outsize': flipflopfings.nstate_flipflop(nbase) } network = helpers.load_model(args.model, **model_kwargs).to(device) log.write('* Network has {} parameters.\n'.format(sum([p.nelement() for p in network.parameters()]))) optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, betas=args.adam, eps=args.eps) lr_scheduler = CosineAnnealingLR(optimizer, args.niteration) score_smoothed = helpers.WindowedExpSmoother() log.write('* Dumping initial model\n') save_model(network, args.outdir, 0) total_bases = 0
def main(): args = parser.parse_args() assert args.device != 'cpu', "Flipflop basecalling in taiyaki requires a GPU and for cupy to be installed" device = torch.device(args.device) # TODO convert to logging sys.stderr.write("* Loading model.\n") model = load_model(args.model).to(device) is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod) do_output_mods = args.modified_base_output is not None if do_output_mods and not is_cat_mod: sys.stderr.write( "Cannot output modified bases from canonical base only model.") sys.exit() n_can_states = nstate_flipflop(model.sublayers[-1].nbase) stride = guess_model_stride(model, device=device) chunk_size, chunk_overlap = basecall_helpers.round_chunk_values( args.chunk_size, args.overlap, stride) sys.stderr.write("* Initializing reads file search.\n") fast5_reads = fast5utils.iterate_fast5_reads( args.input_folder, limit=args.limit, strand_list=args.input_strand_list, recursive=args.recursive) mods_fp = None if do_output_mods: mods_fp = h5py.File(args.modified_base_output) mods_fp.create_group('Reads') mod_long_names = model.sublayers[-1].ordered_mod_long_names sys.stderr.write("* Preparing modified base output: {}.\n".format( ', '.join(map(str, mod_long_names)))) mods_fp.create_dataset( 'mod_long_names', data=np.array(mod_long_names, dtype='S'), dtype=h5py.special_dtype(vlen=str)) sys.stderr.write("* Calling reads.\n") nbase, ncalled, nread, nsample = 0, 0, 0, 0 t0 = time.time() progress = Progress(quiet=args.quiet) try: with open_file_or_stdout(args.output) as fh: for read_filename, read_id in fast5_reads: basecall, read_nsample = process_read( read_filename, read_id, model, chunk_size, chunk_overlap, device, n_can_states, stride, args.alphabet, is_cat_mod, mods_fp) if basecall is not None: fh.write(">{}\n{}\n".format(read_id, basecall)) nbase += len(basecall) ncalled += 1 nread += 1 nsample += read_nsample progress.step() finally: if mods_fp is not None: mods_fp.close() total_time = time.time() - t0 sys.stderr.write("* Called {} reads in {}s\n".format(nread, int(total_time))) sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0)) sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0)) sys.stderr.write("* {} reads failed.\n".format(nread - ncalled)) return