def heatmap_all_to_latex(self, experiment, sec, mean_metrics, x_label, y_label, caption, file_path, file_name): plot = sec.new( Plot(plot_name=file_name, plot_path=file_path, grid=False, lines=False, enlargelimits='false', width=r'.42\textwidth', height=r'.42\textwidth', position='th!', label=file_name, name='plot0')) plot.caption = caption kwargs_per_plot = { 1: { 'as_float_env': False, 'at': '(plot0.south east)', 'anchor': 'south west', 'xshift': r'.12\textwidth', 'name': 'plot1' }, 2: { 'as_float_env': False, 'at': '(plot0.south west)', 'anchor': 'north west', 'yshift': r'-0.07\textwidth', 'name': 'plot2' }, 3: { 'as_float_env': False, 'at': '(plot2.south east)', 'anchor': 'south west', 'xshift': r'.12\textwidth', 'name': 'plot3' } } titles = { 'de': 'English-Deutsch', 'it': 'English-Italian', 'fi': 'English-Finnish', 'es': 'English-Spanish' } for i, language in enumerate(sorted(mean_metrics['accuracies'])): if i == 0: current_plot = plot else: current_plot = Plot(plot_name=file_name, plot_path=file_path, grid=False, lines=False, enlargelimits='false', width=r'.42\textwidth', height=r'.42\textwidth', position='th!', **kwargs_per_plot[i]) current_plot.tikzpicture.head = '' current_plot.tikzpicture.tail = '' x_values, y_values = sorted( experiment.CHANGING_PARAMS[x_label]), sorted( experiment.CHANGING_PARAMS[y_label]) z = np.zeros((len(x_values), len(y_values)), dtype=float) for x_idx, x_value in enumerate(x_values): for y_idx, y_value in enumerate(y_values): z[x_idx, y_idx] = float( mean_metrics['accuracies'][language][(str(y_value), str(x_value))]) z = np.around(z, 2) if i >= 2: current_plot.x_label = r'$p_{factor}$' current_plot.x_ticks_labels = [ '{:.1f}'.format(x) for x in x_values ] else: current_plot.x_ticks_labels = [r'\empty'] if i % 2 == 0: current_plot.y_label = r'$p_0$' current_plot.y_ticks_labels = [ '{:.2f}'.format(y) for y in y_values ] else: current_plot.y_ticks_labels = [r'\empty'] x_values = list(range(len(x_values))) y_values = list(range(len(y_values))) current_plot.x_ticks = x_values current_plot.y_ticks = y_values delta = z.max() - z.min() point_min = z.min() - delta / 2 point_max = z.max() + delta / 2 current_plot.add_matrix_plot(x_values, y_values, z, point_meta_min=point_min, point_meta_max=point_max) current_plot.axis.options += ( r'nodes near coords={\pgfmathprintnumber\pgfplotspointmeta}', r'every node near coord/.append style={xshift=0pt,yshift=-7pt, black, font=\footnotesize}', ) current_plot.axis.kwoptions[ 'colorbar_style'] = '{{/pgf/number format/fixed zerofill, /pgf/number format/precision=1}}' current_plot.title = titles[language] current_plot.plot_name += '_en_{}'.format(language) if i > 0: plot.tikzpicture += current_plot
def plot_all_to_latex(self, sec, mean_metrics, std_metrics, caption, file_path, file_name): plot = sec.new( Plot(plot_name=file_name, plot_path=file_path, position='th!', width=r'.27\textwidth', height=r'.25\textwidth', label=file_name, name='plot0', xshift=r'-.115\textwidth')) plot.caption = caption kwargs_per_plot = { 1: { 'as_float_env': False, 'at': '(plot0.south east)', 'anchor': 'south west', 'xshift': r'-.04\textwidth', 'name': 'plot1' }, 2: { 'as_float_env': False, 'at': '(plot1.south east)', 'anchor': 'south west', 'xshift': r'0.035\textwidth', 'name': 'plot2' }, 3: { 'as_float_env': False, 'at': '(plot2.south east)', 'anchor': 'south west', 'xshift': r'.11\textwidth', 'name': 'plot3' } } titles = { 'de': 'English-Deutsch', 'it': 'English-Italian', 'fi': 'English-Finnish', 'es': 'English-Spanish' } for i, language in enumerate(sorted(mean_metrics['accuracies'])): if i == 0: current_plot = plot else: current_plot = Plot(plot_name=file_name, plot_path=file_path, width=r'.27\textwidth', height=r'.25\textwidth', **kwargs_per_plot[i]) current_plot.tikzpicture.head = '' current_plot.tikzpicture.tail = '' x, y, std = np.array( list(mean_metrics['accuracies'][language].keys())).astype( int), np.array( list(mean_metrics['accuracies'] [language].values())).astype(float), np.array( list(std_metrics['accuracies'] [language].values())).astype(float) sorting = x.argsort() x, y, std = x[sorting], y[sorting], std[sorting] current_plot.add_plot(x, y, 'blue', 'ylabel near ticks', mark='*', line_width='1.2pt', mark_size='.9pt') current_plot.add_plot(x, y + 1.96 * std, name_path='upper', draw='none') current_plot.add_plot(x, y - 1.96 * std, name_path='lower', draw='none') current_plot.axis.append( '\\addplot[fill=blue!10] fill between[of=upper and lower];') current_plot.axis.kwoptions[ 'y tick label style'] = '{/pgf/number format/fixed zerofill, /pgf/number format/precision=1}' current_plot.x_min = np.floor(x.min()) current_plot.x_max = np.ceil(x.max()) y_max, y_min = (y + 1.96 * std).max(), (y - 1.96 * std).min() delta = y.max() - y.min() current_plot.y_min = y_min - delta / 2 current_plot.y_max = y_max + delta / 2 current_plot.title = titles[language] current_plot.plot_name += '_en_{}'.format(language) if i > 0: plot.tikzpicture += current_plot