Пример #1
0
    def test_searchsorted(self):
        bin_locations = torch.linspace(0, 1, 10)  # 9 bins == 10 locations

        left_boundaries = bin_locations[:-1]
        right_boundaries = bin_locations[:-1] + 0.1
        mid_points = bin_locations[:-1] + 0.05

        for inputs in [left_boundaries, right_boundaries, mid_points]:
            with self.subTest(inputs=inputs):
                idx = torchutils.searchsorted(bin_locations[None, :], inputs)
                self.assertEqual(idx, torch.arange(0, 9))
Пример #2
0
 def test_searchsorted_arbitrary_shape(self):
     shape = [2, 3, 4]
     bin_locations = torch.linspace(0, 1, 10).repeat(*shape, 1)
     inputs = torch.rand(*shape)
     idx = torchutils.searchsorted(bin_locations, inputs)
     self.assertEqual(idx.shape, inputs.shape)