Esempio n. 1
0
    def test_create_numpy_strict_false(self):

        # here the value is perfect, and we're not strict about it,
        # so creation should work
        SharedVariable(name='u',
                       type=Tensor(broadcastable=[False], dtype='float64'),
                       value=numpy.asarray([1., 2.]),
                       strict=False)

        # here the value is castable, and we're not strict about it,
        # so creation should work
        SharedVariable(name='u',
                       type=Tensor(broadcastable=[False], dtype='float64'),
                       value=[1., 2.],
                       strict=False)

        # here the value is castable, and we're not strict about it,
        # so creation should work
        SharedVariable(
            name='u',
            type=Tensor(broadcastable=[False], dtype='float64'),
            value=[1, 2],  # different dtype and not a numpy array
            strict=False)

        # here the value is not castable, and we're not strict about it,
        # this is beyond strictness, it must fail
        try:
            SharedVariable(
                name='u',
                type=Tensor(broadcastable=[False], dtype='float64'),
                value=dict(),  # not an array by any stretch
                strict=False)
            assert 0
        except TypeError:
            pass
Esempio n. 2
0
    def test_use_numpy_strict_false(self):

        # here the value is perfect, and we're not strict about it,
        # so creation should work
        u = SharedVariable(name='u',
                           type=Tensor(broadcastable=[False], dtype='float64'),
                           value=numpy.asarray([1., 2.]),
                           strict=False)

        # check that assignments to value are cast properly
        u.set_value([3, 4])
        assert type(u.get_value()) is numpy.ndarray
        assert str(u.get_value(borrow=True).dtype) == 'float64'
        assert numpy.all(u.get_value() == [3, 4])

        # check that assignments of nonsense fail
        try:
            u.set_value('adsf')
            assert 0
        except ValueError:
            pass

        # check that an assignment of a perfect value results in no copying
        uval = theano._asarray([5, 6, 7, 8], dtype='float64')
        u.set_value(uval, borrow=True)
        assert u.get_value(borrow=True) is uval