def __init__(self, bsdf_map, mesh_map, bsdf_ad_keys, mesh_ad_keys,\ lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8): from enoki.cuda_autodiff import Float32 as Float # Ensure that the JIT compiler does merge 'lr' into the PTX code # (this would trigger a recompile every time it is changed) self.lr = lr self.lr_v = ek.detach(Float(lr, literal=False)) self.bsdf_map = bsdf_map self.mesh_map = mesh_map self.bsdf_ad_keys = bsdf_ad_keys self.mesh_ad_keys = mesh_ad_keys self.beta_1 = beta_1 self.beta_2 = beta_2 self.epsilon = epsilon self.t = 0 self.state = {} for k in bsdf_ad_keys: ek.set_requires_gradient(bsdf_map[k].reflectance.data) size = ek.slices(bsdf_map[k].reflectance.data) self.state[k] = (ek.detach( type(bsdf_map[k].reflectance.data).zero(size)), ek.detach( type( bsdf_map[k].reflectance.data).zero(size))) for k in mesh_ad_keys: ek.set_requires_gradient(mesh_map[k].vertex_positions) size = ek.slices(mesh_map[k].vertex_positions) self.state[k] = (ek.detach( type(mesh_map[k].vertex_positions).zero(size)), ek.detach( type( mesh_map[k].vertex_positions).zero(size)))
def test17_cos(m): x = ek.linspace(m.Float, 0.01, 10, 10) ek.enable_grad(x) y = ek.cos(x) ek.backward(y) assert ek.allclose(ek.detach(y), ek.cos(ek.detach(x))) assert ek.allclose(ek.grad(x), -ek.sin(ek.detach(x)))
def step(self): """ Take a gradient step """ self.t += 1 from mitsuba.core import Float lr_t = ek.detach(Float(self.lr * ek.sqrt(1 - self.beta_2**self.t) / (1 - self.beta_1**self.t), literal=False)) for k, p in self.params.items(): g_p = ek.gradient(p) size = ek.slices(g_p) if size == 0: continue elif size != ek.slices(self.state[k][0]): # Reset state if data size has changed self._reset(k) m_tp, v_tp = self.state[k] m_t = self.beta_1 * m_tp + (1 - self.beta_1) * g_p v_t = self.beta_2 * v_tp + (1 - self.beta_2) * ek.sqr(g_p) self.state[k] = (m_t, v_t) u = ek.detach(p) - lr_t * m_t / (ek.sqrt(v_t) + self.epsilon) u = type(p)(u) ek.set_requires_gradient(u) self.params[k] = u
def test23_exp(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.exp(x * x) ek.backward(y) exp_x = ek.exp(ek.sqr(ek.detach(x))) assert ek.allclose(y, exp_x) assert ek.allclose(ek.grad(x), 2 * ek.detach(x) * exp_x)
def test03_sub_mul(m): a, b, c = m.Float(2), m.Float(3), m.Float(4) ek.enable_grad(a, b, c) d = a * b - c ek.backward(d) assert ek.grad(a) == ek.detach(b) assert ek.grad(b) == ek.detach(a) assert ek.grad(c) == -1
def test24_log(m): x = ek.linspace(m.Float, 0.01, 1, 10) ek.enable_grad(x) y = ek.log(x * x) ek.backward(y) log_x = ek.log(ek.sqr(ek.detach(x))) assert ek.allclose(y, log_x) assert ek.allclose(ek.grad(x), 2 / ek.detach(x))
def test20_scatter_add_rev(m): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) if i // 2 == 0: ek.enable_grad(x, y) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_add(buf2, x, idx1) ek.scatter_add(buf2, y, idx2) ref_buf = m.Float(0.0000, 0.2500, 0.5000, 1.7500, 2.3333, 1.6667, 2.0000, 0.0000, 0.0000, 0.0000) assert ek.allclose(ref_buf, buf2, atol=1e-4) assert ek.allclose(ref_buf, buf, atol=1e-4) s = ek.dot_async(buf2, buf2) print(ek.graphviz_str(s)) ek.backward(s) ref_x = m.Float(0.0000, 0.5000, 1.0000, 3.5000, 4.6667) ref_y = m.Float(3.5000, 4.6667, 3.3333, 4.0000) if i // 2 == 0: assert ek.allclose(ek.grad(y), ek.detach(ref_y), atol=1e-4) assert ek.allclose(ek.grad(x), ek.detach(ref_x), atol=1e-4) else: assert ek.grad(x) == 0 assert ek.grad(y) == 0 if i % 2 == 0: assert ek.allclose(ek.grad(buf), ek.detach(ref_buf) * 2, atol=1e-4) else: assert ek.grad(buf) == 0
def test15_abs(m): x = m.Float(-2, 2) ek.enable_grad(x) y = ek.abs(x) ek.backward(y) assert ek.allclose(ek.detach(y), [2, 2]) assert ek.allclose(ek.grad(x), [-1, 1])
def test14_rsqrt(m): x = m.Float(1, .25, 0.0625) ek.enable_grad(x) y = ek.rsqrt(x) ek.backward(y) assert ek.allclose(ek.detach(y), [1, 2, 4]) assert ek.allclose(ek.grad(x), [-.5, -4, -32])
def test13_sqrt(m): x = m.Float(1, 4, 16) ek.enable_grad(x) y = ek.sqrt(x) ek.backward(y) assert ek.allclose(ek.detach(y), [1, 2, 4]) assert ek.allclose(ek.grad(x), [.5, .25, .125])
def _reset(self, key): """ Zero-initializes the internal state associated with a parameter """ if self.momentum == 0: return p = self.params[key] size = ek.slices(p) self.state[key] = ek.detach(type(p).zero(size))
def set_learning_rate(self, lr): """Set the learning rate.""" from mitsuba.core import Float # Ensure that the JIT compiler does merge 'lr' into the PTX code # (this would trigger a recompile every time it is changed) self.lr = lr self.lr_v = ek.detach(Float(lr, literal=False))
def test51_scatter_reduce_fwd_eager(m): with EagerMode(): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1) ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2) s = ek.dot_async(buf2, buf2) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))
def test05_hsum_0_rev(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(x * x) ek.backward(y) assert len(y) == 1 and ek.allclose(y, 95.0 / 27.0) assert ek.allclose(ek.grad(x), 2 * ek.detach(x))
def test06_hsum_0_fwd(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(x * x) ek.forward(x) assert len(y) == 1 and ek.allclose(ek.detach(y), 95.0 / 27.0) assert len(ek.grad(y)) == 1 and ek.allclose(ek.grad(y), 10)
def test42_suspend_resume(m): x = m.Array3f(1, 2, 3) y = m.Array3f(3, 2, 1) ek.enable_grad(x, y) assert ek.grad_enabled(x) and ek.grad_enabled(y) assert not ek.grad_suspended(x) and not ek.grad_suspended(y) ek.suspend_grad(x, y) assert not ek.grad_enabled(x) and not ek.grad_enabled(y) assert ek.grad_suspended(x) and ek.grad_suspended(y) b = x * y ek.resume_grad(x, y) assert ek.grad_enabled(x) and ek.grad_enabled(y) assert not ek.grad_suspended(x) and not ek.grad_suspended(y) c = x * y ek.backward(c) assert ek.grad(x) == ek.detach(y) assert ek.grad(y) == ek.detach(x) ek.suspend_grad(x, y) # validate reference counting of suspended variables
def test31_atan(m): x = ek.linspace(m.Float, -.8, .8, 10) ek.enable_grad(x) y = ek.atan(x * x) ek.backward(y) atan_x = ek.atan(ek.sqr(ek.detach(x))) assert ek.allclose(y, atan_x) assert ek.allclose( ek.grad(x), m.Float(-1.13507, -1.08223, -0.855508, -0.53065, -0.177767, 0.177767, 0.53065, 0.855508, 1.08223, 1.13507))
def test30_acos(m): x = ek.linspace(m.Float, -.8, .8, 10) ek.enable_grad(x) y = ek.acos(x * x) ek.backward(y) acos_x = ek.acos(ek.sqr(ek.detach(x))) assert ek.allclose(y, acos_x) assert ek.allclose( ek.grad(x), m.Float(2.08232, 1.3497, 0.906755, 0.534687, 0.177783, -0.177783, -0.534687, -0.906755, -1.3497, -2.08232))
def test29_asin(m): x = ek.linspace(m.Float, -.8, .8, 10) ek.enable_grad(x) y = ek.asin(x * x) ek.backward(y) asin_x = ek.asin(ek.sqr(ek.detach(x))) assert ek.allclose(y, asin_x) assert ek.allclose( ek.grad(x), m.Float(-2.08232, -1.3497, -0.906755, -0.534687, -0.177783, 0.177783, 0.534687, 0.906755, 1.3497, 2.08232))
def test28_tan(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.tan(x * x) ek.backward(y) tan_x = ek.tan(ek.sqr(ek.detach(x))) assert ek.allclose(y, tan_x) assert ek.allclose( ek.grad(x), m.Float(0., 0.222256, 0.44553, 0.674965, 0.924494, 1.22406, 1.63572, 2.29919, 3.58948, 6.85104))
def test25_pow(m): x = ek.linspace(m.Float, 1, 10, 10) y = ek.full(m.Float, 2.0, 10) ek.enable_grad(x, y) z = x**y ek.backward(z) assert ek.allclose(ek.grad(x), ek.detach(x) * 2) assert ek.allclose( ek.grad(y), m.Float(0., 2.77259, 9.88751, 22.1807, 40.2359, 64.5033, 95.3496, 133.084, 177.975, 230.259))
def test27_sec(m): x = ek.linspace(m.Float, 1, 2, 10) ek.enable_grad(x) y = ek.sec(x * x) ek.backward(y) sec_x = ek.sec(ek.sqr(ek.detach(x))) assert ek.allclose(y, sec_x) assert ek.allclose( ek.grad(x), m.Float(5.76495, 19.2717, 412.208, 61.794, 10.3374, 3.64885, 1.35811, -0.0672242, -1.88437, -7.08534))
def test26_csc(m): x = ek.linspace(m.Float, 1, 2, 10) ek.enable_grad(x) y = ek.csc(x * x) ek.backward(y) csc_x = ek.csc(ek.sqr(ek.detach(x))) assert ek.allclose(y, csc_x) assert ek.allclose(ek.grad(x), m.Float(-1.52612, -0.822733, -0.189079, 0.572183, 1.88201, 5.34839, 24.6017, 9951.86, 20.1158, 4.56495), rtol=5e-5)
def test28_cot(m): x = ek.linspace(m.Float, 1, 2, 10) ek.enable_grad(x) y = ek.cot(x * x) ek.backward(y) cot_x = ek.cot(ek.sqr(ek.detach(x))) assert ek.allclose(y, cot_x) assert ek.allclose(ek.grad(x), m.Float(-2.82457, -2.49367, -2.45898, -2.78425, -3.81687, -7.12557, -26.3248, -9953.63, -22.0932, -6.98385), rtol=5e-5)
def export_(a, migrate_to_host, version): shape = _ek.shape(a) ndim = len(shape) shape = tuple(reversed(shape)) if not a.IsJIT: # F-style strides temp, strides = a.Type.Size, [0] * ndim for i in range(ndim): strides[i] = temp temp *= shape[i] # Array is already contiguous in memory -- document its structure return { 'shape': shape, 'strides': tuple(strides), 'typestr': '<' + a.Type.NumPy, 'data': (a.data_(), False), 'version': version, 'device': -1, 'owner': a } else: # C-style strides temp, strides = a.Type.Size, [0] * ndim for i in reversed(range(ndim)): strides[i] = temp temp *= shape[i] # JIT array -- requires extra transformations b = _ek.ravel(_ek.detach(a) if a.IsDiff else a) _ek.eval(b) if b.IsCUDA and migrate_to_host: if b is a: b = type(a)(b) b = b.migrate_(_ek.AllocType.Host) _ek.sync_thread() elif b.IsLLVM: _ek.sync_thread() record = { 'shape': shape, 'strides': tuple(strides), 'typestr': '<' + a.Type.NumPy, 'data': (b.data_(), False), 'version': version, 'device': _ek.device(b), 'owner': b } return record
def step(self): """ Take a gradient step """ for k, p in self.params.items(): g_p = ek.gradient(p) size = ek.slices(g_p) if size == 0: continue if self.momentum != 0: if size != ek.slices(self.state[k]): # Reset state if data size has changed self._reset(k) self.state[k] = self.momentum * self.state[k] + g_p value = ek.detach(p) - self.lr_v * self.state[k] else: value = ek.detach(p) - self.lr_v * g_p value = type(p)(value) ek.set_requires_gradient(value) self.params[k] = value self.params.update()
def backward(ctx, grad_output): try: ek.set_gradient(ctx.output, ek.detach(Float(grad_output))) Float.backward() result = tuple(ek.gradient(i).torch() if i is not None else None for i in ctx.inputs) del ctx.output del ctx.inputs ek.cuda_malloc_trim() return result except Exception as e: print("render_torch(): critical exception during " "backward pass: %s" % str(e)) raise e
def step(self): """ Take a gradient step """ self.t += 1 from enoki.cuda_autodiff import Float32 as Float lr_t = ek.detach( Float(self.lr * ek.sqrt(1 - self.beta_2**self.t) / (1 - self.beta_1**self.t), literal=False)) for k in self.bsdf_ad_keys: g_p = ek.gradient(self.bsdf_map[k].reflectance.data) size = ek.slices(g_p) assert (size == ek.slices(self.state[k][0])) m_tp, v_tp = self.state[k] m_t = self.beta_1 * m_tp + (1 - self.beta_1) * g_p v_t = self.beta_2 * v_tp + (1 - self.beta_2) * ek.sqr(g_p) self.state[k] = (m_t, v_t) u = ek.detach(self.bsdf_map[k].reflectance.data) - lr_t * m_t / ( ek.sqrt(v_t) + self.epsilon) u = type(self.bsdf_map[k].reflectance.data)(u) ek.set_requires_gradient(u) self.bsdf_map[k].reflectance.data = u for k in self.mesh_ad_keys: g_p = ek.gradient(self.mesh_map[k].vertex_positions) size = ek.slices(g_p) assert (size == ek.slices(self.state[k][0])) m_tp, v_tp = self.state[k] m_t = self.beta_1 * m_tp + (1 - self.beta_1) * g_p v_t = self.beta_2 * v_tp + (1 - self.beta_2) * ek.sqr(g_p) self.state[k] = (m_t, v_t) u = ek.detach(self.mesh_map[k].vertex_positions) - lr_t * m_t / ( ek.sqrt(v_t) + self.epsilon) u = type(self.mesh_map[k].vertex_positions)(u) ek.set_requires_gradient(u) self.mesh_map[k].vertex_positions = u
def test_ad_operations(package): Float, Array3f = package.Float, package.Array3f prepare(package) class MyStruct: ENOKI_STRUCT = {'a': Array3f, 'b': Float} def __init__(self): self.a = Array3f() self.b = Float() foo = ek.zero(MyStruct, 4) assert not ek.grad_enabled(foo.a) assert not ek.grad_enabled(foo.b) assert not ek.grad_enabled(foo) ek.enable_grad(foo) assert ek.grad_enabled(foo.a) assert ek.grad_enabled(foo.b) assert ek.grad_enabled(foo) foo_detached = ek.detach(foo) assert not ek.grad_enabled(foo_detached.a) assert not ek.grad_enabled(foo_detached.b) assert not ek.grad_enabled(foo_detached) x = Float(4.0) ek.enable_grad(x) foo.a += x foo.b += x * x ek.forward(x) foo_grad = ek.grad(foo) assert foo_grad.a == 1 assert foo_grad.b == 8 ek.set_grad(foo, 5.0) foo_grad = ek.grad(foo) assert foo_grad.a == 5.0 assert foo_grad.b == 5.0 ek.accum_grad(foo, 5.0) foo_grad = ek.grad(foo) assert foo_grad.a == 10.0 assert foo_grad.b == 10.0
def test21_scatter_add_fwd(m): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_add(buf2, x, idx1) ek.scatter_add(buf2, y, idx2) s = ek.dot_async(buf2, buf2) if i % 2 == 0: ek.enqueue(buf) if i // 2 == 0: ek.enqueue(x, y) ek.traverse(m.Float, reverse=False) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))