def test_line_markers(self):
        """test high-level usage for simple example.
        Test is successfull if generated tikz code saves correct amount of lines
        """
        x = np.linspace(1, 100, 20)
        y = np.linspace(1, 100, 20)

        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.plot(x, y, linestyle="-", marker="*")
            ax.set_ylim([20, 80])
            ax.set_xlim([20, 80])
            raw = get_tikz_code()

            clean_figure(fig)
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            # the baseline (raw) should have 20 points
            # the clean version (clean) should have 2 points
            # the difference in line numbers should therefore be 2
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 6
        plt.close("all")
    def test_plot3d(self):
        theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)
        z = np.linspace(-2, 2, 100)
        r = z**2 + 1
        x = r * np.sin(theta)
        y = r * np.cos(theta)

        with plt.rc_context(rc=RC_PARAMS):
            fig = plt.figure()
            ax = fig.add_subplot(111, projection="3d")
            ax.plot(x, y, z)
            ax.set_xlim([-2, 2])
            ax.set_ylim([-2, 2])
            ax.set_zlim([-2, 2])
            ax.view_init(30, 30)
            raw = get_tikz_code(fig)

            clean_figure(fig)
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")

            assert numLinesRaw - numLinesClean == 14
        plt.close("all")
Beispiel #3
0
def assert_equality(plot,
                    filename,
                    assert_compilation=True,
                    flavor="latex",
                    **extra_get_tikz_code_args):
    plot()
    code = tikzplotlib.get_tikz_code(
        include_disclaimer=False,
        float_format=".8g",
        flavor=flavor,
        **extra_get_tikz_code_args,
    )
    plt.close()

    this_dir = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(this_dir, filename), encoding="utf-8") as f:
        reference = f.read()
    assert reference == code, filename + "\n" + _unidiff_output(
        reference, code)

    if assert_compilation:
        plot()
        code = tikzplotlib.get_tikz_code(
            include_disclaimer=False,
            standalone=True,
            flavor=flavor,
            **extra_get_tikz_code_args,
        )
        plt.close()
        assert _compile(code, flavor) is not None, code
    def test_subplot(self):
        """octave code
        ```octave
            addpath ("../matlab2tikz/src")

            x = linspace(1, 100, 20);
            y1 = linspace(1, 100, 20);

            figure
            subplot(2, 2, 1)
            plot(x, y1, "-")
            subplot(2, 2, 2)
            plot(x, y1, "-")
            subplot(2, 2, 3)
            plot(x, y1, "-")
            subplot(2, 2, 4)
            plot(x, y1, "-")
            xlim([20, 80])
            ylim([20, 80])
            set(gcf,'Units','Inches');
            set(gcf,'Position',[2.5 2.5 5 5])
            cleanfigure;
        ```
        """

        x = np.linspace(1, 100, 20)
        y = np.linspace(1, 100, 20)

        with plt.rc_context(rc=RC_PARAMS):
            fig, axes = plt.subplots(2, 2, figsize=(5, 5))
            plotstyles = [("-", "o"), ("-", "None"), ("None", "o"),
                          ("--", "x")]
            for ax, style in zip(axes.ravel(), plotstyles):
                ax.plot(x, y, linestyle=style[0], marker=style[1])
                ax.set_ylim([20, 80])
                ax.set_xlim([20, 80])
            raw = get_tikz_code()

            clean_figure(fig)
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            # the baseline (raw) should have 20 points
            # the clean version (clean) should have 2 points
            # the difference in line numbers should therefore be 2
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 36
        plt.close("all")
def _main():
    parser = argparse.ArgumentParser(
        description="Refresh all reference TeX files.")
    parser.parse_args()

    this_dir = pathlib.Path(__file__).resolve().parent

    test_files = [
        f for f in this_dir.iterdir()
        if (this_dir /
            f).is_file() and f.name[:5] == "test_" and f.name[-3:] == ".py"
    ]
    test_modules = [f.name[:-3] for f in test_files]

    # remove some edge cases
    test_modules.remove("test_rotated_labels")
    test_modules.remove("test_deterministic_output")
    test_modules.remove("test_cleanfigure")
    test_modules.remove("test_context")
    test_modules.remove("test_readme")

    for mod in test_modules:
        module = importlib.import_module(mod)
        module.plot()

        code = tpl.get_tikz_code(include_disclaimer=False, float_format=".8g")
        plt.close()

        tex_filename = mod + "_reference.tex"
        with open(this_dir / tex_filename, "w", encoding="utf8") as f:
            f.write(code)
Beispiel #6
0
def assert_equality(plot, filename, **extra_get_tikz_code_args):
    plot()
    code = tikzplotlib.get_tikz_code(include_disclaimer=False,
                                     **extra_get_tikz_code_args)
    plt.close()

    this_dir = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(this_dir, filename), "r", encoding="utf-8") as f:
        reference = f.read()
    assert reference == code, _unidiff_output(code, reference)

    code = tikzplotlib.get_tikz_code(include_disclaimer=False,
                                     standalone=True,
                                     **extra_get_tikz_code_args)
    assert _compile(code) is not None
    return
def plot_quorum_enumeration(n_orgs_range, n_quorums, t_elapsed):
    '''Plot number of quorums and timings'''
    plt.rc('text', usetex=True)

    plt.style.use("ggplot")
    prop_cycle = list(plt.rcParams['axes.prop_cycle'])
    plt.style.use("default")

    def loground(number, direction='up'):
        round_fun = math.ceil if direction == 'up' else math.floor
        return 10**round_fun(math.log10(number))

    _, ax1 = plt.subplots()
    ax1.set_xlabel(r'Number of organizations $n$')
    ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax1.set_ylabel('Computation time (s)')
    ax1.set_yscale('log')
    ax1.set_ylim(loground(min(t_elapsed), direction='down'),
                 loground(max(t_elapsed), direction='up'))
    ax1.plot(n_orgs_range, t_elapsed, 'o-', lw=2, color=prop_cycle[0]['color'])

    ax2 = ax1.twinx()
    ax2.set_ylabel('Number of quorums')
    ax2.set_yscale('log')
    ax2.set_ylim(loground(min(n_quorums), direction='down'),
                 loground(max(n_quorums), direction='up'))
    ax2.plot(n_orgs_range, n_quorums, 'o:', lw=2, color=prop_cycle[1]['color'])

    print(tikzplotlib.get_tikz_code())
Beispiel #8
0
def _main():
    parser = argparse.ArgumentParser(
        description="Refresh the reference TeX files.")
    parser.add_argument("files", nargs="+", help="Files to refresh")
    args = parser.parse_args()

    this_dir = os.path.dirname(os.path.abspath(__file__))
    exclude_list = ["test_rotated_labels.py", "test_deterministic_output.py"]

    for filename in args.files:
        if filename in exclude_list:
            continue
        if filename.startswith("test_") and filename.endswith(".py"):
            spec = importlib.util.spec_from_file_location("plot", filename)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            module.plot()

            code = m2t.get_tikz_code(include_disclaimer=False)
            plt.close()

            tex_filename = filename[:-3] + "_reference.tex"
            with open(os.path.join(this_dir, tex_filename), "w") as f:
                f.write(code)
    return
    def render(self, latexfile=None):
        for ax in self.axes:
            self._axes_legend(ax)
        if latexfile is None:
            if HEADLESS_MODE:
                print("Warning: running in headless mode, won't show anything",
                      file=sys.stdout)
                return
            win = tk.Tk()
            win.title('Figure')

            canvas = FigureCanvasTkAgg(self, master=win)
            canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

            toolbar = NavigationToolbar2Tk(canvas, win)
            toolbar.update()
            canvas.mpl_connect(
                'key_press_event',
                lambda ev: key_press_handler(ev, canvas, toolbar))

            canvas.draw()
            tk.mainloop()
        else:
            tikzcode = tikzplotlib.get_tikz_code(
                figure=self,
                extra_axis_parameters=[
                    'scaled ticks=false',
                    'xticklabel style={/pgf/number format/.cd,fixed,precision=2}',
                    'yticklabel style={/pgf/number format/.cd,fixed,precision=2}'
                ],
                axis_width=r'\figW',
                float_format='.5g')
            with open(latexfile, 'w') as f:
                f.write(tikzcode)
def _main():
    parser = argparse.ArgumentParser(
        description="Refresh all reference TeX files.")
    parser.parse_args()

    this_dir = os.path.dirname(os.path.abspath(__file__))

    test_files = [
        f for f in os.listdir(this_dir)
        if os.path.isfile(os.path.join(this_dir, f)) and f[:5] == "test_"
        and f[-3:] == ".py"
    ]
    test_modules = [f[:-3] for f in test_files]

    # remove some edge cases
    test_modules.remove("test_rotated_labels")
    test_modules.remove("test_deterministic_output")
    test_modules.remove("test_cleanfigure")
    test_modules.remove("test_context")

    for mod in test_modules:
        module = importlib.import_module(mod)
        module.plot()

        code = tpl.get_tikz_code(include_disclaimer=False, float_format=".8g")
        plt.close()

        tex_filename = mod + "_reference.tex"
        with open(os.path.join(this_dir, tex_filename), "w",
                  encoding="utf8") as f:
            f.write(code)
    def test_xlog_2(self):
        x = np.arange(1, 100)
        y = np.arange(1, 100)
        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1)
            ax.plot(x, y)
            ax.set_xscale("log")
            raw = get_tikz_code()
            clean_figure()

            clean = get_tikz_code()
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 51
            assert numLinesClean == 71
        plt.close("all")
    def test_xlog(self):
        y = np.linspace(0, 3, 100)
        x = np.exp(y)

        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1)
            ax.plot(x, y)
            ax.set_xscale("log")
            raw = get_tikz_code()
            clean_figure()

            clean = get_tikz_code()
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 98
            assert numLinesClean == 25
        plt.close("all")
    def test_xlog_3(self):
        x = np.logspace(-3, 3, 20)
        y = np.linspace(1, 100, 20)

        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.plot(x, y)
            ax.set_xscale("log")
            ax.set_xlim([10**(-2), 10**(2)])
            ax.set_ylim([20, 80])
            raw = get_tikz_code()

            clean_figure(fig)
            clean = get_tikz_code()
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 18
        plt.close("all")
Beispiel #14
0
def to_tikz(fig):
    tikzplotlib.clean_figure(fig)
    output = tikzplotlib.get_tikz_code(figure=fig,
                                       filepath=None,
                                       axis_width=None,
                                       axis_height=None,
                                       textsize=10.0,
                                       table_row_sep="\n")
    return output
Beispiel #15
0
def save_comboplot(cip_cls, filename):
    filepath = imgpath + filename + '.tex'
    file_handle = codecs.open(filepath, 'w')

    # Rate graph
    cip_cls.plot_rate()
    # To do: figure out why computer science ('11') raises error here
    try:
        tpl.clean_figure()
    except ValueError:
        pass
    code = tpl.get_tikz_code(axis_height='140pt',
                             axis_width='300pt',
                             # axis_width='150pt',
                             # extra_axis_parameters={'x post scale=2',
                             #                        'y post scale=1'}
                             )
    file_handle.write(code)
    file_handle.write('\n\\vspace{0.1cm}\n\\begin{tikzpicture}')
    file_handle.close()

    # area graph
    cip_cls.plot_area()
    tpl.clean_figure()
    code = tpl.get_tikz_code(
        wrap=False,
        extra_axis_parameters={"height=90pt, width=160pt",
                               "reverse legend",
                               "legend style={"
                               + "at={(2.02, 0.5)},"
                               + "anchor=west,"
                               + "}"},
        extra_groupstyle_parameters={"horizontal sep=0.8cm",
                                     "group name=my plots"},
    )
    with open(filepath, 'a+') as file_handle:
        content = file_handle.read()
        file_handle.seek(0, 0)
        file_handle.write('\n' + code + '\n' + content)
    group_title = 'Number Bachelor\'s degrees awarded (thousands)'
    add_end_content(filepath, group_title, title_space="0.25cm")
    def test_scatter(self):
        x = np.linspace(1, 100, 20)
        y = np.linspace(1, 100, 20)
        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.scatter(x, y)
            ax.set_ylim([20, 80])
            ax.set_xlim([20, 80])
            raw = get_tikz_code()

            clean_figure()
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            # the baseline (raw) should have 20 points
            # the clean version (clean) should have 2 points
            # the difference in line numbers should therefore be 2
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 6
        plt.close("all")
    def test_sine(self):
        x = np.linspace(1, 2 * np.pi, 100)
        y = np.sin(8 * x)

        with plt.rc_context(rc=RC_PARAMS):
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.plot(x, y, linestyle="-", marker="*")
            ax.set_xlim([0.5 * np.pi, 1.5 * np.pi])
            ax.set_ylim([-1, 1])
            raw = get_tikz_code()

            clean_figure(fig)
            clean = get_tikz_code()

            # Use number of lines to test if it worked.
            # the baseline (raw) should have 20 points
            # the clean version (clean) should have 2 points
            # the difference in line numbers should therefore be 2
            numLinesRaw = raw.count("\n")
            numLinesClean = clean.count("\n")
            assert numLinesRaw - numLinesClean == 39
        plt.close("all")
Beispiel #18
0
def compare_mpl_tex(plot, flavor="latex"):
    plot()
    code = tikzplotlib.get_tikz_code(standalone=True)
    directory = os.getcwd()
    filename = "test-0.png"
    plt.savefig(filename)
    plt.close()

    pdf_file = _compile(code, flavor)
    pdf_dirname = os.path.dirname(pdf_file)

    # Convert PDF to PNG.
    subprocess.check_output(
        ["pdftoppm", "-r", "1000", "-png", pdf_file, "test"],
        stderr=subprocess.STDOUT)
    png_path = os.path.join(pdf_dirname, "test-1.png")

    os.rename(png_path, os.path.join(directory, "test-1.png"))
def save(filepath: Path,
         comment: Union[str, dict],
         figure="gcf",
         axis_width=None,
         axis_height=None,
         textsize=10.0,
         table_row_sep="\n"):

    tikzplotlib.clean_figure(fig=figure)
    output = tikzplotlib.get_tikz_code(figure=figure,
                                       filepath=filepath,
                                       axis_width=axis_width,
                                       axis_height=axis_height,
                                       textsize=textsize,
                                       table_row_sep=table_row_sep)
    output = _add_comment(output, comment)
    with open(filepath, "w") as tikz_file:
        tikz_file.write(output)
Beispiel #20
0
    mean=0.9981387614088575,
    std=0.07243252473207047)
object_list = [adam, adapt_only_sgd, adapt_sgd, precon_sgd]

fig = plt.figure(figsize=(3, 3))
plt.bar([o.name for o in object_list], [o.mean for o in object_list],
        alpha=0.7,
        yerr=[o.std for o in object_list],
        capsize=7)
plt.hlines(1, xmin=-1, xmax=4, linestyles='dashed', alpha=0.7)
plt.xticks(rotation='25', horizontalalignment='right', verticalalignment='top')
plt.show()

# mpl.use("pgf")
# mpl.rcParams["pgf.rcfonts"] = False
# plt.gcf().savefig("../../thesis/images/exp_perf_prec.pgf")

# General settings
code = tikzplotlib.get_tikz_code(
    figure=fig,
    extra_axis_parameters=[
        "tick pos=left", "xticklabel style = {anchor = north east}"
    ],
)  #strict = True)
print(code)
#catch missed underscores & save
code = code.replace("\_", "_").replace("_", "\_")
file = codecs.open("../../thesis/images/exp_perf_prec.pgf", "w", 'utf-8')
file.write(code)
file.close()
    ax.set_ylim([-0.05, 1.1])
if gap_data:
    ax.set_ylim([-0.1, 1.15])
ax.legend()
# ax.set_yticks([0.5, 0.8, 0.95, 1.0])
# ax.set_yticklabels([0.5, 0.8, 0.95, 1.0])
if plot_single_lines:
    ax.set_ylim([-1.0, 2.5])

if save:
    save_path = figure_dir.joinpath(
        f"{experiment_name}_{method}_{metric}.pdf".replace(" ", "-"))
    fig.savefig(save_path, bbox_inches="tight")
    tex = tpl.get_tikz_code(
        axis_width="\\figwidth",  # we want LaTeX to take care of the width
        axis_height="\\figheight",  # we want LaTeX to take care of the height
        # we want the plot to look *exactly* like here (e.g. axis limits, axis ticks, etc.)
        # strict=True,
    )
    with save_path.with_suffix(".tex").open("w", encoding="utf-8") as f:
        f.write(tex)

# %%

_net_sizes = np.arange(n_networks) + 1.0
_net_sizes = _net_sizes[np.logical_or(_net_sizes <= 20, _net_sizes % 5 == 0)]

fig, ax = plt.subplots(figsize=(18, 10))
higher_is_better = True
optimum_is_full_ensemble = True
plot_single_lines = False
method = "Ensemble"
Beispiel #22
0
nDf = NpvDf('OFIPS', 'SIM_ICV_01_STG_EXT')

fig, ax = plt.subplots()
ax.plot(nDf.df['Close Condition'], nDf.df['Npv'], 'o', markersize=4)
ax.tick_params(axis="x", labelrotation=90)
#ax.set_title('Net present value\\\\Production strategies w/ ON/OFF ICV\'s')
ax.set_xlabel('(GOR [$sm^3/sm^3$]; WCUT [$sm^3/sm^3$])')
ax.set_ylabel('Millions of dollars')
ax.set_xlim(0, 99)
ax.set_ylim(1100, 2400)

settings = {}
settings['textsize'] = 6
settings['figurewidth'] = '0.70\\textwidth'
settings['figureheight'] = '0.60\\textheight'
eap = []
eap.append(
    'title={Net present value\\\\\\scriptsize{Production strategies w/ ON/OFF ICV\'s}}'
)
eap.append('title style={align=left, font=\\normalsize}')
eap.append('tick label style={font=\\scriptsize}')
eap.append('label style={font=\\scriptsize}')
eap.append('legend pos=north east')
settings['extra_axis_parameters'] = eap
settings['extra_tikzpicture_parameters'] = None
code = tikzplotlib.get_tikz_code(**settings)
code = utils.filter_xtikcs(code, 3)
with open('./latex/npvf/tikz/graph01.tikz', 'w') as fh:
    fh.write(code)
Beispiel #23
0
def tpl_save(*args, **kwargs):

    # these are not passed to get_tikz_code
    encoding = kwargs.pop("encoding", None)
    extra_body_parameters = kwargs.pop('extra_body_parameters', None)

    # always pass certain tikzpicture parameters
    extra_tikzpicture_parameters = kwargs.pop(
        'extra_tikzpicture_parameters', None)
    standard_tikzpicture_parameters = {
        '\\providecommand{\\thisXlabelopacity}{1.0}',
        '\\providecommand{\\thisYlabelopacity}{1.0}',
        '\\pgfplotsset{compat=1.15}',
    }
    extra_tikzpicture_parameters = (
        standard_tikzpicture_parameters
        if extra_tikzpicture_parameters is None else
        extra_tikzpicture_parameters | standard_tikzpicture_parameters
    )

    # always pass certain axis parameters
    extra_axis_parameters = kwargs.pop('extra_axis_parameters', None)
    standard_axis_parameters = {
        'every axis x label/.append style={opacity=\\thisXlabelopacity}',
        'every axis y label/.append style={opacity=\\thisYlabelopacity}',
    }
    extra_axis_parameters = (
        standard_axis_parameters if extra_axis_parameters is None else
        extra_axis_parameters | standard_axis_parameters
    )

    # get the code
    code = tpl.get_tikz_code(
        *args, **kwargs,
        extra_axis_parameters=extra_axis_parameters,
        extra_tikzpicture_parameters=extra_tikzpicture_parameters,
    )

    # perhaps tack some extra code before the \end{axis} command
    if extra_body_parameters is not None:
        end_axis = '\\end{axis}'
        code = tpl.get_tikz_code(
            *args, **kwargs,
            extra_axis_parameters=extra_axis_parameters,
            extra_tikzpicture_parameters=extra_tikzpicture_parameters,
        )
        code_pieces = code.split(end_axis)
        code = (
            code_pieces[0] +
            reduce(lambda a, b: b+'\n'+a, reversed(extra_body_parameters), end_axis) +
            code_pieces[1]
        )

    # ...
    filepath = kwargs.pop('filepath')
    file_handle = codecs.open(filepath, "w", encoding)
    try:
        file_handle.write(code)
    except UnicodeEncodeError:
        # We're probably using Python 2, so treat unicode explicitly
        file_handle.write(six.text_type(code).encode("utf-8"))
    file_handle.close()
Beispiel #24
0
def tpl_save(
        #####
        # non-tpl args
        filepath,
        #####
        *args,
        #####
        # non-tpl kwargs
        encoding=None,
        pre_tikzpicture_lines=None,
        extra_body_parameters=None,
        #####
        #####
        # this kwarg has a non-empty default (set below)
        extra_axis_parameters=None,
        #####
        **kwargs):

    # Always pass certain pre-tikzpicture lines.
    standard_pre_tikzpicture_lines = {
        '\\providecommand{\\thisXlabelopacity}{1.0}',
        '\\providecommand{\\thisYlabelopacity}{1.0}',
        '\\pgfplotsset{compat=1.15}',
    }

    # To allow the boolean to shut these params off, even if they had been
    #  explicitly set to values, use post/.append style
    for boolean, axis_param in zip(
        ['CLEANXAXIS', 'CLEANXAXIS', 'CLEANYAXIS', 'CLEANYAXIS'],
        ['xticklabels', 'xlabel', 'yticklabels', 'ylabel']):
        standard_pre_tikzpicture_lines |= {
            '\\provideboolean{%s}'
            '\\ifthenelse{\\boolean{%s}}{%%'
            '\n\t\\pgfplotsset{every axis post/.append style={%s = {} }}%%'
            '\n}{}%%' % (boolean, boolean, axis_param)
        }

    # don't forget any pre_tikzpicture_lines that have been passed as args
    pre_tikzpicture_lines = augment_params_set(pre_tikzpicture_lines,
                                               standard_pre_tikzpicture_lines)

    # always pass certain axis parameters
    standard_axis_parameters = {
        'every axis x label/.append style={opacity=\\thisXlabelopacity}',
        'every axis y label/.append style={opacity=\\thisYlabelopacity}',
    }
    extra_axis_parameters = augment_params_set(extra_axis_parameters,
                                               standard_axis_parameters)

    # get the code
    code = tpl.get_tikz_code(
        *args,
        **kwargs,
        extra_axis_parameters=extra_axis_parameters,
    )

    # tack some extra code before anything
    if pre_tikzpicture_lines is not None:
        code = '%\n'.join(pre_tikzpicture_lines) + '%\n' + code

    # perhaps tack some extra code before the \end{axis} command
    if extra_body_parameters is not None:
        end_axis = '\\end{axis}'
        code_pieces = code.split(end_axis)
        code = (code_pieces[0] +
                reduce(lambda a, b: b + '\n' + a,
                       reversed(extra_body_parameters), end_axis) +
                code_pieces[1])

    # finally, write out the file
    file_handle = codecs.open(filepath, "w", encoding)
    try:
        file_handle.write(code)
    except UnicodeEncodeError:
        # We're probably using Python 2, so treat unicode explicitly
        file_handle.write(six.text_type(code).encode("utf-8"))
    file_handle.close()
Beispiel #25
0
    def plot(self,
             X: Union[Iterable[np.ndarray], np.ndarray],
             y: Union[Iterable[np.ndarray], np.ndarray],
             batched: bool = False,
             uncertainty: str = None,
             filename: str = None,
             tikz: bool = False,
             title_suffix: str = None,
             feature_names: List[str] = None,
             **save_args) -> Union[plt.Figure, str]:
        """
        Reliability diagram to visualize miscalibration. This could be either in classical way for confidences only
        or w.r.t. additional properties (like x/y-coordinates of detection boxes, width, height, etc.). The additional
        properties get binned. Afterwards, the miscalibration will be calculated for each bin. This is
        visualized as a 2-D plots.

        Parameters
        ----------
        X : iterable of np.ndarray, or np.ndarray of shape=([n_bayes], n_samples, [n_classes/n_box_features])
            NumPy array with confidence values for each prediction on classification with shapes
            1-D for binary classification, 2-D for multi class (softmax).
            If 3-D, interpret first dimension as samples from an Bayesian estimator with mulitple data points
            for a single sample (e.g. variational inference or MC dropout samples).
            If this is an iterable over multiple instances of np.ndarray and parameter batched=True,
            interpret this parameter as multiple predictions that should be averaged.
            On detection, this array must have 2 dimensions with number of additional box features in last dim.
        y : iterable of np.ndarray with same length as X or np.ndarray of shape=([n_bayes], n_samples, [n_classes])
            NumPy array with ground truth labels.
            Either as label vector (1-D) or as one-hot encoded ground truth array (2-D).
            If 3-D, interpret first dimension as samples from an Bayesian estimator with mulitple data points
            for a single sample (e.g. variational inference or MC dropout samples).
            If iterable over multiple instances of np.ndarray and parameter batched=True,
            interpret this parameter as multiple predictions that should be averaged.
        batched : bool, optional, default: False
            Multiple predictions can be evaluated at once (e.g. cross-validation examinations) using batched-mode.
            All predictions given by X and y are separately evaluated and their results are averaged afterwards
            for visualization.
        uncertainty : str, optional, default: False
            Define uncertainty handling if input X has been sampled e.g. by Monte-Carlo dropout or similar methods
            that output an ensemble of predictions per sample. Choose one of the following options:
            - flatten:  treat everything as a separate prediction - this option will yield into a slightly better
                        calibration performance but without the visualization of a prediction interval.
            - mean:     compute Monte-Carlo integration to obtain a simple confidence estimate for a sample
                        (mean) with a standard deviation that is visualized.
        filename : str, optional, default: None
            Optional filename to save the plotted figure.
        tikz : bool, optional, default: False
            If True, use 'tikzplotlib' package to return tikz-code for Latex rather than a Matplotlib figure.
        title_suffix : str, optional, default: None
            Suffix for plot title.
        feature_names : list, optional, default: None
            Names of the additional features that are attached to the axes of a reliability diagram.
        **save_args : args
            Additional arguments passed to 'matplotlib.pyplot.Figure.savefig' function if 'tikz' is False.
            If 'tikz' is True, the argument are passed to 'tikzplotlib.get_tikz_code' function.

        Returns
        -------
        matplotlib.pyplot.Figure if 'tikz' is False else str with tikz code.

        Raises
        ------
        AttributeError
            - If parameter metric is not string or string is not 'ACE', 'ECE' or 'MCE'
            - If parameter 'feature_names' is set but length does not fit to second dim of X
            - If no ground truth samples are provided
            - If length of bins parameter does not match the number of features given by X
            - If more than 3 feature dimensions (including confidence) are provided
        """

        # assign deprecated constructor parameter to title_suffix and feature_names
        if hasattr(self, 'title_suffix') and title_suffix is None:
            title_suffix = self.title_suffix

        if hasattr(self, 'feature_names') and feature_names is None:
            feature_names = self.feature_names

        # check if metric is correct
        if not isinstance(self.metric, str):
            raise AttributeError(
                'Parameter \'metric\' must be string with either \'ece\', \'ace\' or \'mce\'.'
            )

        # check metrics parameter
        if self.metric.lower() not in ['ece', 'ace', 'mce']:
            raise AttributeError(
                'Parameter \'metric\' must be string with either \'ece\', \'ace\' or \'mce\'.'
            )
        else:
            self.metric = self.metric.lower()

        # perform checks and prepare input data
        X, matched, sample_uncertainty, bin_bounds, num_features = self._miscalibration.prepare(
            X, y, batched, uncertainty)
        if num_features > 3:
            raise AttributeError(
                "Diagram is not defined for more than 2 additional feature dimensions."
            )

        histograms = []
        for batch_X, batch_matched, batch_uncertainty, bounds in zip(
                X, matched, sample_uncertainty, bin_bounds):
            batch_histograms = self._miscalibration.binning(
                bounds, batch_X, batch_matched, batch_X[:, 0],
                batch_uncertainty[:, 0])
            histograms.append(batch_histograms[:-1])

        # no additional dimensions? compute standard reliability diagram
        if num_features == 1:
            fig = self.__plot_confidence_histogram(X, matched, histograms,
                                                   bin_bounds, title_suffix)

        # one additional feature? compute 1D-plot
        elif num_features == 2:
            fig = self.__plot_1d(histograms, bin_bounds, title_suffix,
                                 feature_names)

        # two additional features? compute 2D plot
        elif num_features == 3:
            fig = self.__plot_2d(histograms, bin_bounds, title_suffix,
                                 feature_names)

        # number of dimensions exceeds 3? quit
        else:
            raise AttributeError(
                "Diagram is not defined for more than 2 additional feature dimensions."
            )

        # if tikz is true, create tikz code from matplotlib figure
        if tikz:

            # get tikz code for our specific figure and also pass filename to store possible bitmaps
            tikz_fig = tikzplotlib.get_tikz_code(fig,
                                                 filepath=filename,
                                                 **save_args)

            # close matplotlib figure when tikz figure is requested to save memory
            plt.close(fig)
            fig = tikz_fig

        # save figure either as matplotlib PNG or as tikz output file
        if filename is not None:
            if tikz:
                with open(filename, "w") as open_file:
                    open_file.write(fig)
            else:
                fig.savefig(filename, **save_args)

        return fig
Beispiel #26
0
def get_tikz_strings(model_dict, do_hist=True):
    epochs = [ep + 1 for ep in list(range(model_dict['training']['epochs']))]
    hist = model_dict['training']['history']

    # accuracies
    plt.close()
    plt.figure()
    for key in model_dict['training']['metrics']:
        key_disp_stem = key.replace('_', ' ').capitalize()
        if key == 'accuracy':
            key = 'acc'

        for metric in [key, 'val_' + key]:
            if metric[:3] == 'val':
                marker = 'o--'
                metric_disp = 'Test ' + key_disp_stem
            else:
                marker = 'x--'
                metric_disp = 'Training ' + key_disp_stem

            # x axis
            xmin, xmax = epochs[0] - .5, epochs[-1] + .5
            plt.xlabel('Epoch')
            plt.xticks(epochs)

            # y axis
            if len(model_dict['training']['metrics']) > 1:
                plt.ylabel('Metric')
            else:
                plt.ylabel(metric_disp)
            min_metric = float(min(hist[metric]))

            min_loss = min([float(x) for x in hist['loss'] + hist['val_loss']])
            max_loss = max([float(x) for x in hist['loss'] + hist['val_loss']])
            ymin, ymax = .9 * min(round(min_metric, 1), 0.8), 1.025
            yval = [float(val) for val in hist[metric]]

            # axis limits
            plt.axis([xmin, xmax, ymin, ymax])
            # plot
            plt.plot(epochs[:len(yval)],
                     yval,
                     marker,
                     label=metric_disp,
                     markersize=5,
                     linewidth=1)
        # horizontal line at y=1
    plt.plot([xmin, xmax], [1., 1.], 'k:', linewidth=1)
    plt.title(r'{\bf Progression of Evaluation metrics}')
    plt.legend(loc='lower right')
    tikz_acc = tikzplotlib.get_tikz_code(extra_axis_parameters={'font=\small'},
                                         strict=True)

    # loss
    plt.close()
    fig = plt.figure()
    for metric in ['loss', 'val_loss']:
        metric_disp = 'Training Loss' if metric == 'loss' else 'Test Loss'

        marker = 'o--' if metric[:3] == 'val' else 'x--'

        # x axis
        xmin, xmax = epochs[0] - .5, epochs[-1] + .5
        plt.xlabel('Epoch')
        plt.xticks(epochs)

        # y axis
        plt.ylabel('Loss')
        min_loss = min([float(x) for x in hist['loss'] + hist['val_loss']])
        max_loss = max([float(x) for x in hist['loss'] + hist['val_loss']])
        ymin, ymax = .6 * min(round(min_loss, 0), 0.5), max_loss * 1.3
        yval = [float(val) for val in hist[metric]]

        # axis limits
        plt.axis([xmin, xmax, ymin, ymax])
        # plot
        plt.plot(epochs[:len(yval)],
                 yval,
                 marker,
                 label=metric_disp,
                 markersize=5,
                 linewidth=1)
    plt.legend(loc='upper right')
    plt.title(r'{\bf Progression of loss function}')
    tikz_loss = tikzplotlib.get_tikz_code(
        extra_axis_parameters={'font=\small'}, strict=True)

    ######################################
    # learning rate
    plt.close()
    plt.figure()

    # x axis
    xmin, xmax = epochs[0] - .5, epochs[-1] + .5
    plt.xlabel('Epoch')
    plt.xticks(epochs)

    # y axis
    plt.ylabel('Learning Rate')
    min_lr, max_lr = float(min(hist['lr'])), float(max(hist['lr']))
    ymin, ymax = .6 * min(round(min_lr, 1), 0.5), max_lr * 1.3
    yval = [float(val) for val in hist['lr']]

    # axis limits
    plt.axis([xmin, xmax, ymin, ymax])
    # plot
    plt.plot(epochs[:len(yval)],
             yval,
             marker,
             label='Learning Rate',
             markersize=5,
             linewidth=1)
    plt.legend(loc='upper right')
    plt.title(r'{\bf Progression of Learning rate}')
    tikz_lr = tikzplotlib.get_tikz_code(extra_axis_parameters={'font=\small'},
                                        strict=True)

    out_dict = {'loss': tikz_loss, 'acc': tikz_acc, 'lr': tikz_lr}
    return out_dict
Beispiel #27
0
axess[0][0].set_ylabel("test loss")
axess[1][0].set_ylabel("train loss")
axess[2][0].set_ylabel("test acc")
axess[3][0].set_ylabel("train acc")



# modify the plot
for axes in axess:
    for ax in axes:
        lines = ax.get_lines()
        for line in lines:
            line.set_linewidth(3)

fig.canvas.draw()

# General settings
code = tikzplotlib.get_tikz_code(figure = fig,
                                 figurewidth = "\\figurewidth",
                                 figureheight = "5cm",
                                 extra_axis_parameters = ["tick pos=left",
             "legend style={font=\\footnotesize, at={(0 ,0)},xshift=-0.4cm, yshift=-1.5cm,anchor=north,nodes=right}",],
                                 extra_tikzpicture_parameters = ["every axis plot post./append style={line width = 1pt}"],
                                 )#strict = True)

#catch missed underscores & save
code = code.replace("\_", "_").replace("_", "\_")
file = codecs.open("../../thesis/images/exp_tunedalpha.pgf", "w", 'utf-8')
file.write(code)
file.close()
Beispiel #28
0
def plot(dirs, exp_name='figure', save=True, show=True, plot_train=False):

    import pandas as pd
    import matplotlib.pyplot as plt
    import tikzplotlib

    def _patch_tikzplotlib():
        def _new_draw_line2d(data, obj):
            content = ["\\addplot +[mark=none] "] + tikzplotlib._line2d._table(
                obj, data)[0]
            legend_text = tikzplotlib._util.get_legend_text(obj)
            if legend_text is not None:
                content.append(f"\\addlegendentry{{{legend_text}}}\n")
            return data, content

        def _new_init(self, data, obj):
            _tmp(self, data, obj)
            self.axis_options = [
                x for x in self.axis_options
                if 'mod' in x or 'label' in x or 'title' in x
            ]

        _tmp = tikzplotlib._axes.Axes.__init__
        tikzplotlib._axes.Axes.__init__ = _new_init
        tikzplotlib._line2d.draw_line2d = _new_draw_line2d

    def _plot(ax, x, y, xlabel):
        for exp in exps.values():
            if plot_train:
                ax.plot(exp['train'][x], exp['train'][y], ':')
            mask = exp['val'][y] < 4
            ax.plot(exp['val'][x][mask], exp['val'][y][mask])
        ax.set(xlabel=xlabel, ylabel=y)

    exps = {}
    legends = []
    for dir_name in dirs:
        fnames = [
            _name[:-10] for _name in os.listdir(dir_name)
            if 'full' not in _name and 'train' in _name
        ]
        for fname in fnames:
            train = pd.read_csv(dir_name + '/' + fname + '_train.txt', sep=' ')
            val = pd.read_csv(dir_name + '/' + fname + '_val.txt', sep=' ')
            exps[fname] = {'train': train, 'val': val}
            fname = fname.replace('_', '-')
            if plot_train:
                legends += [fname + '-train']
            legends += [fname + '-val']

    fig, axs = plt.subplots(1, 4)
    _plot(axs[0], 'iterations', 'loss', '\#iters')
    _plot(axs[1], 'train_time', 'loss', 'Training time / $s$')
    # _plot(axs[1], 'run_time', 'loss', 'Run time / $s$')
    _plot(axs[2], 'iterations', 'accuracy', '\#iters')
    _plot(axs[3], 'train_time', 'accuracy', 'Training time / $s$')
    # _plot(axs[3], 'run_time', 'accuracy', 'Run time / $s$')
    plt.legend(legends)

    if save:
        _patch_tikzplotlib()
        figure_content = tikzplotlib.get_tikz_code(filepath=exp_name,
                                                   figure=fig,
                                                   externalize_tables=True,
                                                   override_externals=False,
                                                   strict=False)

        with open(exp_name + '-data.tex', 'w') as f:
            data_fils = [x for x in os.listdir() if x.endswith('tsv')]
            for f_path in data_fils:
                with open(f_path, 'r') as f_data:
                    f.write(
                        '\\begin{filecontents}{%s}\n%s\\end{filecontents}\n\n\n'
                        % (f_path, f_data.read()))
                os.remove(f_path)

        with open('%s.tex' % exp_name, 'w') as f:
            f.write('''\
\\documentclass{standalone}
\\usepackage[utf8]{inputenc}
\\usepackage{pgfplots}
\\usepgfplotslibrary{groupplots,dateplot}
\\usetikzlibrary{patterns,shapes.arrows}
\\pgfplotsset{compat=newest}
\\begin{document}
\\input{%s-data.tex}
%s
\\end{document}''' % (exp_name, figure_content))

    if show:
        plt.show()
    '''
Beispiel #29
0
a = 5
t = np.linspace(0, 200)
f = 2000 * logistic.cdf((t - 75) / 10)
orig = np.r_[np.zeros(50), f, f[-1] * np.ones(50), np.flip(f)]
pmf = gamma.pdf(np.linspace(gamma.ppf(0.005, a), gamma.ppf(1 - 0.005, a)), a)
pmf /= sum(pmf)
obs = convolve(orig, pmf, mode="full")
obs *= sum(orig) / sum(obs)

plt.plot(obs, color=palette[1], label="symptom onset reports", linewidth=3)
plt.plot(orig, color="black", label="infections", linewidth=3)
plt.xlabel("time")
plt.ylabel("cases")
plt.legend()
print(tikzplotlib.get_tikz_code())

# b = 3
# orig = np.r_[0, 4, 6, 9, 7, 5, np.zeros(14)]
# pmf = poisson.pmf(range(9), b)
# plt.plot(pmf)
# plt.show()

blur = convolve(orig, pmf, mode="full")
plt.plot(orig)
plt.plot(blur)

plt.show()


# http://freerangestats.info/blog/2020/07/18/victoria-r-convolution
Beispiel #30
0
axes.get_lines()[1].set_marker("x")
axes.get_lines()[1].set_alpha(0.5)

axes.plot([1e-5, 100], [mean, mean],
          linewidth=3,
          linestyle="dashed",
          color="grey")

axes.legend(["PreconditionedSGD", "SGD", "Constructed Learning Rate"])
fig.canvas.draw()

code = tikzplotlib.get_tikz_code(
    figure=fig,
    figurewidth="\\figurewidth + 1cm",
    figureheight="5cm",
    extra_axis_parameters=[
        "tick pos=left",
        "legend style={font=\\footnotesize, at={(0.5 ,0)}, yshift=-1.5cm,anchor=north,nodes=right}"
    ],
    extra_tikzpicture_parameters=[],
)  #strict = True)

#catch missed underscores & save
code = code.replace("\_", "_").replace("_", "\_")

file = codecs.open("../../thesis/images/exp_lr_sens.pgf", "w", 'utf-8')
file.write(code)
file.close()

###### Presentation
code = tikzplotlib.get_tikz_code(
    figure=fig,