Пример #1
0
    def sample_to_x_y_bq_worker(sample, max_label_len, label_encoding, sparse_labels, n_classes):
        """Convert a `Sample` object into an x,y tuple for training.

        :param sample: (filename, sample key)
        :param max_label_len: int, maximum label length, longer labels will be truncated.
        :param label_encoding: {label: int encoded label}.
        :param sparse_labels: bool, create sparse labels.
        :param n_classes: int, number of label classes.

        :returns: (np.ndarray of inputs, np.ndarray of labels)

        """
        sample_key, sample_file = sample

        with DataStore(sample_file) as ds:
            s = ds.load_sample(sample_key)
        if s.labels is None:
            raise ValueError("Sample {} in {} has no labels.".format(sample_key, sample_file))
        x = s.features
        # labels can either be unicode strings or (base, length) integer tuples
        if isinstance(s.labels[0], np.unicode):
            # TODO: is this ever used now we have dispensed with tview code?
            y = np.fromiter((label_encoding[l[:min(max_label_len, len(l))]]
                               for l in s.labels), dtype=int, count=len(s.labels))
        else:
            y = np.fromiter((label_encoding[tuple((l['base'], min(max_label_len, l['run_length'])))]
                             for l in s.labels), dtype=int, count=len(s.labels))
        y = y.reshape(y.shape + (1,))
        if not sparse_labels:
            from keras.utils.np_utils import to_categorical
            y = to_categorical(y, num_classes=n_classes)
        return x, y
Пример #2
0
def load_model(fname, time_steps=None):
    """Load a model from an .hdf file.

    :param fname: .hdf file containing model.
    :param time_steps: number of time points in RNN, `None` for dynamic.

    ..note:: keras' `load_model` cannot handle CuDNNGRU layers, hence this
        function builds the model then loads the weights.
    """
    with DataStore(fname) as ds:
        meta = ds.meta
        num_features = len(meta['medaka_feature_decoding'])
        num_classes = len(meta['medaka_label_decoding'])
        try:
            num_dtypes = len(meta['medaka_features_kwargs']['dtypes'])
        except KeyError:
            num_dtypes = 1
        num_features *= num_dtypes
    build_model = model_builders[meta['medaka_model_name']]

    logger.info(
        "Building model (steps, features, classes): ({}, {}, {})".format(
            time_steps, num_features, num_classes))
    model = build_model(time_steps, num_features, num_classes,
                        **meta['medaka_model_kwargs'])
    logger.info("Loading weights from {}".format(fname))
    model.load_weights(fname)
    return model
Пример #3
0
def create_labelled_samples(args):
    logger = get_named_logger('Prepare')
    regions = get_regions(args.bam, args.regions)
    reg_str = '\n'.join(['\t\t\t{}'.format(r) for r in regions])
    logger.info('Got regions:\n{}'.format(reg_str))

    labels_counter = Counter()

    no_data = False
    with DataStore(args.output, 'w') as ds:
        # write feature options to file
        logger.info("Writing meta data to file.")
        with DataStore(args.model) as model:
            meta = {
                k: model.meta[k]
                for k in ('medaka_features_kwargs', 'medaka_feature_decoding')
            }
        ds.update_meta(meta)
        # TODO: this parallelism would be better in `SampleGenerator.bams_to_training_samples`
        #       since training alignments are usually chunked.
        with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.threads) as executor:
            # break up overly long chunks
            MAX_SIZE = int(1e6)
            regions = itertools.chain(*(r.split(MAX_SIZE) for r in regions))
            futures = [
                executor.submit(_labelled_samples_worker, args, reg)
                for reg in regions
            ]
            for fut in concurrent.futures.as_completed(futures):
                if fut.exception() is None:
                    samples, region, fencoder_args, fencoder_decoder = fut.result(
                    )
                    logger.info("Writing {} samples for region {}".format(
                        len(samples), region))
                    for sample in samples:
                        ds.write_sample(sample)
                else:
                    logger.info(fut.exception())
                fut._result = None  # python issue 27144
        no_data = ds.n_samples == 0

    if no_data:
        logger.critical(
            "Warning: No training data was written to file, deleting output.")
        os.remove(args.output)
Пример #4
0
 def test_000_load_all_models(self):
     for name, model_file in model_dict.items():
         model = models.load_model(model_file)
         self.assertIsInstance(model, tensorflow.keras.models.Model)
         # Check we can get necessary functions for inference
         with DataStore(model_file) as ds:
             feature_encoder = ds.get_meta('feature_encoder')
             self.assertIsInstance(feature_encoder, BaseFeatureEncoder)
             label_scheme = ds.get_meta('label_scheme')
             self.assertIsInstance(label_scheme, BaseLabelScheme)
Пример #5
0
 def test_999_load_all_models(self):
     for name in medaka.options.allowed_models:
         model_file = models.resolve_model(name)
         model = medaka.models.open_model(model_file).load_model()
         self.assertIsInstance(model, tensorflow.keras.models.Model)
         # Check we can get necessary functions for inference
         with DataStore(model_file) as ds:
             feature_encoder = ds.get_meta('feature_encoder')
             self.assertIsInstance(feature_encoder, BaseFeatureEncoder)
             label_scheme = ds.get_meta('label_scheme')
             self.assertIsInstance(label_scheme, BaseLabelScheme)
Пример #6
0
    def __init__(self,
                 bam,
                 region,
                 model,
                 rle_ref=None,
                 truth_bam=None,
                 read_fraction=None,
                 chunk_len=1000,
                 chunk_overlap=200,
                 tag_name=None,
                 tag_value=None,
                 tag_keep_missing=False,
                 enable_chunking=True):
        """Generate chunked inference (or training) samples.

        :param bam: `.bam` containing alignments from which to generate samples.
        :param region: a `Region` for which to generate samples.
        :param model: a medaka model.
        :param truth_bam: a `.bam` containing alignment of truth sequence to
            `reference` sequence. Required only for creating training chunks.
        :param reference: reference `.fasta`, should correspond to `bam`.
        :param tag_name: two letter tag name by which to filter reads.
        :param tag_value: integer value of tag for reads to keep.
        :param tag_keep_missing: whether to keep reads when tag is missing.
        :param enable_chunking: when yielding samples, do so in chunks.

        """
        self.logger = get_named_logger("Sampler")
        self.sample_type = "training" if truth_bam is not None else "consensus"
        self.logger.info("Initializing sampler for {} of region {}.".format(
            self.sample_type, region))
        with DataStore(model) as ds:
            self.fencoder_args = ds.meta['medaka_features_kwargs']
        self.fencoder = FeatureEncoder(tag_name=tag_name,
                                       tag_value=tag_value,
                                       tag_keep_missing=tag_keep_missing,
                                       **self.fencoder_args)

        self.bam = bam
        self.region = region
        self.model = model
        self.rle_ref = rle_ref
        self.truth_bam = truth_bam
        self.read_fraction = read_fraction
        self.chunk_len = chunk_len
        self.chunk_overlap = chunk_overlap
        self.enable_chunking = enable_chunking
        self._source = None  # the base data to be chunked
        self._quarantined = list()  # samples which are shorter than chunk size
Пример #7
0
def run_prediction(output, bam, regions, model, model_file, rle_ref, read_fraction, chunk_len, chunk_ovlp,
                   batch_size=200, save_features=False, tag_name=None, tag_value=None, tag_keep_missing=False):
    """Inference worker."""

    logger = get_named_logger('PWorker')

    def sample_gen():
        # chain all samples whilst dispensing with generators when done
        #   (they hold the feature vector in memory until they die)
        for region in regions:
            data_gen = SampleGenerator(
                bam, region, model_file, rle_ref, read_fraction,
                chunk_len=chunk_len, chunk_overlap=chunk_ovlp,
                tag_name=tag_name, tag_value=tag_value,
                tag_keep_missing=tag_keep_missing)
            yield from data_gen.samples
    batches = background_generator(
        grouper(sample_gen(), batch_size), 10
    )

    total_region_mbases = sum(r.size for r in regions) / 1e6
    logger.info("Running inference for {:.1f}M draft bases.".format(total_region_mbases))

    with DataStore(output, 'a') as ds:
        mbases_done = 0

        t0 = now()
        tlast = t0
        for data in batches:
            x_data = np.stack([x.features for x in data])
            class_probs = model.predict_on_batch(x_data)
            mbases_done += sum(x.span for x in data) / 1e6
            mbases_done = min(mbases_done, total_region_mbases)  # just to avoid funny log msg
            t1 = now()
            if t1 - tlast > 10:
                tlast = t1
                msg = '{:.1%} Done ({:.1f}/{:.1f} Mbases) in {:.1f}s'
                logger.info(msg.format(mbases_done / total_region_mbases, mbases_done, total_region_mbases, t1 - t0))

            best = np.argmax(class_probs, -1)
            for sample, prob, pred, feat in zip(data, class_probs, best, x_data):
                # write out positions and predictions for later analysis
                sample_d = sample._asdict()
                sample_d['label_probs'] = prob
                sample_d['features'] = feat if save_features else None
                ds.write_sample(Sample(**sample_d))

    logger.info('All done')
    return None
Пример #8
0
def run_training(train_name,
                 batcher,
                 model_fp=None,
                 epochs=5000,
                 class_weight=None,
                 n_mini_epochs=1,
                 threads_io=1):
    """Run training."""
    from keras.callbacks import CSVLogger, TensorBoard, EarlyStopping, ReduceLROnPlateau
    from medaka.keras_ext import ModelMetaCheckpoint, SequenceBatcher, BatchQueue

    logger = get_named_logger('RunTraining')

    if model_fp is None:
        model_name = medaka.models.default_model
        model_kwargs = {
            k: v.default
            for (k, v) in inspect.signature(
                medaka.models.model_builders[model_name]).parameters.items()
            if v.default is not inspect.Parameter.empty
        }
    else:
        with DataStore(model_fp) as ds:
            model_name = ds.meta['medaka_model_name']
            model_kwargs = ds.meta['medaka_model_kwargs']

    opt_str = '\n'.join(
        ['{}: {}'.format(k, v) for k, v in model_kwargs.items()])
    logger.info('Building {} model with: \n{}'.format(model_name, opt_str))
    num_classes = len(batcher.label_counts)
    timesteps, feat_dim = batcher.feature_shape
    model = medaka.models.model_builders[model_name](timesteps, feat_dim,
                                                     num_classes,
                                                     **model_kwargs)

    if model_fp is not None:
        try:
            model.load_weights(model_fp)
            logger.info("Loading weights from {}".format(model_fp))
        except:
            logger.info("Could not load weights from {}".format(model_fp))

    msg = "feat_dim: {}, timesteps: {}, num_classes: {}"
    logger.info(msg.format(feat_dim, timesteps, num_classes))
    model.summary()

    model_details = batcher.meta.copy()

    model_details['medaka_model_name'] = model_name
    model_details['medaka_model_kwargs'] = model_kwargs
    model_details['medaka_label_decoding'] = batcher.label_decoding

    opts = dict(verbose=1, save_best_only=True, mode='max')

    callbacks = [
        # Best model according to training set accuracy
        ModelMetaCheckpoint(model_details,
                            os.path.join(train_name, 'model.best.hdf5'),
                            monitor='cat_acc',
                            **opts),
        # Best model according to validation set accuracy
        ModelMetaCheckpoint(model_details,
                            os.path.join(train_name, 'model.best.val.hdf5'),
                            monitor='val_cat_acc',
                            **opts),
        # Best model according to validation set qscore
        ModelMetaCheckpoint(model_details,
                            os.path.join(train_name,
                                         'model.best.val.qscore.hdf5'),
                            monitor='val_qscore',
                            **opts),
        # Checkpoints when training set accuracy improves
        ModelMetaCheckpoint(
            model_details,
            os.path.join(
                train_name,
                'model-acc-improvement-{epoch:02d}-{cat_acc:.2f}.hdf5'),
            monitor='cat_acc',
            **opts),
        ModelMetaCheckpoint(
            model_details,
            os.path.join(
                train_name,
                'model-val_acc-improvement-{epoch:02d}-{val_cat_acc:.2f}.hdf5'
            ),
            monitor='val_cat_acc',
            **opts),
        ## Reduce learning rate when no improvement
        #ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5,
        #    verbose=1, min_delta=0.0001, cooldown=0, min_lr=0),
        # Stop when no improvement
        EarlyStopping(monitor='val_loss', patience=20),
        # Log of epoch stats
        CSVLogger(os.path.join(train_name, 'training.log')),
        # Allow us to run tensorboard to see how things are going. Some
        #   features require validation data, not clear why.
        #TensorBoard(log_dir=os.path.join(train_name, 'logs'),
        #            histogram_freq=5, batch_size=100, write_graph=True,
        #            write_grads=True, write_images=True)
    ]

    if class_weight is not None:
        loss = weighted_categorical_crossentropy(class_weight)
        logger.info("Using weighted_categorical_crossentropy loss function")
    else:
        loss = 'sparse_categorical_crossentropy'
        logger.info("Using {} loss function".format(loss))

    model.compile(
        loss=loss,
        optimizer='nadam',
        metrics=[cat_acc, qscore],
    )

    if n_mini_epochs == 1:
        logger.info(
            "Not using mini_epochs, an epoch is a full traversal of the training data"
        )
    else:
        logger.info(
            "Using mini_epochs, an epoch is a traversal of 1/{} of the training data"
            .format(n_mini_epochs))

    with ProcessPoolExecutor(threads_io) as executor:
        logger.info("Starting data queues.")
        prep_function = functools.partial(
            batcher.sample_to_x_y_bq_worker,
            max_label_len=batcher.max_label_len,
            label_encoding=batcher.label_encoding,
            sparse_labels=batcher.sparse_labels,
            n_classes=batcher.n_classes)
        # TODO: should take mini_epochs into account here
        train_queue = BatchQueue(batcher.train_samples,
                                 prep_function,
                                 batcher.batch_size,
                                 executor,
                                 seed=batcher.seed,
                                 name='Train',
                                 maxsize=100)
        valid_queue = BatchQueue(batcher.valid_samples,
                                 prep_function,
                                 batcher.batch_size,
                                 executor,
                                 seed=batcher.seed,
                                 name='Valid',
                                 maxsize=100)

        # run training
        logger.info("Starting training.")
        model.fit_generator(
            generator=train_queue.yield_batches(),
            steps_per_epoch=train_queue.n_batches // n_mini_epochs,
            validation_data=valid_queue.yield_batches(),
            validation_steps=valid_queue.n_batches,
            max_queue_size=2 * threads_io,
            workers=1,
            use_multiprocessing=False,
            epochs=epochs,
            callbacks=callbacks,
            class_weight=class_weight,
        )
        logger.info("Training finished.")
        train_queue.stop()
        valid_queue.stop()
Пример #9
0
def predict(args):
    """Inference program."""
    logger_level = logging.getLogger(__package__).level
    if logger_level > logging.DEBUG:
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

    from keras.models import load_model
    from keras import backend as K

    args.regions = get_regions(args.bam, region_strs=args.regions)
    logger = get_named_logger('Predict')
    logger.info('Processing region(s): {}'.format(' '.join(
        str(r) for r in args.regions)))

    # write class names to output
    with DataStore(args.model) as ds:
        meta = ds.meta
    with DataStore(args.output, 'w', verify_on_close=False) as ds:
        ds.update_meta(meta)

    logger.info("Setting tensorflow threads to {}.".format(args.threads))
    K.tf.logging.set_verbosity(K.tf.logging.ERROR)
    K.set_session(
        K.tf.Session(config=K.tf.ConfigProto(
            intra_op_parallelism_threads=args.threads,
            inter_op_parallelism_threads=args.threads)))

    # Split overly long regions to maximum size so as to not create
    #   massive feature matrices
    MAX_REGION_SIZE = int(1e6)  # 1Mb
    regions = []
    for region in args.regions:
        if region.size > MAX_REGION_SIZE:
            regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp)
        else:
            regs = [region]
        regions.extend(regs)

    logger.info("Processing {} long region(s) with batching.".format(
        len(regions)))
    model = medaka.models.load_model(args.model, time_steps=args.chunk_len)
    # the returned regions are those where the pileup width is smaller than chunk_len
    remainder_regions = run_prediction(args.output,
                                       args.bam,
                                       regions,
                                       model,
                                       args.model,
                                       args.rle_ref,
                                       args.read_fraction,
                                       args.chunk_len,
                                       args.chunk_ovlp,
                                       batch_size=args.batch_size,
                                       save_features=args.save_features,
                                       tag_name=args.tag_name,
                                       tag_value=args.tag_value,
                                       tag_keep_missing=args.tag_keep_missing)

    # short/remainder regions: just do things without chunking. We can do this
    # here because we now have the size of all pileups (and know they are small).
    # TODO: can we avoid calculating pileups twice whilst controlling memory?
    if len(remainder_regions) > 0:
        logger.info("Processing {} short region(s).".format(
            len(remainder_regions)))
        model = medaka.models.load_model(args.model, time_steps=None)
        for region in remainder_regions:
            new_remainders = run_prediction(
                args.output,
                args.bam,
                [region[0]],
                model,
                args.model,
                args.rle_ref,
                args.read_fraction,
                args.chunk_len,
                args.chunk_ovlp,  # these won't be used
                batch_size=args.batch_size,
                save_features=args.save_features,
                tag_name=args.tag_name,
                tag_value=args.tag_value,
                tag_keep_missing=args.tag_keep_missing,
                enable_chunking=False)
            if len(new_remainders) > 0:
                # shouldn't get here
                ignored = [x[0] for x in new_remainders]
                n_ignored = len(ignored)
                logger.warning("{} regions were not processed: {}.".format(
                    n_ignored, ignored))

    logger.info("Finished processing all regions.")

    if args.check_output:
        logger.info("Validating and finalising output data.")
        with DataStore(args.output, 'a') as ds:
            pass
Пример #10
0
    def __init__(self,
                 features,
                 max_label_len,
                 validation=0.2,
                 seed=0,
                 sparse_labels=True,
                 batch_size=500,
                 threads=1):
        """
        Class to server up batches of training / validation data.

        :param features: iterable of str, training feature files.
        :param max_label_len: int, maximum label length, longer labels will be truncated.
        :param validation: float, fraction of batches to use for validation, or
                iterable of str, validation feature files.
        :param seed: int, random seed for separation of batches into training/validation.
        :param sparse_labels: bool, create sparse labels.

        """
        self.logger = get_named_logger('TrainBatcher')

        self.features = features
        self.max_label_len = max_label_len
        self.validation = validation
        self.seed = seed
        self.sparse_labels = sparse_labels
        self.batch_size = batch_size

        di = DataIndex(self.features, threads=threads)
        self.samples = di.samples.copy()
        self.meta = di.meta.copy()
        self.label_counts = self.meta['medaka_label_counts']

        # check sample size using first batch
        test_sample, test_fname = self.samples[0]
        with DataStore(test_fname) as ds:
            self.feature_shape = ds.load_sample(test_sample).features.shape
        self.logger.info("Sample features have shape {}".format(
            self.feature_shape))

        if isinstance(self.validation, float):
            np.random.seed(self.seed)
            np.random.shuffle(self.samples)
            n_sample_train = int((1 - self.validation) * len(self.samples))
            self.train_samples = self.samples[:n_sample_train]
            self.valid_samples = self.samples[n_sample_train:]
            msg = 'Randomly selected {} ({:3.2%}) of features for validation (seed {})'
            self.logger.info(
                msg.format(len(self.valid_samples), self.validation,
                           self.seed))
        else:
            self.train_samples = self.samples
            self.valid_samples = DataIndex(self.validation).samples.copy()
            msg = 'Found {} validation samples equivalent to {:3.2%} of all the data'
            fraction = len(self.valid_samples) / len(self.valid_samples) + len(
                self.train_samples)
            self.logger.info(msg.format(len(self.valid_samples), fraction))

        msg = 'Got {} samples in {} batches ({} labels) for {}'
        self.logger.info(
            msg.format(len(self.train_samples),
                       len(self.train_samples) // batch_size,
                       len(self.train_samples) * self.feature_shape[0],
                       'training'))
        self.logger.info(
            msg.format(len(self.valid_samples),
                       len(self.valid_samples) // batch_size,
                       len(self.valid_samples) * self.feature_shape[0],
                       'validation'))

        self.n_classes = len(self.label_counts)

        # get label encoding, given max_label_len
        self.logger.info("Max label length: {}".format(
            self.max_label_len if self.max_label_len is not None else 'inf'))
        self.label_encoding, self.label_decoding, self.label_counts = process_labels(
            self.label_counts, max_label_len=self.max_label_len)
Пример #11
0
 def on_epoch_end(self, epoch, logs=None):
     super(ModelMetaCheckpoint, self).on_epoch_end(epoch, logs)
     filepath = self.filepath.format(epoch=epoch + 1, **logs)
     with DataStore(filepath, 'a') as ds:
         ds.meta.update(self.medaka_meta)
Пример #12
0
def predict(args):
    """Inference program."""
    os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"
    from keras.models import load_model
    from keras import backend as K

    args.regions = get_regions(args.bam, region_strs=args.regions)
    logger = get_named_logger('Predict')
    logger.info('Processing region(s): {}'.format(' '.join(str(r) for r in args.regions)))

    # write class names to output
    with DataStore(args.model) as ds:
        meta = ds.meta
    with DataStore(args.output, 'w') as ds:
        ds.update_meta(meta)

    logger.info("Setting tensorflow threads to {}.".format(args.threads))
    K.tf.logging.set_verbosity(K.tf.logging.ERROR)
    K.set_session(K.tf.Session(
        config=K.tf.ConfigProto(
            intra_op_parallelism_threads=args.threads,
            inter_op_parallelism_threads=args.threads)
    ))

    # Split overly long regions to maximum size so as to not create
    #   massive feature matrices, then segregate those which cannot be
    #   batched.
    MAX_REGION_SIZE = int(1e6)  # 1Mb
    long_regions = []
    short_regions = []
    for region in args.regions:
        if region.size > MAX_REGION_SIZE:
            regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp)
        else:
            regs = [region]
        for r in regs:
            if r.size > args.chunk_len:
                long_regions.append(r)
            else:
                short_regions.append(r)
    logger.info("Found {} long and {} short regions.".format(
        len(long_regions), len(short_regions)))

    if len(long_regions) > 0:
        logger.info("Processing long regions.")
        model = medaka.models.load_model(args.model, time_steps=args.chunk_len)
        run_prediction(
            args.output, args.bam, long_regions, model, args.model, args.rle_ref, args.read_fraction,
            args.chunk_len, args.chunk_ovlp,
            batch_size=args.batch_size, save_features=args.save_features,
            tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing
        )

    # short regions must be done individually since they have differing lengths
    #   TODO: consider masking (it appears slow to apply wholesale), maybe
    #         step down args.chunk_len by a constant factor until nothing remains.
    if len(short_regions) > 0:
        logger.info("Processing short regions")
        model = medaka.models.load_model(args.model, time_steps=None)
        for region in short_regions:
            chunk_len = region.size // 2
            chunk_ovlp = chunk_len // 10 # still need overlap as features will be longer
            run_prediction(
                args.output, args.bam, [region], model, args.model, args.rle_ref, args.read_fraction,
                chunk_len, chunk_ovlp,
                batch_size=args.batch_size, save_features=args.save_features,
                tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing
            )
    logger.info("Finished processing all regions.")
Пример #13
0
def yaml2hdf(args):
    with DataStore(args.output, 'a') as ds, open(args.input) as fh:
        ds.update_meta(yaml.unsafe_load(fh))
Пример #14
0
def hdf2yaml(args):
    with DataStore(args.input) as ds, open(args.output, 'w') as fh:
        yaml.dump(ds.meta, fh)