def testGetNormalizedVariableMapScopeContext(self): with tf.variable_scope("prefix1") as s1: with tf.variable_scope("prefix2") as s2: v1 = tf.get_variable("a", shape=[5, 6]) v2 = tf.get_variable("b", shape=[7]) with tf.variable_scope("prefix") as s3: _ = tf.get_variable("c", shape=[8]) err = r"Scope 'prefix1/prefix2' is not prefixed by 'prefix'." with self.assertRaisesRegexp(ValueError, err): variable_map = snt.get_normalized_variable_map(s2, context=s3) variable_map = snt.get_normalized_variable_map(s2, context=s1) self.assertEqual(len(variable_map), 2) self.assertIn("prefix2/a:0", variable_map) self.assertIn("prefix2/b:0", variable_map) self.assertIs(variable_map["prefix2/a:0"], v1) self.assertIs(variable_map["prefix2/b:0"], v2) with tf.variable_scope("") as s4: self.assertEqual(s4.name, "") variable_map = snt.get_normalized_variable_map(s2, context=s4) self.assertEqual(len(variable_map), 2) self.assertIn("prefix1/prefix2/a:0", variable_map) self.assertIn("prefix1/prefix2/b:0", variable_map) self.assertIs(variable_map["prefix1/prefix2/a:0"], v1) self.assertIs(variable_map["prefix1/prefix2/b:0"], v2)
def testGetNormalizedVariableMapWithPartitionedVariable(self): hidden = tf.ones(shape=(1, 16, 16, 3)) partitioner = tf.variable_axis_size_partitioner(4) conv = snt.Conv2D(output_channels=3, kernel_shape=3, stride=1, partitioners={"w": partitioner}) conv(hidden) variable_map = snt.get_normalized_variable_map( conv, group_sliced_variables=True) self.assertEqual(len(variable_map), 2) self.assertEqual(variable_map["b"], conv.b) self.assertEqual(len(variable_map["w"]), 3) variable_map = snt.get_normalized_variable_map( conv, group_sliced_variables=False) self.assertEqual(variable_map["b"], conv.b) self.assertEqual(set(variable_map), set(["b", "w/part_0", "w/part_1", "w/part_2"]))
def testGetNormalizedVariableMapWithPartitionedVariable(self): hidden = tf.ones(shape=(1, 16, 16, 3)) partitioner = tf.variable_axis_size_partitioner(4) conv = snt.Conv2D(output_channels=3, kernel_shape=3, stride=1, partitioners={"w": partitioner}) conv(hidden) variable_map = snt.get_normalized_variable_map(conv, group_sliced_variables=True) self.assertEqual(len(variable_map), 2) self.assertEqual(variable_map["b"], conv.b) self.assertEqual(len(variable_map["w"]), 3) variable_map = snt.get_normalized_variable_map(conv, group_sliced_variables=False) self.assertEqual(variable_map["b"], conv.b) self.assertEqual( set(variable_map), set(["b", "w/part_0", "w/part_1", "w/part_2"]))
def testGetNormalizedVariableMapScope(self): with tf.variable_scope("prefix") as s1: v1 = tf.get_variable("a", shape=[5, 6]) v2 = tf.get_variable("b", shape=[7]) variable_map = snt.get_normalized_variable_map(s1) self.assertEqual(len(variable_map), 2) self.assertIn("a", variable_map) self.assertIn("b", variable_map) self.assertIs(variable_map["a"], v1) self.assertIs(variable_map["b"], v2)
def testGetNormalizedVariableMapModule(self): input_ = tf.placeholder(tf.float32, shape=[1, 10, 10, 3]) conv = snt.Conv2D(output_channels=3, kernel_shape=3) conv(input_) variable_map = snt.get_normalized_variable_map(conv) self.assertEqual(len(variable_map), 2) self.assertIn("w", variable_map) self.assertIn("b", variable_map) self.assertIs(variable_map["w"], conv.w) self.assertIs(variable_map["b"], conv.b)
def testGetNormalizedVariableMapScope(self): with tf.variable_scope("prefix") as s1: v1 = tf.get_variable("a", shape=[5, 6]) v2 = tf.get_variable("b", shape=[7]) variable_map = snt.get_normalized_variable_map(s1) self.assertEqual(len(variable_map), 2) self.assertIn("a", variable_map) self.assertIn("b", variable_map) self.assertIs(variable_map["a"], v1) self.assertIs(variable_map["b"], v2)
def testGetNormalizedVariableMapModule(self): input_ = tf.placeholder(tf.float32, shape=[1, 10, 10, 3]) conv = snt.Conv2D(output_channels=3, kernel_shape=3) conv(input_) variable_map = snt.get_normalized_variable_map(conv) self.assertEqual(len(variable_map), 2) self.assertIn("w", variable_map) self.assertIn("b", variable_map) self.assertIs(variable_map["w"], conv.w) self.assertIs(variable_map["b"], conv.b)
def testVariableMapItems(self): hidden = tf.ones(shape=(1, 16, 16, 3)) partitioner = tf.variable_axis_size_partitioner(4) conv = snt.Conv2D(output_channels=3, kernel_shape=3, stride=1, partitioners={"w": partitioner}) conv(hidden) variable_map = snt.get_normalized_variable_map(conv) items = snt.variable_map_items(variable_map) items_str = sorted((key, var.op.name) for key, var in items) self.assertEqual(items_str, [(u"b", u"conv_2d/b"), ("w", u"conv_2d/w/part_0"), ("w", u"conv_2d/w/part_1"), ("w", u"conv_2d/w/part_2")])
def testVariableMapItems(self): hidden = tf.ones(shape=(1, 16, 16, 3)) partitioner = tf.variable_axis_size_partitioner(4) conv = snt.Conv2D(output_channels=3, kernel_shape=3, stride=1, partitioners={"w": partitioner}) conv(hidden) variable_map = snt.get_normalized_variable_map(conv) items = snt.variable_map_items(variable_map) items_str = sorted((key, var.op.name) for key, var in items) self.assertEqual(items_str, [(u"b", u"conv_2d/b"), ("w", u"conv_2d/w/part_0"), ("w", u"conv_2d/w/part_1"), ("w", u"conv_2d/w/part_2")])
def get_saver(modules, var_collections=(tf.GraphKeys.GLOBAL_VARIABLES, ), ignore_scope=None, **kwargs): """Get tf.train.Saver instance for module. Args: - modules: Sonnet module or list of Sonnet modules from where to extract variables. - var_collections: Collections from where to take variables. - ignore_scope (str): Ignore variables that contain scope in name. - kwargs: Keyword arguments to pass to creation of `tf.train.Saver`. Returns: - saver: tf.train.Saver instance. """ if not isinstance(modules, collections.Iterable): modules = [modules] variable_map = {} for module in modules: for collection in var_collections: model_variables = snt.get_normalized_variable_map( module, collection) total_model_variables = len(model_variables) if ignore_scope: model_variables = { k: v for k, v in model_variables.items() if ignore_scope not in k } new_total_model_variables = len(model_variables) tf.logging.info( 'Not loading/saving {} variables with scope "{}"'.format( total_model_variables - new_total_model_variables, ignore_scope)) variable_map.update(model_variables) return tf.train.Saver(var_list=variable_map, **kwargs)
def get_saver(modules, var_collections=(tf.GraphKeys.GLOBAL_VARIABLES,), ignore_scope=None, **kwargs): """Get tf.train.Saver instance for module. Args: - modules: Sonnet module or list of Sonnet modules from where to extract variables. - var_collections: Collections from where to take variables. - ignore_scope (str): Ignore variables that contain scope in name. - kwargs: Keyword arguments to pass to creation of `tf.train.Saver`. Returns: - saver: tf.train.Saver instance. """ if not isinstance(modules, collections.Iterable): modules = [modules] variable_map = {} for module in modules: for collection in var_collections: model_variables = snt.get_normalized_variable_map( module, collection ) total_model_variables = len(model_variables) if ignore_scope: model_variables = { k: v for k, v in model_variables.items() if ignore_scope not in k } new_total_model_variables = len(model_variables) tf.logging.info( 'Not loading/saving {} variables with scope "{}"'.format( total_model_variables - new_total_model_variables, ignore_scope)) variable_map.update(model_variables) return tf.train.Saver(var_list=variable_map, **kwargs)
def testGetNormalizedVariableMapScopeContext(self): with tf.variable_scope("prefix1") as s1: with tf.variable_scope("prefix2") as s2: v1 = tf.get_variable("a", shape=[5, 6]) v2 = tf.get_variable("b", shape=[7]) with tf.variable_scope("prefix") as s3: _ = tf.get_variable("c", shape=[8]) err = r"Scope 'prefix1/prefix2' is not prefixed by 'prefix'." with self.assertRaisesRegexp(ValueError, err): variable_map = snt.get_normalized_variable_map(s2, context=s3) variable_map = snt.get_normalized_variable_map(s2, context=s1) self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s1), variable_map) self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s1.name), variable_map) self.assertEqual(len(variable_map), 2) self.assertIn("prefix2/a", variable_map) self.assertIn("prefix2/b", variable_map) self.assertIs(variable_map["prefix2/a"], v1) self.assertIs(variable_map["prefix2/b"], v2) with tf.variable_scope("") as s4: self.assertEqual(s4.name, "") variable_map = snt.get_normalized_variable_map(s2, context=s4) self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s4), variable_map) self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s4.name), variable_map) self.assertEqual(len(variable_map), 2) self.assertIn("prefix1/prefix2/a", variable_map) self.assertIn("prefix1/prefix2/b", variable_map) self.assertIs(variable_map["prefix1/prefix2/a"], v1) self.assertIs(variable_map["prefix1/prefix2/b"], v2)