Beispiel #1
0
    def attack_for_each_sentence(self, config, seq, seq_idx, tag, tag_idx, chars, arcs, rels, mask):
        '''
        :param seqs:
        :param seq_idx:
        :param tags:
        :param tag_idx:
        :param arcs:
        :param rels:
        :return:
        '''
        # seq length: ignore the first token (ROOT) of each sentence
        with torch.no_grad():
            # for metric before attacking
            raw_loss, raw_metric = self.task.evaluate([(seq_idx, tag_idx, chars, arcs, rels)],mst=config.mst)
            # score_arc_before_attack, score_rel_before_attack = self.parser.forward(seq_idx, is_chars_judger(self.parser, tag_idx, chars))

            # for metric after attacking
            # generate the attack sentence under attack_index
            attack_seq, attack_tag_idx, attack_mask, attack_gold_arc, attack_gold_rel, revised_number = self.attack_seq_generator.generate_attack_seq(' '.join(seq[1:]), seq_idx, tag, tag_idx, chars, arcs, rels, mask, raw_metric)
            # get the attack seq idx and tag idx
            attack_seq_idx = self.vocab.word2id(attack_seq).unsqueeze(0)
            if torch.cuda.is_available():
                attack_seq_idx = attack_seq_idx.cuda()

            if is_chars_judger(self.parser):
                attack_chars = self.get_chars_idx_by_seq(attack_seq)
                attack_loss, attack_metric = self.task.evaluate([(attack_seq_idx, None, attack_chars, arcs, rels)], mst=config.mst)
                _, attack_arc, attack_rel = self.task.predict([(attack_seq_idx, attack_tag_idx, attack_chars)], mst=config.mst)
            else:
                attack_loss, attack_metric = self.task.evaluate([(attack_seq_idx, attack_tag_idx, None, attack_gold_arc, attack_gold_rel)], mst=config.mst)
                _, attack_arc, attack_rel = self.task.predict([(attack_seq_idx, attack_tag_idx, None)], mst=config.mst)
            return raw_metric, attack_metric, attack_seq, attack_arc[0], attack_rel[0], revised_number
Beispiel #2
0
    def evaluate(self, loader, punct=False, tagger=None, mst=False):
        self.model.eval()

        loss, metric = 0, ParserMetric()

        for words, tags, chars, arcs, rels in loader:
            mask = words.ne(self.vocab.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0

            tags = self.get_tags(words, tags, mask, tagger)

            s_arc, s_rel = self.model(words,
                                      is_chars_judger(self.model, tags, chars))

            loss += self.get_loss(s_arc[mask], s_rel[mask], arcs[mask],
                                  rels[mask])
            pred_arcs, pred_rels = self.decode(s_arc, s_rel, mask, mst)

            # ignore all punctuation if not specified
            if not punct:
                puncts = words.new_tensor(self.vocab.puncts)
                mask &= words.unsqueeze(-1).ne(puncts).all(-1)
            pred_arcs, pred_rels = pred_arcs[mask], pred_rels[mask]
            gold_arcs, gold_rels = arcs[mask], rels[mask]

            metric(pred_arcs, pred_rels, gold_arcs, gold_rels)
        loss /= len(loader)

        return loss, metric
Beispiel #3
0
    def predict(self, loader, tagger=None, mst=False):
        self.model.eval()

        all_tags, all_arcs, all_rels = [], [], []
        for words, tags, chars in loader:
            mask = words.ne(self.vocab.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            lens = mask.sum(dim=1).tolist()

            tags = self.get_tags(words, tags, mask, tagger)
            s_arc, s_rel = self.model(words,
                                      is_chars_judger(self.model, tags, chars))

            pred_arcs, pred_rels = self.decode(s_arc, s_rel, mask, mst)
            tags, pred_arcs, pred_rels = tags[mask], pred_arcs[
                mask], pred_rels[mask]

            all_tags.extend(torch.split(tags, lens))
            all_arcs.extend(torch.split(pred_arcs, lens))
            all_rels.extend(torch.split(pred_rels, lens))
        all_tags = [self.vocab.id2tag(seq) for seq in all_tags]
        all_arcs = [seq.tolist() for seq in all_arcs]
        all_rels = [self.vocab.id2rel(seq) for seq in all_rels]

        return all_tags, all_arcs, all_rels
Beispiel #4
0
    def partial_evaluate(self,
                         instance: tuple,
                         mask_idxs: List[int],
                         punct=False,
                         tagger=None,
                         mst=False,
                         return_metric=True):
        self.model.eval()

        loss, metric = 0, ParserMetric()

        words, tags, chars, arcs, rels = instance

        mask = words.ne(self.vocab.pad_index)
        # ignore the first token of each sentence
        mask[:, 0] = 0
        decode_mask = mask.clone()

        tags = self.get_tags(words, tags, mask, tagger)
        # ignore all punctuation if not specified
        if not punct:
            puncts = words.new_tensor(self.vocab.puncts)
            mask &= words.unsqueeze(-1).ne(puncts).all(-1)
        s_arc, s_rel = self.model(words,
                                  is_chars_judger(self.model, tags, chars))

        # mask given indices
        for idx in mask_idxs:
            mask[:, idx] = 0

        pred_arcs, pred_rels = self.decode(s_arc, s_rel, decode_mask, mst)

        # punct is ignored !!!
        pred_arcs, pred_rels = pred_arcs[mask], pred_rels[mask]
        gold_arcs, gold_rels = arcs[mask], rels[mask]

        # exmask = torch.ones_like(gold_arcs, dtype=torch.uint8)

        # for i, ele in enumerate(cast_list(gold_arcs)):
        #     if ele in mask_idxs:
        #         exmask[i] = 0
        # for i, ele in enumerate(cast_list(pred_arcs)):
        #     if ele in mask_idxs:
        #         exmask[i] = 0
        # gold_arcs = gold_arcs[exmask]
        # pred_arcs = pred_arcs[exmask]
        # gold_rels = gold_rels[exmask]
        # pred_rels = pred_rels[exmask]

        # loss += self.get_loss(s_arc, s_rel, gold_arcs, gold_rels)
        metric(pred_arcs, pred_rels, gold_arcs, gold_rels)

        if return_metric:
            return metric
        else:
            return pred_arcs.view(words.size(0), -1), pred_rels.view(words.size(0), -1), \
                   gold_arcs.view(words.size(0), -1), gold_rels.view(words.size(0), -1)
Beispiel #5
0
    def train(self, loader):
        self.model.train()

        for words, tags, chars, arcs, rels in loader:
            self.optimizer.zero_grad()

            mask = words.ne(self.vocab.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            # tags = self.get_tag(words, tags, mask)
            s_arc, s_rel = self.model(words,
                                      is_chars_judger(self.model, tags, chars))
            s_arc, s_rel = s_arc[mask], s_rel[mask]
            gold_arcs, gold_rels = arcs[mask], rels[mask]

            loss = self.get_loss(s_arc, s_rel, gold_arcs, gold_rels)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
            self.optimizer.step()
            self.scheduler.step()
Beispiel #6
0
    def attack_for_each_sentence(self, config, seq, seq_idx, tag, tag_idx, chars, arcs, rels, mask, number):
        '''
        :param seqs:
        :param seq_idx:
        :param tags:
        :param tag_idx:
        :param arcs:
        :param rels:
        :return:
        '''
        # seq length: ignore the first token (ROOT) of each sentence
        # for metric before attacking
        # loss, raw_metric = self.task.evaluate([seq_idx,tag_idx, chars,arcs, rels])
        self.parser.eval()
        _, raw_metric = self.task.evaluate([(seq_idx, tag_idx, chars, arcs, rels)],mst=config.mst)
        # score_arc_before_attack, score_rel_before_attack = self.parser.forward(seq_idx, is_chars_judger(self.parser, tag_idx, chars))
        # raw_metric = self.get_metric(score_arc_before_attack[mask], score_rel_before_attack[mask], arcs[mask], rels[mask])

        # pre-process word mask
        word_index_grad_neednot_consider = cast_list(rels.eq(self.punct_rel_idx).squeeze().nonzero())
        word_index_grad_neednot_consider.append(0)

        # pre-process char mask
        char_mask = chars.gt(0)[0]
        sorted_lens, indices = torch.sort(char_mask.sum(dim=1), descending=True)
        inverse_indices = indices.argsort()

        char_mask_max = torch.max(char_mask.sum(dim=1))
        char_mask = char_mask[:,:char_mask_max]
        # delete the root token
        char_mask[0, :] = False
        # delete punct
        punct_idx_list = cast_list(rels.eq(self.punct_rel_idx).nonzero())
        char_mask[punct_idx_list, :] = False
        # the index in origin char
        # char_indexes = cast_list(char_mask.nonzero())

        attack_chars = chars.clone()

        forbidden_idx = defaultdict(lambda: set())
        char_idx = dict()
        revised_number = 0
        for i in range(self.attack_epochs):
            self.parser.zero_grad()
            s_arc, s_rel = self.parser.forward(seq_idx, is_chars_judger(self.parser, tag_idx, attack_chars))
            loss = self.get_whitebox_loss(arcs[mask], s_arc[mask])
            loss.backward()

            charembed_grad = self.embed_grad['char_embed_grad'][inverse_indices]
            wordembed_grad = self.parser.word_embed_grad[0]
            word_grad_norm = wordembed_grad.norm(dim=1)
            word_grad_norm[word_index_grad_neednot_consider] = -10000.0

            if i == 0:
                current_norm_indexes = cast_list(word_grad_norm.topk(number)[1])
                for index in current_norm_indexes:
                    forbidden_idx[index].update(self.punct_idx.copy())
                    revised_number += 1
                    char_grad = charembed_grad[index][char_mask[index]]
                    char_grad_norm = char_grad.norm(dim=1)
                    char_index = char_grad_norm.topk(1)[1].item()
                    char_idx[index] = char_index
            # if number == 1:
            #     if len(forbidden_idx.keys()) == 1:
            #         current_norm_indexes = list(forbidden_idx.keys())
            #     else:
            #         current_norm_indexes = cast_list(word_grad_norm.topk(1)[1])
            #         for index in current_norm_indexes:
            #             revised_number += 1
            #             forbidden_idx[index].update(self.punct_idx.copy())
            #             char_grad = charembed_grad[index][char_mask[index]]
            #             char_grad_norm = char_grad.norm(dim=1)
            #             char_index = char_grad_norm.topk(1)[1].item()
            #             char_idx[index] = char_index
            # elif number > 1:
            #     current_norm_indexes = cast_list(word_grad_norm.topk(2)[1])
            #     for count, index in enumerate(current_norm_indexes):
            #         if index in forbidden_idx:
            #             continue
            #         else:
            #             if len(forbidden_idx) < number:
            #                 revised_number += 1
            #                 forbidden_idx[index].update(self.punct_idx.copy())
            #                 char_grad = charembed_grad[index][char_mask[index]]
            #                 char_grad_norm = char_grad.norm(dim=1)
            #                 char_index = char_grad_norm.topk(1)[1].item()
            #                 char_idx[index] = char_index
            #             else:
            #                 current_norm_indexes[count] = np.random.choice(list(forbidden_idx.keys()))
            #     while current_norm_indexes[0] == current_norm_indexes[1]:
            #         current_norm_indexes[1] = np.random.choice(list(forbidden_idx.keys()))
            for index in forbidden_idx.keys():
                raw_index = attack_chars[0, index, char_idx[index]].item()
                # add raw index to be forbidden
                # if raw_index_char is a alpha, including its lower and upper letter
                self.add_raw_index_to_be_forbidden(forbidden_idx, index, raw_index)
                replace_index = self.find_neighbors(raw_index, charembed_grad[index, char_idx[index]],list(forbidden_idx[index]))
                attack_chars[0, index, char_idx[index]] = replace_index

            self.parser.eval()

            _, attack_metric = self.task.evaluate([(seq_idx, tag_idx, attack_chars, arcs, rels)], mst=config.mst)

            if attack_metric.uas < raw_metric.uas:
                self.succeed_number += 1
                print("Succeed", end=" ")
                break
        attack_seq = [Corpus.ROOT] + [self.vocab.id2char(chars) for chars in attack_chars[0,1:]]

        return raw_metric, attack_metric, attack_seq, revised_number
Beispiel #7
0
    def attack(self, config, loader, corpus, attack_corpus):
        success = 0
        all_number = 0
        raw_metric_all = Metric()
        attack_metric_all = Metric()
        span_number = 0

        for index, (seq_idx, tag_idx, chars, arcs, rels) in enumerate(loader):
            print("Sentence {}".format(index + 1))
            start_time = time.time()
            mask = self.get_mask(seq_idx,
                                 self.vocab.pad_index,
                                 punct_list=self.vocab.puncts)
            seqs = self.get_seqs_name(seq_idx)
            tags = self.get_tags_name(tag_idx)

            sent = corpus[index]
            sentence_length = len(sent.FORM)
            spans = self.gen_spans(sent)
            roots = (arcs == 0).squeeze().nonzero().squeeze().tolist()

            valid_spans = self.get_valid_spans(spans, roots, sentence_length)
            if len(valid_spans
                   ) >= 2 and valid_spans[-1][0] > valid_spans[0][1]:
                filter_valid_spans = self.filter_spans(valid_spans)

                spans_list_to_attack = self.get_span_to_attack(
                    filter_valid_spans)
                if len(spans_list_to_attack) == 0:
                    print(
                        "Sentence {} doesn't has enough valid spans. Time: {:.2f}"
                        .format(index + 1,
                                time.time() - start_time))
                    continue
                all_number += 1
                succeed_flag = False

                raw_s_rac, raw_s_rel = self.parser.forward(
                    seq_idx, is_chars_judger(self.parser, tag_idx, chars))
                raw_mask = torch.ones_like(mask, dtype=mask.dtype)
                raw_mask[0, 0] = False
                raw_pred_arc, raw_pred_rel = self.task.decode(raw_s_rac,
                                                              raw_s_rel,
                                                              raw_mask,
                                                              mst=config.mst)

                for span_pair_index, spans_to_attack in enumerate(
                        spans_list_to_attack):
                    src_span = spans_to_attack[0]
                    tgt_span = spans_to_attack[1]

                    mask_idxes = list(range(
                        0, tgt_span[0])) + [tgt_span[2]] + list(
                            range(tgt_span[1] + 1, sentence_length + 1))
                    new_mask = self.update_mask(mask, mask_idxes)

                    # for batch compare , no used task.evaluate
                    raw_non_equal_number = torch.sum(
                        torch.ne(raw_pred_arc[new_mask],
                                 arcs[new_mask])).item()

                    indexes = self.get_attack_index(
                        self.vocab.id2rel(rels),
                        list(range(src_span[0], src_span[1] + 1)),
                        self.reivsed, self.candidates)
                    attack_seqs = [
                        self.attack_seq_generator.substitute(
                            seqs, tags, index) for index in indexes
                    ]
                    attack_seq_idx = torch.cat([
                        self.vocab.word2id(attack_seq).unsqueeze(0)
                        for attack_seq in attack_seqs
                    ],
                                               dim=0)
                    if torch.cuda.is_available():
                        attack_seq_idx = attack_seq_idx.cuda()
                    attack_mask = torch.ones_like(attack_seq_idx,
                                                  dtype=mask.dtype)
                    attack_mask[:, 0] = 0
                    if is_chars_judger(self.parser):
                        attack_chars_idx = torch.cat([
                            self.get_chars_idx_by_seq(attack_seq)
                            for attack_seq in attack_seqs
                        ],
                                                     dim=0)
                        attack_s_arc, attack_s_rel = self.parser.forward(
                            attack_seq_idx, attack_chars_idx)
                    else:
                        attack_tags_idx = tag_idx.repeat(self.candidates, 1)
                        attack_s_arc, attack_s_rel = self.parser.forward(
                            attack_seq_idx, attack_tags_idx)

                    attack_pred_arc, attack_pred_rel = self.task.decode(
                        attack_s_arc,
                        attack_s_rel,
                        attack_mask,
                        mst=config.mst)
                    attack_pred_arc_tgt = torch.split(
                        attack_pred_arc[new_mask.repeat(self.candidates, 1)],
                        [torch.sum(new_mask)] * self.candidates)
                    attack_non_equal_number_index = [
                        count for count, pred in enumerate(attack_pred_arc_tgt)
                        if torch.sum(torch.ne(pred, arcs[new_mask])).item() >
                        raw_non_equal_number
                    ]

                    if len(attack_non_equal_number_index) != 0:
                        success += 1
                        non_equal_numbers = [
                            torch.sum(
                                torch.
                                ne(attack_pred_arc_tgt[non_equal_number_index],
                                   arcs[new_mask])).item()
                            for non_equal_number_index in
                            attack_non_equal_number_index
                        ]
                        attack_succeed_index = sorted(
                            range(len(non_equal_numbers)),
                            key=lambda k: non_equal_numbers[k],
                            reverse=True)[0]
                        attack_succeed_index = attack_non_equal_number_index[
                            attack_succeed_index]

                        success_index = ' '.join([
                            str(random_index)
                            for random_index in indexes[attack_succeed_index]
                        ])
                        success_candidate = ' '.join([
                            attack_seqs[attack_succeed_index][random_index]
                            for random_index in indexes[attack_succeed_index]
                        ])
                        print(
                            'Pair {}/{}, src span:({},{}), tgt span:({},{}) succeed'
                            .format(span_pair_index + 1,
                                    len(spans_list_to_attack), src_span[0],
                                    src_span[1], tgt_span[0], tgt_span[1],
                                    success_candidate))
                        print('indexes: {} candidates: {}'.format(
                            success_index, success_candidate))
                        print(
                            "Sentence {} attacked succeeded! Time: {:.2f}s , Success Rate:{:.2f}%"
                            .format(index + 1,
                                    time.time() - start_time,
                                    success / all_number * 100))

                        attack_metric = Metric()
                        attack_metric(
                            attack_pred_arc[attack_succeed_index].unsqueeze(
                                0)[mask],
                            attack_pred_rel[attack_succeed_index].unsqueeze(
                                0)[mask], arcs[mask], rels[mask])
                        if config.save_result_to_file:
                            attack_seq = attack_seqs[attack_succeed_index]
                            for span in range(src_span[0], src_span[1] + 1):
                                attack_seq[span] = "@" + attack_seq[span]
                            for span in range(tgt_span[0], tgt_span[1] + 1):
                                attack_seq[span] = "#" + attack_seq[span]
                            attack_corpus.append(
                                init_sentence(
                                    seqs[1:], attack_seq[1:], tags[1:],
                                    cast_list(arcs)[1:],
                                    self.vocab.id2rel(rels)[1:],
                                    cast_list(
                                        attack_pred_arc[attack_succeed_index])
                                    [1:],
                                    self.vocab.id2rel(
                                        attack_pred_rel[attack_succeed_index])
                                    [1:]))
                            print('\n'.join(
                                '\t'.join(map(str, i))
                                for i in zip(*(f for f in attack_corpus[-1]
                                               if f))))
                        succeed_flag = True
                        break
                    print(
                        'Pair {}/{}, src span:({},{}), tgt span:({},{}) failed! '
                        .format(span_pair_index + 1, len(spans_list_to_attack),
                                src_span[0], src_span[1], tgt_span[0],
                                tgt_span[1]))

                raw_metric = Metric()
                raw_metric(raw_pred_arc[mask], raw_pred_rel[mask], arcs[mask],
                           rels[mask])
                if not succeed_flag:
                    print(
                        "Sentence {} attacked failed! Time: {:.2f}s, Success Rate:{:.2f}%"
                        .format(index + 1,
                                time.time() - start_time,
                                success / all_number * 100))
                    attack_metric = raw_metric
                raw_metric_all += raw_metric
                attack_metric_all += attack_metric
            else:
                print(
                    "Sentence {} doesn't has enough valid spans. Time: {:.2f}".
                    format(index + 1,
                           time.time() - start_time))
        print("Before: {} After:{}".format(raw_metric_all, attack_metric_all))
        print("All: {}, Success: {}, Success Rate:{:.2f}%".format(
            all_number, success, success / all_number * 100))
        print("Average: {}".format(span_number / all_number))