예제 #1
0
    def test_const_Op_argument(self):
        x = vector("x", dtype=self.dtype)
        y = np.array([2.0, 3.0], dtype=self.dtype)
        c = iscalar("c")
        f = function([c, x], IfElse(1)(c, x, y), mode=self.mode)

        val = f(0, np.r_[1.0, 2.0].astype(self.dtype))
        assert np.array_equal(val, y)
예제 #2
0
 def test_wrong_n_outs(self):
     x = vector("x", dtype=self.dtype)
     c = iscalar("c")
     with pytest.raises(ValueError):
         IfElse(0)(c, x, x)
예제 #3
0
 def get_ifelse(self, n):
     if aesara.config.mode == "FAST_COMPILE":
         return IfElse(n)
     else:
         return IfElse(n, as_view=True)