Exemple #1
0
 def testModelVariables(self):
     batch_size = 5
     height, width = 224, 224
     num_classes = 1000
     with self.cached_session():
         inputs = random_ops.random_uniform((batch_size, height, width, 3))
         vgg.vgg_a(inputs, num_classes)
         expected_names = [
             'vgg_a/conv1/conv1_1/weights',
             'vgg_a/conv1/conv1_1/biases',
             'vgg_a/conv2/conv2_1/weights',
             'vgg_a/conv2/conv2_1/biases',
             'vgg_a/conv3/conv3_1/weights',
             'vgg_a/conv3/conv3_1/biases',
             'vgg_a/conv3/conv3_2/weights',
             'vgg_a/conv3/conv3_2/biases',
             'vgg_a/conv4/conv4_1/weights',
             'vgg_a/conv4/conv4_1/biases',
             'vgg_a/conv4/conv4_2/weights',
             'vgg_a/conv4/conv4_2/biases',
             'vgg_a/conv5/conv5_1/weights',
             'vgg_a/conv5/conv5_1/biases',
             'vgg_a/conv5/conv5_2/weights',
             'vgg_a/conv5/conv5_2/biases',
             'vgg_a/fc6/weights',
             'vgg_a/fc6/biases',
             'vgg_a/fc7/weights',
             'vgg_a/fc7/biases',
             'vgg_a/fc8/weights',
             'vgg_a/fc8/biases',
         ]
         model_variables = [
             v.op.name for v in variables_lib.get_model_variables()
         ]
         self.assertSetEqual(set(model_variables), set(expected_names))
Exemple #2
0
 def testModelHasExpectedNumberOfParameters(self):
   batch_size = 5
   height, width = 224, 224
   inputs = random_ops.random_uniform((batch_size, height, width, 3))
   with arg_scope(inception_v2.inception_v2_arg_scope()):
     inception_v2.inception_v2_base(inputs)
   total_params, _ = model_analyzer.analyze_vars(
       variables_lib.get_model_variables())
   self.assertAlmostEqual(10173112, total_params)