Ejemplo n.º 1
0
def mplot_expression(ax, f, mesh, **kwargs):
    # TODO: Can probably avoid creating the function space here by
    # restructuring mplot_function a bit so it can handle Expression
    # natively
    V = create_cg1_function_space(mesh, f.value_shape)
    g = fem.interpolate(f, V)
    return mplot_function(ax, g, **kwargs)
Ejemplo n.º 2
0
def mplot_function(ax, f, **kwargs):
    mesh = f.function_space.mesh
    gdim = mesh.geometry.dim
    tdim = mesh.topology.dim

    # Extract the function vector in a way that also works for
    # subfunctions
    try:
        fvec = f.vector
    except RuntimeError:
        fspace = f.function_space
        try:
            fspace = fspace.collapse()
        except RuntimeError:
            return
        fvec = fem.interpolate(f, fspace).vector

    map_c = mesh.topology.index_map(tdim)
    num_cells = map_c.size_local + map_c.num_ghosts
    if fvec.getSize() == num_cells:
        # DG0 cellwise function
        C = fvec.get_local()
        if (C.dtype.type is np.complex128):
            warnings.warn("Plotting real part of complex data")
            C = np.real(C)
        # NB! Assuming here dof ordering matching cell numbering
        if gdim == 2 and tdim == 2:
            return ax.tripcolor(mesh2triang(mesh), C, **kwargs)
        elif gdim == 3 and tdim == 2:  # surface in 3d
            # FIXME: Not tested, probably broken
            xy = mesh.geometry.x
            shade = kwargs.pop("shade", True)
            return ax.plot_trisurf(mesh2triang(mesh),
                                   xy[:, 2],
                                   C,
                                   shade=shade,
                                   **kwargs)
        elif gdim == 1 and tdim == 1:
            x = mesh.geometry.x[:, 0]
            nv = len(x)
            # Insert duplicate points to get piecewise constant plot
            xp = np.zeros(2 * nv - 2)
            xp[0] = x[0]
            xp[-1] = x[-1]
            xp[1:2 * nv - 3:2] = x[1:-1]
            xp[2:2 * nv - 2:2] = x[1:-1]
            Cp = np.zeros(len(xp))
            Cp[0:len(Cp) - 1:2] = C
            Cp[1:len(Cp):2] = C
            return ax.plot(xp, Cp, *kwargs)
        # elif tdim == 1:  # FIXME: Plot embedded line
        else:
            raise AttributeError(
                'Matplotlib plotting backend only supports 2D mesh for scalar functions.'
            )

    elif f.function_space.element.value_rank == 0:
        # Scalar function, interpolated to vertices
        # TODO: Handle DG1?
        C = f.compute_point_values()
        if (C.dtype.type is np.complex128):
            warnings.warn("Plotting real part of complex data")
            C = np.real(C)

        if gdim == 2 and tdim == 2:
            mode = kwargs.pop("mode", "contourf")
            if mode == "contourf":
                levels = kwargs.pop("levels", 40)
                return ax.tricontourf(mesh2triang(mesh), C[:, 0], levels,
                                      **kwargs)
            elif mode == "color":
                shading = kwargs.pop("shading", "gouraud")
                return ax.tripcolor(mesh2triang(mesh),
                                    C[:, 0],
                                    shading=shading,
                                    **kwargs)
            elif mode == "warp":
                from matplotlib import cm
                cmap = kwargs.pop("cmap", cm.jet)
                linewidths = kwargs.pop("linewidths", 0)
                return ax.plot_trisurf(mesh2triang(mesh),
                                       C[:, 0],
                                       cmap=cmap,
                                       linewidths=linewidths,
                                       **kwargs)
            elif mode == "wireframe":
                return ax.triplot(mesh2triang(mesh), **kwargs)
            elif mode == "contour":
                return ax.tricontour(mesh2triang(mesh), C[:, 0], **kwargs)
        elif gdim == 3 and tdim == 2:  # surface in 3d
            # FIXME: Not tested
            from matplotlib import cm
            cmap = kwargs.pop("cmap", cm.jet)
            return ax.plot_trisurf(mesh2triang(mesh),
                                   C[:, 0],
                                   cmap=cmap,
                                   **kwargs)
        elif gdim == 3 and tdim == 3:
            # Volume
            # TODO: Isosurfaces?
            # Vertex point cloud
            X = [mesh.geometrycoordinates[:, i] for i in range(gdim)]
            return ax.scatter(*X, c=C, **kwargs)
        elif gdim == 1 and tdim == 1:
            x = mesh.geometry.x[:, 0]
            ax.set_aspect('auto')

            p = ax.plot(x, C[:, 0], **kwargs)

            # Setting limits for Line2D objects
            # Must be done after generating plot to avoid ignoring function
            # range if no vmin/vmax are supplied
            vmin = kwargs.pop("vmin", None)
            vmax = kwargs.pop("vmax", None)
            ax.set_ylim([vmin, vmax])

            return p
        # elif tdim == 1: # FIXME: Plot embedded line
        else:
            raise AttributeError(
                'Matplotlib plotting backend only supports 2D mesh for scalar functions.'
            )

    elif f.function_space.element.value_rank == 1:
        # Vector function, interpolated to vertices
        w0 = f.compute_point_values()
        if (w0.dtype.type is np.complex128):
            warnings.warn("Plotting real part of complex data")
            w0 = np.real(w0)
        map_v = mesh.topology.index_map(0)
        nv = map_v.size_local + map_v.num_ghosts
        if w0.shape[1] != gdim:
            raise AttributeError(
                'Vector length must match geometric dimension.')
        X = mesh.geometry.x
        X = [X[:, i] for i in range(gdim)]
        U = [x for x in w0.T]

        # Compute magnitude
        C = U[0]**2
        for i in range(1, gdim):
            C += U[i]**2
        C = np.sqrt(C)

        mode = kwargs.pop("mode", "glyphs")
        if mode == "glyphs":
            args = X + U + [C]
            if gdim == 3:
                length = kwargs.pop("length", 0.1)
                return ax.quiver(*args, length=length, **kwargs)
            else:
                return ax.quiver(*args, **kwargs)
        elif mode == "displacement":
            Xdef = [X[i] + U[i] for i in range(gdim)]
            import matplotlib.tri as tri
            if gdim == 2 and tdim == 2:
                # FIXME: Not tested
                cells = mesh.geometry.dofmap.array.reshape(
                    (-1, mesh.topology.dim + 1))
                triang = tri.Triangulation(Xdef[0], Xdef[1], cells)
                shading = kwargs.pop("shading", "flat")
                return ax.tripcolor(triang, C, shading=shading, **kwargs)
            else:
                # Return gracefully to make regression test pass without vtk
                warnings.warn(
                    "Plotting does not support displacement for {} in {}. Continuing without plot."
                    .format(tdim, gdim))
                return