コード例 #1
0
def test_ambiguous_mask2(tmpdir):
    # only ambigous regions are present
    bed_file = write_tmp(
        'chr1\t1\t2\t1\t0\nchr2\t1\t3\t0\t1\nchr3\t1\t3\t-1\t-1', tmpdir)
    bt = BedDataset(bed_file, ambiguous_mask=-1)
    assert len(bt) == 2
    assert np.all(bt.get_targets().max(axis=1) >= 0)
コード例 #2
0
def test_ambiguous_mask(tmpdir):
    bed_file = write_tmp(
        'chr1\t1\t2\t1\t0\nchr2\t1\t3\t0\t1\nchr3\t1\t3\t0\t-1', tmpdir)
    bt = BedDataset(bed_file)
    assert len(bt) == 3
    assert np.all(bt[2][1] == np.array([0, -1]))

    # same as before
    bt = BedDataset(bed_file, ambiguous_mask=-1)
    assert len(bt) == 3
    assert np.all(bt[2][1] == np.array([0, -1]))
    assert np.all(bt.get_targets().max(axis=1) >= 0)
コード例 #3
0
def test_bed3_labels(tmpdir):
    bed_file = write_tmp('chr1\t1\t2\t1\t0\nchr1\t1\t3\t0\t1', tmpdir)
    bt = BedDataset(bed_file)
    assert np.all(bt.get_targets() == np.array([[1, 0], [0, 1]]))
    assert len(bt) == 2
    assert bt.n_tasks == 2
    assert np.all(bt.df[0] == 'chr1')
    assert bt[0][0] == Interval("chr1", 1, 2)
    assert np.all(bt[0][1] == np.array([1, 0]))

    assert bt[1][0] == Interval("chr1", 1, 3)
    assert np.all(bt[1][1] == np.array([0, 1]))
    assert len(bt) == 2
コード例 #4
0
def test_label_dtype(tmpdir):
    bed_file = write_tmp('chr1\t1\t2\t1\t0\nchr2\t1\t3\t0\t1', tmpdir)
    bt = BedDataset(bed_file, label_dtype=bool)
    assert len(bt) == 2
    assert bt[0][1].dtype == bool
    assert bt.get_targets().dtype == bool
コード例 #5
0
class ActivityDataset(Dataset):
    """
    Args:
        intervals_file: bed4 file containing chrom  start  end  name
        fasta_file: file path; Genome sequence
        label_dtype: label data type
        num_chr_fasta: if True, the tsv-loader will make sure that the chromosomes
          don't start with chr
    """
    def __init__(self,
                 intervals_file,
                 fasta_file,
                 bigwigs,
                 track_width=2000,
                 incl_chromosomes=None,
                 excl_chromosomes=None,
                 num_chr_fasta=False):
        self.num_chr_fasta = num_chr_fasta
        self.intervals_file = intervals_file
        self.fasta_file = fasta_file
        self.bigwigs = bigwigs
        self.incl_chromosomes = incl_chromosomes
        self.excl_chromosomes = excl_chromosomes
        self.track_width = track_width

        self.tsv = BedDataset(
            self.intervals_file,
            num_chr=self.num_chr_fasta,
            bed_columns=4,
            ignore_targets=True,
            incl_chromosomes=incl_chromosomes,
            excl_chromosomes=excl_chromosomes,
        )
        self.fasta_extractor = None
        self.bigwig_extractors = None

    def __len__(self):
        return len(self.tsv)

    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)
            self.bigwig_extractors = {
                a: [BigwigExtractor(f) for f in self.bigwigs[a]]
                for a in self.bigwigs
            }

        interval, labels = self.tsv[idx]
        interval = resize_interval(interval, 1000)
        # Intervals need to be 1000bp wide
        assert interval.stop - interval.start == 1000

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]))

        interval_wide = resize_interval(deepcopy(interval), self.track_width)

        return {
            "inputs": {
                "seq": seq
            },
            "targets": {
                a:
                sum([e([interval_wide])[0]
                     for e in self.bigwig_extractors[a]]).sum()
                for a in self.bigwig_extractors
            },
            "metadata": {
                "ranges":
                GenomicRanges(interval.chrom, interval.start, interval.stop,
                              str(idx)),
                "ranges_wide":
                GenomicRanges.from_interval(interval_wide),
                "name":
                interval.name
            }
        }

    def get_targets(self):
        return self.tsv.get_targets()