Exemple #1
0
 def setUp(self):
     self.nx = 40
     self.ny = 20
     mesh = fenics.RectangleMesh(fenics.Point(-2, -2), fenics.Point(2, 2), self.nx, self.ny)
     # function spaces
     U = fenics.VectorFunctionSpace(mesh, "Lagrange", 1)
     V = fenics.FunctionSpace(mesh, "Lagrange", 1)
     u_0_conc_expr = fenics.Expression('sqrt(pow(x[0]-x0,2)+pow(x[1]-y0,2)) < 1 ? (1.0) : (0.0)', degree=1,
                                       x0=1,
                                       y0=1)
     u_0_disp_expr = fenics.Constant((1.0, 1.0))
     self.conc = fenics.project(u_0_conc_expr, V)
     self.disp = fenics.project(u_0_disp_expr, U)
     # 3D
     mesh3d = fenics.BoxMesh(fenics.Point(-2, -2, -2), fenics.Point(2, 2, 2), 10, 20, 30)
     # function spaces
     U3 = fenics.VectorFunctionSpace(mesh3d, "Lagrange", 1)
     V3 = fenics.FunctionSpace(mesh3d, "Lagrange", 1)
     u_0_conc_expr = fenics.Expression('sqrt(pow(x[0]-x0,2)+pow(x[1]-y0,2)+pow(x[2]-z0,2)) < 1 ? (1.0) : (0.0)', degree=1,
                                       x0=1, y0=1, z0=1            )
     u_0_disp_expr = fenics.Constant((1.0, 1.0, 1.0))
     self.conc3 = fenics.project(u_0_conc_expr, V3)
     self.disp3 = fenics.project(u_0_disp_expr, U3)
     self.test_path = os.path.join(config.output_dir_testing, 'test_data_io')
     fu.ensure_dir_exists(self.test_path)
Exemple #2
0
def remove_mesh_subdomain(fenics_mesh,
                          subdomains,
                          lower_thr,
                          upper_thr,
                          temp_dir=config.output_dir_temp):
    """
    Creates new fenics mesh containing only subdomains lower_thr to upper_thr.
    :return: fenics mesh and subdomains
    """
    path_to_temp_vtk = os.path.join(temp_dir, 'mesh.vtu')
    fu.ensure_dir_exists(path_to_temp_vtk)
    path_to_temp_vtk_thresh = os.path.join(temp_dir, 'mesh_thresh.vtu')
    # 1) convert fenics mesh and subdomains to vtk mesh using meshio
    mio_mesh = convert_fenics_mesh_to_meshio(fenics_mesh,
                                             subdomains=subdomains)
    mio.write(path_to_temp_vtk, mio_mesh)
    # 2) load mesh using vtk
    mesh_vtk = vtu.read_vtk_data(path_to_temp_vtk)
    # 3) apply threshold filter to vtk mesh to remove subdomain
    mesh_vtk_thresh = vtu.threshold_vtk_data(mesh_vtk,
                                             'cell',
                                             'ElementBlockIds',
                                             lower_thr=lower_thr,
                                             upper_thr=upper_thr)
    # 4) convert vtk mesh to meshio and convert to fenics mesh
    mio_mesh_thresh = convert_vtk_mesh_to_meshio(mesh_vtk_thresh)
    mesh_thresh, subdomains_thresh = convert_meshio_to_fenics_mesh(
        mio_mesh_thresh)
    return mesh_thresh, subdomains_thresh
def register_ants_synquick(fixed_img,
                           moving_img,
                           output_prefix,
                           registration='s',
                           fixed_mask=None,
                           dim=3):
    """
    registration:
        - r -> rigid
        - a -> rigid, affine
        - s -> rigid, affine, syn
    """
    ants_params_dict = {
        'd': dim,
        'f': fixed_img,
        'm': moving_img,
        't': registration,
        'o': output_prefix,
        'n': 4,
        'j': 1,
        'z': 0
    }
    if fixed_mask:
        ants_params_dict['x'] = fixed_mask
    ants_params_str = ' -'.join([
        ' '.join([key, str(value)]) for key, value in ants_params_dict.items()
    ])
    ants_cmd = "%s -" % 'antsRegistrationSyNQuick.sh' + ants_params_str
    print("ANTS command SYNquick: %s" % ants_cmd)
    fu.ensure_dir_exists(os.path.dirname(output_prefix))
    args = shlex.split(ants_cmd)
    process = subprocess.Popen(args, env=os.environ.copy())
    process.wait()
    print("ANTS terminated with return code: '%s'" % process.returncode)
 def test_load_from_hdf5(self):
     path_to_file = os.path.join(config.output_dir_testing, 'timeseries_to_hdf5_for_reading.h5')
     fu.ensure_dir_exists(path_to_file)
     # create file
     tsmd = TimeSeriesMultiData()
     tsmd.register_time_series(name='solution', functionspace=self.functionspace)
     tsmd.add_observation('solution', field=self.U, time=1, time_step=1, recording_step=1)
     tsmd.add_observation('solution', field=self.U, time=1, time_step=1, recording_step=2)
     tsmd.add_observation('solution', field=self.U, time=1, time_step=1, recording_step=3)
     tsmd.register_time_series(name='solution2', functionspace=self.functionspace)
     tsmd.add_observation('solution2', field=self.U, time=1, time_step=1, recording_step=1)
     tsmd.add_observation('solution2', field=self.U, time=1, time_step=1, recording_step=2)
     tsmd.add_observation('solution2', field=self.U, time=1, time_step=1, recording_step=3)
     tsmd.save_to_hdf5(path_to_file, replace=True)
     # read file
     tsmd2 = TimeSeriesMultiData()
     tsmd2.register_time_series(name='solution', functionspace=self.functionspace)
     tsmd2.register_time_series(name='solution2', functionspace=self.functionspace)
     tsmd2.load_from_hdf5(path_to_file)
     self.assertEqual(len(tsmd2.get_all_time_series()),2)
     self.assertEqual(len(tsmd2.get_time_series('solution').get_all_recording_steps()),3)
     self.assertEqual(len(tsmd2.get_time_series('solution2').get_all_recording_steps()), 3)
     u_reloaded = tsmd2.get_solution_function(name='solution')
     # print(u_reloaded.vector().array())
     # print(self.U.vector().array())
     array_1 = u_reloaded.vector().get_local()
     array_2 = self.U.vector().get_local()
     self.assertTrue(np.allclose(array_1, array_2))
Exemple #5
0
def save_function_mesh(function,
                       path_to_hdf5_function,
                       labelfunction=None,
                       subdomains=None):
    if path_to_hdf5_function.endswith('.h5'):
        path_to_hdf5_mesh = path_to_hdf5_function[:-3] + '_mesh.h5'
    else:
        print("Provide path to '.h5' file")
    mesh = function.function_space().mesh()
    fu.ensure_dir_exists(path_to_hdf5_mesh)
    if labelfunction is not None:
        from glimslib.simulation_helpers import SubDomains
        # create subdomains
        subdomains = SubDomains(mesh)
        subdomains.setup_subdomains(label_function=labelfunction)
        # save mesh as hdf5
        save_mesh_hdf5(mesh,
                       path_to_hdf5_mesh,
                       subdomains=subdomains.subdomains)
    elif subdomains is not None:
        save_mesh_hdf5(mesh, path_to_hdf5_mesh, subdomains=subdomains)
    else:
        save_mesh_hdf5(mesh, path_to_hdf5_mesh)
    # save function
    save_functions_hdf5({"function": function},
                        path_to_hdf5_function,
                        time_step=None)
Exemple #6
0
def write_vtk_data(_data, _path_to_file):
    fu.ensure_dir_exists(_path_to_file)
    writer = vtk.vtkXMLDataSetWriter()
    if vtk.VTK_MAJOR_VERSION <= 5:
        writer.SetInput(_data)
    else:
        writer.SetInputData(_data)
    writer.SetFileName(_path_to_file)
    writer.Update()
    writer.Write()
Exemple #7
0
 def create_path(self, path_pattern_list=None, abs_path=True, create=True, with_ext=True, **kwargs):
     if path_pattern_list:
         path = self.bids_layout.build_path(kwargs, path_pattern_list)
     else:
         path = self.bids_layout.build_path(kwargs)
     if abs_path:
         path = os.path.join(self.data_root, path)
     if create:
         fu.ensure_dir_exists(os.path.dirname(path))
     if not with_ext:
         path = '.'.join(path.split('.')[:-1])
     return path
 def test_save_to_hdf5(self):
     path_to_file = os.path.join(config.output_dir_testing, 'timeseries_to_hdf5.h5')
     fu.ensure_dir_exists(path_to_file)
     tsmd = TimeSeriesMultiData()
     tsmd.register_time_series(name='solution', functionspace=self.functionspace)
     tsmd.add_observation('solution', field=self.U, time=1, time_step=1, recording_step=1)
     tsmd.add_observation('solution', field=self.U, time=1, time_step=1, recording_step=2)
     tsmd.add_observation('solution', field=self.U, time=1, time_step=1, recording_step=3)
     tsmd.register_time_series(name='solution2', functionspace=self.functionspace)
     tsmd.add_observation('solution2', field=self.U, time=1, time_step=1, recording_step=1)
     tsmd.add_observation('solution2', field=self.U, time=1, time_step=1, recording_step=2)
     tsmd.add_observation('solution2', field=self.U, time=1, time_step=1, recording_step=3)
     tsmd.save_to_hdf5(path_to_file, replace=True)
Exemple #9
0
 def __init__(self, data_root, path_to_bids_config=None):
     if path_to_bids_config:
         self.path_to_bids_config = path_to_bids_config
     else:
         self.path_to_bids_config = config.path_to_bids_config
     # -- read file to extract path patterns
     with open(self.path_to_bids_config) as json_data:
         self.bids_config = json.load(json_data)
     # -- directory to which all other dirs are relative
     self.data_root = data_root
     # -- initialize bids layout for reading
     fu.ensure_dir_exists(data_root)
     self.init_bids_layout()
Exemple #10
0
def merge_vtus_timestep(base_path,
                        timestep,
                        remove=False,
                        reference_file_path=None):
    """
    This function merges data arrays from multiple vtu files into single vtu file.
    We use this function to join simulation outputs, such as 'concentration' and 'displacement' from multiple
    into a single file per time step.
    :param base_path: path to directory where simulation results are stored
    :param timestep: current time step
    :param remove: boolean flag indicating whether original files should be removed
    :param reference_file_path: path to file that includes labelmap
    """
    print("-- Creating joint vtu for timestep %d" % timestep)
    if reference_file_path is None:
        reference_file_path = os.path.join(base_path, "label_map",
                                           create_file_name("label_map", 0))
    if os.path.exists(reference_file_path):
        mio_mesh_label = mio.read(reference_file_path)
        names = ['concentration', 'proliferation', 'growth', 'displacement']
        for name in names:
            path_to_vtu = os.path.join(base_path, name,
                                       create_file_name(name, timestep))
            if os.path.exists(path_to_vtu):
                mio_mesh = mio.read(path_to_vtu)
                if name in mio_mesh.point_data.keys():
                    point_array = mio_mesh.point_data[name]
                    mio_mesh_label.point_data[name] = point_array
                    if remove:
                        remove_vtu(path_to_vtu)
            else:
                print("   - File '%s' not found" % (path_to_vtu))
        # save joint vtu
        path_to_merged = os.path.join(base_path, 'merged',
                                      create_file_name("all", timestep))
        print("   - Saving joint file to '%s'" % (path_to_merged))
        fu.ensure_dir_exists(path_to_merged)
        mio.write(path_to_merged, mio_mesh_label)
    else:
        print("   - Could not find reference file '%s'... skipping" %
              (reference_file_path))
def ants_apply_transforms(input_img,
                          reference_img,
                          output_file,
                          transforms,
                          interpolation='Linear',
                          dim=3):
    print("  - Starting ANTS Apply Transforms:")
    print("    - INPUT IMG      : %s" % input_img)
    print("    - REFERENCE IMG  : %s" % reference_img)
    print("    - OUTPUT         : %s" % output_file)
    ants_cmd = build_ants_apply_transforms_command(input_img, reference_img,
                                                   output_file, transforms,
                                                   interpolation, dim)
    print("ANTS command: %s" % ants_cmd)
    fu.ensure_dir_exists(os.path.dirname(output_file))
    args = shlex.split(ants_cmd)
    process = subprocess.Popen(args, env=os.environ.copy())
    process.wait()
    #return process.returncode
    print("ANTS apply transforms terminated with return code: '%s'" %
          process.returncode)
def register_ants(fixed_img,
                  moving_img,
                  output_prefix,
                  path_to_transform=None,
                  registration_type='Rigid',
                  image_ext='mha',
                  fixed_mask=None,
                  moving_mask=None,
                  verbose=0,
                  dim=3):
    print("  - Starting ANTS registration:")
    print("    - FIXED IMG : %s" % fixed_img)
    print("    - MOVING IMG: %s" % moving_img)
    print("    - OUTPUT    : %s" % output_prefix)
    ants_cmd = build_ants_registration_command(fixed_img,
                                               moving_img,
                                               output_prefix,
                                               registration_type,
                                               image_ext,
                                               fixed_mask,
                                               moving_mask,
                                               verbose,
                                               dim=dim)
    print("ANTS command: %s" % ants_cmd)
    fu.ensure_dir_exists(os.path.dirname(output_prefix))
    args = shlex.split(ants_cmd)
    process = subprocess.Popen(args)
    process.wait()
    if process.returncode == 0 and path_to_transform != None:
        #-- rename trafo file
        if registration_type == 'Rigid' or registration_type == 'Affine':
            path_to_transform_ants = output_prefix + '0GenericAffine.mat'
            shutil.move(path_to_transform_ants, path_to_transform)
        if registration_type == 'Syn':
            path_to_transform_ants = output_prefix + '1Warp.nii.gz'
            shutil.move(path_to_transform_ants, path_to_transform)
    print("Registration terminated with return code: '%s'" %
          process.returncode)

    return process.returncode
sim.setup_model_parameters(iv_expression=ivs,
                           diffusion=diffusion,
                           coupling=coupling,
                           proliferation=prolif,
                           E=youngmod,
                           poisson=poisson,
                           sim_time=sim_time,
                           sim_time_step=sim_time_step)

# ==============================================================================
# Run Simulation
# ==============================================================================
output_path = os.path.join(
    test_config.output_path,
    'test_case_simulation_tumor_growth_2D_subdomains_adjoint', 'forward')
fu.ensure_dir_exists(output_path)

D_target = 0.3
rho_target = 0.1
coupling_target = 0.2

u_target = sim.run_for_adjoint([D_target, rho_target, coupling_target],
                               output_dir=output_path)

# ==============================================================================
# OPTIMISATION
# ==============================================================================
output_path = os.path.join(
    test_config.output_path,
    'test_case_simulation_tumor_growth_2D_subdomains_adjoint', 'adjoint')
fu.ensure_dir_exists(output_path)
Exemple #14
0
def plot_plt(field_var,
             showmesh=True,
             shading='flat',
             contours=[],
             colormap=plt.cm.jet,
             norm=None,
             cmap_ref=None,
             range_f=None,
             alpha_f=1,
             title=None,
             path=None,
             show=True,
             dpi=300):
    mesh = field_var.function_space().mesh()
    n = mesh.num_vertices()
    d = mesh.geometry().dim()
    value_dim = field_var.function_space().num_sub_spaces()
    #-- only 2D
    if d == 2:
        # Create triangulation
        mesh_coordinates = mesh.coordinates().reshape((n, d))
        triangles = np.asarray([cell.entities(0) for cell in cells(mesh)])
        triangulation = tri.Triangulation(mesh_coordinates[:, 0],
                                          mesh_coordinates[:, 1], triangles)
        plt.figure()
        parameters['allow_extrapolation'] = True
        z = np.asarray([
            field_var(point) if mesh.bounding_box_tree().collides_entity(
                Point(point))  # check if point in domain
            else np.nan  # assign nan otherwise
            for point in mesh_coordinates
        ])
        #z = np.asarray([field_var(point) for point in mesh_coordinates])
        if range_f is None:
            min_f = min(z)
            max_f = max(z)
        else:
            min_f = range_f[0]
            max_f = range_f[1]
        if type(colormap) == str:
            cmap = plt.cm.get_cmap(colormap)
        elif colormap == None:
            cmap = plt.cm.get_cmap('gist_earth')
        else:
            cmap = colormap

        if (norm is None) and (not cmap_ref is None):
            norm = vh.MidpointNormalize(midpoint=cmap_ref,
                                        vmin=min_f,
                                        vmax=max_f)

        #-- scalar function
        if value_dim == 1:

            if showmesh:
                plt.tripcolor(triangulation,
                              z,
                              cmap=cmap,
                              norm=norm,
                              shading=shading,
                              edgecolors='k',
                              linewidth=0.1,
                              vmin=min_f,
                              vmax=max_f,
                              alpha=alpha_f)
            else:
                plt.tripcolor(triangulation,
                              z,
                              cmap=cmap,
                              norm=norm,
                              shading=shading,
                              vmin=min_f,
                              vmax=max_f,
                              alpha=alpha_f)

        elif value_dim == 2:
            plot()

        plt.colorbar()
        if type(contours) == int:
            plt.tricontour(triangulation, z, contours, colors='k')
        elif len(contours) > 0:
            plt.tricontour(triangulation, z, levels=contours, colors='k')
    elif d == 3:
        plot(field_var)

    if title is not None:
        plt.suptitle(title, y=1.08)
    if path is not None:
        fu.ensure_dir_exists(os.path.dirname(path))
        plt.savefig(path, dpi=dpi, bbox_inches='tight')
        print("-- saved figure to '%s'" % path)
    if show:
        vh.show_plot()
Exemple #15
0
def plot(plot_object_list,
         dpi=100,
         plot_range=None,
         margin=0.02,
         cbarwidth=0.05,
         save_path=None,
         show=True,
         xlabel='x position [mm]',
         ylabel='y position [mm]',
         show_axes=True,
         show_cbar=True,
         show_title=True,
         show_ticks=True,
         cbar_size='5%',
         cbar_pad=0.05,
         cbar_fontsize=None,
         **kwargs):
    """
    Each element in `plot_object_list` is a dictionary of the form::

        { 'object' : the object to be plotted,
          'param1' : one plot specific parameter,
          'param2' : another plot specific parameter,
          ...
        }

    See :py:meth:`show_img_seg_f()` for examples.
    Unless a 'zorder' argument is specified, elements are plotted in the order of occurrence in `plot_object_list`.

    """

    # -- create Figure
    fig, ax = plt.subplots(dpi=dpi)
    ax.set_aspect('equal')

    # -- if an image is provided, the axes will be oriented as in the image,
    #   otherwise use xlim, ylim for fixing display orientation
    # -- for discussion about imshow/extent https://matplotlib.org/tutorials/intermediate/imshow_extent.html
    if plot_range:
        ax.set_xlim(*plot_range[0:2])
        ax.set_ylim(*plot_range[2:4])

    if 'title' in kwargs:
        title = kwargs.pop('title')
        if (not title is None) and show_title:
            fig.suptitle(title)
    #-- Check if only a single plot_object has been provided, if so, transform to dict
    if not type(plot_object_list) == list:
        plot_object_dict = {}
        plot_object_dict['object'] = plot_object_list
        plot_object_dict.update(kwargs)
        plot_object_list = [plot_object_dict]
    # -- Iterate through plot objects
    cbar_ax_list = [ax]
    for plot_object_dict_orig in plot_object_list:
        plot_object_dict = plot_object_dict_orig.copy()
        plot_object = plot_object_dict.pop('object')
        # check for 'color'
        if 'color' in plot_object_dict:
            color = plot_object_dict['color']
        else:
            color = False
        cbar = False
        if ('cbar_label' in plot_object_dict) and not color:
            cbar_label = plot_object_dict.pop('cbar_label')
            cbar = True

        #-- use global kwargs as reference and overwrite with settings that are specific to this plot_object
        params = kwargs.copy()
        params.update(plot_object_dict)
        #-- plot
        if type(plot_object) == Function:
            plot = plot_fenics_function(ax, plot_object, **params)
        elif type(plot_object) == sitk.Image:
            plot = plot_sitk_image(ax, plot_object, **params)
        else:
            print("The plot_object is of type '%s' -- not supported." %
                  (type(plot_object)))
            raise Exception

        if cbar and show_cbar:
            #cbax = add_colorbar(fig, cbar_ax_list[0], plot, cbar_label)
            cbax = add_colorbar(fig,
                                cbar_ax_list[0],
                                plot,
                                cbar_label,
                                size=cbar_size,
                                pad=cbar_pad,
                                fontsize=cbar_fontsize)
            cbar_ax_list.append(cbax)
            #fig.subplots_adjust(left=margin, bottom=margin, right=1 - margin - cbarwidth, top=1 - 2 * margin)

    if not show_axes:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

    if show_ticks:
        if xlabel:
            ax.set_xlabel(xlabel)
        if ylabel:
            ax.set_ylabel(ylabel)
    else:
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        ax.tick_params(bottom="off", left="off")

    if save_path:
        fu.ensure_dir_exists(os.path.dirname(save_path))
        fig.savefig(save_path, bbox_inches='tight', dpi=dpi)
        print("  - Saved figure to '%s'" % save_path)

    if show:
        vh.show_plot()
                               dirichlet_bcs=dirichlet_bcs,
                               von_neumann_bcs=von_neuman_bcs)

sim_TG.setup_model_parameters(iv_expression=ivs,
                              diffusion=diffusion,
                              coupling=coupling,
                              proliferation=prolif,
                              E=youngmod,
                              poisson=poisson,
                              sim_time=sim_time,
                              sim_time_step=sim_time_step)

output_path_TG = os.path.join(test_config.output_path,
                              'test_case_comparison_2D_atlas_2',
                              'simulation_tumor_growth')
fu.ensure_dir_exists(output_path_TG)
#sim_TG.run(save_method='xdmf',plot=False, output_dir=output_path_TG, clear_all=True)

# ==============================================================================
# TumorGrowthBrain
# ==============================================================================

sim_TGB = TumorGrowthBrain(mesh)

sim_TGB.setup_global_parameters(label_function=labelfunction,
                                domain_names=tissue_id_name_map,
                                boundaries=boundary_dict,
                                dirichlet_bcs=dirichlet_bcs,
                                von_neumann_bcs=von_neuman_bcs)

sim_TGB.setup_model_parameters(iv_expression=ivs,
                           nu_WM=nu_WM,
                           nu_CSF=nu_CSF,
                           nu_VENT=nu_VENT,
                           D_GM=D_GM,
                           D_WM=D_WM,
                           rho_GM=rho_GM,
                           rho_WM=rho_WM,
                           coupling=coupling)
# ==============================================================================
# Run Simulation
# ==============================================================================
output_path_forward = os.path.join(
    test_config.output_path,
    'test_case_simulation_tumor_growth_brain_2D_atlas_reduced_domain_adjoint_mpi_separated_functional',
    'forward')
fu.ensure_dir_exists(output_path_forward)

D_GM_target = 0.02
D_WM_target = 0.05
rho_GM_target = 0.1
rho_WM_target = 0.1
coupling_target = 0.15

params_target = [
    D_WM_target, D_GM_target, rho_WM_target, rho_GM_target, coupling_target
]
u_target = sim.run_for_adjoint(params_target, output_dir=output_path_forward)
sim.run(save_method='xdmf',
        plot=False,
        output_dir=output_path_forward,
        clear_all=True)
Exemple #18
0
"""Provides configuration settings for visualisation module.

It handles:
- Output option, either 'plt.plot()' for non-interactive backends or storage to local temp output dir.
"""

import os
import matplotlib as mpl

from glimslib import config
import glimslib.utils.file_utils as fu


#=== IDENTIFY TYPE OF PLOTTING BACKEND
#-- IMPORTANT if using PyCharm:
#   This only works if SciView is disabled. Settings, Tools, Python Scientific, disable -> Show Plots in Toolwindow
if mpl.get_backend() in mpl.rcsetup.non_interactive_bk:
    backend_interactive = False
else:
    backend_interactive = True

#=== PATH FOR TEMP PLOTS
path_tmp_fig = os.path.join(config.output_dir, 'tmp_fig')
fu.ensure_dir_exists(path_tmp_fig)