Beispiel #1
0
    def test_plot_with_kwargs(self):
        '''1.3 with kwargs'''

        attention_tensor = fake_attention()
        image_tensor = tf.constant(scipy.misc.face())

        # (a) default execution
        plot_op = tfplot.plot(_overlay_attention, [attention_tensor, image_tensor])
        r = self._execute_plot_op(plot_op, print_image=True)
        self.assertEqual(test_util.hash_image(r), 'c2d64dedd4aa54218e6df95bfeb03bbc17bd17fa')

        # (b) override cmap and alpha
        plot_op = tfplot.plot(_overlay_attention, [attention_tensor, image_tensor],
                              cmap='gray', alpha=0.8)
        r = self._execute_plot_op(plot_op, print_image=True)
        self.assertEqual(test_util.hash_image(r), '31c8029aed7bbafe37bb8c451a3220d573d2d0e0')
Beispiel #2
0
 def test_autowrap_call_extrakwargs(self):
     '''when calling autowrap to wrap a seaborn plot function,
     additional kwargs (non-standard) should be applied as default arguments
     for the actual py.func invocation.'''
     tf_heatmap = tfplot.autowrap(sns.heatmap,
                                  figsize=(2, 2),
                                  tight_layout=True,
                                  cmap='jet',
                                  cbar=False,
                                  xticklabels=False,
                                  yticklabels=False)
     op = tf_heatmap(tf.constant(np.eye(5)))
     r = self._execute_plot_op(op)
     self.assertEquals(test_util.hash_image(r),
                       '528047f739fe6dc4ba4ec1738b3a44b5bc95ecff')
Beispiel #3
0
    def test_summary_plot(self):
        '''tests tfplot.summary.plot'''

        def test_figure(text):
            fig, ax = tfplot.subplots(figsize=(3, 2))
            ax.text(0.5, 0.5, text, ha='center')
            return fig

        summary_op = tfplot.summary.plot("text/hello", test_figure, ["Hello Summary"])
        s = self._execute_summary_op(summary_op)

        # pylint: disable=no-member
        self.assertTrue(s.value[0].tag.startswith('text/hello'))
        self.assertEqual(s.value[0].image.width, 300)    # default dpi = 100
        self.assertEqual(s.value[0].image.height, 200)   # default dpi = 100
        png = s.value[0].image.encoded_image_string
        # pylint: enable=no-member

        if sys.platform == 'darwin':
            imgcat(png)
        self.assertEqual(test_util.hash_image(png), 'dbb47a3281626678894084fa58066f69a2570df4')