def testNumberedPath(self): root = checkpointable.Checkpointable() leaf = checkpointable.Checkpointable() root.track_checkpointable(leaf) leaf.add_variable(name="v", shape=[]) named_variables, _ = checkpointable._serialize_object_graph(root) variable_name, = named_variables.keys() self.assertEqual(r"_0/v", variable_name)
def _get_checkpoint_name(self, name): root = checkpointable.Checkpointable() root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) named_variables, _ = checkpointable._serialize_object_graph(root) checkpoint_name, = named_variables.keys() with ops.name_scope("root/" + checkpoint_name): pass # Make sure we can use this as an op name if we prefix it. return checkpoint_name
def testNumberedPath(self): root = checkpointable.Checkpointable() leaf = checkpointable.Checkpointable() root.track_checkpointable(leaf, name="leaf") leaf.add_variable(name="v", shape=[]) named_variables, _ = checkpointable._serialize_object_graph(root) variable_name, = named_variables.keys() self.assertEqual(r"leaf/v", variable_name)
def testAddVariable(self): obj = NonLayerCheckpointable() with self.assertRaisesRegexp(ValueError, "do not specify shape"): obj.add_variable(name="shape_specified_twice", shape=[], initializer=1) constant_initializer = obj.add_variable(name="constant_initializer", initializer=1) with variable_scope.variable_scope("some_variable_scope"): ones_initializer = obj.add_variable( name="ones_initializer", shape=[2], initializer=init_ops.ones_initializer(dtype=dtypes.float32)) bare_initializer = obj.add_variable( name="bare_initializer", shape=[2, 2], dtype=dtypes.float64, initializer=init_ops.zeros_initializer) # Even in graph mode, there are no naming conflicts between objects, only # naming conflicts within an object. other_duplicate = resource_variable_ops.ResourceVariable( name="duplicate", initial_value=1.) duplicate = obj.add_variable(name="duplicate", shape=[]) with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): obj.add_variable(name="duplicate", shape=[]) if context.in_graph_mode(): self.evaluate(variables.global_variables_initializer()) self.assertEqual("constant_initializer:0", constant_initializer.name) self.assertEqual(1, self.evaluate(constant_initializer)) self.assertEqual("some_variable_scope/ones_initializer:0", ones_initializer.name) self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) self.assertAllEqual([[0., 0.], [0., 0.]], self.evaluate(bare_initializer)) self.assertEqual("a_variable:0", obj.a_variable.name) self.assertEqual("duplicate:0", other_duplicate.name) if context.in_graph_mode(): # The .name attribute may be globally influenced, but the checkpoint name # won't be (tested below). self.assertEqual("duplicate_1:0", duplicate.name) else: # When executing eagerly, there's no uniquification of variable names. The # checkpoint name will be the same. self.assertEqual("duplicate:0", duplicate.name) named_variables, _ = checkpointable._serialize_object_graph(obj) expected_checkpoint_names = ( "a_variable", "bare_initializer", "constant_initializer", "duplicate", "ones_initializer", ) six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys())
def testAddVariable(self): obj = NonLayerCheckpointable() with self.assertRaisesRegexp(ValueError, "do not specify shape"): obj.add_variable( name="shape_specified_twice", shape=[], initializer=1) constant_initializer = obj.add_variable( name="constant_initializer", initializer=1) with variable_scope.variable_scope("some_variable_scope"): ones_initializer = obj.add_variable( name="ones_initializer", shape=[2], initializer=init_ops.ones_initializer(dtype=dtypes.float32)) bare_initializer = obj.add_variable( name="bare_initializer", shape=[2, 2], dtype=dtypes.float64, initializer=init_ops.zeros_initializer) # Even in graph mode, there are no naming conflicts between objects, only # naming conflicts within an object. other_duplicate = resource_variable_ops.ResourceVariable( name="duplicate", initial_value=1.) duplicate = obj.add_variable(name="duplicate", shape=[]) with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): obj.add_variable(name="duplicate", shape=[]) if context.in_graph_mode(): self.evaluate(variables.global_variables_initializer()) self.assertEqual("constant_initializer:0", constant_initializer.name) self.assertEqual(1, self.evaluate(constant_initializer)) self.assertEqual("some_variable_scope/ones_initializer:0", ones_initializer.name) self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) self.assertAllEqual([[0., 0.], [0., 0.]], self.evaluate(bare_initializer)) self.assertEqual("a_variable:0", obj.a_variable.name) self.assertEqual("duplicate:0", other_duplicate.name) if context.in_graph_mode(): # The .name attribute may be globally influenced, but the checkpoint name # won't be (tested below). self.assertEqual("duplicate_1:0", duplicate.name) else: # When executing eagerly, there's no uniquification of variable names. The # checkpoint name will be the same. self.assertEqual("duplicate:0", duplicate.name) named_variables, _ = checkpointable._serialize_object_graph(obj) expected_checkpoint_names = ( "a_variable", "bare_initializer", "constant_initializer", "duplicate", "ones_initializer", ) six.assertCountEqual( self, expected_checkpoint_names, named_variables.keys())
def _get_checkpoint_name(self, name): root = checkpointable.Checkpointable() with variable_scope.variable_scope("get_checkpoint_name"): # Create the variable in a variable scope so that we get more relaxed # naming rules (variables outside a scope may not start with "_", "/" or # "-"). Since we don't use the scope part of the name, these cases are # somewhat annoying. root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) named_variables, _ = checkpointable._serialize_object_graph(root) checkpoint_name, = named_variables.keys() with ops.name_scope("root/" + checkpoint_name): pass # Make sure we can use this as an op name if we prefix it. return checkpoint_name
def testNamingWithOptimizer(self): input_value = constant_op.constant([[3.]]) network = MyNetwork() # A nuisance Network using the same optimizer. Its slot variables should not # go in the checkpoint, since it is never depended on. other_network = MyNetwork() optimizer = CheckpointableAdam(0.001) root_checkpointable = Root(optimizer=optimizer, network=network) if context.in_eager_mode(): optimizer.minimize( lambda: network(input_value), global_step=root_checkpointable.global_step) optimizer.minimize( lambda: other_network(input_value), global_step=root_checkpointable.global_step) else: train_op = optimizer.minimize( network(input_value), global_step=root_checkpointable.global_step) optimizer.minimize( other_network(input_value), global_step=root_checkpointable.global_step) self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) named_variables, serialized_graph = checkpointable._serialize_object_graph( root_checkpointable) expected_checkpoint_names = ( # Created in the root node, so no prefix. "global_step", # No name provided to track_checkpointable(), so the position (1, after # the named track_checkpointable() which is 0) is used instead. "network/_1/kernel", # track_checkpointable() with a name provided, so that's used "network/named_dense/kernel", "network/named_dense/bias", # The optimizer creates two non-slot variables "optimizer/beta1_power", "optimizer/beta2_power", # Slot variables "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/m", "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/v", "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/m", "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/v", "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m", "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/v", ) six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual("global_step:0", named_variables["global_step"].name) self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", named_variables["network/_1/kernel"].name) self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", named_variables["network/named_dense/kernel"].name) self.assertEqual("beta1_power:0", named_variables["optimizer/beta1_power"].name) self.assertEqual("beta2_power:0", named_variables["optimizer/beta2_power"].name) # Spot check the generated protocol buffers. self.assertEqual(0, serialized_graph.nodes[0].children[0].local_uid) self.assertEqual("optimizer", serialized_graph.nodes[0].children[0].local_name) optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ 0].node_id] self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) self.assertEqual( "kernel", optimizer_node.slot_variables[0].original_variable_local_name) original_variable_owner = serialized_graph.nodes[ optimizer_node.slot_variables[0].original_variable_node_id] self.assertEqual("kernel", original_variable_owner.variables[0].local_name) self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) # We strip off the :0 suffix, as variable.name-based saving does. self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam", optimizer_node.slot_variables[0].full_name) self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam:0", optimizer.get_slot( var=named_variables["network/named_dense/kernel"], name="m").name)
def testNamingWithOptimizer(self): input_value = constant_op.constant([[3.]]) network = MyNetwork() # A nuisance Network using the same optimizer. Its slot variables should not # go in the checkpoint, since it is never depended on. other_network = MyNetwork() optimizer = CheckpointableAdam(0.001) root_checkpointable = Root(optimizer=optimizer, network=network) if context.in_eager_mode(): optimizer.minimize( lambda: network(input_value), global_step=root_checkpointable.global_step) optimizer.minimize( lambda: other_network(input_value), global_step=root_checkpointable.global_step) else: train_op = optimizer.minimize( network(input_value), global_step=root_checkpointable.global_step) optimizer.minimize( other_network(input_value), global_step=root_checkpointable.global_step) self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) named_variables, serialized_graph = checkpointable._serialize_object_graph( root_checkpointable) expected_checkpoint_names = ( # Created in the root node, so no prefix. "global_step", # No name provided to track_checkpointable(), so the position is used # instead (one-based). "network/via_track_layer/kernel", # track_checkpointable() with a name provided, so that's used "network/_named_dense/kernel", "network/_named_dense/bias", # non-Layer dependency of the network "network/_non_layer/a_variable", # The optimizer creates two non-slot variables "_optimizer/beta1_power", "_optimizer/beta2_power", # Slot variables "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/m", "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/v", "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/m", "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/v", "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/v", ) six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual("global_step:0", named_variables["global_step"].name) self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", named_variables["network/via_track_layer/kernel"].name) self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", named_variables["network/_named_dense/kernel"].name) self.assertEqual("beta1_power:0", named_variables["_optimizer/beta1_power"].name) self.assertEqual("beta2_power:0", named_variables["_optimizer/beta2_power"].name) # Spot check the generated protocol buffers. self.assertEqual("_optimizer", serialized_graph.nodes[0].children[0].local_name) optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ 0].node_id] self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) # Variable ordering is arbitrary but deterministic (alphabetized) self.assertEqual( "bias", optimizer_node.slot_variables[0].original_variable_local_name) original_variable_owner = serialized_graph.nodes[ optimizer_node.slot_variables[0].original_variable_node_id] self.assertEqual("network/_named_dense/bias", original_variable_owner.variables[0].checkpoint_key) self.assertEqual("bias", original_variable_owner.variables[0].local_name) self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) self.assertEqual("network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", optimizer_node.slot_variables[0].checkpoint_key) # We strip off the :0 suffix, as variable.name-based saving does. self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam", optimizer_node.slot_variables[0].full_name) self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam:0", optimizer.get_slot( var=named_variables["network/_named_dense/bias"], name="m").name)