def samples_to_batch(samples, prep_func, name, batch, epoch): t0 = now() items = [prep_func(s) for s in samples] xs, ys = zip(*items) x, y = np.stack(xs), np.stack(ys) get_named_logger(name).debug( "Took {:5.3}s to load batch {} (epoch {})".format( now() - t0, batch, epoch)) return x, y
def bam_to_alignments(truth_bam, ref_name, start=None, end=None): """Create list of TruthAlignment objects from a bam of Truth aligned to ref. :param truth_bam: (sorted indexed) bam with true sequence aligned to reference :param ref: name of reference to process :param start: starting position within reference :param end: ending position within reference (all alignments with any overlap with the interval start:end will be retrieved) :returns: tuple(positions, encoded_label_array) - positions: numpy structured array with 'ref_major' (reference position index) and 'ref_minor' (trailing insertion index) fields. - feature_array: 1D numpy array of encoded labels """ with pysam.AlignmentFile(truth_bam, 'rb') as bamfile: aln_reads = bamfile.fetch(reference=ref_name, start=start, end=end) alignments = [ TruthAlignment(r) for r in aln_reads if not (r.is_unmapped or r.is_secondary) ] alignments.sort(key=attrgetter('start')) logger = get_named_logger("TruthAlign") logger.info("Retrieved {} alignments.".format(len(alignments))) return alignments
def process_labels(label_counts, max_label_len=10): """Create map from full labels to (encoded) truncated labels. :param label_counrs: `Counter` obj of label counts. :param max_label_len: int, maximum label length, longer labels will be truncated. :returns: :param label_encoding: {label: int encoded label}. :param sparse_labels: bool, create sparse labels. :param n_classes: int, number of label classes. :returns: ({label: int encoding}, [label decodings], `Counter` of truncated counts). """ logger = get_named_logger('Labelling') old_labels = [k for k in label_counts.keys()] if type(old_labels[0]) == tuple: new_labels = (l[1] * decoding[l[0]].upper() for l in old_labels) else: new_labels = [l for l in old_labels] if max_label_len < np.inf: new_labels = [l[:max_label_len] for l in new_labels] old_to_new = dict(zip(old_labels, new_labels)) label_decoding = list(sorted(set(new_labels))) label_encoding = { l: label_decoding.index(old_to_new[l]) for l in old_labels} logger.info("Label encoding dict is:\n{}".format('\n'.join( '{}: {}'.format(k, v) for k, v in label_encoding.items() ))) new_counts = Counter() for l in old_labels: new_counts[label_encoding[l]] += label_counts[l] logger.info("New label counts {}".format(new_counts)) return label_encoding, label_decoding, new_counts
def __init__(self, filename, cache=True): """Basic VCF parser. :param filename: .vcf file. :param cache: if True, all parsed variants are stored in memory for faster subsequent access. """ self.filename = filename self.cache = cache self._indexed = False self._tree = None self._parse_lock = Lock() self.logger = get_named_logger('VCFReader') # Read both metadata and header self.meta = [] self.header = None with open(filename, encoding='utf-8') as handle: for line in handle: line = line.replace('\n', '') if line.startswith('##'): self.meta.append(line[2:]) elif line.startswith('#'): line = line[1:] self.header = line.split('\t') break
def __init__(self, filenames, threads=4): self.logger = get_named_logger('DataIndex') self.filenames = filenames with DataStore(filenames[0]) as ds: self.logger.debug('Loading meta from {}'.format(filenames[0])) self.meta = ds.meta c_grp = 'medaka_label_counts' if c_grp in self.meta: self.meta[c_grp] = Counter() del self.meta['medaka_samples'] self.samples = [] with ProcessPoolExecutor(threads) as executor: future_to_f = {executor.submit(DataIndex._load_meta, f): f for f in filenames} for i, future in enumerate(as_completed(future_to_f), 1): f = future_to_f[future] try: meta = future.result() self.samples.extend([(s, f) for s in meta['medaka_samples']]) self.meta[c_grp].update(meta[c_grp]) except Exception as exc: self.logger.info('Could not load meta from {}'.format(f)) else: self.logger.info('Loaded sample-index from {}/{} ({:.2%}) of feature files.'.format(i, len(filenames), i / len(filenames))) # make order of samples independent of order in which tasks complete self.samples.sort() self._index = None
def __init__(self, fname, chrom, start, end, reference_fasta, label_decoding): vcf_region_str = '{}:{}-{}'.format(chrom, start, end) #is this correct? self.label_decoding = label_decoding self.logger = get_named_logger('VCFWriter') self.logger.info("Writing variants for {}".format(vcf_region_str)) vcf_meta = ['region={}'.format(vcf_region_str)] self.writer = vcf.VCFWriter(fname, meta_info=vcf_meta) self.ref_fasta = pysam.FastaFile(reference_fasta)
def __init__(self, alignment): """Create a TruthAlignment oblist from a `pysam.libcalignedsegment.AlignedSegment` object. :param alignment: `pysam.libcalignedsegment.AlignedSegment` object. """ self.aln = alignment # so we can get positions and labels later # initialise start and end (which might be moved) self.start = self.aln.reference_start # zero-based self.end = self.aln.reference_end self.is_kept = True self.logger = get_named_logger('TruthAlign')
def __init__(self, filename, mode='r'): self.filename = filename self.mode = mode self._sample_keys = set() self.fh = None self.logger = get_named_logger('DataStore') self._meta = None
def __init__(self, filename, mode='r', verify_on_close=True): self.filename = filename self.mode = mode self.verify_on_close = verify_on_close self._sample_keys = set() self.fh = None self.logger = get_named_logger('DataStore') self._meta = None
def __init__(self, samples, prep_func, batch_size, executor, seed=None, name='Train', maxsize=100): """Load and queue training samples into batches from `.hdf` files. :param samples: tuples of (filename, hdf sample key). :param prep_func: function to transform a sample to x,y data. :param batch_size: group samples by this number. :param executor: `ThreadPoolExecutor` instance. :param seed: seed for shuffling. :param name: str, name for logger. :param maxsize: int, maximum queue size. Once initialized batches can be retrieved using batch_q._queue.get(). """ self.samples = samples self.prep_func = prep_func self.batch_size = batch_size if seed is not None: np.random.seed(seed) self.name = name self.logger = get_named_logger('{}Batcher'.format(name.capitalize())) self.maxsize = maxsize self._queue = queue.Queue(maxsize=self.maxsize) self.executor = executor self.stopped = threading.Event() self.qthread = threading.Thread(target=self._fill_queue_batch) self.qthread.daemon = True original_size = len(self.samples) self.n_batches = len(self.samples) // self.batch_size self.samples = self.samples[:self.n_batches * self.batch_size] self.logger.info( '{} batches of {} samples ({}), from {} original.'.format( self.n_batches, self.batch_size, len(self.samples), original_size)) if self.n_batches == 0: raise ValueError("Number of batches is zero.") self.qthread.start() time.sleep(2) self.logger.info( "Started reading samples from files with queue size {}".format( maxsize))
def _labelled_samples_worker(args, region): logger = get_named_logger('PrepWork') logger.info("Processing region {}.".format(region)) data_gen = SampleGenerator(args.bam, region, args.model, args.rle_ref, truth_bam=args.truth, read_fraction=args.read_fraction, chunk_len=args.chunk_len, chunk_overlap=args.chunk_ovlp) return list(data_gen.samples), region, deepcopy( data_gen.fencoder_args), deepcopy(data_gen.fencoder.decoding)
def train(args): """Training program.""" train_name = args.train_name mkdir_p(train_name, info='Results will be overwritten.') logger = get_named_logger('Training') logger.debug("Loading datasets:\n{}".format('\n'.join(args.features))) sparse_labels = not args.balanced_weights args.validation = args.validation_features if args.validation_features is not None else args.validation_split batcher = TrainBatcher(args.features, args.max_label_len, args.validation, args.seed, sparse_labels, args.batch_size, threads=args.threads_io) if args.balanced_weights: n_labels = sum(batcher.label_counts.values()) n_classes = len(batcher.label_counts) class_weight = { k: float(n_labels) / (n_classes * count) for (k, count) in batcher.label_counts.items() } class_weight = np.array( [class_weight[c] for c in sorted(class_weight.keys())]) else: class_weight = None h = lambda d, i: d[i] if d is not None else 1 logger.info("Label statistics are:\n{}".format('\n'.join( '{} ({}) {} (w. {:9.6f})'.format(i, l, batcher.label_counts[i], h(class_weight, i)) for i, l in enumerate(batcher.label_decoding)))) import tensorflow as tf with tf.device('/gpu:{}'.format(args.device)): run_training(train_name, batcher, model_fp=args.model, epochs=args.epochs, class_weight=class_weight, n_mini_epochs=args.mini_epochs, threads_io=args.threads_io) # stop batching threads logger.info("Training finished.")
def __init__(self, bam, region, model, rle_ref=None, truth_bam=None, read_fraction=None, chunk_len=1000, chunk_overlap=200, tag_name=None, tag_value=None, tag_keep_missing=False, enable_chunking=True): """Generate chunked inference (or training) samples. :param bam: `.bam` containing alignments from which to generate samples. :param region: a `Region` for which to generate samples. :param model: a medaka model. :param truth_bam: a `.bam` containing alignment of truth sequence to `reference` sequence. Required only for creating training chunks. :param reference: reference `.fasta`, should correspond to `bam`. :param tag_name: two letter tag name by which to filter reads. :param tag_value: integer value of tag for reads to keep. :param tag_keep_missing: whether to keep reads when tag is missing. :param enable_chunking: when yielding samples, do so in chunks. """ self.logger = get_named_logger("Sampler") self.sample_type = "training" if truth_bam is not None else "consensus" self.logger.info("Initializing sampler for {} of region {}.".format( self.sample_type, region)) with DataStore(model) as ds: self.fencoder_args = ds.meta['medaka_features_kwargs'] self.fencoder = FeatureEncoder(tag_name=tag_name, tag_value=tag_value, tag_keep_missing=tag_keep_missing, **self.fencoder_args) self.bam = bam self.region = region self.model = model self.rle_ref = rle_ref self.truth_bam = truth_bam self.read_fraction = read_fraction self.chunk_len = chunk_len self.chunk_overlap = chunk_overlap self.enable_chunking = enable_chunking self._source = None # the base data to be chunked self._quarantined = list() # samples which are shorter than chunk size
def run_prediction(output, bam, regions, model, model_file, rle_ref, read_fraction, chunk_len, chunk_ovlp, batch_size=200, save_features=False, tag_name=None, tag_value=None, tag_keep_missing=False): """Inference worker.""" logger = get_named_logger('PWorker') def sample_gen(): # chain all samples whilst dispensing with generators when done # (they hold the feature vector in memory until they die) for region in regions: data_gen = SampleGenerator( bam, region, model_file, rle_ref, read_fraction, chunk_len=chunk_len, chunk_overlap=chunk_ovlp, tag_name=tag_name, tag_value=tag_value, tag_keep_missing=tag_keep_missing) yield from data_gen.samples batches = background_generator( grouper(sample_gen(), batch_size), 10 ) total_region_mbases = sum(r.size for r in regions) / 1e6 logger.info("Running inference for {:.1f}M draft bases.".format(total_region_mbases)) with DataStore(output, 'a') as ds: mbases_done = 0 t0 = now() tlast = t0 for data in batches: x_data = np.stack([x.features for x in data]) class_probs = model.predict_on_batch(x_data) mbases_done += sum(x.span for x in data) / 1e6 mbases_done = min(mbases_done, total_region_mbases) # just to avoid funny log msg t1 = now() if t1 - tlast > 10: tlast = t1 msg = '{:.1%} Done ({:.1f}/{:.1f} Mbases) in {:.1f}s' logger.info(msg.format(mbases_done / total_region_mbases, mbases_done, total_region_mbases, t1 - t0)) best = np.argmax(class_probs, -1) for sample, prob, pred, feat in zip(data, class_probs, best, x_data): # write out positions and predictions for later analysis sample_d = sample._asdict() sample_d['label_probs'] = prob sample_d['features'] = feat if save_features else None ds.write_sample(Sample(**sample_d)) logger.info('All done') return None
def create_labelled_samples(args): logger = get_named_logger('Prepare') regions = get_regions(args.bam, args.regions) reg_str = '\n'.join(['\t\t\t{}'.format(r) for r in regions]) logger.info('Got regions:\n{}'.format(reg_str)) labels_counter = Counter() no_data = False with DataStore(args.output, 'w') as ds: # write feature options to file logger.info("Writing meta data to file.") with DataStore(args.model) as model: meta = { k: model.meta[k] for k in ('medaka_features_kwargs', 'medaka_feature_decoding') } ds.update_meta(meta) # TODO: this parallelism would be better in `SampleGenerator.bams_to_training_samples` # since training alignments are usually chunked. with concurrent.futures.ProcessPoolExecutor( max_workers=args.threads) as executor: # break up overly long chunks MAX_SIZE = int(1e6) regions = itertools.chain(*(r.split(MAX_SIZE) for r in regions)) futures = [ executor.submit(_labelled_samples_worker, args, reg) for reg in regions ] for fut in concurrent.futures.as_completed(futures): if fut.exception() is None: samples, region, fencoder_args, fencoder_decoder = fut.result( ) logger.info("Writing {} samples for region {}".format( len(samples), region)) for sample in samples: ds.write_sample(sample) else: logger.info(fut.exception()) fut._result = None # python issue 27144 no_data = ds.n_samples == 0 if no_data: logger.critical( "Warning: No training data was written to file, deleting output.") os.remove(args.output)
def __init__(self, filename, mode='w', header=('CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE'), meta_info=[], version='4.3'): self.filename = filename self.mode = mode self.header = header if version not in self.version_options: raise ValueError('version must be one of {}'.format( self.version_options)) self.version = version self.meta = ['fileformat=VCFv{}'.format(self.version)] + meta_info self.logger = get_named_logger('VCFWriter')
def alphabet_filter(sample_gen, alphabet=None, filter_labels=True, filter_ref_seq=True): """Skip chunks in which labels and/or ref_seq contain bases not in `alphabet`. :param sample_gen: generator of `Sample` named tuples. :param alphabet: set of str of allowed bases. If None, automatically generated from decoding. :param filter_labels: bool, whether to filter on labels. :param filter_ref_seq: bool, whether to filter on ref_seq. :yields: `Sample` named tuples. """ if alphabet is None: alphabet = set([c for c in _alphabet_ + _gap_]) logger = get_named_logger('AlphaFilter') logger.debug("alphabet: {}".format(alphabet)) alphabet = set([encoding[c] for c in alphabet]) def _find_bad_bases(s, field, alphabet): seq_rle = getattr(s, field) bases = set(np.unique(seq_rle['base'])) if not bases.issubset(alphabet): diff = [decoding[i] for i in bases - alphabet] msg = "Skipping {}:{}-{} ({} bases) due to {} {}" pos = s.positions logger.info( msg.format(s.ref_name, pos['major'][0], pos['major'][-1], len(pos), field, diff)) return True for s in sample_gen: if filter_labels and s.labels is not None and _find_bad_bases( s, 'labels', alphabet): continue if filter_ref_seq and s.ref_seq is not None and _find_bad_bases( s, 'ref_seq', alphabet): continue yield s
def __init__(self, batcher, dataset='train', mini_epochs=1, seed=None): """Interface for keras to a `TrainBatcher` for training and validation batches. :param batcher: a `medaka.inference.TrainBatcher` instance. :param dataset: one of 'train' or 'validation'. :param mini_epochs: factor by which to rescale the number of batches in an epoch (useful to output checkpoints more frequently). :param seed: random seed for shuffling data. """ self.batcher = batcher self.dataset = dataset self.mini_epochs = mini_epochs self.batch_size = self.batcher.batch_size if seed is not None: np.random.seed(seed) self.epoch = 1 if dataset == 'train': self.data = batcher.train_samples elif dataset == 'validation': self.data = batcher.valid_samples if mini_epochs != 1: raise ValueError( "'mini_epochs' must be equal to 1 for validation data.") else: raise ValueError("'dataset' should be 'train' or 'validation'.") original_size = len(self.data) self.n_batches = len(self.data) // self.batch_size self.data = self.data[:self.n_batches * self.batch_size] np.random.shuffle(self.data) self.logger = get_named_logger('{}Batcher'.format( dataset.capitalize())) self.logger.info( '{} batches of {} samples ({}), from {} original.'.format( self.n_batches, self.batch_size, len(self.data), original_size))
def predict(args): """Inference program.""" logger_level = logging.getLogger(__package__).level if logger_level > logging.DEBUG: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" from keras.models import load_model from keras import backend as K args.regions = get_regions(args.bam, region_strs=args.regions) logger = get_named_logger('Predict') logger.info('Processing region(s): {}'.format(' '.join( str(r) for r in args.regions))) # write class names to output with DataStore(args.model) as ds: meta = ds.meta with DataStore(args.output, 'w', verify_on_close=False) as ds: ds.update_meta(meta) logger.info("Setting tensorflow threads to {}.".format(args.threads)) K.tf.logging.set_verbosity(K.tf.logging.ERROR) K.set_session( K.tf.Session(config=K.tf.ConfigProto( intra_op_parallelism_threads=args.threads, inter_op_parallelism_threads=args.threads))) # Split overly long regions to maximum size so as to not create # massive feature matrices MAX_REGION_SIZE = int(1e6) # 1Mb regions = [] for region in args.regions: if region.size > MAX_REGION_SIZE: regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp) else: regs = [region] regions.extend(regs) logger.info("Processing {} long region(s) with batching.".format( len(regions))) model = medaka.models.load_model(args.model, time_steps=args.chunk_len) # the returned regions are those where the pileup width is smaller than chunk_len remainder_regions = run_prediction(args.output, args.bam, regions, model, args.model, args.rle_ref, args.read_fraction, args.chunk_len, args.chunk_ovlp, batch_size=args.batch_size, save_features=args.save_features, tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing) # short/remainder regions: just do things without chunking. We can do this # here because we now have the size of all pileups (and know they are small). # TODO: can we avoid calculating pileups twice whilst controlling memory? if len(remainder_regions) > 0: logger.info("Processing {} short region(s).".format( len(remainder_regions))) model = medaka.models.load_model(args.model, time_steps=None) for region in remainder_regions: new_remainders = run_prediction( args.output, args.bam, [region[0]], model, args.model, args.rle_ref, args.read_fraction, args.chunk_len, args.chunk_ovlp, # these won't be used batch_size=args.batch_size, save_features=args.save_features, tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing, enable_chunking=False) if len(new_remainders) > 0: # shouldn't get here ignored = [x[0] for x in new_remainders] n_ignored = len(ignored) logger.warning("{} regions were not processed: {}.".format( n_ignored, ignored)) logger.info("Finished processing all regions.") if args.check_output: logger.info("Validating and finalising output data.") with DataStore(args.output, 'a') as ds: pass
def __init__(self, features, max_label_len, validation=0.2, seed=0, sparse_labels=True, batch_size=500, threads=1): """ Class to server up batches of training / validation data. :param features: iterable of str, training feature files. :param max_label_len: int, maximum label length, longer labels will be truncated. :param validation: float, fraction of batches to use for validation, or iterable of str, validation feature files. :param seed: int, random seed for separation of batches into training/validation. :param sparse_labels: bool, create sparse labels. """ self.logger = get_named_logger('TrainBatcher') self.features = features self.max_label_len = max_label_len self.validation = validation self.seed = seed self.sparse_labels = sparse_labels self.batch_size = batch_size di = DataIndex(self.features, threads=threads) self.samples = di.samples.copy() self.meta = di.meta.copy() self.label_counts = self.meta['medaka_label_counts'] # check sample size using first batch test_sample, test_fname = self.samples[0] with DataStore(test_fname) as ds: self.feature_shape = ds.load_sample(test_sample).features.shape self.logger.info("Sample features have shape {}".format( self.feature_shape)) if isinstance(self.validation, float): np.random.seed(self.seed) np.random.shuffle(self.samples) n_sample_train = int((1 - self.validation) * len(self.samples)) self.train_samples = self.samples[:n_sample_train] self.valid_samples = self.samples[n_sample_train:] msg = 'Randomly selected {} ({:3.2%}) of features for validation (seed {})' self.logger.info( msg.format(len(self.valid_samples), self.validation, self.seed)) else: self.train_samples = self.samples self.valid_samples = DataIndex(self.validation).samples.copy() msg = 'Found {} validation samples equivalent to {:3.2%} of all the data' fraction = len(self.valid_samples) / len(self.valid_samples) + len( self.train_samples) self.logger.info(msg.format(len(self.valid_samples), fraction)) msg = 'Got {} samples in {} batches ({} labels) for {}' self.logger.info( msg.format(len(self.train_samples), len(self.train_samples) // batch_size, len(self.train_samples) * self.feature_shape[0], 'training')) self.logger.info( msg.format(len(self.valid_samples), len(self.valid_samples) // batch_size, len(self.valid_samples) * self.feature_shape[0], 'validation')) self.n_classes = len(self.label_counts) # get label encoding, given max_label_len self.logger.info("Max label length: {}".format( self.max_label_len if self.max_label_len is not None else 'inf')) self.label_encoding, self.label_decoding, self.label_counts = process_labels( self.label_counts, max_label_len=self.max_label_len)
def pileup_counts(region, bam, dtype_prefixes=None, region_split=100000, workers=4, tag_name=None, tag_value=None, keep_missing=False): """Create pileup counts feature array for region. :param region: `Region` object :param bam: .bam file with alignments. :param dtype_prefixes: prefixes for query names which to separate counts. If `None` (or of length 1), counts are not split. :param tag_name: two letter tag name by which to filter reads. :param tag_value: integer value of tag for reads to keep. :param keep_missing: whether to keep reads when tag is missing. :returns: pileup counts array, reference positions, insertion postions """ ffi, lib = libmedaka.ffi, libmedaka.lib logger = get_named_logger('PileUp') num_dtypes, dtypes = 1, ffi.NULL if isinstance(dtype_prefixes, str): dtype_prefixes = [dtype_prefixes] if dtype_prefixes is not None and len(dtype_prefixes) > 1: num_dtypes = len(dtype_prefixes) _dtypes = [ffi.new("char[]", d.encode()) for d in dtype_prefixes] dtypes = ffi.new("char *[]", _dtypes) if tag_name is None: tag_name = ffi.new("char[2]", "".encode()) tag_value = 0 keep_missing = False else: if len(tag_name) > 2: raise ValueError("'tag_value' must be a length-2 string.") tag_name = ffi.new("char[2]", tag_name.encode()) featlen = lib.featlen def _process_region(reg): # htslib start is 1-based, Region object is 0-based region_str = '{}:{}-{}'.format(reg.ref_name, reg.start + 1, reg.end) counts = lib.calculate_pileup(region_str.encode(), bam.encode(), num_dtypes, dtypes, tag_name, tag_value, keep_missing) size_sizet = np.dtype(np.uintp).itemsize np_counts = np.frombuffer( ffi.buffer(counts.counts, size_sizet * counts.n_cols * featlen * num_dtypes), dtype=np.uintp).reshape(counts.n_cols, featlen * num_dtypes).copy() positions = np.empty(counts.n_cols, dtype=[('major', int), ('minor', int)]) np.copyto( positions['major'], np.frombuffer(ffi.buffer(counts.major, size_sizet * counts.n_cols), dtype=np.uintp)) np.copyto( positions['minor'], np.frombuffer(ffi.buffer(counts.minor, size_sizet * counts.n_cols), dtype=np.uintp)) lib.destroy_plp_data(counts) return np_counts, positions # split large regions for performance regions = region.split(region_split) with concurrent.futures.ThreadPoolExecutor( max_workers=workers) as executor: results = list(executor.map(_process_region, regions)) # First pass: need to check for discontinuities within chunks, # these show up as >2 changes in the major coordinate _results = list() for counts, positions in results: move = np.ediff1d(positions['major']) gaps = np.where(move > 2)[0] + 1 if len(gaps) == 0: _results.append((counts, positions)) else: logger.info("Splitting discontiguous pileup region.") start = 0 for i in gaps: _results.append((counts[start:i], positions[start:i])) start = i _results.append((counts[start:], positions[start:])) results = _results # Second pass: stitch abutting chunks together, anything not neighbouring # is kept separate whether it came from the same chunk originally or not def _finalize_chunk(c_buf, p_buf): chunk_counts = np.concatenate(c_buf) chunk_positions = np.concatenate(p_buf) # get rid of 'first' counts row for each datatype (counts of # alternative bases) mask = np.ones(chunk_counts.shape[1], dtype=bool) mask[[x * featlen for x in range(0, num_dtypes)]] = False chunk_counts = chunk_counts[:, mask] return chunk_counts, chunk_positions counts_buffer, positions_buffer = list(), list() chunk_results = list() last = None for counts, positions in results: if len(positions) == 0: continue first = positions['major'][0] if len(counts_buffer) == 0 or first - last == 1: # new or contiguous counts_buffer.append(counts) positions_buffer.append(positions) last = positions['major'][-1] else: # discontinuity chunk_results.append( _finalize_chunk(counts_buffer, positions_buffer)) counts_buffer = [counts] positions_buffer = [positions] last = positions['major'][-1] if len(counts_buffer) != 0: chunk_results.append(_finalize_chunk(counts_buffer, positions_buffer)) return chunk_results
def __init__( self, ref_mode: str = None, max_hp_len: int = 10, log_min: int = None, normalise: str = 'total', with_depth: bool = False, consensus_as_ref: bool = False, is_compressed: bool = True, dtypes=('', ), tag_name=None, tag_value=None, tag_keep_missing=False, sym_indels=False, ): """Class to support multiple feature encodings :param ref_mode: str, how to represent the reference. :param max_hp_len: int, longest homopolymer run which can be represented, longer runs will be truncated. :param log_min: int, take log10 of counts/fractions and set zeros to 10**log_min. :param normalise: str, how to normalise the data. :param with_depth: bool, whether to include a feature describing the total depth. :param consensus_as_ref: bool, whether to use a naive max-count consensus instead of the reference. :param is_compressed: bool, whether to use HP compression. If false, treat as uncompressed. :param dtypes: iterable of str, read id prefixes of distinct data types that should be counted separately. :param tag_name: two letter tag name by which to filter reads. :param tag_value: integer value of tag for reads to keep. :param tag_keep_missing: whether to keep reads when tag is missing. :param sym_indels: bool, whether to count a lack of an insertion as a deletion. """ self.ref_mode = ref_mode self.consensus_as_ref = consensus_as_ref self.max_hp_len = max_hp_len self.log_min = log_min self.normalise = normalise self.feature_dtype = np.float32 if ( self.normalise is not None or self.log_min is not None) else np.uint64 self.with_depth = with_depth self.is_compressed = is_compressed self.logger = get_named_logger('Feature') self.dtypes = dtypes self.tag_name = tag_name self.tag_value = tag_value self.tag_keep_missing = tag_keep_missing self.sym_indels = sym_indels if self.ref_mode not in self._ref_modes_: raise ValueError('ref_mode={} is not one of {}'.format( self.ref_mode, self._ref_modes_)) if self.normalise not in self._norm_modes_: raise ValueError('normalise={} is not one of {}'.format( self.normalise, self._norm_modes_)) opts = inspect.signature(FeatureEncoder.__init__).parameters.keys() opts = {k: getattr(self, k) for k in opts if k != 'self'} read_decoding = [] for dtype in self.dtypes: # set up one-hot encoding of read run lengths for each dtype read_decoding += [(dtype, ) + k for k in itertools.product(( True, False), _alphabet_, range(1, max_hp_len + 1))] # forward and reverse gaps read_decoding += [(dtype, True, None, 1), (dtype, False, None, 1)] if self.ref_mode == 'onehot': ref_decoding = [('ref', b, l) for b, l in itertools.product( alphabet, range(1, max_hp_len + 1))] ref_decoding.append(('ref', _gap_, 1)) # gaps elif self.ref_mode == 'base_length': ref_decoding = ['ref_base', 'ref_length'] self.ref_base_encoding = { b: i for i, b in enumerate(_alphabet_ + _gap_) } elif self.ref_mode == 'index': ref_decoding = ['ref_index'] else: ref_decoding = [] self.decoding = tuple(read_decoding + ref_decoding) if self.with_depth: self.decoding = self.decoding + ('depth', ) self.encoding = OrderedDict( ((a, i) for i, a in enumerate(self.decoding))) self.logger.debug("Creating features with: {}".format(opts)) self.logger.debug("Label decoding is:\n{}".format('\n'.join( '{}: {}'.format(i, x) for i, x in enumerate(self.decoding))))
def predict(args): """Inference program.""" os.environ["TF_CPP_MIN_LOG_LEVEL"]="2" from keras.models import load_model from keras import backend as K args.regions = get_regions(args.bam, region_strs=args.regions) logger = get_named_logger('Predict') logger.info('Processing region(s): {}'.format(' '.join(str(r) for r in args.regions))) # write class names to output with DataStore(args.model) as ds: meta = ds.meta with DataStore(args.output, 'w') as ds: ds.update_meta(meta) logger.info("Setting tensorflow threads to {}.".format(args.threads)) K.tf.logging.set_verbosity(K.tf.logging.ERROR) K.set_session(K.tf.Session( config=K.tf.ConfigProto( intra_op_parallelism_threads=args.threads, inter_op_parallelism_threads=args.threads) )) # Split overly long regions to maximum size so as to not create # massive feature matrices, then segregate those which cannot be # batched. MAX_REGION_SIZE = int(1e6) # 1Mb long_regions = [] short_regions = [] for region in args.regions: if region.size > MAX_REGION_SIZE: regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp) else: regs = [region] for r in regs: if r.size > args.chunk_len: long_regions.append(r) else: short_regions.append(r) logger.info("Found {} long and {} short regions.".format( len(long_regions), len(short_regions))) if len(long_regions) > 0: logger.info("Processing long regions.") model = medaka.models.load_model(args.model, time_steps=args.chunk_len) run_prediction( args.output, args.bam, long_regions, model, args.model, args.rle_ref, args.read_fraction, args.chunk_len, args.chunk_ovlp, batch_size=args.batch_size, save_features=args.save_features, tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing ) # short regions must be done individually since they have differing lengths # TODO: consider masking (it appears slow to apply wholesale), maybe # step down args.chunk_len by a constant factor until nothing remains. if len(short_regions) > 0: logger.info("Processing short regions") model = medaka.models.load_model(args.model, time_steps=None) for region in short_regions: chunk_len = region.size // 2 chunk_ovlp = chunk_len // 10 # still need overlap as features will be longer run_prediction( args.output, args.bam, [region], model, args.model, args.rle_ref, args.read_fraction, chunk_len, chunk_ovlp, batch_size=args.batch_size, save_features=args.save_features, tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing ) logger.info("Finished processing all regions.")
def run_training(train_name, batcher, model_fp=None, epochs=5000, class_weight=None, n_mini_epochs=1, threads_io=1): """Run training.""" from keras.callbacks import CSVLogger, TensorBoard, EarlyStopping, ReduceLROnPlateau from medaka.keras_ext import ModelMetaCheckpoint, SequenceBatcher, BatchQueue logger = get_named_logger('RunTraining') if model_fp is None: model_name = medaka.models.default_model model_kwargs = { k: v.default for (k, v) in inspect.signature( medaka.models.model_builders[model_name]).parameters.items() if v.default is not inspect.Parameter.empty } else: with DataStore(model_fp) as ds: model_name = ds.meta['medaka_model_name'] model_kwargs = ds.meta['medaka_model_kwargs'] opt_str = '\n'.join( ['{}: {}'.format(k, v) for k, v in model_kwargs.items()]) logger.info('Building {} model with: \n{}'.format(model_name, opt_str)) num_classes = len(batcher.label_counts) timesteps, feat_dim = batcher.feature_shape model = medaka.models.model_builders[model_name](timesteps, feat_dim, num_classes, **model_kwargs) if model_fp is not None: try: model.load_weights(model_fp) logger.info("Loading weights from {}".format(model_fp)) except: logger.info("Could not load weights from {}".format(model_fp)) msg = "feat_dim: {}, timesteps: {}, num_classes: {}" logger.info(msg.format(feat_dim, timesteps, num_classes)) model.summary() model_details = batcher.meta.copy() model_details['medaka_model_name'] = model_name model_details['medaka_model_kwargs'] = model_kwargs model_details['medaka_label_decoding'] = batcher.label_decoding opts = dict(verbose=1, save_best_only=True, mode='max') callbacks = [ # Best model according to training set accuracy ModelMetaCheckpoint(model_details, os.path.join(train_name, 'model.best.hdf5'), monitor='cat_acc', **opts), # Best model according to validation set accuracy ModelMetaCheckpoint(model_details, os.path.join(train_name, 'model.best.val.hdf5'), monitor='val_cat_acc', **opts), # Best model according to validation set qscore ModelMetaCheckpoint(model_details, os.path.join(train_name, 'model.best.val.qscore.hdf5'), monitor='val_qscore', **opts), # Checkpoints when training set accuracy improves ModelMetaCheckpoint( model_details, os.path.join( train_name, 'model-acc-improvement-{epoch:02d}-{cat_acc:.2f}.hdf5'), monitor='cat_acc', **opts), ModelMetaCheckpoint( model_details, os.path.join( train_name, 'model-val_acc-improvement-{epoch:02d}-{val_cat_acc:.2f}.hdf5' ), monitor='val_cat_acc', **opts), ## Reduce learning rate when no improvement #ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, # verbose=1, min_delta=0.0001, cooldown=0, min_lr=0), # Stop when no improvement EarlyStopping(monitor='val_loss', patience=20), # Log of epoch stats CSVLogger(os.path.join(train_name, 'training.log')), # Allow us to run tensorboard to see how things are going. Some # features require validation data, not clear why. #TensorBoard(log_dir=os.path.join(train_name, 'logs'), # histogram_freq=5, batch_size=100, write_graph=True, # write_grads=True, write_images=True) ] if class_weight is not None: loss = weighted_categorical_crossentropy(class_weight) logger.info("Using weighted_categorical_crossentropy loss function") else: loss = 'sparse_categorical_crossentropy' logger.info("Using {} loss function".format(loss)) model.compile( loss=loss, optimizer='nadam', metrics=[cat_acc, qscore], ) if n_mini_epochs == 1: logger.info( "Not using mini_epochs, an epoch is a full traversal of the training data" ) else: logger.info( "Using mini_epochs, an epoch is a traversal of 1/{} of the training data" .format(n_mini_epochs)) with ProcessPoolExecutor(threads_io) as executor: logger.info("Starting data queues.") prep_function = functools.partial( batcher.sample_to_x_y_bq_worker, max_label_len=batcher.max_label_len, label_encoding=batcher.label_encoding, sparse_labels=batcher.sparse_labels, n_classes=batcher.n_classes) # TODO: should take mini_epochs into account here train_queue = BatchQueue(batcher.train_samples, prep_function, batcher.batch_size, executor, seed=batcher.seed, name='Train', maxsize=100) valid_queue = BatchQueue(batcher.valid_samples, prep_function, batcher.batch_size, executor, seed=batcher.seed, name='Valid', maxsize=100) # run training logger.info("Starting training.") model.fit_generator( generator=train_queue.yield_batches(), steps_per_epoch=train_queue.n_batches // n_mini_epochs, validation_data=valid_queue.yield_batches(), validation_steps=valid_queue.n_batches, max_queue_size=2 * threads_io, workers=1, use_multiprocessing=False, epochs=epochs, callbacks=callbacks, class_weight=class_weight, ) logger.info("Training finished.") train_queue.stop() valid_queue.stop()
def stitch_from_probs(h5_fp, regions=None): """Join overlapping label probabilities from HDF5 files. Network outputs from multiple samples stored within a file are spliced together into a logically contiguous array and decoded to generate contiguous sequence(s). :param h5_fp: iterable of HDF5 filepaths :param regions: iterable of region (strings) to process :returns: list of (region string, sequence) """ logger = common.get_named_logger('Stitch') if isinstance(regions, medaka.common.Region): regions = [regions] logger.info("Stitching regions: {}".format([str(r) for r in regions])) index = medaka.datastore.DataIndex(h5_fp) label_scheme = index.metadata['label_scheme'] logger.debug("Label decoding is:\n{}".format( '\n'.join('{}: {}'.format(k, v) for k, v in label_scheme._decoding.items()))) def get_pos(sample, i): return '{}.{}'.format( sample.positions[i]['major'] + 1, sample.positions[i]['minor']) ref_assemblies = [] for reg in regions: logger.info("Processing {}.".format(reg)) data_gen = index.yield_from_feature_files(regions=[reg]) seq_parts = list() cur_ref_name = '' cur_segment = None # first sample s1 = next(data_gen) start = get_pos(s1, 0) start_1 = None start_2 = None heuristic_use = 0 for s2 in itertools.chain(data_gen, (None,)): s1_name = 'Unknown' if s1 is None else s1.name s2_name = 'Unknown' if s2 is None else s2.name # s1 is last chunk if s2 is None: end_1 = None else: # s2 ends before s1 if s2.last_pos <= s1.last_pos: logger.info('{} ends before {}, skipping.'.format( s2_name, s1_name )) continue # s1 and s2 overlap by only one position # or there is no overlap between s1 and s2 elif s2.first_pos >= s1.last_pos: # trigger a break end_1, start_2 = None, None else: try: end_1, start_2, heuristic = \ common.Sample.overlap_indices(s1, s2) if heuristic: logger.debug( "Used heuristic to stitch {} and {}.".format( s1.name, s2.name)) heuristic_use += 1 except common.OverlapException as e: logger.info( "Unhandled overlap type whilst stitching chunks.") raise(e) new_seq = label_scheme.decode_consensus( s1.slice(slice(start_1, end_1))) seq_parts.append(new_seq) if end_1 is None: if s1.ref_name != cur_ref_name: cur_ref_name = s1.ref_name cur_segment = 0 else: cur_segment += 1 ref_assemblies.append(( '{}_segment{}'.format(cur_ref_name, cur_segment), '{}:{}-{}'.format(cur_ref_name, start, get_pos(s1, -1)), ''.join(seq_parts))) seq_parts = list() if s2 is not None and start_2 is None: msg = 'There is no overlap betwen {} and {}' logger.info(msg.format(s1_name, s2_name)) start = get_pos(s2, 0) s1 = s2 start_1 = start_2 logger.info("Used heuristic {} times for {}.".format( heuristic_use, reg)) return ref_assemblies
from medaka.datastore import DataStore, DataIndex from medaka.common import get_named_logger logger = get_named_logger('ModelLoad') def load_model(fname, time_steps=None): """Load a model from an .hdf file. :param fname: .hdf file containing model. :param time_steps: number of time points in RNN, `None` for dynamic. ..note:: keras' `load_model` cannot handle CuDNNGRU layers, hence this function builds the model then loads the weights. """ with DataStore(fname) as ds: meta = ds.meta num_features = len(meta['medaka_feature_decoding']) num_classes = len(meta['medaka_label_decoding']) build_model = model_builders[meta['medaka_model_name']] logger.info( "Building model (steps, features, classes): ({}, {}, {})".format( time_steps, num_features, num_classes)) model = build_model(time_steps, num_features, num_classes, **meta['medaka_model_kwargs']) logger.info("Loading weights from {}".format(fname)) model.load_weights(fname) return model