コード例 #1
0
 def test_layer_range_value_fail(self, layer_range):
     model = efficientnet.EfficientNetB0(weights=None)
     try:
         with self.assertRaises(ValueError):
             vis_utils.model_to_dot(model, layer_range=layer_range)
         with self.assertRaises(ValueError):
             vis_utils.plot_model(model, layer_range=layer_range)
     except ImportError:
         pass
コード例 #2
0
 def test_dot_layer_range(self, layer_range):
     model = efficientnet.EfficientNetB0(weights=None)
     layer_ids_from_model = get_layer_ids_from_model(model, layer_range)
     try:
         dot = vis_utils.model_to_dot(model, layer_range=layer_range)
         dot_edges = dot.get_edges()
         layer_ids_from_dot = get_layer_ids_from_dot(dot_edges)
         self.assertAllEqual(sorted(layer_ids_from_model),
                             sorted(layer_ids_from_dot))
     except ImportError:
         pass
コード例 #3
0
 def test_plot_layer_range(self, layer_range):
   model = efficientnet.EfficientNetB0(weights=None)
   effnet_subplot = 'model_effnet.png'
   try:
     vis_utils.plot_model(
         model, to_file=effnet_subplot, layer_range=layer_range)
     self.assertTrue(tf.io.gfile.exists(effnet_subplot))
   except ImportError:
     pass
   finally:
     if tf.io.gfile.exists(effnet_subplot):
       tf.io.gfile.remove(effnet_subplot)