コード例 #1
0
def test_one_hot():

    seq = "ACGTTTATNT"
    assert len(seq) == 10

    assert one_hot_dna(seq).shape == (10, 4)
    assert one_hot(seq).shape == (10, 4)

    assert np.all(one_hot_dna(seq) == one_hot(seq))

    assert one_hot(pad(seq, 20)).shape == (20, 4)

    assert one_hot(fixed_len(seq, 20)).shape == (20, 4)
    assert one_hot(fixed_len(seq, 5)).shape == (5, 4)
    assert trim(seq, 5) == 'TTTAT'
    assert trim(seq, 5, 'start') == 'ACGTT'
    assert trim(seq, 5, 'end') == 'TATNT'
    with pytest.raises(Exception):
        assert pad(seq, 5, 'end') == 'TATNT'

    assert np.all(one_hot(seq)[0] == np.array([1, 0, 0, 0]))
    assert np.all(one_hot(seq)[1] == np.array([0, 1, 0, 0]))
    assert np.all(one_hot(seq)[2] == np.array([0, 0, 1, 0]))
    assert np.all(one_hot(seq)[3] == np.array([0, 0, 0, 1]))
    assert np.all(one_hot(seq)[4] == np.array([0, 0, 0, 1]))
    assert np.all(one_hot(seq)[-1] == np.array([0, 0, 0, 1]))
    assert np.all(one_hot(seq)[-2] == np.array([0.25, 0.25, 0.25, 0.25]))

    with pytest.raises(ValueError):
        one_hot(['A', 'C'])

    with pytest.raises(ValueError):
        one_hot_dna(['A', 'C'])
コード例 #2
0
def test_seq_dataset(intervals_file, fasta_file):
    dl = SeqIntervalDl(intervals_file, fasta_file)
    ret_val = dl[0]

    assert np.all(ret_val['inputs'] == one_hot_dna("GT"))
    assert isinstance(ret_val["inputs"], np.ndarray)
    assert ret_val["inputs"].shape == (2, 4)
コード例 #3
0
 def __call__(self, seq):
     if self.alphabet == DNA and self.neutral_alphabet == ['N'] and self.neutral_value == 0.25:
         return F.one_hot_dna(seq, self.dtype)
     else:
         return F.one_hot(seq,
                          alphabet=self.alphabet,
                          neutral_alphabet=self.neutral_alphabet,
                          neutral_value=self.neutral_value,
                          dtype=self.dtype)
コード例 #4
0
ファイル: dataloader.py プロジェクト: kipoi/kipoi
    def __getitem__(self, idx):
        if self.input_data_extractors is None:
            self.input_data_extractors = {
                "seq":
                FastaStringExtractor(self.fasta_file),
                "dist_polya_st":
                DistToClosestLandmarkExtractor(gtf_file=self.gtf,
                                               landmarks=["polya"])
            }

        interval = self.bt[idx]

        out = {}

        out['inputs'] = {
            "seq":
            one_hot_dna(
                FastaStringExtractor(self.fasta_file).extract(interval)),
            "dist_polya_st":
            np.squeeze(DistToClosestLandmarkExtractor(gtf_file=self.gtf,
                                                      landmarks=["polya"
                                                                 ])([interval
                                                                     ]),
                       axis=0)
        }

        # use trained spline transformation to transform it
        out["inputs"]["dist_polya_st"] = np.squeeze(self.transformer.transform(
            out["inputs"]["dist_polya_st"][np.newaxis], warn=False),
                                                    axis=0)

        if self.target_dataset is not None:
            out["targets"] = np.array([self.target_dataset[idx]])

        # get metadata
        out['metadata'] = {}
        out['metadata']['ranges'] = {}
        out['metadata']['ranges']['chr'] = interval.chrom
        out['metadata']['ranges']['start'] = interval.start
        out['metadata']['ranges']['end'] = interval.stop
        out['metadata']['ranges']['id'] = interval.name
        out['metadata']['ranges']['strand'] = interval.strand

        return out
コード例 #5
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaStringExtractor(self.fasta_file)

        interval = self.bt[idx]

        # Intervals need to be 1000bp wide
        assert interval.stop - interval.start == 1000

        if self.targets is not None:
            y = self.targets.iloc[idx].values
        else:
            y = {}

        # Run the fasta extractor
        seq = one_hot_dna(self.fasta_extractor.extract(interval), dtype=np.float32) # TODO: Remove additional dtype after kipoiseq gets a new release
        return {
            "inputs": seq,
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }