Beispiel #1
0
    def __call__(self, config):
        print("Preprocess the data")
        train = Corpus.load(config.ftrain)
        dev = Corpus.load(config.fdev)
        test = Corpus.load(config.ftest)
        if os.path.exists(config.vocab):
            vocab = torch.load(config.vocab)
        else:
            vocab = Vocab.from_corpus(corpus=train, min_freq=2)
            vocab.read_embeddings(Pretrained.load(config.fembed, config.unk))
            torch.save(vocab, config.vocab)
        config.update({
            'n_words': vocab.n_train_words,
            'n_tags': vocab.n_tags,
            'n_rels': vocab.n_rels,
            'n_chars': vocab.n_chars,
            'pad_index': vocab.pad_index,
            'unk_index': vocab.unk_index
        })
        print(vocab)

        print("Load the dataset")
        trainset = TextDataset(vocab.numericalize(train))
        devset = TextDataset(vocab.numericalize(dev))
        testset = TextDataset(vocab.numericalize(test))
        # set the data loaders
        train_loader = batchify(dataset=trainset,
                                batch_size=config.batch_size,
                                n_buckets=config.buckets,
                                shuffle=True)
        dev_loader = batchify(dataset=devset,
                              batch_size=config.batch_size,
                              n_buckets=config.buckets)
        test_loader = batchify(dataset=testset,
                               batch_size=config.batch_size,
                               n_buckets=config.buckets)
        print(f"{'train:':6} {len(trainset):5} sentences in total, "
              f"{len(train_loader):3} batches provided")
        print(f"{'dev:':6} {len(devset):5} sentences in total, "
              f"{len(dev_loader):3} batches provided")
        print(f"{'test:':6} {len(testset):5} sentences in total, "
              f"{len(test_loader):3} batches provided")

        print("Create the models")
        assert config.train_task in ['parser', 'tagger']
        is_training_parser = config.train_task == 'parser'

        if config.augmentation_training:
            aug_test = Corpus.load(config.augmentation_test_file)
            aug_testset = TextDataset(vocab.numericalize(aug_test))
            aug_test_loader = batchify(dataset=aug_testset,
                                       batch_size=config.batch_size,
                                       n_buckets=config.buckets)
            print(f"{'test:':6} {len(aug_testset):5} sentences in total, "
                  f"{len(aug_test_loader):3} batches provided")

        if is_training_parser:
            model = init_parser(config, vocab.embeddings)
            task = ParserTask(vocab, model)
            best_e, best_metric = 1, ParserMetric()
        else:
            model = PosTagger(config, vocab.embeddings)
            task = TaggerTask(vocab, model)
            best_e, best_metric = 1, TaggerMetric()

        if torch.cuda.is_available():
            model = model.cuda()
        print(f"{model}\n")

        total_time = timedelta()
        # best_e, best_metric = 1, TaggerMetric()
        task.optimizer = Adam(task.model.parameters(), config.lr,
                              (config.beta_1, config.beta_2), config.epsilon)
        task.scheduler = ExponentialLR(task.optimizer,
                                       config.decay**(1 / config.steps))

        for epoch in range(1, config.epochs + 1):
            start = datetime.now()
            # train one epoch and update the parameters
            task.train(train_loader)

            print(f"Epoch {epoch} / {config.epochs}:")
            loss, train_metric = task.evaluate(train_loader, config.punct)
            print(f"{'train:':6} Loss: {loss:.4f} {train_metric}")
            loss, dev_metric = task.evaluate(dev_loader, config.punct)
            print(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}")
            loss, test_metric = task.evaluate(test_loader, config.punct)
            print(f"{'test:':6} Loss: {loss:.4f} {test_metric}")
            if config.augmentation_training:
                loss, aug_test_metric = task.evaluate(aug_test_loader,
                                                      config.punct)
                print(f"{'test:':6} Loss: {loss:.4f} {aug_test_metric}")

            t = datetime.now() - start

            if dev_metric > best_metric and epoch > config.patience:
                best_e, best_metric = epoch, dev_metric
                if is_training_parser:
                    task.model.save(config.parser_model + f".{best_e}")
                else:
                    task.model.save(config.tagger_model + f".{best_e}")
                print(f"{t}s elapsed (saved)\n")
            else:
                print(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= config.patience:
                break

        if is_training_parser:
            copyfile(config.parser_model + f'.{best_e}',
                     config.parser_model + '.best')
            task.model = load_parser(config.parser_model + f".{best_e}")
        else:
            copyfile(config.tagger_model + f'.{best_e}',
                     config.tagger_model + '.best')
            task.model = PosTagger.load(config.tagger_model + f".{best_e}")
        loss, metric = task.evaluate(test_loader, config.punct)

        print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        print(f"the score of test at epoch {best_e} is {metric.score:.2%}")

        if config.augmentation_training:
            loss, metric = task.evaluate(aug_test_loader, config.punct)
            print(
                f"the score of aug test at epoch {best_e} is {metric.score:.2%}"
            )

        print(f"average time of each epoch is {total_time / epoch}s")
        print(f"{total_time}s elapsed")
Beispiel #2
0
        spans = gen_spans(sent)
        dist[len(filter_spans(spans, 4, 12))] += 1
    total = sum(dist.values())
    for k in range(10):
        print(k, '->', dist[k] / total * 100)
    # print(dist)


if __name__ == "__main__":
    # Corpus.load('/home/zhouyi/en_ewt-ud-test.txt')
    # exit()

    # spans = [(2, 5), (7, 10), (11, 14), (16, 23), (19, 23), (20, 23), (28, 33), (38, 41)]
    # print(filter_spans(spans))

    corpus = Corpus.load(
        "/disks/sdb/zjiehang/zhou_data_new/ptb/ptb_test_3.3.0.sd")
    subtree_distribution(corpus)
    # min_span_len = 5
    # max_span_len = 10
    # tt = 0
    # for sid, sent in enumerate(corpus):
    #     if len(sent.ID) < 15:
    #         continue
    #     # min_span_len = len(sent.ID) * 0.2
    #     # max_span_len = len(sent.ID) * 0.3
    #     min_span_len = 5
    #     max_span_len = 8
    #     span = gen_spans(sent)
    #     span = list(filter(lambda ele: min_span_len <= ele[1] - ele[0] <= max_span_len, span))
    #     # num += len(span)
Beispiel #3
0
    crf_tagger.train(train_sents, model_file)
    return crf_tagger


def gen_tagged_sents(corpus: Corpus):
    all_sents = []
    for sentence in corpus:
        sent = []
        for i in range(len(sentence.ID)):
            sent.append((sentence.FORM[i], sentence.POS[i]))
        all_sents.append(sent)
    return all_sents


if __name__ == "__main__":
    train_corpus = Corpus.load(
        "/disks/sdb/zjiehang/zhou_data/ptb/ptb_train_3.3.0.sd")
    test_corpus = Corpus.load(
        "/disks/sdb/zjiehang/zhou_data/ptb/ptb_test_3.3.0.sd")

    # vocab = torch.load("/disks/sdb/zjiehang/zhou_data/ptb/vocab")
    # gen_tag_dict(train_corpus, vocab)

    """
        Below is performance benchmark of different taggers.
    """
    # tagger = train_gram_tagger(train_corpus, 1)   # 0.9279514501446616
    # tagger = train_gram_tagger(train_corpus, 2)   # 0.947216145649566
    # tagger = train_gram_tagger(train_corpus, 3)   # 0.9476395455507727
    # tagger = train_crf_tagger(train_corpus)       # 0.9715263566438501  -> long training time
    ## tagger = CRFTagger()
    # tagger.set_model_file("/disks/sdb/zjiehang/zhou_data/saved_models/crftagger")
Beispiel #4
0
    def __call__(self, config):
        loader = self.pre_attack(config)
        self.parser.register_backward_hook(self.extract_embed_grad)
        # log_config('whitelog.txt',
        #            log_path=config.workspace,
        #            default_target='cf')

        train_corpus = Corpus.load(config.ftrain)
        self.tag_filter = generate_tag_filter(train_corpus, self.vocab)
        # corpus = Corpus.load(config.fdata)
        # dataset = TextDataset(self.vocab.numericalize(corpus, True))
        # # set the data loader
        # loader = DataLoader(dataset=dataset,
        #                     collate_fn=collate_fn)

        # def embed_hook(module, grad_in, grad_out):
        #     self.vals["embed_grad"] = grad_out[0]

        # dpattack.pretrained.register_backward_hook(embed_hook)

        raw_metrics = Metric()
        attack_metrics = Metric()

        log('dist measure', config.hk_dist_measure)

        # batch size == 1
        for sid, (words, tags, arcs, rels) in enumerate(loader):
            # if sid > 10:
            #     break

            raw_words = words.clone()
            words_text = self.get_seqs_name(words)
            tags_text = self.get_tags_name(tags)

            log('****** {}: \n\t{}\n\t{}'.format(sid, " ".join(words_text),
                                                 " ".join(tags_text)))

            self.vals['forbidden'] = [
                self.vocab.unk_index, self.vocab.pad_index
            ]
            for pgdid in range(100):
                result = self.single_hack(words,
                                          tags,
                                          arcs,
                                          rels,
                                          dist_measure=config.hk_dist_measure,
                                          raw_words=raw_words)
                if result['code'] == 200:
                    raw_metrics += result['raw_metric']
                    attack_metrics += result['attack_metric']
                    log('attack successfully at step {}'.format(pgdid))
                    break
                elif result['code'] == 404:
                    raw_metrics += result['raw_metric']
                    attack_metrics += result['raw_metric']
                    log('attack failed at step {}'.format(pgdid))
                    break
                elif result['code'] == 300:
                    if pgdid == 99:
                        raw_metrics += result['raw_metric']
                        attack_metrics += result['raw_metric']
                        log('attack failed at step {}'.format(pgdid))
                    else:
                        words = result['words']
            log()

            log('Aggregated result: {} --> {}'.format(raw_metrics,
                                                      attack_metrics),
                target='cf')