def setUp(self): extra1 = at.iscalar("extra1") extra1_ = np.array(0, dtype=extra1.dtype) extra1.dshape = tuple() extra1.dsize = 1 val1 = at.vector("val1") val1_ = np.zeros(3, dtype=val1.dtype) val1.dshape = (3, ) val1.dsize = 3 val2 = at.matrix("val2") val2_ = np.zeros((2, 3), dtype=val2.dtype) val2.dshape = (2, 3) val2.dsize = 6 self.val1, self.val1_ = val1, val1_ self.val2, self.val2_ = val2, val2_ self.extra1, self.extra1_ = extra1, extra1_ self.cost = extra1 * val1.sum() + val2.sum() self.f_grad = ValueGradFunction([self.cost], [val1, val2], {extra1: extra1_}, mode="FAST_COMPILE")
def test_invalid_type(self): a = at.ivector("a") a.tag.test_value = np.zeros(3, dtype=a.dtype) a.dshape = (3, ) a.dsize = 3 with pytest.raises(TypeError) as err: ValueGradFunction([a.sum()], [a], {}, mode="FAST_COMPILE") err.match("Invalid dtype")
def test_no_extra(self): a = at.vector("a") a.tag.test_value = np.zeros(3, dtype=a.dtype) f_grad = ValueGradFunction([a.sum()], [a], {}, mode="FAST_COMPILE") assert f_grad._extra_vars == []