Exemple #1
0
    def heatmap_all_to_latex(self, experiment, sec, mean_metrics, x_label,
                             y_label, caption, file_path, file_name):
        plot = sec.new(
            Plot(plot_name=file_name,
                 plot_path=file_path,
                 grid=False,
                 lines=False,
                 enlargelimits='false',
                 width=r'.42\textwidth',
                 height=r'.42\textwidth',
                 position='th!',
                 label=file_name,
                 name='plot0'))
        plot.caption = caption

        kwargs_per_plot = {
            1: {
                'as_float_env': False,
                'at': '(plot0.south east)',
                'anchor': 'south west',
                'xshift': r'.12\textwidth',
                'name': 'plot1'
            },
            2: {
                'as_float_env': False,
                'at': '(plot0.south west)',
                'anchor': 'north west',
                'yshift': r'-0.07\textwidth',
                'name': 'plot2'
            },
            3: {
                'as_float_env': False,
                'at': '(plot2.south east)',
                'anchor': 'south west',
                'xshift': r'.12\textwidth',
                'name': 'plot3'
            }
        }
        titles = {
            'de': 'English-Deutsch',
            'it': 'English-Italian',
            'fi': 'English-Finnish',
            'es': 'English-Spanish'
        }

        for i, language in enumerate(sorted(mean_metrics['accuracies'])):
            if i == 0:
                current_plot = plot
            else:
                current_plot = Plot(plot_name=file_name,
                                    plot_path=file_path,
                                    grid=False,
                                    lines=False,
                                    enlargelimits='false',
                                    width=r'.42\textwidth',
                                    height=r'.42\textwidth',
                                    position='th!',
                                    **kwargs_per_plot[i])
                current_plot.tikzpicture.head = ''
                current_plot.tikzpicture.tail = ''

            x_values, y_values = sorted(
                experiment.CHANGING_PARAMS[x_label]), sorted(
                    experiment.CHANGING_PARAMS[y_label])
            z = np.zeros((len(x_values), len(y_values)), dtype=float)

            for x_idx, x_value in enumerate(x_values):
                for y_idx, y_value in enumerate(y_values):
                    z[x_idx, y_idx] = float(
                        mean_metrics['accuracies'][language][(str(y_value),
                                                              str(x_value))])

            z = np.around(z, 2)

            if i >= 2:
                current_plot.x_label = r'$p_{factor}$'
                current_plot.x_ticks_labels = [
                    '{:.1f}'.format(x) for x in x_values
                ]
            else:
                current_plot.x_ticks_labels = [r'\empty']

            if i % 2 == 0:
                current_plot.y_label = r'$p_0$'
                current_plot.y_ticks_labels = [
                    '{:.2f}'.format(y) for y in y_values
                ]
            else:
                current_plot.y_ticks_labels = [r'\empty']

            x_values = list(range(len(x_values)))
            y_values = list(range(len(y_values)))

            current_plot.x_ticks = x_values
            current_plot.y_ticks = y_values

            delta = z.max() - z.min()
            point_min = z.min() - delta / 2
            point_max = z.max() + delta / 2

            current_plot.add_matrix_plot(x_values,
                                         y_values,
                                         z,
                                         point_meta_min=point_min,
                                         point_meta_max=point_max)
            current_plot.axis.options += (
                r'nodes near coords={\pgfmathprintnumber\pgfplotspointmeta}',
                r'every node near coord/.append style={xshift=0pt,yshift=-7pt, black, font=\footnotesize}',
            )

            current_plot.axis.kwoptions[
                'colorbar_style'] = '{{/pgf/number format/fixed zerofill, /pgf/number format/precision=1}}'

            current_plot.title = titles[language]
            current_plot.plot_name += '_en_{}'.format(language)

            if i > 0:
                plot.tikzpicture += current_plot
Exemple #2
0
    def plot_all_to_latex(self, sec, mean_metrics, std_metrics, caption,
                          file_path, file_name):
        plot = sec.new(
            Plot(plot_name=file_name,
                 plot_path=file_path,
                 position='th!',
                 width=r'.27\textwidth',
                 height=r'.25\textwidth',
                 label=file_name,
                 name='plot0',
                 xshift=r'-.115\textwidth'))
        plot.caption = caption

        kwargs_per_plot = {
            1: {
                'as_float_env': False,
                'at': '(plot0.south east)',
                'anchor': 'south west',
                'xshift': r'-.04\textwidth',
                'name': 'plot1'
            },
            2: {
                'as_float_env': False,
                'at': '(plot1.south east)',
                'anchor': 'south west',
                'xshift': r'0.035\textwidth',
                'name': 'plot2'
            },
            3: {
                'as_float_env': False,
                'at': '(plot2.south east)',
                'anchor': 'south west',
                'xshift': r'.11\textwidth',
                'name': 'plot3'
            }
        }
        titles = {
            'de': 'English-Deutsch',
            'it': 'English-Italian',
            'fi': 'English-Finnish',
            'es': 'English-Spanish'
        }

        for i, language in enumerate(sorted(mean_metrics['accuracies'])):
            if i == 0:
                current_plot = plot
            else:
                current_plot = Plot(plot_name=file_name,
                                    plot_path=file_path,
                                    width=r'.27\textwidth',
                                    height=r'.25\textwidth',
                                    **kwargs_per_plot[i])
                current_plot.tikzpicture.head = ''
                current_plot.tikzpicture.tail = ''

            x, y, std = np.array(
                list(mean_metrics['accuracies'][language].keys())).astype(
                    int), np.array(
                        list(mean_metrics['accuracies']
                             [language].values())).astype(float), np.array(
                                 list(std_metrics['accuracies']
                                      [language].values())).astype(float)
            sorting = x.argsort()
            x, y, std = x[sorting], y[sorting], std[sorting]

            current_plot.add_plot(x,
                                  y,
                                  'blue',
                                  'ylabel near ticks',
                                  mark='*',
                                  line_width='1.2pt',
                                  mark_size='.9pt')
            current_plot.add_plot(x,
                                  y + 1.96 * std,
                                  name_path='upper',
                                  draw='none')
            current_plot.add_plot(x,
                                  y - 1.96 * std,
                                  name_path='lower',
                                  draw='none')
            current_plot.axis.append(
                '\\addplot[fill=blue!10] fill between[of=upper and lower];')
            current_plot.axis.kwoptions[
                'y tick label style'] = '{/pgf/number format/fixed zerofill, /pgf/number format/precision=1}'

            current_plot.x_min = np.floor(x.min())
            current_plot.x_max = np.ceil(x.max())
            y_max, y_min = (y + 1.96 * std).max(), (y - 1.96 * std).min()
            delta = y.max() - y.min()
            current_plot.y_min = y_min - delta / 2
            current_plot.y_max = y_max + delta / 2

            current_plot.title = titles[language]
            current_plot.plot_name += '_en_{}'.format(language)

            if i > 0:
                plot.tikzpicture += current_plot