def test_title(self): title = 'My Title' plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY, titles=title, close=False, show=False) plot_metrics_figs_no_title = [] for metric in PlotHistoryTest.METRICS: fig, ax = plt.subplots() plot_metric(PlotHistoryTest.HISTORY, metric, ax=ax) plot_metrics_figs_no_title.append(fig) plot_metrics_figs_with_title = [] for metric in PlotHistoryTest.METRICS: fig, ax = plt.subplots() plot_metric(PlotHistoryTest.HISTORY, metric, title=title, ax=ax) plot_metrics_figs_with_title.append(fig) for ( plot_history_fig, plot_metrics_fig_no_title, plot_metrics_fig_with_title, ) in zip(plot_history_figs, plot_metrics_figs_no_title, plot_metrics_figs_with_title): self.assertEqual(self._to_image(plot_history_fig), self._to_image(plot_metrics_fig_with_title)) self.assertNotEqual(self._to_image(plot_history_fig), self._to_image(plot_metrics_fig_no_title))
def test_basic(self): figs, axes = plot_history(PlotHistoryTest.HISTORY) self.assertEqual(len(figs), PlotHistoryTest.NUM_METRIC_PLOTS) self.assertEqual(len(axes), PlotHistoryTest.NUM_METRIC_PLOTS) for fig, ax in zip(figs, axes): self.assertIsInstance(fig, Figure) self.assertIsInstance(ax, Axes)
def test_save(self): temp_dir_obj = TemporaryDirectory() path = temp_dir_obj.name filename = 'test_{metric}' final_png_filenames = [ os.path.join(path, f'{filename.format(metric=metric)}.png') for metric in PlotHistoryTest.METRICS ] final_pdf_filenames = [ os.path.join(path, f'{filename.format(metric=metric)}.pdf') for metric in PlotHistoryTest.METRICS ] figs, _ = plot_history( PlotHistoryTest.HISTORY, close=False, show=False, save=True, save_directory=path, save_filename_template=filename, save_extensions=('png', 'pdf'), ) for png_filename, pdf_filename in zip(final_png_filenames, final_pdf_filenames): self.assertTrue(os.path.isfile(png_filename)) self.assertTrue(os.path.isfile(pdf_filename)) saved_images = [ Image.open(filename) for filename in final_png_filenames ] ret_images = [self._to_image(fig) for fig in figs] for save_image, ret_image in zip(saved_images, ret_images): self.assertEqual(save_image, ret_image)
def test_labels(self): labels = ['Time', 'Loss', 'Accuracy'] plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY, labels=labels, close=False, show=False) plot_metrics_figs_no_labels = [] for metric in PlotHistoryTest.METRICS: fig, ax = plt.subplots() plot_metric(PlotHistoryTest.HISTORY, metric, ax=ax) plot_metrics_figs_no_labels.append(fig) plot_metrics_figs_with_labels = [] for metric, label in zip(PlotHistoryTest.METRICS, labels): fig, ax = plt.subplots() plot_metric(PlotHistoryTest.HISTORY, metric, label=label, ax=ax) plot_metrics_figs_with_labels.append(fig) for ( plot_history_fig, plot_metrics_fig_no_labels, plot_metrics_fig_with_labels, ) in zip(plot_history_figs, plot_metrics_figs_no_labels, plot_metrics_figs_with_labels): self.assertEqual(self._to_image(plot_history_fig), self._to_image(plot_metrics_fig_with_labels)) self.assertNotEqual(self._to_image(plot_history_fig), self._to_image(plot_metrics_fig_no_labels))
def test_different_titles(self): titles = ['Time', 'Loss', 'Accuracy'] plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY, titles=titles, close=False, show=False) plot_metrics_figs_no_title = [] for metric in PlotHistoryTest.METRICS: fig, ax = plt.subplots() plot_metric(PlotHistoryTest.HISTORY, metric, ax=ax) plot_metrics_figs_no_title.append(fig) plot_metrics_figs_with_title = [] for metric, title in zip(PlotHistoryTest.METRICS, titles): fig, ax = plt.subplots() plot_metric(PlotHistoryTest.HISTORY, metric, title=title, ax=ax) plot_metrics_figs_with_title.append(fig) for ( plot_history_fig, plot_metrics_fig_no_title, plot_metrics_fig_with_title, ) in zip(plot_history_figs, plot_metrics_figs_no_title, plot_metrics_figs_with_title): self.assertEqual(self._to_image(plot_history_fig), self._to_image(plot_metrics_fig_with_title)) self.assertNotEqual(self._to_image(plot_history_fig), self._to_image(plot_metrics_fig_no_title))
def test_all_different(self): plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY, close=False, show=False) images = list(map(self._to_image, plot_history_figs)) for i, _ in enumerate(images): for j in range(i + 1, len(images)): self.assertNotEqual(images[i], images[j])
def test_with_provided_axes(self): provided_figs, provided_axes = zip( *(plt.subplots() for _ in range(PlotHistoryTest.NUM_METRIC_PLOTS))) ret_figs, ret_axes = plot_history(PlotHistoryTest.HISTORY, close=False, show=False, axes=provided_axes) new_figs, _ = plot_history(PlotHistoryTest.HISTORY, close=False, show=False) self.assertEqual(ret_axes, provided_axes) self.assertEqual(len(ret_figs), 0) for provided_fig, new_fig in zip(provided_figs, new_figs): self.assertEqual(self._to_image(provided_fig), self._to_image(new_fig))
def test_compare_plot_history_plot_metric(self): plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY, close=False, show=False) plot_metrics_figs = [] for metric in PlotHistoryTest.METRICS: fig, ax = plt.subplots() plot_metric(PlotHistoryTest.HISTORY, metric, ax=ax) plot_metrics_figs.append(fig) for plot_history_fig, plot_metrics_fig in zip(plot_history_figs, plot_metrics_figs): self.assertEqual(self._to_image(plot_history_fig), self._to_image(plot_metrics_fig))