예제 #1
0
 def turn_tensor_to_list(self, tensor_or_list):
     if isinstance(tensor_or_list, torch.Tensor):
         return cast_list(tensor_or_list)
     elif isinstance(tensor_or_list, list):
         return tensor_or_list
     else:
         return tensor_or_list
예제 #2
0
    def generate_attack_seq(self,
                            seqs,
                            seq_idx,
                            tags,
                            tag_idx,
                            chars,
                            arcs,
                            rels,
                            mask,
                            raw_metric=None):
        gold_arcs = cast_list(arcs)
        attack_index = self.index.get_attack_index(self.copy_str_to_list(seqs),
                                                   seq_idx, tags, tag_idx,
                                                   chars, gold_arcs, mask)

        attack_index.sort(reverse=True)
        attack_mask = cast_list(mask)
        attack_tags = cast_list(tag_idx)
        attack_arcs = gold_arcs.copy()
        attack_rels = cast_list(rels)
        attack_seqs = self.copy_str_to_list(seqs)

        for index in attack_index:
            del attack_seqs[index]
            del attack_tags[index + 1]
            del attack_mask[index + 1]
            del attack_arcs[index + 1]
            del attack_rels[index + 1]
            attack_arcs = [
                arc - 1 if arc > index else arc for arc in attack_arcs
            ]

        attack_tags = torch.tensor(attack_tags, dtype=tag_idx.dtype)
        attack_mask = torch.tensor(attack_mask, dtype=mask.dtype)
        attack_arcs = torch.tensor(attack_arcs, dtype=arcs.dtype)
        attack_rels = torch.tensor(attack_rels, dtype=rels.dtype)
        attack_tags, attack_mask, attack_arcs, attack_rels = map(
            lambda x: x.unsqueeze(0) if len(x.shape) == 1 else x,
            [attack_tags, attack_mask, attack_arcs, attack_rels])
        attack_tags, attack_mask, attack_arcs, attack_rels = map(
            lambda x: x.cuda() if torch.cuda.is_available() else x,
            [attack_tags, attack_mask, attack_arcs, attack_rels])
        return [
            Corpus.ROOT
        ] + attack_seqs, attack_tags, attack_mask, attack_arcs, attack_rels, len(
            attack_index)
예제 #3
0
    def attack_for_each_process(self, config, loader, attack_corpus):
        revised_numbers = 0

        # three metric:
        # metric_before_attack: the metric before attacking(origin)
        # metric_after_attack: the metric after black box attacking
        metric_before_attack = Metric()
        metric_after_attack = Metric()

        for index, (seq_idx, tag_idx, chars, arcs, rels) in enumerate(loader):
            mask = self.get_mask(seq_idx, self.vocab.pad_index, punct_list=self.vocab.puncts)
            seqs = self.get_seqs_name(seq_idx)
            tags = self.get_tags_name(tag_idx)

            number = self.get_number(config.revised_rate, len(seqs) - 1)
            # attack for one sentence
            raw_metric, attack_metric,\
            attack_seq, revised_number = self.attack_for_each_sentence(config, seqs, seq_idx, tags, tag_idx, chars, arcs, rels, mask, number)

            metric_before_attack += raw_metric
            metric_after_attack += attack_metric

            if config.save_result_to_file:
                # all result ignores the first token <ROOT>
                attack_seq_idx = self.vocab.word2id(attack_seq).unsqueeze(0)
                if torch.cuda.is_available():
                    attack_seq_idx = attack_seq_idx.cuda()
                attack_chars = self.get_chars_idx_by_seq(attack_seq)
                _, attack_arc, attack_rel = self.task.predict([(attack_seq_idx, tag_idx, attack_chars)],mst=config.mst)
                attack_corpus.append(init_sentence(seqs[1:],
                                                   attack_seq[1:],
                                                   tags[1:],
                                                   cast_list(arcs)[1:],
                                                   self.vocab.id2rel(rels)[1:],
                                                   attack_arc,
                                                   attack_rel))

            revised_numbers += revised_number
            print("Sentence: {}, Revised: {} Before: {} After: {} ".format(index + 1, revised_number, metric_before_attack, metric_after_attack))

        return metric_before_attack, metric_after_attack, revised_numbers
예제 #4
0
    def attack_for_each_process(self, config, loader, attack_corpus):
        revised_numbers = 0

        # three metric:
        # metric_before_attack: the metric before attacking(origin)
        # metric_after_attack: the metric after black box attacking
        raw_metric_all = Metric()
        attack_metric_all = Metric()
        success_numbers = 0

        for index, (seq_idx, tag_idx, chars, arcs, rels) in enumerate(loader):
            mask = self.get_mask(seq_idx, self.vocab.pad_index, punct_list=self.vocab.puncts)
            seqs = self.get_seqs_name(seq_idx)
            tags = self.get_tags_name(tag_idx)

            # attack for one sentence
            raw_metric, attack_metric, \
            attack_seq, attack_arc, attack_rel, \
            revised_number = self.attack_for_each_sentence(config, seqs, seq_idx, tags, tag_idx, chars, arcs, rels, mask)

            raw_metric_all += raw_metric
            attack_metric_all += attack_metric

            if attack_metric.uas < raw_metric.uas:
                success_numbers += 1

            if config.save_result_to_file:
                # all result ignores the first token <ROOT>
                attack_corpus.append(init_sentence(seqs[1:],
                                                   attack_seq[1:],
                                                   tags[1:],
                                                   cast_list(arcs)[1:],
                                                   self.vocab.id2rel(rels)[1:],
                                                   attack_arc,
                                                   attack_rel))

            revised_numbers += revised_number
            print("Sentence: {}, Revised: {} Before: {} After: {} ".format(index + 1, revised_number, raw_metric_all, attack_metric_all))

        return raw_metric_all, attack_metric_all, revised_numbers, success_numbers
예제 #5
0
    def attack_for_each_sentence(self, config, seq, seq_idx, tag, tag_idx, chars, arcs, rels, mask, number):
        '''
        :param seqs:
        :param seq_idx:
        :param tags:
        :param tag_idx:
        :param arcs:
        :param rels:
        :return:
        '''
        # seq length: ignore the first token (ROOT) of each sentence
        # for metric before attacking
        # loss, raw_metric = self.task.evaluate([seq_idx,tag_idx, chars,arcs, rels])
        self.parser.eval()
        _, raw_metric = self.task.evaluate([(seq_idx, tag_idx, chars, arcs, rels)],mst=config.mst)
        # score_arc_before_attack, score_rel_before_attack = self.parser.forward(seq_idx, is_chars_judger(self.parser, tag_idx, chars))
        # raw_metric = self.get_metric(score_arc_before_attack[mask], score_rel_before_attack[mask], arcs[mask], rels[mask])

        # pre-process word mask
        word_index_grad_neednot_consider = cast_list(rels.eq(self.punct_rel_idx).squeeze().nonzero())
        word_index_grad_neednot_consider.append(0)

        # pre-process char mask
        char_mask = chars.gt(0)[0]
        sorted_lens, indices = torch.sort(char_mask.sum(dim=1), descending=True)
        inverse_indices = indices.argsort()

        char_mask_max = torch.max(char_mask.sum(dim=1))
        char_mask = char_mask[:,:char_mask_max]
        # delete the root token
        char_mask[0, :] = False
        # delete punct
        punct_idx_list = cast_list(rels.eq(self.punct_rel_idx).nonzero())
        char_mask[punct_idx_list, :] = False
        # the index in origin char
        # char_indexes = cast_list(char_mask.nonzero())

        attack_chars = chars.clone()

        forbidden_idx = defaultdict(lambda: set())
        char_idx = dict()
        revised_number = 0
        for i in range(self.attack_epochs):
            self.parser.zero_grad()
            s_arc, s_rel = self.parser.forward(seq_idx, is_chars_judger(self.parser, tag_idx, attack_chars))
            loss = self.get_whitebox_loss(arcs[mask], s_arc[mask])
            loss.backward()

            charembed_grad = self.embed_grad['char_embed_grad'][inverse_indices]
            wordembed_grad = self.parser.word_embed_grad[0]
            word_grad_norm = wordembed_grad.norm(dim=1)
            word_grad_norm[word_index_grad_neednot_consider] = -10000.0

            if i == 0:
                current_norm_indexes = cast_list(word_grad_norm.topk(number)[1])
                for index in current_norm_indexes:
                    forbidden_idx[index].update(self.punct_idx.copy())
                    revised_number += 1
                    char_grad = charembed_grad[index][char_mask[index]]
                    char_grad_norm = char_grad.norm(dim=1)
                    char_index = char_grad_norm.topk(1)[1].item()
                    char_idx[index] = char_index
            # if number == 1:
            #     if len(forbidden_idx.keys()) == 1:
            #         current_norm_indexes = list(forbidden_idx.keys())
            #     else:
            #         current_norm_indexes = cast_list(word_grad_norm.topk(1)[1])
            #         for index in current_norm_indexes:
            #             revised_number += 1
            #             forbidden_idx[index].update(self.punct_idx.copy())
            #             char_grad = charembed_grad[index][char_mask[index]]
            #             char_grad_norm = char_grad.norm(dim=1)
            #             char_index = char_grad_norm.topk(1)[1].item()
            #             char_idx[index] = char_index
            # elif number > 1:
            #     current_norm_indexes = cast_list(word_grad_norm.topk(2)[1])
            #     for count, index in enumerate(current_norm_indexes):
            #         if index in forbidden_idx:
            #             continue
            #         else:
            #             if len(forbidden_idx) < number:
            #                 revised_number += 1
            #                 forbidden_idx[index].update(self.punct_idx.copy())
            #                 char_grad = charembed_grad[index][char_mask[index]]
            #                 char_grad_norm = char_grad.norm(dim=1)
            #                 char_index = char_grad_norm.topk(1)[1].item()
            #                 char_idx[index] = char_index
            #             else:
            #                 current_norm_indexes[count] = np.random.choice(list(forbidden_idx.keys()))
            #     while current_norm_indexes[0] == current_norm_indexes[1]:
            #         current_norm_indexes[1] = np.random.choice(list(forbidden_idx.keys()))
            for index in forbidden_idx.keys():
                raw_index = attack_chars[0, index, char_idx[index]].item()
                # add raw index to be forbidden
                # if raw_index_char is a alpha, including its lower and upper letter
                self.add_raw_index_to_be_forbidden(forbidden_idx, index, raw_index)
                replace_index = self.find_neighbors(raw_index, charembed_grad[index, char_idx[index]],list(forbidden_idx[index]))
                attack_chars[0, index, char_idx[index]] = replace_index

            self.parser.eval()

            _, attack_metric = self.task.evaluate([(seq_idx, tag_idx, attack_chars, arcs, rels)], mst=config.mst)

            if attack_metric.uas < raw_metric.uas:
                self.succeed_number += 1
                print("Succeed", end=" ")
                break
        attack_seq = [Corpus.ROOT] + [self.vocab.id2char(chars) for chars in attack_chars[0,1:]]

        return raw_metric, attack_metric, attack_seq, revised_number
예제 #6
0
    def attack(self, config, loader, corpus, attack_corpus):
        success = 0
        all_number = 0
        raw_metric_all = Metric()
        attack_metric_all = Metric()
        span_number = 0

        for index, (seq_idx, tag_idx, chars, arcs, rels) in enumerate(loader):
            print("Sentence {}".format(index + 1))
            start_time = time.time()
            mask = self.get_mask(seq_idx,
                                 self.vocab.pad_index,
                                 punct_list=self.vocab.puncts)
            seqs = self.get_seqs_name(seq_idx)
            tags = self.get_tags_name(tag_idx)

            sent = corpus[index]
            sentence_length = len(sent.FORM)
            spans = self.gen_spans(sent)
            roots = (arcs == 0).squeeze().nonzero().squeeze().tolist()

            valid_spans = self.get_valid_spans(spans, roots, sentence_length)
            if len(valid_spans
                   ) >= 2 and valid_spans[-1][0] > valid_spans[0][1]:
                filter_valid_spans = self.filter_spans(valid_spans)

                spans_list_to_attack = self.get_span_to_attack(
                    filter_valid_spans)
                if len(spans_list_to_attack) == 0:
                    print(
                        "Sentence {} doesn't has enough valid spans. Time: {:.2f}"
                        .format(index + 1,
                                time.time() - start_time))
                    continue
                all_number += 1
                succeed_flag = False

                raw_s_rac, raw_s_rel = self.parser.forward(
                    seq_idx, is_chars_judger(self.parser, tag_idx, chars))
                raw_mask = torch.ones_like(mask, dtype=mask.dtype)
                raw_mask[0, 0] = False
                raw_pred_arc, raw_pred_rel = self.task.decode(raw_s_rac,
                                                              raw_s_rel,
                                                              raw_mask,
                                                              mst=config.mst)

                for span_pair_index, spans_to_attack in enumerate(
                        spans_list_to_attack):
                    src_span = spans_to_attack[0]
                    tgt_span = spans_to_attack[1]

                    mask_idxes = list(range(
                        0, tgt_span[0])) + [tgt_span[2]] + list(
                            range(tgt_span[1] + 1, sentence_length + 1))
                    new_mask = self.update_mask(mask, mask_idxes)

                    # for batch compare , no used task.evaluate
                    raw_non_equal_number = torch.sum(
                        torch.ne(raw_pred_arc[new_mask],
                                 arcs[new_mask])).item()

                    indexes = self.get_attack_index(
                        self.vocab.id2rel(rels),
                        list(range(src_span[0], src_span[1] + 1)),
                        self.reivsed, self.candidates)
                    attack_seqs = [
                        self.attack_seq_generator.substitute(
                            seqs, tags, index) for index in indexes
                    ]
                    attack_seq_idx = torch.cat([
                        self.vocab.word2id(attack_seq).unsqueeze(0)
                        for attack_seq in attack_seqs
                    ],
                                               dim=0)
                    if torch.cuda.is_available():
                        attack_seq_idx = attack_seq_idx.cuda()
                    attack_mask = torch.ones_like(attack_seq_idx,
                                                  dtype=mask.dtype)
                    attack_mask[:, 0] = 0
                    if is_chars_judger(self.parser):
                        attack_chars_idx = torch.cat([
                            self.get_chars_idx_by_seq(attack_seq)
                            for attack_seq in attack_seqs
                        ],
                                                     dim=0)
                        attack_s_arc, attack_s_rel = self.parser.forward(
                            attack_seq_idx, attack_chars_idx)
                    else:
                        attack_tags_idx = tag_idx.repeat(self.candidates, 1)
                        attack_s_arc, attack_s_rel = self.parser.forward(
                            attack_seq_idx, attack_tags_idx)

                    attack_pred_arc, attack_pred_rel = self.task.decode(
                        attack_s_arc,
                        attack_s_rel,
                        attack_mask,
                        mst=config.mst)
                    attack_pred_arc_tgt = torch.split(
                        attack_pred_arc[new_mask.repeat(self.candidates, 1)],
                        [torch.sum(new_mask)] * self.candidates)
                    attack_non_equal_number_index = [
                        count for count, pred in enumerate(attack_pred_arc_tgt)
                        if torch.sum(torch.ne(pred, arcs[new_mask])).item() >
                        raw_non_equal_number
                    ]

                    if len(attack_non_equal_number_index) != 0:
                        success += 1
                        non_equal_numbers = [
                            torch.sum(
                                torch.
                                ne(attack_pred_arc_tgt[non_equal_number_index],
                                   arcs[new_mask])).item()
                            for non_equal_number_index in
                            attack_non_equal_number_index
                        ]
                        attack_succeed_index = sorted(
                            range(len(non_equal_numbers)),
                            key=lambda k: non_equal_numbers[k],
                            reverse=True)[0]
                        attack_succeed_index = attack_non_equal_number_index[
                            attack_succeed_index]

                        success_index = ' '.join([
                            str(random_index)
                            for random_index in indexes[attack_succeed_index]
                        ])
                        success_candidate = ' '.join([
                            attack_seqs[attack_succeed_index][random_index]
                            for random_index in indexes[attack_succeed_index]
                        ])
                        print(
                            'Pair {}/{}, src span:({},{}), tgt span:({},{}) succeed'
                            .format(span_pair_index + 1,
                                    len(spans_list_to_attack), src_span[0],
                                    src_span[1], tgt_span[0], tgt_span[1],
                                    success_candidate))
                        print('indexes: {} candidates: {}'.format(
                            success_index, success_candidate))
                        print(
                            "Sentence {} attacked succeeded! Time: {:.2f}s , Success Rate:{:.2f}%"
                            .format(index + 1,
                                    time.time() - start_time,
                                    success / all_number * 100))

                        attack_metric = Metric()
                        attack_metric(
                            attack_pred_arc[attack_succeed_index].unsqueeze(
                                0)[mask],
                            attack_pred_rel[attack_succeed_index].unsqueeze(
                                0)[mask], arcs[mask], rels[mask])
                        if config.save_result_to_file:
                            attack_seq = attack_seqs[attack_succeed_index]
                            for span in range(src_span[0], src_span[1] + 1):
                                attack_seq[span] = "@" + attack_seq[span]
                            for span in range(tgt_span[0], tgt_span[1] + 1):
                                attack_seq[span] = "#" + attack_seq[span]
                            attack_corpus.append(
                                init_sentence(
                                    seqs[1:], attack_seq[1:], tags[1:],
                                    cast_list(arcs)[1:],
                                    self.vocab.id2rel(rels)[1:],
                                    cast_list(
                                        attack_pred_arc[attack_succeed_index])
                                    [1:],
                                    self.vocab.id2rel(
                                        attack_pred_rel[attack_succeed_index])
                                    [1:]))
                            print('\n'.join(
                                '\t'.join(map(str, i))
                                for i in zip(*(f for f in attack_corpus[-1]
                                               if f))))
                        succeed_flag = True
                        break
                    print(
                        'Pair {}/{}, src span:({},{}), tgt span:({},{}) failed! '
                        .format(span_pair_index + 1, len(spans_list_to_attack),
                                src_span[0], src_span[1], tgt_span[0],
                                tgt_span[1]))

                raw_metric = Metric()
                raw_metric(raw_pred_arc[mask], raw_pred_rel[mask], arcs[mask],
                           rels[mask])
                if not succeed_flag:
                    print(
                        "Sentence {} attacked failed! Time: {:.2f}s, Success Rate:{:.2f}%"
                        .format(index + 1,
                                time.time() - start_time,
                                success / all_number * 100))
                    attack_metric = raw_metric
                raw_metric_all += raw_metric
                attack_metric_all += attack_metric
            else:
                print(
                    "Sentence {} doesn't has enough valid spans. Time: {:.2f}".
                    format(index + 1,
                           time.time() - start_time))
        print("Before: {} After:{}".format(raw_metric_all, attack_metric_all))
        print("All: {}, Success: {}, Success Rate:{:.2f}%".format(
            all_number, success, success / all_number * 100))
        print("Average: {}".format(span_number / all_number))