Exemple #1
0
    def test_delete_variables(self):
        params = Parameters()
        embed_layers = ["test_1", "test_2"]
        slot_names = ["m", "v"]
        dim = 8
        for layer in embed_layers:
            params.embedding_params[layer] = EmbeddingTable(layer, dim)
            for slot in slot_names:
                slot_key = get_slot_table_name(layer, slot)
                params.embedding_params[slot_key] = EmbeddingTable(
                    slot_key, dim, "0.0", True)

        opt = Adam()
        opt_wrapper = OptimizerWrapper(opt, None, params.get_embedding_param,
                                       params.set_embedding_param)

        opt_wrapper._init_thread_local()
        for name in embed_layers:
            opt_wrapper._tls._unique_ids_all_layers[name] = np.ndarray(
                [2], np.int32)
            opt_wrapper._create_embedding_variable(
                name, np.ndarray([2, dim], np.float32))
            opt_wrapper._get_slot_and_set_to_optimizer(name)

        self.assertTrue(len(opt._weights) == 4)
        self.assertTrue(len(opt._slots) == 2)
        for slot_dict in opt._slots.values():
            self.assertTrue(len(slot_dict) == 2)

        opt_wrapper._delete_slots_and_weights_in_optimizer()
        self.assertTrue(len(opt._weights) == 0)
        self.assertTrue(len(opt._slots) == 0)
Exemple #2
0
    def test_set_slot_to_optimizer(self):
        embed_name = "test_emb"
        indices = np.ndarray([2], dtype=np.int32)
        embed_values = np.ndarray([2, 2], dtype=np.float32)
        slot_values = {
            "m": np.ndarray([2, 2], dtype=np.float32),
            "v": np.ndarray([2, 2], dtype=np.float32),
        }
        params = Parameters()
        params.embedding_params[embed_name] = EmbeddingTable(embed_name, 8)
        for slot in ["m", "v"]:
            slot_table_name = get_slot_table_name(embed_name, slot)
            params.embedding_params[slot_table_name] = EmbeddingTable(
                slot_table_name, 2, "0.0", True)

        opt = Adam()
        opt_wrapper = OptimizerWrapper(opt, None, params.get_embedding_param)
        opt_wrapper._init_thread_local()

        opt_wrapper._tls._unique_ids_all_layers[embed_name] = indices
        opt_wrapper._create_embedding_variable(embed_name, embed_values)
        opt_wrapper._get_slot_and_set_to_optimizer(embed_name)

        self.assertEqual(len(opt._slots), 1)
        opt_slots = list(opt._slots.values())[0]
        self.assertEqual(sorted(opt_slots.keys()), ["m", "v"])
        for name in ["m", "v"]:
            self.assertTrue(
                np.allclose(opt_slots[name].numpy(), slot_values[name]))