Пример #1
0
    def test_line_markers(self):
        """test high-level usage for simple example.
        Test is successfull if generated tikz code saves correct amount of lines
        """
        x = np.linspace(1, 100, 20)
        y = np.linspace(1, 100, 20)

        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.plot(x, y, linestyle="-", marker="*")
            ax.set_ylim([20, 80])
            ax.set_xlim([20, 80])
            raw = get_tikz_code()

            clean_figure(fig)
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            # the baseline (raw) should have 20 points
            # the clean version (clean) should have 2 points
            # the difference in line numbers should therefore be 2
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 6
        plt.close("all")
Пример #2
0
    def test_trisurface3D(self):
        import matplotlib.pyplot as plt
        import numpy as np

        n_radii = 8
        n_angles = 36
        # Make radii and angles spaces (radius r=0 omitted to eliminate duplication).
        radii = np.linspace(0.125, 1.0, n_radii)
        angles = np.linspace(0, 2 * np.pi, n_angles, endpoint=False)

        # Repeat all angles for each radius.
        angles = np.repeat(angles[..., np.newaxis], n_radii, axis=1)

        # Convert polar (radii, angles) coords to cartesian (x, y) coords.
        # (0, 0) is manually added at this stage,  so there will be no duplicate
        # points in the (x, y) plane.
        x = np.append(0, (radii * np.cos(angles)).flatten())
        y = np.append(0, (radii * np.sin(angles)).flatten())

        # Compute z to make the pringle surface.
        z = np.sin(-x * y)

        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.gca(projection="3d")

            ax.plot_trisurf(x, y, z, linewidth=0.2, antialiased=True)
            with pytest.warns(Warning):
                clean_figure(fig)
        plt.close("all")
Пример #3
0
def plot_neat_training():
    with open(os.path.join(base_path, 'resources/neat-results.csv'), 'r') as f:
        neat_data = load_metrics(f)
    with open(os.path.join(base_path, 'resources/neat-recurrent-results.csv'),
              'r') as f:
        neat_rec_data = load_metrics(f)
    neat_data, neat_rec_data = crop_data(neat_data, neat_rec_data)

    fig = plt.figure()
    plt.plot(*neat_data['reward'], c='b', linewidth=1.0, label='feed-forward')
    plt.plot(*neat_rec_data['reward'], c='r', linewidth=1.0, label='recurrent')
    plt.xlabel('generation')
    plt.ylabel('return')
    plt.legend()
    plt.grid(linestyle='--')
    #plt.title('PAAC training cummulative reward')
    #plt.xlim(-, 1.2e7)
    ax = plt.gca()
    #plt.xlim(-, 1.2e7)
    plt.tight_layout()
    format_axes(ax)
    plt.savefig(os.path.join(output_path, "neat-reward.pdf"))
    plt.savefig(os.path.join(output_path, "neat-reward.eps"))

    tikzplotlib.clean_figure()
    tikzplotlib.save(os.path.join(output_path, "neat-reward.tex"))
Пример #4
0
    def test_surface3D(self):
        from matplotlib import cm
        from matplotlib.ticker import FormatStrFormatter, LinearLocator

        # Make data.
        X = np.arange(-5, 5, 0.25)
        Y = np.arange(-5, 5, 0.25)
        X, Y = np.meshgrid(X, Y)
        R = np.sqrt(X**2 + Y**2)
        Z = np.sin(R)

        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.gca(projection="3d")

            # Plot the surface.
            surf = ax.plot_surface(X,
                                   Y,
                                   Z,
                                   cmap=cm.coolwarm,
                                   linewidth=0,
                                   antialiased=False)

            # Customize the z axis.
            ax.set_zlim(-1.01, 1.01)
            ax.zaxis.set_major_locator(LinearLocator(10))
            ax.zaxis.set_major_formatter(FormatStrFormatter("%.02f"))

            # Add a color bar which maps values to colors.
            fig.colorbar(surf, shrink=0.5, aspect=5)

            with pytest.warns(Warning):
                clean_figure(fig)
        plt.close("all")
Пример #5
0
    def test_quiver3D(self):
        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.gca(projection="3d")

            # Make the grid
            x, y, z = np.meshgrid(
                np.arange(-0.8, 1, 0.2),
                np.arange(-0.8, 1, 0.2),
                np.arange(-0.8, 1, 0.8),
            )

            # Make the direction data for the arrows
            u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
            v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
            w = (
                np.sqrt(2.0 / 3.0)
                * np.cos(np.pi * x)
                * np.cos(np.pi * y)
                * np.sin(np.pi * z)
            )

            ax.quiver(x, y, z, u, v, w, length=0.1, normalize=True)
            with pytest.warns(Warning):
                clean_figure(fig)
        plt.close("all")
Пример #6
0
    def test_plot3d(self):
        theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)
        z = np.linspace(-2, 2, 100)
        r = z**2 + 1
        x = r * np.sin(theta)
        y = r * np.cos(theta)

        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.add_subplot(111, projection="3d")
            ax.plot(x, y, z)
            ax.set_xlim([-2, 2])
            ax.set_ylim([-2, 2])
            ax.set_zlim([-2, 2])
            ax.view_init(30, 30)
            raw = get_tikz_code(fig)

            clean_figure(fig)
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")

            assert numLinesRaw - numLinesClean == 14
        plt.close("all")
Пример #7
0
def save_rateplot(filename, source=None, plot_title=None,
                  width=size.w(1.25), height=size.h(1.25)):
    tpl.clean_figure()
    filepath = imgpath + filename + '.tex'
    tpl.save(filepath, wrap=False, axis_height=height,
             axis_width=width)
    add_begin_content(filepath)
    add_end_content(filepath, source=source, plot_title=plot_title)
Пример #8
0
def to_tikz(fig):
    tikzplotlib.clean_figure(fig)
    output = tikzplotlib.get_tikz_code(figure=fig,
                                       filepath=None,
                                       axis_width=None,
                                       axis_height=None,
                                       textsize=10.0,
                                       table_row_sep="\n")
    return output
Пример #9
0
def save_figure_as_tikz_tex_file(fig: plt.Figure,
                                 target_path: Union[str, os.fspath]):
    try:
        tikzplotlib.clean_figure(fig=fig)
        tikzplotlib.save(figure=fig, filepath=target_path, strict=True)
    except Exception as e:
        logging.error(
            f"Exception ({e.__class__.__name__}) occurred in attempt to export plot in tikz raw text format!\n"
            f"The following tikz tex file was not produced.\n\t{target_path}\n"
            f"The following lines show additional information on the {e.__class__.__name__}",
            exc_info=e)
Пример #10
0
 def test_hist(self):
     x = np.linspace(1, 100, 20)
     y = np.linspace(1, 100, 20)
     with plt.rc_context(rc=RC_PARAMS):
         fig, ax = plt.subplots(1, 1, figsize=(5, 5))
         ax.hist(x, y)
         ax.set_ylim([20, 80])
         ax.set_xlim([20, 80])
         with pytest.warns(Warning):
             clean_figure(fig)
     plt.close("all")
Пример #11
0
    def test_contour3D(self):
        from matplotlib import cm
        from mpl_toolkits.mplot3d import axes3d

        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.add_subplot(111, projection="3d")
            X, Y, Z = axes3d.get_test_data(0.05)
            cset = ax.contour(X, Y, Z, cmap=cm.coolwarm)
            ax.clabel(cset, fontsize=9, inline=1)
            with pytest.warns(Warning):
                clean_figure(fig)
        plt.close("all")
def xplotsplit(parameter):
    fig, axes = plt.subplots(1)
    if parameter == 'all':
        fig.set_size_inches(width_inches, height_inches)
        plot_all(axes, data)
    else:
        fig.set_size_inches(width_inches, height_inches)
        split_and_plot(axes, data, parameter)
    # fig.savefig('hyper_'+parameter+'_unfiltered.pgf')
    tikzplotlib.clean_figure()
    tikzplotlib.save(filepath='hyperv2_' + parameter + '_sigma_all.tex', strict=True, axis_height='4cm',
                     axis_width='5cm')
    fig.show()
Пример #13
0
def plottruss():
    '''
    plot the truss
    '''
    if model.plot_truss == "yes":
        if model.ndof == 1:
            for i in range(model.nel):
                XX = np.array([model.x[model.IEN[i, 0]-1], 
                               model.x[model.IEN[i, 1]-1]])
                YY = np.array([0.0, 0.0])
                plt.plot(XX, YY, "blue")

                if model.plot_node == "yes":
                    plt.text(XX[0], YY[0], str(model.IEN[i, 0]))
                    plt.text(XX[1], YY[1], str(model.IEN[i, 1]))
        elif model.ndof == 2:
            for i in range(model.nel):
                XX = np.array([model.x[model.IEN[i, 0]-1], 
                               model.x[model.IEN[i, 1]-1]])
                YY = np.array([model.y[model.IEN[i, 0]-1], 
                               model.y[model.IEN[i, 1]-1]])
                plt.plot(XX, YY, "blue")

                if model.plot_node == "yes":
                    plt.text(XX[0], YY[0], str(model.IEN[i, 0]))
                    plt.text(XX[1], YY[1], str(model.IEN[i, 1]))
        elif model.ndof == 3:
            # insert your code here for 3D
            # ...
            pass # delete or comment this line after your implementation for 3D
        else:
            raise ValueError("The dimension (ndof = {0}) given for the \
                             plottruss is invalid".format(model.ndof))
        
        plt.title("Truss Plot")
        plt.xlabel(r"$x$")
        plt.ylabel(r"$y$")
        plt.savefig("truss.pdf")

        # Convert matplotlib figures into PGFPlots figures stored in a Tikz file, 
        # which can be added into your LaTex source code by "\input{fe_plot.tex}"
        if model.plot_tex == "yes":
            import tikzplotlib
            tikzplotlib.clean_figure()
            tikzplotlib.save("fe_plot.tex")
    
    print("\t2D Truss Params \n")
    print(model.Title + "\n")
    print("No. of Elements  {0}".format(model.nel))
    print("No. of Nodes     {0}".format(model.nnp))
    print("No. of Equations {0}".format(model.neq))
Пример #14
0
    def _plot_training_plt(self,
                           outputpath,
                           metrics=None,
                           splits=None,
                           epochs=None):
        # outputpath: filepath where the plot should be saved
        # metrics can be a list of metrics names, or None if all metrics should be plotted
        color_dict = {}
        query = self.means_per_epoch(metrics, splits, epochs, squeeze=False)

        os.makedirs(outputpath, exist_ok=True)

        # Compute traces
        for idx, (m, s_dict) in enumerate(query.items()):
            f, ax = plt.subplots()
            ax.set_title(m)
            ax.set_xlabel("Epoch")
            axis_scale = 1.0
            if m in self.metrics.keys():
                meta = self.metrics[m]
            if meta is not None:
                if "axis_label" in meta.keys():
                    ax.set_ylabel(meta["axis_label"])
                if "axis_limits" in meta.keys():
                    ax.set_ylim(*meta["axis_limits"])
                if "axis_scale" in meta.keys():
                    axis_scale = float(meta["axis_scale"])

            for s, e_dict in s_dict.items():
                if s not in color_dict.keys():
                    trace_color = COLORS[len(color_dict.keys()) % len(COLORS)]
                    color_dict[s] = trace_color
                    showInLegend = True
                else:
                    trace_color = color_dict[s]
                    showInLegend = False

                y_values = list(e_dict.values())
                y_values = [y * axis_scale for y in y_values]

                ax.plot(
                    x=list(e_dict.keys()),
                    y=y_values,
                    c=trace_color,
                    label=s,
                )
            tikzplotlib.clean_figure()
            tikzplotlib.save(os.path.join(outputpath, "{}.tex".format(m)))
            plt.close()
Пример #15
0
    def test_subplot(self):
        """octave code
        ```octave
            addpath ("../matlab2tikz/src")

            x = linspace(1, 100, 20);
            y1 = linspace(1, 100, 20);

            figure
            subplot(2, 2, 1)
            plot(x, y1, "-")
            subplot(2, 2, 2)
            plot(x, y1, "-")
            subplot(2, 2, 3)
            plot(x, y1, "-")
            subplot(2, 2, 4)
            plot(x, y1, "-")
            xlim([20, 80])
            ylim([20, 80])
            set(gcf,'Units','Inches');
            set(gcf,'Position',[2.5 2.5 5 5])
            cleanfigure;
        ```
        """

        x = np.linspace(1, 100, 20)
        y = np.linspace(1, 100, 20)

        with plt.rc_context(rc=RC_PARAMS):
            fig, axes = plt.subplots(2, 2, figsize=(5, 5))
            plotstyles = [("-", "o"), ("-", "None"), ("None", "o"),
                          ("--", "x")]
            for ax, style in zip(axes.ravel(), plotstyles):
                ax.plot(x, y, linestyle=style[0], marker=style[1])
                ax.set_ylim([20, 80])
                ax.set_xlim([20, 80])
            raw = get_tikz_code()

            clean_figure(fig)
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            # the baseline (raw) should have 20 points
            # the clean version (clean) should have 2 points
            # the difference in line numbers should therefore be 2
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 36
        plt.close("all")
Пример #16
0
def main():
    # Create plot for male - female beta distribution example
    beta_example_gender()
    # Save figure
    tikzplotlib.clean_figure()
    tikzplotlib.save(imgpath + 'beta_example_gender.tex',
                     axis_width=size.w(1.25),
                     axis_height=size.h(1.25))

    # Create beta distribution examples that show how they change
    # This is for the slide deck and wil be smaller
    beta_example_change(ab0=(1, 1), history=[1, 0, 1])

    # Create plot for paper
    beta_example_change_paper()
Пример #17
0
    def test_wireframe3D(self):
        from mpl_toolkits.mplot3d import axes3d

        # Grab some test data.
        X, Y, Z = axes3d.get_test_data(0.05)

        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.add_subplot(111, projection="3d")

            # Plot a basic wireframe.
            ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)
            with pytest.warns(Warning):
                clean_figure(fig)
        plt.close("all")
Пример #18
0
def save_areaplot(filename, title):
    tpl.clean_figure()
    filepath = imgpath + filename + '.tex'
    tpl.save(filepath, wrap=False,
             extra_axis_parameters={"height=180pt, width=150pt",
                                    "reverse legend",
                                    "legend style={"
                                    + "at={(2.02, 0.5)},"
                                    + "anchor=west,"
                                    + "}"},
             extra_groupstyle_parameters={"horizontal sep=0.8cm",
                                          "group name=my plots"},
             )
    add_begin_content(filepath)
    title_str = title + " - number Bachelor's degrees awarded (thousands)"
    add_end_content(filepath, title_str)
Пример #19
0
    def test_xlog_2(self):
        x = np.arange(1, 100)
        y = np.arange(1, 100)
        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1)
            ax.plot(x, y)
            ax.set_xscale("log")
            raw = get_tikz_code()
            clean_figure()

            clean = get_tikz_code()
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 51
            assert numLinesClean == 71
        plt.close("all")
Пример #20
0
    def test_xlog(self):
        y = np.linspace(0, 3, 100)
        x = np.exp(y)

        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1)
            ax.plot(x, y)
            ax.set_xscale("log")
            raw = get_tikz_code()
            clean_figure()

            clean = get_tikz_code()
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 98
            assert numLinesClean == 25
        plt.close("all")
Пример #21
0
    def test_xlog_3(self):
        x = np.logspace(-3, 3, 20)
        y = np.linspace(1, 100, 20)

        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.plot(x, y)
            ax.set_xscale("log")
            ax.set_xlim([10**(-2), 10**(2)])
            ax.set_ylim([20, 80])
            raw = get_tikz_code()

            clean_figure(fig)
            clean = get_tikz_code()
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 18
        plt.close("all")
def save(filepath: Path,
         comment: Union[str, dict],
         figure="gcf",
         axis_width=None,
         axis_height=None,
         textsize=10.0,
         table_row_sep="\n"):

    tikzplotlib.clean_figure(fig=figure)
    output = tikzplotlib.get_tikz_code(figure=figure,
                                       filepath=filepath,
                                       axis_width=axis_width,
                                       axis_height=axis_height,
                                       textsize=textsize,
                                       table_row_sep=table_row_sep)
    output = _add_comment(output, comment)
    with open(filepath, "w") as tikz_file:
        tikz_file.write(output)
Пример #23
0
def main():
    args = parser.parse_args()
    fping_df = None
    if args.fping:
        fping_df = load_df(args.fping)
    iperf_df = None
    if args.iperf:
        iperf_df = load_df(args.iperf)

    # delay histograms
    distribution_type = 'ccdf'
    fit_distribution = False
    plotf = plt.loglog
    plt.figure()

    if fping_df is not None:
        plot_distribution(
            fping_df, label='Ping', plotf=plotf,
            type=distribution_type, fit_distribution=fit_distribution,
        )

    if iperf_df is not None:
        bs = [1, 2, 3, 10]
        for b in bs:
            g = iperf_df.loc[iperf_df['bin_mb'] == b]
            print(f'{b} MB: {len(g)} samples')
            plot_distribution(
                g, label=f'{b} MB', plotf=plotf,
                type=distribution_type, fit_distribution=fit_distribution,
            )

    plt.xlim(1e-5, 1)
    plt.ylim(1e-4, 1)
    plt.title('Network delay')
    plt.xlabel('Delay t [s]')
    plt.ylabel('Pr(delay > t)')
    plt.grid()
    plt.legend()
    tikzplotlib.clean_figure()
    tikzplotlib.save('ccdf.tex')
    plt.show()
    return
Пример #24
0
def plot_paac_training():
    with open(os.path.join(base_path, 'resources/paac-training'), 'r') as f:
        paac_data = load_metrics(f)

    fig = plt.figure()
    plt.plot(*paac_data['reward'], c='r', linewidth=0.7, label='reward')
    plt.xlabel('frames')
    plt.ylabel('return')
    plt.grid(linestyle='--')
    #plt.title('PAAC training cummulative reward')
    #plt.xlim(-, 1.2e7)
    ax = plt.gca()
    #plt.xlim(-, 1.2e7)
    plt.tight_layout()
    format_axes(ax)
    plt.savefig(os.path.join(output_path, "paac-reward.pdf"))
    plt.savefig(os.path.join(output_path, "paac-reward.eps"))

    tikzplotlib.clean_figure()
    tikzplotlib.save(os.path.join(output_path, "paac-reward.tex"))
Пример #25
0
    def test_bar3D(self):
        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.add_subplot(111, projection="3d")
            for c, z in zip(["r", "g", "b", "y"], [30, 20, 10, 0]):
                xs = np.arange(20)
                ys = np.random.rand(20)

                # You can provide either a single color or an array. To demonstrate this,
                # the first bar of each set will be colored cyan.
                cs = [c] * len(xs)
                cs[0] = "c"
                ax.bar(xs, ys, zs=z, zdir="y", color=cs, alpha=0.8)

            ax.set_xlabel("X")
            ax.set_ylabel("Y")
            ax.set_zlabel("Z")
            with pytest.warns(Warning):
                clean_figure(fig)
        plt.close("all")
Пример #26
0
def save_comboplot(cip_cls, filename):
    filepath = imgpath + filename + '.tex'
    file_handle = codecs.open(filepath, 'w')

    # Rate graph
    cip_cls.plot_rate()
    # To do: figure out why computer science ('11') raises error here
    try:
        tpl.clean_figure()
    except ValueError:
        pass
    code = tpl.get_tikz_code(axis_height='140pt',
                             axis_width='300pt',
                             # axis_width='150pt',
                             # extra_axis_parameters={'x post scale=2',
                             #                        'y post scale=1'}
                             )
    file_handle.write(code)
    file_handle.write('\n\\vspace{0.1cm}\n\\begin{tikzpicture}')
    file_handle.close()

    # area graph
    cip_cls.plot_area()
    tpl.clean_figure()
    code = tpl.get_tikz_code(
        wrap=False,
        extra_axis_parameters={"height=90pt, width=160pt",
                               "reverse legend",
                               "legend style={"
                               + "at={(2.02, 0.5)},"
                               + "anchor=west,"
                               + "}"},
        extra_groupstyle_parameters={"horizontal sep=0.8cm",
                                     "group name=my plots"},
    )
    with open(filepath, 'a+') as file_handle:
        content = file_handle.read()
        file_handle.seek(0, 0)
        file_handle.write('\n' + code + '\n' + content)
    group_title = 'Number Bachelor\'s degrees awarded (thousands)'
    add_end_content(filepath, group_title, title_space="0.25cm")
Пример #27
0
    def test_scatter(self):
        x = np.linspace(1, 100, 20)
        y = np.linspace(1, 100, 20)
        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.scatter(x, y)
            ax.set_ylim([20, 80])
            ax.set_xlim([20, 80])
            raw = get_tikz_code()

            clean_figure()
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            # the baseline (raw) should have 20 points
            # the clean version (clean) should have 2 points
            # the difference in line numbers should therefore be 2
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 6
        plt.close("all")
Пример #28
0
    def test_sine(self):
        x = np.linspace(1, 2 * np.pi, 100)
        y = np.sin(8 * x)

        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.plot(x, y, linestyle="-", marker="*")
            ax.set_xlim([0.5 * np.pi, 1.5 * np.pi])
            ax.set_ylim([-1, 1])
            raw = get_tikz_code()

            clean_figure(fig)
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            # the baseline (raw) should have 20 points
            # the clean version (clean) should have 2 points
            # the difference in line numbers should therefore be 2
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 39
        plt.close("all")
Пример #29
0
    def test_polygon3D(self):
        from matplotlib import colors as mcolors
        from matplotlib.collections import PolyCollection

        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.gca(projection="3d")

            def cc(arg):
                """

                :param arg:

                """
                return mcolors.to_rgba(arg, alpha=0.6)

            xs = np.arange(0, 10, 0.4)
            verts = []
            zs = [0.0, 1.0, 2.0, 3.0]
            for z in zs:
                ys = np.random.rand(len(xs))
                ys[0], ys[-1] = 0, 0
                verts.append(list(zip(xs, ys)))

            poly = PolyCollection(
                verts, facecolors=[cc("r"), cc("g"),
                                   cc("b"), cc("y")])
            poly.set_alpha(0.7)
            ax.add_collection3d(poly, zs=zs, zdir="y")

            ax.set_xlabel("X")
            ax.set_xlim3d(0, 10)
            ax.set_ylabel("Y")
            ax.set_ylim3d(-1, 4)
            ax.set_zlabel("Z")
            ax.set_zlim3d(0, 1)
            with pytest.warns(Warning):
                clean_figure(fig)
        plt.close("all")
Пример #30
0
def plot_flow_linear(sol, X, T, h, k, c, U_0, number=5):
    if number > 12:
        number = 12

    colors = ['midnightblue', 'royalblue', 'teal', 'mediumturquoise', 'mediumseagreen', 'forestgreen', 'yellowgreen',
              'goldenrod', 'darkorange', 'orangered', 'red', 'firebrick']
    counter = 0

    fig = plt.figure(figsize=(15, 5))
    ax = plt.subplot()

    for datapoint_raw in np.linspace(0, T - 1, number):
        datapoint = int(np.floor(datapoint_raw / k))
        data = sol[datapoint, :]

        n = len(U_0[0, :])
        shift = int(np.floor(c * datapoint_raw/h))
        left_bound = min(max(0 + shift, 0), n)
        right_bound = max(min(n + shift, n), 0)
        velo = np.copy(U_0)
        velo[0, :left_bound] = U_0[0, 0]
        velo[0, right_bound:] = U_0[0, n-1]
        velo[0, left_bound:right_bound] = U_0[0, left_bound - shift:right_bound - shift]

        x_mesh = np.arange(0, X, h)
        ax.plot(x_mesh, data, color=colors[counter])
        ax.plot(x_mesh, velo[0], ':', color=colors[counter])
        counter += 1

    ax.plot(np.array([]), np.array([]), '-', color='black', label='aproximation')
    ax.plot(np.array([]), np.array([]), ':', color='black', label='true solution')
    plt.legend()

    plt.tight_layout()
    tikzplotlib.clean_figure()
    tikzplotlib.save("test.tex")
    plt.show()