def test_from_object_based_checkpoint(self): dim = 10 keys = [0, 1, 2, 3] values = [[k] * dim for k in keys] with ops.Graph().as_default() as g: with self.session(graph=g): with variable_scope.variable_scope("outer"): prev_var = self._add_and_initialize_devar( "prefix/devar", keys, values, dim) # Save object-based checkpoint. tracking_util.Checkpoint(v=prev_var).save( os.path.join(self.get_temp_dir(), "checkpoint")) with ops.Graph().as_default() as g: with self.session(graph=g): with variable_scope.variable_scope("outer"): var = self._add_devar("prefix/devar", dim) ws_util.warm_start( self.get_temp_dir(), vars_to_warm_start=["outer/prefix/devar"]) self.evaluate( deo.dynamic_embedding_variables_initializer()) _keys, _values = self._export_sorted_keys_and_values(var) self.assertAllEqual(keys, _keys) self.assertAllEqual(values, _values)
def test_basic_devars(self): # Save checkpoint from which to warm-start. dim1, dim2 = 10, 20 keys1, keys2 = [0, 1, 2, 3], [4, 5, 6] values1, values2 = [[k] * dim1 for k in keys1], [[k] * dim2 for k in keys2] with ops.Graph().as_default() as g: with self.session(graph=g) as sess: prev_var1 = self._add_and_initialize_devar( "old_scope/var1", keys1, values1, dim1) prev_var2 = self._add_and_initialize_devar( "old_scope/var2", keys2, values2, dim2) _write_checkpoint(self, sess) # New graph, new session with warm-starting. with ops.Graph().as_default() as g: with self.session(graph=g) as sess: var1 = self._add_devar("new_scope/var1", dim1) self.assertAllEqual(0, self.evaluate(var1.size())) checkpoint_utils.init_mht_saveable_from_checkpoint( self.get_temp_dir(), { prev_table.saveable.name: table.saveable for prev_table, table in zip(prev_var1.tables, var1.tables) }) self.evaluate(deo.dynamic_embedding_variables_initializer()) self.assertAllEqual(4, self.evaluate(var1.size())) keys, values = self._export_sorted_keys_and_values(var1) self.assertAllEqual(keys1, keys) self.assertAllEqual(values1, values) var2 = self._add_devar("new_scope/var2", dim2) self.assertAllEqual(0, self.evaluate(var2.size())) checkpoint_utils.init_mht_saveable_from_checkpoint( self.get_temp_dir(), { prev_table.saveable.name: table.saveable for prev_table, table in zip(prev_var2.tables, var2.tables) }) self.evaluate(deo.dynamic_embedding_variables_initializer()) self.assertAllEqual(3, self.evaluate(var2.size())) keys, values = self._export_sorted_keys_and_values(var2) self.assertAllEqual(keys2, keys) self.assertAllEqual(values2, values)
def test_use_var_name_to_prev_var_name(self): # Save checkpoint from which to warm-start. dim = 2 keys = [0, 1] values = [[0, 0], [1, 1]] with ops.Graph().as_default() as g: with self.session(graph=g) as sess: with variable_scope.variable_scope("old_outer"): prev_v1 = self._add_and_initialize_devar( "v1", keys, values, dim) prev_v2 = self._add_and_initialize_devar( "v2", keys, values, dim) _write_checkpoint(self, sess) # New graph, new session with warm-starting. with ops.Graph().as_default() as g: with self.session(graph=g) as sess: with variable_scope.variable_scope("new_outer"): v1 = self._add_devar("v1", dim) v2 = self._add_devar("v2", dim) self.assertAllEqual(0, self.evaluate(v1.size())) ms_name_to_prev_ms_name = {} for v, prev_v in zip([v1, v2], [prev_v1, prev_v2]): for table, prev_table in zip(v.tables, prev_v.tables): ms_name_to_prev_ms_name[table.saveable.full_name] = \ prev_table.saveable.full_name # Unfound MutableHashTable._Saveable names raises ValueError self.assertRaises( ValueError, ws_util.warm_start, self.get_temp_dir(), vars_to_warm_start=["new_outer/v1", "new_outer/v2"]) # Unused previous MutableHashTable._Saveable names raises ValueError. self.assertRaises( ValueError, ws_util.warm_start, self.get_temp_dir(), vars_to_warm_start=["new_outer/v1"], var_name_to_prev_var_name=ms_name_to_prev_ms_name) ws_util.warm_start( self.get_temp_dir(), vars_to_warm_start=["new_outer/v1", "new_outer/v2"], var_name_to_prev_var_name=ms_name_to_prev_ms_name) self.evaluate( deo.dynamic_embedding_variables_initializer()) # Verify the selection of weights were correctly warm-started (init # overridden to ones). for v in [v1, v2]: keys, values = self._export_sorted_keys_and_values(v) self.assertAllEqual(keys, keys) self.assertAllEqual(values, values)
def test_warm_start_optimizers(self): extra_run_step = 2 for run_id, num_shards, k_dtype, d_dtype, init_mode, dim, run_step \ in _next_run_step_config(): error_msg = "Cond:{},{},{},{},{},{}".format( num_shards, k_dtype, d_dtype, init_mode, dim, run_step) with ops.Graph().as_default() as g: with self.session(graph=g, use_gpu=test_util.is_gpu_available(), config=default_config) as sess: graph = TestGraph(k_dtype, d_dtype, dim, num_shards, 'var', 'devar', run_id) self.evaluate(variables.global_variables_initializer()) sess.run([graph.devar_init_op]) prev_x = sess.run([graph.x])[0] for _ in range(run_step): sess.run([graph.var_opt_op, graph.devar_opt_op]) sess.run([graph.var_loss, graph.devar_loss]) _write_checkpoint(self, sess) for _ in range(extra_run_step): sess.run([graph.var_opt_op, graph.devar_opt_op]) prev_var_loss, prev_devar_loss = sess.run( [graph.var_loss, graph.devar_loss]) self.assertAllCloseAccordingToType(prev_var_loss, prev_devar_loss, msg=error_msg) with ops.Graph().as_default() as g: with self.session(graph=g, use_gpu=test_util.is_gpu_available(), config=default_config) as sess: graph = TestGraph(k_dtype, d_dtype, dim, num_shards, 'var', 'devar', run_id, prev_x) ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=['.*']) self.evaluate(variables.global_variables_initializer()) self.evaluate( deo.dynamic_embedding_variables_initializer()) for _ in range(extra_run_step): sess.run([graph.var_opt_op, graph.devar_opt_op]) var_loss, devar_loss = sess.run( [graph.var_loss, graph.devar_loss]) self.assertAllCloseAccordingToType(var_loss, prev_var_loss, msg=error_msg) self.assertAllCloseAccordingToType(devar_loss, prev_devar_loss, msg=error_msg)
def test_both_vars_and_devars(self): # Save checkpoint from which to warm-start. dim1, dim2 = 10, 20 keys1, keys2 = [0, 1, 2, 3], [4, 5, 6] values1, values2 = [[k] * dim1 for k in keys1], [[k] * dim2 for k in keys2] with ops.Graph().as_default() as g: with self.session(graph=g) as sess: var = variable_scope.get_variable( "v1", shape=[10, 1], initializer=init_ops.ones_initializer()) self.evaluate(variables.global_variables_initializer()) prev_int_val = self.evaluate(var) self.assertAllEqual(np.ones([10, 1]), prev_int_val) devar1 = self._add_and_initialize_devar( "devar1", keys1, values1, dim1) self.assertAllEqual(4, self.evaluate(devar1.size())) devar2 = self._add_and_initialize_devar( "devar2", keys2, values2, dim2) self.assertAllEqual(3, self.evaluate(devar2.size())) _write_checkpoint(self, sess) # New graph, new session with warm-starting. with ops.Graph().as_default() as g: with self.session(graph=g) as sess: # Initialize with zeros. var = variable_scope.get_variable( "v1", shape=[10, 1], initializer=init_ops.zeros_initializer()) devar1 = self._add_devar("devar1", dim1) devar2 = self._add_devar("devar2", dim2) self.assertAllEqual(0, self.evaluate(devar1.size())) self.assertAllEqual(0, self.evaluate(devar2.size())) ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var, devar1]) self.evaluate(variables.global_variables_initializer()) self.evaluate(deo.dynamic_embedding_variables_initializer()) # Verify weights were correctly warm-started (init overridden to ones). self.assertAllEqual(var.eval(), prev_int_val) self.assertAllEqual(4, self.evaluate(devar1.size())) keys, values = self._export_sorted_keys_and_values(devar1) self.assertAllEqual(keys1, keys) self.assertAllEqual(values1, values) self.assertAllEqual(0, self.evaluate(devar2.size()))
def test_list_of_regexes(self): # Save checkpoint from which to warm-start. dim = 2 keys = [0, 1] values = [[0, 0], [1, 1]] with ops.Graph().as_default() as g: with self.session(graph=g) as sess: with variable_scope.variable_scope("outer"): self._add_and_initialize_devar("v1", keys, values, dim) self._add_and_initialize_devar("v1/Momentum", keys, values, dim) self._add_and_initialize_devar("v2", keys, values, dim) self._add_and_initialize_devar("v2/Momentum", keys, values, dim) _write_checkpoint(self, sess) # New graph, new session with warm-starting. with ops.Graph().as_default() as g: with self.session(graph=g) as sess: with variable_scope.variable_scope("outer"): v1 = self._add_devar("v1", dim) v1_momentum = self._add_devar("v1/Momentum", dim) v2 = self._add_devar("v2", dim) v2_momentum = self._add_devar("v2/Momentum", dim) self.assertAllEqual(0, self.evaluate(v1.size())) ws_util.warm_start( self.get_temp_dir(), # This warm-starts both v1 and v1/Momentum, but only # v2 (and not v2/Momentum). vars_to_warm_start=["outer/v1", "outer/v2$"]) self.evaluate( deo.dynamic_embedding_variables_initializer()) # Verify the selection of weights were correctly warm-started (init # overridden to ones). for v in [v1, v1_momentum, v2]: keys, values = self._export_sorted_keys_and_values(v) self.assertAllEqual(keys, keys) self.assertAllEqual(values, values) self.assertAllEqual(0, self.evaluate(v2_momentum.size()))