Beispiel #1
0
    def __init__(self, bsdf_map, mesh_map, bsdf_ad_keys, mesh_ad_keys,\
        lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8):
        from enoki.cuda_autodiff import Float32 as Float
        # Ensure that the JIT compiler does merge 'lr' into the PTX code
        # (this would trigger a recompile every time it is changed)
        self.lr = lr
        self.lr_v = ek.detach(Float(lr, literal=False))

        self.bsdf_map = bsdf_map
        self.mesh_map = mesh_map
        self.bsdf_ad_keys = bsdf_ad_keys
        self.mesh_ad_keys = mesh_ad_keys
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.t = 0
        self.state = {}
        for k in bsdf_ad_keys:
            ek.set_requires_gradient(bsdf_map[k].reflectance.data)
            size = ek.slices(bsdf_map[k].reflectance.data)
            self.state[k] = (ek.detach(
                type(bsdf_map[k].reflectance.data).zero(size)),
                             ek.detach(
                                 type(
                                     bsdf_map[k].reflectance.data).zero(size)))
        for k in mesh_ad_keys:
            ek.set_requires_gradient(mesh_map[k].vertex_positions)
            size = ek.slices(mesh_map[k].vertex_positions)
            self.state[k] = (ek.detach(
                type(mesh_map[k].vertex_positions).zero(size)),
                             ek.detach(
                                 type(
                                     mesh_map[k].vertex_positions).zero(size)))
Beispiel #2
0
def test17_cos(m):
    x = ek.linspace(m.Float, 0.01, 10, 10)
    ek.enable_grad(x)
    y = ek.cos(x)
    ek.backward(y)
    assert ek.allclose(ek.detach(y), ek.cos(ek.detach(x)))
    assert ek.allclose(ek.grad(x), -ek.sin(ek.detach(x)))
Beispiel #3
0
    def step(self):
        """ Take a gradient step """
        self.t += 1

        from mitsuba.core import Float
        lr_t = ek.detach(Float(self.lr * ek.sqrt(1 - self.beta_2**self.t) /
                               (1 - self.beta_1**self.t), literal=False))

        for k, p in self.params.items():
            g_p = ek.gradient(p)
            size = ek.slices(g_p)

            if size == 0:
                continue
            elif size != ek.slices(self.state[k][0]):
                # Reset state if data size has changed
                self._reset(k)

            m_tp, v_tp = self.state[k]
            m_t = self.beta_1 * m_tp + (1 - self.beta_1) * g_p
            v_t = self.beta_2 * v_tp + (1 - self.beta_2) * ek.sqr(g_p)
            self.state[k] = (m_t, v_t)

            u = ek.detach(p) - lr_t * m_t / (ek.sqrt(v_t) + self.epsilon)
            u = type(p)(u)
            ek.set_requires_gradient(u)
            self.params[k] = u
Beispiel #4
0
def test23_exp(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.exp(x * x)
    ek.backward(y)
    exp_x = ek.exp(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, exp_x)
    assert ek.allclose(ek.grad(x), 2 * ek.detach(x) * exp_x)
Beispiel #5
0
def test03_sub_mul(m):
    a, b, c = m.Float(2), m.Float(3), m.Float(4)
    ek.enable_grad(a, b, c)
    d = a * b - c
    ek.backward(d)
    assert ek.grad(a) == ek.detach(b)
    assert ek.grad(b) == ek.detach(a)
    assert ek.grad(c) == -1
Beispiel #6
0
def test24_log(m):
    x = ek.linspace(m.Float, 0.01, 1, 10)
    ek.enable_grad(x)
    y = ek.log(x * x)
    ek.backward(y)
    log_x = ek.log(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, log_x)
    assert ek.allclose(ek.grad(x), 2 / ek.detach(x))
Beispiel #7
0
def test20_scatter_add_rev(m):
    for i in range(3):
        idx1 = ek.arange(m.UInt, 5)
        idx2 = ek.arange(m.UInt, 4) + 3

        x = ek.linspace(m.Float, 0, 1, 5)
        y = ek.linspace(m.Float, 1, 2, 4)
        buf = ek.zero(m.Float, 10)

        if i % 2 == 0:
            ek.enable_grad(buf)
        if i // 2 == 0:
            ek.enable_grad(x, y)

        x.label = "x"
        y.label = "y"
        buf.label = "buf"

        buf2 = m.Float(buf)
        ek.scatter_add(buf2, x, idx1)
        ek.scatter_add(buf2, y, idx2)

        ref_buf = m.Float(0.0000, 0.2500, 0.5000, 1.7500, 2.3333, 1.6667,
                          2.0000, 0.0000, 0.0000, 0.0000)

        assert ek.allclose(ref_buf, buf2, atol=1e-4)
        assert ek.allclose(ref_buf, buf, atol=1e-4)

        s = ek.dot_async(buf2, buf2)

        print(ek.graphviz_str(s))

        ek.backward(s)

        ref_x = m.Float(0.0000, 0.5000, 1.0000, 3.5000, 4.6667)
        ref_y = m.Float(3.5000, 4.6667, 3.3333, 4.0000)

        if i // 2 == 0:
            assert ek.allclose(ek.grad(y), ek.detach(ref_y), atol=1e-4)
            assert ek.allclose(ek.grad(x), ek.detach(ref_x), atol=1e-4)
        else:
            assert ek.grad(x) == 0
            assert ek.grad(y) == 0

        if i % 2 == 0:
            assert ek.allclose(ek.grad(buf), ek.detach(ref_buf) * 2, atol=1e-4)
        else:
            assert ek.grad(buf) == 0
Beispiel #8
0
def test15_abs(m):
    x = m.Float(-2, 2)
    ek.enable_grad(x)
    y = ek.abs(x)
    ek.backward(y)
    assert ek.allclose(ek.detach(y), [2, 2])
    assert ek.allclose(ek.grad(x), [-1, 1])
Beispiel #9
0
def test14_rsqrt(m):
    x = m.Float(1, .25, 0.0625)
    ek.enable_grad(x)
    y = ek.rsqrt(x)
    ek.backward(y)
    assert ek.allclose(ek.detach(y), [1, 2, 4])
    assert ek.allclose(ek.grad(x), [-.5, -4, -32])
Beispiel #10
0
def test13_sqrt(m):
    x = m.Float(1, 4, 16)
    ek.enable_grad(x)
    y = ek.sqrt(x)
    ek.backward(y)
    assert ek.allclose(ek.detach(y), [1, 2, 4])
    assert ek.allclose(ek.grad(x), [.5, .25, .125])
Beispiel #11
0
 def _reset(self, key):
     """ Zero-initializes the internal state associated with a parameter """
     if self.momentum == 0:
         return
     p = self.params[key]
     size = ek.slices(p)
     self.state[key] = ek.detach(type(p).zero(size))
Beispiel #12
0
 def set_learning_rate(self, lr):
     """Set the learning rate."""
     from mitsuba.core import Float
     # Ensure that the JIT compiler does merge 'lr' into the PTX code
     # (this would trigger a recompile every time it is changed)
     self.lr = lr
     self.lr_v = ek.detach(Float(lr, literal=False))
Beispiel #13
0
def test51_scatter_reduce_fwd_eager(m):
    with EagerMode():
        for i in range(3):
            idx1 = ek.arange(m.UInt, 5)
            idx2 = ek.arange(m.UInt, 4) + 3

            x = ek.linspace(m.Float, 0, 1, 5)
            y = ek.linspace(m.Float, 1, 2, 4)
            buf = ek.zero(m.Float, 10)

            if i % 2 == 0:
                ek.enable_grad(buf)
                ek.set_grad(buf, 1)
            if i // 2 == 0:
                ek.enable_grad(x, y)
                ek.set_grad(x, 1)
                ek.set_grad(y, 1)

            x.label = "x"
            y.label = "y"
            buf.label = "buf"

            buf2 = m.Float(buf)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2)

            s = ek.dot_async(buf2, buf2)

            # Verified against Mathematica
            assert ek.allclose(ek.detach(s), 15.5972)
            assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                               (17 if i % 2 == 0 else 0))
Beispiel #14
0
def test05_hsum_0_rev(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(x * x)
    ek.backward(y)
    assert len(y) == 1 and ek.allclose(y, 95.0 / 27.0)
    assert ek.allclose(ek.grad(x), 2 * ek.detach(x))
Beispiel #15
0
def test06_hsum_0_fwd(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(x * x)
    ek.forward(x)
    assert len(y) == 1 and ek.allclose(ek.detach(y), 95.0 / 27.0)
    assert len(ek.grad(y)) == 1 and ek.allclose(ek.grad(y), 10)
Beispiel #16
0
def test42_suspend_resume(m):
    x = m.Array3f(1, 2, 3)
    y = m.Array3f(3, 2, 1)
    ek.enable_grad(x, y)
    assert ek.grad_enabled(x) and ek.grad_enabled(y)
    assert not ek.grad_suspended(x) and not ek.grad_suspended(y)
    ek.suspend_grad(x, y)
    assert not ek.grad_enabled(x) and not ek.grad_enabled(y)
    assert ek.grad_suspended(x) and ek.grad_suspended(y)
    b = x * y
    ek.resume_grad(x, y)
    assert ek.grad_enabled(x) and ek.grad_enabled(y)
    assert not ek.grad_suspended(x) and not ek.grad_suspended(y)
    c = x * y
    ek.backward(c)
    assert ek.grad(x) == ek.detach(y)
    assert ek.grad(y) == ek.detach(x)
    ek.suspend_grad(x, y)  # validate reference counting of suspended variables
Beispiel #17
0
def test31_atan(m):
    x = ek.linspace(m.Float, -.8, .8, 10)
    ek.enable_grad(x)
    y = ek.atan(x * x)
    ek.backward(y)
    atan_x = ek.atan(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, atan_x)
    assert ek.allclose(
        ek.grad(x),
        m.Float(-1.13507, -1.08223, -0.855508, -0.53065, -0.177767, 0.177767,
                0.53065, 0.855508, 1.08223, 1.13507))
Beispiel #18
0
def test30_acos(m):
    x = ek.linspace(m.Float, -.8, .8, 10)
    ek.enable_grad(x)
    y = ek.acos(x * x)
    ek.backward(y)
    acos_x = ek.acos(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, acos_x)
    assert ek.allclose(
        ek.grad(x),
        m.Float(2.08232, 1.3497, 0.906755, 0.534687, 0.177783, -0.177783,
                -0.534687, -0.906755, -1.3497, -2.08232))
Beispiel #19
0
def test29_asin(m):
    x = ek.linspace(m.Float, -.8, .8, 10)
    ek.enable_grad(x)
    y = ek.asin(x * x)
    ek.backward(y)
    asin_x = ek.asin(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, asin_x)
    assert ek.allclose(
        ek.grad(x),
        m.Float(-2.08232, -1.3497, -0.906755, -0.534687, -0.177783, 0.177783,
                0.534687, 0.906755, 1.3497, 2.08232))
Beispiel #20
0
def test28_tan(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.tan(x * x)
    ek.backward(y)
    tan_x = ek.tan(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, tan_x)
    assert ek.allclose(
        ek.grad(x),
        m.Float(0., 0.222256, 0.44553, 0.674965, 0.924494, 1.22406, 1.63572,
                2.29919, 3.58948, 6.85104))
Beispiel #21
0
def test25_pow(m):
    x = ek.linspace(m.Float, 1, 10, 10)
    y = ek.full(m.Float, 2.0, 10)
    ek.enable_grad(x, y)
    z = x**y
    ek.backward(z)
    assert ek.allclose(ek.grad(x), ek.detach(x) * 2)
    assert ek.allclose(
        ek.grad(y),
        m.Float(0., 2.77259, 9.88751, 22.1807, 40.2359, 64.5033, 95.3496,
                133.084, 177.975, 230.259))
Beispiel #22
0
def test27_sec(m):
    x = ek.linspace(m.Float, 1, 2, 10)
    ek.enable_grad(x)
    y = ek.sec(x * x)
    ek.backward(y)
    sec_x = ek.sec(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, sec_x)
    assert ek.allclose(
        ek.grad(x),
        m.Float(5.76495, 19.2717, 412.208, 61.794, 10.3374, 3.64885, 1.35811,
                -0.0672242, -1.88437, -7.08534))
Beispiel #23
0
def test26_csc(m):
    x = ek.linspace(m.Float, 1, 2, 10)
    ek.enable_grad(x)
    y = ek.csc(x * x)
    ek.backward(y)
    csc_x = ek.csc(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, csc_x)
    assert ek.allclose(ek.grad(x),
                       m.Float(-1.52612, -0.822733, -0.189079, 0.572183,
                               1.88201, 5.34839, 24.6017, 9951.86, 20.1158,
                               4.56495),
                       rtol=5e-5)
Beispiel #24
0
def test28_cot(m):
    x = ek.linspace(m.Float, 1, 2, 10)
    ek.enable_grad(x)
    y = ek.cot(x * x)
    ek.backward(y)
    cot_x = ek.cot(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, cot_x)
    assert ek.allclose(ek.grad(x),
                       m.Float(-2.82457, -2.49367, -2.45898, -2.78425,
                               -3.81687, -7.12557, -26.3248, -9953.63,
                               -22.0932, -6.98385),
                       rtol=5e-5)
Beispiel #25
0
def export_(a, migrate_to_host, version):
    shape = _ek.shape(a)
    ndim = len(shape)
    shape = tuple(reversed(shape))

    if not a.IsJIT:
        # F-style strides
        temp, strides = a.Type.Size, [0] * ndim
        for i in range(ndim):
            strides[i] = temp
            temp *= shape[i]

        # Array is already contiguous in memory -- document its structure
        return {
            'shape': shape,
            'strides': tuple(strides),
            'typestr': '<' + a.Type.NumPy,
            'data': (a.data_(), False),
            'version': version,
            'device': -1,
            'owner': a
        }
    else:
        # C-style strides
        temp, strides = a.Type.Size, [0] * ndim
        for i in reversed(range(ndim)):
            strides[i] = temp
            temp *= shape[i]

        # JIT array -- requires extra transformations
        b = _ek.ravel(_ek.detach(a) if a.IsDiff else a)
        _ek.eval(b)

        if b.IsCUDA and migrate_to_host:
            if b is a:
                b = type(a)(b)
            b = b.migrate_(_ek.AllocType.Host)
            _ek.sync_thread()
        elif b.IsLLVM:
            _ek.sync_thread()

        record = {
            'shape': shape,
            'strides': tuple(strides),
            'typestr': '<' + a.Type.NumPy,
            'data': (b.data_(), False),
            'version': version,
            'device': _ek.device(b),
            'owner': b
        }

        return record
Beispiel #26
0
    def step(self):
        """ Take a gradient step """
        for k, p in self.params.items():
            g_p = ek.gradient(p)
            size = ek.slices(g_p)
            if size == 0:
                continue

            if self.momentum != 0:
                if size != ek.slices(self.state[k]):
                    # Reset state if data size has changed
                    self._reset(k)

                self.state[k] = self.momentum * self.state[k] + g_p
                value = ek.detach(p) - self.lr_v * self.state[k]
            else:
                value = ek.detach(p) - self.lr_v * g_p

            value = type(p)(value)
            ek.set_requires_gradient(value)
            self.params[k] = value
        self.params.update()
Beispiel #27
0
 def backward(ctx, grad_output):
     try:
         ek.set_gradient(ctx.output, ek.detach(Float(grad_output)))
         Float.backward()
         result = tuple(ek.gradient(i).torch() if i is not None
                        else None
                        for i in ctx.inputs)
         del ctx.output
         del ctx.inputs
         ek.cuda_malloc_trim()
         return result
     except Exception as e:
         print("render_torch(): critical exception during "
               "backward pass: %s" % str(e))
         raise e
Beispiel #28
0
    def step(self):
        """ Take a gradient step """
        self.t += 1
        from enoki.cuda_autodiff import Float32 as Float
        lr_t = ek.detach(
            Float(self.lr * ek.sqrt(1 - self.beta_2**self.t) /
                  (1 - self.beta_1**self.t),
                  literal=False))

        for k in self.bsdf_ad_keys:
            g_p = ek.gradient(self.bsdf_map[k].reflectance.data)
            size = ek.slices(g_p)
            assert (size == ek.slices(self.state[k][0]))
            m_tp, v_tp = self.state[k]
            m_t = self.beta_1 * m_tp + (1 - self.beta_1) * g_p
            v_t = self.beta_2 * v_tp + (1 - self.beta_2) * ek.sqr(g_p)
            self.state[k] = (m_t, v_t)
            u = ek.detach(self.bsdf_map[k].reflectance.data) - lr_t * m_t / (
                ek.sqrt(v_t) + self.epsilon)
            u = type(self.bsdf_map[k].reflectance.data)(u)
            ek.set_requires_gradient(u)
            self.bsdf_map[k].reflectance.data = u

        for k in self.mesh_ad_keys:
            g_p = ek.gradient(self.mesh_map[k].vertex_positions)
            size = ek.slices(g_p)
            assert (size == ek.slices(self.state[k][0]))
            m_tp, v_tp = self.state[k]
            m_t = self.beta_1 * m_tp + (1 - self.beta_1) * g_p
            v_t = self.beta_2 * v_tp + (1 - self.beta_2) * ek.sqr(g_p)
            self.state[k] = (m_t, v_t)
            u = ek.detach(self.mesh_map[k].vertex_positions) - lr_t * m_t / (
                ek.sqrt(v_t) + self.epsilon)
            u = type(self.mesh_map[k].vertex_positions)(u)
            ek.set_requires_gradient(u)
            self.mesh_map[k].vertex_positions = u
Beispiel #29
0
def test_ad_operations(package):
    Float, Array3f = package.Float, package.Array3f
    prepare(package)

    class MyStruct:
        ENOKI_STRUCT = {'a': Array3f, 'b': Float}

        def __init__(self):
            self.a = Array3f()
            self.b = Float()

    foo = ek.zero(MyStruct, 4)
    assert not ek.grad_enabled(foo.a)
    assert not ek.grad_enabled(foo.b)
    assert not ek.grad_enabled(foo)

    ek.enable_grad(foo)
    assert ek.grad_enabled(foo.a)
    assert ek.grad_enabled(foo.b)
    assert ek.grad_enabled(foo)

    foo_detached = ek.detach(foo)
    assert not ek.grad_enabled(foo_detached.a)
    assert not ek.grad_enabled(foo_detached.b)
    assert not ek.grad_enabled(foo_detached)

    x = Float(4.0)
    ek.enable_grad(x)
    foo.a += x
    foo.b += x * x
    ek.forward(x)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 1
    assert foo_grad.b == 8

    ek.set_grad(foo, 5.0)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 5.0
    assert foo_grad.b == 5.0

    ek.accum_grad(foo, 5.0)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 10.0
    assert foo_grad.b == 10.0
Beispiel #30
0
def test21_scatter_add_fwd(m):
    for i in range(3):
        idx1 = ek.arange(m.UInt, 5)
        idx2 = ek.arange(m.UInt, 4) + 3

        x = ek.linspace(m.Float, 0, 1, 5)
        y = ek.linspace(m.Float, 1, 2, 4)
        buf = ek.zero(m.Float, 10)

        if i % 2 == 0:
            ek.enable_grad(buf)
            ek.set_grad(buf, 1)
        if i // 2 == 0:
            ek.enable_grad(x, y)
            ek.set_grad(x, 1)
            ek.set_grad(y, 1)

        x.label = "x"
        y.label = "y"
        buf.label = "buf"

        buf2 = m.Float(buf)
        ek.scatter_add(buf2, x, idx1)
        ek.scatter_add(buf2, y, idx2)

        s = ek.dot_async(buf2, buf2)

        if i % 2 == 0:
            ek.enqueue(buf)
        if i // 2 == 0:
            ek.enqueue(x, y)

        ek.traverse(m.Float, reverse=False)

        # Verified against Mathematica
        assert ek.allclose(ek.detach(s), 15.5972)
        assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                           (17 if i % 2 == 0 else 0))