def test_relative_position(self): # all 9 cases plus major and minor variations of overlapping, abutting and gapped # different_ref_name = 'Samples come from different reference contigs.' # forward_overlap = 'The end of s1 overlaps the start of s2.' # reverse_overlap = 'The end of s2 overlaps the start of s1.' # forward_abutted = 'The end of s1 abuts the start of s2.' # reverse_abutted = 'The end of s2 abuts the start of s1.' # forward_gapped = 's2 follows s1 with a gab inbetween.' # reverse_gapped = 's1 follows s2 with a gab inbetween.' # s2_within_s1 = 's2 is fully contained within s1.' # s1_within_s2 = 's1 is fully contained within s2.' slices = [ slice(0, 3), # (0,0) -> (1,0) in self.samples[0] slice(3, 5), # (2,0) -> (2,1) in self.samples[0] slice(5, 8), # (2,2) -> (4,0) in self.samples[0] slice(8, 11) # (4.1) -> (4.3) in self.samples[0] ] sliced = [self.samples[0].slice(sl) for sl in slices] sample_dict = self.samples[0]._asdict() sample_dict['ref_name'] = 'other' sample_other = Sample(**sample_dict) samples_expt = [ ([self.samples[0], sample_other], Relationship.different_ref_name), (self.samples[:2], Relationship.forward_overlap), # overlap of minor positions (self.samples[1:], Relationship.forward_overlap), # overlap of major positions (self.samples[:2][::-1], Relationship.reverse_overlap), # overlap of minor positions (self.samples[1:][::-1], Relationship.reverse_overlap), # overlap of major positions (sliced[:2], Relationship.forward_abutted), # (1,0) -> (2,0) (sliced[1:3], Relationship.forward_abutted), # (2,1) -> (2,2) (sliced[::-1][2:], Relationship.reverse_abutted), # (2,0) -> (1,0) (sliced[::-1][1:3], Relationship.reverse_abutted), # (2,2) -> (2,1) ([sliced[0], sliced[2]], Relationship.forward_gapped), # (1,0) -> (2,2) ([self.samples[0].slice(slice(0, 4)), sliced[2]], Relationship.forward_gapped), # (2,0) -> (2,2) ([sliced[0], self.samples[0].slice(slice(6, None))], Relationship.forward_gapped), # (1,0) -> (3,0) ([sliced[2], self.samples[0].slice(slice(0, 4))], Relationship.reverse_gapped), # (2,2) -> (2,0) ([self.samples[0].slice(slice(6, None)), sliced[0]], Relationship.reverse_gapped), # (3,0) -> (1,0) ([self.samples[0], sliced[0]], Relationship.s2_within_s1), ([self.samples[0], sliced[1]], Relationship.s2_within_s1), ([self.samples[0], sliced[2]], Relationship.s2_within_s1), ([self.samples[0], sliced[3]], Relationship.s2_within_s1), ([sliced[0], self.samples[0]], Relationship.s1_within_s2), ([sliced[1], self.samples[0]], Relationship.s1_within_s2), ([sliced[2], self.samples[0]], Relationship.s1_within_s2), ([sliced[3], self.samples[0]], Relationship.s1_within_s2), ] for samples, expt in samples_expt: self.assertIs(Sample.relative_position(*samples), expt)
def test_from_samples(self): # check we can concat a single sample concat1 = Sample.from_samples([self.samples[0]]) self.assertEqual(concat1, self.samples[0]) # check we can concat 3 samples slices = [slice(0, 5), slice(5, 8), slice(8, 11)] # these should span entire samples[0] sliced = [self.samples[0].slice(sl) for sl in slices] concat3 = Sample.from_samples(sliced) self.assertEqual(concat3, self.samples[0]) # also check raises an exception if not forward abutting self.assertRaises(ValueError, Sample.from_samples, sliced[::-1])
def _get_sorted_index(self): """Get index of samples indexed by reference and ordered by start pos. :returns: {ref_name: [sample dicts sorted by start]} """ ref_names = defaultdict(list) for key, f in self.samples: d = Sample.decode_sample_name(key) if d is not None: d['key'] = key d['filename'] = f ref_names[d['ref_name']].append(d) # sort dicts so that refs are in order and within a ref, chunks are in order ref_names_ordered = OrderedDict() get_major_minor = lambda x: tuple((int(i) for i in x.split('.'))) # sort by start and -end so that if we have two samples with the same # start but differrent end points, the longest sample comes first sorter = lambda x: (get_major_minor(x['start']) + tuple((-i for i in get_major_minor(x['end'])))) for ref_name in sorted(ref_names.keys()): ref_names[ref_name].sort(key=sorter) ref_names_ordered[ref_name] = ref_names[ref_name] return ref_names_ordered
def test_overlap_indices(self): # check we get right thing for overlap and abutting major and minor # and that we get an exception if we give it any other case. slices = [ slice(0, 3), # (0,0) -> (1,0) in self.samples[0] slice(3, 5), # (2,0) -> (2,1) in self.samples[0] slice(5, 8), # (2,2) -> (4,0) in self.samples[0] slice(8, 11) # (4.1) -> (4.3) in self.samples[0] ] sliced = [self.samples[0].slice(sl) for sl in slices] samples_expt = [ ( self.samples[:2], (9, 1, False) ), # overlap of minor inds with odd number of overlapping positions ( self.samples[1:], (6, 2, False) ), # overlap of major inds with even number of overlapping positions (sliced[:2], (None, None, False)), # abuts (sliced[1:3], (None, None, False)), # abuts (sliced[2:], (None, None, False)), # abuts ] for samples, expt in samples_expt: self.assertEqual(Sample.overlap_indices(*samples), expt) self.assertRaises(OverlapException, Sample.overlap_indices, samples[1], samples[0])
def test_gapped(self): # check we get is_last_in_contig if we have a gap between samples trimmed = list(Sample.trim_samples((s for s in self.samples[::2]))) for i, (expt, (got, is_last_in_contig, heuristic)) in enumerate(zip(self.samples[::2], trimmed)): self.assertEqual(got, expt) self.assertTrue(is_last_in_contig)
def test_single_sample(self): # test that if we provide a single sample, we get the same sample back results = list(Sample.trim_samples(iter([self.samples[0]]))) self.assertEqual(len(results), 1) got, is_last_in_contig, heuristic = results[0] self.assertEqual(got, self.samples[0]) self.assertTrue(is_last_in_contig)
def test_works(self): # test simple case of 3 chained samples trimmed = list(Sample.trim_samples((s for s in self.samples))) for i, (expt, (got, is_last_in_contig, heuristic)) in enumerate(zip(self.sliced, trimmed)): self.assertEqual(got, expt) if i == len(self.sliced) - 1: self.assertTrue(is_last_in_contig) else: self.assertFalse(is_last_in_contig)
def run_prediction(output, bam, regions, model, model_file, rle_ref, read_fraction, chunk_len, chunk_ovlp, batch_size=200, save_features=False, tag_name=None, tag_value=None, tag_keep_missing=False): """Inference worker.""" logger = get_named_logger('PWorker') def sample_gen(): # chain all samples whilst dispensing with generators when done # (they hold the feature vector in memory until they die) for region in regions: data_gen = SampleGenerator( bam, region, model_file, rle_ref, read_fraction, chunk_len=chunk_len, chunk_overlap=chunk_ovlp, tag_name=tag_name, tag_value=tag_value, tag_keep_missing=tag_keep_missing) yield from data_gen.samples batches = background_generator( grouper(sample_gen(), batch_size), 10 ) total_region_mbases = sum(r.size for r in regions) / 1e6 logger.info("Running inference for {:.1f}M draft bases.".format(total_region_mbases)) with DataStore(output, 'a') as ds: mbases_done = 0 t0 = now() tlast = t0 for data in batches: x_data = np.stack([x.features for x in data]) class_probs = model.predict_on_batch(x_data) mbases_done += sum(x.span for x in data) / 1e6 mbases_done = min(mbases_done, total_region_mbases) # just to avoid funny log msg t1 = now() if t1 - tlast > 10: tlast = t1 msg = '{:.1%} Done ({:.1f}/{:.1f} Mbases) in {:.1f}s' logger.info(msg.format(mbases_done / total_region_mbases, mbases_done, total_region_mbases, t1 - t0)) best = np.argmax(class_probs, -1) for sample, prob, pred, feat in zip(data, class_probs, best, x_data): # write out positions and predictions for later analysis sample_d = sample._asdict() sample_d['label_probs'] = prob sample_d['features'] = feat if save_features else None ds.write_sample(Sample(**sample_d)) logger.info('All done') return None
def setUpClass(cls): pos = np.array([(0, 0), (0, 1), (1, 0), (2, 0), (2, 1), (2, 2), (3, 0), (4, 0), (4, 1), (4, 2), (4, 3)], dtype=[('major', int), ('minor', int)]) data_dim = 10 data = np.zeros(shape=(len(pos), data_dim)) cls.sample = Sample(ref_name='contig1', features=data, ref_seq=None, labels=data, positions=pos, label_probs=data) cls.file = tempfile.NamedTemporaryFile() with datastore.DataStore(cls.file.name, 'w') as store: store.write_sample(cls.sample)
def test_decode_sample_name(self): expected = [ { 'ref_name': 'contig1', 'start': '0.0', 'end': '4.3' }, { 'ref_name': 'contig1', 'start': '4.1', 'end': '7.0' }, ] for expt, sample in zip(expected, self.samples): self.assertEqual(expt, Sample.decode_sample_name(sample.name))
def load_sample(self, key): """Load `Sample` object from HDF5 :param key: str, sample name. :returns: `Sample` object. """ s = {} for field in Sample._fields: pth = '{}/{}/{}'.format(self._sample_path_, key, field) if pth in self.fh: s[field] = self.fh[pth][()] if isinstance(s[field], np.ndarray) and isinstance(s[field][0], type(b'')): s[field] = np.char.decode(s[field]) else: s[field] = None return Sample(**s)
def test_030_bams_to_training_samples_simple(self): reads_bam = tempfile.NamedTemporaryFile(suffix='.bam').name truth_bam = tempfile.NamedTemporaryFile(suffix='.bam').name # we had a bug caused by missing qualities and bad indexing... data = copy.deepcopy(simple_data['calls']) data[0]['quality'] = None create_simple_bam(reads_bam, data) create_simple_bam(truth_bam, [simple_data['truth']]) encoder = medaka.features.CountsFeatureEncoder(normalise='total') label_scheme = medaka.labels.HaploidLabelScheme() region = Region('ref', 0, 100) result = encoder.bams_to_training_samples(truth_bam, reads_bam, region, label_scheme, min_length=0)[0] expected = Sample( ref_name='ref', features=np.array( [[0.5, 0., 0., 0., 0.5, 0., 0., 0., 0., 0.], [0., 0.5, 0., 0., 0., 0.5, 0., 0., 0., 0.], [0.5, 0., 0., 0., 0.5, 0., 0., 0., 0., 0.], [0., 0.25, 0., 0.25, 0., 0., 0., 0.25, 0., 0.25], [0.25, 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0.5, 0., 0., 0., 0.5, 0., 0., 0.], [0.5, 0., 0., 0., 0.5, 0., 0., 0., 0., 0.], [0., 0., 0., 0.5, 0., 0., 0., 0.5, 0., 0.], [0., 0., 0.5, 0., 0., 0., 0.5, 0., 0., 0.]], dtype='float32'), # the two insertions with respect to the draft are dropped labels=np.array([1, 2, 1, 4, 1, 3, 1, 4, 3]), # A C A T A G A T C ref_seq=None, positions=np.array([(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (4, 0), (5, 0), (6, 0), (7, 0)], dtype=[('major', '<i8'), ('minor', '<i8')]), label_probs=None) np.testing.assert_equal(result.labels, expected.labels) np.testing.assert_equal(result.positions, expected.positions) np.testing.assert_equal(result.features, expected.features)
def setUpClass(cls): pos1 = np.array([(0, 0), (0, 1), (1, 0), (2, 0), (2, 1), (2, 2), (3, 0), (4, 0), (4, 1), (4, 2), (4, 3)], dtype=[('major', int), ('minor', int)]) pos2 = np.array([(4, 1), (4, 2), (4, 3), (4, 5), (5, 0), (6, 0), (6, 1), (7, 0)], dtype=[('major', int), ('minor', int)]) pos3 = np.array([(5, 0), (6, 0), (6, 1), (7, 0), (7, 1)], dtype=[('major', int), ('minor', int)]) cls.samples = [] data_dim = 10 for pos in pos1, pos2, pos3: data = np.random.random_sample(size=data_dim * len(pos)).reshape( (len(pos), data_dim)) cls.samples.append( Sample(ref_name='contig1', features=data, ref_seq=None, labels=data, positions=pos, label_probs=data))
def _get_sorted_index(self): """Get index of samples indexed by reference and ordered by start pos. :returns: {ref_name: [sample dicts sorted by start]} """ ref_names = defaultdict(list) for key, f in self.samples: d = Sample.decode_sample_name(key) if d is not None: d['key'] = key d['filename'] = f ref_names[d['ref_name']].append(d) # sort dicts so that refs are in order and within a ref, chunks are in order ref_names_ordered = OrderedDict() for ref_name in sorted(ref_names.keys()): sorter = lambda x: float(x['start']) ref_names[ref_name].sort(key=sorter) ref_names_ordered[ref_name] = ref_names[ref_name] return ref_names_ordered
def test_is_empty(self): self.assertFalse(self.samples[0].is_empty) empty_sample = Sample('contig', None, None, None, self.samples[0][0:0], None) self.assertTrue(empty_sample.is_empty)
def test_raises(self): # test we get an exception if we e.g. provide samples from different # refs or samples out of order self.assertRaises(OverlapException, lambda x: list(Sample.trim_samples(x)), iter(self.samples[::-1]))
def bam_to_sample_c(self, reads_bam, region): """Converts a section of an alignment pileup (as shown by e.g. samtools tview) to a base frequency feature array :param reads_bam: (sorted indexed) bam with read alignment to reference :param region: `Region` object with ref_name, start and end attributes. :param start: starting position within reference :param end: ending position within reference :returns: `Sample` object """ assert self.ref_mode is None assert not self.consensus_as_ref assert self.max_hp_len == 1 assert self.log_min is None assert self.normalise == 'total' or self.normalise == 'fwd_rev' or self.normalise is None assert not self.with_depth assert not self.is_compressed pileup = pileup_counts(region, reads_bam, dtype_prefixes=self.dtypes, tag_name=self.tag_name, tag_value=self.tag_value, keep_missing=self.tag_keep_missing) samples = list() for counts, positions in pileup: if len(counts) == 0: msg = 'Pileup-feature is zero-length for {} indicating no reads in this region.'.format( region) self.logger.warning(msg) samples.append( Sample(ref_name=region.ref_name, features=None, labels=None, ref_seq=None, postions=positions, label_probs=None)) continue start, end = positions['major'][0], positions['major'][-1] if start != region.start or end + 1 != region.end: # TODO investigate off-by-one self.logger.warning( 'Pileup counts do not span requested region, requested {}, ' 'received {}-{}.'.format(region, start, end)) # find the position index for parent major position of all minor positions minor_inds = np.where(positions['minor'] > 0) major_pos_at_minor_inds = positions['major'][minor_inds] major_ind_at_minor_inds = np.searchsorted(positions['major'], major_pos_at_minor_inds, side='left') depth = np.sum(counts, axis=1) depth[minor_inds] = depth[major_ind_at_minor_inds] if self.sym_indels: # make indels at ref and non-ref positions symmetric. # major columns otherwise have counts of reads with and without a # deletion, whilst minor (inserted) columns only have counts of # the reads with an isertion. # To make ref and non-ref positions symmetric, fill in counts of reads which don't have insertions # i.e. depth_del = depth_major - depth_ins for (dt, is_rev), inds in self.feature_indices.items(): dt_depth = np.sum(counts[:, inds], axis=1) del_ind = self.encoding[(dt, is_rev, None, 1)] counts[minor_inds, del_ind] = dt_depth[ major_ind_at_minor_inds] - dt_depth[minor_inds] if self.normalise == 'total': # normalize counts by total depth at major position, since the # counts include deletions this is a count of spanning reads feature_array = counts / np.maximum(1, depth).reshape( (-1, 1)) # max just to avoid div error elif self.normalise == 'fwd_rev': # normalize forward and reverse and by dtype feature_array = np.empty_like(counts, dtype=float) for (dt, is_rev), inds in self.feature_indices.items(): dt_depth = np.sum(counts[:, inds], axis=1) dt_depth[minor_inds] = dt_depth[major_ind_at_minor_inds] feature_array[:, inds] = counts[:, inds] / np.maximum( 1, dt_depth).reshape( (-1, 1)) # max just to avoid div error else: feature_array = counts feature_array = feature_array.astype(self.feature_dtype) sample = Sample(ref_name=region.ref_name, features=feature_array, labels=None, ref_seq=None, positions=positions, label_probs=None) samples.append(sample) self.logger.info('Processed {} (median depth {})'.format( sample.name, np.median(depth))) return samples
def bam_to_sample(self, reads_bam, region, reference=None, read_fraction=None, force_py=False): """Converts a section of an alignment pileup (as shown by e.g. samtools tview) to a base frequency feature array :param reads_bam: (sorted indexed) bam with read alignment to reference :param region: `Region` object with ref_name, start and end attributes. :param reference: reference `.fasta`, should correspond to `bam`. Required only for run length encoded references and reads. :param read_fraction: fraction of reads to use, if `None` use all. :param force_py: bool, if True, force use of python code (rather than c library). :returns: iterable of `Sample` objects """ ref_rle = self.process_ref_seq(region.ref_name, reference) # Try to use fast c function if we can, else fall back on this function if not force_py and (ref_rle is None and read_fraction is None): try: return self.bam_to_sample_c(reads_bam, region) except Exception as e: self.logger.info( 'Could not process sample with bam_to_sample_c, using python code instead.\n({}).' .format(e)) pass if self.tag_name is not None: raise NotImplementedError( "Filtering alignments by tag is not supported in python code.") #TODO: The code below will abut discontiguous regions in a pileup i.e. # where no reads span a reference position the major position # is dropped from the pileup. The correct behaviour would be to # split apart the sub-regions and return them separately. # The C implementation does this splitting. if self.is_compressed: aln_to_pairs = partial(yield_compressed_pairs, ref_rle=ref_rle) elif self.max_hp_len == 1: aln_to_pairs = get_pairs else: aln_to_pairs = partial(get_pairs_with_hp_len, ref_seq=ref_rle) # accumulate data in dicts aln_counters = defaultdict(Counter) ref_bases = dict() with pysam.AlignmentFile(reads_bam, 'rb') as bamfile: aln_reads = bamfile.fetch(region.ref_name, region.start, region.end) if read_fraction is not None: low, high = read_fraction np.random.seed((int(now()) * region.start) % 2**32) fraction = ((high - low) * np.random.random_sample(1) + low)[0] aln_reads = [a for a in aln_reads] n_reads = len(aln_reads) n_reads_to_keep = max(int(fraction * n_reads), 1) replace = n_reads_to_keep > n_reads msg = "Resampling (replace {}) from {} to {} ({:.3f}) for {}" self.logger.debug( msg.format(replace, n_reads, n_reads_to_keep, fraction, region)) aln_reads = np.random.choice(aln_reads, n_reads_to_keep, replace=replace) start = region.start end = region.end if start is None: start = 0 if end is None: end = float('Inf') for aln in aln_reads: # get the dtype from the prefix of the query name try: dtype = self.dtypes[np.where([ aln.query_name.startswith(dt) for dt in self.dtypes ])[0][0]] except: msg = "Skipping read {} as dtype not in {}" self.logger.info(msg.format(aln.query_name, self.dtypes)) continue reverse = aln.is_reverse pairs = aln_to_pairs(aln) ins_count = 0 for pair in itertools.dropwhile( lambda x: (x.rpos is None) or (x.rpos < start), pairs): if ((pair.rpos == aln.reference_end - 1) or (pair.rpos is not None and pair.rpos >= end)): break if pair.rpos is None: ins_count += 1 else: ins_count = 0 current_pos = pair.rpos (aln_counters[(current_pos, ins_count)][self.encoding[ dtype, reverse, pair.qbase, min(pair.qlen, self.max_hp_len)]]) += 1 ref_base = pair.rbase.upper( ) if pair.rbase is not None else '*' ref_bases[(current_pos, ins_count)] = (ref_base, pair.rlen) # create feature array aln_cols = len(aln_counters) feature_len = len(self.encoding) feature_array = np.zeros(shape=(aln_cols, feature_len), dtype=self.feature_dtype) if self.log_min is not None: feature_array.fill(np.nan) ref_array = np.empty(shape=(aln_cols), dtype=[('base', int), ('run_length', int)]) positions = np.empty(aln_cols, dtype=[('major', int), ('minor', int)]) if aln_cols == 0: msg = 'Pileup-feature is zero-length for {} indicating no reads in this region.'.format( region) self.logger.warning(msg) return [ Sample(ref_name=region.ref_name, features=None, labels=None, ref_seq=None, positions=positions, label_probs=None) ] depth_array = np.empty(shape=(aln_cols), dtype=int) # keep track of which features are for fwd/rev reads of each dtype inds_by_type = self.feature_indices #TODO: refactor so common combinations of options can be handled as in C-function for i, ((pos, counts), (_, (ref_base, ref_len))) in \ enumerate(zip(sorted(aln_counters.items()), sorted(ref_bases.items()))): positions[i] = pos ref_array[i] = (encoding[ref_base], ref_len) for j in counts.keys(): feature_array[i, j] = counts[j] if self.consensus_as_ref: cons_i = np.argmax(feature_array[i]) cons_is_reverse, cons_base, cons_length = self.decoding[ cons_i] ref_base = cons_base if cons_base is not None else _gap_ ref_len = cons_length if positions[i]['minor'] == 0: major_depth = sum(counts.values()) # get the depth of each fwd and rev dtype major_depths_by_type = { t: sum((counts[i] for i in inds_by_type[t])) for t in inds_by_type } assert sum(major_depths_by_type.values()) == major_depth if self.sym_indels and positions[i]['minor'] > 0: # make indels at ref and non-ref positions symmetric (see comment in bam_to_sample_c). for (dtype, is_rev), inds in inds_by_type.items(): del_ind = self.encoding[(dtype, is_rev, None, 1)] assert feature_array[i, del_ind] == 0 feature_array[i, del_ind] = major_depths_by_type[ (dtype, is_rev)] - feature_array[i, inds].sum() if self.normalise is not None: if self.normalise == 'total': feature_array[i, :] /= max(major_depth, 1) elif self.normalise == 'fwd_rev': # normalize fwd and reverse seperately for each dtype for dt, inds in inds_by_type.items(): feature_array[i, inds] /= max( major_depths_by_type[dt], 1) depth_array[i] = major_depth if self.with_depth: feature_array[i, self.encoding['depth']] = depth_array[i] if self.log_min is not None: # counts/proportions and depth will be normalised # when we take log of probs, make it easier for network by keeping all log of # probs positive. add self.log_min to any log probs so they are positive, # if self.log_min is 10, we can cope with depth up to 10**9 feature_array[i, :] = np.log10(feature_array[i, :], out=feature_array[i, :]) feature_array[i, :] += self.log_min feature_array[i, :] = np.nan_to_num(feature_array[i, :], copy=False) if self.ref_mode == 'onehot': feature_array[i, self.encoding[('ref', str(ref_base), int(ref_len))]] = 1 elif self.ref_mode == 'base_length': feature_array[ i, self.encoding['ref_base']] = self.ref_base_encoding[ ref_base] feature_array[i, self.encoding['ref_length']] = ref_len elif self.ref_mode == 'index': # index of count which ref would contribute to were it a read feature_array[i, self.encoding['ref_index']] = self.encoding[( False, min(ref_len, self.max_hp_len), ref_base)] sample = Sample(ref_name=region.ref_name, features=feature_array, labels=None, ref_seq=ref_array, positions=positions, label_probs=None) self.logger.info('Processed {} (median depth {})'.format( sample.name, np.median(depth_array))) return [sample]
def test_messy_overlap(self): dtype = [('major', int), ('minor', int)] pos = [ np.array([ (0, 0), (1, 0), (2, 0), (2, 1), (2, 2), (3, 0), (4, 0), (4, 1), (4, 2), (4, 3), (5, 0), (6, 0), (7, 0), (8, 0), (8, 1), (9, 0), ], dtype=dtype), np.array( [ (3, 0), (4, 0), (4, 1), (4, 2), # (4,3) missing (5, 0), (6, 0), (7, 0), (8, 0), (8, 1), (9, 0), (10, 0), (10, 1), (10, 2), ], dtype=dtype), np.array([ (3, 0), (4, 0), (4, 1), (4, 2), (5, 0), (6, 0), (6, 1), (7, 0), (7, 1), (8, 0), (8, 1), (9, 0), (10, 0), (10, 1), (10, 2), ], dtype=dtype), np.array( [ (3, 0), (4, 0), (4, 1), (4, 2), # (4,3) missing (5, 0), (5, 1), (6, 0), (6, 1), (7, 0), (7, 1), (8, 0), (8, 1), (9, 0), (10, 0), (10, 1), (10, 2), ], dtype=dtype), np.array( [ (3, 0), (4, 0), (4, 1), (4, 2), # (4,3) missing (5, 0), (5, 1), (6, 0), (6, 1), (7, 0), (7, 1), (8, 0), (9, 0), (10, 0), (10, 1), (10, 2), ], dtype=dtype), ] sample = [ Sample(ref_name='contig1', features=None, ref_seq=None, labels=None, positions=p, label_probs=None) for p in pos ] expected = [ (12, 6), # (7, 0) is junction (10, 4), # (5, 0) is junction (13, 10), # (8, 0) is junction (15, 11), # (9, 0) is junction ] for other, exp in enumerate(expected, 1): end, start, heuristic = Sample.overlap_indices( sample[0], sample[other]) self.assertTrue(heuristic) self.assertEqual((end, start), exp) self.assertEqual(pos[0][exp[0]], pos[other][exp[1]])
def bams_to_training_samples(self, truth_bam, bam, region, reference=None, read_fraction=None): """Prepare training data chunks. :param truth_bam: .bam file of truth aligned to ref to generate labels. :param bam: input .bam file. :param region: `Region` obj. the reference will be parsed. :param reference: reference `.fasta`, should correspond to `bam`. :returns: tuple of `Sample` objects. .. note:: Chunks might be missing if `truth_bam` is provided and regions with multiple mappings were encountered. """ ref_rle = self.process_ref_seq(region.ref_name, reference) # filter truth alignments to restrict ourselves to regions of the ref where the truth # in unambiguous alignments = TruthAlignment.bam_to_alignments(truth_bam, region.ref_name, start=region.start, end=region.end) filtered_alignments = TruthAlignment.filter_alignments( alignments, start=region.start, end=region.end) if len(filtered_alignments) == 0: self.logger.info( "Filtering removed all alignments of truth to ref from {}.". format(region)) samples = [] for aln in filtered_alignments: mock_compr = self.max_hp_len > 1 and not self.is_compressed truth_pos, truth_labels = aln.get_positions_and_labels( ref_compr_rle=ref_rle, mock_compr=mock_compr, is_compressed=self.is_compressed, rle_dtype=True) aln_samples = self.bam_to_sample(bam, Region(region.ref_name, aln.start, aln.end), ref_rle, read_fraction=read_fraction) for sample in aln_samples: # Create labels according to positions in pileup pad = (encoding[_gap_], 1) if len(truth_labels.dtype) > 0 else encoding[_gap_] padder = itertools.repeat(pad) position_to_label = defaultdict( padder.__next__, zip([tuple(p) for p in truth_pos], [a for a in truth_labels])) padded_labels = np.fromiter( (position_to_label[tuple(p)] for p in sample.positions), dtype=truth_labels.dtype, count=len(sample.positions)) sample = sample._asdict() sample['labels'] = padded_labels samples.append(Sample(**sample)) return tuple(samples)