def test_select(self): mask = generate_mask(self.x, [0.]) a, b = torch.randint(self.x.shape[0], [1]).item(), torch.randint(self.x.shape[1], [1]).item() mask[a:, b:] = 1 x = mask_select(self.x, mask == 1) npt.assert_array_equal(x, self.x[a:, b:].flatten()) x = mask_select(one_hot(self.x), mask == 1) npt.assert_array_equal(x, one_hot(self.x[a:, b:].flatten())) mask[:a, :b] = 2 x = mask_select(self.x, mask == 2) npt.assert_array_equal(x, self.x[:a, :b].flatten()) x = mask_select(one_hot(self.x), mask == 2) npt.assert_array_equal(x, one_hot(self.x[:a, :b].flatten()))
def test_mask_fill(self): mask = generate_mask(self.x, [1.]) value = 2.5 x = mask_fill(torch.zeros_like(one_hot(self.x)), mask == 1, fill_value=value) npt.assert_array_equal(x, torch.ones_like(x) * value) x = mask_fill(torch.zeros_like(one_hot(self.x)), mask != 1, fill_value=value) npt.assert_array_equal(x, torch.zeros_like(x)) mask[0, 0] = 0 x = mask_fill(one_hot(self.x), mask == 0, value) npt.assert_array_equal(x[0, :, 0], [value, value, value, value]) npt.assert_array_equal(x[1:, :, 1:], one_hot(self.x)[1:, :, 1:])
def test_summarize(self): x = one_hot(torch.zeros(1, 7)).permute(1, 0, 2) * self.large_number self.assertEqual(summarize(x, max_len=10), 'A|@@@@@@@|\nG| |\nC| |\nT| |') self.assertEqual(summarize(x, max_len=6), 'A|@@@|\nG| |\nC| |\nT| |') x = one_hot(torch.arange(4).view(1, 4)).permute(1, 0, 2) * self.large_number self.assertEqual(summarize(x, max_len=10), 'A|@ |\nG| @ |\nC| @ |\nT| @|') x = one_hot(torch.zeros(2, 7)).permute(1, 0, 2) * self.large_number self.assertEqual(summarize(x, max_len=10), 'A|@@@|@@@|\nG| | |\nC| | |\nT| | |') x = one_hot(self.batch).permute(1, 0, 2) output = summarize(self.batch, x, max_len=10).split('\n') self.assertEqual(len(output), 5) for substr in output: self.assertEqual(len(substr), 10) x = torch.ones(1, 4, dtype=torch.bool) self.assertEqual(summarize(x), ' |....|') x = torch.zeros(1, 4, dtype=torch.bool) self.assertEqual(summarize(x), ' |!!!!|') x = torch.randint(4, [1, 12]) self.assertEqual(summarize(x), ' |' + index_to_bioseq(x.flatten()) + '|') x = torch.tensor([float('nan')] * 12).reshape(4, 1, 3) self.assertEqual(summarize(x), 'A|XXX|\nG|XXX|\nC|XXX|\nT|XXX|') x = torch.randn_like(one_hot(self.batch), dtype=torch.float) hist = normalize_histogram( prediction_histograms(x, self.batch, n_bins=5)) output = summarize(hist, col_labels=INDEX_TO_BASE, normalize_fn=None).split('\n') self.assertEqual(len(output), 3) for substr in output: self.assertEqual(len(substr), 4 * (5 + 1) + 2) output = summarize(one_hot_to_index(x) == self.batch, self.batch, x.permute(1, 0, 2), max_len=90) output = output.split('\n') self.assertEqual(len(output), 6) col_len = self.batch.shape[1] + 1 n_cols = (90 - 2) // col_len for substr in output: self.assertEqual(len(substr), n_cols * col_len + 2)
def test_correct(self): x = one_hot(self.batch) npt.assert_array_equal(correct(x, self.batch), torch.ones(self.batch.shape, dtype=torch.bool)) self.assertEqual(n_correct(x, self.batch), self.batch.nelement()) self.assertEqual(accuracy(x, self.batch), 1.) npt.assert_array_equal(correct(x, self.batch, threshold_score=2.), torch.zeros(self.batch.shape, dtype=torch.bool)) self.assertEqual(n_correct(x, self.batch, threshold_score=2.), 0) self.assertEqual(accuracy(x, self.batch, threshold_score=2.), 0.) x = one_hot(self.batch + 1) npt.assert_array_equal(correct(x, self.batch, threshold_score=0.5), torch.zeros(self.batch.shape, dtype=torch.bool)) self.assertEqual(n_correct(x, self.batch, threshold_score=0.5), 0) self.assertEqual(accuracy(x, self.batch, threshold_score=0.5), 0.)
def test_prediction_histograms(self): n_bins = 7 x = one_hot(self.batch) * self.large_number hist = prediction_histograms(x, self.batch, n_bins=n_bins) self.assertEqual(hist.shape, (2, 4, n_bins)) npt.assert_array_equal(hist[0, :, :], torch.zeros(4, n_bins)) npt.assert_array_equal(hist[1, :, :-1], torch.zeros(4, n_bins - 1)) self.assertEqual( torch.sum(hist[:, :, -1]).item(), self.batch.nelement()) normalized = normalize_histogram(hist) npt.assert_array_less(normalized, 1 + 1e-5) npt.assert_array_less(0 - 1e-5, normalized) self.assertEqual(torch.sum(normalized[:, :, -1]).item(), 4) x = torch.zeros(x.shape) hist = prediction_histograms(x, self.batch, n_bins=n_bins) npt.assert_array_equal(hist[:, :, 1:], torch.zeros(2, 4, n_bins - 1)) self.assertEqual( torch.sum(hist[:, :, 0]).item(), self.batch.nelement()) normalized = normalize_histogram(hist) npt.assert_array_less(normalized, 1 + 1e-5) npt.assert_array_less(0 - 1e-5, normalized) self.assertEqual(torch.sum(normalized[:, :, 0]).item(), 4) x = one_hot(torch.ones(self.batch.shape)) hist = prediction_histograms(x, self.batch, n_bins=n_bins) npt.assert_array_equal(accuracy_per_class(hist, threshold_prob=0.), [0., 1., 0., 0.]) npt.assert_array_equal(accuracy_per_class(hist, threshold_prob=0.5), [0., 0., 0., 0.]) hist = prediction_histograms(x[1], self.batch[1], n_bins=n_bins) npt.assert_array_equal(accuracy_per_class(hist, threshold_prob=0.), [0., 1., 0., 0.]) npt.assert_array_equal(accuracy_per_class(hist, threshold_prob=0.5), [0., 0., 0., 0.])
def setUp(self): self.batch_size, self.seq_len = 5, 37 self.x = one_hot(create_test_batch(self.batch_size, self.seq_len))
def fn1(x): return torch.cat([ x, reverse_complement(x), reverse(x), complement(x), ], dim=0) def fn2(x): return x.view(x.shape[0], self.out_channels * 4, -1) test( [convolve], [ one_hot(create_test_batch(5, 100)).to(dev), one_hot(create_test_batch(50, 1000)).to(dev), ], [ # RCIConv1d(4, 40, 4, do_reverse=False, do_complement=False).to(dev), nn.Conv1d(4, 40, 4).to(dev), GroupConv1d(4, 40, 4).to(dev), v2GroupConv1d(4, 40, 4).to(dev), # # RCIConv1d(4, 40, 4, do_reverse=True, do_complement=True).to(dev), # # RCIConv1d(4, 160, 4, do_reverse=True, do_complement=True).to(dev), # nn.Conv1d(4, 400, 40).to(dev), # RCIConv1d(4, 400, 40, do_reverse=False, do_complement=False).to(dev), nn.Conv1d(4, 40, 40).to(dev), GroupConv1d(4, 40, 40).to(dev), v2GroupConv1d(4, 40, 40).to(dev), nn.Conv1d(4, 400, 4).to(dev),