Exemple #1
0
def test12_hmax_fwd(m):
    x = m.Float(1, 2, 8, 5, 8)
    ek.enable_grad(x)
    y = ek.hmax_async(x)
    ek.forward(x)
    assert len(y) == 1 and ek.allclose(y[0], 8)
    assert ek.allclose(ek.grad(y), [2])  # Approximation
Exemple #2
0
def test52_scatter_fwd_eager(m):
    with EagerMode():
        x = m.Float(4.0)
        ek.enable_grad(x)
        ek.set_grad(x, 1)

        values = x * x * ek.linspace(m.Float, 1, 4, 4)
        idx = 2 * ek.arange(m.UInt32, 4)

        buf = ek.zero(m.Float, 10)
        ek.scatter(buf, values, idx)

        assert ek.grad_enabled(buf)

        ref = [16.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0]
        assert ek.allclose(buf, ref)

        grad = ek.grad(buf)

        ref_grad = [8.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0]
        assert ek.allclose(grad, ref_grad)

        # Overwrite first value with non-diff value, resulting gradient entry should be 0
        y = m.Float(3)
        idx = m.UInt32(0)
        ek.scatter(buf, y, idx)

        ref = [3.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0]
        assert ek.allclose(buf, ref)

        ek.forward(x)
        grad = ek.grad(buf)

        ref_grad = [0.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0]
        assert ek.allclose(grad, ref_grad)
Exemple #3
0
def test_55_diffloop_simple_fwd(m, no_record):
    fi, fo = m.Float(1, 2, 3), m.Float(0, 0, 0)
    ek.enable_grad(fi)

    loop = m.Loop("MyLoop", lambda: fo)
    while loop(fo < 10):
        fo += fi
    ek.forward(fi)
    assert ek.grad(fo) == m.Float(10, 5, 4)
Exemple #4
0
def test_57_diffloop_masking_rev(m, no_record):
    fo = ek.zero(m.Float, 10)
    fi = m.Float(1, 2)
    i = m.UInt32(0, 5)
    ek.enable_grad(fi)
    loop = m.Loop("MyLoop", lambda: i)
    while loop(i < 5):
        ek.scatter_reduce(ek.ReduceOp.Add, fo, fi, i)
        i += 1
    ek.forward(fi)
    assert fo == m.Float(1, 1, 1, 1, 1, 0, 0, 0, 0, 0)
    assert ek.grad(fo) == m.Float(1, 1, 1, 1, 1, 0, 0, 0, 0, 0)
Exemple #5
0
def test45_diff_loop(m):
    def mcint(a, b, f, sample_count=100000):
        rng = m.PCG32()
        i = m.UInt32(0)
        result = m.Float(0)
        l = m.Loop(i, rng, result)
        while l.cond(i < sample_count):
            result += f(ek.lerp(a, b, rng.next_float32()))
            i += 1
        return result * (b - a) / sample_count

    class EllipticK(ek.CustomOp):
        # --- Internally used utility methods ---

        # Integrand of the 'K' function
        def K(self, x, m_):
            return ek.rsqrt(1 - m_ * ek.sqr(ek.sin(x)))

        # Derivative of the above with respect to 'm'
        def dK(self, x, m_):
            m_ = m.Float(m_)  # Convert 'm' to differentiable type
            ek.enable_grad(m_)
            y = self.K(x, m_)
            ek.forward(m_)
            return ek.grad(y)

        # Monte Carlo integral of dK, used in forward/reverse pass
        def eval_grad(self):
            return mcint(a=0, b=ek.Pi / 2, f=lambda x: self.dK(x, self.m_))

        # --- CustomOp interface ---

        def eval(self, m_):
            self.m_ = m_  # Stash 'm' for later
            return mcint(a=0, b=ek.Pi / 2, f=lambda x: self.K(x, self.m_))

        def forward(self):
            self.set_grad_out(self.grad_in('m_') * self.eval_grad())

    def elliptic_k(m_):
        return ek.custom(EllipticK, m_)

    ek.enable_flag(ek.JitFlag.RecordLoops)
    x = m.Float(0.5)
    ek.enable_grad(x)
    y = elliptic_k(x)
    ek.forward(x)
    assert ek.allclose(y, 1.85407, rtol=5e-4)
    assert ek.allclose(ek.grad(y), 0.847213, rtol=5e-4)
    ek.disable_flag(ek.JitFlag.RecordLoops)
Exemple #6
0
def test_ad_operations(package):
    Float, Array3f = package.Float, package.Array3f
    prepare(package)

    class MyStruct:
        ENOKI_STRUCT = {'a': Array3f, 'b': Float}

        def __init__(self):
            self.a = Array3f()
            self.b = Float()

    foo = ek.zero(MyStruct, 4)
    assert not ek.grad_enabled(foo.a)
    assert not ek.grad_enabled(foo.b)
    assert not ek.grad_enabled(foo)

    ek.enable_grad(foo)
    assert ek.grad_enabled(foo.a)
    assert ek.grad_enabled(foo.b)
    assert ek.grad_enabled(foo)

    foo_detached = ek.detach(foo)
    assert not ek.grad_enabled(foo_detached.a)
    assert not ek.grad_enabled(foo_detached.b)
    assert not ek.grad_enabled(foo_detached)

    x = Float(4.0)
    ek.enable_grad(x)
    foo.a += x
    foo.b += x * x
    ek.forward(x)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 1
    assert foo_grad.b == 8

    ek.set_grad(foo, 5.0)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 5.0
    assert foo_grad.b == 5.0

    ek.accum_grad(foo, 5.0)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 10.0
    assert foo_grad.b == 10.0
Exemple #7
0
def render_gradient(scene, passes, diff_params):
    """Render radiance and gradient image using forward autodiff"""
    from mitsuba.python.autodiff import render

    fsize = scene.sensors()[0].film().size()

    img  = np.zeros((fsize[1], fsize[0], 3), dtype=np.float32)
    grad = np.zeros((fsize[1], fsize[0], 1), dtype=np.float32)
    for i in range(passes):
        img_i = render(scene)
        ek.forward(diff_params, i == passes - 1)

        grad_i = ek.gradient(img_i).numpy().reshape(fsize[1], fsize[0], -1)[:, :, [0]]
        img_i = img_i.numpy().reshape(fsize[1], fsize[0], -1)

        # Remove NaNs
        grad_i[grad_i != grad_i] = 0
        img_i[img_i != img_i] = 0

        grad += grad_i
        img += img_i

    return img / passes, grad / passes
Exemple #8
0
def test22_scatter_fwd_permute(m):
    x = m.Float(4.0)
    ek.enable_grad(x)

    values_0 = x * ek.linspace(m.Float, 1, 9, 5)
    values_1 = x * ek.linspace(m.Float, 11, 19, 5)

    buf = ek.zero(m.Float, 10)

    idx_0 = ek.arange(m.UInt32, 5)
    idx_1 = ek.arange(m.UInt32, 5) + 5

    ek.scatter(buf, values_0, idx_0, permute=False)
    ek.scatter(buf, values_1, idx_1, permute=False)

    ref = [4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 68.0, 76.0]
    assert ek.allclose(buf, ref)

    ek.forward(x)
    grad = ek.grad(buf)

    ref_grad = [1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0]
    assert ek.allclose(grad, ref_grad)
def test05_differentiable_surface_interaction_ray_forward(
        variant_gpu_autodiff_rgb):
    from mitsuba.core import xml, Ray3f, Vector3f, UInt32

    shape = xml.load_dict({'type': 'rectangle'})

    ray = Ray3f(Vector3f(-0.3, -0.3, -10.0), Vector3f(0.0, 0.0, 1.0), 0, [])
    pi = shape.ray_intersect_preliminary(ray)

    ek.set_requires_gradient(ray.o)
    ek.set_requires_gradient(ray.d)

    # If the ray origin is shifted along the x-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.p), [1, 0, 0])

    # If the ray origin is shifted along the y-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.y)
    assert ek.allclose(ek.gradient(si.p), [0, 1, 0])

    # If the ray origin is shifted along the x-axis, so does si.uv
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.uv), [0.5, 0])

    # If the ray origin is shifted along the z-axis, so does si.t
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.z)
    assert ek.allclose(ek.gradient(si.t), -1)

    # If the ray direction is shifted along the x-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.d.x)
    assert ek.allclose(ek.gradient(si.p), [10, 0, 0])
Exemple #10
0
def test47_nan_propagation(m):
    for i in range(2):
        x = ek.arange(m.Float, 10)
        ek.enable_grad(x)
        f0 = m.Float(0)
        y = ek.select(x < (20 if i == 0 else 0), x, x * (f0 / f0))
        ek.backward(y)
        g = ek.grad(x)
        if i == 0:
            assert g == 1
        else:
            assert ek.all(ek.isnan(g))

    for i in range(2):
        x = ek.arange(m.Float, 10)
        ek.enable_grad(x)
        f0 = m.Float(0)
        y = ek.select(x < (20 if i == 0 else 0), x, x * (f0 / f0))
        ek.forward(x)
        g = ek.grad(y)
        if i == 0:
            assert g == 1
        else:
            assert ek.all(ek.isnan(g))
Exemple #11
0
def test02_add_fwd(m):
    if True:
        a, b = m.Float(1), m.Float(2)
        ek.enable_grad(a, b)
        c = 2 * a + b
        ek.forward(a, retain_graph=True)
        assert ek.grad(c) == 2
        ek.set_grad(c, 101)
        ek.forward(b)
        assert ek.grad(c) == 102

    if True:
        a, b = m.Float(1), m.Float(2)
        ek.enable_grad(a, b)
        c = 2 * a + b
        ek.set_grad(a, 1.0)
        ek.enqueue(a)
        ek.traverse(m.Float, retain_graph=True, reverse=False)
        assert ek.grad(c) == 2
        assert ek.grad(a) == 0
        ek.set_grad(a, 1.0)
        ek.enqueue(a)
        ek.traverse(m.Float, retain_graph=False, reverse=False)
        assert ek.grad(c) == 4
def test13_differentiable_surface_interaction_ray_forward(
        variant_gpu_autodiff_rgb):
    from mitsuba.core import xml, Ray3f, Vector3f, UInt32

    scene = xml.load_string('''
        <scene version="2.0.0">
            <shape type="obj" id="rect">
                <string name="filename" value="resources/data/common/meshes/rectangle.obj"/>
            </shape>
        </scene>
    ''')

    ray = Ray3f(Vector3f(-0.3, -0.4, -10.0), Vector3f(0.0, 0.0, 1.0), 0, [])
    pi = scene.ray_intersect_preliminary(ray)

    ek.set_requires_gradient(ray.o)
    ek.set_requires_gradient(ray.d)

    # If the ray origin is shifted along the x-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.p), [1, 0, 0])

    # If the ray origin is shifted along the x-axis, so does si.uv
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.uv), [0.5, 0])

    # If the ray origin is shifted along the z-axis, so does si.t
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.z)
    assert ek.allclose(ek.gradient(si.t), -1)

    # If the ray direction is shifted along the x-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.d.x)
    assert ek.allclose(ek.gradient(si.p), [10, 0, 0])
Exemple #13
0
def test08_hsum_1_fwd(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(ek.hsum_async(x) * x)
    ek.forward(x)
    assert ek.allclose(ek.grad(y), 100)
Exemple #14
0
def test05_differentiable_surface_interaction_ray_forward(
        variant_gpu_autodiff_rgb):
    from mitsuba.core import xml, Ray3f, Vector3f, UInt32

    shape = xml.load_dict({'type': 'sphere'})

    ray = Ray3f(Vector3f(0.0, -10.0, 0.0), Vector3f(0.0, 1.0, 0.0), 0, [])
    pi = shape.ray_intersect_preliminary(ray)

    ek.set_requires_gradient(ray.o)
    ek.set_requires_gradient(ray.d)

    # If the ray origin is shifted along the x-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.p), [1, 0, 0])

    # If the ray origin is shifted along the z-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.z)
    assert ek.allclose(ek.gradient(si.p), [0, 0, 1])

    # If the ray origin is shifted along the y-axis, so does si.t
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.y)
    assert ek.allclose(ek.gradient(si.t), -1)

    # If the ray direction is shifted along the x-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.d.x)
    assert ek.allclose(ek.gradient(si.p), [9, 0, 0])

    # If the ray origin is shifted tangent to the sphere (azimuth), so si.uv.x move by 1 / 2pi
    ek.set_requires_gradient(ray.o)
    si = shape.ray_intersect(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.uv), [1 / (2.0 * ek.pi), 0])

    # If the ray origin is shifted tangent to the sphere (inclination), so si.uv.y move by 2 / 2pi
    ek.set_requires_gradient(ray.o)
    si = shape.ray_intersect(ray)
    ek.forward(ray.o.z)
    assert ek.allclose(ek.gradient(si.uv), [0, -2 / (2.0 * ek.pi)])

    # # If the ray origin is shifted along the x-axis, so does si.n
    ek.set_requires_gradient(ray.o)
    si = shape.ray_intersect(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.n), [1, 0, 0])

    # # If the ray origin is shifted along the z-axis, so does si.n
    ek.set_requires_gradient(ray.o)
    si = shape.ray_intersect(ray)
    ek.forward(ray.o.z)
    assert ek.allclose(ek.gradient(si.n), [0, 0, 1])
Exemple #15
0
 def dK(self, x, m_):
     m_ = m.Float(m_)  # Convert 'm' to differentiable type
     ek.enable_grad(m_)
     y = self.K(x, m_)
     ek.forward(m_)
     return ek.grad(y)
def test04_differentiable_surface_interaction_ray_forward(
        variant_gpu_autodiff_rgb):
    from mitsuba.core import xml, Ray3f, Vector3f, UInt32

    shape = xml.load_dict({'type': 'disk'})

    ray = Ray3f(Vector3f(0.1, -0.2, -10.0), Vector3f(0.0, 0.0, 1.0), 0, [])
    pi = shape.ray_intersect_preliminary(ray)

    ek.set_requires_gradient(ray.o)
    ek.set_requires_gradient(ray.d)

    # If the ray origin is shifted along the x-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.p), [1, 0, 0])

    # If the ray origin is shifted along the y-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.y)
    assert ek.allclose(ek.gradient(si.p), [0, 1, 0])

    # If the ray origin is shifted along the z-axis, so does si.t
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.o.z)
    assert ek.allclose(ek.gradient(si.t), -1)

    # If the ray direction is shifted along the x-axis, so does si.p
    si = pi.compute_surface_interaction(ray)
    ek.forward(ray.d.x)
    assert ek.allclose(ek.gradient(si.p), [10, 0, 0])

    # If the ray origin is shifted toward the center of the disk, so does si.uv.x
    ray = Ray3f(Vector3f(0.9999999, 0.0, -10.0), Vector3f(0.0, 0.0, 1.0), 0,
                [])
    ek.set_requires_gradient(ray.o)
    si = shape.ray_intersect(ray)
    ek.forward(ray.o.x)
    assert ek.allclose(ek.gradient(si.uv), [1, 0])

    # If the ray origin is shifted tangent to the disk, si.uv.y moves by 1 / (2pi)
    si = shape.ray_intersect(ray)
    ek.forward(ray.o.y)
    assert ek.allclose(ek.gradient(si.uv), [0, 0.5 / ek.pi], atol=1e-5)

    # If the ray origin is shifted tangent to the disk, si.dp_dv will also have a component is x
    si = shape.ray_intersect(ray)
    ek.forward(ray.o.y)
    assert ek.allclose(ek.gradient(si.dp_dv), [-1, 0, 0])
Exemple #17
0
def test10_hsum_2_fwd(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(ek.hsum_async(x * x) * ek.hsum_async(x * x))
    ek.forward(x)
    assert ek.allclose(ek.grad(y), 1900.0 / 27.0)
Exemple #18
0
def run_ad(integrator, sc, fname, args):
    global time_threshold

    ad_config = args["AD"]

    if "spp" in ad_config:
        sc.opts.spp = ad_config["spp"]
    if "sppe" in ad_config:
        sc.opts.sppe = ad_config["sppe"]
    if "sppse" in ad_config:
        sc.opts.sppse = ad_config["sppse"]

    if "no_edge" in ad_config:
        for i in ad_config["no_edge"]:
            sc.param_map["Mesh[" + str(i) + "]"].enable_edges = False

    ro = sc.opts
    if ad_config["type"] == "mesh_transform":
        if len(ad_config["Mesh_ID"]) != len(ad_config["Mesh_dir"]):
            raise Exception("Mesh_ID and Mesh_dir have different sizes")
    elif ad_config["type"] == "mesh_rotate":
        if len(ad_config["Mesh_ID"]) != len(ad_config["axis"]):
            raise Exception("Mesh_ID and axis have different sizes")
    elif ad_config["type"] == "vertex_transform":
        if len(ad_config["Mesh_ID"]) != len(ad_config["Vertex_ID"]):
            raise Exception("Mesh_ID and Vertex_ID have different sizes")
        orig_vtx_pos = {}
        for j in ad_config["Mesh_ID"]:
            mesh_obj = sc.param_map["Mesh[" + str(j) + "]"]
            orig_vtx_pos[j] = ek.detach(mesh_obj.vertex_positions)
    elif ad_config["type"] == "material_roughness":
        base_roughness = {}
        for j in ad_config["BSDF_ID"]:
            bsdf_obj = sc.param_map["BSDF[" + str(j) + "]"]
            base_roughness[j] = (ek.detach(bsdf_obj.alpha_u.data),
                                 ek.detach(bsdf_obj.alpha_v.data))
    elif ad_config["type"] == "envmap_rotate":
        if "Emitter_ID" not in ad_config:
            raise Exception("Missing Emitter_ID")
    else:
        raise Exception("Unknown transform")

    if "npass" in ad_config:
        npass = ad_config["npass"]
    elif "npass" in args:
        npass = args["npass"]
    else:
        npass = 1

    num_sensors = sc.num_sensors
    img_ad = [None] * num_sensors

    t0 = time.process_time()
    t1 = t0
    for i in range(npass):
        # AD config
        P = FloatD(0.)
        ek.set_requires_gradient(P)

        if ad_config["type"] == "mesh_transform":
            for j in range(len(ad_config["Mesh_ID"])):
                mesh_transform(sc, ad_config["Mesh_ID"][j],
                               Vector3fD(ad_config["Mesh_dir"][j]) * P)
        elif ad_config["type"] == "mesh_rotate":
            for j in range(len(ad_config["Mesh_ID"])):
                mesh_rotate(sc, ad_config["Mesh_ID"][j],
                            Vector3fD(ad_config["axis"][j]), P)
        elif ad_config["type"] == "vertex_transform":
            for j in range(len(ad_config["Mesh_ID"])):
                vertex_transform(sc, ad_config["Mesh_ID"][j],
                                 ad_config["Vertex_ID"][j],
                                 ad_config["Vertex_dir"][j], orig_vtx_pos[j],
                                 P)
        elif ad_config["type"] == "material_roughness":
            for j in ad_config["BSDF_ID"]:
                material_roughness(sc, j, base_roughness[j], P)
        elif ad_config["type"] == "envmap_rotate":
            envmap_rotate(sc, ad_config["Emitter_ID"], ad_config["axis"], P)
        # End AD config
        sc.configure()

        for sensor_id in range(num_sensors):
            if i == 0 and "guide" in ad_config:
                t2 = time.process_time()
                guide_info = ad_config["guide"]
                integrator.preprocess_secondary_edges(
                    sc, sensor_id, np.array(guide_info["reso"]),
                    guide_info["nround"])
                print("guiding done in %.2f seconds." %
                      (time.process_time() - t2))

            img = integrator.renderD(sc, sensor_id)
            ek.forward(P, free_graph=True)

            grad_img = ek.gradient(img).numpy()
            grad_img[np.logical_not(np.isfinite(grad_img))] = 0.
            if i == 0:
                img_ad[sensor_id] = grad_img
            else:
                img_ad[sensor_id] += grad_img
            del img
        del P

        t2 = time.process_time()
        if t2 - t1 > time_threshold:
            print("(%d/%d) done in %.2f seconds." % (i + 1, npass, t2 - t0),
                  end="\r")
            t1 = t2
    print("(%d/%d) Total AD rendering time: %.2f seconds." %
          (npass, npass, t2 - t0))

    for sensor_id in range(num_sensors):
        img = (img_ad[sensor_id] / float(npass)).reshape(
            (ro.height, ro.width, 3))
        output = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(fname[:-4] + "_" + str(sensor_id) + fname[-4:], output)
def test15_differentiable_surface_interaction_params_forward(
        variant_gpu_autodiff_rgb):
    from mitsuba.core import xml, Float, Ray3f, Vector3f, UInt32, Transform4f

    # Convert flat array into a vector of arrays (will be included in next enoki release)
    def ravel(buf, dim=3):
        idx = dim * UInt32.arange(ek.slices(buf) // dim)
        if dim == 2:
            return Vector2f(ek.gather(buf, idx), ek.gather(buf, idx + 1))
        elif dim == 3:
            return Vector3f(ek.gather(buf, idx), ek.gather(buf, idx + 1),
                            ek.gather(buf, idx + 2))

    # Return contiguous flattened array (will be included in next enoki release)
    def unravel(source, target, dim=3):
        idx = UInt32.arange(ek.slices(source))
        for i in range(dim):
            ek.scatter(target, source[i], dim * idx + i)

    scene = xml.load_string('''
        <scene version="2.0.0">
            <shape type="obj" id="rect">
                <string name="filename" value="resources/data/common/meshes/rectangle.obj"/>
            </shape>
        </scene>
    ''')

    params = traverse(scene)
    shape_param_key = 'rect.vertex_positions_buf'
    positions_buf = params[shape_param_key]
    positions_initial = ravel(positions_buf)

    # Create differential parameter to be optimized
    diff_vector = Vector3f(0.0)
    ek.set_requires_gradient(diff_vector)

    # Apply the transformation to mesh vertex position and update scene
    def apply_transformation(trasfo):
        trasfo = trasfo(diff_vector)
        new_positions = trasfo.transform_point(positions_initial)
        unravel(new_positions, params[shape_param_key])
        params.set_dirty(shape_param_key)
        params.update()

    # ---------------------------------------
    # Test translation

    ray = Ray3f(Vector3f(-0.2, -0.3, -10.0), Vector3f(0.0, 0.0, 1.0), 0, [])
    pi = scene.ray_intersect_preliminary(ray)

    # # If the vertices are shifted along z-axis, so does si.t
    apply_transformation(lambda v: Transform4f.translate(v))
    si = pi.compute_surface_interaction(ray)
    ek.forward(diff_vector.z)
    assert ek.allclose(ek.gradient(si.t), 1)

    # If the vertices are shifted along z-axis, so does si.p
    apply_transformation(lambda v: Transform4f.translate(v))
    si = pi.compute_surface_interaction(ray)
    ek.forward(diff_vector.z)
    assert ek.allclose(ek.gradient(si.p), [0.0, 0.0, 1.0])

    # If the vertices are shifted along x-axis, so does si.uv (times 0.5)
    apply_transformation(lambda v: Transform4f.translate(v))
    si = pi.compute_surface_interaction(ray)
    ek.forward(diff_vector.x)
    assert ek.allclose(ek.gradient(si.uv), [-0.5, 0.0])

    # If the vertices are shifted along y-axis, so does si.uv (times 0.5)
    apply_transformation(lambda v: Transform4f.translate(v))
    si = pi.compute_surface_interaction(ray)
    ek.forward(diff_vector.y)
    assert ek.allclose(ek.gradient(si.uv), [0.0, -0.5])

    # ---------------------------------------
    # Test rotation

    ray = Ray3f(Vector3f(-0.99999, -0.99999, -10.0), Vector3f(0.0, 0.0, 1.0),
                0, [])
    pi = scene.ray_intersect_preliminary(ray)

    # If the vertices are rotated around the center, so does si.uv (times 0.5)
    apply_transformation(lambda v: Transform4f.rotate([0, 0, 1], v.x))
    si = pi.compute_surface_interaction(ray)
    ek.forward(diff_vector.x)
    du = 0.5 * ek.sin(2 * ek.pi / 360.0)
    assert ek.allclose(ek.gradient(si.uv), [-du, du], atol=1e-6)