示例#1
1
def comparison_plot(f, u, Omega, filename='tmp.pdf',
                    plot_title='', ymin=None, ymax=None,
                    u_legend='approximation'):
    """Compare f(x) and u(x) for x in Omega in a plot."""
    x = sm.Symbol('x')
    print 'f:', f

    f = sm.lambdify([x], f, modules="numpy")
    u = sm.lambdify([x], u, modules="numpy")
    if len(Omega) != 2:
        raise ValueError('Omega=%s must be an interval (2-list)' % str(Omega))
    # When doing symbolics, Omega can easily contain symbolic expressions,
    # assume .evalf() will work in that case to obtain numerical
    # expressions, which then must be converted to float before calling
    # linspace below
    if not isinstance(Omega[0], (int,float)):
        Omega[0] = float(Omega[0].evalf())
    if not isinstance(Omega[1], (int,float)):
        Omega[1] = float(Omega[1].evalf())

    resolution = 401  # no of points in plot
    xcoor = linspace(Omega[0], Omega[1], resolution)
    # Vectorized functions expressions does not work with
    # lambdify'ed functions without the modules="numpy"
    exact  = f(xcoor)
    approx = u(xcoor)
    plot(xcoor, approx, '-')
    hold('on')
    plot(xcoor, exact, '-')
    legend([u_legend, 'exact'])
    title(plot_title)
    xlabel('x')
    if ymin is not None and ymax is not None:
        axis([xcoor[0], xcoor[-1], ymin, ymax])
    savefig(filename)
示例#2
0
def comparison_plot(u, Omega, u_e=None, filename='tmp.eps',
                    plot_title='', ymin=None, ymax=None):
    x = sp.Symbol('x')
    u = sp.lambdify([x], u, modules="numpy")
    if len(Omega) != 2:
        raise ValueError('Omega=%s must be an interval (2-list)' % str(Omega))
    # When doing symbolics, Omega can easily contain symbolic expressions,
    # assume .evalf() will work in that case to obtain numerical
    # expressions, which then must be converted to float before calling
    # linspace below
    if not isinstance(Omega[0], (int,float)):
        Omega[0] = float(Omega[0].evalf())
    if not isinstance(Omega[1], (int,float)):
        Omega[1] = float(Omega[1].evalf())

    resolution = 401  # no of points in plot
    xcoor = linspace(Omega[0], Omega[1], resolution)
    # Vectorized functions expressions does not work with
    # lambdify'ed functions without the modules="numpy"
    approx = u(xcoor)
    plot(xcoor, approx)
    legends = ['approximation']
    if u_e is not None:
        exact  = u_e(xcoor)
        hold('on')
        plot(xcoor, exact)
        legends = ['exact']
    legend(legends)
    title(plot_title)
    xlabel('x')
    if ymin is not None and ymax is not None:
        axis([xcoor[0], xcoor[-1], ymin, ymax])
    savefig(filename)
示例#3
0
def demo1(f=f1, nx=4, ny=3, viewx=58, viewy=345, plain_gnuplot=True):
    vertices, cells = mesh(nx, ny, x=[0, 1], y=[0, 1], diagonal='right')
    if f == 'basis':
        # basis function
        zvalues = np.zeros(vertices.shape[0])
        zvalues[int(round(len(zvalues) / 2.)) + int(round(nx / 2.))] = 1
    else:
        zvalues = fill(f1, vertices)
    if plain_gnuplot:
        plt = Gnuplotter()
    else:
        import scitools.std as plt
    """
    if plt.backend == 'gnuplot':
    gpl = plt.get_backend()
    gpl('unset border; unset xtics; unset ytics; unset ztics')
    #gpl('replot')
    """
    draw_mesh(vertices, cells, plt)
    draw_surface(zvalues, vertices, cells, plt)
    plt.axis([0, 1, 0, 1, 0, zvalues.max()])
    plt.view(viewx, viewy)
    if plain_gnuplot:
        plt.savefig('tmp')
    else:
        plt.savefig('tmp.pdf')
        plt.savefig('tmp.eps')
        plt.savefig('tmp.png')
    plt.show()
示例#4
0
def u_P1():
    """
    Plot P1 basis functions and a resulting u to
    illustrate what it means to use finite elements.
    """
    import matplotlib.pyplot as plt
    x = [0, 1.5, 2.5, 3.5, 4]
    phi = [np.zeros(len(x)) for i in range(len(x)-2)]
    for i in range(len(phi)):
        phi[i][i+1] = 1
    #u = 5*x*np.exp(-0.25*x**2)*(4-x)
    u = [0, 8, 5, 4, 0]
    for i in range(len(phi)):
        plt.plot(x, phi[i], 'r-')  #, label=r'$\varphi_%d$' % i)
        plt.text(x[i+1], 1.2, r'$\varphi_%d$' % i)
    plt.plot(x, u, 'b-', label='$u$')
    plt.legend(loc='upper left')
    plt.axis([0, x[-1], 0, 9])
    plt.savefig('tmp_u_P1.png')
    plt.savefig('tmp_u_P1.pdf')
    # Mark elements
    for xi in x[1:-1]:
        plt.plot([xi, xi], [0, 9], 'm--')
    # Mark nodes
    #plt.plot(x, np.zeros(len(x)), 'ro', markersize=4)
    plt.savefig('tmp_u_P1_welms.png')
    plt.savefig('tmp_u_P1_welms.pdf')
    plt.show()
示例#5
0
def u_P1():
    """
    Plot P1 basis functions and a resulting u to
    illustrate what it means to use finite elements.
    """
    import matplotlib.pyplot as plt
    x = [0, 1.5, 2.5, 3.5, 4]
    phi = [np.zeros(len(x)) for i in range(len(x)-2)]
    for i in range(len(phi)):
        phi[i][i+1] = 1
    #u = 5*x*np.exp(-0.25*x**2)*(4-x)
    u = [0, 8, 5, 4, 0]
    for i in range(len(phi)):
        plt.plot(x, phi[i], 'r-')  #, label=r'$\varphi_%d$' % i)
        plt.text(x[i+1], 1.2, r'$\varphi_%d$' % i)
    plt.plot(x, u, 'b-', label='$u$')
    plt.legend(loc='upper left')
    plt.axis([0, x[-1], 0, 9])
    plt.savefig('u_example_P1.png')
    plt.savefig('u_example_P1.pdf')
    # Mark elements
    for xi in x[1:-1]:
        plt.plot([xi, xi], [0, 9], 'm--')
    # Mark nodes
    #plt.plot(x, np.zeros(len(x)), 'ro', markersize=4)
    plt.savefig('u_example_P1_welms.png')
    plt.savefig('u_example_P1_welms.pdf')
    plt.show()
示例#6
0
def demo1(f=f1, nx=4, ny=3, viewx=58, viewy=345, plain_gnuplot=True):
    vertices, cells = mesh(nx, ny, x=[0,1], y=[0,1], diagonal='right')
    if f == 'basis':
        # basis function
        zvalues = np.zeros(vertices.shape[0])
        zvalues[int(round(len(zvalues)/2.)) + int(round(nx/2.))] = 1
    else:
        zvalues = fill(f1, vertices)
    if plain_gnuplot:
        plt = Gnuplotter()
    else:
        import scitools.std as plt

    """
    if plt.backend == 'gnuplot':
    gpl = plt.get_backend()
    gpl('unset border; unset xtics; unset ytics; unset ztics')
    #gpl('replot')
    """
    draw_mesh(vertices, cells, plt)
    draw_surface(zvalues, vertices, cells, plt)
    plt.axis([0, 1, 0, 1, 0, zvalues.max()])
    plt.view(viewx, viewy)
    if plain_gnuplot:
        plt.savefig('tmp')
    else:
        plt.savefig('tmp.pdf')
        plt.savefig('tmp.eps')
        plt.savefig('tmp.png')
    plt.show()
示例#7
0
def visualize(u, t, title='', filename='tmp'):
    plt.plot(t, u, 'b-')
    plt.xlabel('t')
    plt.ylabel('u')
    dt = t[1] - t[0]
    plt.title('dt=%g' % dt)
    umin = 1.2*u.min(); umax = 1.2*u.max()
    plt.axis([t[0], t[-1], umin, umax])
    plt.title(title)
    plt.savefig(filename + '.png')
    plt.savefig(filename + '.pdf')
    plt.show()
示例#8
0
def visualize(u, t, title='', filename='tmp'):
    plt.plot(t, u, 'b-')
    plt.xlabel('t')
    plt.ylabel('u')
    dt = t[1] - t[0]
    plt.title('dt=%g' % dt)
    umin = 1.2*u.min(); umax = 1.2*u.max()
    plt.axis([t[0], t[-1], umin, umax])
    plt.title(title)
    plt.savefig(filename + '.png')
    plt.savefig(filename + '.pdf')
    plt.show()
示例#9
0
def fe_basis_function_figure(d,
                             target_elm=[1],
                             N_e=3,
                             derivative=0,
                             filename='tmp.pdf',
                             labels=False):
    """
    Draw all basis functions (or their derivative), of degree d,
    associated with element target_elm (may be list of elements).
    Add a mesh with N_e elements.
    """
    nodes, elements = mesh_uniform(N_e, d)
    """
    x = 1.1
    print locate_element_vectorized(x, elements, nodes)
    print locate_element_scalar(x, elements, nodes)
    x = 0.1, 0.4, 0.8
    print locate_element_vectorized(x, elements, nodes)
    """
    if isinstance(target_elm, int):
        target_elm = [target_elm]  # wrap in list

    # Draw the basis functions for element 1
    phi_drawn = []  # list of already drawn phi functions
    ymin = ymax = 0
    for e in target_elm:
        for i in elements[e]:
            if not i in phi_drawn:
                x, y = phi_glob(i, elements, nodes, derivative=derivative)
                if x is None and y is None:
                    return  # abort
                ymax = max(ymax, max(y))
                ymin = min(ymin, min(y))
                plt.plot(x, y, '-')
                plt.hold('on')
                if labels:
                    if plt.backend == 'gnuplot':
                        if derivative == 0:
                            plt.legend(r'basis func. %d' % i)
                        else:
                            plt.legend(r'derivative of basis func. %d' % i)
                    elif plt.backend == 'matplotlib':
                        if derivative == 0:
                            plt.legend(r'\varphi_%d' % i)
                        else:
                            plt.legend(r"\varphi_%d'(x)" % i)
                phi_drawn.append(i)

    plt.axis([nodes[0], nodes[-1], ymin - 0.1, ymax + 0.1])
    plot_fe_mesh(nodes, elements, element_marker=[ymin - 0.1, ymax + 0.1])
    plt.hold('off')
    plt.savefig(filename)
示例#10
0
def visualize(u, t, title="", filename="tmp"):
    plt.plot(t, u, "b-")
    plt.xlabel("t")
    plt.ylabel("u")
    dt = t[1] - t[0]
    plt.title("dt=%g" % dt)
    umin = 1.2 * u.min()
    umax = 1.2 * u.max()
    plt.axis([t[0], t[-1], umin, umax])
    plt.title(title)
    plt.savefig(filename + ".png")
    plt.savefig(filename + ".pdf")
    plt.show()
示例#11
0
def fe_basis_function_figure(d, target_elm=[1], n_e=3,
                             derivative=0, filename='tmp.pdf',
                             labels=False):
    """
    Draw all basis functions (or their derivative), of degree d,
    associated with element target_elm (may be list of elements).
    Add a mesh with n_e elements.
    """
    nodes, elements = mesh_uniform(n_e, d)
    """
    x = 1.1
    print locate_element_vectorized(x, elements, nodes)
    print locate_element_scalar(x, elements, nodes)
    x = 0.1, 0.4, 0.8
    print locate_element_vectorized(x, elements, nodes)
    """
    if isinstance(target_elm, int):
        target_elm = [target_elm]  # wrap in list

    # Draw the basis functions for element 1
    phi_drawn = []  # list of already drawn phi functions
    ymin = ymax = 0
    for e in target_elm:
        for i in elements[e]:
            if not i in phi_drawn:
                x, y = phi_glob(i, elements, nodes,
                                derivative=derivative)
                if x is None and y is None:
                    return  # abort
                ymax = max(ymax, max(y))
                ymin = min(ymin, min(y))
                plt.plot(x, y, '-')
                plt.hold('on')
                if labels:
                    if plt.backend == 'gnuplot':
                        if derivative == 0:
                            plt.legend(r'basis function no. %d' % i)
                        else:
                            plt.legend(r'derivative of basis function no. %d' % i)
                    elif plt.backend == 'matplotlib':
                        if derivative == 0:
                            plt.legend(r'\varphi_%d' % i)
                        else:
                            plt.legend(r"\varphi_%d'(x)" % i)
                phi_drawn.append(i)

    plt.axis([nodes[0], nodes[-1], ymin-0.1, ymax+0.1])
    plot_fe_mesh(nodes, elements, element_marker=[ymin-0.1, ymax+0.1])
    plt.hold('off')
    plt.savefig(filename)
示例#12
0
文件: vib.py 项目: soranhm/IN5270
def plot_empirical_freq_and_amplitude(u, t):
    minima, maxima = minmax(t, u)
    p = periods(maxima)
    a = amplitudes(minima, maxima)
    plt.figure()
    from math import pi
    w = 2 * pi / p
    plt.plot(range(len(p)), w, 'r-')
    plt.hold('on')
    plt.plot(range(len(a)), a, 'b-')
    ymax = 1.1 * max(w.max(), a.max())
    ymin = 0.9 * min(w.min(), a.min())
    plt.axis([0, max(len(p), len(a)), ymin, ymax])
    plt.legend(['estimated frequency', 'estimated amplitude'],
               loc='upper right')
    return len(maxima)
示例#13
0
def plot_empirical_freq_and_amplitude(u, t):
    minima, maxima = minmax(t, u)
    p = periods(maxima)
    a = amplitudes(minima, maxima)
    plt.figure()
    from math import pi
    w = 2*pi/p    
    plt.plot(range(len(p)), w, 'r-')
    plt.hold('on')
    plt.plot(range(len(a)), a, 'b-')
    ymax = 1.1*max(w.max(), a.max())
    ymin = 0.9*min(w.min(), a.min())
    plt.axis([0, max(len(p), len(a)), ymin, ymax])
    plt.legend(['estimated frequency', 'estimated amplitude'],
               loc='upper right')
    return len(maxima)
示例#14
0
def draw_sparsity_pattern(elements, num_nodes):
    """Illustrate the matrix sparsity pattern."""
    import matplotlib.pyplot as plt
    sparsity_pattern = {}
    for e in elements:
        for i in range(len(e)):
            for j in range(len(e)):
                sparsity_pattern[(e[i],e[j])] = 1
    y = [i for i, j in sorted(sparsity_pattern)]
    x = [j for i, j in sorted(sparsity_pattern)]
    y.reverse()
    plt.plot(x, y, 'bo')
    ax = plt.gca()
    ax.set_aspect('equal')
    plt.axis('off')
    plt.plot([-1, num_nodes, num_nodes, -1, -1], [-1, -1, num_nodes, num_nodes, -1], 'k-')
    plt.savefig('tmp_sparsity_pattern.pdf')
    plt.savefig('tmp_sparsity_pattern.png')
    plt.show()
示例#15
0
def draw_sparsity_pattern(elements, num_nodes):
    """Illustrate the matrix sparsity pattern."""
    import matplotlib.pyplot as plt
    sparsity_pattern = {}
    for e in elements:
        for i in range(len(e)):
            for j in range(len(e)):
                sparsity_pattern[(e[i],e[j])] = 1
    y = [i for i, j in sorted(sparsity_pattern)]
    x = [j for i, j in sorted(sparsity_pattern)]
    y.reverse()
    plt.plot(x, y, 'bo')
    ax = plt.gca()
    ax.set_aspect('equal')
    plt.axis('off')
    plt.plot([-1, num_nodes, num_nodes, -1, -1], [-1, -1, num_nodes, num_nodes, -1], 'k-')
    plt.savefig('tmp_sparsity_pattern.pdf')
    plt.savefig('tmp_sparsity_pattern.png')
    plt.show()
示例#16
0
def plot_boundaries(outer_boundary, inner_boundaries=[], marked_points=None):
    if not isinstance(inner_boundaries, (tuple, list)):
        inner_boundaries = [inner_boundaries]
    boundaries = [outer_boundary]
    boundaries.extend(inner_boundaries)

    # Find max/min of plotting area
    plot_area = [
        min([b.x.min() for b in boundaries]),
        max([b.x.max() for b in boundaries]),
        min([b.y.min() for b in boundaries]),
        max([b.y.max() for b in boundaries]),
    ]

    aspect = (plot_area[3] - plot_area[2]) / (plot_area[1] - plot_area[0])
    for b in boundaries:
        plot(b.x, b.y, daspect=[aspect, 1, 1], daspectratio="manual")
        hold("on")
    axis(plot_area)
    title("Specification of domain with %d boundaries" % len(boundaries))
    if marked_points:
        for pt, name in marked_points:
            text(pt[0], pt[1], name)
示例#17
0
def plot_boundaries(outer_boundary, inner_boundaries=[], marked_points=None):
    if not isinstance(inner_boundaries, (tuple, list)):
        inner_boundaries = [inner_boundaries]
    boundaries = [outer_boundary]
    boundaries.extend(inner_boundaries)

    # Find max/min of plotting area
    plot_area = [
        min([b.x.min() for b in boundaries]),
        max([b.x.max() for b in boundaries]),
        min([b.y.min() for b in boundaries]),
        max([b.y.max() for b in boundaries])
    ]

    aspect = (plot_area[3] - plot_area[2]) / (plot_area[1] - plot_area[0])
    for b in boundaries:
        plot(b.x, b.y, daspect=[aspect, 1, 1], daspectratio='manual')
        hold('on')
    axis(plot_area)
    title('Specification of domain with %d boundaries' % len(boundaries))
    if marked_points:
        for pt, name in marked_points:
            text(pt[0], pt[1], name)
示例#18
0
 def amplification_factor(names):
     # Use SciTools since it adds markers to colored lines
     from scitools.std import (
         plot, title, xlabel, ylabel, hold, savefig,
         axis, legend, grid, show, figure)
     figure()
     curves = {}
     p = linspace(0, 3, 99)
     curves['exact'] = A_exact(p)
     plot(p, curves['exact'])
     hold('on')
     name2theta = dict(FE=0, BE=1, CN=0.5)
     for name in names:
         curves[name] = A(p, name2theta[name])
         plot(p, curves[name])
         axis([p[0], p[-1], -20, 20])
         #semilogy(p, curves[name])
     plot([p[0], p[-1]], [0, 0], '--')  # A=0 line
     title('Amplification factors')
     grid('on')
     legend(['exact'] + names, loc='lower left', fancybox=True)
     xlabel(r'$p=-a\cdot dt$')
     ylabel('Amplification factor')
     savefig('A_growth.png'); savefig('A_growth.pdf')
示例#19
0
 def amplification_factor(names):
     # Use SciTools since it adds markers to colored lines
     from scitools.std import (plot, title, xlabel, ylabel, hold, savefig,
                               axis, legend, grid, show, figure)
     figure()
     curves = {}
     p = linspace(0, 3, 99)
     curves['exact'] = A_exact(p)
     plot(p, curves['exact'])
     hold('on')
     name2theta = dict(FE=0, BE=1, CN=0.5)
     for name in names:
         curves[name] = A(p, name2theta[name])
         plot(p, curves[name])
         axis([p[0], p[-1], -20, 20])
         #semilogy(p, curves[name])
     plot([p[0], p[-1]], [0, 0], '--')  # A=0 line
     title('Amplification factors')
     grid('on')
     legend(['exact'] + names, loc='lower left', fancybox=True)
     xlabel(r'$p=-a\cdot dt$')
     ylabel('Amplification factor')
     savefig('A_growth.png')
     savefig('A_growth.pdf')
示例#20
0
def estimate(truncation_error, T, N_0, m, makeplot=True):
    """
    Compute the truncation error in a problem with one independent
    variable, using m meshes, and estimate the convergence
    rate of the truncation error.

    The user-supplied function truncation_error(dt, N) computes
    the truncation error on a uniform mesh with N intervals of
    length dt::

      R, t, R_a = truncation_error(dt, N)

    where R holds the truncation error at points in the array t,
    and R_a are the corresponding theoretical truncation error
    values (None if not available).

    The truncation_error function is run on a series of meshes
    with 2**i*N_0 intervals, i=0,1,...,m-1.
    The values of R and R_a are restricted to the coarsest mesh.
    and based on these data, the convergence rate of R (pointwise)
    and time-integrated R can be estimated empirically.
    """
    N = [2**i*N_0 for i in range(m)]

    R_I = np.zeros(m) # time-integrated R values on various meshes
    R   = [None]*m    # time series of R restricted to coarsest mesh
    R_a = [None]*m    # time series of R_a restricted to coarsest mesh
    dt = np.zeros(m)
    legends_R = [];  legends_R_a = []  # all legends of curves

    for i in range(m):
        dt[i] = T/float(N[i])
        R[i], t, R_a[i] = truncation_error(dt[i], N[i])

        R_I[i] = np.sqrt(dt[i]*np.sum(R[i]**2))

        if i == 0:
            t_coarse = t           # the coarsest mesh

        stride = N[i]/N_0
        R[i] = R[i][::stride]      # restrict to coarsest mesh
        R_a[i] = R_a[i][::stride]

        if makeplot:
            plt.figure(1)
            plt.plot(t_coarse, R[i], log='y')
            legends_R.append('N=%d' % N[i])
            plt.hold('on')

            plt.figure(2)
            plt.plot(t_coarse, R_a[i] - R[i], log='y')
            plt.hold('on')
            legends_R_a.append('N=%d' % N[i])

    if makeplot:
        plt.figure(1)
        plt.xlabel('time')
        plt.ylabel('pointwise truncation error')
        plt.legend(legends_R)
        plt.savefig('R_series.png')
        plt.savefig('R_series.pdf')
        plt.figure(2)
        plt.xlabel('time')
        plt.ylabel('pointwise error in estimated truncation error')
        plt.legend(legends_R_a)
        plt.savefig('R_error.png')
        plt.savefig('R_error.pdf')

    # Convergence rates
    r_R_I = convergence_rates(dt, R_I)
    print 'R integrated in time; r:',
    print ' '.join(['%.1f' % r for r in r_R_I])
    R = np.array(R)  # two-dim. numpy array
    r_R = [convergence_rates(dt, R[:,n])[-1]
           for n in range(len(t_coarse))]

    # Plot convergence rates
    if makeplot:
        plt.figure()
        plt.plot(t_coarse, r_R)
        plt.xlabel('time')
        plt.ylabel('r')
        plt.axis([t_coarse[0], t_coarse[-1], 0, 2.5])
        plt.title('Pointwise rate $r$ in truncation error $\sim\Delta t^r$')
        plt.savefig('R_rate_series.png')
        plt.savefig('R_rate_series.pdf')