def sample_to_x_y_bq_worker(sample, max_label_len, label_encoding, sparse_labels, n_classes): """Convert a `Sample` object into an x,y tuple for training. :param sample: (filename, sample key) :param max_label_len: int, maximum label length, longer labels will be truncated. :param label_encoding: {label: int encoded label}. :param sparse_labels: bool, create sparse labels. :param n_classes: int, number of label classes. :returns: (np.ndarray of inputs, np.ndarray of labels) """ sample_key, sample_file = sample with DataStore(sample_file) as ds: s = ds.load_sample(sample_key) if s.labels is None: raise ValueError("Sample {} in {} has no labels.".format(sample_key, sample_file)) x = s.features # labels can either be unicode strings or (base, length) integer tuples if isinstance(s.labels[0], np.unicode): # TODO: is this ever used now we have dispensed with tview code? y = np.fromiter((label_encoding[l[:min(max_label_len, len(l))]] for l in s.labels), dtype=int, count=len(s.labels)) else: y = np.fromiter((label_encoding[tuple((l['base'], min(max_label_len, l['run_length'])))] for l in s.labels), dtype=int, count=len(s.labels)) y = y.reshape(y.shape + (1,)) if not sparse_labels: from keras.utils.np_utils import to_categorical y = to_categorical(y, num_classes=n_classes) return x, y
def load_model(fname, time_steps=None): """Load a model from an .hdf file. :param fname: .hdf file containing model. :param time_steps: number of time points in RNN, `None` for dynamic. ..note:: keras' `load_model` cannot handle CuDNNGRU layers, hence this function builds the model then loads the weights. """ with DataStore(fname) as ds: meta = ds.meta num_features = len(meta['medaka_feature_decoding']) num_classes = len(meta['medaka_label_decoding']) try: num_dtypes = len(meta['medaka_features_kwargs']['dtypes']) except KeyError: num_dtypes = 1 num_features *= num_dtypes build_model = model_builders[meta['medaka_model_name']] logger.info( "Building model (steps, features, classes): ({}, {}, {})".format( time_steps, num_features, num_classes)) model = build_model(time_steps, num_features, num_classes, **meta['medaka_model_kwargs']) logger.info("Loading weights from {}".format(fname)) model.load_weights(fname) return model
def create_labelled_samples(args): logger = get_named_logger('Prepare') regions = get_regions(args.bam, args.regions) reg_str = '\n'.join(['\t\t\t{}'.format(r) for r in regions]) logger.info('Got regions:\n{}'.format(reg_str)) labels_counter = Counter() no_data = False with DataStore(args.output, 'w') as ds: # write feature options to file logger.info("Writing meta data to file.") with DataStore(args.model) as model: meta = { k: model.meta[k] for k in ('medaka_features_kwargs', 'medaka_feature_decoding') } ds.update_meta(meta) # TODO: this parallelism would be better in `SampleGenerator.bams_to_training_samples` # since training alignments are usually chunked. with concurrent.futures.ProcessPoolExecutor( max_workers=args.threads) as executor: # break up overly long chunks MAX_SIZE = int(1e6) regions = itertools.chain(*(r.split(MAX_SIZE) for r in regions)) futures = [ executor.submit(_labelled_samples_worker, args, reg) for reg in regions ] for fut in concurrent.futures.as_completed(futures): if fut.exception() is None: samples, region, fencoder_args, fencoder_decoder = fut.result( ) logger.info("Writing {} samples for region {}".format( len(samples), region)) for sample in samples: ds.write_sample(sample) else: logger.info(fut.exception()) fut._result = None # python issue 27144 no_data = ds.n_samples == 0 if no_data: logger.critical( "Warning: No training data was written to file, deleting output.") os.remove(args.output)
def test_000_load_all_models(self): for name, model_file in model_dict.items(): model = models.load_model(model_file) self.assertIsInstance(model, tensorflow.keras.models.Model) # Check we can get necessary functions for inference with DataStore(model_file) as ds: feature_encoder = ds.get_meta('feature_encoder') self.assertIsInstance(feature_encoder, BaseFeatureEncoder) label_scheme = ds.get_meta('label_scheme') self.assertIsInstance(label_scheme, BaseLabelScheme)
def test_999_load_all_models(self): for name in medaka.options.allowed_models: model_file = models.resolve_model(name) model = medaka.models.open_model(model_file).load_model() self.assertIsInstance(model, tensorflow.keras.models.Model) # Check we can get necessary functions for inference with DataStore(model_file) as ds: feature_encoder = ds.get_meta('feature_encoder') self.assertIsInstance(feature_encoder, BaseFeatureEncoder) label_scheme = ds.get_meta('label_scheme') self.assertIsInstance(label_scheme, BaseLabelScheme)
def __init__(self, bam, region, model, rle_ref=None, truth_bam=None, read_fraction=None, chunk_len=1000, chunk_overlap=200, tag_name=None, tag_value=None, tag_keep_missing=False, enable_chunking=True): """Generate chunked inference (or training) samples. :param bam: `.bam` containing alignments from which to generate samples. :param region: a `Region` for which to generate samples. :param model: a medaka model. :param truth_bam: a `.bam` containing alignment of truth sequence to `reference` sequence. Required only for creating training chunks. :param reference: reference `.fasta`, should correspond to `bam`. :param tag_name: two letter tag name by which to filter reads. :param tag_value: integer value of tag for reads to keep. :param tag_keep_missing: whether to keep reads when tag is missing. :param enable_chunking: when yielding samples, do so in chunks. """ self.logger = get_named_logger("Sampler") self.sample_type = "training" if truth_bam is not None else "consensus" self.logger.info("Initializing sampler for {} of region {}.".format( self.sample_type, region)) with DataStore(model) as ds: self.fencoder_args = ds.meta['medaka_features_kwargs'] self.fencoder = FeatureEncoder(tag_name=tag_name, tag_value=tag_value, tag_keep_missing=tag_keep_missing, **self.fencoder_args) self.bam = bam self.region = region self.model = model self.rle_ref = rle_ref self.truth_bam = truth_bam self.read_fraction = read_fraction self.chunk_len = chunk_len self.chunk_overlap = chunk_overlap self.enable_chunking = enable_chunking self._source = None # the base data to be chunked self._quarantined = list() # samples which are shorter than chunk size
def run_prediction(output, bam, regions, model, model_file, rle_ref, read_fraction, chunk_len, chunk_ovlp, batch_size=200, save_features=False, tag_name=None, tag_value=None, tag_keep_missing=False): """Inference worker.""" logger = get_named_logger('PWorker') def sample_gen(): # chain all samples whilst dispensing with generators when done # (they hold the feature vector in memory until they die) for region in regions: data_gen = SampleGenerator( bam, region, model_file, rle_ref, read_fraction, chunk_len=chunk_len, chunk_overlap=chunk_ovlp, tag_name=tag_name, tag_value=tag_value, tag_keep_missing=tag_keep_missing) yield from data_gen.samples batches = background_generator( grouper(sample_gen(), batch_size), 10 ) total_region_mbases = sum(r.size for r in regions) / 1e6 logger.info("Running inference for {:.1f}M draft bases.".format(total_region_mbases)) with DataStore(output, 'a') as ds: mbases_done = 0 t0 = now() tlast = t0 for data in batches: x_data = np.stack([x.features for x in data]) class_probs = model.predict_on_batch(x_data) mbases_done += sum(x.span for x in data) / 1e6 mbases_done = min(mbases_done, total_region_mbases) # just to avoid funny log msg t1 = now() if t1 - tlast > 10: tlast = t1 msg = '{:.1%} Done ({:.1f}/{:.1f} Mbases) in {:.1f}s' logger.info(msg.format(mbases_done / total_region_mbases, mbases_done, total_region_mbases, t1 - t0)) best = np.argmax(class_probs, -1) for sample, prob, pred, feat in zip(data, class_probs, best, x_data): # write out positions and predictions for later analysis sample_d = sample._asdict() sample_d['label_probs'] = prob sample_d['features'] = feat if save_features else None ds.write_sample(Sample(**sample_d)) logger.info('All done') return None
def run_training(train_name, batcher, model_fp=None, epochs=5000, class_weight=None, n_mini_epochs=1, threads_io=1): """Run training.""" from keras.callbacks import CSVLogger, TensorBoard, EarlyStopping, ReduceLROnPlateau from medaka.keras_ext import ModelMetaCheckpoint, SequenceBatcher, BatchQueue logger = get_named_logger('RunTraining') if model_fp is None: model_name = medaka.models.default_model model_kwargs = { k: v.default for (k, v) in inspect.signature( medaka.models.model_builders[model_name]).parameters.items() if v.default is not inspect.Parameter.empty } else: with DataStore(model_fp) as ds: model_name = ds.meta['medaka_model_name'] model_kwargs = ds.meta['medaka_model_kwargs'] opt_str = '\n'.join( ['{}: {}'.format(k, v) for k, v in model_kwargs.items()]) logger.info('Building {} model with: \n{}'.format(model_name, opt_str)) num_classes = len(batcher.label_counts) timesteps, feat_dim = batcher.feature_shape model = medaka.models.model_builders[model_name](timesteps, feat_dim, num_classes, **model_kwargs) if model_fp is not None: try: model.load_weights(model_fp) logger.info("Loading weights from {}".format(model_fp)) except: logger.info("Could not load weights from {}".format(model_fp)) msg = "feat_dim: {}, timesteps: {}, num_classes: {}" logger.info(msg.format(feat_dim, timesteps, num_classes)) model.summary() model_details = batcher.meta.copy() model_details['medaka_model_name'] = model_name model_details['medaka_model_kwargs'] = model_kwargs model_details['medaka_label_decoding'] = batcher.label_decoding opts = dict(verbose=1, save_best_only=True, mode='max') callbacks = [ # Best model according to training set accuracy ModelMetaCheckpoint(model_details, os.path.join(train_name, 'model.best.hdf5'), monitor='cat_acc', **opts), # Best model according to validation set accuracy ModelMetaCheckpoint(model_details, os.path.join(train_name, 'model.best.val.hdf5'), monitor='val_cat_acc', **opts), # Best model according to validation set qscore ModelMetaCheckpoint(model_details, os.path.join(train_name, 'model.best.val.qscore.hdf5'), monitor='val_qscore', **opts), # Checkpoints when training set accuracy improves ModelMetaCheckpoint( model_details, os.path.join( train_name, 'model-acc-improvement-{epoch:02d}-{cat_acc:.2f}.hdf5'), monitor='cat_acc', **opts), ModelMetaCheckpoint( model_details, os.path.join( train_name, 'model-val_acc-improvement-{epoch:02d}-{val_cat_acc:.2f}.hdf5' ), monitor='val_cat_acc', **opts), ## Reduce learning rate when no improvement #ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, # verbose=1, min_delta=0.0001, cooldown=0, min_lr=0), # Stop when no improvement EarlyStopping(monitor='val_loss', patience=20), # Log of epoch stats CSVLogger(os.path.join(train_name, 'training.log')), # Allow us to run tensorboard to see how things are going. Some # features require validation data, not clear why. #TensorBoard(log_dir=os.path.join(train_name, 'logs'), # histogram_freq=5, batch_size=100, write_graph=True, # write_grads=True, write_images=True) ] if class_weight is not None: loss = weighted_categorical_crossentropy(class_weight) logger.info("Using weighted_categorical_crossentropy loss function") else: loss = 'sparse_categorical_crossentropy' logger.info("Using {} loss function".format(loss)) model.compile( loss=loss, optimizer='nadam', metrics=[cat_acc, qscore], ) if n_mini_epochs == 1: logger.info( "Not using mini_epochs, an epoch is a full traversal of the training data" ) else: logger.info( "Using mini_epochs, an epoch is a traversal of 1/{} of the training data" .format(n_mini_epochs)) with ProcessPoolExecutor(threads_io) as executor: logger.info("Starting data queues.") prep_function = functools.partial( batcher.sample_to_x_y_bq_worker, max_label_len=batcher.max_label_len, label_encoding=batcher.label_encoding, sparse_labels=batcher.sparse_labels, n_classes=batcher.n_classes) # TODO: should take mini_epochs into account here train_queue = BatchQueue(batcher.train_samples, prep_function, batcher.batch_size, executor, seed=batcher.seed, name='Train', maxsize=100) valid_queue = BatchQueue(batcher.valid_samples, prep_function, batcher.batch_size, executor, seed=batcher.seed, name='Valid', maxsize=100) # run training logger.info("Starting training.") model.fit_generator( generator=train_queue.yield_batches(), steps_per_epoch=train_queue.n_batches // n_mini_epochs, validation_data=valid_queue.yield_batches(), validation_steps=valid_queue.n_batches, max_queue_size=2 * threads_io, workers=1, use_multiprocessing=False, epochs=epochs, callbacks=callbacks, class_weight=class_weight, ) logger.info("Training finished.") train_queue.stop() valid_queue.stop()
def predict(args): """Inference program.""" logger_level = logging.getLogger(__package__).level if logger_level > logging.DEBUG: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" from keras.models import load_model from keras import backend as K args.regions = get_regions(args.bam, region_strs=args.regions) logger = get_named_logger('Predict') logger.info('Processing region(s): {}'.format(' '.join( str(r) for r in args.regions))) # write class names to output with DataStore(args.model) as ds: meta = ds.meta with DataStore(args.output, 'w', verify_on_close=False) as ds: ds.update_meta(meta) logger.info("Setting tensorflow threads to {}.".format(args.threads)) K.tf.logging.set_verbosity(K.tf.logging.ERROR) K.set_session( K.tf.Session(config=K.tf.ConfigProto( intra_op_parallelism_threads=args.threads, inter_op_parallelism_threads=args.threads))) # Split overly long regions to maximum size so as to not create # massive feature matrices MAX_REGION_SIZE = int(1e6) # 1Mb regions = [] for region in args.regions: if region.size > MAX_REGION_SIZE: regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp) else: regs = [region] regions.extend(regs) logger.info("Processing {} long region(s) with batching.".format( len(regions))) model = medaka.models.load_model(args.model, time_steps=args.chunk_len) # the returned regions are those where the pileup width is smaller than chunk_len remainder_regions = run_prediction(args.output, args.bam, regions, model, args.model, args.rle_ref, args.read_fraction, args.chunk_len, args.chunk_ovlp, batch_size=args.batch_size, save_features=args.save_features, tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing) # short/remainder regions: just do things without chunking. We can do this # here because we now have the size of all pileups (and know they are small). # TODO: can we avoid calculating pileups twice whilst controlling memory? if len(remainder_regions) > 0: logger.info("Processing {} short region(s).".format( len(remainder_regions))) model = medaka.models.load_model(args.model, time_steps=None) for region in remainder_regions: new_remainders = run_prediction( args.output, args.bam, [region[0]], model, args.model, args.rle_ref, args.read_fraction, args.chunk_len, args.chunk_ovlp, # these won't be used batch_size=args.batch_size, save_features=args.save_features, tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing, enable_chunking=False) if len(new_remainders) > 0: # shouldn't get here ignored = [x[0] for x in new_remainders] n_ignored = len(ignored) logger.warning("{} regions were not processed: {}.".format( n_ignored, ignored)) logger.info("Finished processing all regions.") if args.check_output: logger.info("Validating and finalising output data.") with DataStore(args.output, 'a') as ds: pass
def __init__(self, features, max_label_len, validation=0.2, seed=0, sparse_labels=True, batch_size=500, threads=1): """ Class to server up batches of training / validation data. :param features: iterable of str, training feature files. :param max_label_len: int, maximum label length, longer labels will be truncated. :param validation: float, fraction of batches to use for validation, or iterable of str, validation feature files. :param seed: int, random seed for separation of batches into training/validation. :param sparse_labels: bool, create sparse labels. """ self.logger = get_named_logger('TrainBatcher') self.features = features self.max_label_len = max_label_len self.validation = validation self.seed = seed self.sparse_labels = sparse_labels self.batch_size = batch_size di = DataIndex(self.features, threads=threads) self.samples = di.samples.copy() self.meta = di.meta.copy() self.label_counts = self.meta['medaka_label_counts'] # check sample size using first batch test_sample, test_fname = self.samples[0] with DataStore(test_fname) as ds: self.feature_shape = ds.load_sample(test_sample).features.shape self.logger.info("Sample features have shape {}".format( self.feature_shape)) if isinstance(self.validation, float): np.random.seed(self.seed) np.random.shuffle(self.samples) n_sample_train = int((1 - self.validation) * len(self.samples)) self.train_samples = self.samples[:n_sample_train] self.valid_samples = self.samples[n_sample_train:] msg = 'Randomly selected {} ({:3.2%}) of features for validation (seed {})' self.logger.info( msg.format(len(self.valid_samples), self.validation, self.seed)) else: self.train_samples = self.samples self.valid_samples = DataIndex(self.validation).samples.copy() msg = 'Found {} validation samples equivalent to {:3.2%} of all the data' fraction = len(self.valid_samples) / len(self.valid_samples) + len( self.train_samples) self.logger.info(msg.format(len(self.valid_samples), fraction)) msg = 'Got {} samples in {} batches ({} labels) for {}' self.logger.info( msg.format(len(self.train_samples), len(self.train_samples) // batch_size, len(self.train_samples) * self.feature_shape[0], 'training')) self.logger.info( msg.format(len(self.valid_samples), len(self.valid_samples) // batch_size, len(self.valid_samples) * self.feature_shape[0], 'validation')) self.n_classes = len(self.label_counts) # get label encoding, given max_label_len self.logger.info("Max label length: {}".format( self.max_label_len if self.max_label_len is not None else 'inf')) self.label_encoding, self.label_decoding, self.label_counts = process_labels( self.label_counts, max_label_len=self.max_label_len)
def on_epoch_end(self, epoch, logs=None): super(ModelMetaCheckpoint, self).on_epoch_end(epoch, logs) filepath = self.filepath.format(epoch=epoch + 1, **logs) with DataStore(filepath, 'a') as ds: ds.meta.update(self.medaka_meta)
def predict(args): """Inference program.""" os.environ["TF_CPP_MIN_LOG_LEVEL"]="2" from keras.models import load_model from keras import backend as K args.regions = get_regions(args.bam, region_strs=args.regions) logger = get_named_logger('Predict') logger.info('Processing region(s): {}'.format(' '.join(str(r) for r in args.regions))) # write class names to output with DataStore(args.model) as ds: meta = ds.meta with DataStore(args.output, 'w') as ds: ds.update_meta(meta) logger.info("Setting tensorflow threads to {}.".format(args.threads)) K.tf.logging.set_verbosity(K.tf.logging.ERROR) K.set_session(K.tf.Session( config=K.tf.ConfigProto( intra_op_parallelism_threads=args.threads, inter_op_parallelism_threads=args.threads) )) # Split overly long regions to maximum size so as to not create # massive feature matrices, then segregate those which cannot be # batched. MAX_REGION_SIZE = int(1e6) # 1Mb long_regions = [] short_regions = [] for region in args.regions: if region.size > MAX_REGION_SIZE: regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp) else: regs = [region] for r in regs: if r.size > args.chunk_len: long_regions.append(r) else: short_regions.append(r) logger.info("Found {} long and {} short regions.".format( len(long_regions), len(short_regions))) if len(long_regions) > 0: logger.info("Processing long regions.") model = medaka.models.load_model(args.model, time_steps=args.chunk_len) run_prediction( args.output, args.bam, long_regions, model, args.model, args.rle_ref, args.read_fraction, args.chunk_len, args.chunk_ovlp, batch_size=args.batch_size, save_features=args.save_features, tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing ) # short regions must be done individually since they have differing lengths # TODO: consider masking (it appears slow to apply wholesale), maybe # step down args.chunk_len by a constant factor until nothing remains. if len(short_regions) > 0: logger.info("Processing short regions") model = medaka.models.load_model(args.model, time_steps=None) for region in short_regions: chunk_len = region.size // 2 chunk_ovlp = chunk_len // 10 # still need overlap as features will be longer run_prediction( args.output, args.bam, [region], model, args.model, args.rle_ref, args.read_fraction, chunk_len, chunk_ovlp, batch_size=args.batch_size, save_features=args.save_features, tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing ) logger.info("Finished processing all regions.")
def yaml2hdf(args): with DataStore(args.output, 'a') as ds, open(args.input) as fh: ds.update_meta(yaml.unsafe_load(fh))
def hdf2yaml(args): with DataStore(args.input) as ds, open(args.output, 'w') as fh: yaml.dump(ds.meta, fh)