コード例 #1
0
ファイル: test_mask.py プロジェクト: theLongLab/seqmodel
 def test_select(self):
     mask = generate_mask(self.x, [0.])
     a, b = torch.randint(self.x.shape[0],
                          [1]).item(), torch.randint(self.x.shape[1],
                                                     [1]).item()
     mask[a:, b:] = 1
     x = mask_select(self.x, mask == 1)
     npt.assert_array_equal(x, self.x[a:, b:].flatten())
     x = mask_select(one_hot(self.x), mask == 1)
     npt.assert_array_equal(x, one_hot(self.x[a:, b:].flatten()))
     mask[:a, :b] = 2
     x = mask_select(self.x, mask == 2)
     npt.assert_array_equal(x, self.x[:a, :b].flatten())
     x = mask_select(one_hot(self.x), mask == 2)
     npt.assert_array_equal(x, one_hot(self.x[:a, :b].flatten()))
コード例 #2
0
ファイル: test_mask.py プロジェクト: theLongLab/seqmodel
    def test_mask_fill(self):
        mask = generate_mask(self.x, [1.])
        value = 2.5
        x = mask_fill(torch.zeros_like(one_hot(self.x)),
                      mask == 1,
                      fill_value=value)
        npt.assert_array_equal(x, torch.ones_like(x) * value)
        x = mask_fill(torch.zeros_like(one_hot(self.x)),
                      mask != 1,
                      fill_value=value)
        npt.assert_array_equal(x, torch.zeros_like(x))

        mask[0, 0] = 0
        x = mask_fill(one_hot(self.x), mask == 0, value)
        npt.assert_array_equal(x[0, :, 0], [value, value, value, value])
        npt.assert_array_equal(x[1:, :, 1:], one_hot(self.x)[1:, :, 1:])
コード例 #3
0
ファイル: test_log.py プロジェクト: theLongLab/seqmodel
    def test_summarize(self):
        x = one_hot(torch.zeros(1, 7)).permute(1, 0, 2) * self.large_number
        self.assertEqual(summarize(x, max_len=10),
                         'A|@@@@@@@|\nG|       |\nC|       |\nT|       |')
        self.assertEqual(summarize(x, max_len=6),
                         'A|@@@|\nG|   |\nC|   |\nT|   |')
        x = one_hot(torch.arange(4).view(1, 4)).permute(1, 0,
                                                        2) * self.large_number
        self.assertEqual(summarize(x, max_len=10),
                         'A|@   |\nG| @  |\nC|  @ |\nT|   @|')
        x = one_hot(torch.zeros(2, 7)).permute(1, 0, 2) * self.large_number
        self.assertEqual(summarize(x, max_len=10),
                         'A|@@@|@@@|\nG|   |   |\nC|   |   |\nT|   |   |')
        x = one_hot(self.batch).permute(1, 0, 2)
        output = summarize(self.batch, x, max_len=10).split('\n')
        self.assertEqual(len(output), 5)
        for substr in output:
            self.assertEqual(len(substr), 10)

        x = torch.ones(1, 4, dtype=torch.bool)
        self.assertEqual(summarize(x), ' |....|')
        x = torch.zeros(1, 4, dtype=torch.bool)
        self.assertEqual(summarize(x), ' |!!!!|')
        x = torch.randint(4, [1, 12])
        self.assertEqual(summarize(x),
                         ' |' + index_to_bioseq(x.flatten()) + '|')

        x = torch.tensor([float('nan')] * 12).reshape(4, 1, 3)
        self.assertEqual(summarize(x), 'A|XXX|\nG|XXX|\nC|XXX|\nT|XXX|')
        x = torch.randn_like(one_hot(self.batch), dtype=torch.float)
        hist = normalize_histogram(
            prediction_histograms(x, self.batch, n_bins=5))
        output = summarize(hist, col_labels=INDEX_TO_BASE,
                           normalize_fn=None).split('\n')
        self.assertEqual(len(output), 3)
        for substr in output:
            self.assertEqual(len(substr), 4 * (5 + 1) + 2)
        output = summarize(one_hot_to_index(x) == self.batch,
                           self.batch,
                           x.permute(1, 0, 2),
                           max_len=90)
        output = output.split('\n')
        self.assertEqual(len(output), 6)
        col_len = self.batch.shape[1] + 1
        n_cols = (90 - 2) // col_len
        for substr in output:
            self.assertEqual(len(substr), n_cols * col_len + 2)
コード例 #4
0
ファイル: test_log.py プロジェクト: theLongLab/seqmodel
    def test_correct(self):
        x = one_hot(self.batch)
        npt.assert_array_equal(correct(x, self.batch),
                               torch.ones(self.batch.shape, dtype=torch.bool))
        self.assertEqual(n_correct(x, self.batch), self.batch.nelement())
        self.assertEqual(accuracy(x, self.batch), 1.)

        npt.assert_array_equal(correct(x, self.batch, threshold_score=2.),
                               torch.zeros(self.batch.shape, dtype=torch.bool))
        self.assertEqual(n_correct(x, self.batch, threshold_score=2.), 0)
        self.assertEqual(accuracy(x, self.batch, threshold_score=2.), 0.)

        x = one_hot(self.batch + 1)
        npt.assert_array_equal(correct(x, self.batch, threshold_score=0.5),
                               torch.zeros(self.batch.shape, dtype=torch.bool))
        self.assertEqual(n_correct(x, self.batch, threshold_score=0.5), 0)
        self.assertEqual(accuracy(x, self.batch, threshold_score=0.5), 0.)
コード例 #5
0
ファイル: test_log.py プロジェクト: theLongLab/seqmodel
    def test_prediction_histograms(self):
        n_bins = 7
        x = one_hot(self.batch) * self.large_number
        hist = prediction_histograms(x, self.batch, n_bins=n_bins)
        self.assertEqual(hist.shape, (2, 4, n_bins))
        npt.assert_array_equal(hist[0, :, :], torch.zeros(4, n_bins))
        npt.assert_array_equal(hist[1, :, :-1], torch.zeros(4, n_bins - 1))
        self.assertEqual(
            torch.sum(hist[:, :, -1]).item(), self.batch.nelement())
        normalized = normalize_histogram(hist)
        npt.assert_array_less(normalized, 1 + 1e-5)
        npt.assert_array_less(0 - 1e-5, normalized)
        self.assertEqual(torch.sum(normalized[:, :, -1]).item(), 4)

        x = torch.zeros(x.shape)
        hist = prediction_histograms(x, self.batch, n_bins=n_bins)
        npt.assert_array_equal(hist[:, :, 1:], torch.zeros(2, 4, n_bins - 1))
        self.assertEqual(
            torch.sum(hist[:, :, 0]).item(), self.batch.nelement())
        normalized = normalize_histogram(hist)
        npt.assert_array_less(normalized, 1 + 1e-5)
        npt.assert_array_less(0 - 1e-5, normalized)
        self.assertEqual(torch.sum(normalized[:, :, 0]).item(), 4)

        x = one_hot(torch.ones(self.batch.shape))
        hist = prediction_histograms(x, self.batch, n_bins=n_bins)
        npt.assert_array_equal(accuracy_per_class(hist, threshold_prob=0.),
                               [0., 1., 0., 0.])
        npt.assert_array_equal(accuracy_per_class(hist, threshold_prob=0.5),
                               [0., 0., 0., 0.])

        hist = prediction_histograms(x[1], self.batch[1], n_bins=n_bins)
        npt.assert_array_equal(accuracy_per_class(hist, threshold_prob=0.),
                               [0., 1., 0., 0.])
        npt.assert_array_equal(accuracy_per_class(hist, threshold_prob=0.5),
                               [0., 0., 0., 0.])
コード例 #6
0
ファイル: test_invariant.py プロジェクト: theLongLab/seqmodel
 def setUp(self):
     self.batch_size, self.seq_len = 5, 37
     self.x = one_hot(create_test_batch(self.batch_size, self.seq_len))
コード例 #7
0
    def fn1(x):
        return torch.cat([
            x,
            reverse_complement(x),
            reverse(x),
            complement(x),
        ],
                         dim=0)

    def fn2(x):
        return x.view(x.shape[0], self.out_channels * 4, -1)

    test(
        [convolve],
        [
            one_hot(create_test_batch(5, 100)).to(dev),
            one_hot(create_test_batch(50, 1000)).to(dev),
        ],
        [
            # RCIConv1d(4, 40, 4, do_reverse=False, do_complement=False).to(dev),
            nn.Conv1d(4, 40, 4).to(dev),
            GroupConv1d(4, 40, 4).to(dev),
            v2GroupConv1d(4, 40, 4).to(dev),
            # # RCIConv1d(4, 40, 4, do_reverse=True, do_complement=True).to(dev),
            # # RCIConv1d(4, 160, 4, do_reverse=True, do_complement=True).to(dev),
            # nn.Conv1d(4, 400, 40).to(dev),
            # RCIConv1d(4, 400, 40, do_reverse=False, do_complement=False).to(dev),
            nn.Conv1d(4, 40, 40).to(dev),
            GroupConv1d(4, 40, 40).to(dev),
            v2GroupConv1d(4, 40, 40).to(dev),
            nn.Conv1d(4, 400, 4).to(dev),