def test_layer_range_value_fail(self, layer_range): model = efficientnet.EfficientNetB0(weights=None) try: with self.assertRaises(ValueError): vis_utils.model_to_dot(model, layer_range=layer_range) with self.assertRaises(ValueError): vis_utils.plot_model(model, layer_range=layer_range) except ImportError: pass
def test_dot_layer_range(self, layer_range): model = efficientnet.EfficientNetB0(weights=None) layer_ids_from_model = get_layer_ids_from_model(model, layer_range) try: dot = vis_utils.model_to_dot(model, layer_range=layer_range) dot_edges = dot.get_edges() layer_ids_from_dot = get_layer_ids_from_dot(dot_edges) self.assertAllEqual(sorted(layer_ids_from_model), sorted(layer_ids_from_dot)) except ImportError: pass
def test_plot_layer_range(self, layer_range): model = efficientnet.EfficientNetB0(weights=None) effnet_subplot = 'model_effnet.png' try: vis_utils.plot_model( model, to_file=effnet_subplot, layer_range=layer_range) self.assertTrue(tf.io.gfile.exists(effnet_subplot)) except ImportError: pass finally: if tf.io.gfile.exists(effnet_subplot): tf.io.gfile.remove(effnet_subplot)