def setup(self, config): self.config = config print("Load the models") vocab = torch.load(config.vocab) # type: Vocab parser = load_parser(fetch_best_ckpt_name(config.parser_model)) self.task = ParserTask(vocab, parser) print("Load the dataset") train_corpus = Corpus.load(config.ftrain) if config.hk_training_set == 'on': self.corpus = train_corpus else: self.corpus = Corpus.load(config.fdata) dataset = TextDataset(vocab.numericalize(self.corpus, True)) # set the data loader self.loader = DataLoader(dataset=dataset, collate_fn=collate_fn) def embed_backward_hook(module, grad_in, grad_out): ram_write('embed_grad', grad_out[0]) self.parser.char_lstm.embed.register_backward_hook(embed_backward_hook) # self.parser.embed.register_backward_hook(embed_backward_hook) self.parser.eval() self.embed_searcher = EmbeddingSearcher( embed=self.parser.char_lstm.embed.weight, idx2word=lambda x: self.vocab.chars[x], word2idx=lambda x: self.vocab.char_dict[x]) random.seed(1) torch.manual_seed(1)
def __call__(self, config): print("Load the models") vocab = torch.load(config.vocab) parser = load_parser(fetch_best_ckpt_name(config.parser_model)) task = ParserTask(vocab, parser) if config.pred_tag: tagger = PosTagger.load(fetch_best_ckpt_name(config.tagger_model)) else: tagger = None print("Load the dataset") corpus = Corpus.load(config.fdata) dataset = TextDataset(vocab.numericalize(corpus)) # set the data loader loader = batchify(dataset, config.batch_size, config.buckets) print("Evaluate the dataset") loss, metric = task.evaluate(loader, config.punct, tagger, True) print(f"Loss: {loss:.4f} {metric}")
def __call__(self, config): print("Load the models") vocab = torch.load(config.vocab) parser = load_parser(fetch_best_ckpt_name(config.parser_model)) task = ParserTask(vocab, parser) if config.pred_tag: tagger = PosTagger.load(fetch_best_ckpt_name(config.tagger_model)) else: tagger = None print("Load the dataset") corpus = Corpus.load(config.fdata) dataset = TextDataset(vocab.numericalize(corpus, training=False)) # set the data loader loader = batchify(dataset, config.batch_size) print("Make predictions on the dataset") corpus.tags, corpus.heads, corpus.rels = task.predict(loader, tagger) saved_path = '{}/raw_result.conllx'.format(config.result_path) print(f"Save the predicted result to {saved_path}") corpus.save(saved_path)
def compare_idxes(nbr1, nbr2): nbr1 = set(cast_list(nbr1)) nbr2 = set(cast_list(nbr2)) inter = nbr1.intersection(nbr2) return len(inter) if __name__ == '__main__': from dpattack.libs.luna import fetch_best_ckpt_name, cast_list, show_mean_std, time_record from dpattack.utils.parser_helper import load_parser from dpattack.utils.vocab import Vocab vocab = torch.load( "/disks/sdb/zjiehang/zhou_data/ptb/vocab") # type: Vocab parser = load_parser( fetch_best_ckpt_name( "/disks/sdb/zjiehang/zhou_data/saved_models/word_tag/lzynb")) # print(type(vocab)) esglv = EmbeddingSearcher(embed=vocab.embeddings, idx2word=lambda x: vocab.words[x], word2idx=lambda x: vocab.word_dict[x]) with time_record(): esglv.use_faiss_backend(False, True, 10, 1) for _ in range(10): esglv.find_neighbours(0, 100) # esglv.show_embedding_info() # esmdl = EmbeddingSearcher(embed=parser.embed.weight, # idx2word=lambda x: vocab.words[x], # word2idx=lambda x: vocab.word_dict[x]) # esmdl.show_embedding_info()
def nn_tagger(self) -> PosTagger: if self.__nn_tagger is None: self.__nn_tagger = PosTagger.load( fetch_best_ckpt_name(self.config.tagger_model)) return self.__nn_tagger