Exemplo n.º 1
0
    def testGetNormalizedVariableMapScopeContext(self):
        with tf.variable_scope("prefix1") as s1:
            with tf.variable_scope("prefix2") as s2:
                v1 = tf.get_variable("a", shape=[5, 6])
                v2 = tf.get_variable("b", shape=[7])

        with tf.variable_scope("prefix") as s3:
            _ = tf.get_variable("c", shape=[8])

        err = r"Scope 'prefix1/prefix2' is not prefixed by 'prefix'."
        with self.assertRaisesRegexp(ValueError, err):
            variable_map = snt.get_normalized_variable_map(s2, context=s3)

        variable_map = snt.get_normalized_variable_map(s2, context=s1)

        self.assertEqual(len(variable_map), 2)
        self.assertIn("prefix2/a:0", variable_map)
        self.assertIn("prefix2/b:0", variable_map)
        self.assertIs(variable_map["prefix2/a:0"], v1)
        self.assertIs(variable_map["prefix2/b:0"], v2)

        with tf.variable_scope("") as s4:
            self.assertEqual(s4.name, "")
            variable_map = snt.get_normalized_variable_map(s2, context=s4)

        self.assertEqual(len(variable_map), 2)
        self.assertIn("prefix1/prefix2/a:0", variable_map)
        self.assertIn("prefix1/prefix2/b:0", variable_map)
        self.assertIs(variable_map["prefix1/prefix2/a:0"], v1)
        self.assertIs(variable_map["prefix1/prefix2/b:0"], v2)
Exemplo n.º 2
0
    def testGetNormalizedVariableMapWithPartitionedVariable(self):
        hidden = tf.ones(shape=(1, 16, 16, 3))
        partitioner = tf.variable_axis_size_partitioner(4)
        conv = snt.Conv2D(output_channels=3,
                          kernel_shape=3,
                          stride=1,
                          partitioners={"w": partitioner})
        conv(hidden)
        variable_map = snt.get_normalized_variable_map(
            conv, group_sliced_variables=True)
        self.assertEqual(len(variable_map), 2)
        self.assertEqual(variable_map["b"], conv.b)
        self.assertEqual(len(variable_map["w"]), 3)

        variable_map = snt.get_normalized_variable_map(
            conv, group_sliced_variables=False)
        self.assertEqual(variable_map["b"], conv.b)
        self.assertEqual(set(variable_map),
                         set(["b", "w/part_0", "w/part_1", "w/part_2"]))
Exemplo n.º 3
0
  def testGetNormalizedVariableMapWithPartitionedVariable(self):
    hidden = tf.ones(shape=(1, 16, 16, 3))
    partitioner = tf.variable_axis_size_partitioner(4)
    conv = snt.Conv2D(output_channels=3,
                      kernel_shape=3,
                      stride=1,
                      partitioners={"w": partitioner})
    conv(hidden)
    variable_map = snt.get_normalized_variable_map(conv,
                                                   group_sliced_variables=True)
    self.assertEqual(len(variable_map), 2)
    self.assertEqual(variable_map["b"], conv.b)
    self.assertEqual(len(variable_map["w"]), 3)

    variable_map = snt.get_normalized_variable_map(conv,
                                                   group_sliced_variables=False)
    self.assertEqual(variable_map["b"], conv.b)
    self.assertEqual(
        set(variable_map), set(["b", "w/part_0", "w/part_1", "w/part_2"]))
Exemplo n.º 4
0
    def testGetNormalizedVariableMapScope(self):
        with tf.variable_scope("prefix") as s1:
            v1 = tf.get_variable("a", shape=[5, 6])
            v2 = tf.get_variable("b", shape=[7])

        variable_map = snt.get_normalized_variable_map(s1)

        self.assertEqual(len(variable_map), 2)
        self.assertIn("a", variable_map)
        self.assertIn("b", variable_map)
        self.assertIs(variable_map["a"], v1)
        self.assertIs(variable_map["b"], v2)
Exemplo n.º 5
0
    def testGetNormalizedVariableMapModule(self):
        input_ = tf.placeholder(tf.float32, shape=[1, 10, 10, 3])
        conv = snt.Conv2D(output_channels=3, kernel_shape=3)
        conv(input_)

        variable_map = snt.get_normalized_variable_map(conv)

        self.assertEqual(len(variable_map), 2)
        self.assertIn("w", variable_map)
        self.assertIn("b", variable_map)
        self.assertIs(variable_map["w"], conv.w)
        self.assertIs(variable_map["b"], conv.b)
Exemplo n.º 6
0
  def testGetNormalizedVariableMapScope(self):
    with tf.variable_scope("prefix") as s1:
      v1 = tf.get_variable("a", shape=[5, 6])
      v2 = tf.get_variable("b", shape=[7])

    variable_map = snt.get_normalized_variable_map(s1)

    self.assertEqual(len(variable_map), 2)
    self.assertIn("a", variable_map)
    self.assertIn("b", variable_map)
    self.assertIs(variable_map["a"], v1)
    self.assertIs(variable_map["b"], v2)
Exemplo n.º 7
0
  def testGetNormalizedVariableMapModule(self):
    input_ = tf.placeholder(tf.float32, shape=[1, 10, 10, 3])
    conv = snt.Conv2D(output_channels=3, kernel_shape=3)
    conv(input_)

    variable_map = snt.get_normalized_variable_map(conv)

    self.assertEqual(len(variable_map), 2)
    self.assertIn("w", variable_map)
    self.assertIn("b", variable_map)
    self.assertIs(variable_map["w"], conv.w)
    self.assertIs(variable_map["b"], conv.b)
Exemplo n.º 8
0
    def testVariableMapItems(self):
        hidden = tf.ones(shape=(1, 16, 16, 3))
        partitioner = tf.variable_axis_size_partitioner(4)
        conv = snt.Conv2D(output_channels=3,
                          kernel_shape=3,
                          stride=1,
                          partitioners={"w": partitioner})
        conv(hidden)
        variable_map = snt.get_normalized_variable_map(conv)
        items = snt.variable_map_items(variable_map)

        items_str = sorted((key, var.op.name) for key, var in items)
        self.assertEqual(items_str, [(u"b", u"conv_2d/b"),
                                     ("w", u"conv_2d/w/part_0"),
                                     ("w", u"conv_2d/w/part_1"),
                                     ("w", u"conv_2d/w/part_2")])
Exemplo n.º 9
0
  def testVariableMapItems(self):
    hidden = tf.ones(shape=(1, 16, 16, 3))
    partitioner = tf.variable_axis_size_partitioner(4)
    conv = snt.Conv2D(output_channels=3,
                      kernel_shape=3,
                      stride=1,
                      partitioners={"w": partitioner})
    conv(hidden)
    variable_map = snt.get_normalized_variable_map(conv)
    items = snt.variable_map_items(variable_map)

    items_str = sorted((key, var.op.name) for key, var in items)
    self.assertEqual(items_str, [(u"b", u"conv_2d/b"),
                                 ("w", u"conv_2d/w/part_0"),
                                 ("w", u"conv_2d/w/part_1"),
                                 ("w", u"conv_2d/w/part_2")])
Exemplo n.º 10
0
def get_saver(modules,
              var_collections=(tf.GraphKeys.GLOBAL_VARIABLES, ),
              ignore_scope=None,
              **kwargs):
    """Get tf.train.Saver instance for module.

    Args:
        - modules: Sonnet module or list of Sonnet modules from where to
            extract variables.
        - var_collections: Collections from where to take variables.
        - ignore_scope (str): Ignore variables that contain scope in name.
        - kwargs: Keyword arguments to pass to creation of `tf.train.Saver`.

    Returns:
        - saver: tf.train.Saver instance.
    """
    if not isinstance(modules, collections.Iterable):
        modules = [modules]

    variable_map = {}
    for module in modules:
        for collection in var_collections:
            model_variables = snt.get_normalized_variable_map(
                module, collection)
            total_model_variables = len(model_variables)
            if ignore_scope:
                model_variables = {
                    k: v
                    for k, v in model_variables.items()
                    if ignore_scope not in k
                }
                new_total_model_variables = len(model_variables)
                tf.logging.info(
                    'Not loading/saving {} variables with scope "{}"'.format(
                        total_model_variables - new_total_model_variables,
                        ignore_scope))

            variable_map.update(model_variables)

    return tf.train.Saver(var_list=variable_map, **kwargs)
Exemplo n.º 11
0
def get_saver(modules, var_collections=(tf.GraphKeys.GLOBAL_VARIABLES,),
              ignore_scope=None, **kwargs):
    """Get tf.train.Saver instance for module.

    Args:
        - modules: Sonnet module or list of Sonnet modules from where to
            extract variables.
        - var_collections: Collections from where to take variables.
        - ignore_scope (str): Ignore variables that contain scope in name.
        - kwargs: Keyword arguments to pass to creation of `tf.train.Saver`.

    Returns:
        - saver: tf.train.Saver instance.
    """
    if not isinstance(modules, collections.Iterable):
        modules = [modules]

    variable_map = {}
    for module in modules:
        for collection in var_collections:
            model_variables = snt.get_normalized_variable_map(
                module, collection
            )
            total_model_variables = len(model_variables)
            if ignore_scope:
                model_variables = {
                    k: v for k, v in model_variables.items()
                    if ignore_scope not in k
                }
                new_total_model_variables = len(model_variables)
                tf.logging.info(
                    'Not loading/saving {} variables with scope "{}"'.format(
                        total_model_variables - new_total_model_variables,
                        ignore_scope))

            variable_map.update(model_variables)

    return tf.train.Saver(var_list=variable_map, **kwargs)
Exemplo n.º 12
0
  def testGetNormalizedVariableMapScopeContext(self):
    with tf.variable_scope("prefix1") as s1:
      with tf.variable_scope("prefix2") as s2:
        v1 = tf.get_variable("a", shape=[5, 6])
        v2 = tf.get_variable("b", shape=[7])

    with tf.variable_scope("prefix") as s3:
      _ = tf.get_variable("c", shape=[8])

    err = r"Scope 'prefix1/prefix2' is not prefixed by 'prefix'."
    with self.assertRaisesRegexp(ValueError, err):
      variable_map = snt.get_normalized_variable_map(s2, context=s3)

    variable_map = snt.get_normalized_variable_map(s2, context=s1)
    self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s1),
                     variable_map)
    self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s1.name),
                     variable_map)

    self.assertEqual(len(variable_map), 2)
    self.assertIn("prefix2/a", variable_map)
    self.assertIn("prefix2/b", variable_map)
    self.assertIs(variable_map["prefix2/a"], v1)
    self.assertIs(variable_map["prefix2/b"], v2)

    with tf.variable_scope("") as s4:
      self.assertEqual(s4.name, "")

    variable_map = snt.get_normalized_variable_map(s2, context=s4)
    self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s4),
                     variable_map)
    self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s4.name),
                     variable_map)

    self.assertEqual(len(variable_map), 2)
    self.assertIn("prefix1/prefix2/a", variable_map)
    self.assertIn("prefix1/prefix2/b", variable_map)
    self.assertIs(variable_map["prefix1/prefix2/a"], v1)
    self.assertIs(variable_map["prefix1/prefix2/b"], v2)