enlargelimits='false',
    y_dir='reverse',
    axis_x='top',
    x_tick_label_style='{font=\\footnotesize}',
    y_tick_label_style='{font=\\footnotesize}',
    colorbar_style=
    '{ylabel=Attention, ylabel style={rotate=180}, yticklabels={0.00,0.25,0.50,0.75,1.00}, ytick={0,0.25,0.5,0.75,1}}'
)

plot.add_matrix_plot(X, Y, Z)
for i, zz in enumerate(Z):
    for j, z in enumerate(zz):
        if z > .6:
            plot.axis += f'\\node[white] at ({i},{j}) {{\\scriptsize {z:.2f} }};'

plot.add_to_preamble(cmap_def)

plot.x_ticks = range(len(sentences[0]))
plot.x_ticks_labels = sentences[0]
plot.y_ticks = range(len(sentences[1]))
plot.y_ticks_labels = sentences[1]

doc = Document('heatmap_example',
               filepath='./examples/',
               doc_type='standalone')
doc.add_package('xcolor')

doc += plot

doc.build(delete_files=['log', 'aux'])
Exemple #2
0
    def heatmap_all_to_latex(self, experiment, sec, mean_metrics, x_label,
                             y_label, caption, file_path, file_name):
        plot = sec.new(
            Plot(plot_name=file_name,
                 plot_path=file_path,
                 grid=False,
                 lines=False,
                 enlargelimits='false',
                 width=r'.42\textwidth',
                 height=r'.42\textwidth',
                 position='th!',
                 label=file_name,
                 name='plot0'))
        plot.caption = caption

        kwargs_per_plot = {
            1: {
                'as_float_env': False,
                'at': '(plot0.south east)',
                'anchor': 'south west',
                'xshift': r'.12\textwidth',
                'name': 'plot1'
            },
            2: {
                'as_float_env': False,
                'at': '(plot0.south west)',
                'anchor': 'north west',
                'yshift': r'-0.07\textwidth',
                'name': 'plot2'
            },
            3: {
                'as_float_env': False,
                'at': '(plot2.south east)',
                'anchor': 'south west',
                'xshift': r'.12\textwidth',
                'name': 'plot3'
            }
        }
        titles = {
            'de': 'English-Deutsch',
            'it': 'English-Italian',
            'fi': 'English-Finnish',
            'es': 'English-Spanish'
        }

        for i, language in enumerate(sorted(mean_metrics['accuracies'])):
            if i == 0:
                current_plot = plot
            else:
                current_plot = Plot(plot_name=file_name,
                                    plot_path=file_path,
                                    grid=False,
                                    lines=False,
                                    enlargelimits='false',
                                    width=r'.42\textwidth',
                                    height=r'.42\textwidth',
                                    position='th!',
                                    **kwargs_per_plot[i])
                current_plot.tikzpicture.head = ''
                current_plot.tikzpicture.tail = ''

            x_values, y_values = sorted(
                experiment.CHANGING_PARAMS[x_label]), sorted(
                    experiment.CHANGING_PARAMS[y_label])
            z = np.zeros((len(x_values), len(y_values)), dtype=float)

            for x_idx, x_value in enumerate(x_values):
                for y_idx, y_value in enumerate(y_values):
                    z[x_idx, y_idx] = float(
                        mean_metrics['accuracies'][language][(str(y_value),
                                                              str(x_value))])

            z = np.around(z, 2)

            if i >= 2:
                current_plot.x_label = r'$p_{factor}$'
                current_plot.x_ticks_labels = [
                    '{:.1f}'.format(x) for x in x_values
                ]
            else:
                current_plot.x_ticks_labels = [r'\empty']

            if i % 2 == 0:
                current_plot.y_label = r'$p_0$'
                current_plot.y_ticks_labels = [
                    '{:.2f}'.format(y) for y in y_values
                ]
            else:
                current_plot.y_ticks_labels = [r'\empty']

            x_values = list(range(len(x_values)))
            y_values = list(range(len(y_values)))

            current_plot.x_ticks = x_values
            current_plot.y_ticks = y_values

            delta = z.max() - z.min()
            point_min = z.min() - delta / 2
            point_max = z.max() + delta / 2

            current_plot.add_matrix_plot(x_values,
                                         y_values,
                                         z,
                                         point_meta_min=point_min,
                                         point_meta_max=point_max)
            current_plot.axis.options += (
                r'nodes near coords={\pgfmathprintnumber\pgfplotspointmeta}',
                r'every node near coord/.append style={xshift=0pt,yshift=-7pt, black, font=\footnotesize}',
            )

            current_plot.axis.kwoptions[
                'colorbar_style'] = '{{/pgf/number format/fixed zerofill, /pgf/number format/precision=1}}'

            current_plot.title = titles[language]
            current_plot.plot_name += '_en_{}'.format(language)

            if i > 0:
                plot.tikzpicture += current_plot