コード例 #1
0
    def test_complex2rgb(self):
        c = np.array(
            [[1, np.exp(2j * np.pi / 3)], [np.exp(-2j * np.pi / 3), -1]],
            dtype=np.complex128)
        res = np.array([[[0, 1, 1], [1, 0, 1]], [[1, 1, 0], [1, 0, 0]]],
                       dtype=np.float64)

        rgb = utils.complex2rgb(c)
        npt.assert_almost_equal(rgb, res)

        # Check saturation
        rgb = utils.complex2rgb(10.0 * c)
        npt.assert_almost_equal(
            rgb, res, err_msg='Saturated values are not as expected.')

        # Check intensity scaling
        rgb = utils.complex2rgb(0.5 * c)
        npt.assert_almost_equal(
            rgb,
            0.5 * res,
            err_msg='The intensity does not scale linearly with the amplitude.'
        )

        c[1, 1] /= 2.0
        res[1, 1, :] /= 2.0
        rgb = utils.complex2rgb(0.5 * c)
        npt.assert_almost_equal(
            rgb,
            0.5 * res,
            err_msg='Non-uniform amplitudes are not represented correctly.')
コード例 #2
0
    def display(s):
        log.info('Displaying iteration %d: error %0.1f%%' %
                 (s.iteration, 100 * s.residue))
        nb_dims = s.E.shape[0]
        for dim_idx in range(nb_dims):
            images[dim_idx].set_data(
                utils.complex2rgb(s.E[dim_idx], 1, inverted=True))
            figure_title = '$E_' + 'xyz'[
                dim_idx] + "$ it %d: rms error %0.1f%% " % (s.iteration,
                                                            100 * s.residue)
            add_circles_to_axes(axs[dim_idx][0])
            axs[dim_idx][0].set_title(figure_title)

        plt.draw()
        plt.pause(0.001)
コード例 #3
0
    def display(s):
        log.info("Displaying iteration %d: error %0.1f%%" %
                 (s.iteration, 100 * s.residue))
        nb_dims = s.E.shape[0]
        for dim_idx in range(nb_dims):
            images[dim_idx].set_data(utils.complex2rgb(s.E[dim_idx], 1))
            figure_title = '$E_' + 'xyz'[
                dim_idx] + "$ it %d: rms error %0.1f%% " % (s.iteration,
                                                            100 * s.residue)
            axs[dim_idx][0].set_title(figure_title)

        S = s.S
        S /= np.sqrt(np.max(np.sum(np.abs(S)**2, axis=0))) / (
            sample_pitch[0] * arrow_sep[0])  # Normalize
        U = S[0, ...]
        V = S[1, ...]
        quiver.set_UVC(V[::arrow_sep[0], ::arrow_sep[1]] * 1e6,
                       U[::arrow_sep[0], ::arrow_sep[1]] * 1e6)

        plt.draw()
        plt.pause(0.001)
コード例 #4
0
ファイル: air_glass_air_2D.py プロジェクト: tgparton/MacroMax
    def display(s):
        log.info('Displaying iteration %d: error %0.1f%%' %
                 (s.iteration, 100 * s.residue))
        nb_dims = s.E.shape[0]
        for dim_idx in range(nb_dims):
            images[dim_idx].set_data(
                utils.complex2rgb(s.E[dim_idx], 1, inverted=True))
            figure_title = '$E_' + 'xyz'[
                dim_idx] + "$ it %d: rms error %0.1f%% " % (s.iteration,
                                                            100 * s.residue)
            add_rectangle_to_axes(axs.flatten()[dim_idx])
            axs.flatten()[dim_idx].set_title(figure_title)
        intensity = np.linalg.norm(s.E, axis=0)
        intensity /= np.max(intensity)
        intensity_rgb = np.concatenate(
            (intensity[:, :, np.newaxis], intensity[:, :, np.newaxis],
             intensity[:, :, np.newaxis]),
            axis=2)
        images[-1].set_data(intensity_rgb)
        add_rectangle_to_axes(axs.flatten()[-1])
        axs.flatten()[3].set_title('I')

        plt.draw()
        plt.pause(0.001)
コード例 #5
0
def show_birefringence():
    #
    # Medium settings
    #
    data_shape = np.array([128, 256]) * 2
    wavelength = 500e-9
    boundary_thickness = 3e-6
    beam_diameter = 2.5e-6
    k0 = 2 * np.pi / wavelength
    angular_frequency = const.c * k0
    source_amplitude = 1j * angular_frequency * const.mu_0
    sample_pitch = np.array([1, 1]) * wavelength / 8
    ranges = utils.calc_ranges(data_shape, sample_pitch)
    incident_angle = 0 * np.pi / 180

    def rot_Z(a):
        return np.array([[np.cos(a), -np.sin(a), 0], [np.sin(a),
                                                      np.cos(a), 0], [0, 0,
                                                                      1]])

    incident_k = rot_Z(incident_angle) * k0 @ np.array([0, 1, 0])
    p_source = rot_Z(incident_angle) @ np.array([1, 0, 1]) / np.sqrt(
        2)  # diagonally polarized beam
    source = -source_amplitude * np.exp(
        1j * (incident_k[0] * ranges[0][:, np.newaxis] +
              incident_k[1] * ranges[1][np.newaxis, :]))
    # Aperture the incoming beam
    source = source * np.exp(
        -0.5 * (np.abs(ranges[1][np.newaxis, :] -
                       (ranges[1][0] + boundary_thickness)) / wavelength)**2)
    source = source * np.exp(-0.5 * (
        (ranges[0][:, np.newaxis] - ranges[0][int(len(ranges[0]) * 1 / 4)]) /
        (beam_diameter / 2))**2)
    source = p_source[:, np.newaxis, np.newaxis] * source[np.newaxis, ...]

    permittivity = np.tile(
        np.eye(3, dtype=np.complex128)[:, :, np.newaxis, np.newaxis],
        (1, 1, *data_shape))
    # Add prism
    epsilon_crystal = rot_Z(-np.pi / 4) @ np.diag(
        (1.486, 1.658, 1.658))**2 @ rot_Z(np.pi / 4)
    permittivity[:, :, :, int(data_shape[1]*(1-5/8)/2)+np.arange(int(data_shape[1]*5/8))] = \
        np.tile(epsilon_crystal[:, :, np.newaxis, np.newaxis], (1, 1, data_shape[0], int(data_shape[1]*5/8)))

    # Add boundary
    dist_in_boundary = np.maximum(
        np.maximum(
            0.0, -(ranges[0][:, np.newaxis] -
                   (ranges[0][0] + boundary_thickness))) + np.maximum(
                       0.0, ranges[0][:, np.newaxis] -
                       (ranges[0][-1] - boundary_thickness)),
        np.maximum(
            0.0, -(ranges[1][np.newaxis, :] -
                   (ranges[1][0] + boundary_thickness))) + np.maximum(
                       0.0, ranges[1][np.newaxis, :] -
                       (ranges[1][-1] - boundary_thickness)))
    weight_boundary = dist_in_boundary / boundary_thickness
    for dim_idx in range(3):
        permittivity[dim_idx, dim_idx, :, :] += -1.0 + (
            1.0 + 0.2j * weight_boundary)  # boundary

    # Prepare the display
    fig, axs = plt.subplots(3, 2, frameon=False, figsize=(12, 9))

    for ax in axs.ravel():
        ax.set_xlabel('y [$\mu$m]')
        ax.set_ylabel('x [$\mu$m]')
        ax.set_aspect('equal')

    images = [
        axs[dim_idx][0].imshow(
            utils.complex2rgb(np.zeros(data_shape), 1),
            extent=np.array([*ranges[1][[0, -1]], *ranges[0][[0, -1]]]) * 1e6,
            origin='lower') for dim_idx in range(3)
    ]
    axs[0][1].imshow(
        utils.complex2rgb(permittivity[0, 0], 1),
        extent=np.array([*ranges[1][[0, -1]], *ranges[0][[0, -1]]]) * 1e6,
        origin='lower')
    axs[2][1].imshow(
        utils.complex2rgb(source[0], 1),
        extent=np.array([*ranges[1][[0, -1]], *ranges[0][[0, -1]]]) * 1e6,
        origin='lower')
    axs[0][1].set_title('$\chi$')
    axs[1][1].axis('off')
    axs[2][1].set_title('source and S')
    mesh_ranges = [0, 1]
    for dim_idx in range(len(ranges)):
        mesh_ranges[dim_idx] = utils.to_dim(ranges[dim_idx].flatten(),
                                            len(ranges),
                                            axis=dim_idx)
    X, Y = np.meshgrid(mesh_ranges[1], mesh_ranges[0])
    arrow_sep = np.array([1, 1], dtype=int) * 30
    quiver = axs[2][1].quiver(X[::arrow_sep[0], ::arrow_sep[1]] * 1e6,
                              Y[::arrow_sep[0], ::arrow_sep[1]] * 1e6,
                              X[::arrow_sep[0], ::arrow_sep[1]] * 0,
                              Y[::arrow_sep[0], ::arrow_sep[1]] * 0,
                              pivot='mid',
                              scale=1.0,
                              scale_units='x',
                              units='x',
                              color=np.array([1, 0, 1, 0.5]))

    for dim_idx in range(3):
        for col_idx in range(2):
            axs[dim_idx][col_idx].autoscale(False, tight=True)

    # plt.show(block=True)

    #
    # Display the current solution
    #
    def display(s):
        log.info("Displaying iteration %d: error %0.1f%%" %
                 (s.iteration, 100 * s.residue))
        nb_dims = s.E.shape[0]
        for dim_idx in range(nb_dims):
            images[dim_idx].set_data(utils.complex2rgb(s.E[dim_idx], 1))
            figure_title = '$E_' + 'xyz'[
                dim_idx] + "$ it %d: rms error %0.1f%% " % (s.iteration,
                                                            100 * s.residue)
            axs[dim_idx][0].set_title(figure_title)

        S = s.S
        S /= np.sqrt(np.max(np.sum(np.abs(S)**2, axis=0))) / (
            sample_pitch[0] * arrow_sep[0])  # Normalize
        U = S[0, ...]
        V = S[1, ...]
        quiver.set_UVC(V[::arrow_sep[0], ::arrow_sep[1]] * 1e6,
                       U[::arrow_sep[0], ::arrow_sep[1]] * 1e6)

        plt.draw()
        plt.pause(0.001)

    #
    # Display the (intermediate) result
    #
    def update_function(s):
        if np.mod(s.iteration, 10) == 0:
            log.info("Iteration %0.0f: rms error %0.1f%%" %
                     (s.iteration, 100 * s.residue))
        if np.mod(s.iteration, 10) == 0:
            display(s)

        return s.residue > 1e-3 and s.iteration < 1e4

    # The actual work is done here:
    start_time = time.time()
    solution = macromax.solve(ranges,
                              vacuum_wavelength=wavelength,
                              source_distribution=source,
                              epsilon=permittivity,
                              callback=update_function)
    log.info("Calculation time: %0.3fs." % (time.time() - start_time))

    # Show final result
    log.info('Displaying final result.')
    display(solution)
    plt.show(block=True)
コード例 #6
0
ファイル: air_glass_air_2D.py プロジェクト: tgparton/MacroMax
def show_scatterer(vectorial=True):
    output_name = 'air_glass_air_2D'

    #
    # Medium settings
    #
    scale = 2
    data_shape = np.array([256, 256]) * scale
    wavelength = 500e-9
    medium_refractive_index = 1.0  # 1.4758, 2.7114
    boundary_thickness = 2e-6
    beam_diameter = 1.0e-6 * scale
    plate_thickness = 2.5e-6 * scale

    k0 = 2 * np.pi / wavelength
    angular_frequency = const.c * k0
    source_amplitude = 1j * angular_frequency * const.mu_0
    sample_pitch = np.array([1, 1]) * wavelength / 15
    ranges = utils.calc_ranges(data_shape, sample_pitch)
    incident_angle = 30 * np.pi / 180
    # incident_angle = np.arctan(1.5)  # Brewster's angle

    log.info('Calculating fields over a %0.1fum x %0.1fum area...' %
             tuple(data_shape * sample_pitch * 1e6))

    def rot_Z(a):
        return np.array([[np.cos(a), -np.sin(a), 0], [np.sin(a),
                                                      np.cos(a), 0], [0, 0,
                                                                      1]])

    incident_k = rot_Z(incident_angle) * k0 @ np.array([1, 0, 0])
    p_source = rot_Z(incident_angle) @ np.array([0, 1, 1j]) / np.sqrt(2)
    source = -source_amplitude * np.exp(
        1j * (incident_k[0] * ranges[0][:, np.newaxis] +
              incident_k[1] * ranges[1][np.newaxis, :]))
    # Aperture the incoming beam
    # source = source * np.exp(-0.5*(np.abs(ranges[1][np.newaxis, :] - (ranges[1][0]+boundary_thickness))
    #                                * medium_refractive_index / wavelength)**2)  # source position
    source_pixel = data_shape[0] - int(boundary_thickness / sample_pitch[0])
    source[:source_pixel, :] = 0
    source[source_pixel + 1:, :] = 0
    source = source * np.exp(-0.5 * (
        (ranges[1][np.newaxis, :] - ranges[1][int(len(ranges[0]) * 1 / 4)]) /
        (beam_diameter / 2))**2)  # beam aperture
    source = source[np.newaxis, ...]
    if vectorial:
        source = source * p_source[:, np.newaxis, np.newaxis]

    # define the glass plate
    refractive_index = 1.0 + 0.5 * np.ones(len(ranges[1]))[np.newaxis, :] * (
        np.abs(ranges[0]) < plate_thickness / 2)[:, np.newaxis]
    permittivity = np.array(refractive_index**2, dtype=np.complex128)
    permittivity = permittivity[np.newaxis, np.newaxis, :, :]

    # Add boundary
    dist_in_boundary = np.maximum(
        np.maximum(
            0.0, -(ranges[0][:, np.newaxis] -
                   (ranges[0][0] + boundary_thickness))) + np.maximum(
                       0.0, ranges[0][:, np.newaxis] -
                       (ranges[0][-1] - boundary_thickness)),
        np.maximum(
            0.0, -(ranges[1][np.newaxis, :] -
                   (ranges[1][0] + boundary_thickness))) + np.maximum(
                       0.0, ranges[1][np.newaxis, :] -
                       (ranges[1][-1] - boundary_thickness)))
    weight_boundary = dist_in_boundary / boundary_thickness
    for dim_idx in range(permittivity.shape[0]):
        permittivity[dim_idx, dim_idx, :, :] += -1.0 + (
            1.0 + 0.5j * weight_boundary)  # boundary

    # Prepare the display
    def add_rectangle_to_axes(axes):
        rectangle = plt.Rectangle(np.array(
            (ranges[1][0], -plate_thickness / 2)) * 1e6,
                                  (data_shape[1] * sample_pitch[1]) * 1e6,
                                  plate_thickness * 1e6,
                                  edgecolor=np.array((0, 1, 1, 0.25)),
                                  linewidth=1,
                                  fill=True,
                                  facecolor=np.array((0, 1, 1, 0.05)))
        axes.add_artist(rectangle)

    fig, axs = plt.subplots(2,
                            2,
                            frameon=False,
                            figsize=(12, 12),
                            sharex=True,
                            sharey=True)
    for ax in axs.ravel():
        ax.set_xlabel('y [$\mu$m]')
        ax.set_ylabel('x [$\mu$m]')
        ax.set_aspect('equal')

    images = [
        axs.flatten()[idx].imshow(utils.complex2rgb(np.zeros(data_shape),
                                                    1,
                                                    inverted=True),
                                  extent=utils.ranges2extent(*ranges) * 1e6)
        for idx in range(4)
    ]

    axs[0][1].set_title('$||E||^2$')

    # Display the medium without the boundaries
    for idx in range(4):
        axs.flatten()[idx].set_xlim(
            (ranges[1].flatten()[0] + boundary_thickness) * 1e6,
            (ranges[1].flatten()[-1] - boundary_thickness) * 1e6)
        axs.flatten()[idx].set_ylim(
            (ranges[0].flatten()[0] + boundary_thickness) * 1e6,
            (ranges[0].flatten()[-1] - boundary_thickness) * 1e6)
        axs.flatten()[idx].autoscale(False)

    #
    # Display the current solution
    #
    def display(s):
        log.info('Displaying iteration %d: error %0.1f%%' %
                 (s.iteration, 100 * s.residue))
        nb_dims = s.E.shape[0]
        for dim_idx in range(nb_dims):
            images[dim_idx].set_data(
                utils.complex2rgb(s.E[dim_idx], 1, inverted=True))
            figure_title = '$E_' + 'xyz'[
                dim_idx] + "$ it %d: rms error %0.1f%% " % (s.iteration,
                                                            100 * s.residue)
            add_rectangle_to_axes(axs.flatten()[dim_idx])
            axs.flatten()[dim_idx].set_title(figure_title)
        intensity = np.linalg.norm(s.E, axis=0)
        intensity /= np.max(intensity)
        intensity_rgb = np.concatenate(
            (intensity[:, :, np.newaxis], intensity[:, :, np.newaxis],
             intensity[:, :, np.newaxis]),
            axis=2)
        images[-1].set_data(intensity_rgb)
        add_rectangle_to_axes(axs.flatten()[-1])
        axs.flatten()[3].set_title('I')

        plt.draw()
        plt.pause(0.001)

    #
    # Display progress and the (intermediate) result
    #
    residues = []
    times = []

    def update_function(s):
        # Log progress
        times.append(time.time())
        residues.append(s.residue)

        if np.mod(s.iteration, 10) == 0:
            log.info("Iteration %0.0f: rms error %0.3f%%" %
                     (s.iteration, 100 * s.residue))
        if np.mod(s.iteration, 10) == 1:
            display(s)

        return s.residue > 1e-4 and s.iteration < 1e4

    # The actual work is done here:
    start_time = time.time()
    solution = macromax.solve(ranges,
                              vacuum_wavelength=wavelength,
                              source_distribution=source,
                              epsilon=permittivity,
                              callback=update_function)
    log.info("Calculation time: %0.3fs." % (time.time() - start_time))

    # Display how the method converged
    times = np.array(times) - start_time
    log.info("Calculation time: %0.3fs." % times[-1])

    # Show final result
    log.info('Displaying final result.')
    display(solution)
    plt.show(block=False)
    # Save the individual images
    log.info('Saving results to folder %s...' % os.getcwd())
    for dim_idx in range(solution.E.shape[0]):
        plt.imsave(output_name + '_E%s.png' % chr(ord('x') + dim_idx),
                   utils.complex2rgb(solution.E[dim_idx], 1, inverted=True),
                   vmin=0.0,
                   vmax=1.0,
                   cmap=None,
                   format='png',
                   origin=None,
                   dpi=600)
    # Save the figure
    plt.ioff()
    fig.savefig(output_name + '.svgz', bbox_inches='tight', format='svgz')
    plt.ion()

    return times, residues
コード例 #7
0
ファイル: air_glass_air_2D.py プロジェクト: tgparton/MacroMax
    return np.sum(np.conj(a) * b, axis=-1, keepdims=True)


if __name__ == "__main__":
    start_time = time.time()
    # times, residues = show_scatterer(vectorial=False)
    times, residues = show_scatterer()
    log.info("Total time: %0.3fs." % (time.time() - start_time))

    # Display how the method converged
    fig_summary, axs_summary = plt.subplots(1,
                                            2,
                                            frameon=False,
                                            figsize=(18, 9))
    axs_summary[0].semilogy(times, residues)
    axs_summary[0].scatter(times[::100], residues[::100])
    axs_summary[0].set_xlabel('t [s]')
    axs_summary[0].set_ylabel(r'$||\Delta E|| / ||E||$')
    colormap_ranges = [
        -(np.arange(256) / 256 * 2 * np.pi - np.pi),
        np.linspace(0, 1, 256)
    ]
    colormap_image = utils.complex2rgb(
        colormap_ranges[1][np.newaxis, :] *
        np.exp(1j * colormap_ranges[0][:, np.newaxis]),
        inverted=True)
    axs_summary[1].imshow(colormap_image,
                          extent=utils.ranges2extent(*colormap_ranges))

    plt.show(block=True)
コード例 #8
0
def show_scatterer(vectorial=True, anisotropic=True, scattering_layer=True):
    if not vectorial:
        anisotropic = False

    #
    # Medium settings
    #
    scale = 2
    data_shape = np.array([128, 256]) * scale
    wavelength = 500e-9
    medium_refractive_index = 1.0  # 1.4758, 2.7114
    boundary_thickness = 2e-6
    beam_diameter = 1.0e-6 * scale
    layer_thickness = 2.5e-6 * scale

    k0 = 2 * np.pi / wavelength
    angular_frequency = const.c * k0
    source_amplitude = 1j * angular_frequency * const.mu_0
    sample_pitch = np.array([1, 1]) * wavelength / 15
    ranges = utils.calc_ranges(data_shape, sample_pitch)
    incident_angle = 0 * np.pi / 180

    log.info('Calculating fields over a %0.1fum x %0.1fum area...' %
             tuple(data_shape * sample_pitch * 1e6))

    def rot_Z(a):
        return np.array([[np.cos(a), -np.sin(a), 0], [np.sin(a),
                                                      np.cos(a), 0], [0, 0,
                                                                      1]])

    incident_k = rot_Z(incident_angle) * k0 @ np.array([0, 1, 0])
    p_source = rot_Z(incident_angle) @ np.array([1, 0, 1j]) / np.sqrt(2)
    source = -source_amplitude * np.exp(
        1j * (incident_k[0] * ranges[0][:, np.newaxis] +
              incident_k[1] * ranges[1][np.newaxis, :]))
    # Aperture the incoming beam
    source = source * np.exp(
        -0.5 * (np.abs(ranges[1][np.newaxis, :] -
                       (ranges[1][0] + boundary_thickness)) *
                medium_refractive_index / wavelength)**2)  # source position
    source = source * np.exp(-0.5 * (
        (ranges[0][:, np.newaxis] - ranges[0][int(len(ranges[0]) * 2 / 4)]) /
        (beam_diameter / 2))**2)  # beam aperture
    source = source[np.newaxis, ...]
    if vectorial:
        source = source * p_source[:, np.newaxis, np.newaxis]

    # Place randomly oriented TiO2 particles
    permittivity, orientation, grain_pos, grain_rad, grain_dir = \
        generate_random_layer(data_shape, sample_pitch, layer_thickness=layer_thickness, grain_mean=1e-6,
                              grain_std=0.2e-6, normal_dim=1,
                              birefringent=anisotropic, medium_refractive_index=medium_refractive_index,
                              scattering_layer=scattering_layer)

    if not anisotropic:
        permittivity = permittivity[:1, :1, ...]
    log.info('Sample ready.')

    # for r, pos in zip(grain_rad, grain_pos):
    #     plot_circle(plt, radius=r*1e6, origin=pos[::-1]*1e6)
    # epsilon_abs = np.abs(permittivity[0, 0]) - 1
    # rgb_image = colors.hsv_to_rgb(np.stack((np.mod(direction / (2*np.pi),1), 1+0*direction, epsilon_abs), axis=2))
    # plt.imshow(rgb_image, zorder=0, extent=utils.ranges2extent(*ranges)*1e6)
    # plt.axis('equal')
    # plt.pause(0.01)
    # plt.show(block=True)

    # Add boundary
    dist_in_boundary = np.maximum(
        np.maximum(
            0.0, -(ranges[0][:, np.newaxis] -
                   (ranges[0][0] + boundary_thickness))) + np.maximum(
                       0.0, ranges[0][:, np.newaxis] -
                       (ranges[0][-1] - boundary_thickness)),
        np.maximum(
            0.0, -(ranges[1][np.newaxis, :] -
                   (ranges[1][0] + boundary_thickness))) + np.maximum(
                       0.0, ranges[1][np.newaxis, :] -
                       (ranges[1][-1] - boundary_thickness)))
    weight_boundary = dist_in_boundary / boundary_thickness
    for dim_idx in range(permittivity.shape[0]):
        permittivity[dim_idx, dim_idx, :, :] += -1.0 + (
            1.0 + 0.5j * weight_boundary)  # boundary

    # Prepare the display
    def add_circles_to_axes(axes):
        for r, pos in zip(grain_rad, grain_pos):
            circle = plt.Circle(pos[::-1] * 1e6,
                                r * 1e6,
                                edgecolor=np.array((1, 1, 1)) * 0.0,
                                facecolor=None,
                                alpha=0.25,
                                fill=False,
                                linewidth=1)
            axes.add_artist(circle)

    fig, axs = plt.subplots(3,
                            2,
                            frameon=False,
                            figsize=(12, 9),
                            sharex=True,
                            sharey=True)
    for ax in axs.ravel():
        ax.set_xlabel('y [$\mu$m]')
        ax.set_ylabel('x [$\mu$m]')
        ax.set_aspect('equal')

    images = [
        axs[dim_idx][0].imshow(utils.complex2rgb(np.zeros(data_shape),
                                                 1,
                                                 inverted=True),
                               extent=utils.ranges2extent(*ranges) * 1e6)
        for dim_idx in range(3)
    ]

    epsilon_abs = np.abs(permittivity[0, 0]) - 1
    # rgb_image = colors.hsv_to_rgb(np.stack((np.mod(direction / (2*np.pi), 1), 1+0*direction, epsilon_abs), axis=2))
    axs[0][1].imshow(utils.complex2rgb(epsilon_abs * np.exp(1j * orientation),
                                       normalization=True,
                                       inverted=True),
                     zorder=0,
                     extent=utils.ranges2extent(*ranges) * 1e6)
    add_circles_to_axes(axs[0][1])
    axs[1][1].imshow(utils.complex2rgb(permittivity[0, 0], 1, inverted=True),
                     extent=utils.ranges2extent(*ranges) * 1e6)
    axs[2][1].imshow(utils.complex2rgb(source[0], 1, inverted=True),
                     extent=utils.ranges2extent(*ranges) * 1e6)
    axs[0][1].set_title('crystal axis orientation')
    axs[1][1].set_title('$\chi$')
    axs[2][1].set_title('source')

    # Display the medium without the boundaries
    for dim_idx in range(len(axs)):
        for col_idx in range(len(axs[dim_idx])):
            axs[dim_idx][col_idx].set_xlim(
                (ranges[1].flatten()[0] + boundary_thickness) * 1e6,
                (ranges[1].flatten()[-1] - boundary_thickness) * 1e6)
            axs[dim_idx][col_idx].set_ylim(
                (ranges[0].flatten()[0] + boundary_thickness) * 1e6,
                (ranges[0].flatten()[-1] - boundary_thickness) * 1e6)
            axs[dim_idx][col_idx].autoscale(False)

    #
    # Display the current solution
    #
    def display(s):
        log.info('Displaying iteration %d: error %0.1f%%' %
                 (s.iteration, 100 * s.residue))
        nb_dims = s.E.shape[0]
        for dim_idx in range(nb_dims):
            images[dim_idx].set_data(
                utils.complex2rgb(s.E[dim_idx], 1, inverted=True))
            figure_title = '$E_' + 'xyz'[
                dim_idx] + "$ it %d: rms error %0.1f%% " % (s.iteration,
                                                            100 * s.residue)
            add_circles_to_axes(axs[dim_idx][0])
            axs[dim_idx][0].set_title(figure_title)

        plt.draw()
        plt.pause(0.001)

    #
    # Display progress and the (intermediate) result
    #
    residues = []
    times = []

    def update_function(s):
        # Log progress
        times.append(time.time())
        residues.append(s.residue)

        if np.mod(s.iteration, 10) == 0:
            log.info("Iteration %0.0f: rms error %0.3f%%" %
                     (s.iteration, 100 * s.residue))
        if np.mod(s.iteration, 10) == 1:
            display(s)

        return s.residue > 1e-5 and s.iteration < 1e4

    # The actual work is done here:
    start_time = time.time()
    solution = macromax.solve(ranges,
                              vacuum_wavelength=wavelength,
                              source_distribution=source,
                              epsilon=permittivity,
                              callback=update_function)
    log.info("Calculation time: %0.3fs." % (time.time() - start_time))

    # Display how the method converged
    times = np.array(times) - start_time
    log.info("Calculation time: %0.3fs." % times[-1])

    # Calculate total energy flow in propagation direction
    # forward_poynting_vector = np.sum(solution.S[1, :, :], axis=0)
    forward_E = np.mean(solution.E, axis=1)  # average over dimension x
    forward_H = np.mean(solution.H, axis=1)  # average over dimension x
    forward_poynting_vector = (0.5 / const.mu_0) * ParallelOperations.cross(
        forward_E, np.conj(forward_H)).real
    forward_poynting_vector = forward_poynting_vector[1, :]
    forward_poynting_vector_after_layer =\
        forward_poynting_vector[(ranges[1] > layer_thickness / 2) & (ranges[1] < ranges[1][-1] - boundary_thickness)]
    forward_poynting_vector_after_layer = forward_poynting_vector_after_layer[
        int(len(forward_poynting_vector_after_layer) / 2)]
    log.info('Forward Poynting vector: %g' %
             forward_poynting_vector_after_layer)
    fig_S = plt.figure(frameon=False, figsize=(12, 9))
    ax_S = fig_S.add_subplot(111)
    ax_S.plot(ranges[1] * 1e6, forward_poynting_vector)
    ax_S.set_xlabel(r'$z [\mu m]$')
    ax_S.set_ylabel(r'$S_z$')

    # Show final result
    log.info('Displaying final result.')
    display(solution)
    plt.show(block=False)
    # Save the individual images
    log.info('Saving results to folder %s...' % os.getcwd())
    plt.imsave('rutile_orientation.png',
               utils.complex2rgb(epsilon_abs * np.exp(1j * orientation),
                                 normalization=True,
                                 inverted=True),
               vmin=0.0,
               vmax=1.0,
               cmap=None,
               format='png',
               origin=None,
               dpi=600)
    for dim_idx in range(solution.E.shape[0]):
        plt.imsave('rutile_E%s.png' % chr(ord('x') + dim_idx),
                   utils.complex2rgb(solution.E[dim_idx], 1, inverted=True),
                   vmin=0.0,
                   vmax=1.0,
                   cmap=None,
                   format='png',
                   origin=None,
                   dpi=600)
    # Save the figure
    plt.ioff()
    fig.savefig('rutile.svgz', bbox_inches='tight', format='svgz')
    plt.ion()

    return times, residues, forward_poynting_vector