def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with tf.Graph().as_default() as g: with self.test_session(graph=g) as session: with tf.variable_scope("some_scope"): my1 = tf.get_variable("my1", [1, 10]) with tf.variable_scope("some_other_scope"): my2 = tf.get_variable("my2", [10, 10]) with tf.variable_scope("other_useful_scope"): my4 = tf.get_variable("var4", [9, 9]) my3 = tf.get_variable("my3", [100, 100]) checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/my1": "var1", "some_scope/some_other_scope/other_useful_scope/": "useful_scope/", }) checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/some_other_scope/my2": "var2", my3: "var3", }) session.run(tf.initialize_all_variables()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4) # Check that tensors are not explicitly in the graph. self.assertLess(len(str(session.graph.as_graph_def())), 26000)
def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. with tf.Graph().as_default() as g: with self.test_session(graph=g) as session: with tf.variable_scope("some_scope"): # TODO(ipolosukhin): Enable this when get_variable partitioning works. # Currently get_variable with partitioner doesn't return Variable, # but returns a concat op. # my1 = tf.get_variable( # "my1", [100, 100], # partitioner=tf.variable_axis_size_partitioner(axis=0, # max_shard_bytes=100)) my1 = tf.create_partitioned_variables( shape=[100, 100], slicing=[5, 1], name="my1", initializer=tf.truncated_normal_initializer(0.5)) checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/my1": "var1", }) session.run(tf.initialize_all_variables()) my1_values = session.run(my1) self.assertAllEqual(my1_values, v1)
def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with tf.Graph().as_default() as g: with self.test_session(graph=g) as session: with tf.variable_scope("some_scope"): my1 = tf.get_variable("my1", [1, 10]) with tf.variable_scope("some_other_scope"): my2 = tf.get_variable("my2", [10, 10]) with tf.variable_scope("other_useful_scope"): my4 = tf.get_variable("var4", [9, 9]) my3 = tf.get_variable("my3", [100, 100]) checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/my1": "var1", "some_scope/some_other_scope/other_useful_scope/": "useful_scope/", }) checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/some_other_scope/my2": "var2", "my3": "var3", }) session.run(tf.initialize_all_variables()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4) # Check that tensors are not explicitly in the graph. self.assertLess(len(str(session.graph.as_graph_def())), 22000)
def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. with tf.Graph().as_default() as g: with self.test_session(graph=g) as session: with tf.variable_scope("some_scope"): # TODO(ipolosukhin): Enable this when get_variable partitioning works. # Currently get_variable with partitioner doesn't return Variable, # but returns a concat op. # my1 = tf.get_variable( # "my1", [100, 100], # partitioner=tf.variable_axis_size_partitioner(axis=0, # max_shard_bytes=100)) my1 = tf.create_partitioned_variables( shape=[100, 100], slicing=[5, 1], name="my1", initializer=tf.truncated_normal_initializer(0.5)) checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/my1": "var1", }) session.run(tf.initialize_all_variables()) my1_values = session.run(my1) self.assertAllEqual(my1_values, v1)
def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with tf.Graph().as_default() as g: with self.test_session(graph=g) as session: with tf.variable_scope("some_scope"): my1 = tf.get_variable("var1", [1, 10]) my2 = tf.get_variable("var2", [10, 10]) my3 = tf.get_variable("var3", [100, 100]) with tf.variable_scope("useful_scope"): my4 = tf.get_variable("var4", [9, 9]) checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/": "/", }) session.run(tf.initialize_all_variables()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4)
def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with tf.Graph().as_default() as g: with self.test_session(graph=g) as session: with tf.variable_scope("some_scope"): my1 = tf.get_variable("var1", [1, 10]) my2 = tf.get_variable("var2", [10, 10]) my3 = tf.get_variable("var3", [100, 100]) with tf.variable_scope("useful_scope"): my4 = tf.get_variable("var4", [9, 9]) checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/": "/", }) session.run(tf.initialize_all_variables()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4)
def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with tf.Graph().as_default() as g: with self.test_session(graph=g) as session: with tf.variable_scope("some_scope"): _ = tf.get_variable("my1", [10, 10]) _ = tf.get_variable("my2", [1, 10], dtype=tf.int64, initializer=tf.zeros_initializer) # No directory. with self.assertRaises(tf.errors.OpError): checkpoints.init_from_checkpoint( "no_dir", {"some_scope/my1": "var1"}) # No variable in checkpoint. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint( checkpoint_dir, {"some_scope/my1": "no_var"}) # No variable in the graph. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint( checkpoint_dir, {"some_scope/no_var": "var3"}) # Shape mismatch. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint( checkpoint_dir, {"some_scope/my1": "var1"}) # Variable 'my1' and 'my2' are missing in given checkpoint scope. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint( checkpoint_dir, {"some_scope/": "useful_scope/"}) # Mapping is not to scope name. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint( checkpoint_dir, {"some_scope/": "useful_scope"})
def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with tf.Graph().as_default() as g: with self.test_session(graph=g) as session: with tf.variable_scope("some_scope"): _ = tf.get_variable("my1", [10, 10]) _ = tf.get_variable("my2", [1, 10], dtype=tf.int64, initializer=tf.zeros_initializer) # No directory. with self.assertRaises(tf.errors.OpError): checkpoints.init_from_checkpoint("no_dir", { "some_scope/my1": "var1"}) # No variable in checkpoint. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/my1": "no_var"}) # No variable in the graph. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/no_var": "var3"}) # Shape mismatch. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/my1": "var1"}) # DType mismatch. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/my2": "var1"}) # Variable 'my1' and 'my2' are missing in given checkpoint scope. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/": "useful_scope/"}) # Mapping is not to scope name. with self.assertRaises(ValueError): checkpoints.init_from_checkpoint(checkpoint_dir, { "some_scope/": "useful_scope"})