def test_const_Op_argument(self): x = vector("x", dtype=self.dtype) y = np.array([2.0, 3.0], dtype=self.dtype) c = iscalar("c") f = function([c, x], IfElse(1)(c, x, y), mode=self.mode) val = f(0, np.r_[1.0, 2.0].astype(self.dtype)) assert np.array_equal(val, y)
def test_wrong_n_outs(self): x = vector("x", dtype=self.dtype) c = iscalar("c") with pytest.raises(ValueError): IfElse(0)(c, x, x)
def get_ifelse(self, n): if aesara.config.mode == "FAST_COMPILE": return IfElse(n) else: return IfElse(n, as_view=True)