def test_delete_variables(self): params = Parameters() embed_layers = ["test_1", "test_2"] slot_names = ["m", "v"] dim = 8 for layer in embed_layers: params.embedding_params[layer] = EmbeddingTable(layer, dim) for slot in slot_names: slot_key = get_slot_table_name(layer, slot) params.embedding_params[slot_key] = EmbeddingTable( slot_key, dim, "0.0", True) opt = Adam() opt_wrapper = OptimizerWrapper(opt, None, params.get_embedding_param, params.set_embedding_param) opt_wrapper._init_thread_local() for name in embed_layers: opt_wrapper._tls._unique_ids_all_layers[name] = np.ndarray( [2], np.int32) opt_wrapper._create_embedding_variable( name, np.ndarray([2, dim], np.float32)) opt_wrapper._get_slot_and_set_to_optimizer(name) self.assertTrue(len(opt._weights) == 4) self.assertTrue(len(opt._slots) == 2) for slot_dict in opt._slots.values(): self.assertTrue(len(slot_dict) == 2) opt_wrapper._delete_slots_and_weights_in_optimizer() self.assertTrue(len(opt._weights) == 0) self.assertTrue(len(opt._slots) == 0)
def test_set_slot_to_optimizer(self): embed_name = "test_emb" indices = np.ndarray([2], dtype=np.int32) embed_values = np.ndarray([2, 2], dtype=np.float32) slot_values = { "m": np.ndarray([2, 2], dtype=np.float32), "v": np.ndarray([2, 2], dtype=np.float32), } params = Parameters() params.embedding_params[embed_name] = EmbeddingTable(embed_name, 8) for slot in ["m", "v"]: slot_table_name = get_slot_table_name(embed_name, slot) params.embedding_params[slot_table_name] = EmbeddingTable( slot_table_name, 2, "0.0", True) opt = Adam() opt_wrapper = OptimizerWrapper(opt, None, params.get_embedding_param) opt_wrapper._init_thread_local() opt_wrapper._tls._unique_ids_all_layers[embed_name] = indices opt_wrapper._create_embedding_variable(embed_name, embed_values) opt_wrapper._get_slot_and_set_to_optimizer(embed_name) self.assertEqual(len(opt._slots), 1) opt_slots = list(opt._slots.values())[0] self.assertEqual(sorted(opt_slots.keys()), ["m", "v"]) for name in ["m", "v"]: self.assertTrue( np.allclose(opt_slots[name].numpy(), slot_values[name]))