    def pdf(self, ctx, si, wo, active):
        if not ctx.is_enabled(BSDFFlags.DiffuseReflection):
            return Vector3f(0)

        cos_theta_i = Frame3f.cos_theta(si.wi)
        cos_theta_o = Frame3f.cos_theta(wo)

        pdf = warp.square_to_cosine_hemisphere_pdf(wo)

        return ek.select((cos_theta_i > 0.0) & (cos_theta_o > 0.0), pdf, 0.0)
def test03_frame_equality(variant_scalar_rgb):
    from mitsuba.core import Frame3f

    f1 = Frame3f([1, 0, 0], [0, 1, 0], [0, 0, 1])
    f2 = Frame3f([0, 0, 1])
    f3 = Frame3f([0, 0, 1], [0, 1, 0], [1, 0, 0])

    assert f1 == f2
    assert f2 == f1
    assert not f1 == f3
    assert not f2 == f3
    def eval(self, ctx, si, wo, active):
        if not ctx.is_enabled(BSDFFlags.DiffuseReflection):
            return Vector3f(0)

        cos_theta_i = Frame3f.cos_theta(si.wi)
        cos_theta_o = Frame3f.cos_theta(wo)

        value = self.m_reflectance.eval(si, active) * math.InvPi * cos_theta_o

        return ek.select((cos_theta_i > 0.0) & (cos_theta_o > 0.0), value,
    def eval(self, ctx, si, wo, active):
        Emitter sampling
        if not ctx.is_enabled(BSDFFlags.DiffuseReflection):
            return Vector3f(0)

        cos_theta_i = Frame3f.cos_theta(si.wi) 
        cos_theta_o = Frame3f.cos_theta(wo)

        value = self.get_btf(si.wi, wo, si.uv) * math.InvPi

        return ek.select((cos_theta_i > 0.0) & (cos_theta_o > 0.0), value, Vector3f(0))
def test03_mueller_to_world_to_local(variant_scalar_mono_polarized):
    At a few places, coordinate changes between local BSDF reference frame and
    world coordinates need to take place. This change also needs to be applied
    to Mueller matrices used in computations involving polarization state.

    In practice, this is always a simple rotation of reference Stokes vectors
    (for incident & outgoing directions) of the Mueller matrix.

    To test this behavior we take any Mueller matrix (e.g. linear polarizer)
    for some arbitrary incident/outgoing directions in world coordinates and
    compute the round trip going to local frame and back again.
    from mitsuba.core import Frame3f, UnpolarizedSpectrum
    from mitsuba.render import SurfaceInteraction3f
    from mitsuba.render.mueller import linear_polarizer
    import numpy as np

    si = SurfaceInteraction3f()
    si.sh_frame = Frame3f(ek.normalize([1.0, 1.0, 1.0]))

    M = linear_polarizer(UnpolarizedSpectrum(1.0))

    # Random incident and outgoing directions
    wi_world = ek.normalize([0.2, 0.0, 1.0])
    wo_world = ek.normalize([0.0, -0.8, 1.0])

    wi_local = si.to_local(wi_world)
    wo_local = si.to_local(wo_world)

    M_local = si.to_local_mueller(M, wi_world, wo_world)
    M_world = si.to_world_mueller(M_local, wi_local, wo_local)

    assert ek.allclose(M, M_world, atol=1e-5)
    def eval(self, si, active):
        cosTheta = ek.dot(si.wi, si.n)

        tmp2 = ek.select(Frame3f.cos_theta(si.wi) > 0, \
                self.m_radiance.eval(si, active)*self.fallof(cosTheta), \
        return tmp2
def test03_sample_ray(variant_packet_spectral, spectrum_key):
    # Check the correctness of the sample_ray() method

    from mitsuba.core import warp, Frame3f, sample_shifted
    from mitsuba.render import SurfaceInteraction3f

    shape, spectrum = create_emitter_and_spectrum(spectrum_key)
    emitter = shape.emitter()

    time = 0.5
    wavelength_sample = [0.5, 0.33, 0.1]
    pos_sample = [[0.2, 0.1, 0.2], [0.6, 0.9, 0.2]]
    dir_sample = [[0.4, 0.5, 0.3], [0.1, 0.4, 0.9]]

    # Sample a ray (position, direction, wavelengths) on the emitter
    ray, res = emitter.sample_ray(time, wavelength_sample, pos_sample,

    # Sample wavelengths on the spectrum
    it = SurfaceInteraction3f.zero(3)
    wav, spec = spectrum.sample_spectrum(it, sample_shifted(wavelength_sample))

    # Sample a position on the shape
    ps = shape.sample_position(time, pos_sample)

    assert ek.allclose(res, spec * shape.surface_area() * ek.pi)
    assert ek.allclose(ray.time, time)
    assert ek.allclose(ray.wavelengths, wav)
    assert ek.allclose(ray.o, ps.p)
    assert ek.allclose(
def test02_eval_all(variant_scalar_rgb):
    from mitsuba.core import Frame3f
    from mitsuba.render import BSDFFlags, BSDFContext, SurfaceInteraction3f
    from mitsuba.core.xml import load_string
    from mitsuba.core.math import InvPi

    weight = 0.2

    bsdf = load_string("""<bsdf version="2.0.0" type="blendbsdf">
        <bsdf type="diffuse">
            <spectrum name="reflectance" value="0.0"/>
        <bsdf type="diffuse">
            <spectrum name="reflectance" value="1.0"/>
        <spectrum name="weight" value="{}"/>

    si = SurfaceInteraction3f()
    si.t = 0.1
    si.p = [0, 0, 0]
    si.n = [0, 0, 1]
    si.sh_frame = Frame3f(si.n)
    si.wi = [0, 0, 1]

    wo = [0, 0, 1]
    ctx = BSDFContext()

    # Evaluate the blend of both components
    expected = (1 - weight) * 0.0 * InvPi + weight * 1.0 * InvPi
    value    = bsdf.eval(ctx, si, wo)
    assert ek.allclose(value, expected)
def test02_sample_quarter_wave_local(variant_scalar_mono_polarized):
    from mitsuba.core import Frame3f, Transform4f, Spectrum
    from mitsuba.core.xml import load_string
    from mitsuba.render import BSDFContext, TransportMode, SurfaceInteraction3f

    def spectrum_from_stokes(v):
        res = Spectrum(0.0)
        for i in range(4):
            res[i, 0] = v[i]
        return res

    # Test polarized implementation. Special case of delta = 90˚, also known
    # as a quarter-wave plate. (In local BSDF coordinates.)
    # Following "Polarized Light - Fundamentals and Applications" by Edward Collett
    # Chapter 5.3, equation (30) & (31):
    # Case 1) Linearly polarized +45˚ light (Stokes vector [1, 0, 1, 0]) yields
    #         right circularly polarized light (Stokes vector [1, 0, 0, 1]).
    # Case 2) Linearly polarized -45˚ light (Stokes vector [1, 0, -1, 0]) yields
    #         left circularly polarized light (Stokes vector [1, 0, 0, -1]).
    # Case 3) Right circularly polarized light (Stokes vector [1, 0, 0, 1]) yields
    #         linearly polarized -45˚ light (Stokes vector [1, 0, -1, 0]).
    # Case 4) Left circularly polarized light (Stokes vector [1, 0, 0, -1]) yields
    #         linearly polarized +45˚ light (Stokes vector [1, 0, 1, 0]).

    linear_pos = spectrum_from_stokes([1, 0, +1, 0])
    linear_neg = spectrum_from_stokes([1, 0, -1, 0])
    circular_right = spectrum_from_stokes([1, 0, 0, +1])
    circular_left = spectrum_from_stokes([1, 0, 0, -1])

    bsdf = load_string("""<bsdf version='2.0.0' type='retarder'>
                          <spectrum name="theta" value="0"/>
                          <spectrum name="delta" value="90.0"/>

    # Incident direction
    wi = [0, 0, 1]

    ctx = BSDFContext()
    ctx.mode = TransportMode.Importance
    si = SurfaceInteraction3f()
    si.p = [0, 0, 0]
    si.wi = wi
    n = [0, 0, 1]
    si.n = n
    si.sh_frame = Frame3f(si.n)

    bs, M = bsdf.sample(ctx, si, 0.0, [0.0, 0.0])

    # Case 1)
    assert ek.allclose(M @ linear_pos, circular_right, atol=1e-3)
    # Case 2)
    assert ek.allclose(M @ linear_neg, circular_left, atol=1e-3)
    # Case 3)
    assert ek.allclose(M @ circular_right, linear_neg, atol=1e-3)
    # Case 4)
    assert ek.allclose(M @ circular_left, linear_pos, atol=1e-3)
 def make_context(n):
     mi = MediumInteraction3f.zero(n)
     mi.wi = wi
     ek.set_slices(mi.wi, n)
     mi.sh_frame = Frame3f(-mi.wi)
     mi.wavelengths = []
     ctx = PhaseFunctionContext(None)
     return mi, ctx
    def sample(self, ctx, si, sample1, sample2, active):
        BSDF sampling
        cos_theta_i = Frame3f.cos_theta(si.wi)

        active &= cos_theta_i > 0

        bs = BSDFSample3f()
        bs.wo  = warp.square_to_cosine_hemisphere(sample2)
        bs.pdf = warp.square_to_cosine_hemisphere_pdf(bs.wo)
        bs.eta = 1.0
        bs.sampled_type = +BSDFFlags.DiffuseReflection
        bs.sampled_component = 0

        value = self.get_btf(si.wi, bs.wo, si.uv) / Frame3f.cos_theta(bs.wo)

        return ( bs, ek.select(active & (bs.pdf > 0.0), value, Vector3f(0)) )
def interaction():
    from mitsuba.core import Frame3f
    from mitsuba.render import SurfaceInteraction3f
    si = SurfaceInteraction3f()
    si.t = 0.1
    si.p = [0, 0, 0]
    si.n = [0, 0, 1]
    si.sh_frame = Frame3f(si.n)
    return si
def test01_construction(variant_scalar_rgb):
    from mitsuba.core import Frame3f

    # Uninitialized frame
    _ = Frame3f()

    # Frame3f from the 3 vectors: no normalization should be performed
    f1 = Frame3f([0.005, 50, -6], [0.01, -13.37, 1], [0.5, 0, -6.2])
    assert ek.allclose(f1.s, [0.005, 50, -6])
    assert ek.allclose(f1.t, [0.01, -13.37, 1])
    assert ek.allclose(f1.n, [0.5, 0, -6.2])

    # Frame3f from the Normal component only
    f2 = Frame3f([0, 0, 1])
    assert ek.allclose(f2.s, [1, 0, 0])
    assert ek.allclose(f2.t, [0, 1, 0])
    assert ek.allclose(f2.n, [0, 0, 1])

    # Copy constructor
    f3 = Frame3f(f2)
    assert f2 == f3
def test03_sample_half_wave_local(variant_scalar_mono_polarized):
    from mitsuba.core import Frame3f, Transform4f, Spectrum
    from mitsuba.core.xml import load_string
    from mitsuba.render import BSDFContext, TransportMode, SurfaceInteraction3f

    def spectrum_from_stokes(v):
        res = Spectrum(0.0)
        for i in range(4):
            res[i, 0] = v[i]
        return res

    # Test polarized implementation. Special case of delta = 180˚, also known
    # as a half-wave plate. (In local BSDF coordinates.)
    # Following "Polarized Light - Fundamentals and Applications" by Edward Collett
    # Chapter 5.3:
    # Case 1 & 2) Switch between diagonal linear polarization states (-45˚ & + 45˚)
    # Case 3 & 4) Switch circular polarization direction

    linear_pos = spectrum_from_stokes([1, 0, +1, 0])
    linear_neg = spectrum_from_stokes([1, 0, -1, 0])
    circular_right = spectrum_from_stokes([1, 0, 0, +1])
    circular_left = spectrum_from_stokes([1, 0, 0, -1])

    bsdf = load_string("""<bsdf version='2.0.0' type='retarder'>
                          <spectrum name="theta" value="0"/>
                          <spectrum name="delta" value="180.0"/>

    # Incident direction
    wi = [0, 0, 1]

    ctx = BSDFContext()
    ctx.mode = TransportMode.Importance
    si = SurfaceInteraction3f()
    si.p = [0, 0, 0]
    si.wi = wi
    n = [0, 0, 1]
    si.n = n
    si.sh_frame = Frame3f(si.n)

    bs, M = bsdf.sample(ctx, si, 0.0, [0.0, 0.0])

    # Case 1)
    assert ek.allclose(M @ linear_pos, linear_neg, atol=1e-3)
    # Case 2)
    assert ek.allclose(M @ linear_neg, linear_pos, atol=1e-3)
    # Case 3)
    assert ek.allclose(M @ circular_right, circular_left, atol=1e-3)
    # Case 4)
    assert ek.allclose(M @ circular_left, circular_right, atol=1e-3)
def test01_intersection_construction(variant_scalar_rgb):
    from mitsuba.core import Frame3f
    from mitsuba.render import SurfaceInteraction3f

    si = SurfaceInteraction3f()
    si.shape = None
    si.t = 1
    si.time = 2
    si.wavelengths = []
    si.p = [1, 2, 3]
    si.n = [4, 5, 6]
    si.uv = [7, 8]
    si.sh_frame = Frame3f([9, 10, 11], [12, 13, 14], [15, 16, 17])
    si.dp_du = [18, 19, 20]
    si.dp_dv = [21, 22, 23]
    si.duv_dx = [24, 25]
    si.duv_dy = [26, 27]
    si.wi = [31, 32, 33]
    si.prim_index = 34
    si.instance = None
    assert si.sh_frame == Frame3f([9, 10, 11], [12, 13, 14], [15, 16, 17])
    assert repr(si) == """SurfaceInteraction[
def test05_sample_components(variant_scalar_rgb):
    from mitsuba.core import Frame3f
    from mitsuba.render import BSDFFlags, BSDFContext, SurfaceInteraction3f
    from mitsuba.core.xml import load_string
    from mitsuba.core.math import InvPi

    weight = 0.2

    bsdf = load_string("""<bsdf version="2.0.0" type="blendbsdf">
        <bsdf type="diffuse">
            <spectrum name="reflectance" value="0.0"/>
        <bsdf type="diffuse">
            <spectrum name="reflectance" value="1.0"/>
        <spectrum name="weight" value="{}"/>

    si = SurfaceInteraction3f()
    si.t = 0.1
    si.p = [0, 0, 0]
    si.n = [0, 0, 1]
    si.sh_frame = Frame3f(si.n)
    si.wi = [0, 0, 1]

    ctx = BSDFContext()

    # Sample specific components separately using two different values of 'sample1'
    # and make sure the desired component is chosen always.

    ctx.component = 0

    expected_a = (1-weight)*0.0    # InvPi will cancel out with sampling pdf, but still need to apply weight
    bs_a, weight_a = bsdf.sample(ctx, si, 0.1, [0.5, 0.5])
    assert ek.allclose(weight_a, expected_a)

    expected_b = (1-weight)*0.0    # InvPi will cancel out with sampling pdf, but still need to apply weight
    bs_b, weight_b = bsdf.sample(ctx, si, 0.3, [0.5, 0.5])
    assert ek.allclose(weight_b, expected_b)

    ctx.component = 1

    expected_a = weight*1.0    # InvPi will cancel out with sampling pdf, but still need to apply weight
    bs_a, weight_a = bsdf.sample(ctx, si, 0.1, [0.5, 0.5])
    assert ek.allclose(weight_a, expected_a)

    expected_b = weight*1.0    # InvPi will cancel out with sampling pdf, but still need to apply weight
    bs_b, weight_b = bsdf.sample(ctx, si, 0.3, [0.5, 0.5])
    assert ek.allclose(weight_b, expected_b)
    def sample(self, ctx, si, sample1, sample2, active):
        cos_theta_i = Frame3f.cos_theta(si.wi)

        active &= cos_theta_i > 0

        bs = BSDFSample3f()
        bs.wo = mr.retro_transmit(si.wi)
        bs.pdf = 1
        bs.sampled_type = +BSDFFlags.DeltaTransmission
        bs.sampled_component = 0
        bs.eta = 1

        value = self.m_retro_transmittance.eval(si, active)

        return (bs, ek.select(active, value, 0))
    def sample_ray(
            sample1,  # wavelength
            sample2,  # pos
            sample3,  # dir

        ps = self.m_shape.sample_position(time, sample2, active)
        local = warp.square_to_cosine_hemisphere(sample3)

        si = SurfaceInteraction3f(ps, 0)
        wavelengths, spec_weight = self.m_radiance.sample(
            si, ek.arange(sample1), active)

        ray = Ray3f(ps.p, Frame3f(ps.n).to_world(local), time, wavelengths)
        return (ray, spec_weight * self.m_area_times_pi)
def test02_eval_pdf(variant_scalar_rgb):
    from mitsuba.core import Frame3f
    from mitsuba.render import BSDFContext, BSDFFlags, SurfaceInteraction3f
    from mitsuba.core.xml import load_string

    bsdf = load_string("<bsdf version='2.0.0' type='diffuse'></bsdf>")

    si    = SurfaceInteraction3f()
    si.p  = [0, 0, 0]
    si.n  = [0, 0, 1]
    si.wi = [0, 0, 1]
    si.sh_frame = Frame3f(si.n)

    ctx = BSDFContext()

    for i in range(20):
        theta = i / 19.0 * (ek.pi / 2)
        wo = [ek.sin(theta), 0, ek.cos(theta)]

        v_pdf  = bsdf.pdf(ctx, si, wo=wo)
        v_eval = bsdf.eval(ctx, si, wo=wo)[0]
        assert ek.allclose(v_pdf, wo[2] / ek.pi)
        assert ek.allclose(v_eval, 0.5 * wo[2] / ek.pi)
def test04_sample_all(variant_scalar_rgb):
    from mitsuba.core import Frame3f
    from mitsuba.render import BSDFFlags, BSDFContext, SurfaceInteraction3f
    from mitsuba.core.xml import load_string
    from mitsuba.core.math import InvPi

    weight = 0.2

    bsdf = load_string("""<bsdf version="2.0.0" type="blendbsdf">
        <bsdf type="diffuse">
            <spectrum name="reflectance" value="0.0"/>
        <bsdf type="diffuse">
            <spectrum name="reflectance" value="1.0"/>
        <spectrum name="weight" value="{}"/>

    si = SurfaceInteraction3f()
    si.t = 0.1
    si.p = [0, 0, 0]
    si.n = [0, 0, 1]
    si.sh_frame = Frame3f(si.n)
    si.wi = [0, 0, 1]

    ctx = BSDFContext()

    # Sample using two different values of 'sample1' and make sure correct
    # components are chosen.

    expected_a = 1.0    # InvPi & weight will cancel out with sampling pdf
    bs_a, weight_a = bsdf.sample(ctx, si, 0.1, [0.5, 0.5])
    assert ek.allclose(weight_a, expected_a)

    expected_b = 0.0    # InvPi & weight will cancel out with sampling pdf
    bs_b, weight_b = bsdf.sample(ctx, si, 0.3, [0.5, 0.5])
    assert ek.allclose(weight_b, expected_b)
def test01_roughplastic(variant_scalar_rgb):
    from mitsuba.core.xml import load_string
    from mitsuba.render import BSDF, BSDFContext, SurfaceInteraction3f
    from mitsuba.core import Frame3f

    thetas = np.linspace(0, np.pi / 2, 20)
    phi = np.pi

    values_ref = []

    # Create plastic reference BSDF
    bsdf = load_string("""<bsdf version="2.0.0" type="roughplastic">
                              <spectrum name="diffuse_reflectance" value="0.5"/>
                              <float name="alpha" value="0.3"/>
                              <string name="distribution" value="beckmann"/>
                              <float name="int_ior" value="1.5"/>
                              <float name="ext_ior" value="1.0"/>
                              <boolean name="nonlinear" value="true"/>

    theta_i = np.radians(30.0)
    si = SurfaceInteraction3f()
    si.p = [0, 0, 0]
    si.n = [0, 0, 1]
    si.wi = [np.sin(theta_i), 0, np.cos(theta_i)]
    si.sh_frame = Frame3f(si.n)
    ctx = BSDFContext()

    for theta in thetas:
        wo = [
            np.sin(theta) * np.cos(phi),
            np.sin(theta) * np.sin(phi),
        values_ref.append(bsdf.eval(ctx, si, wo=wo)[0])

    # Create same BSDF as layer representation by applying adding equations
    n, ms, md = mitsuba.layer.microfacet_parameter_heuristic(0.3, 0.3, 1.5)
    mu, w = mitsuba.core.quad.gauss_lobatto(n)

    coating = mitsuba.layer.Layer(mu, w, ms, md)
    coating.set_microfacet(1.5, 0.3, 0.3)
    base = mitsuba.layer.Layer(mu, w, ms, md)

    layer = mitsuba.layer.Layer.add(coating, base)

    for i, theta in enumerate(thetas):
        l_eval = layer.eval(-np.cos(theta), np.cos(theta_i)) * np.abs(
        # Values should be close (except if they are insignificantly small).
        # We have less precision at grazing angles because of Fourier representation.
        assert values_ref[i] < 1e-5 or np.allclose(
            values_ref[i], l_eval, rtol=0.05 / (np.abs(np.cos(theta))))

    # Convert into BSDF storage representation
    base_path = os.path.dirname(os.path.realpath(__file__)) + "/data/"
    if not os.path.exists(base_path):
    path = base_path + "roughplastic.bsdf"
    storage = mitsuba.layer.BSDFStorage.from_layer(path, layer, 1e-5)

    for i, theta in enumerate(thetas):
        s_eval = storage.eval(np.cos(theta_i), -np.cos(theta))[0]
        # Values should be close (except if they are insignificantly small).
        # We have less precision at grazing angles because of Fourier representation.
        assert values_ref[i] < 1e-5 or np.allclose(
            values_ref[i], s_eval, rtol=0.05 / (np.abs(np.cos(theta))))

    # And load via the "fourier" BSDF plugin
    fourier = load_string("""<bsdf version="2.0.0" type="fourier">
                                 <string name="filename" value="{}"/>

    for i, theta in enumerate(thetas):
        wo = [
            np.sin(theta) * np.cos(phi),
            np.sin(theta) * np.sin(phi),
        f_eval = fourier.eval(ctx, si, wo=wo)[0]
        assert values_ref[i] < 1e-5 or np.allclose(
            values_ref[i], f_eval, rtol=0.05 / (np.abs(np.cos(theta))))
    del fourier
def test02_unit_frame(variant_scalar_rgb):
    from mitsuba.core import Frame3f, Vector2f, Vector3f

    for theta in [30 * mitsuba.core.math.Pi / 180, 95 * mitsuba.core.math.Pi / 180]:
        phi = 73 * mitsuba.core.math.Pi / 180
        sin_theta, cos_theta = ek.sin(theta), ek.cos(theta)
        sin_phi, cos_phi = ek.sin(phi), ek.cos(phi)

        v = Vector3f(
            cos_phi * sin_theta,
            sin_phi * sin_theta,
        f = Frame3f(Vector3f(1.0, 2.0, 3.0) / ek.sqrt(14))

        v2 = f.to_local(v)
        v3 = f.to_world(v2)

        assert ek.allclose(v3, v)

        assert ek.allclose(Frame3f.cos_theta(v), cos_theta)
        assert ek.allclose(Frame3f.sin_theta(v), sin_theta)
        assert ek.allclose(Frame3f.cos_phi(v), cos_phi)
        assert ek.allclose(Frame3f.sin_phi(v), sin_phi)
        assert ek.allclose(Frame3f.cos_theta_2(v), cos_theta * cos_theta)
        assert ek.allclose(Frame3f.sin_theta_2(v), sin_theta * sin_theta)
        assert ek.allclose(Frame3f.cos_phi_2(v), cos_phi * cos_phi)
        assert ek.allclose(Frame3f.sin_phi_2(v), sin_phi * sin_phi)
        assert ek.allclose(Vector2f(Frame3f.sincos_phi(v)), [sin_phi, cos_phi])
        assert ek.allclose(Vector2f(Frame3f.sincos_phi_2(v)), [sin_phi * sin_phi, cos_phi * cos_phi])
def test02_sample_local(variant_scalar_mono_polarized):
    from mitsuba.core import Frame3f, Transform4f, Spectrum
    from mitsuba.core.xml import load_string
    from mitsuba.render import BSDFContext, TransportMode, SurfaceInteraction3f

    def spectrum_from_stokes(v):
        res = Spectrum(0.0)
        for i in range(4):
            res[i, 0] = v[i]
        return res

    # Test polarized implementation, version in local BSDF coordinate system
    # (surface normal aligned with "z").
    # The polarizer is rotated to different angles and hit with fully
    # unpolarized light (Stokes vector [1, 0, 0, 0]).
    # We then test if the outgoing Stokes vector corresponds to the expected
    # rotation of linearly polarized light (Case 1).
    # Additionally, a perfect linear polarizer should be invariant to "tilting",
    # i.e. rotations around "x" or "z" in this local frame (Case 2, 3).

    # Incident direction
    wi = [0, 0, 1]
    stokes_in = spectrum_from_stokes([1, 0, 0, 0])

    ctx = BSDFContext()
    ctx.mode = TransportMode.Importance
    si = SurfaceInteraction3f()
    si.p = [0, 0, 0]
    si.wi = wi
    n = [0, 0, 1]
    si.n = n
    si.sh_frame = Frame3f(si.n)

    # Polarizer rotation angles
    angles = [0, 90, +45, -45]
    # Expected outgoing Stokes vector
    expected_states = [spectrum_from_stokes([0.5,  0.5,  0,   0]),
                       spectrum_from_stokes([0.5, -0.5,  0,   0]),
                       spectrum_from_stokes([0.5,  0,   +0.5, 0]),
                       spectrum_from_stokes([0.5,  0,   -0.5, 0])]

    for k in range(len(angles)):
        angle = angles[k]
        expected = expected_states[k]

        bsdf = load_string("""<bsdf version='2.0.0' type='polarizer'>
                                  <spectrum name="theta" value="{}"/>

        # Case 1: Perpendicular incidence.
        bs, M = bsdf.sample(ctx, si, 0.0, [0.0, 0.0])

        stokes_out = M @ stokes_in
        assert ek.allclose(expected, stokes_out, atol=1e-3)

        def rotate_vector(v, axis, angle):
            return Transform4f.rotate(axis, angle).transform_vector(v)

        # Case 2: Tilt polarizer around "x". Should not change anything.
        # (Note: to stay with local coordinates, we rotate the incident direction instead.)
        si.wi = rotate_vector(wi, [1, 0, 0], angle=30.0)
        bs, M = bsdf.sample(ctx, si, 0.0, [0.0, 0.0])
        stokes_out = M @ stokes_in
        assert ek.allclose(expected, stokes_out, atol=1e-3)

        # Case 3: Tilt polarizer around "y". Should not change anything.
        # (Note: to stay with local coordinates, we rotate the incident direction instead.)
        si.wi = rotate_vector(wi, [0, 1, 0], angle=30.0)
        bs, M = bsdf.sample(ctx, si, 0.0, [0.0, 0.0])
        stokes_out = M @ stokes_in
        assert ek.allclose(expected, stokes_out, atol=1e-3)
def test01_diffuse(variant_scalar_rgb):
    from mitsuba.core.xml import load_string
    from mitsuba.render import BSDF, BSDFContext, SurfaceInteraction3f
    from mitsuba.core import Frame3f

    thetas = np.linspace(0, np.pi / 2, 20)
    phi = np.pi

    values_ref = []

    # Create diffuse reference BSDF
    bsdf = load_string("""<bsdf version="2.0.0" type="diffuse">
                              <spectrum name="reflectance" value="0.5"/>

    theta_i = np.radians(30.0)
    si = SurfaceInteraction3f()
    si.p = [0, 0, 0]
    si.n = [0, 0, 1]
    si.wi = [np.sin(theta_i), 0, np.cos(theta_i)]
    si.sh_frame = Frame3f(si.n)
    ctx = BSDFContext()

    for theta in thetas:
        wo = [
            np.sin(theta) * np.cos(phi),
            np.sin(theta) * np.sin(phi),
        values_ref.append(bsdf.eval(ctx, si, wo=wo)[0])

    # Create same BSDF as layer representation
    n = 100
    ms = 1
    md = 1
    mu, w = mitsuba.core.quad.gauss_lobatto(n)
    layer = mitsuba.layer.Layer(mu, w, ms, md)

    for i, theta in enumerate(thetas):
        l_eval = layer.eval(-np.cos(theta), np.cos(theta_i)) * np.abs(
        # Values should be close (except if they are insignificantly small).
        # We have less precision at grazing angles because of Fourier representation.
        assert np.allclose(values_ref[i], l_eval, rtol=0.01)

    # Convert into BSDF storage representation
    base_path = os.path.dirname(os.path.realpath(__file__)) + "/data/"
    if not os.path.exists(base_path):
    path = base_path + "diffuse.bsdf"
    storage = mitsuba.layer.BSDFStorage.from_layer(path, layer, 1e-5)

    for i, theta in enumerate(thetas):
        s_eval = storage.eval(np.cos(theta_i), -np.cos(theta))[0]
        # Values should be close (except if they are insignificantly small).
        # We have less precision at grazing angles because of Fourier representation.
        assert np.allclose(values_ref[i], s_eval, rtol=0.01)

    # And load via the "fourier" BSDF plugin
    fourier = load_string("""<bsdf version="2.0.0" type="fourier">
                                 <string name="filename" value="{}"/>

    for i, theta in enumerate(thetas):
        wo = [
            np.sin(theta) * np.cos(phi),
            np.sin(theta) * np.sin(phi),
        f_eval = fourier.eval(ctx, si, wo=wo)[0]
        assert np.allclose(values_ref[i], f_eval, rtol=0.02)
    del fourier
def outgoing_direction(n_phi,
                       theta_max=np.pi / 2,
    from mitsuba.core import Vector2f, Frame3f
    from mitsuba.core import MarginalContinuous2D2

    print("Max theta angle is %f deg." % np.degrees(theta_max))

    phi_o = np.zeros((phi_i.size, theta_i.size, n_phi, n_theta))
    theta_o = np.zeros((phi_i.size, theta_i.size, n_phi, n_theta))
    invalid = np.ones((phi_i.size, theta_i.size, n_phi, n_theta), dtype='bool')
    active = np.ones((phi_i.size, theta_i.size, n_phi, n_theta), dtype='bool')

    # Create uniform samples
    u_0 = np.linspace(0, 1, n_theta)
    u_1 = np.linspace(0, 1, n_phi)
    samples = Vector2f(np.tile(u_0, n_phi), np.repeat(u_1, n_theta))

    # Construct projected surface area interpolant data structure
    params = [phi_i.tolist(), theta_i.tolist()]
    m_vndf = MarginalContinuous2D2(Dvis_sampler, params, normalize=True)

    for i in range(phi_i.size):
        for j in range(theta_i.size):
            # Warp uniform samples by VNDF distribution (G1 mapping)
            u_m, ndf_pdf = m_vndf.sample(samples, [phi_i[i], theta_i[j]])
            # Convert samples to radians (G2 mapping)
            theta_m = u2theta(u_m.x)  # [0, 1] -> [0, pi]
            phi_m = u2phi(u_m.y)  # [0, 1] -> [0, 2pi]
            if isotropic:
                phi_m += phi_i[i]
            # Phase vector
            m = spherical2cartesian(theta_m, phi_m)
            # Incident direction
            wi = spherical2cartesian(theta_i[j], phi_i[i])
            # Outgoing direction (reflection over phase vector)
            wo = ek.fmsub(m, 2.0 * ek.dot(m, wi), wi)
            tmp1, tmp2 = cartesian2spherical(wo)
            # Remove invalid directions
            act = u_m.y > 0  # covered twice [-pi = pi]
            inv = Frame3f.cos_theta(wo) < 0  # below surface plane
            act &= np.invert(inv)  # above surface plane
            act &= tmp1 <= (theta_max + EPSILON)  # further angular restriction
            if isotropic:
                act &= tmp2 >= 0
            if not all:
                tmp1[~act] = 0
                tmp2[~act] = 0
                tmp1[inv] = 0
                tmp2[inv] = 0

            # Fit to datashape
            act = np.reshape(act, (n_phi, n_theta))
            inv = np.reshape(inv, (n_phi, n_theta))
            tmp1 = np.reshape(tmp1, (n_phi, n_theta))
            tmp2 = np.reshape(tmp2, (n_phi, n_theta))

            # Append
            active[i, j] = act
            invalid[i, j] = inv
            theta_o[i, j] = tmp1
            phi_o[i, j] = tmp2
    return [theta_o, phi_o, active, invalid]
def test02_sample_pol_world(variant_scalar_mono_polarized):
    from mitsuba.core import Frame3f, Spectrum, UnpolarizedSpectrum
    from mitsuba.core.xml import load_string
    from mitsuba.render import BSDFContext, TransportMode, SurfaceInteraction3f, fresnel_conductor
    from mitsuba.render.mueller import stokes_basis, rotate_mueller_basis

    def spectrum_from_stokes(v):
        res = Spectrum(0.0)
        for i in range(4):
            res[i, 0] = v[i]
        return res

    # Create a Silver conductor BSDF and reflect different polarization states
    # at a 45˚ angle.
    # This test takes place in world coordinates and thus involves additional
    # coordinate system rotations.
    # The setup is as follows:
    # - The mirror is positioned at [0, 0, 0], angled s.t. the surface normal
    #   is along [1, 1, 0].
    # - Incident ray w1 = [-1, 0, 0] strikes the mirror at a 45˚ angle and
    #   reflects into direction w2 = [0, 1, 0]
    # - The corresponding Stokes bases are b1 = [0, 1, 0] and b2 = [1, 0, 0].

    # Setup
    eta = 0.136125 + 4.010625j  # IoR values of Ag at 635.816284nm
    bsdf = load_string("""<bsdf version='2.0.0' type='conductor'>
                              <spectrum name="eta" value="{}"/>
                              <spectrum name="k" value="{}"/>
                          </bsdf>""".format(eta.real, eta.imag))
    ctx = BSDFContext()
    ctx.mode = TransportMode.Importance
    si = SurfaceInteraction3f()
    si.p = [0, 0, 0]
    si.n = ek.normalize([1.0, 1.0, 0.0])
    si.sh_frame = Frame3f(si.n)

    # Incident / outgoing directions and their Stokes bases
    w1 = ek.scalar.Vector3f([-1, 0, 0])
    b1 = [0, 1, 0]
    w2 = ek.scalar.Vector3f([0, 1, 0])
    b2 = [1, 0, 0]

    # Perform actual reflection
    si.wi = si.to_local(-w1)
    bs, M_tmp = bsdf.sample(ctx, si, 0.0, [0.0, 0.0])
    M_world = si.to_world_mueller(M_tmp, -si.wi, bs.wo)

    # Test that outgoing direction is as predicted
    assert ek.allclose(si.to_world(bs.wo), w2, atol=1e-5)

    # Align reference frames s.t. polarization is expressed w.r.t. b1 & b2
    M_world = rotate_mueller_basis(M_world, w1, stokes_basis(w1), b1, w2,
                                   stokes_basis(w2), b2)

    # Apply to unpolarized light and verify that it is equivalent to normal Fresnel
    a0 = M_world @ spectrum_from_stokes([1, 0, 0, 0])
    F = fresnel_conductor(si.wi[2], ek.scalar.Complex2f(eta.real, eta.imag))
    a0 = a0[0, 0]
    assert ek.allclose(a0[0], F, atol=1e-3)

    # Apply to horizontally polarized light (linear at 0˚)
    # Test that it is..
    # 1) still fully polarized (though with reduced intensity)
    # 2) still horizontally polarized
    a1 = M_world @ spectrum_from_stokes([1, +1, 0, 0])
    assert ek.allclose(a1[0, 0], ek.abs(a1[1, 0]), atol=1e-3)  # 1)
    a1 /= a1[0, 0]
    assert ek.allclose(a1, spectrum_from_stokes([1, 1, 0, 0]), atol=1e-3)  # 2)

    # Apply to vertically polarized light (linear at 90˚)
    # Test that it is..
    # 1) still fully polarized (though with reduced intensity)
    # 2) still vertically polarized
    a2 = M_world @ spectrum_from_stokes([1, -1, 0, 0])
    assert ek.allclose(a2[0, 0], ek.abs(a2[1, 0]), atol=1e-3)  # 1)
    a2 /= a2[0, 0]
    assert ek.allclose(a2, spectrum_from_stokes([1, -1, 0, 0]),
                       atol=1e-3)  # 2)

    # Apply to (positive) diagonally polarized light (linear at +45˚)
    # Test that..
    # 1) The polarization is flipped to -45˚
    # 2) There is now also some (left) circular polarization
    a3 = M_world @ spectrum_from_stokes([1, 0, +1, 0])
    assert ek.all(a3[2, 0] < UnpolarizedSpectrum(0.0))
    assert ek.all(a3[3, 0] < UnpolarizedSpectrum(0.0))

    # Apply to (negative) diagonally polarized light (linear at -45˚)
    # Test that..
    # 1) The polarization is flipped to +45˚
    # 2) There is now also some (right) circular polarization
    a4 = M_world @ spectrum_from_stokes([1, 0, -1, 0])
    assert ek.all(a4[2, 0] > UnpolarizedSpectrum(0.0))
    assert ek.all(a4[3, 0] > UnpolarizedSpectrum(0.0))

    # Apply to right circularly polarized light
    # Test that the polarization is flipped to left circular
    a5 = M_world @ spectrum_from_stokes([1, 0, 0, +1])
    assert ek.all(a5[3, 0] < UnpolarizedSpectrum(0.0))

    # Apply to left circularly polarized light
    # Test that the polarization is flipped to right circular
    a6 = M_world @ spectrum_from_stokes([1, 0, 0, -1])
    assert ek.all(a6[3, 0] > UnpolarizedSpectrum(0.0))
def test02_roughconductor(variant_scalar_rgb):
    from mitsuba.core.xml import load_string
    from mitsuba.render import BSDF, BSDFContext, SurfaceInteraction3f
    from mitsuba.core import Frame3f

    for alpha in [(0.3, 0.3), (0.3 + 1e-5, 0.3 - 1e-5), (0.2, 0.4)]:
        alpha_u = alpha[0]
        alpha_v = alpha[1]

        thetas = np.linspace(0, np.pi / 2, 20)
        phi = np.pi

        values_ref = []

        # Create conductor reference BSDF
        bsdf = load_string("""<bsdf version="2.0.0" type="roughconductor">
                                  <float name="alpha_u" value="{}"/>
                                  <float name="alpha_v" value="{}"/>
                                  <string name="distribution" value="beckmann"/>
                                  <spectrum name="eta" value="0.0"/>
                                  <spectrum name="k" value="1.0"/>
                              </bsdf>""".format(alpha_u, alpha_v))

        theta_i = np.radians(30.0)
        si = SurfaceInteraction3f()
        si.p = [0, 0, 0]
        si.n = [0, 0, 1]
        si.wi = [np.sin(theta_i), 0, np.cos(theta_i)]
        si.sh_frame = Frame3f(si.n)
        ctx = BSDFContext()

        for theta in thetas:
            wo = [
                np.sin(theta) * np.cos(phi),
                np.sin(theta) * np.sin(phi),
            values_ref.append(bsdf.eval(ctx, si, wo=wo)[0])

        # Create same BSDF as layer representation
        n, ms, md = mitsuba.layer.microfacet_parameter_heuristic(
            alpha_u, alpha_v, 0 + 1j)
        mu, w = mitsuba.core.quad.gauss_lobatto(n)
        layer = mitsuba.layer.Layer(mu, w, ms, md)
        layer.set_microfacet(0 + 1j, alpha_u, alpha_v)

        for i, theta in enumerate(thetas):
            l_eval = layer.eval(-np.cos(theta), np.cos(theta_i)) * np.abs(
            # Values should be close (except if they are insignificantly small).
            # We have less precision at grazing angles because of Fourier representation.
            print(values_ref[i], l_eval)
            assert values_ref[i] < 1e-5 or np.allclose(
                values_ref[i], l_eval, rtol=0.05 / (np.abs(np.cos(theta))))

        # Convert into BSDF storage representation
        base_path = os.path.dirname(os.path.realpath(__file__)) + "/data/"
        if not os.path.exists(base_path):
        path = base_path + "roughconductor.bsdf"
        storage = mitsuba.layer.BSDFStorage.from_layer(path, layer, 1e-8)

        for i, theta in enumerate(thetas):
            s_eval = storage.eval(np.cos(theta_i), -np.cos(theta))[0]
            # Values should be close (except if they are insignificantly small).
            # We have less precision at grazing angles because of Fourier representation.
            assert values_ref[i] < 1e-5 or np.allclose(
                values_ref[i], s_eval, rtol=0.05 / (np.abs(np.cos(theta))))

        # And load via the "fourier" BSDF plugin
        fourier = load_string("""<bsdf version="2.0.0" type="fourier">
                                     <string name="filename" value="{}"/>

        for i, theta in enumerate(thetas):
            wo = [
                np.sin(theta) * np.cos(phi),
                np.sin(theta) * np.sin(phi),
            f_eval = fourier.eval(ctx, si, wo=wo)[0]
            assert values_ref[i] < 1e-5 or np.allclose(
                values_ref[i], f_eval, rtol=0.05 / (np.abs(np.cos(theta))))
        del fourier
def render_sample(scene, sampler, rays, bdata, heightmap_pybind, bssrdf=None):
    Sample RTE
    TODO: Support multi channel sampling

        scene: Target scene object
        sampler: Sampler object for random number
        rays: Given rays for sampling
        bdata: BSSRDF Data object
        heightmap_pybind: Object for getting height map around incident position.
                          Refer src/librender/python/heightmap.cpp

        result: Sampling RTE result
        valid_rays: Mask data whether rays are valid or not
        scatter: Scatter components of Sampling RTE result
        non_scatter: Non scatter components of Sampling RTE result
        invalid_sample: Sampling RTE result with invalid sampled data by VAEBSSRDF

    eta = Float(1.0)
    emission_weight = Float(1.0)
    throughput = Spectrum(1.0)
    result = Spectrum(0.0)
    scatter = Spectrum(0.0)
    non_scatter = Spectrum(0.0)
    invalid_sample = Spectrum(0.0)
    active = True
    is_bssrdf = False

    ##### First interaction #####
    si = scene.ray_intersect(rays, active)
    active = si.is_valid() & active
    valid_rays = si.is_valid()

    emitter = si.emitter(scene, active)

    depth = 0

    # Set channel
    # At and after evaluating BSSRDF, a ray consider only this one channel
    n_channels = 3
    channel = UInt32(
        ek.min(sampler.next_1d(active) * n_channels, n_channels - 1))

    d_out_local = Vector3f().zero()
    d_out_pdf = Float(0)

    sss = Mask(False)

    while (True):
        depth += 1
        if config.aovs and depth == 2:
            sss = is_bssrdf

        ##### Interaction with emitters #####
        emission_val = emission_weight * throughput * Emitter.eval_vec(
            emitter, si, active)

        result += ek.select(active, emission_val, Spectrum(0.0))
        invalid_sample += ek.select(active, emission_val, Spectrum(0.0))
        scatter += ek.select(active & sss, emission_val, Spectrum(0.0))
        non_scatter += ek.select(active & ~sss, emission_val, Spectrum(0.0))

        active = active & si.is_valid()

        # Process russian roulette
        if depth > config.rr_depth:
            q = ek.min(ek.hmax(throughput) * ek.sqr(eta), 0.95)
            active = active & (sampler.next_1d(active) < q)
            throughput *= ek.rcp(q)

        # Stop if the number of bouces exceeds the given limit bounce, or
        # all rays are invalid. latter check is done only when the limit
        # bounce is infinite
        if depth >= config.max_depth:

        ##### Emitter sampling #####
        bsdf = si.bsdf(rays)
        ctx = BSDFContext()

        active_e = active & has_flag(BSDF.flags_vec(bsdf), BSDFFlags.Smooth)
        ds, emitter_val = scene.sample_emitter_direction(
            si, sampler.next_2d(active_e), True, active_e)
        active_e &= ek.neq(ds.pdf, 0.0)

        # Query the BSDF for that emitter-sampled direction
        wo = si.to_local(ds.d)
        bsdf_val = BSDF.eval_vec(bsdf, ctx, si, wo, active_e)
        # Determine density of sampling that same direction using BSDF sampling
        bsdf_pdf = BSDF.pdf_vec(bsdf, ctx, si, wo, active_e)

        mis = ek.select(ds.delta, Float(1), mis_weight(ds.pdf, bsdf_pdf))

        emission_val = mis * throughput * bsdf_val * emitter_val

        result += ek.select(active, emission_val, Spectrum(0.0))
        invalid_sample += ek.select(active, emission_val, Spectrum(0.0))
        scatter += ek.select(active & sss, emission_val, Spectrum(0.0))
        non_scatter += ek.select(active & ~sss, emission_val, Spectrum(0.0))

        ##### BSDF sampling #####
        bs, bsdf_val = BSDF.sample_vec(bsdf, ctx, si, sampler.next_1d(active),
                                       sampler.next_2d(active), active)

        ##### BSSRDF replacing #####
        if (config.enable_bssrdf):
            # Replace bsdf samples by ones of BSSRDF
            bs.wo = ek.select(is_bssrdf, d_out_local, bs.wo)
            bs.pdf = ek.select(is_bssrdf, d_out_pdf, bs.pdf)
            bs.sampled_component = ek.select(is_bssrdf, UInt32(1),
            bs.sampled_type = ek.select(is_bssrdf,

        throughput *= ek.select(is_bssrdf, Float(1.0), bsdf_val)
        active &= ek.any(ek.neq(throughput, 0))

        eta *= bs.eta

        # Intersect the BSDF ray against the scene geometry
        rays = RayDifferential3f(si.spawn_ray(si.to_world(bs.wo)))
        si_bsdf = scene.ray_intersect(rays, active)

        ##### Checking BSSRDF #####
        if (config.enable_bssrdf):
            # Whether the BSDF is BSS   RDF or not?
            is_bssrdf = (active
                         & has_flag(BSDF.flags_vec(bsdf), BSDFFlags.BSSRDF)
                         & (Frame3f.cos_theta(bs.wo) < Float(0.0))
                         & (Frame3f.cos_theta(si.wi) > Float(0.0)))

            # Decide whether we should use 0-scattering or multiple scattering
            is_zero_scatter = utils_render.check_zero_scatter(
                sampler, si_bsdf, bs, channel, is_bssrdf)
            is_bssrdf = is_bssrdf & ~is_zero_scatter

            throughput *= ek.select(is_bssrdf, ek.sqr(bs.eta), Float(1.0))

        ###### Process for BSSRDF ######
        if (config.enable_bssrdf and not ek.none(is_bssrdf)):
            # Get projected samples from BSSRDF
            projected_si, project_suc, abs_prob = bssrdf.sample_bssrdf(
                scene, bsdf, bs, si, bdata, heightmap_pybind, channel,

            if config.visualize_invalid_sample and (depth <= 1):
                active = active & (~is_bssrdf | project_suc)
                invalid_sample += ek.select((is_bssrdf & (~project_suc)),
                                            Spectrum([100, 0, 0]),

            # Sample outgoing direction from projected position
            d_out_local, d_out_pdf = utils_render.resample_wo(
                sampler, is_bssrdf)
            # Apply absorption probability
            throughput *= ek.select(is_bssrdf,
                                    Spectrum(1) - abs_prob, Spectrum(1))
            # Replace interactions by sampled ones from BSSRDF
            si_bsdf = SurfaceInteraction3f().masked_si(si_bsdf, projected_si,

        # Determine probability of having sampled that same
        # direction using emitter sampling
        emitter = si_bsdf.emitter(scene, active)
        ds = DirectionSample3f(si_bsdf, si)
        ds.object = emitter

        delta = has_flag(bs.sampled_type, BSDFFlags.Delta)
        emitter_pdf = ek.select(delta, Float(0.0),
                                scene.pdf_emitter_direction(si, ds))
        emission_weight = mis_weight(bs.pdf, emitter_pdf)

        si = si_bsdf

    return result, valid_rays, scatter, non_scatter, invalid_sample
def test02_sample_pol_local(variant_scalar_mono_polarized):
    from mitsuba.core import Frame3f, Transform4f, Spectrum, UnpolarizedSpectrum, Vector3f
    from mitsuba.core.xml import load_string
    from mitsuba.render import BSDFContext, TransportMode, SurfaceInteraction3f, fresnel_conductor
    from mitsuba.render.mueller import stokes_basis, rotate_mueller_basis

    def spectrum_from_stokes(v):
        res = Spectrum(0.0)
        for i in range(4):
            res[i, 0] = v[i]
        return res

    # Create a Silver conductor BSDF and reflect different polarization states
    # at a 45˚ angle.
    # All tests take place directly in local BSDF coordinates. Here we only
    # want to make sure that the output of this looks like what you would
    # expect from a Mueller matrix describing specular reflection on a mirror.

    eta = 0.136125 + 4.010625j  # IoR values of Ag at 635.816284nm
    bsdf = load_string("""<bsdf version='2.0.0' type='conductor'>
                              <spectrum name="eta" value="{}"/>
                              <spectrum name="k" value="{}"/>
                          </bsdf>""".format(eta.real, eta.imag))

    theta_i = 45 * ek.pi / 180
    wi = Vector3f([-ek.sin(theta_i), 0, ek.cos(theta_i)])

    ctx = BSDFContext()
    ctx.mode = TransportMode.Importance
    si = SurfaceInteraction3f()
    si.p = [0, 0, 0]
    si.wi = wi
    n = [0, 0, 1]
    si.sh_frame = Frame3f(n)

    bs, M_local = bsdf.sample(ctx, si, 0.0, [0.0, 0.0])
    wo = bs.wo

    # Rotate into standard coordinates for specular reflection
    bi_s = ek.normalize(ek.cross(n, -wi))
    bi_p = ek.normalize(ek.cross(-wi, bi_s))
    bo_s = ek.normalize(ek.cross(n, wo))
    bo_p = ek.normalize(ek.cross(wo, bi_s))

    M_local = rotate_mueller_basis(M_local, -wi, stokes_basis(-wi), bi_p, wo,
                                   stokes_basis(wo), bo_p)

    # Apply to unpolarized light and verify that it is equivalent to normal Fresnel
    a0 = M_local @ spectrum_from_stokes([1, 0, 0, 0])
    F = fresnel_conductor(ek.cos(theta_i),
                          ek.scalar.Complex2f(eta.real, eta.imag))
    a0 = a0[0, 0]
    assert ek.allclose(a0[0], F, atol=1e-3)

    # Apply to horizontally polarized light (linear at 0˚)
    # Test that it is..
    # 1) still fully polarized (though with reduced intensity)
    # 2) still horizontally polarized
    a1 = M_local @ spectrum_from_stokes([1, +1, 0, 0])
    assert ek.allclose(a1[0, 0], ek.abs(a1[1, 0]), atol=1e-3)  # 1)
    a1 /= a1[0, 0]
    assert ek.allclose(a1, spectrum_from_stokes([1, 1, 0, 0]), atol=1e-3)  # 2)

    # Apply to vertically polarized light (linear at 90˚)
    # Test that it is..
    # 1) still fully polarized (though with reduced intensity)
    # 2) still vertically polarized
    a2 = M_local @ spectrum_from_stokes([1, -1, 0, 0])
    assert ek.allclose(a2[0, 0], ek.abs(a2[1, 0]), atol=1e-3)  # 1)
    a2 /= a2[0, 0]
    assert ek.allclose(a2, spectrum_from_stokes([1, -1, 0, 0]),
                       atol=1e-3)  # 2)

    # Apply to (positive) diagonally polarized light (linear at +45˚)
    # Test that..
    # 1) The polarization is flipped to -45˚
    # 2) There is now also some (left) circular polarization
    a3 = M_local @ spectrum_from_stokes([1, 0, +1, 0])
    assert ek.all(a3[2, 0] < UnpolarizedSpectrum(0.0))
    assert ek.all(a3[3, 0] < UnpolarizedSpectrum(0.0))

    # Apply to (negative) diagonally polarized light (linear at -45˚)
    # Test that..
    # 1) The polarization is flipped to +45˚
    # 2) There is now also some (right) circular polarization
    a4 = M_local @ spectrum_from_stokes([1, 0, -1, 0])
    assert ek.all(a4[2, 0] > UnpolarizedSpectrum(0.0))
    assert ek.all(a4[3, 0] > UnpolarizedSpectrum(0.0))

    # Apply to right circularly polarized light
    # Test that the polarization is flipped to left circular
    a5 = M_local @ spectrum_from_stokes([1, 0, 0, +1])
    assert ek.all(a5[3, 0] < UnpolarizedSpectrum(0.0))

    # Apply to left circularly polarized light
    # Test that the polarization is flipped to right circular
    a6 = M_local @ spectrum_from_stokes([1, 0, 0, -1])
    assert ek.all(a6[3, 0] > UnpolarizedSpectrum(0.0))