Esempio n. 1
0
    def show_embedding_info(self):
        print('*** Statistics of parameters and 2-norm ***')
        show_mean_std(self.embed, 'Param')
        show_mean_std(torch.norm(self.embed, p=2, dim=1), 'Norm')

        print('*** Statistics of distances in a N-nearest neighbourhood ***')
        nbr_num = [5, 10, 20, 50, 100, 200, 500, 10000, 20000]
        dists = {nbr: [] for nbr in nbr_num}
        for ele in cast_list(torch.randint(self.embed.size(0), (50, ))):
            if self.embed[ele].sum() == 0.:
                continue
            idxs, vals = self.find_neighbours(ele, -1, 'euc', False)
            for nbr in nbr_num:
                dists[nbr].append(vals[1:nbr + 1])
        table = []
        for nbr in nbr_num:
            dists[nbr] = torch.cat(dists[nbr])
            table.append(
                [nbr, dists[nbr].mean().item(), dists[nbr].std().item()])
        print(tabulate(table, headers=['N', 'mean', 'std'], floatfmt='.2f'))
        # exit()

        print('*** Statistics of distances in a N-nearest neighbourhood ***\n'
              '    when randomly moving by different step sizes')
        mve_nom = [1, 2, 5, 10, 20, 50]
        nbr_num = [5, 10, 20, 50, 100, 500]
        dists = {mve: {nbr: [] for nbr in nbr_num} for mve in mve_nom}
        cover = {mve: {nbr: [] for nbr in nbr_num} for mve in mve_nom}
        for ele in cast_list(torch.randint(self.embed.size(0), (50, ))):
            if self.embed[ele].sum() == 0.:
                continue
            vect = torch.rand_like(self.embed[ele])
            vect = vect / torch.norm(vect)
            ridxs, rvals = self.find_neighbours(self.embed[ele], -1, 'euc',
                                                False)
            for mve in mve_nom:
                idxs, vals = self.find_neighbours(self.embed[ele] + vect * mve,
                                                  500, 'euc', False)
                for nbr in nbr_num:
                    dists[mve][nbr].append(vals[1:nbr + 1])
                    cover[mve][nbr].append(
                        compare_idxes(idxs[1:nbr + 1], ridxs[1:nbr + 1]))
        table = []
        for mve in mve_nom:
            row = [mve]
            for nbr in nbr_num:
                dist = torch.cat(dists[mve][nbr])
                row.append(dist.mean().item())
                row.append("{:.1f}%".format(
                    np.mean(cover[mve][nbr]) / nbr * 100))
            table.append(row)

        print(
            tabulate(table,
                     headers=[
                         'Step', 'D-5', 'I-5', 'D-10', 'I-10', 'D-20', 'I-20',
                         'D-50', 'I-50', 'D-100', 'I-100'
                     ],
                     floatfmt='.2f'))
Esempio n. 2
0
 def njvri_subidxs(self, njvri_tags):
     ret = []
     for t in njvri_tags:
         for ptb_tag in HACK_TAGS[t]:
             if ptb_tag in self.tag_dict:
                 ret.extend(cast_list(self.tag_dict[ptb_tag]))
     return ret
Esempio n. 3
0
    def single_hack(self,
                    words,
                    tags,
                    arcs,
                    rels,
                    raw_words,
                    raw_metric,
                    raw_arcs,
                    forbidden_idxs__,
                    change_positions__,
                    orphans__,
                    verbose=False,
                    max_change_num=1,
                    iter_change_num=1,
                    iter_id=-1):
        sent_len = words.size(1)
        """
            Loss back-propagation
        """
        embed_grad = self.backward_loss(
            words,
            tags,
            arcs,
            rels,
            loss_based_on=self.config.hkw_loss_based_on)
        # Sometimes the loss/grad will be zero.
        # Especially in the case of applying pgd_freq>1 to small sentences:
        # e.g., the uas of projected version may be 83.33%
        # while under the case of the unprojected version, the loss is 0.
        if torch.sum(embed_grad) == 0.0:
            return {"code": 404, "info": "Loss is zero."}
        """
            Select and change a word
        """
        grad_norm = embed_grad.norm(dim=2)

        position_mask = [False for _ in range(words.size(1))]
        # Mask some positions by POS & <UNK>
        for i in range(sent_len):
            if self.vocab.tags[tags[0][i]] not in HACK_TAGS[
                    self.config.hkw_tag_type]:
                position_mask[i] = True
        # Mask some orphans
        for i in orphans__:
            position_mask[i] = True
        # Check if the number of changed words exceeds the max value
        if len(change_positions__) >= max_change_num:
            for i in range(sent_len):
                if i not in change_positions__:
                    position_mask[i] = True
        if all(position_mask):
            return {
                "code": 404,
                "info": "Constrained by tags, no valid word to replace."
            }

        for i in range(sent_len):
            if position_mask[i]:
                grad_norm[0][i] = -(grad_norm[0][i] + 1000)

        # Select a word and forbid itself
        word_sids = []  # type: list[torch.Tensor]
        word_vids = []  # type: list[torch.Tensor]
        new_word_vids = []  # type: list[torch.Tensor]

        _, topk_idxs = grad_norm[0].sort(descending=True)
        selected_words = elder_select(ordered_idxs=cast_list(topk_idxs),
                                      num_to_select=iter_change_num,
                                      selected=change_positions__,
                                      max_num=max_change_num)
        # The position mask will ensure that at least one word is legal,
        # but the second one may not be allowed
        selected_words = [
            ele for ele in selected_words if position_mask[ele] is False
        ]
        word_sids = torch.tensor(selected_words)

        for word_sid in word_sids:
            word_grad = embed_grad[0][word_sid]

            word_vid = words[0][word_sid]
            emb_to_rpl = self.parser.embed.weight[word_vid]
            forbidden_idxs__.append(word_vid.item())
            change_positions__.add(word_sid.item())
            # print(self.change_positions__)

            # Find a word to change with dynamically step
            # Note that it is possible that all words found are not as required, e.g.
            #   all neighbours have different tags.
            delta = word_grad / \
                torch.norm(word_grad) * self.config.hkw_step_size
            changed = emb_to_rpl - delta

            must_tag = self.vocab.tags[tags[0][word_sid].item()]
            new_word_vid, repl_info = self.find_replacement(
                changed,
                must_tag,
                dist_measure=self.config.hkw_dist_measure,
                forbidden_idxs__=forbidden_idxs__,
                repl_method=self.config.hkw_repl_method,
                words=words,
                word_sid=word_sid,
                raw_words=raw_words)

            word_vids.append(word_vid)
            new_word_vids.append(new_word_vid)

        new_words = words.clone()
        exist_change = False
        for i in range(len(word_vids)):
            if new_word_vids[i] is not None:
                new_words[0][word_sids[i]] = new_word_vids[i]
                exist_change = True

        if not exist_change:
            # if self.config.hkw_selection == 'orphan':
            #     orphans__.add(word_sids[0])
            #     log('iter {}, Add word {}\'s location to orphans.'.format(
            #         iter_id,
            #         self.vocab.words[raw_words[0][word_sid].item()]))
            #     return {
            #         'code': '200',
            #         'words': words,
            #         'atack_metric': 100.,
            #         'logtable': 'This will be never selected'
            #     }
            log('Attack failed.')
            return {
                'code': 404,
                'info': 'Neighbours of all selected words have different tags.'
            }

        # if new_word_vid is None:
        #     log('Attack failed.')
        #     return {'code': 404,
        #             'info': 'Neighbours of the selected words have different tags.'}

        # new_words = words.clone()
        # new_words[0][word_sid] = new_word_vid
        """
            Evaluating the result
        """
        # print('START EVALUATING')
        # print([self.vocab.words[ele] for ele in self.forbidden_idxs__])
        new_words_text = [self.vocab.words[i.item()] for i in new_words[0]]
        # print(new_words_txt)
        loss, metric = self.task.evaluate(
            [(new_words, tags, None, arcs, rels)],
            mst=self.config.hkw_mst == 'on')

        def _gen_log_table():
            new_words_text = [self.vocab.words[i.item()] for i in new_words[0]]
            raw_words_text = [self.vocab.words[i.item()] for i in raw_words[0]]
            tags_text = [self.vocab.tags[i.item()] for i in tags[0]]
            _, att_arcs, _ = self.task.predict([(new_words, tags, None)],
                                               mst=self.config.hkw_mst == 'on')

            table = []
            for i in range(sent_len):
                gold_arc = int(arcs[0][i])
                raw_arc = 0 if i == 0 else raw_arcs[0][i - 1]
                att_arc = 0 if i == 0 else att_arcs[0][i - 1]

                relevant_mask = '&' if \
                    raw_words[0][att_arc] != new_words[0][att_arc] or \
                    raw_words_text[i] != new_words_text[i] else ""
                table.append([
                    i, raw_words_text[i], '>{}'.format(new_words_text[i])
                    if raw_words_text[i] != new_words_text[i] else "*",
                    tags_text[i], gold_arc, raw_arc, '>{}{}'.format(
                        att_arc, relevant_mask) if att_arc != raw_arc else '*',
                    grad_norm[0][i].item()
                ])
            return table

        if verbose:
            print('$$$$$$$$$$$$$$$$$$$$$$$$$$$')
            print('Iter {}'.format(iter_id))
            print(tabulate(_gen_log_table(), floatfmt=('.6f')))
            print('^^^^^^^^^^^^^^^^^^^^^^^^^^^')

        rpl_detail = ""
        for i in range(len(word_sids)):
            if new_word_vids[i] is not None:
                rpl_detail += "{}:{}->{}  ".format(
                    self.vocab.words[raw_words[0][word_sids[i]].item()],
                    self.vocab.words[word_vids[i].item()],
                    self.vocab.words[new_word_vids[i].item()])
        log(
            "iter {}, uas {:.4f}, ".format(iter_id, metric.uas),
            "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'],
                                                  repl_info['avgd'])
            if 'mind' in repl_info else '', rpl_detail)
        if metric.uas >= raw_metric.uas - .00001:
            logtable = 'Nothing'
        else:
            logtable = tabulate(_gen_log_table(), floatfmt='.6f')
        return {
            'code': 200,
            'words': new_words,
            'attack_metric': metric,
            'logtable': logtable,
            # "forbidden_idxs__": forbidden_idxs__,
            # "change_positions__": change_positions__,
        }
Esempio n. 4
0
def compare_idxes(nbr1, nbr2):
    nbr1 = set(cast_list(nbr1))
    nbr2 = set(cast_list(nbr2))
    inter = nbr1.intersection(nbr2)
    return len(inter)
Esempio n. 5
0
    def single_hack(self, instance,
                    src_span, tgt_span,
                    iter_id,
                    raw_words, raw_metric, raw_arcs,
                    forbidden_idxs__: list,
                    change_positions__: set,
                    max_change_num,
                    iter_change_num,
                    verbose=False):
        # yapf: enable
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        # Backward loss
        embed_grad = self.backward_loss(instance=instance,
                                        mask_idxs=ex_span_idx(tgt_span, sent_len),
                                        verbose=True)
        grad_norm = embed_grad.norm(dim=2)
        if self.config.hks_word_random == "on":
            grad_norm = torch.rand(grad_norm.size(), device=grad_norm.device)

        position_mask = [False for _ in range(words.size(1))]
        # Mask some positions
        for i in range(sent_len):
            if rels[0][i].item(
            ) == self.vocab.rel_dict['punct'] or not src_span[0] <= i <= src_span[1]:
                position_mask[i] = True
        # Check if the number of changed words exceeds the max value
        if len(change_positions__) >= max_change_num:
            for i in range(sent_len):
                if i not in change_positions__:
                    position_mask[i] = True

        for i in range(sent_len):
            if position_mask[i]:
                grad_norm[0][i] = -(grad_norm[0][i] + 1000)

        # print(grad_norm)

        # Select a word and forbid itself
        word_sids = []  # type: list[torch.Tensor]
        word_vids = []  # type: list[torch.Tensor]
        new_word_vids = []  # type: list[torch.Tensor]

        # _, topk_idxs = grad_norm[0].topk(min(max_change_num, len(src_idxs)))
        # for ele in topk_idxs:
        #     word_sids.append(ele)
        _, topk_idxs = grad_norm[0].sort(descending=True)
        selected_words = elder_select(ordered_idxs=cast_list(topk_idxs),
                                      num_to_select=iter_change_num,
                                      selected=change_positions__,
                                      max_num=max_change_num)
        # The position mask will ensure that at least one word is legal,
        # but the second one may not be allowed
        selected_words = [ele for ele in selected_words if position_mask[ele] is False]
        word_sids = torch.tensor(selected_words)

        for word_sid in word_sids:
            word_vid = words[0][word_sid]
            emb_to_rpl = self.parser.embed.weight[word_vid]

            if self.config.hks_step_size > 0:
                word_grad = embed_grad[0][word_sid]
                delta = word_grad / \
                    torch.norm(word_grad) * self.config.hks_step_size
                changed = emb_to_rpl - delta

                tag_type = self.config.hks_constraint
                if tag_type == 'any':
                    must_tag = None
                elif tag_type == 'same':
                    must_tag = self.vocab.tags[tags[0][word_sid].item()]
                elif re.match("[njvri]+", tag_type):
                    must_tag = HACK_TAGS[tag_type]
                else:
                    raise Exception
                forbidden_idxs__.append(word_vid)
                change_positions__.add(word_sid.item())
                new_word_vid, repl_info = self.find_replacement(
                    changed,
                    must_tag,
                    dist_measure=self.config.hks_dist_measure,
                    forbidden_idxs__=forbidden_idxs__,
                    repl_method='tagdict',
                    words=words,
                    word_sid=word_sid)
            else:
                new_word_vid = random.randint(0, self.vocab.n_words)
                while new_word_vid in [self.vocab.pad_index, self.vocab.word_dict["<root>"]]:
                    new_word_vid = random.randint(0, self.vocab.n_words - 1)
                new_word_vid = torch.tensor(new_word_vid, device=words.device)
                repl_info = {}
                change_positions__.add(word_sid.item())

            word_vids.append(word_vid)
            new_word_vids.append(new_word_vid)

            # log(delta @ (self.parser.embed.weight[new_word_vid] - self.parser.embed.weight[word_vid]))
            # log()

        new_words = words.clone()
        for i in range(len(word_vids)):
            if new_word_vids[i] is not None:
                new_words[0][word_sids[i]] = new_word_vids[i]
        """
            Evaluating the result
        """
        # print('START EVALUATING')
        # print([self.vocab.words[ele] for ele in self.forbidden_idxs__])
        metric = self.task.partial_evaluate(instance=(new_words, tags, None, arcs, rels),
                                            mask_idxs=ex_span_idx(tgt_span, sent_len),
                                            mst=self.config.hks_mst == 'on')
        att_tags, att_arcs, att_rels = self.task.predict([(new_words, tags, None)],
                                                         mst=self.config.hks_mst == 'on')

        # if verbose:
        #     print('$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        #     print('Iter {}'.format(iter_id))
        #     print(tabulate(_gen_log_table(), floatfmt=('.6f')))
        #     print('^^^^^^^^^^^^^^^^^^^^^^^^^^^')

        rpl_detail = ""
        for i in range(len(word_sids)):
            if new_word_vids[i] is not None:
                rpl_detail += "{}:{}->{}  ".format(
                    self.vocab.words[raw_words[0][word_sids[i]].item()],
                    self.vocab.words[word_vids[i].item()],
                    self.vocab.words[new_word_vids[i].item()])

        log(
            "iter {}, uas {:.4f}, ".format(iter_id, metric.uas),
            "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'], repl_info['avgd'])
            if 'mind' in repl_info else '', rpl_detail)
        if metric.uas >= raw_metric.uas - .00001:
            info = 'Nothing'
        else:
            info = tabulate(self._gen_log_table(raw_words, new_words, tags, arcs, rels, raw_arcs,
                                                att_arcs, src_span, tgt_span),
                            floatfmt='.6f')
        return {
            'code': 200,
            'words': new_words,
            'attack_metric': metric,
            'logtable': info,
        }
Esempio n. 6
0
    def random_hack(self, instance, sentence, tgt_span_lst, src_span):
        t0 = time.time()
        words, tags, chars, arcs, rels = instance
        sent_len = words.size(1)

        raw_words_lst = cast_list(words)
        idxs = gen_idxs_to_substitute(list(range(src_span[0], src_span[1] + 1)),
                                      int(self.config.hks_max_change), self.config.hks_cand_num)
        if self.config.hks_blk_repl_tag == 'any':
            cand_words_lst = subsitute_by_idxs(raw_words_lst, idxs, self.blackbox_sub_idxs)
        elif re.match("[njvri]+", self.config.hks_blk_repl_tag):
            cand_words_lst = subsitute_by_idxs(raw_words_lst, idxs,
                                               self.njvri_subidxs(self.config.hks_blk_repl_tag))
        elif re.match("keep", self.config.hks_blk_repl_tag):
            tag_texts = cast_list(tags[0])
            vocab_idxs_lst = []
            for i in range(src_span[0], src_span[1] + 1):
                vocab_idxs_lst.append(self.univ_subidxs(HACK_TAGS.ptb2uni(self.vocab.tags[tag_texts[i]])))
            cand_words_lst = subsitute_by_idxs_2(raw_words_lst, idxs, src_span[0], vocab_idxs_lst)
        else:
            raise Exception

        # First index is raw words
        all_words = torch.tensor([raw_words_lst] + cand_words_lst, device=words.device)
        raw_words = all_words[0:1]
        cand_words = all_words[1:]

        tgt_idxs = []
        for tgt_span in tgt_span_lst:
            tgt_idxs.extend([i for i in range(tgt_span[0], tgt_span[1] + 1) if i != tgt_span[2]])
        ex_tgt_idxs = [i for i in range(sent_len) if i not in tgt_idxs]

        pred_arcs, pred_rels, gold_arcs, gold_rels = self.task.partial_evaluate(
            (all_words, tags.expand_as(all_words), None, arcs.expand_as(all_words),
             rels.expand_as(all_words)),
            mask_idxs=ex_tgt_idxs,
            mst=False,
            return_metric=False)

        # raw_metric = ParserMetric()
        # raw_metric(pred_arcs[0], pred_rels[0], gold_arcs[0], gold_rels[0])
        raw_metric = self.task.partial_evaluate((raw_words, tags, None, arcs, rels),
                                                ex_tgt_idxs,
                                                mst=True)

        succ = False
        arc_delta = (pred_arcs[1:] - pred_arcs[0]).abs().sum(1)
        if arc_delta.sum() == 0:
            # direct forward to the last attacked sentence
            att_id = self.config.hks_cand_num - 1
        else:
            for att_id in range(0, self.config.hks_cand_num):
                if arc_delta[att_id] != 0:
                    att_metric = self.task.partial_evaluate(
                        (cand_words[att_id:att_id + 1], tags, None, arcs, rels),
                        ex_tgt_idxs,
                        mst=True)

                    if att_metric.uas < raw_metric.uas - 0.0001:
                        succ = True
                        break

        # ATT_ID will be equal to att_id when failing
        t1 = time.time()

        _, raw_arcs, raw_rels = self.task.predict([(raw_words, tags, None)], mst=True)
        _, att_arcs, att_rels = self.task.predict([(cand_words[att_id:att_id + 1], tags, None)],
                                                  mst=True)

        if not succ:
            info = 'Nothing'
        else:
            for tgt_span in tgt_span_lst:
                if tgt_span == src_span:
                    continue
                raw_span_corr = 0
                att_span_corr = 0
                for i in range(tgt_span[0], tgt_span[1] + 1):
                    if i == tgt_span[2]:
                        continue
                    if raw_arcs[0][i - 1] == arcs[0][i - 1]:
                        raw_span_corr += 1
                    if att_arcs[0][i - 1] == arcs[0][i - 1]:
                        att_span_corr += 1
                if att_span_corr < raw_span_corr:
                    break

            info = tabulate(self._gen_log_table(words, cand_words[att_id:att_id + 1], tags, arcs,
                                                rels, raw_arcs, att_arcs, src_span, tgt_span),
                            floatfmt='.6f')

        return defaultdict(
            lambda: -1, {
                "succ": 1 if succ else 0,
                "att_id": att_id if att_id < self.config.hks_cand_num - 1 else np.nan,
                "num_changed": self.config.hks_max_change,
                "time": t1 - t0,
                "logtable": info
            })
Esempio n. 7
0
 def univ_subidxs(self, univ_tag):
     ret = []
     for ptbtag in HACK_TAGS.uni2ptb(univ_tag):
         if ptbtag in self.tag_dict:
             ret.extend(cast_list(self.tag_dict[ptbtag]))
     return ret
Esempio n. 8
0
    def single_hack(self,
                    words, tags, chars, arcs, rels,
                    raw_chars, raw_metric, raw_arcs,
                    forbidden_idxs__,
                    change_positions__,
                    verbose=False,
                    max_change_num=1,
                    iter_change_num=1,
                    iter_id=-1):
        sent_len = words.size(1)

        """
            Loss back-propagation
        """
        char_grads, word_grads = self.backward_loss(words, chars, arcs, rels)

        """
            Select and change a word
        """
        word_grad_norm = word_grads.norm(dim=-1)   # 1 x length
        # 1 x length x max_char_length
        char_grad_norm = char_grads.norm(dim=-1)
        woca_grad_norm = char_grad_norm.max(dim=-1)    # 1 x length

        position_mask = [False for _ in range(words.size(1))]
        # Mask some positions by POS & <UNK>
        for i in range(sent_len):
            if rels[0][i].item() == self.vocab.rel_dict['punct']:
                position_mask[i] = True
        # Check if the number of changed words exceeds the max value
        if len(change_positions__) >= max_change_num:
            for i in range(sent_len):
                if i not in change_positions__:
                    position_mask[i] = True
        if all(position_mask):
            return {"code": 404, "info": "Constrained by tags, no valid word to replace."}

        for i in range(sent_len):
            if position_mask[i]:
                word_grad_norm[0][i] = -(word_grad_norm[0][i] + 1000)
                char_grad_norm[0][i] = -(char_grad_norm[0][i] + 1000)

        # Select a word and forbid itself
        word_sids = []
        char_wids = []
        char_vids = []
        new_char_vids = []

        if self.config.hkc_selection == 'elder':
            _, topk_idxs = word_grad_norm[0].sort(descending=True)
            selected_words = elder_select(ordered_idxs=cast_list(topk_idxs),
                                          num_to_select=iter_change_num,
                                          selected=change_positions__,
                                          max_num=max_change_num)
            selected_chars = []
            for ele in selected_words:
                if ele in change_positions__:
                    selected_chars.append(change_positions__[ele])
                else:
                    wcid = char_grad_norm[0][ele].argmax()
                    selected_chars.append(wcid)
            word_sids = torch.tensor(selected_words)
            char_wids = torch.tensor(selected_chars)
        elif self.config.hkc_selection == 'young':
            _, topk_idxs = word_grad_norm[0].sort(descending=True)
            selected_words, ex_words = young_select(ordered_idxs=cast_list(topk_idxs),
                                                    num_to_select=iter_change_num,
                                                    selected=change_positions__,
                                                    max_num=max_change_num)
            for ele in ex_words:
                change_positions__.pop(ele)
                forbidden_idxs__.pop(ele)
                chars[0][ele] = raw_chars[0][ele]
                log('Drop elder replacement',
                    self.vocab.words[words[i].item()])
            selected_chars = []
            for ele in selected_words:
                if ele in change_positions__:
                    selected_chars.append(change_positions__[ele])
                else:
                    wcid = char_grad_norm[0][ele].argmax()
                    selected_chars.append(wcid)
            word_sids = torch.tensor(selected_words)
            char_wids = torch.tensor(selected_chars)
        else:
            raise Exception

        for word_sid, char_wid in zip(word_sids, char_wids):

            char_grad = char_grads[0][word_sid][char_wid]
            char_vid = chars[0][word_sid][char_wid]

            emb_to_rpl = self.parser.char_lstm.embed.weight[char_vid]
            forbidden_idxs__[word_sid.item()].append(char_vid.item())
            change_positions__[word_sid.item()] = char_wid.item()

            # Find a word to change with dynamically step
            # Note that it is possible that all words found are not as required, e.g.
            #   all neighbours have different tags.
            delta = char_grad / \
                torch.norm(char_grad) * self.config.hkc_step_size
            changed = emb_to_rpl - delta

            dist = {'euc': euc_dist, 'cos': cos_dist}[
                self.config.hkc_dist_measure](changed, self.parser.char_lstm.embed.weight)
            vals, idxs = dist.sort()
            for ele in idxs:
                if ele.item() not in forbidden_idxs__[word_sid.item()] and \
                        ele.item() not in [self.vocab.pad_index, self.vocab.unk_index]:
                    new_char_vid = ele
                    break

            char_vids.append(char_vid)
            new_char_vids.append(new_char_vid)

        new_chars = chars.clone()
        for i in range(len(word_sids)):
            new_chars[0][word_sids[i]][char_wids[i]] = new_char_vids[i]

        # log(dict(forbidden_idxs__))
        # log(change_positions__)

        """
            Evaluating the result
        """
        # print('START EVALUATING')
        # print([self.vocab.words[ele] for ele in self.forbidden_idxs__])
        new_chars_text = [self.vocab.id2char(ele) for ele in new_chars[0]]
        # print(new_words_txt)
        loss, metric = self.task.evaluate(
            [(words, tags, new_chars, arcs, rels)],
            mst=self.config.hkc_mst == 'on')

        def _gen_log_table():
            new_words_text = [self.vocab.id2char(ele) for ele in new_chars[0]]
            raw_words_text = [self.vocab.id2char(ele) for ele in raw_chars[0]]
            tags_text = [self.vocab.tags[ele.item()] for ele in tags[0]]
            _, att_arcs, _ = self.task.predict([(words, tags, new_chars)],
                                               mst=self.config.hkc_mst == 'on')

            table = []
            for i in range(sent_len):
                gold_arc = int(arcs[0][i])
                raw_arc = 0 if i == 0 else raw_arcs[0][i - 1]
                att_arc = 0 if i == 0 else att_arcs[0][i - 1]

                relevant_mask = '&' if \
                    raw_words_text[att_arc] != new_words_text[att_arc] \
                    or raw_words_text[i] != new_words_text[i] \
                    else ""
                table.append([
                    i,
                    raw_words_text[i],
                    '>{}'.format(
                        new_words_text[i]) if raw_words_text[i] != new_words_text[i] else "*",
                    tags_text[i],
                    gold_arc,
                    raw_arc,
                    '>{}{}'.format(
                        att_arc, relevant_mask) if att_arc != raw_arc else '*',
                    word_grad_norm[0][i].item()
                ])
            return table

        if verbose:
            print('$$$$$$$$$$$$$$$$$$$$$$$$$$$')
            print('Iter {}'.format(iter_id))
            print(tabulate(_gen_log_table(), floatfmt=('.6f')))
            print('^^^^^^^^^^^^^^^^^^^^^^^^^^^')

        rpl_detail = ""
        for i in range(len(word_sids)):
            rpl_detail += "{}:{}->{}  ".format(
                self.vocab.id2char(raw_chars[0, word_sids[i]]),
                self.vocab.id2char(chars[0, word_sids[i]]),
                self.vocab.id2char(new_chars[0, word_sids[i]]))
        log("iter {}, uas {:.4f}, ".format(iter_id, metric.uas),
            # "mind {:6.3f}, avgd {:6.3f}, ".format(
            #     repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '',
            rpl_detail
            )
        if metric.uas >= raw_metric.uas - .00001:
            logtable = 'Nothing'
        else:
            logtable = tabulate(_gen_log_table(), floatfmt='.6f')
        return {
            'code': 200,
            'chars': new_chars,
            'attack_metric': metric,
            'logtable': logtable,
            # "forbidden_idxs__": forbidden_idxs__,
            # "change_positions__": change_positions__,
        }
Esempio n. 9
0
 def id2char(self, ids):
     ids = cast_list(ids)
     return ''.join([self.chars[i] for i in ids if i!=0])
Esempio n. 10
0
 def id2rel(self, ids):
     ids = cast_list(ids)
     return [self.rels[i] for i in ids]
Esempio n. 11
0
 def id2tag(self, ids):
     ids = cast_list(ids)
     return [self.tags[i] for i in ids]
Esempio n. 12
0
 def id2word(self, ids):
     ids = cast_list(ids)
     return [self.words[idx] for idx in ids]
Esempio n. 13
0
    def find_replacement(
        self,
        changed,
        must_tags,
        dist_measure,
        forbidden_idxs__,
        repl_method='tagdict',
        words=None,
        word_sid=None,  # Only need when using a tagger
        raw_words=None,
    ) -> (Optional[torch.Tensor], dict):
        if must_tags is None:
            must_tags = tuple(self.vocab.tags)
        if isinstance(must_tags, str):
            must_tags = (must_tags, )

        if repl_method == 'lstm':
            # Pipeline:
            #    256 minimum dists
            # -> Filtered by a NN tagger
            # -> Smallest one
            words = words.repeat(64, words.size(1))
            dists, idxs = self.embed_searcher.find_neighbours(
                changed, 64, dist_measure, False)
            for i, ele in enumerate(idxs):
                words[i][word_sid] = ele
            self.nn_tagger.eval()
            s_tags = self.nn_tagger(words)
            pred_tags = s_tags.argmax(-1)[:, word_sid]
            pred_tags = pred_tags.cpu().numpy().tolist()
            new_word_vid = None
            for i, ele in enumerate(pred_tags):
                if self.vocab.tags[ele] in must_tags:
                    if idxs[i] not in forbidden_idxs__:
                        new_word_vid = idxs[i]
                        break
            return new_word_vid, {
                "avgd": dists.mean().item(),
                "mind": dists.min().item()
            }
        elif repl_method in ['3gram', 'crf']:
            # Pipeline:
            #    256 minimum dists
            # -> Filtered by a Statistical tagger
            # -> Smallest one
            tagger = self.trigram_tagger if repl_method == '3gram' else self.crf_tagger
            word_texts = self.vocab.id2word(words)
            word_sid = word_sid.item()

            dists, idxs = self.embed_searcher.find_neighbours(
                changed, 64, dist_measure, False)

            cands = []
            for ele in cast_list(idxs):
                cand = word_texts.copy()
                cand[word_sid] = self.vocab.words[ele]
                cands.append(cand)

            pred_tags = tagger.tag_sents(cands)
            s_tags = [ele[word_sid][1] for ele in pred_tags]

            new_word_vid = None
            for i, ele in enumerate(s_tags):
                if ele in must_tags:
                    if idxs[i] not in forbidden_idxs__:
                        new_word_vid = idxs[i]
                        break
            return new_word_vid, {
                "avgd": dists.mean().item(),
                "mind": dists.min().item()
            }
        elif repl_method in ['tagdict', 'bertag']:
            # Pipeline:
            #    All dists/Bert filtered dists
            # -> Filtered by a tag dict
            # -> Smallest one
            dist = {
                'euc': euc_dist,
                'cos': cos_dist
            }[dist_measure](changed, self.parser.embed.weight)

            # Mask illegal words by its POS
            if repl_method == 'tagdict':
                msk = self._gen_tag_mask(must_tags, dist.device, dist.size())
            elif repl_method == 'bertag':
                msk = self._gen_tag_mask(must_tags, dist.device, dist.size())
                # Mask illegal words by BERT
                bert_mask = self._gen_bert_mask(
                    " ".join(self.vocab.id2word(raw_words)[1:]),
                    word_sid.item() - 1, dist.device, dist.size())
                # F**k pytorch.
                msk = msk * bert_mask
            else:
                raise Exception

            dist.masked_fill_((1 - msk).bool(), 1000.)
            for ele in forbidden_idxs__:
                dist[ele] = 1000.
            mindist = dist.min()
            if abs(mindist - 1000.) < 0.001:
                new_word_vid = None
            else:
                new_word_vid = dist.argmin()
            return new_word_vid, {}
        else:
            raise NotImplementedError