def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable("my1", [1, 10]) with variable_scope.variable_scope("some_other_scope"): my2 = variable_scope.get_variable("my2", [10, 10]) with variable_scope.variable_scope("other_useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) my3 = variable_scope.get_variable("my3", [100, 100]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "some_scope/my1", "useful_scope/": "some_scope/some_other_scope/other_useful_scope/", }) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var2": "some_scope/some_other_scope/my2", "var3": my3, }) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4) # Check that tensors are not explicitly in the graph. self.assertLess(len(str(session.graph.as_graph_def())), 27000)
def begin(self): if self._init_checkpoint is None: return init_map = create_embedding_map(self._init_checkpoint) tf.logging.info("embeddings to be initialized from {}: {}".format( self._init_checkpoint, init_map)) checkpoint_utils.init_from_checkpoint(self._init_checkpoint, init_map) pass
def _maybe_restore_from_checkpoint(checkpoint_path, variable): if checkpoint_path is not None: path, tensor_name = checkpoint_path weights_to_restore = variable if len(variable) == 1: weights_to_restore = variable[0] checkpoint_utils.init_from_checkpoint(path, {tensor_name: weights_to_restore})
def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable( name="my1", shape=[100, 100], initializer=init_ops.truncated_normal_initializer(0.5), partitioner=partitioned_variables.min_max_variable_partitioner( max_partitions=5, axis=0, min_slice_size=8 << 10)) my1_var_list = my1._get_variable_list() with variable_scope.variable_scope("some_other_scope"): my2 = variable_scope.get_variable( name="var1", shape=[100, 100], initializer=init_ops.truncated_normal_initializer(0.5), partitioner=partitioned_variables.min_max_variable_partitioner( max_partitions=5, axis=0, min_slice_size=8 << 10)) my2_var_list = my2._get_variable_list() checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "scope/var1": "some_scope/my1", "scope/": "some_other_scope/"}) session.run(variables.global_variables_initializer()) my1_values = session.run(my1_var_list) self.assertAllEqual(my1_values, v1) my2_values = session.run(my2_var_list) self.assertAllEqual(my2_values, v1) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable( name="my1", shape=[100, 100], initializer=init_ops.truncated_normal_initializer(0.5), partitioner=partitioned_variables.min_max_variable_partitioner( max_partitions=5, axis=0, min_slice_size=8 << 10)) my1_var_list = my1._get_variable_list() checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"scope/var1": my1_var_list,}) session.run(variables.global_variables_initializer()) my1_values = session.run(my1_var_list) self.assertAllEqual(my1_values, v1)
def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable( name="my1", shape=[100, 100], initializer=init_ops.truncated_normal_initializer(0.5), partitioner=partitioned_variables. min_max_variable_partitioner(max_partitions=5, axis=0, min_slice_size=8 << 10)) my1_var_list = my1._get_variable_list() checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "some_scope/my1", }) session.run(variables.global_variables_initializer()) my1_values = session.run(my1_var_list) self.assertAllEqual(my1_values, v1) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable( name="my1", shape=[100, 100], initializer=init_ops.truncated_normal_initializer(0.5), partitioner=partitioned_variables. min_max_variable_partitioner(max_partitions=5, axis=0, min_slice_size=8 << 10)) my1_var_list = my1._get_variable_list() checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": my1_var_list, }) session.run(variables.global_variables_initializer()) my1_values = session.run(my1_var_list) self.assertAllEqual(my1_values, v1)
def testInitWithScopeDoesNotCaptureSuffixes(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: _, _, _, v4 = _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default() as g: with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) with variable_scope.variable_scope("useful_scope_1"): my5_init = [[1.0, 2.0], [3.0, 4.0]] my5 = variable_scope.get_variable("var5", initializer=my5_init) checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope/": "useful_scope/"}) with self.session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(my4.eval(session), v4) self.assertAllEqual(my5.eval(session), my5_init)
def testInitToRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: my1 = variable_scope.get_variable("var1", [1, 10]) my2 = variable_scope.get_variable("var2", [10, 10]) my3 = variable_scope.get_variable("var3", [100, 100]) with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"/": "/",}) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4)
def testInitToRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: my1 = variable_scope.get_variable("var1", [1, 10]) my2 = variable_scope.get_variable("var2", [10, 10]) my3 = variable_scope.get_variable("var3", [100, 100]) with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "/": "/", }) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4)
def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable("my1", [1, 10]) with variable_scope.variable_scope("some_other_scope"): my2 = variable_scope.get_variable("my2", [10, 10]) with variable_scope.variable_scope( "other_useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) my3 = variable_scope.get_variable("my3", [100, 100]) checkpoint_utils.init_from_checkpoint( checkpoint_dir, { "var1": "some_scope/my1", "useful_scope/": "some_scope/some_other_scope/other_useful_scope/", }) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var2": "some_scope/some_other_scope/my2", "var3": my3, }) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4) # Check that tensors are not explicitly in the graph. self.assertLess(len(str(session.graph.as_graph_def())), 27000)
def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: with variable_scope.variable_scope("some_scope"): _ = variable_scope.get_variable("my1", [10, 10]) _ = variable_scope.get_variable( "my2", [1, 10], dtype=dtypes.int64, initializer=init_ops.zeros_initializer()) # No directory. with self.assertRaises(errors_impl.OpError): checkpoint_utils.init_from_checkpoint( "no_dir", {"var1": "some_scope/my1"}) # No variable in checkpoint. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"no_var": "some_scope/my1"}) # No variable in the graph. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"var3": "some_scope/no_var"}) # Shape mismatch. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"var1": "some_scope/my1"}) # Variable 'my1' and 'my2' are missing in given checkpoint scope. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope/": "some_scope/"}) # Mapping is not to scope name. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope": "some_scope/"})
def init_from_checkpoint(checkpoint_dir, assignment_map): """See `tf.contrib.framework.init_from_checkpoint`.""" checkpoint_utils.init_from_checkpoint(checkpoint_dir, assignment_map)
def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): _ = variable_scope.get_variable("my1", [10, 10]) _ = variable_scope.get_variable( "my2", [1, 10], dtype=dtypes.int64, initializer=init_ops.zeros_initializer()) # No directory. with self.assertRaises(errors_impl.OpError): checkpoint_utils.init_from_checkpoint("no_dir", {"var1": "some_scope/my1"}) # No variable in checkpoint. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"no_var": "some_scope/my1"}) # No variable in the graph. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var3": "some_scope/no_var"}) # Shape mismatch. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": "some_scope/my1"}) # Variable 'my1' and 'my2' are missing in given checkpoint scope. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope/": "some_scope/"}) # Mapping is not to scope name. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope": "some_scope/"})
def restore_from_checkpoint(chkpt_dir_or_file, restore_chkpt_scopes=None, chkpt_to_graph_scope_map=None): if chkpt_to_graph_scope_map: print("map graph scopes to checkpoint scopes: {}".format( chkpt_to_graph_scope_map)) print("list variables in checkpoint:") list_of_name_and_shape = tf.train.list_variables(chkpt_dir_or_file) init_assignment_map = dict() for chkpt_var_name, chkpt_shape in list_of_name_and_shape: if '/Adam' in str(chkpt_var_name): continue if chkpt_to_graph_scope_map is None or len( chkpt_to_graph_scope_map) == 0: graph_var_name = chkpt_var_name else: graph_var_name = None for c_scope, g_scope in chkpt_to_graph_scope_map.items(): if str(chkpt_var_name).startswith(c_scope): graph_var_name = str(chkpt_var_name).replace( c_scope, g_scope) break if graph_var_name is None: continue graph_var = check_if_variable(graph_var_name) graph_shape = None if graph_var is not None: if not isinstance(graph_var, list): graph_shape = graph_var.get_shape().as_list() else: graph_shape = graph_var[0].get_shape().as_list() for i in range(1, len(graph_var)): graph_shape[0] += graph_var[i].get_shape().as_list()[0] if graph_var is None: print( "var not found in graph: name={} checkpoint_name={} checkpoint_shape={}" .format(graph_var_name, chkpt_var_name, chkpt_shape)) elif graph_shape != chkpt_shape: print( "bad shape: name={} checkpoint_name={} shape={} checkpoint_shape={} var={}" .format(graph_var_name, chkpt_var_name, graph_shape, chkpt_shape, graph_var)) else: print( "ready to load: name={} checkpoint_name={} shape={} checkpoint_shape={} var={}" .format(graph_var_name, chkpt_var_name, graph_shape, chkpt_shape, graph_var)) init_assignment_map[chkpt_var_name] = graph_var_name assignment_map = dict() if restore_chkpt_scopes is not None and len(restore_chkpt_scopes) > 0: print("load var in checkpoint scope: {}".format( ", ".join(restore_chkpt_scopes))) for k, v in init_assignment_map.items(): for var_scope in restore_chkpt_scopes: if k.startswith(var_scope): assignment_map[k] = v else: print('load all possible var from checkpoint') assignment_map = init_assignment_map print("init_from_checkpoint: " + pprint.pformat(assignment_map)) checkpoint_utils.init_from_checkpoint(chkpt_dir_or_file, assignment_map)