def test_searchsorted(self): bin_locations = torch.linspace(0, 1, 10) # 9 bins == 10 locations left_boundaries = bin_locations[:-1] right_boundaries = bin_locations[:-1] + 0.1 mid_points = bin_locations[:-1] + 0.05 for inputs in [left_boundaries, right_boundaries, mid_points]: with self.subTest(inputs=inputs): idx = torchutils.searchsorted(bin_locations[None, :], inputs) self.assertEqual(idx, torch.arange(0, 9))
def test_searchsorted_arbitrary_shape(self): shape = [2, 3, 4] bin_locations = torch.linspace(0, 1, 10).repeat(*shape, 1) inputs = torch.rand(*shape) idx = torchutils.searchsorted(bin_locations, inputs) self.assertEqual(idx.shape, inputs.shape)