def test_from_object_based_checkpoint(self):
        dim = 10
        keys = [0, 1, 2, 3]
        values = [[k] * dim for k in keys]
        with ops.Graph().as_default() as g:
            with self.session(graph=g):
                with variable_scope.variable_scope("outer"):
                    prev_var = self._add_and_initialize_devar(
                        "prefix/devar", keys, values, dim)
                    # Save object-based checkpoint.
                    tracking_util.Checkpoint(v=prev_var).save(
                        os.path.join(self.get_temp_dir(), "checkpoint"))

        with ops.Graph().as_default() as g:
            with self.session(graph=g):
                with variable_scope.variable_scope("outer"):
                    var = self._add_devar("prefix/devar", dim)
                    ws_util.warm_start(
                        self.get_temp_dir(),
                        vars_to_warm_start=["outer/prefix/devar"])
                    self.evaluate(
                        deo.dynamic_embedding_variables_initializer())
                    _keys, _values = self._export_sorted_keys_and_values(var)
                    self.assertAllEqual(keys, _keys)
                    self.assertAllEqual(values, _values)
    def test_basic_devars(self):
        # Save checkpoint from which to warm-start.
        dim1, dim2 = 10, 20
        keys1, keys2 = [0, 1, 2, 3], [4, 5, 6]
        values1, values2 = [[k] * dim1 for k in keys1], [[k] * dim2
                                                         for k in keys2]
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                prev_var1 = self._add_and_initialize_devar(
                    "old_scope/var1", keys1, values1, dim1)
                prev_var2 = self._add_and_initialize_devar(
                    "old_scope/var2", keys2, values2, dim2)
                _write_checkpoint(self, sess)

        # New graph, new session with warm-starting.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                var1 = self._add_devar("new_scope/var1", dim1)
                self.assertAllEqual(0, self.evaluate(var1.size()))
                checkpoint_utils.init_mht_saveable_from_checkpoint(
                    self.get_temp_dir(), {
                        prev_table.saveable.name: table.saveable
                        for prev_table, table in zip(prev_var1.tables,
                                                     var1.tables)
                    })
                self.evaluate(deo.dynamic_embedding_variables_initializer())
                self.assertAllEqual(4, self.evaluate(var1.size()))
                keys, values = self._export_sorted_keys_and_values(var1)
                self.assertAllEqual(keys1, keys)
                self.assertAllEqual(values1, values)

                var2 = self._add_devar("new_scope/var2", dim2)
                self.assertAllEqual(0, self.evaluate(var2.size()))
                checkpoint_utils.init_mht_saveable_from_checkpoint(
                    self.get_temp_dir(), {
                        prev_table.saveable.name: table.saveable
                        for prev_table, table in zip(prev_var2.tables,
                                                     var2.tables)
                    })
                self.evaluate(deo.dynamic_embedding_variables_initializer())
                self.assertAllEqual(3, self.evaluate(var2.size()))
                keys, values = self._export_sorted_keys_and_values(var2)
                self.assertAllEqual(keys2, keys)
                self.assertAllEqual(values2, values)
    def test_use_var_name_to_prev_var_name(self):
        # Save checkpoint from which to warm-start.
        dim = 2
        keys = [0, 1]
        values = [[0, 0], [1, 1]]
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                with variable_scope.variable_scope("old_outer"):
                    prev_v1 = self._add_and_initialize_devar(
                        "v1", keys, values, dim)
                    prev_v2 = self._add_and_initialize_devar(
                        "v2", keys, values, dim)
                    _write_checkpoint(self, sess)

        # New graph, new session with warm-starting.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                with variable_scope.variable_scope("new_outer"):
                    v1 = self._add_devar("v1", dim)
                    v2 = self._add_devar("v2", dim)
                    self.assertAllEqual(0, self.evaluate(v1.size()))

                    ms_name_to_prev_ms_name = {}
                    for v, prev_v in zip([v1, v2], [prev_v1, prev_v2]):
                        for table, prev_table in zip(v.tables, prev_v.tables):
                            ms_name_to_prev_ms_name[table.saveable.full_name] = \
                              prev_table.saveable.full_name

                    # Unfound MutableHashTable._Saveable names raises ValueError
                    self.assertRaises(
                        ValueError,
                        ws_util.warm_start,
                        self.get_temp_dir(),
                        vars_to_warm_start=["new_outer/v1", "new_outer/v2"])

                    # Unused previous MutableHashTable._Saveable names raises ValueError.
                    self.assertRaises(
                        ValueError,
                        ws_util.warm_start,
                        self.get_temp_dir(),
                        vars_to_warm_start=["new_outer/v1"],
                        var_name_to_prev_var_name=ms_name_to_prev_ms_name)

                    ws_util.warm_start(
                        self.get_temp_dir(),
                        vars_to_warm_start=["new_outer/v1", "new_outer/v2"],
                        var_name_to_prev_var_name=ms_name_to_prev_ms_name)
                    self.evaluate(
                        deo.dynamic_embedding_variables_initializer())
                    # Verify the selection of weights were correctly warm-started (init
                    # overridden to ones).
                    for v in [v1, v2]:
                        keys, values = self._export_sorted_keys_and_values(v)
                        self.assertAllEqual(keys, keys)
                        self.assertAllEqual(values, values)
    def test_warm_start_optimizers(self):
        extra_run_step = 2
        for run_id, num_shards, k_dtype, d_dtype, init_mode, dim, run_step \
            in _next_run_step_config():
            error_msg = "Cond:{},{},{},{},{},{}".format(
                num_shards, k_dtype, d_dtype, init_mode, dim, run_step)
            with ops.Graph().as_default() as g:
                with self.session(graph=g,
                                  use_gpu=test_util.is_gpu_available(),
                                  config=default_config) as sess:
                    graph = TestGraph(k_dtype, d_dtype, dim, num_shards, 'var',
                                      'devar', run_id)
                    self.evaluate(variables.global_variables_initializer())
                    sess.run([graph.devar_init_op])
                    prev_x = sess.run([graph.x])[0]
                    for _ in range(run_step):
                        sess.run([graph.var_opt_op, graph.devar_opt_op])
                    sess.run([graph.var_loss, graph.devar_loss])
                    _write_checkpoint(self, sess)
                    for _ in range(extra_run_step):
                        sess.run([graph.var_opt_op, graph.devar_opt_op])
                    prev_var_loss, prev_devar_loss = sess.run(
                        [graph.var_loss, graph.devar_loss])
                    self.assertAllCloseAccordingToType(prev_var_loss,
                                                       prev_devar_loss,
                                                       msg=error_msg)

            with ops.Graph().as_default() as g:
                with self.session(graph=g,
                                  use_gpu=test_util.is_gpu_available(),
                                  config=default_config) as sess:
                    graph = TestGraph(k_dtype, d_dtype, dim, num_shards, 'var',
                                      'devar', run_id, prev_x)
                    ws_util.warm_start(self.get_temp_dir(),
                                       vars_to_warm_start=['.*'])
                    self.evaluate(variables.global_variables_initializer())
                    self.evaluate(
                        deo.dynamic_embedding_variables_initializer())
                    for _ in range(extra_run_step):
                        sess.run([graph.var_opt_op, graph.devar_opt_op])
                    var_loss, devar_loss = sess.run(
                        [graph.var_loss, graph.devar_loss])
                    self.assertAllCloseAccordingToType(var_loss,
                                                       prev_var_loss,
                                                       msg=error_msg)
                    self.assertAllCloseAccordingToType(devar_loss,
                                                       prev_devar_loss,
                                                       msg=error_msg)
    def test_both_vars_and_devars(self):
        # Save checkpoint from which to warm-start.
        dim1, dim2 = 10, 20
        keys1, keys2 = [0, 1, 2, 3], [4, 5, 6]
        values1, values2 = [[k] * dim1 for k in keys1], [[k] * dim2
                                                         for k in keys2]
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                var = variable_scope.get_variable(
                    "v1",
                    shape=[10, 1],
                    initializer=init_ops.ones_initializer())
                self.evaluate(variables.global_variables_initializer())
                prev_int_val = self.evaluate(var)
                self.assertAllEqual(np.ones([10, 1]), prev_int_val)
                devar1 = self._add_and_initialize_devar(
                    "devar1", keys1, values1, dim1)
                self.assertAllEqual(4, self.evaluate(devar1.size()))
                devar2 = self._add_and_initialize_devar(
                    "devar2", keys2, values2, dim2)
                self.assertAllEqual(3, self.evaluate(devar2.size()))
                _write_checkpoint(self, sess)

        # New graph, new session with warm-starting.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                # Initialize with zeros.
                var = variable_scope.get_variable(
                    "v1",
                    shape=[10, 1],
                    initializer=init_ops.zeros_initializer())
                devar1 = self._add_devar("devar1", dim1)
                devar2 = self._add_devar("devar2", dim2)
                self.assertAllEqual(0, self.evaluate(devar1.size()))
                self.assertAllEqual(0, self.evaluate(devar2.size()))
                ws_util.warm_start(self.get_temp_dir(),
                                   vars_to_warm_start=[var, devar1])
                self.evaluate(variables.global_variables_initializer())
                self.evaluate(deo.dynamic_embedding_variables_initializer())
                # Verify weights were correctly warm-started (init overridden to ones).
                self.assertAllEqual(var.eval(), prev_int_val)
                self.assertAllEqual(4, self.evaluate(devar1.size()))
                keys, values = self._export_sorted_keys_and_values(devar1)
                self.assertAllEqual(keys1, keys)
                self.assertAllEqual(values1, values)
                self.assertAllEqual(0, self.evaluate(devar2.size()))
    def test_list_of_regexes(self):
        # Save checkpoint from which to warm-start.
        dim = 2
        keys = [0, 1]
        values = [[0, 0], [1, 1]]
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                with variable_scope.variable_scope("outer"):
                    self._add_and_initialize_devar("v1", keys, values, dim)
                    self._add_and_initialize_devar("v1/Momentum", keys, values,
                                                   dim)
                    self._add_and_initialize_devar("v2", keys, values, dim)
                    self._add_and_initialize_devar("v2/Momentum", keys, values,
                                                   dim)
                    _write_checkpoint(self, sess)

        # New graph, new session with warm-starting.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                with variable_scope.variable_scope("outer"):
                    v1 = self._add_devar("v1", dim)
                    v1_momentum = self._add_devar("v1/Momentum", dim)
                    v2 = self._add_devar("v2", dim)
                    v2_momentum = self._add_devar("v2/Momentum", dim)
                    self.assertAllEqual(0, self.evaluate(v1.size()))

                    ws_util.warm_start(
                        self.get_temp_dir(),
                        # This warm-starts both v1 and v1/Momentum, but only
                        # v2 (and not v2/Momentum).
                        vars_to_warm_start=["outer/v1", "outer/v2$"])
                    self.evaluate(
                        deo.dynamic_embedding_variables_initializer())
                    # Verify the selection of weights were correctly warm-started (init
                    # overridden to ones).
                    for v in [v1, v1_momentum, v2]:
                        keys, values = self._export_sorted_keys_and_values(v)
                        self.assertAllEqual(keys, keys)
                        self.assertAllEqual(values, values)
                    self.assertAllEqual(0, self.evaluate(v2_momentum.size()))