def test_exclude_by_regex(self): tf.Variable(tf.zeros((3, 2)), trainable=True) with tf.variable_scope('foo'): tf.Variable(tf.zeros((5, 2)), trainable=True) with tf.variable_scope('bar'): tf.Variable(tf.zeros((1, 2)), trainable=True) self.assertEqual(0, count_weights(exclude=r'.*')) self.assertEqual(6, count_weights(exclude=r'(^|/)foo/.*')) self.assertEqual(16, count_weights(exclude=r'.*/bar/.*'))
def test_non_default_graph(self): graph = tf.Graph() with graph.as_default(): tf.Variable(tf.zeros((5, 3)), trainable=True) tf.Variable(tf.zeros((8, 2)), trainable=False) self.assertNotEqual(graph, tf.get_default_graph) self.assertEqual(15, count_weights(graph=graph))
def test_restrict_invalid_scope(self): tf.Variable(tf.zeros((3, 2)), trainable=True) with tf.variable_scope('foo'): tf.Variable(tf.zeros((5, 2)), trainable=True) with tf.variable_scope('bar'): tf.Variable(tf.zeros((1, 2)), trainable=True) self.assertEqual(0, count_weights('bar'))
def test_trainable_and_non_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=True) tf.Variable(tf.zeros((8, 2)), trainable=False) tf.Variable(tf.zeros((1, 1)), trainable=True) tf.Variable(tf.zeros((5, )), trainable=True) tf.Variable(tf.zeros((3, 1)), trainable=False) self.assertEqual(15 + 1 + 5, count_weights())
def test_include_scopes(self): tf.Variable(tf.zeros((3, 2)), trainable=True) with tf.variable_scope('foo'): tf.Variable(tf.zeros((5, 2)), trainable=True) self.assertEqual(6 + 10, count_weights())
def test_ignore_non_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=False) tf.Variable(tf.zeros((1, 1)), trainable=False) tf.Variable(tf.zeros((5, )), trainable=False) self.assertEqual(0, count_weights())
def test_count_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=True) tf.Variable(tf.zeros((1, 1)), trainable=True) tf.Variable(tf.zeros((5, )), trainable=True) self.assertEqual(15 + 1 + 5, count_weights())