def _warm_start_var(var, prev_ckpt, prev_tensor_name=None): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Args: var: Current graph's variable that needs to be warm-started (initialized). Can be either of the following: (i) `Variable` (ii) `ResourceVariable` (iii) list of `Variable`: The list must contain slices of the same larger variable. (iv) `PartitionedVariable` prev_ckpt: A string specifying the directory with checkpoint file(s) or path to checkpoint. The given checkpoint must have tensor with name `prev_tensor_name` (if not None) or tensor with name same as given `var`. prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. """ if checkpoint_utils._is_variable(var): # pylint: disable=protected-access current_var_name = _infer_var_name([var]) elif (isinstance(var, list) and all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access current_var_name = _infer_var_name(var) elif isinstance(var, variables_lib.PartitionedVariable): current_var_name = _infer_var_name([var]) var = var._get_variable_list() # pylint: disable=protected-access else: raise TypeError( "var MUST be one of the following: a Variable, list of Variable or " "PartitionedVariable, but is {}".format(type(var))) if not prev_tensor_name: # Assume tensor name remains the same. prev_tensor_name = current_var_name checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
def testNoAdditionalReadOpsForResourceVariables(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1") with ops.name_scope("init_from_checkpoint"): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1}) # Basic sanity checks: session.run(variables.global_variables_initializer()) self.assertAllEqual(session.run(my1), v1) ops_in_init_from_checkpoint_scope = [ op for op in g.get_operations() if (op.name.startswith("init_from_checkpoint/") and not op.name.startswith("init_from_checkpoint/checkpoint_initializer" ) and op.type != "AssignVariableOp" and op.type != "Identity") ] self.assertEqual(ops_in_init_from_checkpoint_scope, [])
def testInitialValueComesFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: with variable_scope.variable_scope( "some_scope", initializer=init_ops.zeros_initializer()): my1 = variable_scope.get_variable("my1", [1, 10]) before = my1.initialized_value() checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1}) after = my1.initialized_value() self.assertAllEqual(session.run(before), [[0.0] * 10]) self.assertAllEqual(session.run(after), v1) session.run(variables.global_variables_initializer()) self.assertAllEqual(session.run(my1), v1) self.assertAllEqual(session.run(my1.initialized_value()), v1) self.assertAllClose(session.run(before), v1) self.assertAllClose(session.run(after), v1) with self.assertRaises(AssertionError): self.assertAllClose(v1, [[0.0] * 10])
def testInitialValueComesFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope( "some_scope", initializer=init_ops.zeros_initializer()): my1 = variable_scope.get_variable("my1", [1, 10]) before = my1.initialized_value() checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1}) after = my1.initialized_value() self.assertAllEqual(session.run(before), [[0.0] * 10]) self.assertAllEqual(session.run(after), v1) session.run(variables.global_variables_initializer()) self.assertAllEqual(session.run(my1), v1) self.assertAllEqual(session.run(my1.initialized_value()), v1) self.assertAllClose(session.run(before), v1) self.assertAllClose(session.run(after), v1) with self.assertRaises(AssertionError): self.assertAllClose(v1, [[0.0] * 10])
def _warm_start_var(var, prev_ckpt, prev_tensor_name=None): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Args: var: Current graph's variable that needs to be warm-started (initialized). Can be either of the following: (i) `Variable` (ii) `ResourceVariable` (iii) list of `Variable`: The list must contain slices of the same larger variable. (iv) `PartitionedVariable` prev_ckpt: A string specifying the directory with checkpoint file(s) or path to checkpoint. The given checkpoint must have tensor with name `prev_tensor_name` (if not None) or tensor with name same as given `var`. prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. """ if _is_variable(var): current_var_name = _infer_var_name([var]) elif isinstance(var, list) and all(_is_variable(v) for v in var): current_var_name = _infer_var_name(var) elif isinstance(var, variables_lib.PartitionedVariable): current_var_name = _infer_var_name([var]) var = var._get_variable_list() # pylint: disable=protected-access else: raise TypeError( "var MUST be one of the following: a Variable, list of Variable or " "PartitionedVariable, but is {}".format(type(var))) if not prev_tensor_name: # Assume tensor name remains the same. prev_tensor_name = current_var_name checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable("my1", [1, 10]) with variable_scope.variable_scope("some_other_scope"): my2 = variable_scope.get_variable("my2", [10, 10]) with variable_scope.variable_scope("other_useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) my3 = variable_scope.get_variable("my3", [100, 100]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "some_scope/my1", "useful_scope/": "some_scope/some_other_scope/other_useful_scope/", }) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var2": "some_scope/some_other_scope/my2", "var3": my3, }) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4) # Check that tensors are not explicitly in the graph. self.assertLess(len(str(session.graph.as_graph_def())), 29000)
def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable("my1", [1, 10]) with variable_scope.variable_scope("some_other_scope"): my2 = variable_scope.get_variable("my2", [10, 10]) with variable_scope.variable_scope("other_useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) my3 = variable_scope.get_variable("my3", [100, 100]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "some_scope/my1", "useful_scope/": "some_scope/some_other_scope/other_useful_scope/", }) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var2": "some_scope/some_other_scope/my2", "var3": my3, }) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4) # Check that tensors are not explicitly in the graph. self.assertLess(len(str(session.graph.as_graph_def())), 29000)
def testInitialValueComesFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope( "some_scope", initializer=init_ops.zeros_initializer()): my1 = variable_scope.get_variable("my1", [1, 10]) # At this point, my1.initialized_value() will add ops that reference # the zeros initializer of my1. before = variables.Variable(my1.initialized_value(), name="before") checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1}) # At this point, my1.initialized_value() will add ops that reference # the newly set initializer of my1. after = variables.Variable(my1.initialized_value(), name="after") session.run(variables.global_variables_initializer()) self.assertAllEqual(session.run(my1), v1) self.assertAllEqual(session.run(my1.initialized_value()), v1) self.assertAllClose(session.run(before), [[0.0] * 10]) self.assertAllClose(session.run(after), v1) with self.assertRaises(AssertionError): self.assertAllClose(session.run(before), session.run(after))
def _get_dense_tensor(self,inputs,weight_collections=None,trainable=None): """Private method that follows the signature of _get_dense_tensor.""" # Get sparse IDs and weights. sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access inputs, weight_collections=weight_collections, trainable=trainable) sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor candidate_dense_tensors = self._get_candidate_dense_tensor(inputs,weight_collections,trainable) embedding_weights = self.layer_creator( weight_collections=weight_collections, scope=variable_scope.get_variable_scope()) if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): to_restore = to_restore._get_variable_list() # pylint: disable=protected-access checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, { self.tensor_name_in_ckpt: to_restore }) # Return embedding lookup result. return attention_safe_embedding_lookup_sparse( embedding_weights=embedding_weights, sparse_ids=sparse_ids, sparse_weights=sparse_weights, candidate_dense_tensors = candidate_dense_tensors, combiner=self.combiner, name='%s_weights' % self.name, max_norm=self.max_norm)
def testInitialValueComesFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope( "some_scope", initializer=init_ops.zeros_initializer()): my1 = variable_scope.get_variable("my1", [1, 10]) # At this point, my1.initialized_value() will add ops that reference # the zeros initializer of my1. before = variables.Variable(my1.initialized_value(), name="before") checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1}) # At this point, my1.initialized_value() will add ops that reference # the newly set initializer of my1. after = variables.Variable(my1.initialized_value(), name="after") session.run(variables.global_variables_initializer()) self.assertAllEqual(session.run(my1), v1) self.assertAllEqual(session.run(my1.initialized_value()), v1) self.assertAllClose(session.run(before), [[0.0] * 10]) self.assertAllClose(session.run(after), v1) with self.assertRaises(AssertionError): self.assertAllClose(session.run(before), session.run(after))
def testNoAdditionalReadOpsForResourceVariables(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1") with ops.name_scope("init_from_checkpoint"): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"var1": my1}) # Basic sanity checks: session.run(variables.global_variables_initializer()) self.assertAllEqual(session.run(my1), v1) ops_in_init_from_checkpoint_scope = [ op for op in g.get_operations() if (op.name.startswith("init_from_checkpoint/") and not op.name. startswith("init_from_checkpoint/checkpoint_initializer") and op.type != "AssignVariableOp" and op.type != "Identity") ] self.assertEqual(ops_in_init_from_checkpoint_scope, [])
def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "new_var1", }) with self.test_session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(v1_value, self.evaluate(v1))
def init_lm_checkpoints(lm_dirs): assignment_map = {'LanguageModel/': 'LanguageModel/'} init_from_checkpoint(os.path.join(lm_dirs['forward'], 'ckpt'), assignment_map=assignment_map) if lm_dirs['reverse'] is not None: assignment_map = {'LanguageModel/': 'LanguageModelReverse/'} init_from_checkpoint(os.path.join(lm_dirs['reverse'], 'ckpt'), assignment_map=assignment_map)
def testRestoreRunsOnSameDevice(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default(): with ops.device("/job:ps"): with variable_scope.variable_scope("useful_scope"): variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope/": "useful_scope/"})
def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) # Use string add to create new object in each replica prefix = "new_" suffix = "var1" new_var1 = prefix + suffix checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": new_var1, }) with self.test_session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(v1_value, self.evaluate(v1))
def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) # Use string add to create new object in each replica prefix = "new_" suffix = "var1" new_var1 = prefix + suffix checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": new_var1, }) with self.test_session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(v1_value, self.evaluate(v1))
def testRestoreRunsOnSameDevice(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default(): with ops.device("/job:ps"): with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope/": "useful_scope/"})
def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) v2 = variable_scope.get_variable( "new_var2", [10, 10], synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.MEAN) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "new_var1", "var2": "new_var2" }) with self.session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(v1_value, self.evaluate(v1)) self.assertAllEqual(v2_value, self.evaluate(v2))
def testRestoreRunsOnSameDevice(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default(): with ops.device("/job:ps"): with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope/": "useful_scope/"}) self.assertEqual(my4._initializer_op.op.inputs[1].device, "/job:ps")
def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) v2 = variable_scope.get_variable( "new_var2", [10, 10], synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.MEAN) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "new_var1", "var2": "new_var2" }) with self.session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(v1_value, self.evaluate(v1)) self.assertAllEqual(v2_value, self.evaluate(v2))
def testRestoreRunsOnSameDevice(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default(): with ops.device("/job:ps"): with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope/": "useful_scope/"}) # initializer runs on the same task but always on CPU. self.assertEqual(my4._initializer_op.op.inputs[1].device, "/job:ps/device:CPU:0")
def testInitWithScopeDoesNotCaptureSuffixes(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _, _, _, v4 = _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default() as g: with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) with variable_scope.variable_scope("useful_scope_1"): my5_init = [[1.0, 2.0], [3.0, 4.0]] my5 = variable_scope.get_variable("var5", initializer=my5_init) checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope/": "useful_scope/"}) with self.session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(my4.eval(session), v4) self.assertAllEqual(my5.eval(session), my5_init)
def testInitWithScopeDoesNotCaptureSuffixes(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _, _, _, v4 = _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default() as g: with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) with variable_scope.variable_scope("useful_scope_1"): my5_init = [[1.0, 2.0], [3.0, 4.0]] my5 = variable_scope.get_variable("var5", initializer=my5_init) checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope/": "useful_scope/"}) with self.test_session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(my4.eval(session), v4) self.assertAllEqual(my5.eval(session), my5_init)
def testInitToRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: my1 = variable_scope.get_variable("var1", [1, 10]) my2 = variable_scope.get_variable("var2", [10, 10]) my3 = variable_scope.get_variable("var3", [100, 100]) with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"/": "/",}) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4)
def testInitToRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: my1 = variable_scope.get_variable("var1", [1, 10]) my2 = variable_scope.get_variable("var2", [10, 10]) my3 = variable_scope.get_variable("var3", [100, 100]) with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"/": "/",}) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4)
def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): _ = variable_scope.get_variable("my1", [10, 10]) _ = variable_scope.get_variable( "my2", [1, 10], dtype=dtypes.int64, initializer=init_ops.zeros_initializer()) # No directory. with self.assertRaises(errors_impl.OpError): checkpoint_utils.init_from_checkpoint("no_dir", {"var1": "some_scope/my1"}) # No variable in checkpoint. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"no_var": "some_scope/my1"}) # No variable in the graph. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var3": "some_scope/no_var"}) # Shape mismatch. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": "some_scope/my1"}) # Variable 'my1' and 'my2' are missing in given checkpoint scope. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope/": "some_scope/"}) # Mapping is not to scope name. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope": "some_scope/"})
def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable( name="my1", shape=[100, 100], initializer=init_ops.zeros_initializer(), partitioner=partitioned_variables.min_max_variable_partitioner( max_partitions=5, axis=0, min_slice_size=8 << 10)) my1_var_list = my1._get_variable_list() # Create another variable with different partitions than the variable in # the checkpoint. with variable_scope.variable_scope("some_other_scope"): my2 = variable_scope.get_variable( name="var1", shape=[100, 100], initializer=init_ops.zeros_initializer(), partitioner=partitioned_variables.min_max_variable_partitioner( max_partitions=5, axis=0, min_slice_size=16 << 10)) my2_var_list = my2._get_variable_list() checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "scope/var1": "some_scope/my1", "scope/": "some_other_scope/"}) session.run(variables.global_variables_initializer()) my1_values = session.run(my1_var_list) self.assertAllEqual(my1_values, v1) my2_values = session.run(my2_var_list) # Verify we created different number of partitions. self.assertNotEquals(len(my2_values), len(v1)) # Verify the values were correctly initialized inspite of different # partitions. full_my2_values = np.concatenate(my2_values, axis=0) full_v1_values = np.concatenate(v1, axis=0) self.assertAllEqual(full_my2_values, full_v1_values) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable( name="my1", shape=[100, 100], initializer=init_ops.truncated_normal_initializer(0.5), partitioner=partitioned_variables.min_max_variable_partitioner( max_partitions=5, axis=0, min_slice_size=8 << 10)) my1_var_list = my1._get_variable_list() checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"scope/var1": my1_var_list,}) session.run(variables.global_variables_initializer()) my1_values = session.run(my1_var_list) self.assertAllEqual(my1_values, v1)
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): #Get sparse IDs and weights. sparse_tensors = self.categorical_column._get_sparse_tensors( #pylint: disable=protected-access inputs, weight_collections=weight_collections, trainable=trainable) sparse_ids = sparse_tensors.id_tensor batch_size = sparse_ids.dense_shape[0] dense_tensor_ids = sparse_ops.sparse_to_dense( sparse_ids.indices, [batch_size, self.max_sequence_length], sparse_ids.values, default_value=0) # Create embedding weight, and restore from checkpoint if necessary. embedding_weights = variable_scope.get_variable( name='embedding_weights', shape=(self.categorical_column._num_buckets, self.embedding_dimension), # pylint: disable=protected-access dtype=dtypes.float32, initializer=self.initializer, trainable=self.trainable and trainable, collections=weight_collections) if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): to_restore = to_restore._get_variable_list() # pylint: disable=protected-access checkpoint_utils.init_from_checkpoint( self.ckpt_to_load_from, {self.tensor_name_in_ckpt: to_restore}) #dense_tensor_ids = utils.tf_print(dense_tensor_ids, "dense:") embedding_inputs = embedding_lookup(embedding_weights, dense_tensor_ids, max_norm=self.max_norm) dropout = (self.dropout_keep_probabilities if self.mode == model_fn_lib.ModeKeys.TRAIN else None) sequence_lengths = self._sequence_lengths(sparse_ids) if self.bidirectional_rnn: cell_fw = rnn_common.construct_rnn_cell(self.num_units, self.cell_type, dropout) cell_bw = rnn_common.construct_rnn_cell(self.num_units, self.cell_type, dropout) with ops.name_scope('RNN'): rnn_outputs, final_states = rnn.bidirectional_dynamic_rnn( cell_fw, cell_bw, embedding_inputs, sequence_length=sequence_lengths, dtype=dtypes.float32) #outputs = layers.fully_connected( # inputs=array_ops.concat(rnn_outputs, 2), # num_outputs=self.num_units, # activation_fn=self.activation_fn, # trainable=True) return array_ops.concat(final_states, 1) else: cell = rnn_common.construct_rnn_cell(self.num_units, self.cell_type, dropout) with ops.name_scope('RNN'): rnn_outputs, final_state = rnn.dynamic_rnn( cell, embedding_inputs, sequence_length=sequence_lengths, dtype=dtypes.float32) #rnn_outputs = utils.tf_print(rnn_outputs, "rnn_output:") #rnn_last_outputs = utils.tf_print(rnn_last_outputs, "rnn_last:") #outputs = layers.fully_connected( # inputs=rnn_outputs, # num_outputs=self.num_units, # activation_fn=self.activation_fn, # trainable=True) return final_state.h
def _warmstart_var(var, prev_ckpt, prev_tensor_name=None): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Args: var: Current graph's variable that needs to be warm-started (initialized). Can be either of the following: (i) `Variable` (ii) `ResourceVariable` (iii) `PartitionedVariable` (iv) list of `Variable` and/or `PartitionedVariable`: The list may contain one or more variables that has been sharded. For example: [Variable('a/part_0'), Variable('b/part_0'), Variable('a/part_1'), PartitionedVariable([Variable('c/part_0'), Variable('c/part_1')])] where we have three whole Variables represented ('a', 'b', and 'c'). prev_ckpt: A string specifying the directory with checkpoint file(s) or path to checkpoint. The given checkpoint must have tensor with name `prev_tensor_name` (if not None) or tensor with name same as given `var`. prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. Raises: ValueError: If prev_tensor_name is not None, but the given var represents more than one Variable. TypeError: If var is not one of the allowed types. """ if _is_variable(var): current_var_name = _infer_var_name([var]) elif isinstance(var, variables.PartitionedVariable): current_var_name = _infer_var_name([var]) var = var._get_variable_list() # pylint: disable=protected-access elif (isinstance(var, list) and all( _is_variable(v) or isinstance(v, variables.PartitionedVariable) for v in var)): # Convert length-1 lists of vars to single tf.Variables. This ensures that # checkpoint_utils.init_from_checkpoint() doesn't incorrectly assume # slice info is present. if len(var) == 1: current_var_name = _infer_var_name(var) var = var[0] else: # If we have multiple elements in var, we cannot assume they all # represent the same Variable. name_to_var_dict = saver.BaseSaverBuilder.OpListToDict( var, convert_variable_to_tensor=False) if prev_tensor_name: # Providing a prev_tensor_name is only viable if var representes a # single Variable. if len(name_to_var_dict) > 1: raise ValueError("var represented more than one Variable, but " "prev_tensor_name was provided.") checkpoint_utils.init_from_checkpoint(prev_ckpt, { prev_tensor_name: var }) else: # OpListToDict gives us roughly what we need, but # the values in the dict may be PartitionedVariables (which # init_from_checkpoint does not expect) that we need to convert to # lists. name_to_var_dict_fixed = {} for name, var in six.iteritems(name_to_var_dict): if isinstance(var, variables.PartitionedVariable): name_to_var_dict_fixed[name] = var._get_variable_list() # pylint: disable=protected-access else: name_to_var_dict_fixed[name] = var checkpoint_utils.init_from_checkpoint(prev_ckpt, name_to_var_dict_fixed) return else: raise TypeError( "var MUST be one of the following: a Variable, PartitionedVariable, or " "list of Variable's and/or PartitionedVariable's, but is {}".format( type(var))) if not prev_tensor_name: # Assume tensor name remains the same. prev_tensor_name = current_var_name checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
def _warm_start_variables(ckpt_to_initialize_from, vars_to_warm_start, var_name_to_vocab_info, var_name_to_prev_var_name): grouped_variables = _get_grouped_variables(vars_to_warm_start) if var_name_to_vocab_info is None: var_name_to_vocab_info = {} if not var_name_to_prev_var_name: # Detect whether the checkpoint is object-based, in which case the # var_name_to_prev_var_name dictionary should map variable names to # checkpoint keys. If the user has specified var_name_to_prev_var_name, we # do not override it. var_name_to_prev_var_name = _get_object_checkpoint_renames( ckpt_to_initialize_from, grouped_variables.keys()) warmstarted_count = 0 # Keep track of which var_names in var_name_to_prev_var_name and # var_name_to_vocab_info have been used. Err on the safer side by throwing an # exception if any are unused by the end of the loop. It is easy to misname # a variable during this configuration, in which case without this check, we # would fail to warm-start silently. prev_var_name_used = set() vocab_info_used = set() # Group the vocabless vars into one call to init_from_checkpoint. vocabless_vars = {} for var_name, variable in six.iteritems(grouped_variables): prev_var_name = var_name_to_prev_var_name.get(var_name) if prev_var_name: prev_var_name_used.add(var_name) vocab_info = var_name_to_vocab_info.get(var_name) if vocab_info: vocab_info_used.add(var_name) warmstarted_count += 1 logging.debug( "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" " initializer: {}".format( var_name, vocab_info.new_vocab, vocab_info.new_vocab_size, vocab_info.old_vocab, (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0 else "All"), vocab_info.num_oov_buckets, prev_var_name or "Unchanged", vocab_info.backup_initializer or "zero-initialized")) _warm_start_var_with_vocab( variable, current_vocab_path=vocab_info.new_vocab, current_vocab_size=vocab_info.new_vocab_size, prev_ckpt=ckpt_to_initialize_from, prev_vocab_path=vocab_info.old_vocab, previous_vocab_size=vocab_info.old_vocab_size, current_oov_buckets=vocab_info.num_oov_buckets, prev_tensor_name=prev_var_name, initializer=vocab_info.backup_initializer, axis=vocab_info.axis) else: # For the special value of vars_to_warm_start = None, # we only warm-start variables with explicitly specified vocabularies. if vars_to_warm_start: warmstarted_count += 1 logging.debug("Warm-starting variable: {}; prev_var_name: {}".format( var_name, prev_var_name or "Unchanged")) # Because we use a default empty list in grouped_variables, single # unpartitioned variables will be lists here, which we rectify in order # for init_from_checkpoint logic to work correctly. if len(variable) == 1: variable = variable[0] prev_tensor_name, var = _get_var_info(variable, prev_var_name) vocabless_vars[prev_tensor_name] = var checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars) vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used logging.info("Warm-started %d variables.", warmstarted_count) if vocab_info_not_used: raise ValueError( "You provided the following variables in " "var_name_to_vocab_info that were not used: {0}. " " Perhaps you misspelled them? Here is the list of viable variable " "names: {1}".format(vocab_info_not_used, grouped_variables.keys())) return prev_var_name_used, set(grouped_variables.keys())
def warm_start(ckpt_to_initialize_from, vars_to_warm_start=".*", var_name_to_vocab_info=None, var_name_to_prev_var_name=None): """Warm-starts a model using the given settings. If you are using a tf.estimator.Estimator, this will automatically be called during training. Args: ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. vars_to_warm_start: [Optional] One of the following: - A regular expression (string) that captures which variables to warm-start (see tf.compat.v1.get_collection). This expression will only consider variables in the TRAINABLE_VARIABLES collection -- if you need to warm-start non_TRAINABLE vars (such as optimizer accumulators or batch norm statistics), please use the below option. - A list of Variables to warm-start. If you do not have access to the `Variable` objects at the call site, please use the below option. - A list of strings, each a regex scope provided to tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see tf.compat.v1.get_collection). For backwards compatibility reasons, this is separate from the single-string argument type. - `None`, in which case only variables specified in `var_name_to_vocab_info` will be warm-started. Defaults to `'.*'`, which warm-starts all variables in the TRAINABLE_VARIABLES collection. Note that this excludes variables such as accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to `tf.estimator.VocabInfo`. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable is assumed to have no (changes to) vocabulary. var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to name of the previously-trained variable in `ckpt_to_initialize_from`. If not explicitly provided, the name of the variable is assumed to be same between previous checkpoint and current model. Note that this has no effect on the set of variables that is warm-started, and only controls name mapping (use `vars_to_warm_start` for controlling what variables to warm-start). Raises: ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo configuration for variable names that are not used. This is to ensure a stronger check for variable configuration than relying on users to examine the logs. """ if var_name_to_vocab_info is None: var_name_to_vocab_info = {} if var_name_to_prev_var_name is None: var_name_to_prev_var_name = {} logging.info("Warm-starting from: %s", (ckpt_to_initialize_from, )) grouped_variables = _get_grouped_variables(vars_to_warm_start) # Keep track of which var_names in var_name_to_prev_var_name and # var_name_to_vocab_info have been used. Err on the safer side by throwing an # exception if any are unused by the end of the loop. It is easy to misname # a variable during this configuration, in which case without this check, we # would fail to warm-start silently. prev_var_name_used = set() vocab_info_used = set() # Group the vocabless vars into one call to init_from_checkpoint. vocabless_vars = {} for var_name, variable in six.iteritems(grouped_variables): prev_var_name = var_name_to_prev_var_name.get(var_name) if prev_var_name: prev_var_name_used.add(var_name) vocab_info = var_name_to_vocab_info.get(var_name) if vocab_info: vocab_info_used.add(var_name) logging.debug( "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" " initializer: {}".format( var_name, vocab_info.new_vocab, vocab_info.new_vocab_size, vocab_info.old_vocab, (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0 else "All"), vocab_info.num_oov_buckets, prev_var_name or "Unchanged", vocab_info.backup_initializer or "zero-initialized")) _warm_start_var_with_vocab( variable, current_vocab_path=vocab_info.new_vocab, current_vocab_size=vocab_info.new_vocab_size, prev_ckpt=ckpt_to_initialize_from, prev_vocab_path=vocab_info.old_vocab, previous_vocab_size=vocab_info.old_vocab_size, current_oov_buckets=vocab_info.num_oov_buckets, prev_tensor_name=prev_var_name, initializer=vocab_info.backup_initializer, axis=vocab_info.axis) else: # For the special value of vars_to_warm_start = None, # we only warm-start variables with explicitly specified vocabularies. if vars_to_warm_start: logging.debug( "Warm-starting variable: {}; prev_var_name: {}".format( var_name, prev_var_name or "Unchanged")) # Because we use a default empty list in grouped_variables, single # unpartitioned variables will be lists here, which we rectify in order # for init_from_checkpoint logic to work correctly. if len(variable) == 1: variable = variable[0] prev_tensor_name, var = _get_var_info(variable, prev_var_name) vocabless_vars[prev_tensor_name] = var checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars) prev_var_name_not_used = set( var_name_to_prev_var_name.keys()) - prev_var_name_used vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used if prev_var_name_not_used: raise ValueError( "You provided the following variables in " "var_name_to_prev_var_name that were not used: " "{0}. Perhaps you misspelled them? Here is the list of viable " "variable names: {1}".format(prev_var_name_not_used, grouped_variables.keys())) if vocab_info_not_used: raise ValueError( "You provided the following variables in " "var_name_to_vocab_info that were not used: {0}. " " Perhaps you misspelled them? Here is the list of viable variable " "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
def _warmstart_var(var, prev_ckpt, prev_tensor_name=None): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Args: var: Current graph's variable that needs to be warm-started (initialized). Can be either of the following: (i) `Variable` (ii) `ResourceVariable` (iii) `PartitionedVariable` (iv) list of `Variable` and/or `PartitionedVariable`: The list may contain one or more variables that has been sharded. For example: [Variable('a/part_0'), Variable('b/part_0'), Variable('a/part_1'), PartitionedVariable([Variable('c/part_0'), Variable('c/part_1')])] where we have three whole Variables represented ('a', 'b', and 'c'). prev_ckpt: A string specifying the directory with checkpoint file(s) or path to checkpoint. The given checkpoint must have tensor with name `prev_tensor_name` (if not None) or tensor with name same as given `var`. prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. Raises: ValueError: If prev_tensor_name is not None, but the given var represents more than one Variable. TypeError: If var is not one of the allowed types. """ if _is_variable(var): current_var_name = _infer_var_name([var]) elif isinstance(var, variables.PartitionedVariable): current_var_name = _infer_var_name([var]) var = var._get_variable_list() # pylint: disable=protected-access elif (isinstance(var, list) and all( _is_variable(v) or isinstance(v, variables.PartitionedVariable) for v in var)): # Convert length-1 lists of vars to single tf.Variables. This ensures that # checkpoint_utils.init_from_checkpoint() doesn't incorrectly assume # slice info is present. if len(var) == 1: current_var_name = _infer_var_name(var) var = var[0] else: # If we have multiple elements in var, we cannot assume they all # represent the same Variable. name_to_var_dict = saver.BaseSaverBuilder.OpListToDict( var, convert_variable_to_tensor=False) if prev_tensor_name: # Providing a prev_tensor_name is only viable if var representes a # single Variable. if len(name_to_var_dict) > 1: raise ValueError( "var represented more than one Variable, but " "prev_tensor_name was provided.") checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var}) else: # OpListToDict gives us roughly what we need, but # the values in the dict may be PartitionedVariables (which # init_from_checkpoint does not expect) that we need to convert to # lists. name_to_var_dict_fixed = {} for name, var in six.iteritems(name_to_var_dict): if isinstance(var, variables.PartitionedVariable): name_to_var_dict_fixed[name] = var._get_variable_list() # pylint: disable=protected-access else: name_to_var_dict_fixed[name] = var checkpoint_utils.init_from_checkpoint(prev_ckpt, name_to_var_dict_fixed) return else: raise TypeError( "var MUST be one of the following: a Variable, PartitionedVariable, or " "list of Variable's and/or PartitionedVariable's, but is {}". format(type(var))) if not prev_tensor_name: # Assume tensor name remains the same. prev_tensor_name = current_var_name checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
def warm_start(ckpt_to_initialize_from, vars_to_warm_start=".*", var_name_to_vocab_info=None, var_name_to_prev_var_name=None): """Warm-starts a model using the given settings. If you are using a tf.estimator.Estimator, this will automatically be called during training. Args: ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. vars_to_warm_start: [Optional] One of the following: - A regular expression (string) that captures which variables to warm-start (see tf.get_collection). This expression will only consider variables in the TRAINABLE_VARIABLES collection -- if you need to warm-start non_TRAINABLE vars (such as optimizer accumulators or batch norm statistics), please use the below option. - A list of Variables to warm-start. If you do not have access to the `Variable` objects at the call site, please use the below option. - A list of strings, each a regex scope provided to tf.get_collection with GLOBAL_VARIABLES (please see tf.get_collection). For backwards compatibility reasons, this is separate from the single-string argument type. - `None`, in which case only variables specified in `var_name_to_vocab_info` will be warm-started. Defaults to `'.*'`, which warm-starts all variables in the TRAINABLE_VARIABLES collection. Note that this excludes variables such as accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to `tf.estimator.VocabInfo`. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable is assumed to have no (changes to) vocabulary. var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to name of the previously-trained variable in `ckpt_to_initialize_from`. If not explicitly provided, the name of the variable is assumed to be same between previous checkpoint and current model. Note that this has no effect on the set of variables that is warm-started, and only controls name mapping (use `vars_to_warm_start` for controlling what variables to warm-start). Raises: ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo configuration for variable names that are not used. This is to ensure a stronger check for variable configuration than relying on users to examine the logs. """ if var_name_to_vocab_info is None: var_name_to_vocab_info = {} if var_name_to_prev_var_name is None: var_name_to_prev_var_name = {} logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,)) grouped_variables = _get_grouped_variables(vars_to_warm_start) # Keep track of which var_names in var_name_to_prev_var_name and # var_name_to_vocab_info have been used. Err on the safer side by throwing an # exception if any are unused by the end of the loop. It is easy to misname # a variable during this configuration, in which case without this check, we # would fail to warm-start silently. prev_var_name_used = set() vocab_info_used = set() # Group the vocabless vars into one call to init_from_checkpoint. vocabless_vars = {} for var_name, variable in six.iteritems(grouped_variables): prev_var_name = var_name_to_prev_var_name.get(var_name) if prev_var_name: prev_var_name_used.add(var_name) vocab_info = var_name_to_vocab_info.get(var_name) if vocab_info: vocab_info_used.add(var_name) logging.info( "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" " initializer: {}".format( var_name, vocab_info.new_vocab, vocab_info.new_vocab_size, vocab_info.old_vocab, (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0 else "All"), vocab_info.num_oov_buckets, prev_var_name or "Unchanged", vocab_info.backup_initializer or "zero-initialized")) _warm_start_var_with_vocab( variable, current_vocab_path=vocab_info.new_vocab, current_vocab_size=vocab_info.new_vocab_size, prev_ckpt=ckpt_to_initialize_from, prev_vocab_path=vocab_info.old_vocab, previous_vocab_size=vocab_info.old_vocab_size, current_oov_buckets=vocab_info.num_oov_buckets, prev_tensor_name=prev_var_name, initializer=vocab_info.backup_initializer, axis=vocab_info.axis) else: # For the special value of vars_to_warm_start = None, # we only warm-start variables with explicitly specified vocabularies. if vars_to_warm_start: logging.info("Warm-starting variable: {}; prev_var_name: {}".format( var_name, prev_var_name or "Unchanged")) # Because we use a default empty list in grouped_variables, single # unpartitioned variables will be lists here, which we rectify in order # for init_from_checkpoint logic to work correctly. if len(variable) == 1: variable = variable[0] prev_tensor_name, var = _get_var_info(variable, prev_var_name) vocabless_vars[prev_tensor_name] = var checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars) prev_var_name_not_used = set( var_name_to_prev_var_name.keys()) - prev_var_name_used vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used if prev_var_name_not_used: raise ValueError( "You provided the following variables in " "var_name_to_prev_var_name that were not used: " "{0}. Perhaps you misspelled them? Here is the list of viable " "variable names: {1}".format(prev_var_name_not_used, grouped_variables.keys())) if vocab_info_not_used: raise ValueError( "You provided the following variables in " "var_name_to_vocab_info that were not used: {0}. " " Perhaps you misspelled them? Here is the list of viable variable " "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable( name="my1", shape=[100, 100], initializer=init_ops.zeros_initializer(), partitioner=partitioned_variables. min_max_variable_partitioner(max_partitions=5, axis=0, min_slice_size=8 << 10)) my1_var_list = my1._get_variable_list() # Create another variable with different partitions than the variable in # the checkpoint. with variable_scope.variable_scope("some_other_scope"): my2 = variable_scope.get_variable( name="var1", shape=[100, 100], initializer=init_ops.zeros_initializer(), partitioner=partitioned_variables. min_max_variable_partitioner(max_partitions=5, axis=0, min_slice_size=16 << 10)) my2_var_list = my2._get_variable_list() checkpoint_utils.init_from_checkpoint( checkpoint_dir, { "scope/var1": "some_scope/my1", "scope/": "some_other_scope/" }) session.run(variables.global_variables_initializer()) my1_values = session.run(my1_var_list) self.assertAllEqual(my1_values, v1) my2_values = session.run(my2_var_list) # Verify we created different number of partitions. self.assertNotEquals(len(my2_values), len(v1)) # Verify the values were correctly initialized inspite of different # partitions. full_my2_values = np.concatenate(my2_values, axis=0) full_v1_values = np.concatenate(v1, axis=0) self.assertAllEqual(full_my2_values, full_v1_values) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable( name="my1", shape=[100, 100], initializer=init_ops.truncated_normal_initializer(0.5), partitioner=partitioned_variables. min_max_variable_partitioner(max_partitions=5, axis=0, min_slice_size=8 << 10)) my1_var_list = my1._get_variable_list() checkpoint_utils.init_from_checkpoint( checkpoint_dir, { "scope/var1": my1_var_list, }) session.run(variables.global_variables_initializer()) my1_values = session.run(my1_var_list) self.assertAllEqual(my1_values, v1)
def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.session(graph=g) as session: with variable_scope.variable_scope("some_scope"): _ = variable_scope.get_variable("my1", [10, 10]) _ = variable_scope.get_variable( "my2", [1, 10], dtype=dtypes.int64, initializer=init_ops.zeros_initializer()) # No directory. with self.assertRaises(errors_impl.OpError): checkpoint_utils.init_from_checkpoint( "no_dir", {"var1": "some_scope/my1"}) # No variable in checkpoint. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"no_var": "some_scope/my1"}) # No variable in the graph. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"var3": "some_scope/no_var"}) # Shape mismatch. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"var1": "some_scope/my1"}) # Variable 'my1' and 'my2' are missing in given checkpoint scope. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope/": "some_scope/"}) # Mapping is not to scope name. with self.assertRaises(ValueError): checkpoint_utils.init_from_checkpoint( checkpoint_dir, {"useful_scope": "some_scope/"})