示例#1
0
 def test_exclude_by_regex(self):
     tf.Variable(tf.zeros((3, 2)), trainable=True)
     with tf.variable_scope('foo'):
         tf.Variable(tf.zeros((5, 2)), trainable=True)
         with tf.variable_scope('bar'):
             tf.Variable(tf.zeros((1, 2)), trainable=True)
     self.assertEqual(0, count_weights(exclude=r'.*'))
     self.assertEqual(6, count_weights(exclude=r'(^|/)foo/.*'))
     self.assertEqual(16, count_weights(exclude=r'.*/bar/.*'))
示例#2
0
 def test_non_default_graph(self):
     graph = tf.Graph()
     with graph.as_default():
         tf.Variable(tf.zeros((5, 3)), trainable=True)
         tf.Variable(tf.zeros((8, 2)), trainable=False)
     self.assertNotEqual(graph, tf.get_default_graph)
     self.assertEqual(15, count_weights(graph=graph))
示例#3
0
 def test_restrict_invalid_scope(self):
     tf.Variable(tf.zeros((3, 2)), trainable=True)
     with tf.variable_scope('foo'):
         tf.Variable(tf.zeros((5, 2)), trainable=True)
         with tf.variable_scope('bar'):
             tf.Variable(tf.zeros((1, 2)), trainable=True)
     self.assertEqual(0, count_weights('bar'))
示例#4
0
 def test_trainable_and_non_trainable(self):
     tf.Variable(tf.zeros((5, 3)), trainable=True)
     tf.Variable(tf.zeros((8, 2)), trainable=False)
     tf.Variable(tf.zeros((1, 1)), trainable=True)
     tf.Variable(tf.zeros((5, )), trainable=True)
     tf.Variable(tf.zeros((3, 1)), trainable=False)
     self.assertEqual(15 + 1 + 5, count_weights())
示例#5
0
 def test_include_scopes(self):
     tf.Variable(tf.zeros((3, 2)), trainable=True)
     with tf.variable_scope('foo'):
         tf.Variable(tf.zeros((5, 2)), trainable=True)
     self.assertEqual(6 + 10, count_weights())
示例#6
0
 def test_ignore_non_trainable(self):
     tf.Variable(tf.zeros((5, 3)), trainable=False)
     tf.Variable(tf.zeros((1, 1)), trainable=False)
     tf.Variable(tf.zeros((5, )), trainable=False)
     self.assertEqual(0, count_weights())
示例#7
0
 def test_count_trainable(self):
     tf.Variable(tf.zeros((5, 3)), trainable=True)
     tf.Variable(tf.zeros((1, 1)), trainable=True)
     tf.Variable(tf.zeros((5, )), trainable=True)
     self.assertEqual(15 + 1 + 5, count_weights())