Ejemplo n.º 1
0
    def evaluate(self, loader, punct=False, tagger=None, mst=False):
        self.model.eval()

        loss, metric = 0, ParserMetric()

        for words, tags, chars, arcs, rels in loader:
            mask = words.ne(self.vocab.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0

            tags = self.get_tags(words, tags, mask, tagger)

            s_arc, s_rel = self.model(words,
                                      is_chars_judger(self.model, tags, chars))

            loss += self.get_loss(s_arc[mask], s_rel[mask], arcs[mask],
                                  rels[mask])
            pred_arcs, pred_rels = self.decode(s_arc, s_rel, mask, mst)

            # ignore all punctuation if not specified
            if not punct:
                puncts = words.new_tensor(self.vocab.puncts)
                mask &= words.unsqueeze(-1).ne(puncts).all(-1)
            pred_arcs, pred_rels = pred_arcs[mask], pred_rels[mask]
            gold_arcs, gold_rels = arcs[mask], rels[mask]

            metric(pred_arcs, pred_rels, gold_arcs, gold_rels)
        loss /= len(loader)

        return loss, metric
Ejemplo n.º 2
0
    def partial_evaluate(self,
                         instance: tuple,
                         mask_idxs: List[int],
                         punct=False,
                         tagger=None,
                         mst=False,
                         return_metric=True):
        self.model.eval()

        loss, metric = 0, ParserMetric()

        words, tags, chars, arcs, rels = instance

        mask = words.ne(self.vocab.pad_index)
        # ignore the first token of each sentence
        mask[:, 0] = 0
        decode_mask = mask.clone()

        tags = self.get_tags(words, tags, mask, tagger)
        # ignore all punctuation if not specified
        if not punct:
            puncts = words.new_tensor(self.vocab.puncts)
            mask &= words.unsqueeze(-1).ne(puncts).all(-1)
        s_arc, s_rel = self.model(words,
                                  is_chars_judger(self.model, tags, chars))

        # mask given indices
        for idx in mask_idxs:
            mask[:, idx] = 0

        pred_arcs, pred_rels = self.decode(s_arc, s_rel, decode_mask, mst)

        # punct is ignored !!!
        pred_arcs, pred_rels = pred_arcs[mask], pred_rels[mask]
        gold_arcs, gold_rels = arcs[mask], rels[mask]

        # exmask = torch.ones_like(gold_arcs, dtype=torch.uint8)

        # for i, ele in enumerate(cast_list(gold_arcs)):
        #     if ele in mask_idxs:
        #         exmask[i] = 0
        # for i, ele in enumerate(cast_list(pred_arcs)):
        #     if ele in mask_idxs:
        #         exmask[i] = 0
        # gold_arcs = gold_arcs[exmask]
        # pred_arcs = pred_arcs[exmask]
        # gold_rels = gold_rels[exmask]
        # pred_rels = pred_rels[exmask]

        # loss += self.get_loss(s_arc, s_rel, gold_arcs, gold_rels)
        metric(pred_arcs, pred_rels, gold_arcs, gold_rels)

        if return_metric:
            return metric
        else:
            return pred_arcs.view(words.size(0), -1), pred_rels.view(words.size(0), -1), \
                   gold_arcs.view(words.size(0), -1), gold_rels.view(words.size(0), -1)
Ejemplo n.º 3
0
    def __call__(self, config):
        self.init_logger(config)
        self.setup(config)

        if self.config.hk_use_worker == 'on':
            start_sid, end_sid = locate_chunk(len(self.loader),
                                              self.config.hk_num_worker,
                                              self.config.hk_worker_id)
            log('Run code on a chunk [{}, {})'.format(start_sid, end_sid))

        raw_metrics = ParserMetric()
        attack_metrics = ParserMetric()

        agg = Aggregator()
        for sid, (words, tags, chars, arcs, rels) in enumerate(self.loader):
            # if sid in [0, 1, 2, 3, 4]:
            #     continue
            # if sid < 1434:
            #     continue
            if self.config.hk_use_worker == 'on':
                if sid < start_sid or sid >= end_sid:
                    continue
            if self.config.hk_training_set == 'on' and words.size(1) > 50:
                log('Skip sentence {} whose length is {}(>50).'.format(
                    sid, words.size(1)))
                continue

            if words.size(1) < 5:
                log('Skip sentence {} whose length is {}(<5).'.format(
                    sid, words.size(1)))
                continue

            words_text = self.vocab.id2word(words[0])
            tags_text = self.vocab.id2tag(tags[0])
            log('****** {}: \n{}\n{}'.format(sid, " ".join(words_text),
                                             " ".join(tags_text)))

            # hack it!
            result = self.hack(instance=(words, tags, chars, arcs, rels))

            # aggregate information
            raw_metrics += result['raw_metric']
            attack_metrics += result['attack_metric']
            agg.aggregate(
                ("iters", result['iters']), ("time", result['time']),
                ("fail",
                 abs(result['attack_metric'].uas - result['raw_metric'].uas) <
                 1e-4), ('best_iter', result['best_iter']),
                ("changed", result['num_changed']))

            # log some information
            log('Show result from iter {}, changed num {}:'.format(
                result['best_iter'], result['num_changed']))
            log(result['logtable'])

            log('Aggregated result: {} --> {}, '
                'iters(avg) {:.1f}, time(avg) {:.1f}s, '
                'fail rate {:.2f}, best_iter(avg) {:.1f}, best_iter(std) {:.1f}, '
                'changed(avg) {:.1f}'.format(raw_metrics, attack_metrics,
                                             agg.mean('iters'),
                                             agg.mean('time'),
                                             agg.mean('fail'),
                                             agg.mean('best_iter'),
                                             agg.std('best_iter'),
                                             agg.mean('changed')))
            log()
Ejemplo n.º 4
0
    def __call__(self, config):
        print("Preprocess the data")
        train = Corpus.load(config.ftrain)
        dev = Corpus.load(config.fdev)
        test = Corpus.load(config.ftest)
        if os.path.exists(config.vocab):
            vocab = torch.load(config.vocab)
        else:
            vocab = Vocab.from_corpus(corpus=train, min_freq=2)
            vocab.read_embeddings(Pretrained.load(config.fembed, config.unk))
            torch.save(vocab, config.vocab)
        config.update({
            'n_words': vocab.n_train_words,
            'n_tags': vocab.n_tags,
            'n_rels': vocab.n_rels,
            'n_chars': vocab.n_chars,
            'pad_index': vocab.pad_index,
            'unk_index': vocab.unk_index
        })
        print(vocab)

        print("Load the dataset")
        trainset = TextDataset(vocab.numericalize(train))
        devset = TextDataset(vocab.numericalize(dev))
        testset = TextDataset(vocab.numericalize(test))
        # set the data loaders
        train_loader = batchify(dataset=trainset,
                                batch_size=config.batch_size,
                                n_buckets=config.buckets,
                                shuffle=True)
        dev_loader = batchify(dataset=devset,
                              batch_size=config.batch_size,
                              n_buckets=config.buckets)
        test_loader = batchify(dataset=testset,
                               batch_size=config.batch_size,
                               n_buckets=config.buckets)
        print(f"{'train:':6} {len(trainset):5} sentences in total, "
              f"{len(train_loader):3} batches provided")
        print(f"{'dev:':6} {len(devset):5} sentences in total, "
              f"{len(dev_loader):3} batches provided")
        print(f"{'test:':6} {len(testset):5} sentences in total, "
              f"{len(test_loader):3} batches provided")

        print("Create the models")
        assert config.train_task in ['parser', 'tagger']
        is_training_parser = config.train_task == 'parser'

        if config.augmentation_training:
            aug_test = Corpus.load(config.augmentation_test_file)
            aug_testset = TextDataset(vocab.numericalize(aug_test))
            aug_test_loader = batchify(dataset=aug_testset,
                                       batch_size=config.batch_size,
                                       n_buckets=config.buckets)
            print(f"{'test:':6} {len(aug_testset):5} sentences in total, "
                  f"{len(aug_test_loader):3} batches provided")

        if is_training_parser:
            model = init_parser(config, vocab.embeddings)
            task = ParserTask(vocab, model)
            best_e, best_metric = 1, ParserMetric()
        else:
            model = PosTagger(config, vocab.embeddings)
            task = TaggerTask(vocab, model)
            best_e, best_metric = 1, TaggerMetric()

        if torch.cuda.is_available():
            model = model.cuda()
        print(f"{model}\n")
        total_time = timedelta()
        # best_e, best_metric = 1, TaggerMetric()
        task.optimizer = Adam(task.model.parameters(), config.lr,
                              (config.beta_1, config.beta_2), config.epsilon)
        task.scheduler = ExponentialLR(task.optimizer,
                                       config.decay**(1 / config.steps))
        for epoch in range(1, config.epochs + 1):
            start = datetime.now()
            # train one epoch and update the parameters
            task.train(train_loader)

            print(f"Epoch {epoch} / {config.epochs}:")
            loss, train_metric = task.evaluate(train_loader, config.punct)
            print(f"{'train:':6} Loss: {loss:.4f} {train_metric}")
            loss, dev_metric = task.evaluate(dev_loader, config.punct)
            print(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}")
            loss, test_metric = task.evaluate(test_loader, config.punct)
            print(f"{'test:':6} Loss: {loss:.4f} {test_metric}")
            if config.augmentation_training:
                loss, aug_test_metric = task.evaluate(aug_test_loader,
                                                      config.punct)
                print(f"{'test:':6} Loss: {loss:.4f} {aug_test_metric}")

            t = datetime.now() - start

            if dev_metric > best_metric and epoch > config.patience:
                best_e, best_metric = epoch, dev_metric
                if is_training_parser:
                    task.model.save(config.parser_model + f".{best_e}")
                else:
                    task.model.save(config.tagger_model + f".{best_e}")
                print(f"{t}s elapsed (saved)\n")
            else:
                print(f"{t}s elapsed\n")
            sys.stdout.flush()
            total_time += t
            if epoch - best_e >= config.patience:
                break

        if is_training_parser:
            copyfile(config.parser_model + f'.{best_e}',
                     config.parser_model + '.best')
            task.model = load_parser(config.parser_model + f".{best_e}")
        else:
            copyfile(config.tagger_model + f'.{best_e}',
                     config.tagger_model + '.best')
            task.model = PosTagger.load(config.tagger_model + f".{best_e}")
        loss, metric = task.evaluate(test_loader, config.punct)

        print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        print(f"the score of test at epoch {best_e} is {metric.score:.2%}")

        if config.augmentation_training:
            loss, metric = task.evaluate(aug_test_loader, config.punct)
            print(
                f"the score of aug test at epoch {best_e} is {metric.score:.2%}"
            )

        print(f"average time of each epoch is {total_time / epoch}s")
        print(f"{total_time}s elapsed")
Ejemplo n.º 5
0
    def __call__(self, config):
        if config.logf == 'on':
            log_config('hackoutside',
                       log_path=config.workspace,
                       default_target='cf')
            from dpattack.libs.luna import log
        else:
            log = print

        log('[General Settings]')
        log(config)
        log('[Hack Settings]')
        for arg in config.kwargs:
            if arg.startswith('hks'):
                log(arg, '\t', config.kwargs[arg])
        log('------------------')

        self.setup(config)

        raw_metrics = ParserMetric()
        attack_metrics = ParserMetric()

        agg = Aggregator()
        for sid, (words, tags, chars, arcs, rels) in enumerate(self.loader):
            # if sid > 100:
            #     continue

            words_text = self.vocab.id2word(words[0])
            tags_text = self.vocab.id2tag(tags[0])
            log('****** {}: \n{}\n{}'.format(sid, " ".join(words_text),
                                             " ".join(tags_text)))

            result = self.hack(instance=(words, tags, chars, arcs, rels),
                               sentence=self.corpus[sid])

            if result is None:
                continue
            else:
                raw_metrics += result['raw_metric']
                attack_metrics += result['attack_metric']

            agg.aggregate(
                ("iters", result['iters']), ("time", result['time']),
                ("fail",
                 abs(result['attack_metric'].uas - result['raw_metric'].uas) <
                 1e-4), ('best_iter', result['best_iter']),
                ("changed", result['num_changed']))

            # WARNING: SOME SENTENCE NOT SHOWN!
            if result:
                log('Show result from iter {}:'.format(result['best_iter']))
                log(result['logtable'])

            log('Aggregated result: {} --> {}, '
                'iters(avg) {:.1f}, time(avg) {:.1f}s, '
                'fail rate {:.2f}, best_iter(avg) {:.1f}, best_iter(std) {:.1f}, '
                'changed(avg) {:.1f}'.format(raw_metrics, attack_metrics,
                                             agg.mean('iters'),
                                             agg.mean('time'),
                                             agg.mean('fail'),
                                             agg.mean('best_iter'),
                                             agg.std('best_iter'),
                                             agg.mean('changed')))
            log()