示例#1
0
    def setUp(self):
        extra1 = at.iscalar("extra1")
        extra1_ = np.array(0, dtype=extra1.dtype)
        extra1.dshape = tuple()
        extra1.dsize = 1

        val1 = at.vector("val1")
        val1_ = np.zeros(3, dtype=val1.dtype)
        val1.dshape = (3, )
        val1.dsize = 3

        val2 = at.matrix("val2")
        val2_ = np.zeros((2, 3), dtype=val2.dtype)
        val2.dshape = (2, 3)
        val2.dsize = 6

        self.val1, self.val1_ = val1, val1_
        self.val2, self.val2_ = val2, val2_
        self.extra1, self.extra1_ = extra1, extra1_

        self.cost = extra1 * val1.sum() + val2.sum()

        self.f_grad = ValueGradFunction([self.cost], [val1, val2],
                                        {extra1: extra1_},
                                        mode="FAST_COMPILE")
示例#2
0
 def test_invalid_type(self):
     a = at.ivector("a")
     a.tag.test_value = np.zeros(3, dtype=a.dtype)
     a.dshape = (3, )
     a.dsize = 3
     with pytest.raises(TypeError) as err:
         ValueGradFunction([a.sum()], [a], {}, mode="FAST_COMPILE")
     err.match("Invalid dtype")
示例#3
0
 def test_no_extra(self):
     a = at.vector("a")
     a.tag.test_value = np.zeros(3, dtype=a.dtype)
     f_grad = ValueGradFunction([a.sum()], [a], {}, mode="FAST_COMPILE")
     assert f_grad._extra_vars == []