コード例 #1
0
 def testNumberedPath(self):
   root = checkpointable.Checkpointable()
   leaf = checkpointable.Checkpointable()
   root.track_checkpointable(leaf)
   leaf.add_variable(name="v", shape=[])
   named_variables, _ = checkpointable._serialize_object_graph(root)
   variable_name, = named_variables.keys()
   self.assertEqual(r"_0/v", variable_name)
コード例 #2
0
 def _get_checkpoint_name(self, name):
   root = checkpointable.Checkpointable()
   root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64)
   named_variables, _ = checkpointable._serialize_object_graph(root)
   checkpoint_name, = named_variables.keys()
   with ops.name_scope("root/" + checkpoint_name):
     pass  # Make sure we can use this as an op name if we prefix it.
   return checkpoint_name
コード例 #3
0
 def testNumberedPath(self):
   root = checkpointable.Checkpointable()
   leaf = checkpointable.Checkpointable()
   root.track_checkpointable(leaf, name="leaf")
   leaf.add_variable(name="v", shape=[])
   named_variables, _ = checkpointable._serialize_object_graph(root)
   variable_name, = named_variables.keys()
   self.assertEqual(r"leaf/v", variable_name)
コード例 #4
0
 def _get_checkpoint_name(self, name):
     root = checkpointable.Checkpointable()
     root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64)
     named_variables, _ = checkpointable._serialize_object_graph(root)
     checkpoint_name, = named_variables.keys()
     with ops.name_scope("root/" + checkpoint_name):
         pass  # Make sure we can use this as an op name if we prefix it.
     return checkpoint_name
コード例 #5
0
    def testAddVariable(self):
        obj = NonLayerCheckpointable()
        with self.assertRaisesRegexp(ValueError, "do not specify shape"):
            obj.add_variable(name="shape_specified_twice",
                             shape=[],
                             initializer=1)
        constant_initializer = obj.add_variable(name="constant_initializer",
                                                initializer=1)
        with variable_scope.variable_scope("some_variable_scope"):
            ones_initializer = obj.add_variable(
                name="ones_initializer",
                shape=[2],
                initializer=init_ops.ones_initializer(dtype=dtypes.float32))
        bare_initializer = obj.add_variable(
            name="bare_initializer",
            shape=[2, 2],
            dtype=dtypes.float64,
            initializer=init_ops.zeros_initializer)

        # Even in graph mode, there are no naming conflicts between objects, only
        # naming conflicts within an object.
        other_duplicate = resource_variable_ops.ResourceVariable(
            name="duplicate", initial_value=1.)
        duplicate = obj.add_variable(name="duplicate", shape=[])
        with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"):
            obj.add_variable(name="duplicate", shape=[])

        if context.in_graph_mode():
            self.evaluate(variables.global_variables_initializer())
        self.assertEqual("constant_initializer:0", constant_initializer.name)
        self.assertEqual(1, self.evaluate(constant_initializer))
        self.assertEqual("some_variable_scope/ones_initializer:0",
                         ones_initializer.name)
        self.assertAllEqual([1, 1], self.evaluate(ones_initializer))
        self.assertAllEqual([[0., 0.], [0., 0.]],
                            self.evaluate(bare_initializer))
        self.assertEqual("a_variable:0", obj.a_variable.name)
        self.assertEqual("duplicate:0", other_duplicate.name)
        if context.in_graph_mode():
            # The .name attribute may be globally influenced, but the checkpoint name
            # won't be (tested below).
            self.assertEqual("duplicate_1:0", duplicate.name)
        else:
            # When executing eagerly, there's no uniquification of variable names. The
            # checkpoint name will be the same.
            self.assertEqual("duplicate:0", duplicate.name)
        named_variables, _ = checkpointable._serialize_object_graph(obj)
        expected_checkpoint_names = (
            "a_variable",
            "bare_initializer",
            "constant_initializer",
            "duplicate",
            "ones_initializer",
        )
        six.assertCountEqual(self, expected_checkpoint_names,
                             named_variables.keys())
コード例 #6
0
  def testAddVariable(self):
    obj = NonLayerCheckpointable()
    with self.assertRaisesRegexp(ValueError, "do not specify shape"):
      obj.add_variable(
          name="shape_specified_twice", shape=[], initializer=1)
    constant_initializer = obj.add_variable(
        name="constant_initializer", initializer=1)
    with variable_scope.variable_scope("some_variable_scope"):
      ones_initializer = obj.add_variable(
          name="ones_initializer",
          shape=[2],
          initializer=init_ops.ones_initializer(dtype=dtypes.float32))
    bare_initializer = obj.add_variable(
        name="bare_initializer",
        shape=[2, 2],
        dtype=dtypes.float64,
        initializer=init_ops.zeros_initializer)

    # Even in graph mode, there are no naming conflicts between objects, only
    # naming conflicts within an object.
    other_duplicate = resource_variable_ops.ResourceVariable(
        name="duplicate", initial_value=1.)
    duplicate = obj.add_variable(name="duplicate", shape=[])
    with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"):
      obj.add_variable(name="duplicate", shape=[])

    if context.in_graph_mode():
      self.evaluate(variables.global_variables_initializer())
    self.assertEqual("constant_initializer:0", constant_initializer.name)
    self.assertEqual(1, self.evaluate(constant_initializer))
    self.assertEqual("some_variable_scope/ones_initializer:0",
                     ones_initializer.name)
    self.assertAllEqual([1, 1], self.evaluate(ones_initializer))
    self.assertAllEqual([[0., 0.],
                         [0., 0.]], self.evaluate(bare_initializer))
    self.assertEqual("a_variable:0", obj.a_variable.name)
    self.assertEqual("duplicate:0", other_duplicate.name)
    if context.in_graph_mode():
      # The .name attribute may be globally influenced, but the checkpoint name
      # won't be (tested below).
      self.assertEqual("duplicate_1:0", duplicate.name)
    else:
      # When executing eagerly, there's no uniquification of variable names. The
      # checkpoint name will be the same.
      self.assertEqual("duplicate:0", duplicate.name)
    named_variables, _ = checkpointable._serialize_object_graph(obj)
    expected_checkpoint_names = (
        "a_variable",
        "bare_initializer",
        "constant_initializer",
        "duplicate",
        "ones_initializer",
    )
    six.assertCountEqual(
        self, expected_checkpoint_names, named_variables.keys())
コード例 #7
0
 def _get_checkpoint_name(self, name):
   root = checkpointable.Checkpointable()
   with variable_scope.variable_scope("get_checkpoint_name"):
     # Create the variable in a variable scope so that we get more relaxed
     # naming rules (variables outside a scope may not start with "_", "/" or
     # "-"). Since we don't use the scope part of the name, these cases are
     # somewhat annoying.
     root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64)
   named_variables, _ = checkpointable._serialize_object_graph(root)
   checkpoint_name, = named_variables.keys()
   with ops.name_scope("root/" + checkpoint_name):
     pass  # Make sure we can use this as an op name if we prefix it.
   return checkpoint_name
コード例 #8
0
 def _get_checkpoint_name(self, name):
   root = checkpointable.Checkpointable()
   with variable_scope.variable_scope("get_checkpoint_name"):
     # Create the variable in a variable scope so that we get more relaxed
     # naming rules (variables outside a scope may not start with "_", "/" or
     # "-"). Since we don't use the scope part of the name, these cases are
     # somewhat annoying.
     root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64)
   named_variables, _ = checkpointable._serialize_object_graph(root)
   checkpoint_name, = named_variables.keys()
   with ops.name_scope("root/" + checkpoint_name):
     pass  # Make sure we can use this as an op name if we prefix it.
   return checkpoint_name
コード例 #9
0
 def testNamingWithOptimizer(self):
   input_value = constant_op.constant([[3.]])
   network = MyNetwork()
   # A nuisance Network using the same optimizer. Its slot variables should not
   # go in the checkpoint, since it is never depended on.
   other_network = MyNetwork()
   optimizer = CheckpointableAdam(0.001)
   root_checkpointable = Root(optimizer=optimizer, network=network)
   if context.in_eager_mode():
     optimizer.minimize(
         lambda: network(input_value),
         global_step=root_checkpointable.global_step)
     optimizer.minimize(
         lambda: other_network(input_value),
         global_step=root_checkpointable.global_step)
   else:
     train_op = optimizer.minimize(
         network(input_value), global_step=root_checkpointable.global_step)
     optimizer.minimize(
         other_network(input_value),
         global_step=root_checkpointable.global_step)
     self.evaluate(variables.global_variables_initializer())
     self.evaluate(train_op)
   named_variables, serialized_graph = checkpointable._serialize_object_graph(
       root_checkpointable)
   expected_checkpoint_names = (
       # Created in the root node, so no prefix.
       "global_step",
       # No name provided to track_checkpointable(), so the position (1, after
       # the named track_checkpointable() which is 0) is used instead.
       "network/_1/kernel",
       # track_checkpointable() with a name provided, so that's used
       "network/named_dense/kernel",
       "network/named_dense/bias",
       # The optimizer creates two non-slot variables
       "optimizer/beta1_power",
       "optimizer/beta2_power",
       # Slot variables
       "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/m",
       "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/v",
       "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/m",
       "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/v",
       "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m",
       "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/v",
   )
   six.assertCountEqual(self, expected_checkpoint_names,
                        named_variables.keys())
   # Check that we've mapped to the right variable objects (not exhaustive)
   self.assertEqual("global_step:0", named_variables["global_step"].name)
   self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0",
                    named_variables["network/_1/kernel"].name)
   self.assertEqual("my_network/checkpointable_dense_layer/kernel:0",
                    named_variables["network/named_dense/kernel"].name)
   self.assertEqual("beta1_power:0",
                    named_variables["optimizer/beta1_power"].name)
   self.assertEqual("beta2_power:0",
                    named_variables["optimizer/beta2_power"].name)
   # Spot check the generated protocol buffers.
   self.assertEqual(0, serialized_graph.nodes[0].children[0].local_uid)
   self.assertEqual("optimizer",
                    serialized_graph.nodes[0].children[0].local_name)
   optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
       0].node_id]
   self.assertEqual("beta1_power", optimizer_node.variables[0].local_name)
   self.assertEqual("beta1_power", optimizer_node.variables[0].full_name)
   self.assertEqual(
       "kernel", optimizer_node.slot_variables[0].original_variable_local_name)
   original_variable_owner = serialized_graph.nodes[
       optimizer_node.slot_variables[0].original_variable_node_id]
   self.assertEqual("kernel", original_variable_owner.variables[0].local_name)
   self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
   # We strip off the :0 suffix, as variable.name-based saving does.
   self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam",
                    optimizer_node.slot_variables[0].full_name)
   self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam:0",
                    optimizer.get_slot(
                        var=named_variables["network/named_dense/kernel"],
                        name="m").name)
コード例 #10
0
 def testNamingWithOptimizer(self):
   input_value = constant_op.constant([[3.]])
   network = MyNetwork()
   # A nuisance Network using the same optimizer. Its slot variables should not
   # go in the checkpoint, since it is never depended on.
   other_network = MyNetwork()
   optimizer = CheckpointableAdam(0.001)
   root_checkpointable = Root(optimizer=optimizer, network=network)
   if context.in_eager_mode():
     optimizer.minimize(
         lambda: network(input_value),
         global_step=root_checkpointable.global_step)
     optimizer.minimize(
         lambda: other_network(input_value),
         global_step=root_checkpointable.global_step)
   else:
     train_op = optimizer.minimize(
         network(input_value), global_step=root_checkpointable.global_step)
     optimizer.minimize(
         other_network(input_value),
         global_step=root_checkpointable.global_step)
     self.evaluate(variables.global_variables_initializer())
     self.evaluate(train_op)
   named_variables, serialized_graph = checkpointable._serialize_object_graph(
       root_checkpointable)
   expected_checkpoint_names = (
       # Created in the root node, so no prefix.
       "global_step",
       # No name provided to track_checkpointable(), so the position is used
       # instead (one-based).
       "network/via_track_layer/kernel",
       # track_checkpointable() with a name provided, so that's used
       "network/_named_dense/kernel",
       "network/_named_dense/bias",
       # non-Layer dependency of the network
       "network/_non_layer/a_variable",
       # The optimizer creates two non-slot variables
       "_optimizer/beta1_power",
       "_optimizer/beta2_power",
       # Slot variables
       "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/m",
       "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/v",
       "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/m",
       "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/v",
       "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m",
       "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/v",
   )
   six.assertCountEqual(self, expected_checkpoint_names,
                        named_variables.keys())
   # Check that we've mapped to the right variable objects (not exhaustive)
   self.assertEqual("global_step:0", named_variables["global_step"].name)
   self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0",
                    named_variables["network/via_track_layer/kernel"].name)
   self.assertEqual("my_network/checkpointable_dense_layer/kernel:0",
                    named_variables["network/_named_dense/kernel"].name)
   self.assertEqual("beta1_power:0",
                    named_variables["_optimizer/beta1_power"].name)
   self.assertEqual("beta2_power:0",
                    named_variables["_optimizer/beta2_power"].name)
   # Spot check the generated protocol buffers.
   self.assertEqual("_optimizer",
                    serialized_graph.nodes[0].children[0].local_name)
   optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
       0].node_id]
   self.assertEqual("beta1_power", optimizer_node.variables[0].local_name)
   self.assertEqual("beta1_power", optimizer_node.variables[0].full_name)
   # Variable ordering is arbitrary but deterministic (alphabetized)
   self.assertEqual(
       "bias", optimizer_node.slot_variables[0].original_variable_local_name)
   original_variable_owner = serialized_graph.nodes[
       optimizer_node.slot_variables[0].original_variable_node_id]
   self.assertEqual("network/_named_dense/bias",
                    original_variable_owner.variables[0].checkpoint_key)
   self.assertEqual("bias", original_variable_owner.variables[0].local_name)
   self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
   self.assertEqual("network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m",
                    optimizer_node.slot_variables[0].checkpoint_key)
   # We strip off the :0 suffix, as variable.name-based saving does.
   self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam",
                    optimizer_node.slot_variables[0].full_name)
   self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam:0",
                    optimizer.get_slot(
                        var=named_variables["network/_named_dense/bias"],
                        name="m").name)