Exemplo n.º 1
0
 def samples_to_batch(samples, prep_func, name, batch, epoch):
     t0 = now()
     items = [prep_func(s) for s in samples]
     xs, ys = zip(*items)
     x, y = np.stack(xs), np.stack(ys)
     get_named_logger(name).debug(
         "Took {:5.3}s to load batch {} (epoch {})".format(
             now() - t0, batch, epoch))
     return x, y
Exemplo n.º 2
0
    def bam_to_alignments(truth_bam, ref_name, start=None, end=None):
        """Create list of TruthAlignment objects from a bam of Truth aligned to ref.

        :param truth_bam: (sorted indexed) bam with true sequence aligned to reference
        :param ref: name of reference to process
        :param start: starting position within reference
        :param end: ending position within reference
            (all alignments with any overlap with the interval start:end will be retrieved)
        :returns: tuple(positions, encoded_label_array)

            - positions: numpy structured array with 'ref_major'
              (reference position index) and 'ref_minor'
              (trailing insertion index) fields.

            - feature_array: 1D numpy array of encoded labels
        """
        with pysam.AlignmentFile(truth_bam, 'rb') as bamfile:
            aln_reads = bamfile.fetch(reference=ref_name, start=start, end=end)
            alignments = [
                TruthAlignment(r) for r in aln_reads
                if not (r.is_unmapped or r.is_secondary)
            ]
            alignments.sort(key=attrgetter('start'))
        logger = get_named_logger("TruthAlign")
        logger.info("Retrieved {} alignments.".format(len(alignments)))
        return alignments
Exemplo n.º 3
0
def process_labels(label_counts, max_label_len=10):
    """Create map from full labels to (encoded) truncated labels.

    :param label_counrs: `Counter` obj of label counts.
    :param max_label_len: int, maximum label length, longer labels will be truncated.
    :returns:
    :param label_encoding: {label: int encoded label}.
    :param sparse_labels: bool, create sparse labels.
    :param n_classes: int, number of label classes.
    :returns: ({label: int encoding}, [label decodings], `Counter` of truncated counts).
    """
    logger = get_named_logger('Labelling')

    old_labels = [k for k in label_counts.keys()]
    if type(old_labels[0]) == tuple:
        new_labels = (l[1] * decoding[l[0]].upper() for l in old_labels)
    else:
        new_labels = [l for l in old_labels]

    if max_label_len < np.inf:
        new_labels = [l[:max_label_len] for l in new_labels]

    old_to_new = dict(zip(old_labels, new_labels))
    label_decoding = list(sorted(set(new_labels)))
    label_encoding = { l: label_decoding.index(old_to_new[l]) for l in old_labels}
    logger.info("Label encoding dict is:\n{}".format('\n'.join(
        '{}: {}'.format(k, v) for k, v in label_encoding.items()
    )))

    new_counts = Counter()
    for l in old_labels:
        new_counts[label_encoding[l]] += label_counts[l]
    logger.info("New label counts {}".format(new_counts))

    return label_encoding, label_decoding, new_counts
Exemplo n.º 4
0
    def __init__(self, filename, cache=True):
        """Basic VCF parser.

        :param filename: .vcf file.
        :param cache: if True, all parsed variants are stored in memory for
            faster subsequent access.

        """

        self.filename = filename
        self.cache = cache
        self._indexed = False
        self._tree = None
        self._parse_lock = Lock()
        self.logger = get_named_logger('VCFReader')

        # Read both metadata and header
        self.meta = []
        self.header = None
        with open(filename, encoding='utf-8') as handle:
            for line in handle:
                line = line.replace('\n', '')
                if line.startswith('##'):
                    self.meta.append(line[2:])
                elif line.startswith('#'):
                    line = line[1:]
                    self.header = line.split('\t')
                    break
Exemplo n.º 5
0
    def __init__(self, filenames, threads=4):

        self.logger = get_named_logger('DataIndex')

        self.filenames = filenames

        with DataStore(filenames[0]) as ds:
            self.logger.debug('Loading meta from {}'.format(filenames[0]))
            self.meta = ds.meta

        c_grp = 'medaka_label_counts'
        if c_grp in self.meta:
            self.meta[c_grp] = Counter()

        del self.meta['medaka_samples']

        self.samples = []

        with ProcessPoolExecutor(threads) as executor:
            future_to_f = {executor.submit(DataIndex._load_meta, f): f for f in filenames}
            for i, future in enumerate(as_completed(future_to_f), 1):
                f = future_to_f[future]
                try:
                    meta = future.result()
                    self.samples.extend([(s, f) for s in meta['medaka_samples']])
                    self.meta[c_grp].update(meta[c_grp])
                except Exception as exc:
                    self.logger.info('Could not load meta from {}'.format(f))
                else:
                    self.logger.info('Loaded sample-index from {}/{} ({:.2%}) of feature files.'.format(i, len(filenames), i / len(filenames)))

        # make order of samples independent of order in which tasks complete
        self.samples.sort()

        self._index = None
Exemplo n.º 6
0
    def __init__(self, fname, chrom, start, end, reference_fasta, label_decoding):
        vcf_region_str = '{}:{}-{}'.format(chrom, start, end) #is this correct?
        self.label_decoding = label_decoding
        self.logger = get_named_logger('VCFWriter')
        self.logger.info("Writing variants for {}".format(vcf_region_str))

        vcf_meta = ['region={}'.format(vcf_region_str)]
        self.writer = vcf.VCFWriter(fname, meta_info=vcf_meta)
        self.ref_fasta = pysam.FastaFile(reference_fasta)
Exemplo n.º 7
0
    def __init__(self, alignment):
        """Create a TruthAlignment oblist from a `pysam.libcalignedsegment.AlignedSegment` object.

        :param alignment: `pysam.libcalignedsegment.AlignedSegment` object.
        """
        self.aln = alignment  # so we can get positions and labels later
        # initialise start and end (which might be moved)
        self.start = self.aln.reference_start  # zero-based
        self.end = self.aln.reference_end
        self.is_kept = True
        self.logger = get_named_logger('TruthAlign')
Exemplo n.º 8
0
    def __init__(self, filename, mode='r'):

        self.filename = filename
        self.mode = mode

        self._sample_keys = set()
        self.fh = None

        self.logger = get_named_logger('DataStore')

        self._meta = None
Exemplo n.º 9
0
    def __init__(self, filename, mode='r', verify_on_close=True):

        self.filename = filename
        self.mode = mode
        self.verify_on_close = verify_on_close

        self._sample_keys = set()
        self.fh = None

        self.logger = get_named_logger('DataStore')

        self._meta = None
Exemplo n.º 10
0
    def __init__(self,
                 samples,
                 prep_func,
                 batch_size,
                 executor,
                 seed=None,
                 name='Train',
                 maxsize=100):
        """Load and queue training samples into batches from `.hdf` files.

        :param samples: tuples of (filename, hdf sample key).
        :param prep_func: function to transform a sample to x,y data.
        :param batch_size: group samples by this number.
        :param executor: `ThreadPoolExecutor` instance.
        :param seed: seed for shuffling.
        :param name: str, name for logger.
        :param maxsize: int, maximum queue size.

        Once initialized batches can be retrieved using batch_q._queue.get().

        """
        self.samples = samples
        self.prep_func = prep_func
        self.batch_size = batch_size

        if seed is not None:
            np.random.seed(seed)

        self.name = name
        self.logger = get_named_logger('{}Batcher'.format(name.capitalize()))
        self.maxsize = maxsize
        self._queue = queue.Queue(maxsize=self.maxsize)
        self.executor = executor
        self.stopped = threading.Event()
        self.qthread = threading.Thread(target=self._fill_queue_batch)
        self.qthread.daemon = True

        original_size = len(self.samples)
        self.n_batches = len(self.samples) // self.batch_size
        self.samples = self.samples[:self.n_batches * self.batch_size]
        self.logger.info(
            '{} batches of {} samples ({}), from {} original.'.format(
                self.n_batches, self.batch_size, len(self.samples),
                original_size))
        if self.n_batches == 0:
            raise ValueError("Number of batches is zero.")

        self.qthread.start()
        time.sleep(2)
        self.logger.info(
            "Started reading samples from files with queue size {}".format(
                maxsize))
Exemplo n.º 11
0
def _labelled_samples_worker(args, region):
    logger = get_named_logger('PrepWork')
    logger.info("Processing region {}.".format(region))
    data_gen = SampleGenerator(args.bam,
                               region,
                               args.model,
                               args.rle_ref,
                               truth_bam=args.truth,
                               read_fraction=args.read_fraction,
                               chunk_len=args.chunk_len,
                               chunk_overlap=args.chunk_ovlp)
    return list(data_gen.samples), region, deepcopy(
        data_gen.fencoder_args), deepcopy(data_gen.fencoder.decoding)
Exemplo n.º 12
0
def train(args):
    """Training program."""
    train_name = args.train_name
    mkdir_p(train_name, info='Results will be overwritten.')

    logger = get_named_logger('Training')
    logger.debug("Loading datasets:\n{}".format('\n'.join(args.features)))

    sparse_labels = not args.balanced_weights

    args.validation = args.validation_features if args.validation_features is not None else args.validation_split

    batcher = TrainBatcher(args.features,
                           args.max_label_len,
                           args.validation,
                           args.seed,
                           sparse_labels,
                           args.batch_size,
                           threads=args.threads_io)

    if args.balanced_weights:
        n_labels = sum(batcher.label_counts.values())
        n_classes = len(batcher.label_counts)
        class_weight = {
            k: float(n_labels) / (n_classes * count)
            for (k, count) in batcher.label_counts.items()
        }
        class_weight = np.array(
            [class_weight[c] for c in sorted(class_weight.keys())])
    else:
        class_weight = None

    h = lambda d, i: d[i] if d is not None else 1
    logger.info("Label statistics are:\n{}".format('\n'.join(
        '{} ({}) {} (w. {:9.6f})'.format(i, l, batcher.label_counts[i],
                                         h(class_weight, i))
        for i, l in enumerate(batcher.label_decoding))))

    import tensorflow as tf
    with tf.device('/gpu:{}'.format(args.device)):
        run_training(train_name,
                     batcher,
                     model_fp=args.model,
                     epochs=args.epochs,
                     class_weight=class_weight,
                     n_mini_epochs=args.mini_epochs,
                     threads_io=args.threads_io)

    # stop batching threads
    logger.info("Training finished.")
Exemplo n.º 13
0
    def __init__(self,
                 bam,
                 region,
                 model,
                 rle_ref=None,
                 truth_bam=None,
                 read_fraction=None,
                 chunk_len=1000,
                 chunk_overlap=200,
                 tag_name=None,
                 tag_value=None,
                 tag_keep_missing=False,
                 enable_chunking=True):
        """Generate chunked inference (or training) samples.

        :param bam: `.bam` containing alignments from which to generate samples.
        :param region: a `Region` for which to generate samples.
        :param model: a medaka model.
        :param truth_bam: a `.bam` containing alignment of truth sequence to
            `reference` sequence. Required only for creating training chunks.
        :param reference: reference `.fasta`, should correspond to `bam`.
        :param tag_name: two letter tag name by which to filter reads.
        :param tag_value: integer value of tag for reads to keep.
        :param tag_keep_missing: whether to keep reads when tag is missing.
        :param enable_chunking: when yielding samples, do so in chunks.

        """
        self.logger = get_named_logger("Sampler")
        self.sample_type = "training" if truth_bam is not None else "consensus"
        self.logger.info("Initializing sampler for {} of region {}.".format(
            self.sample_type, region))
        with DataStore(model) as ds:
            self.fencoder_args = ds.meta['medaka_features_kwargs']
        self.fencoder = FeatureEncoder(tag_name=tag_name,
                                       tag_value=tag_value,
                                       tag_keep_missing=tag_keep_missing,
                                       **self.fencoder_args)

        self.bam = bam
        self.region = region
        self.model = model
        self.rle_ref = rle_ref
        self.truth_bam = truth_bam
        self.read_fraction = read_fraction
        self.chunk_len = chunk_len
        self.chunk_overlap = chunk_overlap
        self.enable_chunking = enable_chunking
        self._source = None  # the base data to be chunked
        self._quarantined = list()  # samples which are shorter than chunk size
Exemplo n.º 14
0
def run_prediction(output, bam, regions, model, model_file, rle_ref, read_fraction, chunk_len, chunk_ovlp,
                   batch_size=200, save_features=False, tag_name=None, tag_value=None, tag_keep_missing=False):
    """Inference worker."""

    logger = get_named_logger('PWorker')

    def sample_gen():
        # chain all samples whilst dispensing with generators when done
        #   (they hold the feature vector in memory until they die)
        for region in regions:
            data_gen = SampleGenerator(
                bam, region, model_file, rle_ref, read_fraction,
                chunk_len=chunk_len, chunk_overlap=chunk_ovlp,
                tag_name=tag_name, tag_value=tag_value,
                tag_keep_missing=tag_keep_missing)
            yield from data_gen.samples
    batches = background_generator(
        grouper(sample_gen(), batch_size), 10
    )

    total_region_mbases = sum(r.size for r in regions) / 1e6
    logger.info("Running inference for {:.1f}M draft bases.".format(total_region_mbases))

    with DataStore(output, 'a') as ds:
        mbases_done = 0

        t0 = now()
        tlast = t0
        for data in batches:
            x_data = np.stack([x.features for x in data])
            class_probs = model.predict_on_batch(x_data)
            mbases_done += sum(x.span for x in data) / 1e6
            mbases_done = min(mbases_done, total_region_mbases)  # just to avoid funny log msg
            t1 = now()
            if t1 - tlast > 10:
                tlast = t1
                msg = '{:.1%} Done ({:.1f}/{:.1f} Mbases) in {:.1f}s'
                logger.info(msg.format(mbases_done / total_region_mbases, mbases_done, total_region_mbases, t1 - t0))

            best = np.argmax(class_probs, -1)
            for sample, prob, pred, feat in zip(data, class_probs, best, x_data):
                # write out positions and predictions for later analysis
                sample_d = sample._asdict()
                sample_d['label_probs'] = prob
                sample_d['features'] = feat if save_features else None
                ds.write_sample(Sample(**sample_d))

    logger.info('All done')
    return None
Exemplo n.º 15
0
def create_labelled_samples(args):
    logger = get_named_logger('Prepare')
    regions = get_regions(args.bam, args.regions)
    reg_str = '\n'.join(['\t\t\t{}'.format(r) for r in regions])
    logger.info('Got regions:\n{}'.format(reg_str))

    labels_counter = Counter()

    no_data = False
    with DataStore(args.output, 'w') as ds:
        # write feature options to file
        logger.info("Writing meta data to file.")
        with DataStore(args.model) as model:
            meta = {
                k: model.meta[k]
                for k in ('medaka_features_kwargs', 'medaka_feature_decoding')
            }
        ds.update_meta(meta)
        # TODO: this parallelism would be better in `SampleGenerator.bams_to_training_samples`
        #       since training alignments are usually chunked.
        with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.threads) as executor:
            # break up overly long chunks
            MAX_SIZE = int(1e6)
            regions = itertools.chain(*(r.split(MAX_SIZE) for r in regions))
            futures = [
                executor.submit(_labelled_samples_worker, args, reg)
                for reg in regions
            ]
            for fut in concurrent.futures.as_completed(futures):
                if fut.exception() is None:
                    samples, region, fencoder_args, fencoder_decoder = fut.result(
                    )
                    logger.info("Writing {} samples for region {}".format(
                        len(samples), region))
                    for sample in samples:
                        ds.write_sample(sample)
                else:
                    logger.info(fut.exception())
                fut._result = None  # python issue 27144
        no_data = ds.n_samples == 0

    if no_data:
        logger.critical(
            "Warning: No training data was written to file, deleting output.")
        os.remove(args.output)
Exemplo n.º 16
0
    def __init__(self,
                 filename,
                 mode='w',
                 header=('CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER',
                         'INFO', 'FORMAT', 'SAMPLE'),
                 meta_info=[],
                 version='4.3'):

        self.filename = filename
        self.mode = mode
        self.header = header
        if version not in self.version_options:
            raise ValueError('version must be one of {}'.format(
                self.version_options))
        self.version = version
        self.meta = ['fileformat=VCFv{}'.format(self.version)] + meta_info
        self.logger = get_named_logger('VCFWriter')
Exemplo n.º 17
0
def alphabet_filter(sample_gen,
                    alphabet=None,
                    filter_labels=True,
                    filter_ref_seq=True):
    """Skip chunks in which labels and/or ref_seq contain bases not in `alphabet`.

    :param sample_gen: generator of `Sample` named tuples.
    :param alphabet: set of str of allowed bases. If None, automatically generated from decoding.
    :param filter_labels: bool, whether to filter on labels.
    :param filter_ref_seq: bool, whether to filter on ref_seq.

    :yields: `Sample` named tuples.
    """
    if alphabet is None:
        alphabet = set([c for c in _alphabet_ + _gap_])
    logger = get_named_logger('AlphaFilter')
    logger.debug("alphabet: {}".format(alphabet))

    alphabet = set([encoding[c] for c in alphabet])

    def _find_bad_bases(s, field, alphabet):
        seq_rle = getattr(s, field)
        bases = set(np.unique(seq_rle['base']))
        if not bases.issubset(alphabet):
            diff = [decoding[i] for i in bases - alphabet]
            msg = "Skipping {}:{}-{} ({} bases) due to {} {}"
            pos = s.positions
            logger.info(
                msg.format(s.ref_name, pos['major'][0], pos['major'][-1],
                           len(pos), field, diff))
            return True

    for s in sample_gen:
        if filter_labels and s.labels is not None and _find_bad_bases(
                s, 'labels', alphabet):
            continue
        if filter_ref_seq and s.ref_seq is not None and _find_bad_bases(
                s, 'ref_seq', alphabet):
            continue
        yield s
Exemplo n.º 18
0
    def __init__(self, batcher, dataset='train', mini_epochs=1, seed=None):
        """Interface for keras to a `TrainBatcher` for training and validation
        batches.

        :param batcher: a `medaka.inference.TrainBatcher` instance.
        :param dataset: one of 'train' or 'validation'.
        :param mini_epochs: factor by which to rescale the number of batches
            in an epoch (useful to output checkpoints more frequently).
        :param seed: random seed for shuffling data.

        """
        self.batcher = batcher
        self.dataset = dataset
        self.mini_epochs = mini_epochs
        self.batch_size = self.batcher.batch_size
        if seed is not None:
            np.random.seed(seed)
        self.epoch = 1

        if dataset == 'train':
            self.data = batcher.train_samples
        elif dataset == 'validation':
            self.data = batcher.valid_samples
            if mini_epochs != 1:
                raise ValueError(
                    "'mini_epochs' must be equal to 1 for validation data.")
        else:
            raise ValueError("'dataset' should be 'train' or 'validation'.")

        original_size = len(self.data)
        self.n_batches = len(self.data) // self.batch_size
        self.data = self.data[:self.n_batches * self.batch_size]
        np.random.shuffle(self.data)
        self.logger = get_named_logger('{}Batcher'.format(
            dataset.capitalize()))
        self.logger.info(
            '{} batches of {} samples ({}), from {} original.'.format(
                self.n_batches, self.batch_size, len(self.data),
                original_size))
Exemplo n.º 19
0
def predict(args):
    """Inference program."""
    logger_level = logging.getLogger(__package__).level
    if logger_level > logging.DEBUG:
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

    from keras.models import load_model
    from keras import backend as K

    args.regions = get_regions(args.bam, region_strs=args.regions)
    logger = get_named_logger('Predict')
    logger.info('Processing region(s): {}'.format(' '.join(
        str(r) for r in args.regions)))

    # write class names to output
    with DataStore(args.model) as ds:
        meta = ds.meta
    with DataStore(args.output, 'w', verify_on_close=False) as ds:
        ds.update_meta(meta)

    logger.info("Setting tensorflow threads to {}.".format(args.threads))
    K.tf.logging.set_verbosity(K.tf.logging.ERROR)
    K.set_session(
        K.tf.Session(config=K.tf.ConfigProto(
            intra_op_parallelism_threads=args.threads,
            inter_op_parallelism_threads=args.threads)))

    # Split overly long regions to maximum size so as to not create
    #   massive feature matrices
    MAX_REGION_SIZE = int(1e6)  # 1Mb
    regions = []
    for region in args.regions:
        if region.size > MAX_REGION_SIZE:
            regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp)
        else:
            regs = [region]
        regions.extend(regs)

    logger.info("Processing {} long region(s) with batching.".format(
        len(regions)))
    model = medaka.models.load_model(args.model, time_steps=args.chunk_len)
    # the returned regions are those where the pileup width is smaller than chunk_len
    remainder_regions = run_prediction(args.output,
                                       args.bam,
                                       regions,
                                       model,
                                       args.model,
                                       args.rle_ref,
                                       args.read_fraction,
                                       args.chunk_len,
                                       args.chunk_ovlp,
                                       batch_size=args.batch_size,
                                       save_features=args.save_features,
                                       tag_name=args.tag_name,
                                       tag_value=args.tag_value,
                                       tag_keep_missing=args.tag_keep_missing)

    # short/remainder regions: just do things without chunking. We can do this
    # here because we now have the size of all pileups (and know they are small).
    # TODO: can we avoid calculating pileups twice whilst controlling memory?
    if len(remainder_regions) > 0:
        logger.info("Processing {} short region(s).".format(
            len(remainder_regions)))
        model = medaka.models.load_model(args.model, time_steps=None)
        for region in remainder_regions:
            new_remainders = run_prediction(
                args.output,
                args.bam,
                [region[0]],
                model,
                args.model,
                args.rle_ref,
                args.read_fraction,
                args.chunk_len,
                args.chunk_ovlp,  # these won't be used
                batch_size=args.batch_size,
                save_features=args.save_features,
                tag_name=args.tag_name,
                tag_value=args.tag_value,
                tag_keep_missing=args.tag_keep_missing,
                enable_chunking=False)
            if len(new_remainders) > 0:
                # shouldn't get here
                ignored = [x[0] for x in new_remainders]
                n_ignored = len(ignored)
                logger.warning("{} regions were not processed: {}.".format(
                    n_ignored, ignored))

    logger.info("Finished processing all regions.")

    if args.check_output:
        logger.info("Validating and finalising output data.")
        with DataStore(args.output, 'a') as ds:
            pass
Exemplo n.º 20
0
    def __init__(self,
                 features,
                 max_label_len,
                 validation=0.2,
                 seed=0,
                 sparse_labels=True,
                 batch_size=500,
                 threads=1):
        """
        Class to server up batches of training / validation data.

        :param features: iterable of str, training feature files.
        :param max_label_len: int, maximum label length, longer labels will be truncated.
        :param validation: float, fraction of batches to use for validation, or
                iterable of str, validation feature files.
        :param seed: int, random seed for separation of batches into training/validation.
        :param sparse_labels: bool, create sparse labels.

        """
        self.logger = get_named_logger('TrainBatcher')

        self.features = features
        self.max_label_len = max_label_len
        self.validation = validation
        self.seed = seed
        self.sparse_labels = sparse_labels
        self.batch_size = batch_size

        di = DataIndex(self.features, threads=threads)
        self.samples = di.samples.copy()
        self.meta = di.meta.copy()
        self.label_counts = self.meta['medaka_label_counts']

        # check sample size using first batch
        test_sample, test_fname = self.samples[0]
        with DataStore(test_fname) as ds:
            self.feature_shape = ds.load_sample(test_sample).features.shape
        self.logger.info("Sample features have shape {}".format(
            self.feature_shape))

        if isinstance(self.validation, float):
            np.random.seed(self.seed)
            np.random.shuffle(self.samples)
            n_sample_train = int((1 - self.validation) * len(self.samples))
            self.train_samples = self.samples[:n_sample_train]
            self.valid_samples = self.samples[n_sample_train:]
            msg = 'Randomly selected {} ({:3.2%}) of features for validation (seed {})'
            self.logger.info(
                msg.format(len(self.valid_samples), self.validation,
                           self.seed))
        else:
            self.train_samples = self.samples
            self.valid_samples = DataIndex(self.validation).samples.copy()
            msg = 'Found {} validation samples equivalent to {:3.2%} of all the data'
            fraction = len(self.valid_samples) / len(self.valid_samples) + len(
                self.train_samples)
            self.logger.info(msg.format(len(self.valid_samples), fraction))

        msg = 'Got {} samples in {} batches ({} labels) for {}'
        self.logger.info(
            msg.format(len(self.train_samples),
                       len(self.train_samples) // batch_size,
                       len(self.train_samples) * self.feature_shape[0],
                       'training'))
        self.logger.info(
            msg.format(len(self.valid_samples),
                       len(self.valid_samples) // batch_size,
                       len(self.valid_samples) * self.feature_shape[0],
                       'validation'))

        self.n_classes = len(self.label_counts)

        # get label encoding, given max_label_len
        self.logger.info("Max label length: {}".format(
            self.max_label_len if self.max_label_len is not None else 'inf'))
        self.label_encoding, self.label_decoding, self.label_counts = process_labels(
            self.label_counts, max_label_len=self.max_label_len)
Exemplo n.º 21
0
def pileup_counts(region,
                  bam,
                  dtype_prefixes=None,
                  region_split=100000,
                  workers=4,
                  tag_name=None,
                  tag_value=None,
                  keep_missing=False):
    """Create pileup counts feature array for region.

    :param region: `Region` object
    :param bam: .bam file with alignments.
    :param dtype_prefixes: prefixes for query names which to separate counts.
        If `None` (or of length 1), counts are not split.
    :param tag_name: two letter tag name by which to filter reads.
    :param tag_value: integer value of tag for reads to keep.
    :param keep_missing: whether to keep reads when tag is missing.

    :returns: pileup counts array, reference positions, insertion postions
    """
    ffi, lib = libmedaka.ffi, libmedaka.lib
    logger = get_named_logger('PileUp')

    num_dtypes, dtypes = 1, ffi.NULL
    if isinstance(dtype_prefixes, str):
        dtype_prefixes = [dtype_prefixes]
    if dtype_prefixes is not None and len(dtype_prefixes) > 1:
        num_dtypes = len(dtype_prefixes)
        _dtypes = [ffi.new("char[]", d.encode()) for d in dtype_prefixes]
        dtypes = ffi.new("char *[]", _dtypes)
    if tag_name is None:
        tag_name = ffi.new("char[2]", "".encode())
        tag_value = 0
        keep_missing = False
    else:
        if len(tag_name) > 2:
            raise ValueError("'tag_value' must be a length-2 string.")
        tag_name = ffi.new("char[2]", tag_name.encode())

    featlen = lib.featlen

    def _process_region(reg):
        # htslib start is 1-based, Region object is 0-based
        region_str = '{}:{}-{}'.format(reg.ref_name, reg.start + 1, reg.end)

        counts = lib.calculate_pileup(region_str.encode(), bam.encode(),
                                      num_dtypes, dtypes, tag_name, tag_value,
                                      keep_missing)

        size_sizet = np.dtype(np.uintp).itemsize
        np_counts = np.frombuffer(
            ffi.buffer(counts.counts,
                       size_sizet * counts.n_cols * featlen * num_dtypes),
            dtype=np.uintp).reshape(counts.n_cols,
                                    featlen * num_dtypes).copy()

        positions = np.empty(counts.n_cols,
                             dtype=[('major', int), ('minor', int)])
        np.copyto(
            positions['major'],
            np.frombuffer(ffi.buffer(counts.major, size_sizet * counts.n_cols),
                          dtype=np.uintp))
        np.copyto(
            positions['minor'],
            np.frombuffer(ffi.buffer(counts.minor, size_sizet * counts.n_cols),
                          dtype=np.uintp))

        lib.destroy_plp_data(counts)
        return np_counts, positions

    # split large regions for performance
    regions = region.split(region_split)
    with concurrent.futures.ThreadPoolExecutor(
            max_workers=workers) as executor:
        results = list(executor.map(_process_region, regions))

    # First pass: need to check for discontinuities within chunks,
    # these show up as >2 changes in the major coordinate
    _results = list()
    for counts, positions in results:
        move = np.ediff1d(positions['major'])
        gaps = np.where(move > 2)[0] + 1
        if len(gaps) == 0:
            _results.append((counts, positions))
        else:
            logger.info("Splitting discontiguous pileup region.")
            start = 0
            for i in gaps:
                _results.append((counts[start:i], positions[start:i]))
                start = i
            _results.append((counts[start:], positions[start:]))
    results = _results

    # Second pass: stitch abutting chunks together, anything not neighbouring
    # is kept separate whether it came from the same chunk originally or not

    def _finalize_chunk(c_buf, p_buf):
        chunk_counts = np.concatenate(c_buf)
        chunk_positions = np.concatenate(p_buf)
        # get rid of 'first' counts row for each datatype (counts of
        # alternative bases)
        mask = np.ones(chunk_counts.shape[1], dtype=bool)
        mask[[x * featlen for x in range(0, num_dtypes)]] = False
        chunk_counts = chunk_counts[:, mask]
        return chunk_counts, chunk_positions

    counts_buffer, positions_buffer = list(), list()
    chunk_results = list()
    last = None
    for counts, positions in results:
        if len(positions) == 0:
            continue
        first = positions['major'][0]
        if len(counts_buffer) == 0 or first - last == 1:
            # new or contiguous
            counts_buffer.append(counts)
            positions_buffer.append(positions)
            last = positions['major'][-1]
        else:
            # discontinuity
            chunk_results.append(
                _finalize_chunk(counts_buffer, positions_buffer))
            counts_buffer = [counts]
            positions_buffer = [positions]
            last = positions['major'][-1]
    if len(counts_buffer) != 0:
        chunk_results.append(_finalize_chunk(counts_buffer, positions_buffer))

    return chunk_results
Exemplo n.º 22
0
    def __init__(
        self,
        ref_mode: str = None,
        max_hp_len: int = 10,
        log_min: int = None,
        normalise: str = 'total',
        with_depth: bool = False,
        consensus_as_ref: bool = False,
        is_compressed: bool = True,
        dtypes=('', ),
        tag_name=None,
        tag_value=None,
        tag_keep_missing=False,
        sym_indels=False,
    ):
        """Class to support multiple feature encodings

        :param ref_mode: str, how to represent the reference.
        :param max_hp_len: int, longest homopolymer run which can be represented, longer runs will be truncated.
        :param log_min: int, take log10 of counts/fractions and set zeros to 10**log_min.
        :param normalise: str, how to normalise the data.
        :param with_depth: bool, whether to include a feature describing the total depth.
        :param consensus_as_ref: bool, whether to use a naive max-count consensus instead of the reference.
        :param is_compressed: bool, whether to use HP compression. If false, treat as uncompressed.
        :param dtypes: iterable of str, read id prefixes of distinct data types that should be counted separately.
        :param tag_name: two letter tag name by which to filter reads.
        :param tag_value: integer value of tag for reads to keep.
        :param tag_keep_missing: whether to keep reads when tag is missing.
        :param sym_indels: bool, whether to count a lack of an insertion as a deletion.

        """
        self.ref_mode = ref_mode
        self.consensus_as_ref = consensus_as_ref
        self.max_hp_len = max_hp_len
        self.log_min = log_min
        self.normalise = normalise
        self.feature_dtype = np.float32 if (
            self.normalise is not None
            or self.log_min is not None) else np.uint64
        self.with_depth = with_depth
        self.is_compressed = is_compressed
        self.logger = get_named_logger('Feature')
        self.dtypes = dtypes
        self.tag_name = tag_name
        self.tag_value = tag_value
        self.tag_keep_missing = tag_keep_missing
        self.sym_indels = sym_indels

        if self.ref_mode not in self._ref_modes_:
            raise ValueError('ref_mode={} is not one of {}'.format(
                self.ref_mode, self._ref_modes_))
        if self.normalise not in self._norm_modes_:
            raise ValueError('normalise={} is not one of {}'.format(
                self.normalise, self._norm_modes_))

        opts = inspect.signature(FeatureEncoder.__init__).parameters.keys()
        opts = {k: getattr(self, k) for k in opts if k != 'self'}

        read_decoding = []
        for dtype in self.dtypes:
            # set up one-hot encoding of read run lengths for each dtype
            read_decoding += [(dtype, ) + k for k in itertools.product((
                True, False), _alphabet_, range(1, max_hp_len + 1))]

            # forward and reverse gaps
            read_decoding += [(dtype, True, None, 1), (dtype, False, None, 1)]

        if self.ref_mode == 'onehot':
            ref_decoding = [('ref', b, l) for b, l in itertools.product(
                alphabet, range(1, max_hp_len + 1))]
            ref_decoding.append(('ref', _gap_, 1))  # gaps
        elif self.ref_mode == 'base_length':
            ref_decoding = ['ref_base', 'ref_length']
            self.ref_base_encoding = {
                b: i
                for i, b in enumerate(_alphabet_ + _gap_)
            }
        elif self.ref_mode == 'index':
            ref_decoding = ['ref_index']
        else:
            ref_decoding = []

        self.decoding = tuple(read_decoding + ref_decoding)
        if self.with_depth:
            self.decoding = self.decoding + ('depth', )
        self.encoding = OrderedDict(
            ((a, i) for i, a in enumerate(self.decoding)))
        self.logger.debug("Creating features with: {}".format(opts))

        self.logger.debug("Label decoding is:\n{}".format('\n'.join(
            '{}: {}'.format(i, x) for i, x in enumerate(self.decoding))))
Exemplo n.º 23
0
def predict(args):
    """Inference program."""
    os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"
    from keras.models import load_model
    from keras import backend as K

    args.regions = get_regions(args.bam, region_strs=args.regions)
    logger = get_named_logger('Predict')
    logger.info('Processing region(s): {}'.format(' '.join(str(r) for r in args.regions)))

    # write class names to output
    with DataStore(args.model) as ds:
        meta = ds.meta
    with DataStore(args.output, 'w') as ds:
        ds.update_meta(meta)

    logger.info("Setting tensorflow threads to {}.".format(args.threads))
    K.tf.logging.set_verbosity(K.tf.logging.ERROR)
    K.set_session(K.tf.Session(
        config=K.tf.ConfigProto(
            intra_op_parallelism_threads=args.threads,
            inter_op_parallelism_threads=args.threads)
    ))

    # Split overly long regions to maximum size so as to not create
    #   massive feature matrices, then segregate those which cannot be
    #   batched.
    MAX_REGION_SIZE = int(1e6)  # 1Mb
    long_regions = []
    short_regions = []
    for region in args.regions:
        if region.size > MAX_REGION_SIZE:
            regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp)
        else:
            regs = [region]
        for r in regs:
            if r.size > args.chunk_len:
                long_regions.append(r)
            else:
                short_regions.append(r)
    logger.info("Found {} long and {} short regions.".format(
        len(long_regions), len(short_regions)))

    if len(long_regions) > 0:
        logger.info("Processing long regions.")
        model = medaka.models.load_model(args.model, time_steps=args.chunk_len)
        run_prediction(
            args.output, args.bam, long_regions, model, args.model, args.rle_ref, args.read_fraction,
            args.chunk_len, args.chunk_ovlp,
            batch_size=args.batch_size, save_features=args.save_features,
            tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing
        )

    # short regions must be done individually since they have differing lengths
    #   TODO: consider masking (it appears slow to apply wholesale), maybe
    #         step down args.chunk_len by a constant factor until nothing remains.
    if len(short_regions) > 0:
        logger.info("Processing short regions")
        model = medaka.models.load_model(args.model, time_steps=None)
        for region in short_regions:
            chunk_len = region.size // 2
            chunk_ovlp = chunk_len // 10 # still need overlap as features will be longer
            run_prediction(
                args.output, args.bam, [region], model, args.model, args.rle_ref, args.read_fraction,
                chunk_len, chunk_ovlp,
                batch_size=args.batch_size, save_features=args.save_features,
                tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing
            )
    logger.info("Finished processing all regions.")
Exemplo n.º 24
0
def run_training(train_name,
                 batcher,
                 model_fp=None,
                 epochs=5000,
                 class_weight=None,
                 n_mini_epochs=1,
                 threads_io=1):
    """Run training."""
    from keras.callbacks import CSVLogger, TensorBoard, EarlyStopping, ReduceLROnPlateau
    from medaka.keras_ext import ModelMetaCheckpoint, SequenceBatcher, BatchQueue

    logger = get_named_logger('RunTraining')

    if model_fp is None:
        model_name = medaka.models.default_model
        model_kwargs = {
            k: v.default
            for (k, v) in inspect.signature(
                medaka.models.model_builders[model_name]).parameters.items()
            if v.default is not inspect.Parameter.empty
        }
    else:
        with DataStore(model_fp) as ds:
            model_name = ds.meta['medaka_model_name']
            model_kwargs = ds.meta['medaka_model_kwargs']

    opt_str = '\n'.join(
        ['{}: {}'.format(k, v) for k, v in model_kwargs.items()])
    logger.info('Building {} model with: \n{}'.format(model_name, opt_str))
    num_classes = len(batcher.label_counts)
    timesteps, feat_dim = batcher.feature_shape
    model = medaka.models.model_builders[model_name](timesteps, feat_dim,
                                                     num_classes,
                                                     **model_kwargs)

    if model_fp is not None:
        try:
            model.load_weights(model_fp)
            logger.info("Loading weights from {}".format(model_fp))
        except:
            logger.info("Could not load weights from {}".format(model_fp))

    msg = "feat_dim: {}, timesteps: {}, num_classes: {}"
    logger.info(msg.format(feat_dim, timesteps, num_classes))
    model.summary()

    model_details = batcher.meta.copy()

    model_details['medaka_model_name'] = model_name
    model_details['medaka_model_kwargs'] = model_kwargs
    model_details['medaka_label_decoding'] = batcher.label_decoding

    opts = dict(verbose=1, save_best_only=True, mode='max')

    callbacks = [
        # Best model according to training set accuracy
        ModelMetaCheckpoint(model_details,
                            os.path.join(train_name, 'model.best.hdf5'),
                            monitor='cat_acc',
                            **opts),
        # Best model according to validation set accuracy
        ModelMetaCheckpoint(model_details,
                            os.path.join(train_name, 'model.best.val.hdf5'),
                            monitor='val_cat_acc',
                            **opts),
        # Best model according to validation set qscore
        ModelMetaCheckpoint(model_details,
                            os.path.join(train_name,
                                         'model.best.val.qscore.hdf5'),
                            monitor='val_qscore',
                            **opts),
        # Checkpoints when training set accuracy improves
        ModelMetaCheckpoint(
            model_details,
            os.path.join(
                train_name,
                'model-acc-improvement-{epoch:02d}-{cat_acc:.2f}.hdf5'),
            monitor='cat_acc',
            **opts),
        ModelMetaCheckpoint(
            model_details,
            os.path.join(
                train_name,
                'model-val_acc-improvement-{epoch:02d}-{val_cat_acc:.2f}.hdf5'
            ),
            monitor='val_cat_acc',
            **opts),
        ## Reduce learning rate when no improvement
        #ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5,
        #    verbose=1, min_delta=0.0001, cooldown=0, min_lr=0),
        # Stop when no improvement
        EarlyStopping(monitor='val_loss', patience=20),
        # Log of epoch stats
        CSVLogger(os.path.join(train_name, 'training.log')),
        # Allow us to run tensorboard to see how things are going. Some
        #   features require validation data, not clear why.
        #TensorBoard(log_dir=os.path.join(train_name, 'logs'),
        #            histogram_freq=5, batch_size=100, write_graph=True,
        #            write_grads=True, write_images=True)
    ]

    if class_weight is not None:
        loss = weighted_categorical_crossentropy(class_weight)
        logger.info("Using weighted_categorical_crossentropy loss function")
    else:
        loss = 'sparse_categorical_crossentropy'
        logger.info("Using {} loss function".format(loss))

    model.compile(
        loss=loss,
        optimizer='nadam',
        metrics=[cat_acc, qscore],
    )

    if n_mini_epochs == 1:
        logger.info(
            "Not using mini_epochs, an epoch is a full traversal of the training data"
        )
    else:
        logger.info(
            "Using mini_epochs, an epoch is a traversal of 1/{} of the training data"
            .format(n_mini_epochs))

    with ProcessPoolExecutor(threads_io) as executor:
        logger.info("Starting data queues.")
        prep_function = functools.partial(
            batcher.sample_to_x_y_bq_worker,
            max_label_len=batcher.max_label_len,
            label_encoding=batcher.label_encoding,
            sparse_labels=batcher.sparse_labels,
            n_classes=batcher.n_classes)
        # TODO: should take mini_epochs into account here
        train_queue = BatchQueue(batcher.train_samples,
                                 prep_function,
                                 batcher.batch_size,
                                 executor,
                                 seed=batcher.seed,
                                 name='Train',
                                 maxsize=100)
        valid_queue = BatchQueue(batcher.valid_samples,
                                 prep_function,
                                 batcher.batch_size,
                                 executor,
                                 seed=batcher.seed,
                                 name='Valid',
                                 maxsize=100)

        # run training
        logger.info("Starting training.")
        model.fit_generator(
            generator=train_queue.yield_batches(),
            steps_per_epoch=train_queue.n_batches // n_mini_epochs,
            validation_data=valid_queue.yield_batches(),
            validation_steps=valid_queue.n_batches,
            max_queue_size=2 * threads_io,
            workers=1,
            use_multiprocessing=False,
            epochs=epochs,
            callbacks=callbacks,
            class_weight=class_weight,
        )
        logger.info("Training finished.")
        train_queue.stop()
        valid_queue.stop()
Exemplo n.º 25
0
def stitch_from_probs(h5_fp, regions=None):
    """Join overlapping label probabilities from HDF5 files.

     Network outputs from multiple samples stored within a file are spliced
     together into a logically contiguous array and decoded to generate
     contiguous sequence(s).

    :param h5_fp: iterable of HDF5 filepaths
    :param regions: iterable of region (strings) to process

    :returns: list of (region string, sequence)
    """
    logger = common.get_named_logger('Stitch')
    if isinstance(regions, medaka.common.Region):
        regions = [regions]
    logger.info("Stitching regions: {}".format([str(r) for r in regions]))

    index = medaka.datastore.DataIndex(h5_fp)
    label_scheme = index.metadata['label_scheme']

    logger.debug("Label decoding is:\n{}".format(
        '\n'.join('{}: {}'.format(k, v)
                  for k, v in label_scheme._decoding.items())))

    def get_pos(sample, i):
        return '{}.{}'.format(
            sample.positions[i]['major'] + 1, sample.positions[i]['minor'])

    ref_assemblies = []
    for reg in regions:
        logger.info("Processing {}.".format(reg))
        data_gen = index.yield_from_feature_files(regions=[reg])
        seq_parts = list()
        cur_ref_name = ''
        cur_segment = None
        # first sample
        s1 = next(data_gen)
        start = get_pos(s1, 0)
        start_1 = None
        start_2 = None
        heuristic_use = 0

        for s2 in itertools.chain(data_gen, (None,)):
            s1_name = 'Unknown' if s1 is None else s1.name
            s2_name = 'Unknown' if s2 is None else s2.name

            # s1 is last chunk
            if s2 is None:
                end_1 = None
            else:
                # s2 ends before s1
                if s2.last_pos <= s1.last_pos:
                    logger.info('{} ends before {}, skipping.'.format(
                        s2_name, s1_name
                    ))
                    continue
                # s1 and s2 overlap by only one position
                # or there is no overlap between s1 and s2
                elif s2.first_pos >= s1.last_pos:
                    # trigger a break
                    end_1, start_2 = None, None
                else:
                    try:
                        end_1, start_2, heuristic = \
                            common.Sample.overlap_indices(s1, s2)
                        if heuristic:
                            logger.debug(
                                "Used heuristic to stitch {} and {}.".format(
                                    s1.name, s2.name))
                            heuristic_use += 1
                    except common.OverlapException as e:
                        logger.info(
                            "Unhandled overlap type whilst stitching chunks.")
                        raise(e)

            new_seq = label_scheme.decode_consensus(
                s1.slice(slice(start_1, end_1)))

            seq_parts.append(new_seq)

            if end_1 is None:
                if s1.ref_name != cur_ref_name:
                    cur_ref_name = s1.ref_name
                    cur_segment = 0
                else:
                    cur_segment += 1
                ref_assemblies.append((
                    '{}_segment{}'.format(cur_ref_name, cur_segment),
                    '{}:{}-{}'.format(cur_ref_name, start, get_pos(s1, -1)),
                    ''.join(seq_parts)))
                seq_parts = list()

                if s2 is not None and start_2 is None:
                    msg = 'There is no overlap betwen {} and {}'
                    logger.info(msg.format(s1_name, s2_name))
                    start = get_pos(s2, 0)

            s1 = s2
            start_1 = start_2
        logger.info("Used heuristic {} times for {}.".format(
            heuristic_use, reg))
    return ref_assemblies
Exemplo n.º 26
0
from medaka.datastore import DataStore, DataIndex
from medaka.common import get_named_logger

logger = get_named_logger('ModelLoad')


def load_model(fname, time_steps=None):
    """Load a model from an .hdf file.

    :param fname: .hdf file containing model.
    :param time_steps: number of time points in RNN, `None` for dynamic.

    ..note:: keras' `load_model` cannot handle CuDNNGRU layers, hence this
        function builds the model then loads the weights.
    """
    with DataStore(fname) as ds:
        meta = ds.meta
        num_features = len(meta['medaka_feature_decoding'])
        num_classes = len(meta['medaka_label_decoding'])
    build_model = model_builders[meta['medaka_model_name']]

    logger.info(
        "Building model (steps, features, classes): ({}, {}, {})".format(
            time_steps, num_features, num_classes))
    model = build_model(time_steps, num_features, num_classes,
                        **meta['medaka_model_kwargs'])
    logger.info("Loading weights from {}".format(fname))
    model.load_weights(fname)
    return model