def show_embedding_info(self): print('*** Statistics of parameters and 2-norm ***') show_mean_std(self.embed, 'Param') show_mean_std(torch.norm(self.embed, p=2, dim=1), 'Norm') print('*** Statistics of distances in a N-nearest neighbourhood ***') nbr_num = [5, 10, 20, 50, 100, 200, 500, 10000, 20000] dists = {nbr: [] for nbr in nbr_num} for ele in cast_list(torch.randint(self.embed.size(0), (50, ))): if self.embed[ele].sum() == 0.: continue idxs, vals = self.find_neighbours(ele, -1, 'euc', False) for nbr in nbr_num: dists[nbr].append(vals[1:nbr + 1]) table = [] for nbr in nbr_num: dists[nbr] = torch.cat(dists[nbr]) table.append( [nbr, dists[nbr].mean().item(), dists[nbr].std().item()]) print(tabulate(table, headers=['N', 'mean', 'std'], floatfmt='.2f')) # exit() print('*** Statistics of distances in a N-nearest neighbourhood ***\n' ' when randomly moving by different step sizes') mve_nom = [1, 2, 5, 10, 20, 50] nbr_num = [5, 10, 20, 50, 100, 500] dists = {mve: {nbr: [] for nbr in nbr_num} for mve in mve_nom} cover = {mve: {nbr: [] for nbr in nbr_num} for mve in mve_nom} for ele in cast_list(torch.randint(self.embed.size(0), (50, ))): if self.embed[ele].sum() == 0.: continue vect = torch.rand_like(self.embed[ele]) vect = vect / torch.norm(vect) ridxs, rvals = self.find_neighbours(self.embed[ele], -1, 'euc', False) for mve in mve_nom: idxs, vals = self.find_neighbours(self.embed[ele] + vect * mve, 500, 'euc', False) for nbr in nbr_num: dists[mve][nbr].append(vals[1:nbr + 1]) cover[mve][nbr].append( compare_idxes(idxs[1:nbr + 1], ridxs[1:nbr + 1])) table = [] for mve in mve_nom: row = [mve] for nbr in nbr_num: dist = torch.cat(dists[mve][nbr]) row.append(dist.mean().item()) row.append("{:.1f}%".format( np.mean(cover[mve][nbr]) / nbr * 100)) table.append(row) print( tabulate(table, headers=[ 'Step', 'D-5', 'I-5', 'D-10', 'I-10', 'D-20', 'I-20', 'D-50', 'I-50', 'D-100', 'I-100' ], floatfmt='.2f'))
def njvri_subidxs(self, njvri_tags): ret = [] for t in njvri_tags: for ptb_tag in HACK_TAGS[t]: if ptb_tag in self.tag_dict: ret.extend(cast_list(self.tag_dict[ptb_tag])) return ret
def single_hack(self, words, tags, arcs, rels, raw_words, raw_metric, raw_arcs, forbidden_idxs__, change_positions__, orphans__, verbose=False, max_change_num=1, iter_change_num=1, iter_id=-1): sent_len = words.size(1) """ Loss back-propagation """ embed_grad = self.backward_loss( words, tags, arcs, rels, loss_based_on=self.config.hkw_loss_based_on) # Sometimes the loss/grad will be zero. # Especially in the case of applying pgd_freq>1 to small sentences: # e.g., the uas of projected version may be 83.33% # while under the case of the unprojected version, the loss is 0. if torch.sum(embed_grad) == 0.0: return {"code": 404, "info": "Loss is zero."} """ Select and change a word """ grad_norm = embed_grad.norm(dim=2) position_mask = [False for _ in range(words.size(1))] # Mask some positions by POS & <UNK> for i in range(sent_len): if self.vocab.tags[tags[0][i]] not in HACK_TAGS[ self.config.hkw_tag_type]: position_mask[i] = True # Mask some orphans for i in orphans__: position_mask[i] = True # Check if the number of changed words exceeds the max value if len(change_positions__) >= max_change_num: for i in range(sent_len): if i not in change_positions__: position_mask[i] = True if all(position_mask): return { "code": 404, "info": "Constrained by tags, no valid word to replace." } for i in range(sent_len): if position_mask[i]: grad_norm[0][i] = -(grad_norm[0][i] + 1000) # Select a word and forbid itself word_sids = [] # type: list[torch.Tensor] word_vids = [] # type: list[torch.Tensor] new_word_vids = [] # type: list[torch.Tensor] _, topk_idxs = grad_norm[0].sort(descending=True) selected_words = elder_select(ordered_idxs=cast_list(topk_idxs), num_to_select=iter_change_num, selected=change_positions__, max_num=max_change_num) # The position mask will ensure that at least one word is legal, # but the second one may not be allowed selected_words = [ ele for ele in selected_words if position_mask[ele] is False ] word_sids = torch.tensor(selected_words) for word_sid in word_sids: word_grad = embed_grad[0][word_sid] word_vid = words[0][word_sid] emb_to_rpl = self.parser.embed.weight[word_vid] forbidden_idxs__.append(word_vid.item()) change_positions__.add(word_sid.item()) # print(self.change_positions__) # Find a word to change with dynamically step # Note that it is possible that all words found are not as required, e.g. # all neighbours have different tags. delta = word_grad / \ torch.norm(word_grad) * self.config.hkw_step_size changed = emb_to_rpl - delta must_tag = self.vocab.tags[tags[0][word_sid].item()] new_word_vid, repl_info = self.find_replacement( changed, must_tag, dist_measure=self.config.hkw_dist_measure, forbidden_idxs__=forbidden_idxs__, repl_method=self.config.hkw_repl_method, words=words, word_sid=word_sid, raw_words=raw_words) word_vids.append(word_vid) new_word_vids.append(new_word_vid) new_words = words.clone() exist_change = False for i in range(len(word_vids)): if new_word_vids[i] is not None: new_words[0][word_sids[i]] = new_word_vids[i] exist_change = True if not exist_change: # if self.config.hkw_selection == 'orphan': # orphans__.add(word_sids[0]) # log('iter {}, Add word {}\'s location to orphans.'.format( # iter_id, # self.vocab.words[raw_words[0][word_sid].item()])) # return { # 'code': '200', # 'words': words, # 'atack_metric': 100., # 'logtable': 'This will be never selected' # } log('Attack failed.') return { 'code': 404, 'info': 'Neighbours of all selected words have different tags.' } # if new_word_vid is None: # log('Attack failed.') # return {'code': 404, # 'info': 'Neighbours of the selected words have different tags.'} # new_words = words.clone() # new_words[0][word_sid] = new_word_vid """ Evaluating the result """ # print('START EVALUATING') # print([self.vocab.words[ele] for ele in self.forbidden_idxs__]) new_words_text = [self.vocab.words[i.item()] for i in new_words[0]] # print(new_words_txt) loss, metric = self.task.evaluate( [(new_words, tags, None, arcs, rels)], mst=self.config.hkw_mst == 'on') def _gen_log_table(): new_words_text = [self.vocab.words[i.item()] for i in new_words[0]] raw_words_text = [self.vocab.words[i.item()] for i in raw_words[0]] tags_text = [self.vocab.tags[i.item()] for i in tags[0]] _, att_arcs, _ = self.task.predict([(new_words, tags, None)], mst=self.config.hkw_mst == 'on') table = [] for i in range(sent_len): gold_arc = int(arcs[0][i]) raw_arc = 0 if i == 0 else raw_arcs[0][i - 1] att_arc = 0 if i == 0 else att_arcs[0][i - 1] relevant_mask = '&' if \ raw_words[0][att_arc] != new_words[0][att_arc] or \ raw_words_text[i] != new_words_text[i] else "" table.append([ i, raw_words_text[i], '>{}'.format(new_words_text[i]) if raw_words_text[i] != new_words_text[i] else "*", tags_text[i], gold_arc, raw_arc, '>{}{}'.format( att_arc, relevant_mask) if att_arc != raw_arc else '*', grad_norm[0][i].item() ]) return table if verbose: print('$$$$$$$$$$$$$$$$$$$$$$$$$$$') print('Iter {}'.format(iter_id)) print(tabulate(_gen_log_table(), floatfmt=('.6f'))) print('^^^^^^^^^^^^^^^^^^^^^^^^^^^') rpl_detail = "" for i in range(len(word_sids)): if new_word_vids[i] is not None: rpl_detail += "{}:{}->{} ".format( self.vocab.words[raw_words[0][word_sids[i]].item()], self.vocab.words[word_vids[i].item()], self.vocab.words[new_word_vids[i].item()]) log( "iter {}, uas {:.4f}, ".format(iter_id, metric.uas), "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '', rpl_detail) if metric.uas >= raw_metric.uas - .00001: logtable = 'Nothing' else: logtable = tabulate(_gen_log_table(), floatfmt='.6f') return { 'code': 200, 'words': new_words, 'attack_metric': metric, 'logtable': logtable, # "forbidden_idxs__": forbidden_idxs__, # "change_positions__": change_positions__, }
def compare_idxes(nbr1, nbr2): nbr1 = set(cast_list(nbr1)) nbr2 = set(cast_list(nbr2)) inter = nbr1.intersection(nbr2) return len(inter)
def single_hack(self, instance, src_span, tgt_span, iter_id, raw_words, raw_metric, raw_arcs, forbidden_idxs__: list, change_positions__: set, max_change_num, iter_change_num, verbose=False): # yapf: enable words, tags, chars, arcs, rels = instance sent_len = words.size(1) # Backward loss embed_grad = self.backward_loss(instance=instance, mask_idxs=ex_span_idx(tgt_span, sent_len), verbose=True) grad_norm = embed_grad.norm(dim=2) if self.config.hks_word_random == "on": grad_norm = torch.rand(grad_norm.size(), device=grad_norm.device) position_mask = [False for _ in range(words.size(1))] # Mask some positions for i in range(sent_len): if rels[0][i].item( ) == self.vocab.rel_dict['punct'] or not src_span[0] <= i <= src_span[1]: position_mask[i] = True # Check if the number of changed words exceeds the max value if len(change_positions__) >= max_change_num: for i in range(sent_len): if i not in change_positions__: position_mask[i] = True for i in range(sent_len): if position_mask[i]: grad_norm[0][i] = -(grad_norm[0][i] + 1000) # print(grad_norm) # Select a word and forbid itself word_sids = [] # type: list[torch.Tensor] word_vids = [] # type: list[torch.Tensor] new_word_vids = [] # type: list[torch.Tensor] # _, topk_idxs = grad_norm[0].topk(min(max_change_num, len(src_idxs))) # for ele in topk_idxs: # word_sids.append(ele) _, topk_idxs = grad_norm[0].sort(descending=True) selected_words = elder_select(ordered_idxs=cast_list(topk_idxs), num_to_select=iter_change_num, selected=change_positions__, max_num=max_change_num) # The position mask will ensure that at least one word is legal, # but the second one may not be allowed selected_words = [ele for ele in selected_words if position_mask[ele] is False] word_sids = torch.tensor(selected_words) for word_sid in word_sids: word_vid = words[0][word_sid] emb_to_rpl = self.parser.embed.weight[word_vid] if self.config.hks_step_size > 0: word_grad = embed_grad[0][word_sid] delta = word_grad / \ torch.norm(word_grad) * self.config.hks_step_size changed = emb_to_rpl - delta tag_type = self.config.hks_constraint if tag_type == 'any': must_tag = None elif tag_type == 'same': must_tag = self.vocab.tags[tags[0][word_sid].item()] elif re.match("[njvri]+", tag_type): must_tag = HACK_TAGS[tag_type] else: raise Exception forbidden_idxs__.append(word_vid) change_positions__.add(word_sid.item()) new_word_vid, repl_info = self.find_replacement( changed, must_tag, dist_measure=self.config.hks_dist_measure, forbidden_idxs__=forbidden_idxs__, repl_method='tagdict', words=words, word_sid=word_sid) else: new_word_vid = random.randint(0, self.vocab.n_words) while new_word_vid in [self.vocab.pad_index, self.vocab.word_dict["<root>"]]: new_word_vid = random.randint(0, self.vocab.n_words - 1) new_word_vid = torch.tensor(new_word_vid, device=words.device) repl_info = {} change_positions__.add(word_sid.item()) word_vids.append(word_vid) new_word_vids.append(new_word_vid) # log(delta @ (self.parser.embed.weight[new_word_vid] - self.parser.embed.weight[word_vid])) # log() new_words = words.clone() for i in range(len(word_vids)): if new_word_vids[i] is not None: new_words[0][word_sids[i]] = new_word_vids[i] """ Evaluating the result """ # print('START EVALUATING') # print([self.vocab.words[ele] for ele in self.forbidden_idxs__]) metric = self.task.partial_evaluate(instance=(new_words, tags, None, arcs, rels), mask_idxs=ex_span_idx(tgt_span, sent_len), mst=self.config.hks_mst == 'on') att_tags, att_arcs, att_rels = self.task.predict([(new_words, tags, None)], mst=self.config.hks_mst == 'on') # if verbose: # print('$$$$$$$$$$$$$$$$$$$$$$$$$$$') # print('Iter {}'.format(iter_id)) # print(tabulate(_gen_log_table(), floatfmt=('.6f'))) # print('^^^^^^^^^^^^^^^^^^^^^^^^^^^') rpl_detail = "" for i in range(len(word_sids)): if new_word_vids[i] is not None: rpl_detail += "{}:{}->{} ".format( self.vocab.words[raw_words[0][word_sids[i]].item()], self.vocab.words[word_vids[i].item()], self.vocab.words[new_word_vids[i].item()]) log( "iter {}, uas {:.4f}, ".format(iter_id, metric.uas), "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '', rpl_detail) if metric.uas >= raw_metric.uas - .00001: info = 'Nothing' else: info = tabulate(self._gen_log_table(raw_words, new_words, tags, arcs, rels, raw_arcs, att_arcs, src_span, tgt_span), floatfmt='.6f') return { 'code': 200, 'words': new_words, 'attack_metric': metric, 'logtable': info, }
def random_hack(self, instance, sentence, tgt_span_lst, src_span): t0 = time.time() words, tags, chars, arcs, rels = instance sent_len = words.size(1) raw_words_lst = cast_list(words) idxs = gen_idxs_to_substitute(list(range(src_span[0], src_span[1] + 1)), int(self.config.hks_max_change), self.config.hks_cand_num) if self.config.hks_blk_repl_tag == 'any': cand_words_lst = subsitute_by_idxs(raw_words_lst, idxs, self.blackbox_sub_idxs) elif re.match("[njvri]+", self.config.hks_blk_repl_tag): cand_words_lst = subsitute_by_idxs(raw_words_lst, idxs, self.njvri_subidxs(self.config.hks_blk_repl_tag)) elif re.match("keep", self.config.hks_blk_repl_tag): tag_texts = cast_list(tags[0]) vocab_idxs_lst = [] for i in range(src_span[0], src_span[1] + 1): vocab_idxs_lst.append(self.univ_subidxs(HACK_TAGS.ptb2uni(self.vocab.tags[tag_texts[i]]))) cand_words_lst = subsitute_by_idxs_2(raw_words_lst, idxs, src_span[0], vocab_idxs_lst) else: raise Exception # First index is raw words all_words = torch.tensor([raw_words_lst] + cand_words_lst, device=words.device) raw_words = all_words[0:1] cand_words = all_words[1:] tgt_idxs = [] for tgt_span in tgt_span_lst: tgt_idxs.extend([i for i in range(tgt_span[0], tgt_span[1] + 1) if i != tgt_span[2]]) ex_tgt_idxs = [i for i in range(sent_len) if i not in tgt_idxs] pred_arcs, pred_rels, gold_arcs, gold_rels = self.task.partial_evaluate( (all_words, tags.expand_as(all_words), None, arcs.expand_as(all_words), rels.expand_as(all_words)), mask_idxs=ex_tgt_idxs, mst=False, return_metric=False) # raw_metric = ParserMetric() # raw_metric(pred_arcs[0], pred_rels[0], gold_arcs[0], gold_rels[0]) raw_metric = self.task.partial_evaluate((raw_words, tags, None, arcs, rels), ex_tgt_idxs, mst=True) succ = False arc_delta = (pred_arcs[1:] - pred_arcs[0]).abs().sum(1) if arc_delta.sum() == 0: # direct forward to the last attacked sentence att_id = self.config.hks_cand_num - 1 else: for att_id in range(0, self.config.hks_cand_num): if arc_delta[att_id] != 0: att_metric = self.task.partial_evaluate( (cand_words[att_id:att_id + 1], tags, None, arcs, rels), ex_tgt_idxs, mst=True) if att_metric.uas < raw_metric.uas - 0.0001: succ = True break # ATT_ID will be equal to att_id when failing t1 = time.time() _, raw_arcs, raw_rels = self.task.predict([(raw_words, tags, None)], mst=True) _, att_arcs, att_rels = self.task.predict([(cand_words[att_id:att_id + 1], tags, None)], mst=True) if not succ: info = 'Nothing' else: for tgt_span in tgt_span_lst: if tgt_span == src_span: continue raw_span_corr = 0 att_span_corr = 0 for i in range(tgt_span[0], tgt_span[1] + 1): if i == tgt_span[2]: continue if raw_arcs[0][i - 1] == arcs[0][i - 1]: raw_span_corr += 1 if att_arcs[0][i - 1] == arcs[0][i - 1]: att_span_corr += 1 if att_span_corr < raw_span_corr: break info = tabulate(self._gen_log_table(words, cand_words[att_id:att_id + 1], tags, arcs, rels, raw_arcs, att_arcs, src_span, tgt_span), floatfmt='.6f') return defaultdict( lambda: -1, { "succ": 1 if succ else 0, "att_id": att_id if att_id < self.config.hks_cand_num - 1 else np.nan, "num_changed": self.config.hks_max_change, "time": t1 - t0, "logtable": info })
def univ_subidxs(self, univ_tag): ret = [] for ptbtag in HACK_TAGS.uni2ptb(univ_tag): if ptbtag in self.tag_dict: ret.extend(cast_list(self.tag_dict[ptbtag])) return ret
def single_hack(self, words, tags, chars, arcs, rels, raw_chars, raw_metric, raw_arcs, forbidden_idxs__, change_positions__, verbose=False, max_change_num=1, iter_change_num=1, iter_id=-1): sent_len = words.size(1) """ Loss back-propagation """ char_grads, word_grads = self.backward_loss(words, chars, arcs, rels) """ Select and change a word """ word_grad_norm = word_grads.norm(dim=-1) # 1 x length # 1 x length x max_char_length char_grad_norm = char_grads.norm(dim=-1) woca_grad_norm = char_grad_norm.max(dim=-1) # 1 x length position_mask = [False for _ in range(words.size(1))] # Mask some positions by POS & <UNK> for i in range(sent_len): if rels[0][i].item() == self.vocab.rel_dict['punct']: position_mask[i] = True # Check if the number of changed words exceeds the max value if len(change_positions__) >= max_change_num: for i in range(sent_len): if i not in change_positions__: position_mask[i] = True if all(position_mask): return {"code": 404, "info": "Constrained by tags, no valid word to replace."} for i in range(sent_len): if position_mask[i]: word_grad_norm[0][i] = -(word_grad_norm[0][i] + 1000) char_grad_norm[0][i] = -(char_grad_norm[0][i] + 1000) # Select a word and forbid itself word_sids = [] char_wids = [] char_vids = [] new_char_vids = [] if self.config.hkc_selection == 'elder': _, topk_idxs = word_grad_norm[0].sort(descending=True) selected_words = elder_select(ordered_idxs=cast_list(topk_idxs), num_to_select=iter_change_num, selected=change_positions__, max_num=max_change_num) selected_chars = [] for ele in selected_words: if ele in change_positions__: selected_chars.append(change_positions__[ele]) else: wcid = char_grad_norm[0][ele].argmax() selected_chars.append(wcid) word_sids = torch.tensor(selected_words) char_wids = torch.tensor(selected_chars) elif self.config.hkc_selection == 'young': _, topk_idxs = word_grad_norm[0].sort(descending=True) selected_words, ex_words = young_select(ordered_idxs=cast_list(topk_idxs), num_to_select=iter_change_num, selected=change_positions__, max_num=max_change_num) for ele in ex_words: change_positions__.pop(ele) forbidden_idxs__.pop(ele) chars[0][ele] = raw_chars[0][ele] log('Drop elder replacement', self.vocab.words[words[i].item()]) selected_chars = [] for ele in selected_words: if ele in change_positions__: selected_chars.append(change_positions__[ele]) else: wcid = char_grad_norm[0][ele].argmax() selected_chars.append(wcid) word_sids = torch.tensor(selected_words) char_wids = torch.tensor(selected_chars) else: raise Exception for word_sid, char_wid in zip(word_sids, char_wids): char_grad = char_grads[0][word_sid][char_wid] char_vid = chars[0][word_sid][char_wid] emb_to_rpl = self.parser.char_lstm.embed.weight[char_vid] forbidden_idxs__[word_sid.item()].append(char_vid.item()) change_positions__[word_sid.item()] = char_wid.item() # Find a word to change with dynamically step # Note that it is possible that all words found are not as required, e.g. # all neighbours have different tags. delta = char_grad / \ torch.norm(char_grad) * self.config.hkc_step_size changed = emb_to_rpl - delta dist = {'euc': euc_dist, 'cos': cos_dist}[ self.config.hkc_dist_measure](changed, self.parser.char_lstm.embed.weight) vals, idxs = dist.sort() for ele in idxs: if ele.item() not in forbidden_idxs__[word_sid.item()] and \ ele.item() not in [self.vocab.pad_index, self.vocab.unk_index]: new_char_vid = ele break char_vids.append(char_vid) new_char_vids.append(new_char_vid) new_chars = chars.clone() for i in range(len(word_sids)): new_chars[0][word_sids[i]][char_wids[i]] = new_char_vids[i] # log(dict(forbidden_idxs__)) # log(change_positions__) """ Evaluating the result """ # print('START EVALUATING') # print([self.vocab.words[ele] for ele in self.forbidden_idxs__]) new_chars_text = [self.vocab.id2char(ele) for ele in new_chars[0]] # print(new_words_txt) loss, metric = self.task.evaluate( [(words, tags, new_chars, arcs, rels)], mst=self.config.hkc_mst == 'on') def _gen_log_table(): new_words_text = [self.vocab.id2char(ele) for ele in new_chars[0]] raw_words_text = [self.vocab.id2char(ele) for ele in raw_chars[0]] tags_text = [self.vocab.tags[ele.item()] for ele in tags[0]] _, att_arcs, _ = self.task.predict([(words, tags, new_chars)], mst=self.config.hkc_mst == 'on') table = [] for i in range(sent_len): gold_arc = int(arcs[0][i]) raw_arc = 0 if i == 0 else raw_arcs[0][i - 1] att_arc = 0 if i == 0 else att_arcs[0][i - 1] relevant_mask = '&' if \ raw_words_text[att_arc] != new_words_text[att_arc] \ or raw_words_text[i] != new_words_text[i] \ else "" table.append([ i, raw_words_text[i], '>{}'.format( new_words_text[i]) if raw_words_text[i] != new_words_text[i] else "*", tags_text[i], gold_arc, raw_arc, '>{}{}'.format( att_arc, relevant_mask) if att_arc != raw_arc else '*', word_grad_norm[0][i].item() ]) return table if verbose: print('$$$$$$$$$$$$$$$$$$$$$$$$$$$') print('Iter {}'.format(iter_id)) print(tabulate(_gen_log_table(), floatfmt=('.6f'))) print('^^^^^^^^^^^^^^^^^^^^^^^^^^^') rpl_detail = "" for i in range(len(word_sids)): rpl_detail += "{}:{}->{} ".format( self.vocab.id2char(raw_chars[0, word_sids[i]]), self.vocab.id2char(chars[0, word_sids[i]]), self.vocab.id2char(new_chars[0, word_sids[i]])) log("iter {}, uas {:.4f}, ".format(iter_id, metric.uas), # "mind {:6.3f}, avgd {:6.3f}, ".format( # repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '', rpl_detail ) if metric.uas >= raw_metric.uas - .00001: logtable = 'Nothing' else: logtable = tabulate(_gen_log_table(), floatfmt='.6f') return { 'code': 200, 'chars': new_chars, 'attack_metric': metric, 'logtable': logtable, # "forbidden_idxs__": forbidden_idxs__, # "change_positions__": change_positions__, }
def id2char(self, ids): ids = cast_list(ids) return ''.join([self.chars[i] for i in ids if i!=0])
def id2rel(self, ids): ids = cast_list(ids) return [self.rels[i] for i in ids]
def id2tag(self, ids): ids = cast_list(ids) return [self.tags[i] for i in ids]
def id2word(self, ids): ids = cast_list(ids) return [self.words[idx] for idx in ids]
def find_replacement( self, changed, must_tags, dist_measure, forbidden_idxs__, repl_method='tagdict', words=None, word_sid=None, # Only need when using a tagger raw_words=None, ) -> (Optional[torch.Tensor], dict): if must_tags is None: must_tags = tuple(self.vocab.tags) if isinstance(must_tags, str): must_tags = (must_tags, ) if repl_method == 'lstm': # Pipeline: # 256 minimum dists # -> Filtered by a NN tagger # -> Smallest one words = words.repeat(64, words.size(1)) dists, idxs = self.embed_searcher.find_neighbours( changed, 64, dist_measure, False) for i, ele in enumerate(idxs): words[i][word_sid] = ele self.nn_tagger.eval() s_tags = self.nn_tagger(words) pred_tags = s_tags.argmax(-1)[:, word_sid] pred_tags = pred_tags.cpu().numpy().tolist() new_word_vid = None for i, ele in enumerate(pred_tags): if self.vocab.tags[ele] in must_tags: if idxs[i] not in forbidden_idxs__: new_word_vid = idxs[i] break return new_word_vid, { "avgd": dists.mean().item(), "mind": dists.min().item() } elif repl_method in ['3gram', 'crf']: # Pipeline: # 256 minimum dists # -> Filtered by a Statistical tagger # -> Smallest one tagger = self.trigram_tagger if repl_method == '3gram' else self.crf_tagger word_texts = self.vocab.id2word(words) word_sid = word_sid.item() dists, idxs = self.embed_searcher.find_neighbours( changed, 64, dist_measure, False) cands = [] for ele in cast_list(idxs): cand = word_texts.copy() cand[word_sid] = self.vocab.words[ele] cands.append(cand) pred_tags = tagger.tag_sents(cands) s_tags = [ele[word_sid][1] for ele in pred_tags] new_word_vid = None for i, ele in enumerate(s_tags): if ele in must_tags: if idxs[i] not in forbidden_idxs__: new_word_vid = idxs[i] break return new_word_vid, { "avgd": dists.mean().item(), "mind": dists.min().item() } elif repl_method in ['tagdict', 'bertag']: # Pipeline: # All dists/Bert filtered dists # -> Filtered by a tag dict # -> Smallest one dist = { 'euc': euc_dist, 'cos': cos_dist }[dist_measure](changed, self.parser.embed.weight) # Mask illegal words by its POS if repl_method == 'tagdict': msk = self._gen_tag_mask(must_tags, dist.device, dist.size()) elif repl_method == 'bertag': msk = self._gen_tag_mask(must_tags, dist.device, dist.size()) # Mask illegal words by BERT bert_mask = self._gen_bert_mask( " ".join(self.vocab.id2word(raw_words)[1:]), word_sid.item() - 1, dist.device, dist.size()) # F**k pytorch. msk = msk * bert_mask else: raise Exception dist.masked_fill_((1 - msk).bool(), 1000.) for ele in forbidden_idxs__: dist[ele] = 1000. mindist = dist.min() if abs(mindist - 1000.) < 0.001: new_word_vid = None else: new_word_vid = dist.argmin() return new_word_vid, {} else: raise NotImplementedError