Exemple #1
0
    def select_by_deltalogit(self, instance, sentence):
        minl = self.config.hks_min_span_len
        maxl = self.config.hks_max_span_len
        # Compute the ``vulnerable'' values of each words
        words, tags, chars, arcs, rels = instance
        # sent_len = words.size(1)

        # The returned margins does not contain <ROOT>
        margins = self.compute_margin(instance)
        log(margins)

        # Count the vulnerable words in each span,
        # Select the most vulerable span as target.
        vul_scores = [0]
        for ele in margins:
            if ele > 0:
                vul_scores.append(10 - ele)
            else:
                vul_scores.append(0)
        spans = filter_spans(gen_spans(sentence), minl, maxl, True)
        span_vuls = list()
        # span_ratios = list()
        for span in spans:
            span_vuls.append(sum(vul_scores[span[0]:span[1] + 1]))
            # span_ratios.append(span_vuls[-1] / (span[1] + 1 - span[0]))

        pairs = []
        for tid, t in enumerate(spans):
            for s in spans:
                if check_gap(s, t, self.config.hks_span_gap):
                    pairs.append((span_vuls[tid], (t, s)))
        if len(pairs) == 0:
            return None
        spairs = sorted(pairs, key=lambda x: x[0], reverse=True)
        return list(zip(*spairs))[1]
Exemple #2
0
    def select_tgt_span(self, instance, sentence, mode):
        minl = self.config.hks_min_span_len
        maxl = self.config.hks_max_span_len
        if mode == 'vul':
            # Compute the ``vulnerable'' values of each words
            words, tags, chars, arcs, rels = instance
            # sent_len = words.size(1)

            # The returned margins does not contain <ROOT>
            margins = self.compute_margin(instance)
            log(margins)

            # Count the vulnerable words in each span,
            # Select the most vulerable span as target.
            vul_margins = [0]
            for ele in margins:
                if 0 < ele < 1:
                    vul_margins.append(100)
                elif 1 <= ele < 2:
                    vul_margins.append(1)
                else:
                    vul_margins.append(-1)
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            span_vuls = list()
            # span_ratios = list()
            for span in spans:
                span_vuls.append(sum(vul_margins[span[0]:span[1] + 1]))
                # span_ratios.append(span_vuls[-1] / (span[1] + 1 - span[0]))

            pairs = []
            for tid, t in enumerate(spans):
                for s in spans:
                    if check_gap(s, t, self.config.hks_span_gap):
                        pairs.append((span_vuls[tid], (t, s)))
            if len(pairs) == 0:
                return None
            spairs = sorted(pairs, key=lambda x: x[0], reverse=True)
            return list(zip(*spairs))[1]

            tgt_picker = CherryPicker(lower_is_better=False)
            for span_i, span in enumerate(spans):
                tgt_picker.add(span_vuls[span_i], span)
            if tgt_picker.size == 0:
                log('Target span not found')
                return None
            _, _, tgt_span = tgt_picker.select_best_point()
            return tgt_span

        elif mode == "rdm":
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            if len(spans) == 0:
                return None
            return random.choice(spans)
        else:
            raise Exception
Exemple #3
0
    def select_src_span(self, instance, sentence, tgt_span, mode):
        minl = self.config.hks_min_span_len
        maxl = self.config.hks_max_span_len
        gap = self.config.hks_span_gap
        if mode == "cls":
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            src_span = None
            src_picker = CherryPicker(lower_is_better=True)
            for span in spans:
                stgap = get_gap(span, tgt_span)
                if stgap >= gap:
                    src_picker.add(stgap, span)
            if src_picker.size == 0:
                log('Source span not found')
                return None
            _, _, src_span = src_picker.select_best_point()
            return src_span
        elif mode == 'rdm':
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            src_spans = []
            for span in spans:
                if check_gap(span, tgt_span, self.config.hks_span_gap):
                    src_spans.append(span)
            if len(src_spans) == 0:
                return None
            return random.choice(src_spans)
        elif mode == 'nom':
            sent_len = instance[0].size(1)
            spans = filter_spans(gen_spans(sentence), minl, maxl, True)
            embed_grad = self.backward_loss(instance, ex_span_idx(tgt_span, sent_len))
            grad_norm = embed_grad.norm(dim=2)  # 1 x sent_len, <root> included

            src_picker = CherryPicker(lower_is_better=False)
            for span in spans:
                if check_gap(span, tgt_span, self.config.hks_span_gap):
                    src_picker.add(grad_norm[0][span[0]:span[1] + 1].sum(), span)
            if src_picker.size == 0:
                return None
            _, _, src_span = src_picker.select_best_point()
            return src_span
        else:
            raise Exception
Exemple #4
0
 def select_by_random(self, instance, sentence):
     spans = filter_spans(gen_spans(sentence), self.config.hks_min_span_len,
                          self.config.hks_max_span_len, True)
     paired = []
     for tgt_span in spans:
         for src_span in spans:
             if check_gap(tgt_span, src_span, self.config.hks_span_gap):
                 paired.append((tgt_span, src_span))
     if len(paired) == 0:
         return None
     return paired
Exemple #5
0
    def meta_hack(self, instance, sentence):
        t0 = time.time()
        ram_reset("hk")

        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        words_text = self.vocab.id2word(words[0])

        succ = False

        if self.config.hks_color == "black":
            spans = filter_spans(gen_spans(sentence), self.config.hks_min_span_len,
                                 self.config.hks_max_span_len, True)
            hack_result = None
            for pair_id, src_span in enumerate(spans):
                log('[Chosen source] ', src_span, ' '.join(words_text[src_span[0]:src_span[1] + 1]))
                tgt_span_lst = [
                    ele for ele in spans if check_gap(ele, src_span, self.config.hks_span_gap)
                ]
                if len(tgt_span_lst) == 0:
                    continue
                hack_result = self.random_hack(instance, sentence, tgt_span_lst, src_span)
                if hack_result['succ'] == 1:
                    succ = True
                    log("Succ on source span {}.\n".format(src_span))
                    break
                log("Fail on source span {}.\n".format(src_span))
            if hack_result is None:
                log("Not enough span pairs")
                return None
        elif self.config.hks_color in ['white', 'grey']:
            pairs = self.select_span_pair(instance, sentence)
            if pairs is None:
                return None
            for pair_id, (tgt_span, src_span) in enumerate(pairs[:self.config.hks_topk_pair]):
                log('[Chosen span] ', 'tgt-', tgt_span,
                    ' '.join(words_text[tgt_span[0]:tgt_span[1] + 1]), 'src-', src_span,
                    ' '.join(words_text[src_span[0]:src_span[1] + 1]))

                if self.config.hks_color == "white":
                    hack_result = self.white_hack(instance, sentence, tgt_span, src_span)
                elif self.config.hks_color == "grey":
                    hack_result = self.random_hack(instance, sentence, [tgt_span], src_span)
                else:
                    raise Exception
                if hack_result['succ'] == 1:
                    succ = True
                    log("Succ on the span pair {}/{}.\n".format(pair_id, len(pairs)))
                    break
                log("Fail on the span pair {}/{}.\n".format(pair_id, len(pairs)))

            hack_result['meta_trial_pair'] = pair_id + 1
            hack_result['meta_total_pair'] = len(pairs)
            hack_result['meta_succ_trial_pair'] = pair_id + 1 if succ else 0
            hack_result['meta_succ_total_pair'] = len(pairs) if succ else 0

        t1 = time.time()
        log('Sentence cost {:.1f}s'.format(t1 - t0))
        hack_result['meta_time'] = t1 - t0

        return hack_result
Exemple #6
0
    def select_by_jacobian(self, instance, sentence, mode):
        sent_len = instance[0].size(1)
        spans = filter_spans(gen_spans(sentence), self.config.hks_min_span_len,
                             self.config.hks_max_span_len, True)
        if mode == 'jacobian1':
            table = []
            pairs = []
            for tgt_span in spans:
                embed_grad = self.backward_loss(instance, ex_span_idx(tgt_span, sent_len))
                grad_norm = embed_grad.norm(dim=-1)  # 1 x (sen_len - 1)
                row = [tgt_span]
                for src_span in spans:
                    if check_gap(src_span, tgt_span, self.config.hks_span_gap):
                        # <root> not included
                        norm = grad_norm[0][src_span[0] - 1:src_span[1]].sum().item()
                        pairs.append((norm, (tgt_span, src_span)))
                        row.append("{:4.2f}".format(norm))
                    else:
                        row.append("-")
                table.append(row)
            log(tabulate(table, headers=["t↓"] + spans))
            if len(pairs) == 0:
                return None
            spairs = sorted(pairs, key=lambda x: x[0], reverse=True)

            return list(zip(*spairs))[1]
        elif mode == 'jacobian2':
            # log('valid', spans)
            # margins = self.compute_margin(instance)
            # log(margins)
            norms = []
            # WARNING: This is rather slow!
            idxs_to_backward = []
            for span in spans:
                idxs_to_backward.extend(list(range(span[0], span[1] + 1)))
            for i in range(1, sent_len):
                if i in idxs_to_backward:
                    embed_grad = self.backward_loss(instance,
                                                    [_ for _ in range(1, sent_len) if _ != i])
                    grad_norm = embed_grad.norm(dim=-1)[:, 1:]
                    norms.append(grad_norm)
                else:
                    norms.append(torch.zeros_like(instance[0]).float()[:, 1:])
            norms = torch.cat(norms)  # size: sent_len x sent_len
            # log("t↓", flt2str(list(range(sent_len)), cat=" ", fmt=":3"))
            # for i, norm in enumerate(norms):
            #     log("{:2}".format(i + 1),
            #         flt2str(cast_list(norm), cat=" ", fmt=":3.1f"),
            #         flt2str(sum(norm), fmt=":3.1f"))

            table = []
            for t in spans:
                row = [t]
                for s in spans:
                    if check_gap(s, t, self.config.hks_span_gap):
                        row.append("{:4.2f}".format(norms[t[0] - 1:t[1],
                                                          s[0] - 1:s[1]].sum().item()))
                    else:
                        row.append('-')
                table.append(row)
            log(tabulate(table, headers=["t↓"] + spans))

            pairs = []
            for t in spans:
                for s in spans:
                    if check_gap(s, t, self.config.hks_span_gap):
                        pairs.append((norms[t[0] - 1:t[1], s[0] - 1:s[1]].sum().item(), (t, s)))
            if len(pairs) == 0:
                return None
            spairs = sorted(pairs, key=lambda x: x[0], reverse=True)

            return list(zip(*spairs))[1]