Beispiel #1
0
    def test_title(self):
        title = 'My Title'
        plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY,
                                            titles=title,
                                            close=False,
                                            show=False)

        plot_metrics_figs_no_title = []
        for metric in PlotHistoryTest.METRICS:
            fig, ax = plt.subplots()
            plot_metric(PlotHistoryTest.HISTORY, metric, ax=ax)
            plot_metrics_figs_no_title.append(fig)

        plot_metrics_figs_with_title = []
        for metric in PlotHistoryTest.METRICS:
            fig, ax = plt.subplots()
            plot_metric(PlotHistoryTest.HISTORY, metric, title=title, ax=ax)
            plot_metrics_figs_with_title.append(fig)

        for (
                plot_history_fig,
                plot_metrics_fig_no_title,
                plot_metrics_fig_with_title,
        ) in zip(plot_history_figs, plot_metrics_figs_no_title,
                 plot_metrics_figs_with_title):
            self.assertEqual(self._to_image(plot_history_fig),
                             self._to_image(plot_metrics_fig_with_title))
            self.assertNotEqual(self._to_image(plot_history_fig),
                                self._to_image(plot_metrics_fig_no_title))
Beispiel #2
0
 def test_basic(self):
     figs, axes = plot_history(PlotHistoryTest.HISTORY)
     self.assertEqual(len(figs), PlotHistoryTest.NUM_METRIC_PLOTS)
     self.assertEqual(len(axes), PlotHistoryTest.NUM_METRIC_PLOTS)
     for fig, ax in zip(figs, axes):
         self.assertIsInstance(fig, Figure)
         self.assertIsInstance(ax, Axes)
Beispiel #3
0
    def test_save(self):
        temp_dir_obj = TemporaryDirectory()
        path = temp_dir_obj.name
        filename = 'test_{metric}'
        final_png_filenames = [
            os.path.join(path, f'{filename.format(metric=metric)}.png')
            for metric in PlotHistoryTest.METRICS
        ]
        final_pdf_filenames = [
            os.path.join(path, f'{filename.format(metric=metric)}.pdf')
            for metric in PlotHistoryTest.METRICS
        ]

        figs, _ = plot_history(
            PlotHistoryTest.HISTORY,
            close=False,
            show=False,
            save=True,
            save_directory=path,
            save_filename_template=filename,
            save_extensions=('png', 'pdf'),
        )

        for png_filename, pdf_filename in zip(final_png_filenames,
                                              final_pdf_filenames):
            self.assertTrue(os.path.isfile(png_filename))
            self.assertTrue(os.path.isfile(pdf_filename))

        saved_images = [
            Image.open(filename) for filename in final_png_filenames
        ]
        ret_images = [self._to_image(fig) for fig in figs]
        for save_image, ret_image in zip(saved_images, ret_images):
            self.assertEqual(save_image, ret_image)
Beispiel #4
0
    def test_labels(self):
        labels = ['Time', 'Loss', 'Accuracy']
        plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY,
                                            labels=labels,
                                            close=False,
                                            show=False)

        plot_metrics_figs_no_labels = []
        for metric in PlotHistoryTest.METRICS:
            fig, ax = plt.subplots()
            plot_metric(PlotHistoryTest.HISTORY, metric, ax=ax)
            plot_metrics_figs_no_labels.append(fig)

        plot_metrics_figs_with_labels = []
        for metric, label in zip(PlotHistoryTest.METRICS, labels):
            fig, ax = plt.subplots()
            plot_metric(PlotHistoryTest.HISTORY, metric, label=label, ax=ax)
            plot_metrics_figs_with_labels.append(fig)

        for (
                plot_history_fig,
                plot_metrics_fig_no_labels,
                plot_metrics_fig_with_labels,
        ) in zip(plot_history_figs, plot_metrics_figs_no_labels,
                 plot_metrics_figs_with_labels):
            self.assertEqual(self._to_image(plot_history_fig),
                             self._to_image(plot_metrics_fig_with_labels))
            self.assertNotEqual(self._to_image(plot_history_fig),
                                self._to_image(plot_metrics_fig_no_labels))
Beispiel #5
0
    def test_different_titles(self):
        titles = ['Time', 'Loss', 'Accuracy']
        plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY,
                                            titles=titles,
                                            close=False,
                                            show=False)

        plot_metrics_figs_no_title = []
        for metric in PlotHistoryTest.METRICS:
            fig, ax = plt.subplots()
            plot_metric(PlotHistoryTest.HISTORY, metric, ax=ax)
            plot_metrics_figs_no_title.append(fig)

        plot_metrics_figs_with_title = []
        for metric, title in zip(PlotHistoryTest.METRICS, titles):
            fig, ax = plt.subplots()
            plot_metric(PlotHistoryTest.HISTORY, metric, title=title, ax=ax)
            plot_metrics_figs_with_title.append(fig)

        for (
                plot_history_fig,
                plot_metrics_fig_no_title,
                plot_metrics_fig_with_title,
        ) in zip(plot_history_figs, plot_metrics_figs_no_title,
                 plot_metrics_figs_with_title):
            self.assertEqual(self._to_image(plot_history_fig),
                             self._to_image(plot_metrics_fig_with_title))
            self.assertNotEqual(self._to_image(plot_history_fig),
                                self._to_image(plot_metrics_fig_no_title))
Beispiel #6
0
 def test_all_different(self):
     plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY,
                                         close=False,
                                         show=False)
     images = list(map(self._to_image, plot_history_figs))
     for i, _ in enumerate(images):
         for j in range(i + 1, len(images)):
             self.assertNotEqual(images[i], images[j])
Beispiel #7
0
    def test_with_provided_axes(self):
        provided_figs, provided_axes = zip(
            *(plt.subplots() for _ in range(PlotHistoryTest.NUM_METRIC_PLOTS)))
        ret_figs, ret_axes = plot_history(PlotHistoryTest.HISTORY,
                                          close=False,
                                          show=False,
                                          axes=provided_axes)

        new_figs, _ = plot_history(PlotHistoryTest.HISTORY,
                                   close=False,
                                   show=False)

        self.assertEqual(ret_axes, provided_axes)
        self.assertEqual(len(ret_figs), 0)

        for provided_fig, new_fig in zip(provided_figs, new_figs):
            self.assertEqual(self._to_image(provided_fig),
                             self._to_image(new_fig))
Beispiel #8
0
    def test_compare_plot_history_plot_metric(self):
        plot_history_figs, _ = plot_history(PlotHistoryTest.HISTORY,
                                            close=False,
                                            show=False)

        plot_metrics_figs = []
        for metric in PlotHistoryTest.METRICS:
            fig, ax = plt.subplots()
            plot_metric(PlotHistoryTest.HISTORY, metric, ax=ax)
            plot_metrics_figs.append(fig)

        for plot_history_fig, plot_metrics_fig in zip(plot_history_figs,
                                                      plot_metrics_figs):
            self.assertEqual(self._to_image(plot_history_fig),
                             self._to_image(plot_metrics_fig))