Exemple #1
0
    def test_relative_position(self):
        # all 9 cases plus major and minor variations of overlapping, abutting and gapped
        #   different_ref_name = 'Samples come from different reference contigs.'
        #   forward_overlap = 'The end of s1 overlaps the start of s2.'
        #   reverse_overlap = 'The end of s2 overlaps the start of s1.'
        #   forward_abutted = 'The end of s1 abuts the start of s2.'
        #   reverse_abutted = 'The end of s2 abuts the start of s1.'
        #   forward_gapped = 's2 follows s1 with a gab inbetween.'
        #   reverse_gapped = 's1 follows s2 with a gab inbetween.'
        #   s2_within_s1 = 's2 is fully contained within s1.'
        #   s1_within_s2 = 's1 is fully contained within s2.'

        slices = [
            slice(0, 3),  # (0,0) -> (1,0) in self.samples[0]
            slice(3, 5),  # (2,0) -> (2,1) in self.samples[0]
            slice(5, 8),  # (2,2) -> (4,0) in self.samples[0]
            slice(8, 11)  # (4.1) -> (4.3) in self.samples[0]
        ]
        sliced = [self.samples[0].slice(sl) for sl in slices]
        sample_dict = self.samples[0]._asdict()
        sample_dict['ref_name'] = 'other'
        sample_other = Sample(**sample_dict)

        samples_expt = [
            ([self.samples[0], sample_other], Relationship.different_ref_name),
            (self.samples[:2],
             Relationship.forward_overlap),  # overlap of minor positions
            (self.samples[1:],
             Relationship.forward_overlap),  # overlap of major positions
            (self.samples[:2][::-1],
             Relationship.reverse_overlap),  # overlap of minor positions
            (self.samples[1:][::-1],
             Relationship.reverse_overlap),  # overlap of major positions
            (sliced[:2], Relationship.forward_abutted),  # (1,0) -> (2,0)
            (sliced[1:3], Relationship.forward_abutted),  # (2,1) -> (2,2)
            (sliced[::-1][2:], Relationship.reverse_abutted),  # (2,0) -> (1,0)
            (sliced[::-1][1:3],
             Relationship.reverse_abutted),  # (2,2) -> (2,1)
            ([sliced[0],
              sliced[2]], Relationship.forward_gapped),  # (1,0) -> (2,2)
            ([self.samples[0].slice(slice(0, 4)),
              sliced[2]], Relationship.forward_gapped),  # (2,0) -> (2,2)
            ([sliced[0], self.samples[0].slice(slice(6, None))],
             Relationship.forward_gapped),  # (1,0) -> (3,0)
            ([sliced[2], self.samples[0].slice(slice(0, 4))],
             Relationship.reverse_gapped),  # (2,2) -> (2,0)
            ([self.samples[0].slice(slice(6, None)),
              sliced[0]], Relationship.reverse_gapped),  # (3,0) -> (1,0)
            ([self.samples[0], sliced[0]], Relationship.s2_within_s1),
            ([self.samples[0], sliced[1]], Relationship.s2_within_s1),
            ([self.samples[0], sliced[2]], Relationship.s2_within_s1),
            ([self.samples[0], sliced[3]], Relationship.s2_within_s1),
            ([sliced[0], self.samples[0]], Relationship.s1_within_s2),
            ([sliced[1], self.samples[0]], Relationship.s1_within_s2),
            ([sliced[2], self.samples[0]], Relationship.s1_within_s2),
            ([sliced[3], self.samples[0]], Relationship.s1_within_s2),
        ]
        for samples, expt in samples_expt:
            self.assertIs(Sample.relative_position(*samples), expt)
Exemple #2
0
    def test_from_samples(self):
        # check we can concat a single sample
        concat1 = Sample.from_samples([self.samples[0]])
        self.assertEqual(concat1, self.samples[0])

        # check we can concat 3 samples
        slices = [slice(0, 5), slice(5, 8),
                  slice(8, 11)]  # these should span entire samples[0]
        sliced = [self.samples[0].slice(sl) for sl in slices]

        concat3 = Sample.from_samples(sliced)
        self.assertEqual(concat3, self.samples[0])

        # also check raises an exception if not forward abutting
        self.assertRaises(ValueError, Sample.from_samples, sliced[::-1])
Exemple #3
0
    def _get_sorted_index(self):
        """Get index of samples indexed by reference and ordered by start pos.

        :returns: {ref_name: [sample dicts sorted by start]}
        """

        ref_names = defaultdict(list)

        for key, f in self.samples:
            d = Sample.decode_sample_name(key)
            if d is not None:
                d['key'] = key
                d['filename'] = f
                ref_names[d['ref_name']].append(d)

        # sort dicts so that refs are in order and within a ref, chunks are in order
        ref_names_ordered = OrderedDict()

        get_major_minor = lambda x: tuple((int(i) for i in x.split('.')))
        # sort by start and -end so that if we have two samples with the same
        # start but differrent end points, the longest sample comes first
        sorter = lambda x: (get_major_minor(x['start']) + tuple((-i for i in get_major_minor(x['end']))))
        for ref_name in sorted(ref_names.keys()):
            ref_names[ref_name].sort(key=sorter)
            ref_names_ordered[ref_name] = ref_names[ref_name]

        return ref_names_ordered
Exemple #4
0
    def test_overlap_indices(self):
        # check we get right thing for overlap and abutting major and minor
        # and that we get an exception if we give it any other case.
        slices = [
            slice(0, 3),  # (0,0) -> (1,0) in self.samples[0]
            slice(3, 5),  # (2,0) -> (2,1) in self.samples[0]
            slice(5, 8),  # (2,2) -> (4,0) in self.samples[0]
            slice(8, 11)  # (4.1) -> (4.3) in self.samples[0]
        ]
        sliced = [self.samples[0].slice(sl) for sl in slices]

        samples_expt = [
            (
                self.samples[:2], (9, 1, False)
            ),  # overlap of minor inds with odd number of overlapping positions
            (
                self.samples[1:], (6, 2, False)
            ),  # overlap of major inds with even number of overlapping positions
            (sliced[:2], (None, None, False)),  # abuts
            (sliced[1:3], (None, None, False)),  # abuts
            (sliced[2:], (None, None, False)),  # abuts
        ]
        for samples, expt in samples_expt:
            self.assertEqual(Sample.overlap_indices(*samples), expt)

        self.assertRaises(OverlapException, Sample.overlap_indices, samples[1],
                          samples[0])
Exemple #5
0
 def test_gapped(self):
     # check we get is_last_in_contig if we have a gap between samples
     trimmed = list(Sample.trim_samples((s for s in self.samples[::2])))
     for i, (expt, (got, is_last_in_contig,
                    heuristic)) in enumerate(zip(self.samples[::2],
                                                 trimmed)):
         self.assertEqual(got, expt)
         self.assertTrue(is_last_in_contig)
Exemple #6
0
    def test_single_sample(self):
        # test that if we provide a single sample, we get the same sample back

        results = list(Sample.trim_samples(iter([self.samples[0]])))
        self.assertEqual(len(results), 1)
        got, is_last_in_contig, heuristic = results[0]
        self.assertEqual(got, self.samples[0])
        self.assertTrue(is_last_in_contig)
Exemple #7
0
    def test_works(self):
        # test simple case of 3 chained samples

        trimmed = list(Sample.trim_samples((s for s in self.samples)))

        for i, (expt, (got, is_last_in_contig,
                       heuristic)) in enumerate(zip(self.sliced, trimmed)):
            self.assertEqual(got, expt)
            if i == len(self.sliced) - 1:
                self.assertTrue(is_last_in_contig)
            else:
                self.assertFalse(is_last_in_contig)
def run_prediction(output, bam, regions, model, model_file, rle_ref, read_fraction, chunk_len, chunk_ovlp,
                   batch_size=200, save_features=False, tag_name=None, tag_value=None, tag_keep_missing=False):
    """Inference worker."""

    logger = get_named_logger('PWorker')

    def sample_gen():
        # chain all samples whilst dispensing with generators when done
        #   (they hold the feature vector in memory until they die)
        for region in regions:
            data_gen = SampleGenerator(
                bam, region, model_file, rle_ref, read_fraction,
                chunk_len=chunk_len, chunk_overlap=chunk_ovlp,
                tag_name=tag_name, tag_value=tag_value,
                tag_keep_missing=tag_keep_missing)
            yield from data_gen.samples
    batches = background_generator(
        grouper(sample_gen(), batch_size), 10
    )

    total_region_mbases = sum(r.size for r in regions) / 1e6
    logger.info("Running inference for {:.1f}M draft bases.".format(total_region_mbases))

    with DataStore(output, 'a') as ds:
        mbases_done = 0

        t0 = now()
        tlast = t0
        for data in batches:
            x_data = np.stack([x.features for x in data])
            class_probs = model.predict_on_batch(x_data)
            mbases_done += sum(x.span for x in data) / 1e6
            mbases_done = min(mbases_done, total_region_mbases)  # just to avoid funny log msg
            t1 = now()
            if t1 - tlast > 10:
                tlast = t1
                msg = '{:.1%} Done ({:.1f}/{:.1f} Mbases) in {:.1f}s'
                logger.info(msg.format(mbases_done / total_region_mbases, mbases_done, total_region_mbases, t1 - t0))

            best = np.argmax(class_probs, -1)
            for sample, prob, pred, feat in zip(data, class_probs, best, x_data):
                # write out positions and predictions for later analysis
                sample_d = sample._asdict()
                sample_d['label_probs'] = prob
                sample_d['features'] = feat if save_features else None
                ds.write_sample(Sample(**sample_d))

    logger.info('All done')
    return None
Exemple #9
0
 def setUpClass(cls):
     pos = np.array([(0, 0), (0, 1), (1, 0), (2, 0), (2, 1), (2, 2), (3, 0),
                     (4, 0), (4, 1), (4, 2), (4, 3)],
                    dtype=[('major', int), ('minor', int)])
     data_dim = 10
     data = np.zeros(shape=(len(pos), data_dim))
     cls.sample = Sample(ref_name='contig1',
                         features=data,
                         ref_seq=None,
                         labels=data,
                         positions=pos,
                         label_probs=data)
     cls.file = tempfile.NamedTemporaryFile()
     with datastore.DataStore(cls.file.name, 'w') as store:
         store.write_sample(cls.sample)
Exemple #10
0
    def test_decode_sample_name(self):
        expected = [
            {
                'ref_name': 'contig1',
                'start': '0.0',
                'end': '4.3'
            },
            {
                'ref_name': 'contig1',
                'start': '4.1',
                'end': '7.0'
            },
        ]

        for expt, sample in zip(expected, self.samples):
            self.assertEqual(expt, Sample.decode_sample_name(sample.name))
Exemple #11
0
    def load_sample(self, key):
        """Load `Sample` object from HDF5

        :param key: str, sample name.
        :returns: `Sample` object.
        """
        s = {}
        for field in Sample._fields:
            pth = '{}/{}/{}'.format(self._sample_path_, key, field)
            if pth in self.fh:
                s[field] = self.fh[pth][()]
                if isinstance(s[field], np.ndarray) and isinstance(s[field][0], type(b'')):
                    s[field] = np.char.decode(s[field])
            else:
                s[field] = None
        return Sample(**s)
Exemple #12
0
    def test_030_bams_to_training_samples_simple(self):
        reads_bam = tempfile.NamedTemporaryFile(suffix='.bam').name
        truth_bam = tempfile.NamedTemporaryFile(suffix='.bam').name

        # we had a bug caused by missing qualities and bad indexing...
        data = copy.deepcopy(simple_data['calls'])
        data[0]['quality'] = None

        create_simple_bam(reads_bam, data)
        create_simple_bam(truth_bam, [simple_data['truth']])
        encoder = medaka.features.CountsFeatureEncoder(normalise='total')
        label_scheme = medaka.labels.HaploidLabelScheme()
        region = Region('ref', 0, 100)
        result = encoder.bams_to_training_samples(truth_bam,
                                                  reads_bam,
                                                  region,
                                                  label_scheme,
                                                  min_length=0)[0]

        expected = Sample(
            ref_name='ref',
            features=np.array(
                [[0.5, 0., 0., 0., 0.5, 0., 0., 0., 0., 0.],
                 [0., 0.5, 0., 0., 0., 0.5, 0., 0., 0., 0.],
                 [0.5, 0., 0., 0., 0.5, 0., 0., 0., 0., 0.],
                 [0., 0.25, 0., 0.25, 0., 0., 0., 0.25, 0., 0.25],
                 [0.25, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0.5, 0., 0., 0., 0.5, 0., 0., 0.],
                 [0.5, 0., 0., 0., 0.5, 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0.5, 0., 0., 0., 0.5, 0., 0.],
                 [0., 0., 0.5, 0., 0., 0., 0.5, 0., 0., 0.]],
                dtype='float32'),
            # the two insertions with respect to the draft are dropped
            labels=np.array([1, 2, 1, 4, 1, 3, 1, 4, 3]),  # A C A T A G A T C
            ref_seq=None,
            positions=np.array([(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (4, 0),
                                (5, 0), (6, 0), (7, 0)],
                               dtype=[('major', '<i8'), ('minor', '<i8')]),
            label_probs=None)

        np.testing.assert_equal(result.labels, expected.labels)
        np.testing.assert_equal(result.positions, expected.positions)
        np.testing.assert_equal(result.features, expected.features)
Exemple #13
0
 def setUpClass(cls):
     pos1 = np.array([(0, 0), (0, 1), (1, 0), (2, 0), (2, 1), (2, 2),
                      (3, 0), (4, 0), (4, 1), (4, 2), (4, 3)],
                     dtype=[('major', int), ('minor', int)])
     pos2 = np.array([(4, 1), (4, 2), (4, 3), (4, 5), (5, 0), (6, 0),
                      (6, 1), (7, 0)],
                     dtype=[('major', int), ('minor', int)])
     pos3 = np.array([(5, 0), (6, 0), (6, 1), (7, 0), (7, 1)],
                     dtype=[('major', int), ('minor', int)])
     cls.samples = []
     data_dim = 10
     for pos in pos1, pos2, pos3:
         data = np.random.random_sample(size=data_dim * len(pos)).reshape(
             (len(pos), data_dim))
         cls.samples.append(
             Sample(ref_name='contig1',
                    features=data,
                    ref_seq=None,
                    labels=data,
                    positions=pos,
                    label_probs=data))
Exemple #14
0
    def _get_sorted_index(self):
        """Get index of samples indexed by reference and ordered by start pos.

        :returns: {ref_name: [sample dicts sorted by start]}
        """

        ref_names = defaultdict(list)

        for key, f in self.samples:
            d = Sample.decode_sample_name(key)
            if d is not None:
                d['key'] = key
                d['filename'] = f
                ref_names[d['ref_name']].append(d)

        # sort dicts so that refs are in order and within a ref, chunks are in order
        ref_names_ordered = OrderedDict()
        for ref_name in sorted(ref_names.keys()):
            sorter = lambda x: float(x['start'])
            ref_names[ref_name].sort(key=sorter)
            ref_names_ordered[ref_name] = ref_names[ref_name]

        return ref_names_ordered
Exemple #15
0
 def test_is_empty(self):
     self.assertFalse(self.samples[0].is_empty)
     empty_sample = Sample('contig', None, None, None, self.samples[0][0:0],
                           None)
     self.assertTrue(empty_sample.is_empty)
Exemple #16
0
 def test_raises(self):
     # test we get an exception if we e.g. provide samples from different
     # refs or samples out of order
     self.assertRaises(OverlapException,
                       lambda x: list(Sample.trim_samples(x)),
                       iter(self.samples[::-1]))
Exemple #17
0
    def bam_to_sample_c(self, reads_bam, region):
        """Converts a section of an alignment pileup (as shown
        by e.g. samtools tview) to a base frequency feature array

        :param reads_bam: (sorted indexed) bam with read alignment to reference
        :param region: `Region` object with ref_name, start and end attributes.
        :param start: starting position within reference
        :param end: ending position within reference
        :returns: `Sample` object
        """
        assert self.ref_mode is None
        assert not self.consensus_as_ref
        assert self.max_hp_len == 1
        assert self.log_min is None
        assert self.normalise == 'total' or self.normalise == 'fwd_rev' or self.normalise is None
        assert not self.with_depth
        assert not self.is_compressed

        pileup = pileup_counts(region,
                               reads_bam,
                               dtype_prefixes=self.dtypes,
                               tag_name=self.tag_name,
                               tag_value=self.tag_value,
                               keep_missing=self.tag_keep_missing)
        samples = list()
        for counts, positions in pileup:

            if len(counts) == 0:
                msg = 'Pileup-feature is zero-length for {} indicating no reads in this region.'.format(
                    region)
                self.logger.warning(msg)
                samples.append(
                    Sample(ref_name=region.ref_name,
                           features=None,
                           labels=None,
                           ref_seq=None,
                           postions=positions,
                           label_probs=None))
                continue

            start, end = positions['major'][0], positions['major'][-1]
            if start != region.start or end + 1 != region.end:  # TODO investigate off-by-one
                self.logger.warning(
                    'Pileup counts do not span requested region, requested {}, '
                    'received {}-{}.'.format(region, start, end))

            # find the position index for parent major position of all minor positions
            minor_inds = np.where(positions['minor'] > 0)
            major_pos_at_minor_inds = positions['major'][minor_inds]
            major_ind_at_minor_inds = np.searchsorted(positions['major'],
                                                      major_pos_at_minor_inds,
                                                      side='left')

            depth = np.sum(counts, axis=1)
            depth[minor_inds] = depth[major_ind_at_minor_inds]

            if self.sym_indels:
                # make indels at ref and non-ref positions symmetric.
                # major columns otherwise have counts of reads with and without a
                # deletion, whilst minor (inserted) columns only have counts of
                # the reads with an isertion.
                # To make ref and non-ref positions symmetric, fill in counts of reads which don't have insertions
                # i.e. depth_del = depth_major - depth_ins
                for (dt, is_rev), inds in self.feature_indices.items():
                    dt_depth = np.sum(counts[:, inds], axis=1)
                    del_ind = self.encoding[(dt, is_rev, None, 1)]
                    counts[minor_inds, del_ind] = dt_depth[
                        major_ind_at_minor_inds] - dt_depth[minor_inds]
            if self.normalise == 'total':
                # normalize counts by total depth at major position, since the
                # counts include deletions this is a count of spanning reads
                feature_array = counts / np.maximum(1, depth).reshape(
                    (-1, 1))  # max just to avoid div error
            elif self.normalise == 'fwd_rev':
                # normalize forward and reverse and by dtype
                feature_array = np.empty_like(counts, dtype=float)
                for (dt, is_rev), inds in self.feature_indices.items():
                    dt_depth = np.sum(counts[:, inds], axis=1)
                    dt_depth[minor_inds] = dt_depth[major_ind_at_minor_inds]
                    feature_array[:, inds] = counts[:, inds] / np.maximum(
                        1, dt_depth).reshape(
                            (-1, 1))  # max just to avoid div error
            else:
                feature_array = counts

            feature_array = feature_array.astype(self.feature_dtype)

            sample = Sample(ref_name=region.ref_name,
                            features=feature_array,
                            labels=None,
                            ref_seq=None,
                            positions=positions,
                            label_probs=None)
            samples.append(sample)
            self.logger.info('Processed {} (median depth {})'.format(
                sample.name, np.median(depth)))
        return samples
Exemple #18
0
    def bam_to_sample(self,
                      reads_bam,
                      region,
                      reference=None,
                      read_fraction=None,
                      force_py=False):
        """Converts a section of an alignment pileup (as shown
        by e.g. samtools tview) to a base frequency feature array

        :param reads_bam: (sorted indexed) bam with read alignment to reference
        :param region: `Region` object with ref_name, start and end attributes.
        :param reference: reference `.fasta`, should correspond to `bam`.
            Required only for run length encoded references and reads.
        :param read_fraction: fraction of reads to use, if `None` use all.
        :param force_py: bool, if True, force use of python code (rather than c library).
        :returns: iterable of `Sample` objects
        """

        ref_rle = self.process_ref_seq(region.ref_name, reference)

        # Try to use fast c function if we can, else fall back on this function
        if not force_py and (ref_rle is None and read_fraction is None):
            try:
                return self.bam_to_sample_c(reads_bam, region)
            except Exception as e:
                self.logger.info(
                    'Could not process sample with bam_to_sample_c, using python code instead.\n({}).'
                    .format(e))
                pass
        if self.tag_name is not None:
            raise NotImplementedError(
                "Filtering alignments by tag is not supported in python code.")

        #TODO: The code below will abut discontiguous regions in a pileup i.e.
        #      where no reads span a reference position the major position
        #      is dropped from the pileup. The correct behaviour would be to
        #      split apart the sub-regions and return them separately.
        #      The C implementation does this splitting.

        if self.is_compressed:
            aln_to_pairs = partial(yield_compressed_pairs, ref_rle=ref_rle)
        elif self.max_hp_len == 1:
            aln_to_pairs = get_pairs
        else:
            aln_to_pairs = partial(get_pairs_with_hp_len, ref_seq=ref_rle)

        # accumulate data in dicts
        aln_counters = defaultdict(Counter)
        ref_bases = dict()
        with pysam.AlignmentFile(reads_bam, 'rb') as bamfile:
            aln_reads = bamfile.fetch(region.ref_name, region.start,
                                      region.end)
            if read_fraction is not None:
                low, high = read_fraction
                np.random.seed((int(now()) * region.start) % 2**32)
                fraction = ((high - low) * np.random.random_sample(1) + low)[0]
                aln_reads = [a for a in aln_reads]
                n_reads = len(aln_reads)
                n_reads_to_keep = max(int(fraction * n_reads), 1)
                replace = n_reads_to_keep > n_reads
                msg = "Resampling (replace {}) from {} to {} ({:.3f}) for {}"
                self.logger.debug(
                    msg.format(replace, n_reads, n_reads_to_keep, fraction,
                               region))
                aln_reads = np.random.choice(aln_reads,
                                             n_reads_to_keep,
                                             replace=replace)

            start = region.start
            end = region.end
            if start is None:
                start = 0
            if end is None:
                end = float('Inf')

            for aln in aln_reads:
                # get the dtype from the prefix of the query name
                try:
                    dtype = self.dtypes[np.where([
                        aln.query_name.startswith(dt) for dt in self.dtypes
                    ])[0][0]]
                except:
                    msg = "Skipping read {} as dtype not in {}"
                    self.logger.info(msg.format(aln.query_name, self.dtypes))
                    continue

                reverse = aln.is_reverse
                pairs = aln_to_pairs(aln)
                ins_count = 0
                for pair in itertools.dropwhile(
                        lambda x: (x.rpos is None) or (x.rpos < start), pairs):
                    if ((pair.rpos == aln.reference_end - 1)
                            or (pair.rpos is not None and pair.rpos >= end)):
                        break
                    if pair.rpos is None:
                        ins_count += 1
                    else:
                        ins_count = 0
                        current_pos = pair.rpos

                    (aln_counters[(current_pos, ins_count)][self.encoding[
                        dtype, reverse, pair.qbase,
                        min(pair.qlen, self.max_hp_len)]]) += 1

                    ref_base = pair.rbase.upper(
                    ) if pair.rbase is not None else '*'
                    ref_bases[(current_pos, ins_count)] = (ref_base, pair.rlen)

            # create feature array
            aln_cols = len(aln_counters)
            feature_len = len(self.encoding)
            feature_array = np.zeros(shape=(aln_cols, feature_len),
                                     dtype=self.feature_dtype)
            if self.log_min is not None:
                feature_array.fill(np.nan)
            ref_array = np.empty(shape=(aln_cols),
                                 dtype=[('base', int), ('run_length', int)])
            positions = np.empty(aln_cols,
                                 dtype=[('major', int), ('minor', int)])

            if aln_cols == 0:
                msg = 'Pileup-feature is zero-length for {} indicating no reads in this region.'.format(
                    region)
                self.logger.warning(msg)
                return [
                    Sample(ref_name=region.ref_name,
                           features=None,
                           labels=None,
                           ref_seq=None,
                           positions=positions,
                           label_probs=None)
                ]

            depth_array = np.empty(shape=(aln_cols), dtype=int)

            # keep track of which features are for fwd/rev reads of each dtype
            inds_by_type = self.feature_indices

            #TODO: refactor so common combinations of options can be handled as in C-function
            for i, ((pos, counts), (_, (ref_base, ref_len))) in \
                    enumerate(zip(sorted(aln_counters.items()),
                                sorted(ref_bases.items()))):
                positions[i] = pos
                ref_array[i] = (encoding[ref_base], ref_len)
                for j in counts.keys():
                    feature_array[i, j] = counts[j]

                if self.consensus_as_ref:
                    cons_i = np.argmax(feature_array[i])
                    cons_is_reverse, cons_base, cons_length = self.decoding[
                        cons_i]
                    ref_base = cons_base if cons_base is not None else _gap_
                    ref_len = cons_length

                if positions[i]['minor'] == 0:
                    major_depth = sum(counts.values())
                    # get the depth of each fwd and rev dtype
                    major_depths_by_type = {
                        t: sum((counts[i] for i in inds_by_type[t]))
                        for t in inds_by_type
                    }
                    assert sum(major_depths_by_type.values()) == major_depth

                if self.sym_indels and positions[i]['minor'] > 0:
                    # make indels at ref and non-ref positions symmetric (see comment in bam_to_sample_c).
                    for (dtype, is_rev), inds in inds_by_type.items():
                        del_ind = self.encoding[(dtype, is_rev, None, 1)]
                        assert feature_array[i, del_ind] == 0
                        feature_array[i, del_ind] = major_depths_by_type[
                            (dtype, is_rev)] - feature_array[i, inds].sum()

                if self.normalise is not None:
                    if self.normalise == 'total':
                        feature_array[i, :] /= max(major_depth, 1)
                    elif self.normalise == 'fwd_rev':
                        # normalize fwd and reverse seperately for each dtype
                        for dt, inds in inds_by_type.items():
                            feature_array[i, inds] /= max(
                                major_depths_by_type[dt], 1)

                depth_array[i] = major_depth

                if self.with_depth:
                    feature_array[i, self.encoding['depth']] = depth_array[i]

                if self.log_min is not None:  # counts/proportions and depth will be normalised
                    # when we take log of probs, make it easier for network by keeping all log of
                    # probs positive. add self.log_min to any log probs so they are positive,
                    # if self.log_min is 10, we can cope with depth up to 10**9
                    feature_array[i, :] = np.log10(feature_array[i, :],
                                                   out=feature_array[i, :])
                    feature_array[i, :] += self.log_min
                    feature_array[i, :] = np.nan_to_num(feature_array[i, :],
                                                        copy=False)

                if self.ref_mode == 'onehot':
                    feature_array[i, self.encoding[('ref', str(ref_base),
                                                    int(ref_len))]] = 1
                elif self.ref_mode == 'base_length':
                    feature_array[
                        i, self.encoding['ref_base']] = self.ref_base_encoding[
                            ref_base]
                    feature_array[i, self.encoding['ref_length']] = ref_len
                elif self.ref_mode == 'index':
                    # index of count which ref would contribute to were it a read
                    feature_array[i,
                                  self.encoding['ref_index']] = self.encoding[(
                                      False, min(ref_len,
                                                 self.max_hp_len), ref_base)]

            sample = Sample(ref_name=region.ref_name,
                            features=feature_array,
                            labels=None,
                            ref_seq=ref_array,
                            positions=positions,
                            label_probs=None)
            self.logger.info('Processed {} (median depth {})'.format(
                sample.name, np.median(depth_array)))

            return [sample]
Exemple #19
0
    def test_messy_overlap(self):
        dtype = [('major', int), ('minor', int)]
        pos = [
            np.array([
                (0, 0),
                (1, 0),
                (2, 0),
                (2, 1),
                (2, 2),
                (3, 0),
                (4, 0),
                (4, 1),
                (4, 2),
                (4, 3),
                (5, 0),
                (6, 0),
                (7, 0),
                (8, 0),
                (8, 1),
                (9, 0),
            ],
                     dtype=dtype),
            np.array(
                [
                    (3, 0),
                    (4, 0),
                    (4, 1),
                    (4, 2),  # (4,3) missing
                    (5, 0),
                    (6, 0),
                    (7, 0),
                    (8, 0),
                    (8, 1),
                    (9, 0),
                    (10, 0),
                    (10, 1),
                    (10, 2),
                ],
                dtype=dtype),
            np.array([
                (3, 0),
                (4, 0),
                (4, 1),
                (4, 2),
                (5, 0),
                (6, 0),
                (6, 1),
                (7, 0),
                (7, 1),
                (8, 0),
                (8, 1),
                (9, 0),
                (10, 0),
                (10, 1),
                (10, 2),
            ],
                     dtype=dtype),
            np.array(
                [
                    (3, 0),
                    (4, 0),
                    (4, 1),
                    (4, 2),  # (4,3) missing
                    (5, 0),
                    (5, 1),
                    (6, 0),
                    (6, 1),
                    (7, 0),
                    (7, 1),
                    (8, 0),
                    (8, 1),
                    (9, 0),
                    (10, 0),
                    (10, 1),
                    (10, 2),
                ],
                dtype=dtype),
            np.array(
                [
                    (3, 0),
                    (4, 0),
                    (4, 1),
                    (4, 2),  # (4,3) missing
                    (5, 0),
                    (5, 1),
                    (6, 0),
                    (6, 1),
                    (7, 0),
                    (7, 1),
                    (8, 0),
                    (9, 0),
                    (10, 0),
                    (10, 1),
                    (10, 2),
                ],
                dtype=dtype),
        ]

        sample = [
            Sample(ref_name='contig1',
                   features=None,
                   ref_seq=None,
                   labels=None,
                   positions=p,
                   label_probs=None) for p in pos
        ]

        expected = [
            (12, 6),  # (7, 0) is junction
            (10, 4),  # (5, 0) is junction
            (13, 10),  # (8, 0) is junction
            (15, 11),  # (9, 0) is junction
        ]
        for other, exp in enumerate(expected, 1):
            end, start, heuristic = Sample.overlap_indices(
                sample[0], sample[other])
            self.assertTrue(heuristic)
            self.assertEqual((end, start), exp)
            self.assertEqual(pos[0][exp[0]], pos[other][exp[1]])
Exemple #20
0
    def bams_to_training_samples(self,
                                 truth_bam,
                                 bam,
                                 region,
                                 reference=None,
                                 read_fraction=None):
        """Prepare training data chunks.

        :param truth_bam: .bam file of truth aligned to ref to generate labels.
        :param bam: input .bam file.
        :param region: `Region` obj.
            the reference will be parsed.
        :param reference: reference `.fasta`, should correspond to `bam`.

        :returns: tuple of `Sample` objects.

        .. note:: Chunks might be missing if `truth_bam` is provided and
            regions with multiple mappings were encountered.

        """
        ref_rle = self.process_ref_seq(region.ref_name, reference)

        # filter truth alignments to restrict ourselves to regions of the ref where the truth
        # in unambiguous
        alignments = TruthAlignment.bam_to_alignments(truth_bam,
                                                      region.ref_name,
                                                      start=region.start,
                                                      end=region.end)
        filtered_alignments = TruthAlignment.filter_alignments(
            alignments, start=region.start, end=region.end)
        if len(filtered_alignments) == 0:
            self.logger.info(
                "Filtering removed all alignments of truth to ref from {}.".
                format(region))

        samples = []
        for aln in filtered_alignments:
            mock_compr = self.max_hp_len > 1 and not self.is_compressed
            truth_pos, truth_labels = aln.get_positions_and_labels(
                ref_compr_rle=ref_rle,
                mock_compr=mock_compr,
                is_compressed=self.is_compressed,
                rle_dtype=True)
            aln_samples = self.bam_to_sample(bam,
                                             Region(region.ref_name, aln.start,
                                                    aln.end),
                                             ref_rle,
                                             read_fraction=read_fraction)
            for sample in aln_samples:
                # Create labels according to positions in pileup
                pad = (encoding[_gap_],
                       1) if len(truth_labels.dtype) > 0 else encoding[_gap_]
                padder = itertools.repeat(pad)
                position_to_label = defaultdict(
                    padder.__next__,
                    zip([tuple(p) for p in truth_pos],
                        [a for a in truth_labels]))
                padded_labels = np.fromiter(
                    (position_to_label[tuple(p)] for p in sample.positions),
                    dtype=truth_labels.dtype,
                    count=len(sample.positions))

                sample = sample._asdict()
                sample['labels'] = padded_labels
                samples.append(Sample(**sample))
        return tuple(samples)