Beispiel #1
0
    def test_infer_shape(self):
        # Test using default parameters' value
        self._compile_and_check(
            [self.x, self.v],
            [searchsorted(self.x, self.v)],
            [self.a[self.idx_sorted], self.b],
            self.op_class,
        )

        # Test parameter ``sorter``
        sorter = vector("sorter", dtype="int32")
        self._compile_and_check(
            [self.x, self.v, sorter],
            [searchsorted(self.x, self.v, sorter=sorter)],
            [self.a, self.b, self.idx_sorted],
            self.op_class,
        )

        # Test parameter ``side``
        la = np.ones(10).astype(config.floatX)
        lb = np.ones(shape=(1, 2, 3)).astype(config.floatX)
        self._compile_and_check(
            [self.x, self.v],
            [searchsorted(self.x, self.v, side="right")],
            [la, lb],
            self.op_class,
        )
Beispiel #2
0
 def test_searchsortedOp_on_right_side(self):
     f = aesara.function(
         [self.x, self.v], searchsorted(self.x, self.v, side="right")
     )
     assert np.allclose(
         np.searchsorted(self.a, self.b, side="right"), f(self.a, self.b)
     )
Beispiel #3
0
 def test_searchsortedOp_on_int_sorter(self, dtype):
     sorter = vector("sorter", dtype=dtype)
     f = aesara.function(
         [self.x, self.v, sorter],
         searchsorted(self.x, self.v, sorter=sorter),
         allow_input_downcast=True,
     )
     assert np.allclose(
         np.searchsorted(self.a, self.b, sorter=self.idx_sorted),
         f(self.a, self.b, self.idx_sorted),
     )
Beispiel #4
0
 def test_searchsortedOp_on_int_sorter(self):
     compatible_types = ("int8", "int16", "int32")
     if PYTHON_INT_BITWIDTH == 64:
         compatible_types += ("int64", )
     # 'uint8', 'uint16', 'uint32', 'uint64')
     for dtype in compatible_types:
         sorter = vector("sorter", dtype=dtype)
         f = aesara.function(
             [self.x, self.v, sorter],
             searchsorted(self.x, self.v, sorter=sorter),
             allow_input_downcast=True,
         )
         assert np.allclose(
             np.searchsorted(self.a, self.b, sorter=self.idx_sorted),
             f(self.a, self.b, self.idx_sorted),
         )
 def test_searchsortedOp_on_int_sorter(self):
     compatible_types = ("int8", "int16", "int32")
     if aesara.configdefaults.python_int_bitwidth() == 64:
         compatible_types += ("int64", )
     # 'uint8', 'uint16', 'uint32', 'uint64')
     for dtype in compatible_types:
         sorter = tt.vector("sorter", dtype=dtype)
         f = aesara.function(
             [self.x, self.v, sorter],
             searchsorted(self.x, self.v, sorter=sorter),
             allow_input_downcast=True,
         )
         assert np.allclose(
             np.searchsorted(self.a, self.b, sorter=self.idx_sorted),
             f(self.a, self.b, self.idx_sorted),
         )
Beispiel #6
0
    def test_searchsortedOp_on_sorted_input(self):
        f = aesara.function([self.x, self.v], searchsorted(self.x, self.v))
        assert np.allclose(
            np.searchsorted(self.a[self.idx_sorted], self.b),
            f(self.a[self.idx_sorted], self.b),
        )

        sorter = vector("sorter", dtype="int32")
        f = aesara.function(
            [self.x, self.v, sorter],
            self.x.searchsorted(self.v, sorter=sorter, side="right"),
        )
        assert np.allclose(
            self.a.searchsorted(self.b, sorter=self.idx_sorted, side="right"),
            f(self.a, self.b, self.idx_sorted),
        )

        sa = self.a[self.idx_sorted]
        f = aesara.function([self.x, self.v], self.x.searchsorted(self.v, side="right"))
        assert np.allclose(sa.searchsorted(self.b, side="right"), f(sa, self.b))
Beispiel #7
0
 def test_searchsortedOp_on_float_sorter(self):
     sorter = vector("sorter", dtype="float32")
     with pytest.raises(TypeError):
         searchsorted(self.x, self.v, sorter=sorter)
Beispiel #8
0
 def test_searchsortedOp_on_no_1d_inp(self):
     no_1d = dmatrix("no_1d")
     with pytest.raises(ValueError):
         searchsorted(no_1d, self.v)
     with pytest.raises(ValueError):
         searchsorted(self.x, self.v, sorter=no_1d)
Beispiel #9
0
 def test_searchsortedOp_wrong_side_kwd(self):
     with pytest.raises(ValueError):
         searchsorted(self.x, self.v, side="asdfa")