Ejemplo n.º 1
0
    def hack(self, instance):
        words, tags, chars, arcs, rels = instance
        _, raw_metric = self.task.evaluate([(words, tags, chars, arcs, rels)],
                                           mst=self.config.hkw_mst == 'on')
        _, raw_arcs, _ = self.task.predict([(words, tags, None)],
                                           mst=self.config.hkw_mst == 'on')

        # Setup some states before attacking a sentence
        # WARNING: Operations on variables n__global_ramambc__" passed to a function
        #  are in-placed! Never try to save an internal state of the variable.
        forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index]
        change_positions__ = set()
        orphans__ = set()
        if self.config.hkw_max_change > 0.9999:
            max_change_num = int(self.config.hkw_max_change)
        else:
            max_change_num = max(
                1, int(self.config.hkw_max_change * words.size(1)))
        iter_change_num = min(max_change_num, self.config.hkw_iter_change)

        var_words = words.clone()
        raw_words = words.clone()

        # HIGHLIGHT: ITERATION
        t0 = time.time()
        picker = CherryPicker(lower_is_better=True,
                              compare_fn=lambda m1, m2: m1.uas - m2.uas)
        # iter 0 -> raw
        picker.add(raw_metric, {
            "num_changed": 0,
            "logtable": 'No modification'
        })
        for iter_id in range(1, self.config.hkw_steps):
            result = self.single_hack(var_words,
                                      tags,
                                      arcs,
                                      rels,
                                      raw_words=raw_words,
                                      raw_metric=raw_metric,
                                      raw_arcs=raw_arcs,
                                      verbose=False,
                                      max_change_num=max_change_num,
                                      iter_change_num=iter_change_num,
                                      iter_id=iter_id,
                                      forbidden_idxs__=forbidden_idxs__,
                                      change_positions__=change_positions__,
                                      orphans__=orphans__)
            # Fail
            if result['code'] == 404:
                log('Stop in step {}, info: {}'.format(iter_id,
                                                       result['info']))
                break
            # Success
            if result['code'] == 200:
                picker.add(
                    result['attack_metric'],
                    {
                        # "words": result['new_words'],
                        "logtable": result['logtable'],
                        "num_changed": len(change_positions__)
                    })
                if result[
                        'attack_metric'].uas < raw_metric.uas - self.config.hkw_eps:
                    log('Succeed in step {}'.format(iter_id))
                    break
            var_words = result['words']
            # forbidden_idxs__ = result['forbidden_idxs__']
            # change_positions__ = result['change_positions__']
        t1 = time.time()

        best_iter, best_attack_metric, best_info = picker.select_best_point()

        return {
            "raw_metric": raw_metric,
            "attack_metric": best_attack_metric,
            "iters": iter_id,
            "best_iter": best_iter,
            "num_changed": best_info['num_changed'],
            "time": t1 - t0,
            "logtable": best_info['logtable']
        }
Ejemplo n.º 2
0
    def meta_hack(self, instance, sentence):
        t0 = time.time()
        ram_reset("hk")

        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        words_text = self.vocab.id2word(words[0])

        succ = False

        if self.config.hks_color == "black":
            spans = filter_spans(gen_spans(sentence), self.config.hks_min_span_len,
                                 self.config.hks_max_span_len, True)
            hack_result = None
            for pair_id, src_span in enumerate(spans):
                log('[Chosen source] ', src_span, ' '.join(words_text[src_span[0]:src_span[1] + 1]))
                tgt_span_lst = [
                    ele for ele in spans if check_gap(ele, src_span, self.config.hks_span_gap)
                ]
                if len(tgt_span_lst) == 0:
                    continue
                hack_result = self.random_hack(instance, sentence, tgt_span_lst, src_span)
                if hack_result['succ'] == 1:
                    succ = True
                    log("Succ on source span {}.\n".format(src_span))
                    break
                log("Fail on source span {}.\n".format(src_span))
            if hack_result is None:
                log("Not enough span pairs")
                return None
        elif self.config.hks_color in ['white', 'grey']:
            pairs = self.select_span_pair(instance, sentence)
            if pairs is None:
                return None
            for pair_id, (tgt_span, src_span) in enumerate(pairs[:self.config.hks_topk_pair]):
                log('[Chosen span] ', 'tgt-', tgt_span,
                    ' '.join(words_text[tgt_span[0]:tgt_span[1] + 1]), 'src-', src_span,
                    ' '.join(words_text[src_span[0]:src_span[1] + 1]))

                if self.config.hks_color == "white":
                    hack_result = self.white_hack(instance, sentence, tgt_span, src_span)
                elif self.config.hks_color == "grey":
                    hack_result = self.random_hack(instance, sentence, [tgt_span], src_span)
                else:
                    raise Exception
                if hack_result['succ'] == 1:
                    succ = True
                    log("Succ on the span pair {}/{}.\n".format(pair_id, len(pairs)))
                    break
                log("Fail on the span pair {}/{}.\n".format(pair_id, len(pairs)))

            hack_result['meta_trial_pair'] = pair_id + 1
            hack_result['meta_total_pair'] = len(pairs)
            hack_result['meta_succ_trial_pair'] = pair_id + 1 if succ else 0
            hack_result['meta_succ_total_pair'] = len(pairs) if succ else 0

        t1 = time.time()
        log('Sentence cost {:.1f}s'.format(t1 - t0))
        hack_result['meta_time'] = t1 - t0

        return hack_result
Ejemplo n.º 3
0
    def random_hack(self, instance, sentence, tgt_span_lst, src_span):
        t0 = time.time()
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        raw_words_lst = cast_list(words)
        idxs = gen_idxs_to_substitute(list(range(src_span[0], src_span[1] + 1)),
                                      int(self.config.hks_max_change), self.config.hks_cand_num)
        if self.config.hks_blk_repl_tag == 'any':
            cand_words_lst = subsitute_by_idxs(raw_words_lst, idxs, self.blackbox_sub_idxs)
        elif re.match("[njvri]+", self.config.hks_blk_repl_tag):
            cand_words_lst = subsitute_by_idxs(raw_words_lst, idxs,
                                               self.njvri_subidxs(self.config.hks_blk_repl_tag))
        elif re.match("keep", self.config.hks_blk_repl_tag):
            tag_texts = cast_list(tags[0])
            vocab_idxs_lst = []
            for i in range(src_span[0], src_span[1] + 1):
                vocab_idxs_lst.append(self.univ_subidxs(HACK_TAGS.ptb2uni(self.vocab.tags[tag_texts[i]])))
            cand_words_lst = subsitute_by_idxs_2(raw_words_lst, idxs, src_span[0], vocab_idxs_lst)
        else:
            raise Exception

        # First index is raw words
        all_words = torch.tensor([raw_words_lst] + cand_words_lst, device=words.device)
        raw_words = all_words[0:1]
        cand_words = all_words[1:]

        tgt_idxs = []
        for tgt_span in tgt_span_lst:
            tgt_idxs.extend([i for i in range(tgt_span[0], tgt_span[1] + 1) if i != tgt_span[2]])
        ex_tgt_idxs = [i for i in range(sent_len) if i not in tgt_idxs]

        pred_arcs, pred_rels, gold_arcs, gold_rels = self.task.partial_evaluate(
            (all_words, tags.expand_as(all_words), None, arcs.expand_as(all_words),
             rels.expand_as(all_words)),
            mask_idxs=ex_tgt_idxs,
            mst=False,
            return_metric=False)

        # raw_metric = ParserMetric()
        # raw_metric(pred_arcs[0], pred_rels[0], gold_arcs[0], gold_rels[0])
        raw_metric = self.task.partial_evaluate((raw_words, tags, None, arcs, rels),
                                                ex_tgt_idxs,
                                                mst=True)

        succ = False
        arc_delta = (pred_arcs[1:] - pred_arcs[0]).abs().sum(1)
        if arc_delta.sum() == 0:
            # direct forward to the last attacked sentence
            att_id = self.config.hks_cand_num - 1
        else:
            for att_id in range(0, self.config.hks_cand_num):
                if arc_delta[att_id] != 0:
                    att_metric = self.task.partial_evaluate(
                        (cand_words[att_id:att_id + 1], tags, None, arcs, rels),
                        ex_tgt_idxs,
                        mst=True)

                    if att_metric.uas < raw_metric.uas - 0.0001:
                        succ = True
                        break

        # ATT_ID will be equal to att_id when failing
        t1 = time.time()

        _, raw_arcs, raw_rels = self.task.predict([(raw_words, tags, None)], mst=True)
        _, att_arcs, att_rels = self.task.predict([(cand_words[att_id:att_id + 1], tags, None)],
                                                  mst=True)

        if not succ:
            info = 'Nothing'
        else:
            for tgt_span in tgt_span_lst:
                if tgt_span == src_span:
                    continue
                raw_span_corr = 0
                att_span_corr = 0
                for i in range(tgt_span[0], tgt_span[1] + 1):
                    if i == tgt_span[2]:
                        continue
                    if raw_arcs[0][i - 1] == arcs[0][i - 1]:
                        raw_span_corr += 1
                    if att_arcs[0][i - 1] == arcs[0][i - 1]:
                        att_span_corr += 1
                if att_span_corr < raw_span_corr:
                    break

            info = tabulate(self._gen_log_table(words, cand_words[att_id:att_id + 1], tags, arcs,
                                                rels, raw_arcs, att_arcs, src_span, tgt_span),
                            floatfmt='.6f')

        return defaultdict(
            lambda: -1, {
                "succ": 1 if succ else 0,
                "att_id": att_id if att_id < self.config.hks_cand_num - 1 else np.nan,
                "num_changed": self.config.hks_max_change,
                "time": t1 - t0,
                "logtable": info
            })
Ejemplo n.º 4
0
    def white_hack(self, instance, sentence, tgt_span, src_span):
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        raw_words = words.clone()
        var_words = words.clone()

        raw_metric = self.task.partial_evaluate(instance=(raw_words, tags, None, arcs, rels),
                                                mask_idxs=ex_span_idx(tgt_span, sent_len),
                                                mst=self.config.hks_mst == 'on')
        _, raw_arcs, _ = self.task.predict([(raw_words, tags, None)])

        forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index]
        change_positions__ = set()
        if self.config.hks_max_change > 0.9999:
            max_change_num = int(self.config.hks_max_change)
        else:
            max_change_num = max(1, int(self.config.hks_max_change * words.size(1)))
        iter_change_num = min(max_change_num, self.config.hks_iter_change)

        picker = CherryPicker(lower_is_better=True)
        t0 = time.time()
        picker.add(raw_metric, {"num_changed": 0, "logtable": 'No modification'})
        log('iter -1, uas {:.4f}'.format(raw_metric.uas))
        succ = False
        for iter_id in range(self.config.hks_steps):
            result = self.single_hack(instance=(var_words, tags, None, arcs, rels),
                                      raw_words=raw_words,
                                      raw_metric=raw_metric,
                                      raw_arcs=raw_arcs,
                                      src_span=src_span,
                                      tgt_span=tgt_span,
                                      iter_id=iter_id,
                                      forbidden_idxs__=forbidden_idxs__,
                                      change_positions__=change_positions__,
                                      max_change_num=max_change_num,
                                      iter_change_num=iter_change_num)
            if result['code'] == 200:
                var_words = result['words']
                picker.add(result['attack_metric'], {
                    'logtable': result['logtable'],
                    "num_changed": len(change_positions__)
                })
                if result['attack_metric'].uas < raw_metric.uas - 0.00001:
                    succ = True
                    log('Succeed in step {}'.format(iter_id))
                    break
            elif result['code'] == 404:
                log('FAILED')
                break
        t1 = time.time()
        best_iter, best_attack_metric, best_info = picker.select_best_point()

        return defaultdict(
            lambda: -1, {
                "succ": 1 if succ else 0,
                "raw_metric": raw_metric,
                "attack_metric": best_attack_metric,
                "iters": iter_id,
                "best_iter": best_iter,
                "num_changed": best_info['num_changed'],
                "time": t1 - t0,
                "logtable": best_info['logtable']
            })
Ejemplo n.º 5
0
    def hack(self, instance):
        words, tags, chars, arcs, rels = instance
        _, raw_metric = self.task.evaluate([(words, tags, chars, arcs, rels)],
                                           mst=self.config.hkc_mst == 'on')
        _, raw_arcs, _ = self.task.predict([(words, tags, chars)],
                                           mst=self.config.hkc_mst == 'on')

        # char_grads, word_grad = self.backward_loss(words, chars, arcs, rels)

        forbidden_idxs__ = defaultdict(
            lambda: deque(maxlen=5))    # word_sid -> deque()
        change_positions__ = dict()  # word_sid -> char_wid
        if self.config.hkc_max_change > 0.9999:
            max_change_num = int(self.config.hkc_max_change)
        else:
            max_change_num = max(
                1, int(self.config.hkc_max_change * words.size(1)))
        iter_change_num = min(max_change_num, self.config.hkc_iter_change)

        raw_chars = chars.clone()
        var_chars = chars.clone()

        # HIGHLIGHT: ITERATION
        t0 = time.time()
        picker = CherryPicker(lower_is_better=True,
                              compare_fn=lambda m1, m2: m1.uas - m2.uas)
        # iter 0 -> raw
        picker.add(raw_metric, {
            "num_changed": 0,
            "logtable": 'No modification'
        })
        for iter_id in range(1, self.config.hkc_steps):
            result = self.single_hack(
                words, tags, var_chars, arcs, rels,
                raw_chars=raw_chars, raw_metric=raw_metric, raw_arcs=raw_arcs,
                verbose=False,
                max_change_num=max_change_num,
                iter_change_num=iter_change_num,
                iter_id=iter_id,
                forbidden_idxs__=forbidden_idxs__,
                change_positions__=change_positions__,
            )
            # Fail
            if result['code'] == 404:
                log('Stop in step {}, info: {}'.format(
                    iter_id, result['info']))
                break
            # Success
            if result['code'] == 200:
                picker.add(result['attack_metric'], {
                    # "words": result['new_words'],
                    "logtable": result['logtable'],
                    "num_changed": len(change_positions__)
                })
                if result['attack_metric'].uas < raw_metric.uas - self.config.hkw_eps:
                    log('Succeed in step {}'.format(iter_id))
                    break
            var_chars = result['chars']
        t1 = time.time()

        best_iter, best_attack_metric, best_info = picker.select_best_point()

        return {
            "raw_metric": raw_metric,
            "attack_metric": best_attack_metric,
            "iters": iter_id,
            "best_iter": best_iter,
            "num_changed": best_info['num_changed'],
            "time": t1 - t0,
            "logtable": best_info['logtable']
        }
Ejemplo n.º 6
0
    def hack(self, instance, sentence):
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)
        spans = gen_spans(sentence)
        valid_spans = list(filter(lambda ele: 5 <= ele[1] - ele[0] <= 8,
                                  spans))
        if len(valid_spans) == 0:
            log("Attack error, no valid spans".format())
            return None
        chosen_span = random.choice(valid_spans)
        # chosen_span = (18, 24)
        if chosen_span[1] - chosen_span[0] + 1 > 0.5 * sent_len:
            return None
        # sent_print(sentence, 'tablev')
        # print('spans', spans)
        # print('valid', valid_spans)
        log(
            'chosen span: ', chosen_span, ' '.join(
                self.vocab.id2word(words[0])[chosen_span[0]:chosen_span[1] +
                                             1]))

        raw_words = words.clone()
        var_words = words.clone()
        words_text = self.vocab.id2word(words[0])
        tags_text = self.vocab.id2tag(tags[0])

        # eval_idxs = [eval_idx for eval_idx in range(sent_len)
        #               if not chosen_span[0] <= eval_idx <= chosen_span[1]]
        mask_idxs = list(range(chosen_span[0], chosen_span[1] + 1))
        raw_metric = self.task.partial_evaluate(instance=(raw_words, tags,
                                                          None, arcs, rels),
                                                mask_idxs=mask_idxs)
        _, raw_arcs, _ = self.task.predict([(raw_words, tags, None)])

        forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index]
        change_positions__ = set()
        if isinstance(self.config.hko_max_change, int):
            max_change_num = self.config.hko_max_change
        elif isinstance(self.config.hk_max_change, float):
            max_change_num = int(self.config.hk_max_change * words.size(1))
        else:
            raise Exception("hk_max_change must be a float or an int")

        picker = CherryPicker(lower_is_better=True)
        t0 = time.time()
        for iter_id in range(self.config.hko_steps):
            result = self.single_hack(instance=(var_words, tags, None, arcs,
                                                rels),
                                      raw_words=raw_words,
                                      raw_metric=raw_metric,
                                      raw_arcs=raw_arcs,
                                      mask_idxs=mask_idxs,
                                      iter_id=iter_id,
                                      forbidden_idxs__=forbidden_idxs__,
                                      change_positions__=change_positions__,
                                      max_change_num=max_change_num)
            if result['code'] == 200:
                var_words = result['words']
                picker.add(
                    result['attack_metric'], {
                        'logtable': result['logtable'],
                        "num_changed": len(change_positions__)
                    })
            elif result['code'] == 404:
                print('FAILED')
                break
        t1 = time.time()
        best_iter, best_attack_metric, best_info = picker.select_best_point()

        return {
            "raw_metric": raw_metric,
            "attack_metric": best_attack_metric,
            "iters": iter_id,
            "best_iter": best_iter,
            "num_changed": best_info['num_changed'],
            "time": t1 - t0,
            "logtable": best_info['logtable']
        }