Example #1
0
    def setup(self, config):
        self.config = config

        print("Load the models")
        vocab = torch.load(config.vocab)  # type: Vocab
        parser = load_parser(fetch_best_ckpt_name(config.parser_model))

        self.task = ParserTask(vocab, parser)

        print("Load the dataset")

        train_corpus = Corpus.load(config.ftrain)

        if config.hk_training_set == 'on':
            self.corpus = train_corpus
        else:
            self.corpus = Corpus.load(config.fdata)
        dataset = TextDataset(vocab.numericalize(self.corpus, True))
        # set the data loader
        self.loader = DataLoader(dataset=dataset, collate_fn=collate_fn)

        def embed_backward_hook(module, grad_in, grad_out):
            ram_write('embed_grad', grad_out[0])

        self.parser.char_lstm.embed.register_backward_hook(embed_backward_hook)
        # self.parser.embed.register_backward_hook(embed_backward_hook)
        self.parser.eval()

        self.embed_searcher = EmbeddingSearcher(
            embed=self.parser.char_lstm.embed.weight,
            idx2word=lambda x: self.vocab.chars[x],
            word2idx=lambda x: self.vocab.char_dict[x])

        random.seed(1)
        torch.manual_seed(1)
Example #2
0
    def __call__(self, config):
        print("Load the models")
        vocab = torch.load(config.vocab)
        parser = load_parser(fetch_best_ckpt_name(config.parser_model))
        task = ParserTask(vocab, parser)
        if config.pred_tag:
            tagger = PosTagger.load(fetch_best_ckpt_name(config.tagger_model))
        else:
            tagger = None

        print("Load the dataset")
        corpus = Corpus.load(config.fdata)
        dataset = TextDataset(vocab.numericalize(corpus))
        # set the data loader
        loader = batchify(dataset, config.batch_size, config.buckets)

        print("Evaluate the dataset")
        loss, metric = task.evaluate(loader, config.punct, tagger, True)
        print(f"Loss: {loss:.4f} {metric}")
Example #3
0
    def __call__(self, config):
        print("Load the models")
        vocab = torch.load(config.vocab)
        parser = load_parser(fetch_best_ckpt_name(config.parser_model))
        task = ParserTask(vocab, parser)
        if config.pred_tag:
            tagger = PosTagger.load(fetch_best_ckpt_name(config.tagger_model))
        else:
            tagger = None

        print("Load the dataset")
        corpus = Corpus.load(config.fdata)
        dataset = TextDataset(vocab.numericalize(corpus, training=False))
        # set the data loader
        loader = batchify(dataset, config.batch_size)

        print("Make predictions on the dataset")
        corpus.tags, corpus.heads, corpus.rels = task.predict(loader, tagger)

        saved_path = '{}/raw_result.conllx'.format(config.result_path)
        print(f"Save the predicted result to {saved_path}")
        corpus.save(saved_path)
Example #4
0
def compare_idxes(nbr1, nbr2):
    nbr1 = set(cast_list(nbr1))
    nbr2 = set(cast_list(nbr2))
    inter = nbr1.intersection(nbr2)
    return len(inter)


if __name__ == '__main__':
    from dpattack.libs.luna import fetch_best_ckpt_name, cast_list, show_mean_std, time_record
    from dpattack.utils.parser_helper import load_parser
    from dpattack.utils.vocab import Vocab

    vocab = torch.load(
        "/disks/sdb/zjiehang/zhou_data/ptb/vocab")  # type: Vocab
    parser = load_parser(
        fetch_best_ckpt_name(
            "/disks/sdb/zjiehang/zhou_data/saved_models/word_tag/lzynb"))
    # print(type(vocab))
    esglv = EmbeddingSearcher(embed=vocab.embeddings,
                              idx2word=lambda x: vocab.words[x],
                              word2idx=lambda x: vocab.word_dict[x])

    with time_record():
        esglv.use_faiss_backend(False, True, 10, 1)
        for _ in range(10):
            esglv.find_neighbours(0, 100)

    # esglv.show_embedding_info()
    # esmdl = EmbeddingSearcher(embed=parser.embed.weight,
    #                           idx2word=lambda x: vocab.words[x],
    #                           word2idx=lambda x: vocab.word_dict[x])
    # esmdl.show_embedding_info()
Example #5
0
 def nn_tagger(self) -> PosTagger:
     if self.__nn_tagger is None:
         self.__nn_tagger = PosTagger.load(
             fetch_best_ckpt_name(self.config.tagger_model))
     return self.__nn_tagger