Exemple #1
0
    def __init__(self, borrow=None, force_floatX=False, context=None):
        """
        Arguments
        ---------

        borrow : tuple of objects
            If an object in this tuple is encountered while tracing the
            function, then its symbolic representation will alias that object's
            memory location. This means that *inplace* operations on the Python
            (likely NumPy) object will affect the symbolic function.

        force_floatX : bool
            If True, floats and float NumPy ndarrays will be cast to the dtype
            specified at theano.config.floatX when forming symbolic shared
            variables, if they do not have it already. Objects in `borrowable`
            are never cast.

        """
        if context is None:
            self.context = Context(borrowable=utils.as_seq(borrow, tuple),
                                   force_floatX=force_floatX)
        elif isinstance(context, Context):
            self.context = context
        else:
            raise TypeError(
                'Received unrecognized Context: {0}'.format(context))
Exemple #2
0
    def __init__(self,
                 pyfn,
                 context=None,
                 force_floatX=False,
                 borrowable=None,
                 ignore=None,
                 infer_updates=False,
                 escape_on_error=False):
        """
        Arguments
        ---------

        borrow : tuple of objects
            If an object in this tuple is encountered while tracing the
            function, then its symbolic representation will alias that object's
            memory location. This means that *inplace* operations on the Python
            (likely NumPy) object will affect the symbolic function.

        """

        if context is None:
            context = Context(borrowable=utils.as_seq(borrowable, tuple),
                              ignore=utils.as_seq(ignore, tuple),
                              force_floatX=force_floatX,
                              infer_updates=infer_updates,
                              escape_on_error=escape_on_error)
        assert isinstance(context, Context)
        self.context = context

        if isinstance(pyfn, Symbolic):
            pyfn = pyfn.pyfn
        self._pyfn = pyfn

        self._symfn = self.context.recompile(self.pyfn)
Exemple #3
0
def test_recalculate():
    x = np.zeros(3)
    c = Context()
    assert compute_stuff(x) == 12
    y = c.call(compute_stuff, (x, ))
    assert y == 12

    f = recalculate_fn(c, y, x)
    assert f(x) == 12
    assert f(x + 1) != 12
Exemple #4
0
def test_low_integer_constants():
    one = 2 - 1
    # CPython re-uses ids of low integer constants
    # which kind of plays hell with the id-tracking done in the Context object
    assert one is 1
    # the current implementation crashes here because the addition creates a
    # shadow for the constant `1`, which then gets picked up by the axis
    # argument, and causes Theano to barf because axis can't be a symbolic
    # variable.
    r = Context().call(lambda x: (1 + x).sum(axis=1), (np.ones((2, 3)), ))
    assert np.allclose(r, [6, 6])
Exemple #5
0
def test_loop():
    # test that non-data-dependent loops are unrolled properly

    x = np.zeros(3)
    c = Context()
    y = c.call(repeat_double, (x, 4))

    f = recalculate_fn(c, y, x)
    y2 = f(x + 1)
    assert np.all(y == 0)
    assert np.all(y2 == 16)
Exemple #6
0
def test_grad():
    x = np.zeros(3)

    c = Context()
    y = c.call(compute_stuff, (x, ))
    assert id(x) in c.svars
    assert id(y) in c.svars
    dy_dx_fn = grad_fn(c, y, x)

    assert np.all(dy_dx_fn(x + 0) == 8)
    assert np.all(dy_dx_fn(x + 1) == 16)
    assert np.all(dy_dx_fn(x + 2) == 24)