def mplot_expression(ax, f, mesh, **kwargs): # TODO: Can probably avoid creating the function space here by # restructuring mplot_function a bit so it can handle Expression # natively V = create_cg1_function_space(mesh, f.value_shape) g = fem.interpolate(f, V) return mplot_function(ax, g, **kwargs)
def mplot_function(ax, f, **kwargs): mesh = f.function_space.mesh gdim = mesh.geometry.dim tdim = mesh.topology.dim # Extract the function vector in a way that also works for # subfunctions try: fvec = f.vector except RuntimeError: fspace = f.function_space try: fspace = fspace.collapse() except RuntimeError: return fvec = fem.interpolate(f, fspace).vector map_c = mesh.topology.index_map(tdim) num_cells = map_c.size_local + map_c.num_ghosts if fvec.getSize() == num_cells: # DG0 cellwise function C = fvec.get_local() if (C.dtype.type is np.complex128): warnings.warn("Plotting real part of complex data") C = np.real(C) # NB! Assuming here dof ordering matching cell numbering if gdim == 2 and tdim == 2: return ax.tripcolor(mesh2triang(mesh), C, **kwargs) elif gdim == 3 and tdim == 2: # surface in 3d # FIXME: Not tested, probably broken xy = mesh.geometry.x shade = kwargs.pop("shade", True) return ax.plot_trisurf(mesh2triang(mesh), xy[:, 2], C, shade=shade, **kwargs) elif gdim == 1 and tdim == 1: x = mesh.geometry.x[:, 0] nv = len(x) # Insert duplicate points to get piecewise constant plot xp = np.zeros(2 * nv - 2) xp[0] = x[0] xp[-1] = x[-1] xp[1:2 * nv - 3:2] = x[1:-1] xp[2:2 * nv - 2:2] = x[1:-1] Cp = np.zeros(len(xp)) Cp[0:len(Cp) - 1:2] = C Cp[1:len(Cp):2] = C return ax.plot(xp, Cp, *kwargs) # elif tdim == 1: # FIXME: Plot embedded line else: raise AttributeError( 'Matplotlib plotting backend only supports 2D mesh for scalar functions.' ) elif f.function_space.element.value_rank == 0: # Scalar function, interpolated to vertices # TODO: Handle DG1? C = f.compute_point_values() if (C.dtype.type is np.complex128): warnings.warn("Plotting real part of complex data") C = np.real(C) if gdim == 2 and tdim == 2: mode = kwargs.pop("mode", "contourf") if mode == "contourf": levels = kwargs.pop("levels", 40) return ax.tricontourf(mesh2triang(mesh), C[:, 0], levels, **kwargs) elif mode == "color": shading = kwargs.pop("shading", "gouraud") return ax.tripcolor(mesh2triang(mesh), C[:, 0], shading=shading, **kwargs) elif mode == "warp": from matplotlib import cm cmap = kwargs.pop("cmap", cm.jet) linewidths = kwargs.pop("linewidths", 0) return ax.plot_trisurf(mesh2triang(mesh), C[:, 0], cmap=cmap, linewidths=linewidths, **kwargs) elif mode == "wireframe": return ax.triplot(mesh2triang(mesh), **kwargs) elif mode == "contour": return ax.tricontour(mesh2triang(mesh), C[:, 0], **kwargs) elif gdim == 3 and tdim == 2: # surface in 3d # FIXME: Not tested from matplotlib import cm cmap = kwargs.pop("cmap", cm.jet) return ax.plot_trisurf(mesh2triang(mesh), C[:, 0], cmap=cmap, **kwargs) elif gdim == 3 and tdim == 3: # Volume # TODO: Isosurfaces? # Vertex point cloud X = [mesh.geometrycoordinates[:, i] for i in range(gdim)] return ax.scatter(*X, c=C, **kwargs) elif gdim == 1 and tdim == 1: x = mesh.geometry.x[:, 0] ax.set_aspect('auto') p = ax.plot(x, C[:, 0], **kwargs) # Setting limits for Line2D objects # Must be done after generating plot to avoid ignoring function # range if no vmin/vmax are supplied vmin = kwargs.pop("vmin", None) vmax = kwargs.pop("vmax", None) ax.set_ylim([vmin, vmax]) return p # elif tdim == 1: # FIXME: Plot embedded line else: raise AttributeError( 'Matplotlib plotting backend only supports 2D mesh for scalar functions.' ) elif f.function_space.element.value_rank == 1: # Vector function, interpolated to vertices w0 = f.compute_point_values() if (w0.dtype.type is np.complex128): warnings.warn("Plotting real part of complex data") w0 = np.real(w0) map_v = mesh.topology.index_map(0) nv = map_v.size_local + map_v.num_ghosts if w0.shape[1] != gdim: raise AttributeError( 'Vector length must match geometric dimension.') X = mesh.geometry.x X = [X[:, i] for i in range(gdim)] U = [x for x in w0.T] # Compute magnitude C = U[0]**2 for i in range(1, gdim): C += U[i]**2 C = np.sqrt(C) mode = kwargs.pop("mode", "glyphs") if mode == "glyphs": args = X + U + [C] if gdim == 3: length = kwargs.pop("length", 0.1) return ax.quiver(*args, length=length, **kwargs) else: return ax.quiver(*args, **kwargs) elif mode == "displacement": Xdef = [X[i] + U[i] for i in range(gdim)] import matplotlib.tri as tri if gdim == 2 and tdim == 2: # FIXME: Not tested cells = mesh.geometry.dofmap.array.reshape( (-1, mesh.topology.dim + 1)) triang = tri.Triangulation(Xdef[0], Xdef[1], cells) shading = kwargs.pop("shading", "flat") return ax.tripcolor(triang, C, shading=shading, **kwargs) else: # Return gracefully to make regression test pass without vtk warnings.warn( "Plotting does not support displacement for {} in {}. Continuing without plot." .format(tdim, gdim)) return