Пример #1
0
    def select_by_deltalogit(self, instance, sentence):
        minl = self.config.hks_min_span_len
        maxl = self.config.hks_max_span_len
        # Compute the ``vulnerable'' values of each words
        words, tags, chars, arcs, rels = instance
        # sent_len = words.size(1)

        # The returned margins does not contain <ROOT>
        margins = self.compute_margin(instance)
        log(margins)

        # Count the vulnerable words in each span,
        # Select the most vulerable span as target.
        vul_scores = [0]
        for ele in margins:
            if ele > 0:
                vul_scores.append(10 - ele)
            else:
                vul_scores.append(0)
        spans = filter_spans(gen_spans(sentence), minl, maxl, True)
        span_vuls = list()
        # span_ratios = list()
        for span in spans:
            span_vuls.append(sum(vul_scores[span[0]:span[1] + 1]))
            # span_ratios.append(span_vuls[-1] / (span[1] + 1 - span[0]))

        pairs = []
        for tid, t in enumerate(spans):
            for s in spans:
                if check_gap(s, t, self.config.hks_span_gap):
                    pairs.append((span_vuls[tid], (t, s)))
        if len(pairs) == 0:
            return None
        spairs = sorted(pairs, key=lambda x: x[0], reverse=True)
        return list(zip(*spairs))[1]
Пример #2
0
    def select_tgt_span(self, instance, sentence, mode):
        minl = self.config.hks_min_span_len
        maxl = self.config.hks_max_span_len
        if mode == 'vul':
            # Compute the ``vulnerable'' values of each words
            words, tags, chars, arcs, rels = instance
            # sent_len = words.size(1)

            # The returned margins does not contain <ROOT>
            margins = self.compute_margin(instance)
            log(margins)

            # Count the vulnerable words in each span,
            # Select the most vulerable span as target.
            vul_margins = [0]
            for ele in margins:
                if 0 < ele < 1:
                    vul_margins.append(100)
                elif 1 <= ele < 2:
                    vul_margins.append(1)
                else:
                    vul_margins.append(-1)
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            span_vuls = list()
            # span_ratios = list()
            for span in spans:
                span_vuls.append(sum(vul_margins[span[0]:span[1] + 1]))
                # span_ratios.append(span_vuls[-1] / (span[1] + 1 - span[0]))

            pairs = []
            for tid, t in enumerate(spans):
                for s in spans:
                    if check_gap(s, t, self.config.hks_span_gap):
                        pairs.append((span_vuls[tid], (t, s)))
            if len(pairs) == 0:
                return None
            spairs = sorted(pairs, key=lambda x: x[0], reverse=True)
            return list(zip(*spairs))[1]

            tgt_picker = CherryPicker(lower_is_better=False)
            for span_i, span in enumerate(spans):
                tgt_picker.add(span_vuls[span_i], span)
            if tgt_picker.size == 0:
                log('Target span not found')
                return None
            _, _, tgt_span = tgt_picker.select_best_point()
            return tgt_span

        elif mode == "rdm":
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            if len(spans) == 0:
                return None
            return random.choice(spans)
        else:
            raise Exception
Пример #3
0
    def select_src_span(self, instance, sentence, tgt_span, mode):
        minl = self.config.hks_min_span_len
        maxl = self.config.hks_max_span_len
        gap = self.config.hks_span_gap
        if mode == "cls":
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            src_span = None
            src_picker = CherryPicker(lower_is_better=True)
            for span in spans:
                stgap = get_gap(span, tgt_span)
                if stgap >= gap:
                    src_picker.add(stgap, span)
            if src_picker.size == 0:
                log('Source span not found')
                return None
            _, _, src_span = src_picker.select_best_point()
            return src_span
        elif mode == 'rdm':
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            src_spans = []
            for span in spans:
                if check_gap(span, tgt_span, self.config.hks_span_gap):
                    src_spans.append(span)
            if len(src_spans) == 0:
                return None
            return random.choice(src_spans)
        elif mode == 'nom':
            sent_len = instance[0].size(1)
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            embed_grad = self.backward_loss(instance, ex_span_idx(tgt_span, sent_len))
            grad_norm = embed_grad.norm(dim=2)  # 1 x sent_len, <root> included

            src_picker = CherryPicker(lower_is_better=False)
            for span in spans:
                if check_gap(span, tgt_span, self.config.hks_span_gap):
                    src_picker.add(grad_norm[0][span[0]:span[1] + 1].sum(), span)
            if src_picker.size == 0:
                return None
            _, _, src_span = src_picker.select_best_point()
            return src_span
        else:
            raise Exception
Пример #4
0
    def init_logger(self, config):
        if config.logf == 'on':
            if config.hk_use_worker == 'on':
                worker_info = "-{}@{}".format(config.hk_num_worker,
                                              config.hk_worker_id)
            else:
                worker_info = ""
            log_config('{}'.format(config.mode),
                       log_path=config.workspace,
                       default_target='cf')
            from dpattack.libs.luna import log
        else:
            log = print

        log('[General Settings]')
        log(config)
        log('[Hack Settings]')
        for arg in config.kwargs:
            if arg.startswith('hk'):
                log(arg, '\t', config.kwargs[arg])
        log('------------------')
Пример #5
0
    def select_span_pair(self, instance, sentence):
        if self.config.hks_span_selection in \
                ["vulnom", "rdmnom", "vulcls", "rdmcls", "rdmrdm", "rdmnom"]:
            tgt_mode = self.config.hks_span_selection[:3]
            src_mode = self.config.hks_span_selection[3:]
            tgt_span = self.select_tgt_span(instance, sentence, tgt_mode)
            if tgt_span is None:
                log('Target span not found')
                return
            src_span = self.select_src_span(instance, sentence, tgt_span, src_mode)
            if src_span is None:
                log('Source span not found')
                return
            ret = [(tgt_span, src_span)]
        elif self.config.hks_span_selection in ['jacobian1', 'jacobian2']:
            ret = self.select_by_jacobian(instance, sentence, self.config.hks_span_selection)
        elif self.config.hks_span_selection in ['random']:
            ret = self.select_by_random(instance, sentence)
        elif self.config.hks_span_selection in ['deltalogit']:
            ret = self.select_by_deltalogit(instance, sentence)
        else:
            raise Exception

        if ret is None:
            log('Span pair not found')
            return
        return ret
Пример #6
0
    def meta_hack(self, instance, sentence):
        t0 = time.time()
        ram_reset("hk")

        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        words_text = self.vocab.id2word(words[0])

        succ = False

        if self.config.hks_color == "black":
            spans = filter_spans(gen_spans(sentence), self.config.hks_min_span_len,
                                 self.config.hks_max_span_len, True)
            hack_result = None
            for pair_id, src_span in enumerate(spans):
                log('[Chosen source] ', src_span, ' '.join(words_text[src_span[0]:src_span[1] + 1]))
                tgt_span_lst = [
                    ele for ele in spans if check_gap(ele, src_span, self.config.hks_span_gap)
                ]
                if len(tgt_span_lst) == 0:
                    continue
                hack_result = self.random_hack(instance, sentence, tgt_span_lst, src_span)
                if hack_result['succ'] == 1:
                    succ = True
                    log("Succ on source span {}.\n".format(src_span))
                    break
                log("Fail on source span {}.\n".format(src_span))
            if hack_result is None:
                log("Not enough span pairs")
                return None
        elif self.config.hks_color in ['white', 'grey']:
            pairs = self.select_span_pair(instance, sentence)
            if pairs is None:
                return None
            for pair_id, (tgt_span, src_span) in enumerate(pairs[:self.config.hks_topk_pair]):
                log('[Chosen span] ', 'tgt-', tgt_span,
                    ' '.join(words_text[tgt_span[0]:tgt_span[1] + 1]), 'src-', src_span,
                    ' '.join(words_text[src_span[0]:src_span[1] + 1]))

                if self.config.hks_color == "white":
                    hack_result = self.white_hack(instance, sentence, tgt_span, src_span)
                elif self.config.hks_color == "grey":
                    hack_result = self.random_hack(instance, sentence, [tgt_span], src_span)
                else:
                    raise Exception
                if hack_result['succ'] == 1:
                    succ = True
                    log("Succ on the span pair {}/{}.\n".format(pair_id, len(pairs)))
                    break
                log("Fail on the span pair {}/{}.\n".format(pair_id, len(pairs)))

            hack_result['meta_trial_pair'] = pair_id + 1
            hack_result['meta_total_pair'] = len(pairs)
            hack_result['meta_succ_trial_pair'] = pair_id + 1 if succ else 0
            hack_result['meta_succ_total_pair'] = len(pairs) if succ else 0

        t1 = time.time()
        log('Sentence cost {:.1f}s'.format(t1 - t0))
        hack_result['meta_time'] = t1 - t0

        return hack_result
Пример #7
0
    def __call__(self, config):
        self.init_logger(config)
        self.setup(config)
        self.blackbox_sub_idxs = [
            ele for ele in range(self.vocab.n_words) if ele not in
            [self.vocab.pad_index, self.vocab.unk_index, self.vocab.word_dict['<root>']]
        ]

        agg = Aggregator()
        for sid, (words, tags, chars, arcs, rels) in enumerate(self.loader):
            # if sid > 100:
            #     continue
            # zjh = [
            #     20, 41, 46, 117, 137, 143, 183, 198, 258, 295, 310, 350, 410, 421, 464, 485, 512,
            #     528, 544, 600, 601, 681, 702, 728, 735, 738, 762, 783, 794, 803, 805, 821, 844, 845,
            #     921, 931, 937, 939, 948, 962, 968, 975, 1019, 1044, 1068, 1069, 1096, 1104, 1121,
            #     1122, 1138, 1142, 1155, 1163, 1180, 1197, 1224, 1228, 1270, 1272, 1292, 1306, 1315,
            #     1317, 1342, 1345, 1393, 1400, 1431, 1478, 1503, 1522, 1524, 1526, 1608, 1677, 1729,
            #     1759, 1775, 1795, 1811, 1831, 1925, 1929, 1983, 1984, 2026, 2031, 2176, 2234, 2291,
            #     2296, 2318, 2327, 2330, 2342, 2343, 2355, 2360
            # ]
            # if sid + 1 not in zjh:
            #     continue
            # if sid > 600 - 1:
            #     continue

            words_text = self.vocab.id2word(words[0])
            tags_text = self.vocab.id2tag(tags[0])
            log('****** {}: \n{}\n{}'.format(sid, " ".join(words_text), " ".join(tags_text)))

            result = self.meta_hack(instance=(words, tags, chars, arcs, rels),
                                    sentence=self.corpus[sid])

            if result is None:
                continue

            # yapf: disable
            agg.aggregate(
                ("iters", result['iters']),
                ("time", result['time']),
                ("succ", result['succ']),
                ('best_iter', result['best_iter']),
                ("changed", result['num_changed']),
                ("att_id", result['att_id']),
                ("meta_time", result['meta_time']),
                ("meta_trial_pair", result['meta_trial_pair']),
                ("meta_total_pair", result['meta_total_pair']),
                ("meta_succ_trial_pair", result['meta_succ_trial_pair']),
                ("meta_succ_total_pair", result['meta_succ_total_pair']),
            )

            # # WARNING: SOME SENTENCE NOT SHOWN!
            if result:
                log('Show result from iter {}:'.format(result['best_iter']))
                log(result['logtable'])

            log('Aggregated result: '
                'iters(avg) {:.1f}, time(avg) {:.1f}s, meta_time(avg) {:.1f}s, '
                'succ rate {:.2f}% ({}/{}), best_iter(avg) {:.1f}, best_iter(std) {:.1f}, '
                'changed(avg) {:.1f}, '
                'total pair {}, trial pair {}, '
                'succ total pair {}, succ trial pair {}, '
                'succ att id {}'.format(
                    agg.mean('iters'), agg.mean('time'), agg.mean('meta_time'),
                    agg.mean('succ') * 100, agg.sum('succ'), agg.size,
                    agg.mean('best_iter'), agg.std('best_iter'),
                    agg.mean('changed'),
                    agg.sum('meta_total_pair'), agg.sum('meta_trial_pair'),
                    agg.sum('meta_succ_total_pair'), agg.sum('meta_succ_trial_pair'),
                    agg.aggregated(key='att_id', reduce=np.nanmean)
                ))
            log()
Пример #8
0
    def white_hack(self, instance, sentence, tgt_span, src_span):
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        raw_words = words.clone()
        var_words = words.clone()

        raw_metric = self.task.partial_evaluate(instance=(raw_words, tags, None, arcs, rels),
                                                mask_idxs=ex_span_idx(tgt_span, sent_len),
                                                mst=self.config.hks_mst == 'on')
        _, raw_arcs, _ = self.task.predict([(raw_words, tags, None)])

        forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index]
        change_positions__ = set()
        if self.config.hks_max_change > 0.9999:
            max_change_num = int(self.config.hks_max_change)
        else:
            max_change_num = max(1, int(self.config.hks_max_change * words.size(1)))
        iter_change_num = min(max_change_num, self.config.hks_iter_change)

        picker = CherryPicker(lower_is_better=True)
        t0 = time.time()
        picker.add(raw_metric, {"num_changed": 0, "logtable": 'No modification'})
        log('iter -1, uas {:.4f}'.format(raw_metric.uas))
        succ = False
        for iter_id in range(self.config.hks_steps):
            result = self.single_hack(instance=(var_words, tags, None, arcs, rels),
                                      raw_words=raw_words,
                                      raw_metric=raw_metric,
                                      raw_arcs=raw_arcs,
                                      src_span=src_span,
                                      tgt_span=tgt_span,
                                      iter_id=iter_id,
                                      forbidden_idxs__=forbidden_idxs__,
                                      change_positions__=change_positions__,
                                      max_change_num=max_change_num,
                                      iter_change_num=iter_change_num)
            if result['code'] == 200:
                var_words = result['words']
                picker.add(result['attack_metric'], {
                    'logtable': result['logtable'],
                    "num_changed": len(change_positions__)
                })
                if result['attack_metric'].uas < raw_metric.uas - 0.00001:
                    succ = True
                    log('Succeed in step {}'.format(iter_id))
                    break
            elif result['code'] == 404:
                log('FAILED')
                break
        t1 = time.time()
        best_iter, best_attack_metric, best_info = picker.select_best_point()

        return defaultdict(
            lambda: -1, {
                "succ": 1 if succ else 0,
                "raw_metric": raw_metric,
                "attack_metric": best_attack_metric,
                "iters": iter_id,
                "best_iter": best_iter,
                "num_changed": best_info['num_changed'],
                "time": t1 - t0,
                "logtable": best_info['logtable']
            })
Пример #9
0
    def single_hack(self,
                    instance,
                    mask_idxs,
                    iter_id,
                    raw_words,
                    raw_metric,
                    raw_arcs,
                    forbidden_idxs__: list,
                    change_positions__: set,
                    max_change_num,
                    verbose=False):
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        # Backward loss
        embed_grad = self.backward_loss(instance=instance, mask_idxs=mask_idxs)

        grad_norm = embed_grad.norm(dim=2)

        position_mask = [False for _ in range(words.size(1))]
        # Mask some positions
        for i in range(sent_len):
            if i not in mask_idxs or rels[0][i].item(
            ) == self.vocab.rel_dict['punct']:
                position_mask[i] = True
        # Check if the number of changed words exceeds the max value
        if len(change_positions__) >= max_change_num:
            for i in range(sent_len):
                if i not in change_positions__:
                    position_mask[i] = True

        for i in range(sent_len):
            if position_mask[i]:
                grad_norm[0][i] = -(grad_norm[0][i] + 1000)

        # print(grad_norm)

        # Select a word and forbid itself
        word_sids = []  # type: list[torch.Tensor]
        word_vids = []  # type: list[torch.Tensor]
        new_word_vids = []  # type: list[torch.Tensor]

        _, topk_idxs = grad_norm[0].topk(max_change_num)
        for ele in topk_idxs:
            word_sids.append(ele)

        for word_sid in word_sids:
            word_grad = embed_grad[0][word_sid]
            word_vid = words[0][word_sid]
            emb_to_rpl = self.parser.embed.weight[word_vid]

            # Find a word to change
            delta = word_grad / \
                torch.norm(word_grad) * self.config.hko_step_size
            changed = emb_to_rpl - delta

            # must_tag = self.vocab.tags[tags[0][word_sid].item()]
            must_tag = None
            # must_tag = HACK_TAGS['njr']
            # must_tag = 'CD'
            forbidden_idxs__.append(word_vid)
            change_positions__.add(word_sid.item())
            new_word_vid, repl_info = self.find_replacement(
                changed,
                must_tag,
                dist_measure=self.config.hko_dist_measure,
                forbidden_idxs__=forbidden_idxs__,
                repl_method='tagdict',
                words=words,
                word_sid=word_sid)
            word_vids.append(word_vid)
            new_word_vids.append(new_word_vid)

        new_words = words.clone()
        for i in range(len(word_vids)):
            if new_word_vids[i] is not None:
                new_words[0][word_sids[i]] = new_word_vids[i]
        """
            Evaluating the result
        """
        # print('START EVALUATING')
        # print([self.vocab.words[ele] for ele in self.forbidden_idxs__])
        metric = self.task.partial_evaluate(instance=(new_words, tags, None,
                                                      arcs, rels),
                                            mask_idxs=mask_idxs)

        def _gen_log_table():
            new_words_text = [self.vocab.words[i.item()] for i in new_words[0]]
            raw_words_text = [self.vocab.words[i.item()] for i in raw_words[0]]
            tags_text = [self.vocab.tags[i.item()] for i in tags[0]]
            att_tags, att_arcs, att_rels = self.task.predict([(new_words, tags,
                                                               None)])

            table = []
            for i in range(sent_len):
                gold_arc = int(arcs[0][i])
                raw_arc = 0 if i == 0 else raw_arcs[0][i - 1]
                att_arc = 0 if i == 0 else att_arcs[0][i - 1]
                mask_symbol = '&' if att_arc in mask_idxs or i in mask_idxs else ""

                table.append([
                    "{}{}".format(i, "@" if i in mask_idxs else ""),
                    raw_words_text[i], '>{}'.format(new_words_text[i])
                    if raw_words_text[i] != new_words_text[i] else "*",
                    tags_text[i], gold_arc, raw_arc, '>{}{}'.format(
                        att_arc, mask_symbol) if att_arc != raw_arc else '*',
                    grad_norm[0][i].item()
                ])
            return table

        if verbose:
            print('$$$$$$$$$$$$$$$$$$$$$$$$$$$')
            print('Iter {}'.format(iter_id))
            print(tabulate(_gen_log_table(), floatfmt=('.6f')))
            print('^^^^^^^^^^^^^^^^^^^^^^^^^^^')

        rpl_detail = ""
        for i in range(len(word_sids)):
            if new_word_vids[i] is not None:
                rpl_detail += "{}:{}->{}  ".format(
                    self.vocab.words[raw_words[0][word_sids[i]].item()],
                    self.vocab.words[word_vids[i].item()],
                    self.vocab.words[new_word_vids[i].item()])

        log(
            "iter {}, uas {:.4f}, ".format(iter_id, metric.uas),
            "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'],
                                                  repl_info['avgd'])
            if 'mind' in repl_info else '', rpl_detail)
        if metric.uas >= raw_metric.uas - .00001:
            info = 'Nothing'
        else:
            info = tabulate(_gen_log_table(), floatfmt='.6f')
        return {
            'code': 200,
            'words': new_words,
            'attack_metric': metric,
            'logtable': info,
        }
Пример #10
0
    def single_hack(self,
                    words,
                    tags,
                    arcs,
                    rels,
                    raw_words,
                    raw_metric,
                    raw_arcs,
                    forbidden_idxs__,
                    change_positions__,
                    orphans__,
                    verbose=False,
                    max_change_num=1,
                    iter_change_num=1,
                    iter_id=-1):
        sent_len = words.size(1)
        """
            Loss back-propagation
        """
        embed_grad = self.backward_loss(
            words,
            tags,
            arcs,
            rels,
            loss_based_on=self.config.hkw_loss_based_on)
        # Sometimes the loss/grad will be zero.
        # Especially in the case of applying pgd_freq>1 to small sentences:
        # e.g., the uas of projected version may be 83.33%
        # while under the case of the unprojected version, the loss is 0.
        if torch.sum(embed_grad) == 0.0:
            return {"code": 404, "info": "Loss is zero."}
        """
            Select and change a word
        """
        grad_norm = embed_grad.norm(dim=2)

        position_mask = [False for _ in range(words.size(1))]
        # Mask some positions by POS & <UNK>
        for i in range(sent_len):
            if self.vocab.tags[tags[0][i]] not in HACK_TAGS[
                    self.config.hkw_tag_type]:
                position_mask[i] = True
        # Mask some orphans
        for i in orphans__:
            position_mask[i] = True
        # Check if the number of changed words exceeds the max value
        if len(change_positions__) >= max_change_num:
            for i in range(sent_len):
                if i not in change_positions__:
                    position_mask[i] = True
        if all(position_mask):
            return {
                "code": 404,
                "info": "Constrained by tags, no valid word to replace."
            }

        for i in range(sent_len):
            if position_mask[i]:
                grad_norm[0][i] = -(grad_norm[0][i] + 1000)

        # Select a word and forbid itself
        word_sids = []  # type: list[torch.Tensor]
        word_vids = []  # type: list[torch.Tensor]
        new_word_vids = []  # type: list[torch.Tensor]

        _, topk_idxs = grad_norm[0].sort(descending=True)
        selected_words = elder_select(ordered_idxs=cast_list(topk_idxs),
                                      num_to_select=iter_change_num,
                                      selected=change_positions__,
                                      max_num=max_change_num)
        # The position mask will ensure that at least one word is legal,
        # but the second one may not be allowed
        selected_words = [
            ele for ele in selected_words if position_mask[ele] is False
        ]
        word_sids = torch.tensor(selected_words)

        for word_sid in word_sids:
            word_grad = embed_grad[0][word_sid]

            word_vid = words[0][word_sid]
            emb_to_rpl = self.parser.embed.weight[word_vid]
            forbidden_idxs__.append(word_vid.item())
            change_positions__.add(word_sid.item())
            # print(self.change_positions__)

            # Find a word to change with dynamically step
            # Note that it is possible that all words found are not as required, e.g.
            #   all neighbours have different tags.
            delta = word_grad / \
                torch.norm(word_grad) * self.config.hkw_step_size
            changed = emb_to_rpl - delta

            must_tag = self.vocab.tags[tags[0][word_sid].item()]
            new_word_vid, repl_info = self.find_replacement(
                changed,
                must_tag,
                dist_measure=self.config.hkw_dist_measure,
                forbidden_idxs__=forbidden_idxs__,
                repl_method=self.config.hkw_repl_method,
                words=words,
                word_sid=word_sid,
                raw_words=raw_words)

            word_vids.append(word_vid)
            new_word_vids.append(new_word_vid)

        new_words = words.clone()
        exist_change = False
        for i in range(len(word_vids)):
            if new_word_vids[i] is not None:
                new_words[0][word_sids[i]] = new_word_vids[i]
                exist_change = True

        if not exist_change:
            # if self.config.hkw_selection == 'orphan':
            #     orphans__.add(word_sids[0])
            #     log('iter {}, Add word {}\'s location to orphans.'.format(
            #         iter_id,
            #         self.vocab.words[raw_words[0][word_sid].item()]))
            #     return {
            #         'code': '200',
            #         'words': words,
            #         'atack_metric': 100.,
            #         'logtable': 'This will be never selected'
            #     }
            log('Attack failed.')
            return {
                'code': 404,
                'info': 'Neighbours of all selected words have different tags.'
            }

        # if new_word_vid is None:
        #     log('Attack failed.')
        #     return {'code': 404,
        #             'info': 'Neighbours of the selected words have different tags.'}

        # new_words = words.clone()
        # new_words[0][word_sid] = new_word_vid
        """
            Evaluating the result
        """
        # print('START EVALUATING')
        # print([self.vocab.words[ele] for ele in self.forbidden_idxs__])
        new_words_text = [self.vocab.words[i.item()] for i in new_words[0]]
        # print(new_words_txt)
        loss, metric = self.task.evaluate(
            [(new_words, tags, None, arcs, rels)],
            mst=self.config.hkw_mst == 'on')

        def _gen_log_table():
            new_words_text = [self.vocab.words[i.item()] for i in new_words[0]]
            raw_words_text = [self.vocab.words[i.item()] for i in raw_words[0]]
            tags_text = [self.vocab.tags[i.item()] for i in tags[0]]
            _, att_arcs, _ = self.task.predict([(new_words, tags, None)],
                                               mst=self.config.hkw_mst == 'on')

            table = []
            for i in range(sent_len):
                gold_arc = int(arcs[0][i])
                raw_arc = 0 if i == 0 else raw_arcs[0][i - 1]
                att_arc = 0 if i == 0 else att_arcs[0][i - 1]

                relevant_mask = '&' if \
                    raw_words[0][att_arc] != new_words[0][att_arc] or \
                    raw_words_text[i] != new_words_text[i] else ""
                table.append([
                    i, raw_words_text[i], '>{}'.format(new_words_text[i])
                    if raw_words_text[i] != new_words_text[i] else "*",
                    tags_text[i], gold_arc, raw_arc, '>{}{}'.format(
                        att_arc, relevant_mask) if att_arc != raw_arc else '*',
                    grad_norm[0][i].item()
                ])
            return table

        if verbose:
            print('$$$$$$$$$$$$$$$$$$$$$$$$$$$')
            print('Iter {}'.format(iter_id))
            print(tabulate(_gen_log_table(), floatfmt=('.6f')))
            print('^^^^^^^^^^^^^^^^^^^^^^^^^^^')

        rpl_detail = ""
        for i in range(len(word_sids)):
            if new_word_vids[i] is not None:
                rpl_detail += "{}:{}->{}  ".format(
                    self.vocab.words[raw_words[0][word_sids[i]].item()],
                    self.vocab.words[word_vids[i].item()],
                    self.vocab.words[new_word_vids[i].item()])
        log(
            "iter {}, uas {:.4f}, ".format(iter_id, metric.uas),
            "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'],
                                                  repl_info['avgd'])
            if 'mind' in repl_info else '', rpl_detail)
        if metric.uas >= raw_metric.uas - .00001:
            logtable = 'Nothing'
        else:
            logtable = tabulate(_gen_log_table(), floatfmt='.6f')
        return {
            'code': 200,
            'words': new_words,
            'attack_metric': metric,
            'logtable': logtable,
            # "forbidden_idxs__": forbidden_idxs__,
            # "change_positions__": change_positions__,
        }
Пример #11
0
    def single_hack(self,
                    words, tags, chars, arcs, rels,
                    raw_chars, raw_metric, raw_arcs,
                    forbidden_idxs__,
                    change_positions__,
                    verbose=False,
                    max_change_num=1,
                    iter_change_num=1,
                    iter_id=-1):
        sent_len = words.size(1)

        """
            Loss back-propagation
        """
        char_grads, word_grads = self.backward_loss(words, chars, arcs, rels)

        """
            Select and change a word
        """
        word_grad_norm = word_grads.norm(dim=-1)   # 1 x length
        # 1 x length x max_char_length
        char_grad_norm = char_grads.norm(dim=-1)
        woca_grad_norm = char_grad_norm.max(dim=-1)    # 1 x length

        position_mask = [False for _ in range(words.size(1))]
        # Mask some positions by POS & <UNK>
        for i in range(sent_len):
            if rels[0][i].item() == self.vocab.rel_dict['punct']:
                position_mask[i] = True
        # Check if the number of changed words exceeds the max value
        if len(change_positions__) >= max_change_num:
            for i in range(sent_len):
                if i not in change_positions__:
                    position_mask[i] = True
        if all(position_mask):
            return {"code": 404, "info": "Constrained by tags, no valid word to replace."}

        for i in range(sent_len):
            if position_mask[i]:
                word_grad_norm[0][i] = -(word_grad_norm[0][i] + 1000)
                char_grad_norm[0][i] = -(char_grad_norm[0][i] + 1000)

        # Select a word and forbid itself
        word_sids = []
        char_wids = []
        char_vids = []
        new_char_vids = []

        if self.config.hkc_selection == 'elder':
            _, topk_idxs = word_grad_norm[0].sort(descending=True)
            selected_words = elder_select(ordered_idxs=cast_list(topk_idxs),
                                          num_to_select=iter_change_num,
                                          selected=change_positions__,
                                          max_num=max_change_num)
            selected_chars = []
            for ele in selected_words:
                if ele in change_positions__:
                    selected_chars.append(change_positions__[ele])
                else:
                    wcid = char_grad_norm[0][ele].argmax()
                    selected_chars.append(wcid)
            word_sids = torch.tensor(selected_words)
            char_wids = torch.tensor(selected_chars)
        elif self.config.hkc_selection == 'young':
            _, topk_idxs = word_grad_norm[0].sort(descending=True)
            selected_words, ex_words = young_select(ordered_idxs=cast_list(topk_idxs),
                                                    num_to_select=iter_change_num,
                                                    selected=change_positions__,
                                                    max_num=max_change_num)
            for ele in ex_words:
                change_positions__.pop(ele)
                forbidden_idxs__.pop(ele)
                chars[0][ele] = raw_chars[0][ele]
                log('Drop elder replacement',
                    self.vocab.words[words[i].item()])
            selected_chars = []
            for ele in selected_words:
                if ele in change_positions__:
                    selected_chars.append(change_positions__[ele])
                else:
                    wcid = char_grad_norm[0][ele].argmax()
                    selected_chars.append(wcid)
            word_sids = torch.tensor(selected_words)
            char_wids = torch.tensor(selected_chars)
        else:
            raise Exception

        for word_sid, char_wid in zip(word_sids, char_wids):

            char_grad = char_grads[0][word_sid][char_wid]
            char_vid = chars[0][word_sid][char_wid]

            emb_to_rpl = self.parser.char_lstm.embed.weight[char_vid]
            forbidden_idxs__[word_sid.item()].append(char_vid.item())
            change_positions__[word_sid.item()] = char_wid.item()

            # Find a word to change with dynamically step
            # Note that it is possible that all words found are not as required, e.g.
            #   all neighbours have different tags.
            delta = char_grad / \
                torch.norm(char_grad) * self.config.hkc_step_size
            changed = emb_to_rpl - delta

            dist = {'euc': euc_dist, 'cos': cos_dist}[
                self.config.hkc_dist_measure](changed, self.parser.char_lstm.embed.weight)
            vals, idxs = dist.sort()
            for ele in idxs:
                if ele.item() not in forbidden_idxs__[word_sid.item()] and \
                        ele.item() not in [self.vocab.pad_index, self.vocab.unk_index]:
                    new_char_vid = ele
                    break

            char_vids.append(char_vid)
            new_char_vids.append(new_char_vid)

        new_chars = chars.clone()
        for i in range(len(word_sids)):
            new_chars[0][word_sids[i]][char_wids[i]] = new_char_vids[i]

        # log(dict(forbidden_idxs__))
        # log(change_positions__)

        """
            Evaluating the result
        """
        # print('START EVALUATING')
        # print([self.vocab.words[ele] for ele in self.forbidden_idxs__])
        new_chars_text = [self.vocab.id2char(ele) for ele in new_chars[0]]
        # print(new_words_txt)
        loss, metric = self.task.evaluate(
            [(words, tags, new_chars, arcs, rels)],
            mst=self.config.hkc_mst == 'on')

        def _gen_log_table():
            new_words_text = [self.vocab.id2char(ele) for ele in new_chars[0]]
            raw_words_text = [self.vocab.id2char(ele) for ele in raw_chars[0]]
            tags_text = [self.vocab.tags[ele.item()] for ele in tags[0]]
            _, att_arcs, _ = self.task.predict([(words, tags, new_chars)],
                                               mst=self.config.hkc_mst == 'on')

            table = []
            for i in range(sent_len):
                gold_arc = int(arcs[0][i])
                raw_arc = 0 if i == 0 else raw_arcs[0][i - 1]
                att_arc = 0 if i == 0 else att_arcs[0][i - 1]

                relevant_mask = '&' if \
                    raw_words_text[att_arc] != new_words_text[att_arc] \
                    or raw_words_text[i] != new_words_text[i] \
                    else ""
                table.append([
                    i,
                    raw_words_text[i],
                    '>{}'.format(
                        new_words_text[i]) if raw_words_text[i] != new_words_text[i] else "*",
                    tags_text[i],
                    gold_arc,
                    raw_arc,
                    '>{}{}'.format(
                        att_arc, relevant_mask) if att_arc != raw_arc else '*',
                    word_grad_norm[0][i].item()
                ])
            return table

        if verbose:
            print('$$$$$$$$$$$$$$$$$$$$$$$$$$$')
            print('Iter {}'.format(iter_id))
            print(tabulate(_gen_log_table(), floatfmt=('.6f')))
            print('^^^^^^^^^^^^^^^^^^^^^^^^^^^')

        rpl_detail = ""
        for i in range(len(word_sids)):
            rpl_detail += "{}:{}->{}  ".format(
                self.vocab.id2char(raw_chars[0, word_sids[i]]),
                self.vocab.id2char(chars[0, word_sids[i]]),
                self.vocab.id2char(new_chars[0, word_sids[i]]))
        log("iter {}, uas {:.4f}, ".format(iter_id, metric.uas),
            # "mind {:6.3f}, avgd {:6.3f}, ".format(
            #     repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '',
            rpl_detail
            )
        if metric.uas >= raw_metric.uas - .00001:
            logtable = 'Nothing'
        else:
            logtable = tabulate(_gen_log_table(), floatfmt='.6f')
        return {
            'code': 200,
            'chars': new_chars,
            'attack_metric': metric,
            'logtable': logtable,
            # "forbidden_idxs__": forbidden_idxs__,
            # "change_positions__": change_positions__,
        }
Пример #12
0
    def hack(self, instance):
        words, tags, chars, arcs, rels = instance
        _, raw_metric = self.task.evaluate([(words, tags, chars, arcs, rels)],
                                           mst=self.config.hkc_mst == 'on')
        _, raw_arcs, _ = self.task.predict([(words, tags, chars)],
                                           mst=self.config.hkc_mst == 'on')

        # char_grads, word_grad = self.backward_loss(words, chars, arcs, rels)

        forbidden_idxs__ = defaultdict(
            lambda: deque(maxlen=5))    # word_sid -> deque()
        change_positions__ = dict()  # word_sid -> char_wid
        if self.config.hkc_max_change > 0.9999:
            max_change_num = int(self.config.hkc_max_change)
        else:
            max_change_num = max(
                1, int(self.config.hkc_max_change * words.size(1)))
        iter_change_num = min(max_change_num, self.config.hkc_iter_change)

        raw_chars = chars.clone()
        var_chars = chars.clone()

        # HIGHLIGHT: ITERATION
        t0 = time.time()
        picker = CherryPicker(lower_is_better=True,
                              compare_fn=lambda m1, m2: m1.uas - m2.uas)
        # iter 0 -> raw
        picker.add(raw_metric, {
            "num_changed": 0,
            "logtable": 'No modification'
        })
        for iter_id in range(1, self.config.hkc_steps):
            result = self.single_hack(
                words, tags, var_chars, arcs, rels,
                raw_chars=raw_chars, raw_metric=raw_metric, raw_arcs=raw_arcs,
                verbose=False,
                max_change_num=max_change_num,
                iter_change_num=iter_change_num,
                iter_id=iter_id,
                forbidden_idxs__=forbidden_idxs__,
                change_positions__=change_positions__,
            )
            # Fail
            if result['code'] == 404:
                log('Stop in step {}, info: {}'.format(
                    iter_id, result['info']))
                break
            # Success
            if result['code'] == 200:
                picker.add(result['attack_metric'], {
                    # "words": result['new_words'],
                    "logtable": result['logtable'],
                    "num_changed": len(change_positions__)
                })
                if result['attack_metric'].uas < raw_metric.uas - self.config.hkw_eps:
                    log('Succeed in step {}'.format(iter_id))
                    break
            var_chars = result['chars']
        t1 = time.time()

        best_iter, best_attack_metric, best_info = picker.select_best_point()

        return {
            "raw_metric": raw_metric,
            "attack_metric": best_attack_metric,
            "iters": iter_id,
            "best_iter": best_iter,
            "num_changed": best_info['num_changed'],
            "time": t1 - t0,
            "logtable": best_info['logtable']
        }
Пример #13
0
    def __call__(self, config):
        _, loader = self.pre_attack(config)
        self.parser.register_backward_hook(self.extract_embed_grad)
        # log_config('whitelog.txt',
        #            log_path=config.workspace,
        #            default_target='cf')

        train_corpus = Corpus.load(config.ftrain)
        self.tag_filter = generate_tag_filter(train_corpus, self.vocab)
        # corpus = Corpus.load(config.fdata)
        # dataset = TextDataset(self.vocab.numericalize(corpus, True))
        # # set the data loader
        # loader = DataLoader(dataset=dataset,
        #                     collate_fn=collate_fn)

        # def embed_hook(module, grad_in, grad_out):
        #     self.vals["embed_grad"] = grad_out[0]

        # dpattack.pretrained.register_backward_hook(embed_hook)

        raw_metrics = Metric()
        attack_metrics = Metric()

        log('dist measure', config.hk_dist_measure)

        # batch size == 1
        for sid, (words, tags, chars, arcs, rels) in enumerate(loader):
            # if sid > 10:
            #     break

            raw_words = words.clone()
            words_text = self.get_seqs_name(words)
            tags_text = self.get_tags_name(tags)

            log('****** {}: \n\t{}\n\t{}'.format(sid, " ".join(words_text),
                                                 " ".join(tags_text)))

            self.vals['forbidden'] = [
                self.vocab.unk_index, self.vocab.pad_index
            ]
            for pgdid in range(100):
                result = self.single_hack(words,
                                          tags,
                                          arcs,
                                          rels,
                                          dist_measure=config.hk_dist_measure,
                                          raw_words=raw_words)
                if result['code'] == 200:
                    raw_metrics += result['raw_metric']
                    attack_metrics += result['attack_metric']
                    log('attack successfully at step {}'.format(pgdid))
                    break
                elif result['code'] == 404:
                    raw_metrics += result['raw_metric']
                    attack_metrics += result['raw_metric']
                    log('attack failed at step {}'.format(pgdid))
                    break
                elif result['code'] == 300:
                    if pgdid == 99:
                        raw_metrics += result['raw_metric']
                        attack_metrics += result['raw_metric']
                        log('attack failed at step {}'.format(pgdid))
                    else:
                        words = result['words']
            log()

            log('Aggregated result: {} --> {}'.format(raw_metrics,
                                                      attack_metrics),
                target='cf')
Пример #14
0
    def single_hack(self,
                    words,
                    tags,
                    arcs,
                    rels,
                    raw_words,
                    target_tags=['NN', 'JJ'],
                    dist_measure='enc',
                    verbose=False):
        vocab = self.vocab
        parser = self.parser
        loss_fn = self.task.criterion
        embed = parser.embed.weight

        raw_loss, raw_metric = self.task.evaluate([(raw_words, tags, arcs,
                                                    rels)])

        # backward the loss
        parser.zero_grad()
        mask = words.ne(vocab.pad_index)
        mask[:, 0] = 0
        s_arc, s_rel = parser(words, tags)
        s_arc, s_rel = s_arc[mask], s_rel[mask]
        gold_arcs, gold_rels = arcs[mask], rels[mask]
        loss = loss_fn(s_arc, gold_arcs)
        loss.backward()

        grad_norm = self.vals['embed_grad'].norm(dim=2)

        # Select a word to attack by its POS and norm
        exist_target_tag = False
        for i in range(tags.size(1)):
            if vocab.tags[tags[0][i]] not in target_tags:
                grad_norm[0][i] -= 99999.
            else:
                exist_target_tag = True
        if not exist_target_tag:
            return {"code": 404, 'raw_metric': raw_metric}
        word_sid = grad_norm[0].argmax()
        word_vid = words[0][word_sid]
        max_grad = self.vals['embed_grad'][0][word_sid]

        # Forbid the word itself to be selected
        self.vals['forbidden'].append(word_vid.item())

        # Find a word to change
        changed = embed[word_vid] - max_grad * 1
        dist = {
            'euc':
            lambda: (changed - embed).pow(2).sum(dim=1),
            'cos':
            lambda: torch.nn.functional.cosine_similarity(
                embed, changed.repeat(embed.size(0), 1), dim=1)
        }[dist_measure]()

        must_tag = vocab.tags[tags[0][word_sid].item()]
        legal_tag_index = self.tag_filter[must_tag].to(dist.device)
        legal_tag_mask = dist.new_zeros(dist.size()) \
            .index_fill_(0, legal_tag_index, 1.).byte()

        dist.masked_fill_(1 - legal_tag_mask, 99999.)

        for ele in self.vals['forbidden']:
            dist[ele] = 99999.
        word_vid_to_rpl = dist.argmin()
        # A forbidden word maybe chosen if all words are forbidden(99999.)
        if word_vid_to_rpl.item() in self.vals['forbidden']:
            log('Attack failed.')
            return {'code': 404, 'raw_metric': raw_metric}

        # =====================
        # Evaluating the result
        # =====================
        repl_words = words.clone()
        repl_words[0][word_sid] = word_vid_to_rpl

        repl_words_text = [vocab.words[i.item()] for i in repl_words[0]]
        raw_words_text = [vocab.words[i.item()] for i in raw_words[0]]
        tags_text = [vocab.tags[i.item()] for i in tags[0]]
        if verbose:
            print('After Attacking: \n\t{}\n\t{}'.format(
                " ".join(repl_words_text), " ".join(tags_text)))
        pred_tags, pred_arcs, pred_rels = self.task.predict([(repl_words, tags)
                                                             ])
        loss, metric = self.task.evaluate([(repl_words, tags, arcs, rels)])
        table = []
        for i in range(words.size(1)):
            gold_arc = int(arcs[0][i])
            pred_arc = 0 if i == 0 else pred_arcs[0][i - 1]
            table.append([
                i, repl_words_text[i], raw_words_text[i] if
                raw_words_text[i] != repl_words_text[i] else "*", tags_text[i],
                gold_arc, pred_arc if pred_arc != gold_arc else '*',
                grad_norm[0][i].item()
            ])

        if verbose:
            print('{} --> {}'.format(vocab.words[word_vid.item()],
                                     vocab.words[word_vid_to_rpl.item()]))
            print(tabulate(table, floatfmt=('.6f')))
            print(metric)
            print('**************************')
        if metric.uas > raw_metric.uas - 0.1:
            return {
                'code': 300,
                'words': repl_words,
                'raw_metric': raw_metric,
                'attack_metric': metric
            }
        else:
            log(tabulate(table, floatfmt=('.6f')))
            log('Result {} --> {}'.format(raw_metric.uas, metric.uas),
                target='cf')
            return {
                'code': 200,
                'words': repl_words,
                'raw_metric': raw_metric,
                'attack_metric': metric
            }
Пример #15
0
    def __call__(self, config):
        if config.logf == 'on':
            log_config('hackoutside',
                       log_path=config.workspace,
                       default_target='cf')
            from dpattack.libs.luna import log
        else:
            log = print

        log('[General Settings]')
        log(config)
        log('[Hack Settings]')
        for arg in config.kwargs:
            if arg.startswith('hks'):
                log(arg, '\t', config.kwargs[arg])
        log('------------------')

        self.setup(config)

        raw_metrics = ParserMetric()
        attack_metrics = ParserMetric()

        agg = Aggregator()
        for sid, (words, tags, chars, arcs, rels) in enumerate(self.loader):
            # if sid > 100:
            #     continue

            words_text = self.vocab.id2word(words[0])
            tags_text = self.vocab.id2tag(tags[0])
            log('****** {}: \n{}\n{}'.format(sid, " ".join(words_text),
                                             " ".join(tags_text)))

            result = self.hack(instance=(words, tags, chars, arcs, rels),
                               sentence=self.corpus[sid])

            if result is None:
                continue
            else:
                raw_metrics += result['raw_metric']
                attack_metrics += result['attack_metric']

            agg.aggregate(
                ("iters", result['iters']), ("time", result['time']),
                ("fail",
                 abs(result['attack_metric'].uas - result['raw_metric'].uas) <
                 1e-4), ('best_iter', result['best_iter']),
                ("changed", result['num_changed']))

            # WARNING: SOME SENTENCE NOT SHOWN!
            if result:
                log('Show result from iter {}:'.format(result['best_iter']))
                log(result['logtable'])

            log('Aggregated result: {} --> {}, '
                'iters(avg) {:.1f}, time(avg) {:.1f}s, '
                'fail rate {:.2f}, best_iter(avg) {:.1f}, best_iter(std) {:.1f}, '
                'changed(avg) {:.1f}'.format(raw_metrics, attack_metrics,
                                             agg.mean('iters'),
                                             agg.mean('time'),
                                             agg.mean('fail'),
                                             agg.mean('best_iter'),
                                             agg.std('best_iter'),
                                             agg.mean('changed')))
            log()
Пример #16
0
    def single_hack(self, instance,
                    src_span, tgt_span,
                    iter_id,
                    raw_words, raw_metric, raw_arcs,
                    forbidden_idxs__: list,
                    change_positions__: set,
                    max_change_num,
                    iter_change_num,
                    verbose=False):
        # yapf: enable
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        # Backward loss
        embed_grad = self.backward_loss(instance=instance,
                                        mask_idxs=ex_span_idx(tgt_span, sent_len),
                                        verbose=True)
        grad_norm = embed_grad.norm(dim=2)
        if self.config.hks_word_random == "on":
            grad_norm = torch.rand(grad_norm.size(), device=grad_norm.device)

        position_mask = [False for _ in range(words.size(1))]
        # Mask some positions
        for i in range(sent_len):
            if rels[0][i].item(
            ) == self.vocab.rel_dict['punct'] or not src_span[0] <= i <= src_span[1]:
                position_mask[i] = True
        # Check if the number of changed words exceeds the max value
        if len(change_positions__) >= max_change_num:
            for i in range(sent_len):
                if i not in change_positions__:
                    position_mask[i] = True

        for i in range(sent_len):
            if position_mask[i]:
                grad_norm[0][i] = -(grad_norm[0][i] + 1000)

        # print(grad_norm)

        # Select a word and forbid itself
        word_sids = []  # type: list[torch.Tensor]
        word_vids = []  # type: list[torch.Tensor]
        new_word_vids = []  # type: list[torch.Tensor]

        # _, topk_idxs = grad_norm[0].topk(min(max_change_num, len(src_idxs)))
        # for ele in topk_idxs:
        #     word_sids.append(ele)
        _, topk_idxs = grad_norm[0].sort(descending=True)
        selected_words = elder_select(ordered_idxs=cast_list(topk_idxs),
                                      num_to_select=iter_change_num,
                                      selected=change_positions__,
                                      max_num=max_change_num)
        # The position mask will ensure that at least one word is legal,
        # but the second one may not be allowed
        selected_words = [ele for ele in selected_words if position_mask[ele] is False]
        word_sids = torch.tensor(selected_words)

        for word_sid in word_sids:
            word_vid = words[0][word_sid]
            emb_to_rpl = self.parser.embed.weight[word_vid]

            if self.config.hks_step_size > 0:
                word_grad = embed_grad[0][word_sid]
                delta = word_grad / \
                    torch.norm(word_grad) * self.config.hks_step_size
                changed = emb_to_rpl - delta

                tag_type = self.config.hks_constraint
                if tag_type == 'any':
                    must_tag = None
                elif tag_type == 'same':
                    must_tag = self.vocab.tags[tags[0][word_sid].item()]
                elif re.match("[njvri]+", tag_type):
                    must_tag = HACK_TAGS[tag_type]
                else:
                    raise Exception
                forbidden_idxs__.append(word_vid)
                change_positions__.add(word_sid.item())
                new_word_vid, repl_info = self.find_replacement(
                    changed,
                    must_tag,
                    dist_measure=self.config.hks_dist_measure,
                    forbidden_idxs__=forbidden_idxs__,
                    repl_method='tagdict',
                    words=words,
                    word_sid=word_sid)
            else:
                new_word_vid = random.randint(0, self.vocab.n_words)
                while new_word_vid in [self.vocab.pad_index, self.vocab.word_dict["<root>"]]:
                    new_word_vid = random.randint(0, self.vocab.n_words - 1)
                new_word_vid = torch.tensor(new_word_vid, device=words.device)
                repl_info = {}
                change_positions__.add(word_sid.item())

            word_vids.append(word_vid)
            new_word_vids.append(new_word_vid)

            # log(delta @ (self.parser.embed.weight[new_word_vid] - self.parser.embed.weight[word_vid]))
            # log()

        new_words = words.clone()
        for i in range(len(word_vids)):
            if new_word_vids[i] is not None:
                new_words[0][word_sids[i]] = new_word_vids[i]
        """
            Evaluating the result
        """
        # print('START EVALUATING')
        # print([self.vocab.words[ele] for ele in self.forbidden_idxs__])
        metric = self.task.partial_evaluate(instance=(new_words, tags, None, arcs, rels),
                                            mask_idxs=ex_span_idx(tgt_span, sent_len),
                                            mst=self.config.hks_mst == 'on')
        att_tags, att_arcs, att_rels = self.task.predict([(new_words, tags, None)],
                                                         mst=self.config.hks_mst == 'on')

        # if verbose:
        #     print('$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        #     print('Iter {}'.format(iter_id))
        #     print(tabulate(_gen_log_table(), floatfmt=('.6f')))
        #     print('^^^^^^^^^^^^^^^^^^^^^^^^^^^')

        rpl_detail = ""
        for i in range(len(word_sids)):
            if new_word_vids[i] is not None:
                rpl_detail += "{}:{}->{}  ".format(
                    self.vocab.words[raw_words[0][word_sids[i]].item()],
                    self.vocab.words[word_vids[i].item()],
                    self.vocab.words[new_word_vids[i].item()])

        log(
            "iter {}, uas {:.4f}, ".format(iter_id, metric.uas),
            "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'], repl_info['avgd'])
            if 'mind' in repl_info else '', rpl_detail)
        if metric.uas >= raw_metric.uas - .00001:
            info = 'Nothing'
        else:
            info = tabulate(self._gen_log_table(raw_words, new_words, tags, arcs, rels, raw_arcs,
                                                att_arcs, src_span, tgt_span),
                            floatfmt='.6f')
        return {
            'code': 200,
            'words': new_words,
            'attack_metric': metric,
            'logtable': info,
        }
Пример #17
0
    def hack(self, instance):
        words, tags, chars, arcs, rels = instance
        _, raw_metric = self.task.evaluate([(words, tags, chars, arcs, rels)],
                                           mst=self.config.hkw_mst == 'on')
        _, raw_arcs, _ = self.task.predict([(words, tags, None)],
                                           mst=self.config.hkw_mst == 'on')

        # Setup some states before attacking a sentence
        # WARNING: Operations on variables n__global_ramambc__" passed to a function
        #  are in-placed! Never try to save an internal state of the variable.
        forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index]
        change_positions__ = set()
        orphans__ = set()
        if self.config.hkw_max_change > 0.9999:
            max_change_num = int(self.config.hkw_max_change)
        else:
            max_change_num = max(
                1, int(self.config.hkw_max_change * words.size(1)))
        iter_change_num = min(max_change_num, self.config.hkw_iter_change)

        var_words = words.clone()
        raw_words = words.clone()

        # HIGHLIGHT: ITERATION
        t0 = time.time()
        picker = CherryPicker(lower_is_better=True,
                              compare_fn=lambda m1, m2: m1.uas - m2.uas)
        # iter 0 -> raw
        picker.add(raw_metric, {
            "num_changed": 0,
            "logtable": 'No modification'
        })
        for iter_id in range(1, self.config.hkw_steps):
            result = self.single_hack(var_words,
                                      tags,
                                      arcs,
                                      rels,
                                      raw_words=raw_words,
                                      raw_metric=raw_metric,
                                      raw_arcs=raw_arcs,
                                      verbose=False,
                                      max_change_num=max_change_num,
                                      iter_change_num=iter_change_num,
                                      iter_id=iter_id,
                                      forbidden_idxs__=forbidden_idxs__,
                                      change_positions__=change_positions__,
                                      orphans__=orphans__)
            # Fail
            if result['code'] == 404:
                log('Stop in step {}, info: {}'.format(iter_id,
                                                       result['info']))
                break
            # Success
            if result['code'] == 200:
                picker.add(
                    result['attack_metric'],
                    {
                        # "words": result['new_words'],
                        "logtable": result['logtable'],
                        "num_changed": len(change_positions__)
                    })
                if result[
                        'attack_metric'].uas < raw_metric.uas - self.config.hkw_eps:
                    log('Succeed in step {}'.format(iter_id))
                    break
            var_words = result['words']
            # forbidden_idxs__ = result['forbidden_idxs__']
            # change_positions__ = result['change_positions__']
        t1 = time.time()

        best_iter, best_attack_metric, best_info = picker.select_best_point()

        return {
            "raw_metric": raw_metric,
            "attack_metric": best_attack_metric,
            "iters": iter_id,
            "best_iter": best_iter,
            "num_changed": best_info['num_changed'],
            "time": t1 - t0,
            "logtable": best_info['logtable']
        }
Пример #18
0
    def select_by_jacobian(self, instance, sentence, mode):
        sent_len = instance[0].size(1)
        spans = filter_spans(gen_spans(sentence), self.config.hks_min_span_len,
                             self.config.hks_max_span_len, True)
        if mode == 'jacobian1':
            table = []
            pairs = []
            for tgt_span in spans:
                embed_grad = self.backward_loss(instance, ex_span_idx(tgt_span, sent_len))
                grad_norm = embed_grad.norm(dim=-1)  # 1 x (sen_len - 1)
                row = [tgt_span]
                for src_span in spans:
                    if check_gap(src_span, tgt_span, self.config.hks_span_gap):
                        # <root> not included
                        norm = grad_norm[0][src_span[0] - 1:src_span[1]].sum().item()
                        pairs.append((norm, (tgt_span, src_span)))
                        row.append("{:4.2f}".format(norm))
                    else:
                        row.append("-")
                table.append(row)
            log(tabulate(table, headers=["t↓"] + spans))
            if len(pairs) == 0:
                return None
            spairs = sorted(pairs, key=lambda x: x[0], reverse=True)

            return list(zip(*spairs))[1]
        elif mode == 'jacobian2':
            # log('valid', spans)
            # margins = self.compute_margin(instance)
            # log(margins)
            norms = []
            # WARNING: This is rather slow!
            idxs_to_backward = []
            for span in spans:
                idxs_to_backward.extend(list(range(span[0], span[1] + 1)))
            for i in range(1, sent_len):
                if i in idxs_to_backward:
                    embed_grad = self.backward_loss(instance,
                                                    [_ for _ in range(1, sent_len) if _ != i])
                    grad_norm = embed_grad.norm(dim=-1)[:, 1:]
                    norms.append(grad_norm)
                else:
                    norms.append(torch.zeros_like(instance[0]).float()[:, 1:])
            norms = torch.cat(norms)  # size: sent_len x sent_len
            # log("t↓", flt2str(list(range(sent_len)), cat=" ", fmt=":3"))
            # for i, norm in enumerate(norms):
            #     log("{:2}".format(i + 1),
            #         flt2str(cast_list(norm), cat=" ", fmt=":3.1f"),
            #         flt2str(sum(norm), fmt=":3.1f"))

            table = []
            for t in spans:
                row = [t]
                for s in spans:
                    if check_gap(s, t, self.config.hks_span_gap):
                        row.append("{:4.2f}".format(norms[t[0] - 1:t[1],
                                                          s[0] - 1:s[1]].sum().item()))
                    else:
                        row.append('-')
                table.append(row)
            log(tabulate(table, headers=["t↓"] + spans))

            pairs = []
            for t in spans:
                for s in spans:
                    if check_gap(s, t, self.config.hks_span_gap):
                        pairs.append((norms[t[0] - 1:t[1], s[0] - 1:s[1]].sum().item(), (t, s)))
            if len(pairs) == 0:
                return None
            spairs = sorted(pairs, key=lambda x: x[0], reverse=True)

            return list(zip(*spairs))[1]
Пример #19
0
    def __call__(self, config):
        self.init_logger(config)
        self.setup(config)

        if self.config.hk_use_worker == 'on':
            start_sid, end_sid = locate_chunk(len(self.loader),
                                              self.config.hk_num_worker,
                                              self.config.hk_worker_id)
            log('Run code on a chunk [{}, {})'.format(start_sid, end_sid))

        raw_metrics = ParserMetric()
        attack_metrics = ParserMetric()

        agg = Aggregator()
        for sid, (words, tags, chars, arcs, rels) in enumerate(self.loader):
            # if sid in [0, 1, 2, 3, 4]:
            #     continue
            # if sid < 1434:
            #     continue
            if self.config.hk_use_worker == 'on':
                if sid < start_sid or sid >= end_sid:
                    continue
            if self.config.hk_training_set == 'on' and words.size(1) > 50:
                log('Skip sentence {} whose length is {}(>50).'.format(
                    sid, words.size(1)))
                continue

            if words.size(1) < 5:
                log('Skip sentence {} whose length is {}(<5).'.format(
                    sid, words.size(1)))
                continue

            words_text = self.vocab.id2word(words[0])
            tags_text = self.vocab.id2tag(tags[0])
            log('****** {}: \n{}\n{}'.format(sid, " ".join(words_text),
                                             " ".join(tags_text)))

            # hack it!
            result = self.hack(instance=(words, tags, chars, arcs, rels))

            # aggregate information
            raw_metrics += result['raw_metric']
            attack_metrics += result['attack_metric']
            agg.aggregate(
                ("iters", result['iters']), ("time", result['time']),
                ("fail",
                 abs(result['attack_metric'].uas - result['raw_metric'].uas) <
                 1e-4), ('best_iter', result['best_iter']),
                ("changed", result['num_changed']))

            # log some information
            log('Show result from iter {}, changed num {}:'.format(
                result['best_iter'], result['num_changed']))
            log(result['logtable'])

            log('Aggregated result: {} --> {}, '
                'iters(avg) {:.1f}, time(avg) {:.1f}s, '
                'fail rate {:.2f}, best_iter(avg) {:.1f}, best_iter(std) {:.1f}, '
                'changed(avg) {:.1f}'.format(raw_metrics, attack_metrics,
                                             agg.mean('iters'),
                                             agg.mean('time'),
                                             agg.mean('fail'),
                                             agg.mean('best_iter'),
                                             agg.std('best_iter'),
                                             agg.mean('changed')))
            log()
Пример #20
0
    def hack(self, instance, sentence):
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)
        spans = gen_spans(sentence)
        valid_spans = list(filter(lambda ele: 5 <= ele[1] - ele[0] <= 8,
                                  spans))
        if len(valid_spans) == 0:
            log("Attack error, no valid spans".format())
            return None
        chosen_span = random.choice(valid_spans)
        # chosen_span = (18, 24)
        if chosen_span[1] - chosen_span[0] + 1 > 0.5 * sent_len:
            return None
        # sent_print(sentence, 'tablev')
        # print('spans', spans)
        # print('valid', valid_spans)
        log(
            'chosen span: ', chosen_span, ' '.join(
                self.vocab.id2word(words[0])[chosen_span[0]:chosen_span[1] +
                                             1]))

        raw_words = words.clone()
        var_words = words.clone()
        words_text = self.vocab.id2word(words[0])
        tags_text = self.vocab.id2tag(tags[0])

        # eval_idxs = [eval_idx for eval_idx in range(sent_len)
        #               if not chosen_span[0] <= eval_idx <= chosen_span[1]]
        mask_idxs = list(range(chosen_span[0], chosen_span[1] + 1))
        raw_metric = self.task.partial_evaluate(instance=(raw_words, tags,
                                                          None, arcs, rels),
                                                mask_idxs=mask_idxs)
        _, raw_arcs, _ = self.task.predict([(raw_words, tags, None)])

        forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index]
        change_positions__ = set()
        if isinstance(self.config.hko_max_change, int):
            max_change_num = self.config.hko_max_change
        elif isinstance(self.config.hk_max_change, float):
            max_change_num = int(self.config.hk_max_change * words.size(1))
        else:
            raise Exception("hk_max_change must be a float or an int")

        picker = CherryPicker(lower_is_better=True)
        t0 = time.time()
        for iter_id in range(self.config.hko_steps):
            result = self.single_hack(instance=(var_words, tags, None, arcs,
                                                rels),
                                      raw_words=raw_words,
                                      raw_metric=raw_metric,
                                      raw_arcs=raw_arcs,
                                      mask_idxs=mask_idxs,
                                      iter_id=iter_id,
                                      forbidden_idxs__=forbidden_idxs__,
                                      change_positions__=change_positions__,
                                      max_change_num=max_change_num)
            if result['code'] == 200:
                var_words = result['words']
                picker.add(
                    result['attack_metric'], {
                        'logtable': result['logtable'],
                        "num_changed": len(change_positions__)
                    })
            elif result['code'] == 404:
                print('FAILED')
                break
        t1 = time.time()
        best_iter, best_attack_metric, best_info = picker.select_best_point()

        return {
            "raw_metric": raw_metric,
            "attack_metric": best_attack_metric,
            "iters": iter_id,
            "best_iter": best_iter,
            "num_changed": best_info['num_changed'],
            "time": t1 - t0,
            "logtable": best_info['logtable']
        }