def select_by_deltalogit(self, instance, sentence): minl = self.config.hks_min_span_len maxl = self.config.hks_max_span_len # Compute the ``vulnerable'' values of each words words, tags, chars, arcs, rels = instance # sent_len = words.size(1) # The returned margins does not contain <ROOT> margins = self.compute_margin(instance) log(margins) # Count the vulnerable words in each span, # Select the most vulerable span as target. vul_scores = [0] for ele in margins: if ele > 0: vul_scores.append(10 - ele) else: vul_scores.append(0) spans = filter_spans(gen_spans(sentence), minl, maxl, True) span_vuls = list() # span_ratios = list() for span in spans: span_vuls.append(sum(vul_scores[span[0]:span[1] + 1])) # span_ratios.append(span_vuls[-1] / (span[1] + 1 - span[0])) pairs = [] for tid, t in enumerate(spans): for s in spans: if check_gap(s, t, self.config.hks_span_gap): pairs.append((span_vuls[tid], (t, s))) if len(pairs) == 0: return None spairs = sorted(pairs, key=lambda x: x[0], reverse=True) return list(zip(*spairs))[1]
def select_tgt_span(self, instance, sentence, mode): minl = self.config.hks_min_span_len maxl = self.config.hks_max_span_len if mode == 'vul': # Compute the ``vulnerable'' values of each words words, tags, chars, arcs, rels = instance # sent_len = words.size(1) # The returned margins does not contain <ROOT> margins = self.compute_margin(instance) log(margins) # Count the vulnerable words in each span, # Select the most vulerable span as target. vul_margins = [0] for ele in margins: if 0 < ele < 1: vul_margins.append(100) elif 1 <= ele < 2: vul_margins.append(1) else: vul_margins.append(-1) spans = filter_spans(gen_spans(sentence), minl, maxl, True) span_vuls = list() # span_ratios = list() for span in spans: span_vuls.append(sum(vul_margins[span[0]:span[1] + 1])) # span_ratios.append(span_vuls[-1] / (span[1] + 1 - span[0])) pairs = [] for tid, t in enumerate(spans): for s in spans: if check_gap(s, t, self.config.hks_span_gap): pairs.append((span_vuls[tid], (t, s))) if len(pairs) == 0: return None spairs = sorted(pairs, key=lambda x: x[0], reverse=True) return list(zip(*spairs))[1] tgt_picker = CherryPicker(lower_is_better=False) for span_i, span in enumerate(spans): tgt_picker.add(span_vuls[span_i], span) if tgt_picker.size == 0: log('Target span not found') return None _, _, tgt_span = tgt_picker.select_best_point() return tgt_span elif mode == "rdm": spans = filter_spans(gen_spans(sentence), minl, maxl, True) if len(spans) == 0: return None return random.choice(spans) else: raise Exception
def select_src_span(self, instance, sentence, tgt_span, mode): minl = self.config.hks_min_span_len maxl = self.config.hks_max_span_len gap = self.config.hks_span_gap if mode == "cls": spans = filter_spans(gen_spans(sentence), minl, maxl, True) src_span = None src_picker = CherryPicker(lower_is_better=True) for span in spans: stgap = get_gap(span, tgt_span) if stgap >= gap: src_picker.add(stgap, span) if src_picker.size == 0: log('Source span not found') return None _, _, src_span = src_picker.select_best_point() return src_span elif mode == 'rdm': spans = filter_spans(gen_spans(sentence), minl, maxl, True) src_spans = [] for span in spans: if check_gap(span, tgt_span, self.config.hks_span_gap): src_spans.append(span) if len(src_spans) == 0: return None return random.choice(src_spans) elif mode == 'nom': sent_len = instance[0].size(1) spans = filter_spans(gen_spans(sentence), minl, maxl, True) embed_grad = self.backward_loss(instance, ex_span_idx(tgt_span, sent_len)) grad_norm = embed_grad.norm(dim=2) # 1 x sent_len, <root> included src_picker = CherryPicker(lower_is_better=False) for span in spans: if check_gap(span, tgt_span, self.config.hks_span_gap): src_picker.add(grad_norm[0][span[0]:span[1] + 1].sum(), span) if src_picker.size == 0: return None _, _, src_span = src_picker.select_best_point() return src_span else: raise Exception
def init_logger(self, config): if config.logf == 'on': if config.hk_use_worker == 'on': worker_info = "-{}@{}".format(config.hk_num_worker, config.hk_worker_id) else: worker_info = "" log_config('{}'.format(config.mode), log_path=config.workspace, default_target='cf') from dpattack.libs.luna import log else: log = print log('[General Settings]') log(config) log('[Hack Settings]') for arg in config.kwargs: if arg.startswith('hk'): log(arg, '\t', config.kwargs[arg]) log('------------------')
def select_span_pair(self, instance, sentence): if self.config.hks_span_selection in \ ["vulnom", "rdmnom", "vulcls", "rdmcls", "rdmrdm", "rdmnom"]: tgt_mode = self.config.hks_span_selection[:3] src_mode = self.config.hks_span_selection[3:] tgt_span = self.select_tgt_span(instance, sentence, tgt_mode) if tgt_span is None: log('Target span not found') return src_span = self.select_src_span(instance, sentence, tgt_span, src_mode) if src_span is None: log('Source span not found') return ret = [(tgt_span, src_span)] elif self.config.hks_span_selection in ['jacobian1', 'jacobian2']: ret = self.select_by_jacobian(instance, sentence, self.config.hks_span_selection) elif self.config.hks_span_selection in ['random']: ret = self.select_by_random(instance, sentence) elif self.config.hks_span_selection in ['deltalogit']: ret = self.select_by_deltalogit(instance, sentence) else: raise Exception if ret is None: log('Span pair not found') return return ret
def meta_hack(self, instance, sentence): t0 = time.time() ram_reset("hk") words, tags, chars, arcs, rels = instance sent_len = words.size(1) words_text = self.vocab.id2word(words[0]) succ = False if self.config.hks_color == "black": spans = filter_spans(gen_spans(sentence), self.config.hks_min_span_len, self.config.hks_max_span_len, True) hack_result = None for pair_id, src_span in enumerate(spans): log('[Chosen source] ', src_span, ' '.join(words_text[src_span[0]:src_span[1] + 1])) tgt_span_lst = [ ele for ele in spans if check_gap(ele, src_span, self.config.hks_span_gap) ] if len(tgt_span_lst) == 0: continue hack_result = self.random_hack(instance, sentence, tgt_span_lst, src_span) if hack_result['succ'] == 1: succ = True log("Succ on source span {}.\n".format(src_span)) break log("Fail on source span {}.\n".format(src_span)) if hack_result is None: log("Not enough span pairs") return None elif self.config.hks_color in ['white', 'grey']: pairs = self.select_span_pair(instance, sentence) if pairs is None: return None for pair_id, (tgt_span, src_span) in enumerate(pairs[:self.config.hks_topk_pair]): log('[Chosen span] ', 'tgt-', tgt_span, ' '.join(words_text[tgt_span[0]:tgt_span[1] + 1]), 'src-', src_span, ' '.join(words_text[src_span[0]:src_span[1] + 1])) if self.config.hks_color == "white": hack_result = self.white_hack(instance, sentence, tgt_span, src_span) elif self.config.hks_color == "grey": hack_result = self.random_hack(instance, sentence, [tgt_span], src_span) else: raise Exception if hack_result['succ'] == 1: succ = True log("Succ on the span pair {}/{}.\n".format(pair_id, len(pairs))) break log("Fail on the span pair {}/{}.\n".format(pair_id, len(pairs))) hack_result['meta_trial_pair'] = pair_id + 1 hack_result['meta_total_pair'] = len(pairs) hack_result['meta_succ_trial_pair'] = pair_id + 1 if succ else 0 hack_result['meta_succ_total_pair'] = len(pairs) if succ else 0 t1 = time.time() log('Sentence cost {:.1f}s'.format(t1 - t0)) hack_result['meta_time'] = t1 - t0 return hack_result
def __call__(self, config): self.init_logger(config) self.setup(config) self.blackbox_sub_idxs = [ ele for ele in range(self.vocab.n_words) if ele not in [self.vocab.pad_index, self.vocab.unk_index, self.vocab.word_dict['<root>']] ] agg = Aggregator() for sid, (words, tags, chars, arcs, rels) in enumerate(self.loader): # if sid > 100: # continue # zjh = [ # 20, 41, 46, 117, 137, 143, 183, 198, 258, 295, 310, 350, 410, 421, 464, 485, 512, # 528, 544, 600, 601, 681, 702, 728, 735, 738, 762, 783, 794, 803, 805, 821, 844, 845, # 921, 931, 937, 939, 948, 962, 968, 975, 1019, 1044, 1068, 1069, 1096, 1104, 1121, # 1122, 1138, 1142, 1155, 1163, 1180, 1197, 1224, 1228, 1270, 1272, 1292, 1306, 1315, # 1317, 1342, 1345, 1393, 1400, 1431, 1478, 1503, 1522, 1524, 1526, 1608, 1677, 1729, # 1759, 1775, 1795, 1811, 1831, 1925, 1929, 1983, 1984, 2026, 2031, 2176, 2234, 2291, # 2296, 2318, 2327, 2330, 2342, 2343, 2355, 2360 # ] # if sid + 1 not in zjh: # continue # if sid > 600 - 1: # continue words_text = self.vocab.id2word(words[0]) tags_text = self.vocab.id2tag(tags[0]) log('****** {}: \n{}\n{}'.format(sid, " ".join(words_text), " ".join(tags_text))) result = self.meta_hack(instance=(words, tags, chars, arcs, rels), sentence=self.corpus[sid]) if result is None: continue # yapf: disable agg.aggregate( ("iters", result['iters']), ("time", result['time']), ("succ", result['succ']), ('best_iter', result['best_iter']), ("changed", result['num_changed']), ("att_id", result['att_id']), ("meta_time", result['meta_time']), ("meta_trial_pair", result['meta_trial_pair']), ("meta_total_pair", result['meta_total_pair']), ("meta_succ_trial_pair", result['meta_succ_trial_pair']), ("meta_succ_total_pair", result['meta_succ_total_pair']), ) # # WARNING: SOME SENTENCE NOT SHOWN! if result: log('Show result from iter {}:'.format(result['best_iter'])) log(result['logtable']) log('Aggregated result: ' 'iters(avg) {:.1f}, time(avg) {:.1f}s, meta_time(avg) {:.1f}s, ' 'succ rate {:.2f}% ({}/{}), best_iter(avg) {:.1f}, best_iter(std) {:.1f}, ' 'changed(avg) {:.1f}, ' 'total pair {}, trial pair {}, ' 'succ total pair {}, succ trial pair {}, ' 'succ att id {}'.format( agg.mean('iters'), agg.mean('time'), agg.mean('meta_time'), agg.mean('succ') * 100, agg.sum('succ'), agg.size, agg.mean('best_iter'), agg.std('best_iter'), agg.mean('changed'), agg.sum('meta_total_pair'), agg.sum('meta_trial_pair'), agg.sum('meta_succ_total_pair'), agg.sum('meta_succ_trial_pair'), agg.aggregated(key='att_id', reduce=np.nanmean) )) log()
def white_hack(self, instance, sentence, tgt_span, src_span): words, tags, chars, arcs, rels = instance sent_len = words.size(1) raw_words = words.clone() var_words = words.clone() raw_metric = self.task.partial_evaluate(instance=(raw_words, tags, None, arcs, rels), mask_idxs=ex_span_idx(tgt_span, sent_len), mst=self.config.hks_mst == 'on') _, raw_arcs, _ = self.task.predict([(raw_words, tags, None)]) forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index] change_positions__ = set() if self.config.hks_max_change > 0.9999: max_change_num = int(self.config.hks_max_change) else: max_change_num = max(1, int(self.config.hks_max_change * words.size(1))) iter_change_num = min(max_change_num, self.config.hks_iter_change) picker = CherryPicker(lower_is_better=True) t0 = time.time() picker.add(raw_metric, {"num_changed": 0, "logtable": 'No modification'}) log('iter -1, uas {:.4f}'.format(raw_metric.uas)) succ = False for iter_id in range(self.config.hks_steps): result = self.single_hack(instance=(var_words, tags, None, arcs, rels), raw_words=raw_words, raw_metric=raw_metric, raw_arcs=raw_arcs, src_span=src_span, tgt_span=tgt_span, iter_id=iter_id, forbidden_idxs__=forbidden_idxs__, change_positions__=change_positions__, max_change_num=max_change_num, iter_change_num=iter_change_num) if result['code'] == 200: var_words = result['words'] picker.add(result['attack_metric'], { 'logtable': result['logtable'], "num_changed": len(change_positions__) }) if result['attack_metric'].uas < raw_metric.uas - 0.00001: succ = True log('Succeed in step {}'.format(iter_id)) break elif result['code'] == 404: log('FAILED') break t1 = time.time() best_iter, best_attack_metric, best_info = picker.select_best_point() return defaultdict( lambda: -1, { "succ": 1 if succ else 0, "raw_metric": raw_metric, "attack_metric": best_attack_metric, "iters": iter_id, "best_iter": best_iter, "num_changed": best_info['num_changed'], "time": t1 - t0, "logtable": best_info['logtable'] })
def single_hack(self, instance, mask_idxs, iter_id, raw_words, raw_metric, raw_arcs, forbidden_idxs__: list, change_positions__: set, max_change_num, verbose=False): words, tags, chars, arcs, rels = instance sent_len = words.size(1) # Backward loss embed_grad = self.backward_loss(instance=instance, mask_idxs=mask_idxs) grad_norm = embed_grad.norm(dim=2) position_mask = [False for _ in range(words.size(1))] # Mask some positions for i in range(sent_len): if i not in mask_idxs or rels[0][i].item( ) == self.vocab.rel_dict['punct']: position_mask[i] = True # Check if the number of changed words exceeds the max value if len(change_positions__) >= max_change_num: for i in range(sent_len): if i not in change_positions__: position_mask[i] = True for i in range(sent_len): if position_mask[i]: grad_norm[0][i] = -(grad_norm[0][i] + 1000) # print(grad_norm) # Select a word and forbid itself word_sids = [] # type: list[torch.Tensor] word_vids = [] # type: list[torch.Tensor] new_word_vids = [] # type: list[torch.Tensor] _, topk_idxs = grad_norm[0].topk(max_change_num) for ele in topk_idxs: word_sids.append(ele) for word_sid in word_sids: word_grad = embed_grad[0][word_sid] word_vid = words[0][word_sid] emb_to_rpl = self.parser.embed.weight[word_vid] # Find a word to change delta = word_grad / \ torch.norm(word_grad) * self.config.hko_step_size changed = emb_to_rpl - delta # must_tag = self.vocab.tags[tags[0][word_sid].item()] must_tag = None # must_tag = HACK_TAGS['njr'] # must_tag = 'CD' forbidden_idxs__.append(word_vid) change_positions__.add(word_sid.item()) new_word_vid, repl_info = self.find_replacement( changed, must_tag, dist_measure=self.config.hko_dist_measure, forbidden_idxs__=forbidden_idxs__, repl_method='tagdict', words=words, word_sid=word_sid) word_vids.append(word_vid) new_word_vids.append(new_word_vid) new_words = words.clone() for i in range(len(word_vids)): if new_word_vids[i] is not None: new_words[0][word_sids[i]] = new_word_vids[i] """ Evaluating the result """ # print('START EVALUATING') # print([self.vocab.words[ele] for ele in self.forbidden_idxs__]) metric = self.task.partial_evaluate(instance=(new_words, tags, None, arcs, rels), mask_idxs=mask_idxs) def _gen_log_table(): new_words_text = [self.vocab.words[i.item()] for i in new_words[0]] raw_words_text = [self.vocab.words[i.item()] for i in raw_words[0]] tags_text = [self.vocab.tags[i.item()] for i in tags[0]] att_tags, att_arcs, att_rels = self.task.predict([(new_words, tags, None)]) table = [] for i in range(sent_len): gold_arc = int(arcs[0][i]) raw_arc = 0 if i == 0 else raw_arcs[0][i - 1] att_arc = 0 if i == 0 else att_arcs[0][i - 1] mask_symbol = '&' if att_arc in mask_idxs or i in mask_idxs else "" table.append([ "{}{}".format(i, "@" if i in mask_idxs else ""), raw_words_text[i], '>{}'.format(new_words_text[i]) if raw_words_text[i] != new_words_text[i] else "*", tags_text[i], gold_arc, raw_arc, '>{}{}'.format( att_arc, mask_symbol) if att_arc != raw_arc else '*', grad_norm[0][i].item() ]) return table if verbose: print('$$$$$$$$$$$$$$$$$$$$$$$$$$$') print('Iter {}'.format(iter_id)) print(tabulate(_gen_log_table(), floatfmt=('.6f'))) print('^^^^^^^^^^^^^^^^^^^^^^^^^^^') rpl_detail = "" for i in range(len(word_sids)): if new_word_vids[i] is not None: rpl_detail += "{}:{}->{} ".format( self.vocab.words[raw_words[0][word_sids[i]].item()], self.vocab.words[word_vids[i].item()], self.vocab.words[new_word_vids[i].item()]) log( "iter {}, uas {:.4f}, ".format(iter_id, metric.uas), "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '', rpl_detail) if metric.uas >= raw_metric.uas - .00001: info = 'Nothing' else: info = tabulate(_gen_log_table(), floatfmt='.6f') return { 'code': 200, 'words': new_words, 'attack_metric': metric, 'logtable': info, }
def single_hack(self, words, tags, arcs, rels, raw_words, raw_metric, raw_arcs, forbidden_idxs__, change_positions__, orphans__, verbose=False, max_change_num=1, iter_change_num=1, iter_id=-1): sent_len = words.size(1) """ Loss back-propagation """ embed_grad = self.backward_loss( words, tags, arcs, rels, loss_based_on=self.config.hkw_loss_based_on) # Sometimes the loss/grad will be zero. # Especially in the case of applying pgd_freq>1 to small sentences: # e.g., the uas of projected version may be 83.33% # while under the case of the unprojected version, the loss is 0. if torch.sum(embed_grad) == 0.0: return {"code": 404, "info": "Loss is zero."} """ Select and change a word """ grad_norm = embed_grad.norm(dim=2) position_mask = [False for _ in range(words.size(1))] # Mask some positions by POS & <UNK> for i in range(sent_len): if self.vocab.tags[tags[0][i]] not in HACK_TAGS[ self.config.hkw_tag_type]: position_mask[i] = True # Mask some orphans for i in orphans__: position_mask[i] = True # Check if the number of changed words exceeds the max value if len(change_positions__) >= max_change_num: for i in range(sent_len): if i not in change_positions__: position_mask[i] = True if all(position_mask): return { "code": 404, "info": "Constrained by tags, no valid word to replace." } for i in range(sent_len): if position_mask[i]: grad_norm[0][i] = -(grad_norm[0][i] + 1000) # Select a word and forbid itself word_sids = [] # type: list[torch.Tensor] word_vids = [] # type: list[torch.Tensor] new_word_vids = [] # type: list[torch.Tensor] _, topk_idxs = grad_norm[0].sort(descending=True) selected_words = elder_select(ordered_idxs=cast_list(topk_idxs), num_to_select=iter_change_num, selected=change_positions__, max_num=max_change_num) # The position mask will ensure that at least one word is legal, # but the second one may not be allowed selected_words = [ ele for ele in selected_words if position_mask[ele] is False ] word_sids = torch.tensor(selected_words) for word_sid in word_sids: word_grad = embed_grad[0][word_sid] word_vid = words[0][word_sid] emb_to_rpl = self.parser.embed.weight[word_vid] forbidden_idxs__.append(word_vid.item()) change_positions__.add(word_sid.item()) # print(self.change_positions__) # Find a word to change with dynamically step # Note that it is possible that all words found are not as required, e.g. # all neighbours have different tags. delta = word_grad / \ torch.norm(word_grad) * self.config.hkw_step_size changed = emb_to_rpl - delta must_tag = self.vocab.tags[tags[0][word_sid].item()] new_word_vid, repl_info = self.find_replacement( changed, must_tag, dist_measure=self.config.hkw_dist_measure, forbidden_idxs__=forbidden_idxs__, repl_method=self.config.hkw_repl_method, words=words, word_sid=word_sid, raw_words=raw_words) word_vids.append(word_vid) new_word_vids.append(new_word_vid) new_words = words.clone() exist_change = False for i in range(len(word_vids)): if new_word_vids[i] is not None: new_words[0][word_sids[i]] = new_word_vids[i] exist_change = True if not exist_change: # if self.config.hkw_selection == 'orphan': # orphans__.add(word_sids[0]) # log('iter {}, Add word {}\'s location to orphans.'.format( # iter_id, # self.vocab.words[raw_words[0][word_sid].item()])) # return { # 'code': '200', # 'words': words, # 'atack_metric': 100., # 'logtable': 'This will be never selected' # } log('Attack failed.') return { 'code': 404, 'info': 'Neighbours of all selected words have different tags.' } # if new_word_vid is None: # log('Attack failed.') # return {'code': 404, # 'info': 'Neighbours of the selected words have different tags.'} # new_words = words.clone() # new_words[0][word_sid] = new_word_vid """ Evaluating the result """ # print('START EVALUATING') # print([self.vocab.words[ele] for ele in self.forbidden_idxs__]) new_words_text = [self.vocab.words[i.item()] for i in new_words[0]] # print(new_words_txt) loss, metric = self.task.evaluate( [(new_words, tags, None, arcs, rels)], mst=self.config.hkw_mst == 'on') def _gen_log_table(): new_words_text = [self.vocab.words[i.item()] for i in new_words[0]] raw_words_text = [self.vocab.words[i.item()] for i in raw_words[0]] tags_text = [self.vocab.tags[i.item()] for i in tags[0]] _, att_arcs, _ = self.task.predict([(new_words, tags, None)], mst=self.config.hkw_mst == 'on') table = [] for i in range(sent_len): gold_arc = int(arcs[0][i]) raw_arc = 0 if i == 0 else raw_arcs[0][i - 1] att_arc = 0 if i == 0 else att_arcs[0][i - 1] relevant_mask = '&' if \ raw_words[0][att_arc] != new_words[0][att_arc] or \ raw_words_text[i] != new_words_text[i] else "" table.append([ i, raw_words_text[i], '>{}'.format(new_words_text[i]) if raw_words_text[i] != new_words_text[i] else "*", tags_text[i], gold_arc, raw_arc, '>{}{}'.format( att_arc, relevant_mask) if att_arc != raw_arc else '*', grad_norm[0][i].item() ]) return table if verbose: print('$$$$$$$$$$$$$$$$$$$$$$$$$$$') print('Iter {}'.format(iter_id)) print(tabulate(_gen_log_table(), floatfmt=('.6f'))) print('^^^^^^^^^^^^^^^^^^^^^^^^^^^') rpl_detail = "" for i in range(len(word_sids)): if new_word_vids[i] is not None: rpl_detail += "{}:{}->{} ".format( self.vocab.words[raw_words[0][word_sids[i]].item()], self.vocab.words[word_vids[i].item()], self.vocab.words[new_word_vids[i].item()]) log( "iter {}, uas {:.4f}, ".format(iter_id, metric.uas), "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '', rpl_detail) if metric.uas >= raw_metric.uas - .00001: logtable = 'Nothing' else: logtable = tabulate(_gen_log_table(), floatfmt='.6f') return { 'code': 200, 'words': new_words, 'attack_metric': metric, 'logtable': logtable, # "forbidden_idxs__": forbidden_idxs__, # "change_positions__": change_positions__, }
def single_hack(self, words, tags, chars, arcs, rels, raw_chars, raw_metric, raw_arcs, forbidden_idxs__, change_positions__, verbose=False, max_change_num=1, iter_change_num=1, iter_id=-1): sent_len = words.size(1) """ Loss back-propagation """ char_grads, word_grads = self.backward_loss(words, chars, arcs, rels) """ Select and change a word """ word_grad_norm = word_grads.norm(dim=-1) # 1 x length # 1 x length x max_char_length char_grad_norm = char_grads.norm(dim=-1) woca_grad_norm = char_grad_norm.max(dim=-1) # 1 x length position_mask = [False for _ in range(words.size(1))] # Mask some positions by POS & <UNK> for i in range(sent_len): if rels[0][i].item() == self.vocab.rel_dict['punct']: position_mask[i] = True # Check if the number of changed words exceeds the max value if len(change_positions__) >= max_change_num: for i in range(sent_len): if i not in change_positions__: position_mask[i] = True if all(position_mask): return {"code": 404, "info": "Constrained by tags, no valid word to replace."} for i in range(sent_len): if position_mask[i]: word_grad_norm[0][i] = -(word_grad_norm[0][i] + 1000) char_grad_norm[0][i] = -(char_grad_norm[0][i] + 1000) # Select a word and forbid itself word_sids = [] char_wids = [] char_vids = [] new_char_vids = [] if self.config.hkc_selection == 'elder': _, topk_idxs = word_grad_norm[0].sort(descending=True) selected_words = elder_select(ordered_idxs=cast_list(topk_idxs), num_to_select=iter_change_num, selected=change_positions__, max_num=max_change_num) selected_chars = [] for ele in selected_words: if ele in change_positions__: selected_chars.append(change_positions__[ele]) else: wcid = char_grad_norm[0][ele].argmax() selected_chars.append(wcid) word_sids = torch.tensor(selected_words) char_wids = torch.tensor(selected_chars) elif self.config.hkc_selection == 'young': _, topk_idxs = word_grad_norm[0].sort(descending=True) selected_words, ex_words = young_select(ordered_idxs=cast_list(topk_idxs), num_to_select=iter_change_num, selected=change_positions__, max_num=max_change_num) for ele in ex_words: change_positions__.pop(ele) forbidden_idxs__.pop(ele) chars[0][ele] = raw_chars[0][ele] log('Drop elder replacement', self.vocab.words[words[i].item()]) selected_chars = [] for ele in selected_words: if ele in change_positions__: selected_chars.append(change_positions__[ele]) else: wcid = char_grad_norm[0][ele].argmax() selected_chars.append(wcid) word_sids = torch.tensor(selected_words) char_wids = torch.tensor(selected_chars) else: raise Exception for word_sid, char_wid in zip(word_sids, char_wids): char_grad = char_grads[0][word_sid][char_wid] char_vid = chars[0][word_sid][char_wid] emb_to_rpl = self.parser.char_lstm.embed.weight[char_vid] forbidden_idxs__[word_sid.item()].append(char_vid.item()) change_positions__[word_sid.item()] = char_wid.item() # Find a word to change with dynamically step # Note that it is possible that all words found are not as required, e.g. # all neighbours have different tags. delta = char_grad / \ torch.norm(char_grad) * self.config.hkc_step_size changed = emb_to_rpl - delta dist = {'euc': euc_dist, 'cos': cos_dist}[ self.config.hkc_dist_measure](changed, self.parser.char_lstm.embed.weight) vals, idxs = dist.sort() for ele in idxs: if ele.item() not in forbidden_idxs__[word_sid.item()] and \ ele.item() not in [self.vocab.pad_index, self.vocab.unk_index]: new_char_vid = ele break char_vids.append(char_vid) new_char_vids.append(new_char_vid) new_chars = chars.clone() for i in range(len(word_sids)): new_chars[0][word_sids[i]][char_wids[i]] = new_char_vids[i] # log(dict(forbidden_idxs__)) # log(change_positions__) """ Evaluating the result """ # print('START EVALUATING') # print([self.vocab.words[ele] for ele in self.forbidden_idxs__]) new_chars_text = [self.vocab.id2char(ele) for ele in new_chars[0]] # print(new_words_txt) loss, metric = self.task.evaluate( [(words, tags, new_chars, arcs, rels)], mst=self.config.hkc_mst == 'on') def _gen_log_table(): new_words_text = [self.vocab.id2char(ele) for ele in new_chars[0]] raw_words_text = [self.vocab.id2char(ele) for ele in raw_chars[0]] tags_text = [self.vocab.tags[ele.item()] for ele in tags[0]] _, att_arcs, _ = self.task.predict([(words, tags, new_chars)], mst=self.config.hkc_mst == 'on') table = [] for i in range(sent_len): gold_arc = int(arcs[0][i]) raw_arc = 0 if i == 0 else raw_arcs[0][i - 1] att_arc = 0 if i == 0 else att_arcs[0][i - 1] relevant_mask = '&' if \ raw_words_text[att_arc] != new_words_text[att_arc] \ or raw_words_text[i] != new_words_text[i] \ else "" table.append([ i, raw_words_text[i], '>{}'.format( new_words_text[i]) if raw_words_text[i] != new_words_text[i] else "*", tags_text[i], gold_arc, raw_arc, '>{}{}'.format( att_arc, relevant_mask) if att_arc != raw_arc else '*', word_grad_norm[0][i].item() ]) return table if verbose: print('$$$$$$$$$$$$$$$$$$$$$$$$$$$') print('Iter {}'.format(iter_id)) print(tabulate(_gen_log_table(), floatfmt=('.6f'))) print('^^^^^^^^^^^^^^^^^^^^^^^^^^^') rpl_detail = "" for i in range(len(word_sids)): rpl_detail += "{}:{}->{} ".format( self.vocab.id2char(raw_chars[0, word_sids[i]]), self.vocab.id2char(chars[0, word_sids[i]]), self.vocab.id2char(new_chars[0, word_sids[i]])) log("iter {}, uas {:.4f}, ".format(iter_id, metric.uas), # "mind {:6.3f}, avgd {:6.3f}, ".format( # repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '', rpl_detail ) if metric.uas >= raw_metric.uas - .00001: logtable = 'Nothing' else: logtable = tabulate(_gen_log_table(), floatfmt='.6f') return { 'code': 200, 'chars': new_chars, 'attack_metric': metric, 'logtable': logtable, # "forbidden_idxs__": forbidden_idxs__, # "change_positions__": change_positions__, }
def hack(self, instance): words, tags, chars, arcs, rels = instance _, raw_metric = self.task.evaluate([(words, tags, chars, arcs, rels)], mst=self.config.hkc_mst == 'on') _, raw_arcs, _ = self.task.predict([(words, tags, chars)], mst=self.config.hkc_mst == 'on') # char_grads, word_grad = self.backward_loss(words, chars, arcs, rels) forbidden_idxs__ = defaultdict( lambda: deque(maxlen=5)) # word_sid -> deque() change_positions__ = dict() # word_sid -> char_wid if self.config.hkc_max_change > 0.9999: max_change_num = int(self.config.hkc_max_change) else: max_change_num = max( 1, int(self.config.hkc_max_change * words.size(1))) iter_change_num = min(max_change_num, self.config.hkc_iter_change) raw_chars = chars.clone() var_chars = chars.clone() # HIGHLIGHT: ITERATION t0 = time.time() picker = CherryPicker(lower_is_better=True, compare_fn=lambda m1, m2: m1.uas - m2.uas) # iter 0 -> raw picker.add(raw_metric, { "num_changed": 0, "logtable": 'No modification' }) for iter_id in range(1, self.config.hkc_steps): result = self.single_hack( words, tags, var_chars, arcs, rels, raw_chars=raw_chars, raw_metric=raw_metric, raw_arcs=raw_arcs, verbose=False, max_change_num=max_change_num, iter_change_num=iter_change_num, iter_id=iter_id, forbidden_idxs__=forbidden_idxs__, change_positions__=change_positions__, ) # Fail if result['code'] == 404: log('Stop in step {}, info: {}'.format( iter_id, result['info'])) break # Success if result['code'] == 200: picker.add(result['attack_metric'], { # "words": result['new_words'], "logtable": result['logtable'], "num_changed": len(change_positions__) }) if result['attack_metric'].uas < raw_metric.uas - self.config.hkw_eps: log('Succeed in step {}'.format(iter_id)) break var_chars = result['chars'] t1 = time.time() best_iter, best_attack_metric, best_info = picker.select_best_point() return { "raw_metric": raw_metric, "attack_metric": best_attack_metric, "iters": iter_id, "best_iter": best_iter, "num_changed": best_info['num_changed'], "time": t1 - t0, "logtable": best_info['logtable'] }
def __call__(self, config): _, loader = self.pre_attack(config) self.parser.register_backward_hook(self.extract_embed_grad) # log_config('whitelog.txt', # log_path=config.workspace, # default_target='cf') train_corpus = Corpus.load(config.ftrain) self.tag_filter = generate_tag_filter(train_corpus, self.vocab) # corpus = Corpus.load(config.fdata) # dataset = TextDataset(self.vocab.numericalize(corpus, True)) # # set the data loader # loader = DataLoader(dataset=dataset, # collate_fn=collate_fn) # def embed_hook(module, grad_in, grad_out): # self.vals["embed_grad"] = grad_out[0] # dpattack.pretrained.register_backward_hook(embed_hook) raw_metrics = Metric() attack_metrics = Metric() log('dist measure', config.hk_dist_measure) # batch size == 1 for sid, (words, tags, chars, arcs, rels) in enumerate(loader): # if sid > 10: # break raw_words = words.clone() words_text = self.get_seqs_name(words) tags_text = self.get_tags_name(tags) log('****** {}: \n\t{}\n\t{}'.format(sid, " ".join(words_text), " ".join(tags_text))) self.vals['forbidden'] = [ self.vocab.unk_index, self.vocab.pad_index ] for pgdid in range(100): result = self.single_hack(words, tags, arcs, rels, dist_measure=config.hk_dist_measure, raw_words=raw_words) if result['code'] == 200: raw_metrics += result['raw_metric'] attack_metrics += result['attack_metric'] log('attack successfully at step {}'.format(pgdid)) break elif result['code'] == 404: raw_metrics += result['raw_metric'] attack_metrics += result['raw_metric'] log('attack failed at step {}'.format(pgdid)) break elif result['code'] == 300: if pgdid == 99: raw_metrics += result['raw_metric'] attack_metrics += result['raw_metric'] log('attack failed at step {}'.format(pgdid)) else: words = result['words'] log() log('Aggregated result: {} --> {}'.format(raw_metrics, attack_metrics), target='cf')
def single_hack(self, words, tags, arcs, rels, raw_words, target_tags=['NN', 'JJ'], dist_measure='enc', verbose=False): vocab = self.vocab parser = self.parser loss_fn = self.task.criterion embed = parser.embed.weight raw_loss, raw_metric = self.task.evaluate([(raw_words, tags, arcs, rels)]) # backward the loss parser.zero_grad() mask = words.ne(vocab.pad_index) mask[:, 0] = 0 s_arc, s_rel = parser(words, tags) s_arc, s_rel = s_arc[mask], s_rel[mask] gold_arcs, gold_rels = arcs[mask], rels[mask] loss = loss_fn(s_arc, gold_arcs) loss.backward() grad_norm = self.vals['embed_grad'].norm(dim=2) # Select a word to attack by its POS and norm exist_target_tag = False for i in range(tags.size(1)): if vocab.tags[tags[0][i]] not in target_tags: grad_norm[0][i] -= 99999. else: exist_target_tag = True if not exist_target_tag: return {"code": 404, 'raw_metric': raw_metric} word_sid = grad_norm[0].argmax() word_vid = words[0][word_sid] max_grad = self.vals['embed_grad'][0][word_sid] # Forbid the word itself to be selected self.vals['forbidden'].append(word_vid.item()) # Find a word to change changed = embed[word_vid] - max_grad * 1 dist = { 'euc': lambda: (changed - embed).pow(2).sum(dim=1), 'cos': lambda: torch.nn.functional.cosine_similarity( embed, changed.repeat(embed.size(0), 1), dim=1) }[dist_measure]() must_tag = vocab.tags[tags[0][word_sid].item()] legal_tag_index = self.tag_filter[must_tag].to(dist.device) legal_tag_mask = dist.new_zeros(dist.size()) \ .index_fill_(0, legal_tag_index, 1.).byte() dist.masked_fill_(1 - legal_tag_mask, 99999.) for ele in self.vals['forbidden']: dist[ele] = 99999. word_vid_to_rpl = dist.argmin() # A forbidden word maybe chosen if all words are forbidden(99999.) if word_vid_to_rpl.item() in self.vals['forbidden']: log('Attack failed.') return {'code': 404, 'raw_metric': raw_metric} # ===================== # Evaluating the result # ===================== repl_words = words.clone() repl_words[0][word_sid] = word_vid_to_rpl repl_words_text = [vocab.words[i.item()] for i in repl_words[0]] raw_words_text = [vocab.words[i.item()] for i in raw_words[0]] tags_text = [vocab.tags[i.item()] for i in tags[0]] if verbose: print('After Attacking: \n\t{}\n\t{}'.format( " ".join(repl_words_text), " ".join(tags_text))) pred_tags, pred_arcs, pred_rels = self.task.predict([(repl_words, tags) ]) loss, metric = self.task.evaluate([(repl_words, tags, arcs, rels)]) table = [] for i in range(words.size(1)): gold_arc = int(arcs[0][i]) pred_arc = 0 if i == 0 else pred_arcs[0][i - 1] table.append([ i, repl_words_text[i], raw_words_text[i] if raw_words_text[i] != repl_words_text[i] else "*", tags_text[i], gold_arc, pred_arc if pred_arc != gold_arc else '*', grad_norm[0][i].item() ]) if verbose: print('{} --> {}'.format(vocab.words[word_vid.item()], vocab.words[word_vid_to_rpl.item()])) print(tabulate(table, floatfmt=('.6f'))) print(metric) print('**************************') if metric.uas > raw_metric.uas - 0.1: return { 'code': 300, 'words': repl_words, 'raw_metric': raw_metric, 'attack_metric': metric } else: log(tabulate(table, floatfmt=('.6f'))) log('Result {} --> {}'.format(raw_metric.uas, metric.uas), target='cf') return { 'code': 200, 'words': repl_words, 'raw_metric': raw_metric, 'attack_metric': metric }
def __call__(self, config): if config.logf == 'on': log_config('hackoutside', log_path=config.workspace, default_target='cf') from dpattack.libs.luna import log else: log = print log('[General Settings]') log(config) log('[Hack Settings]') for arg in config.kwargs: if arg.startswith('hks'): log(arg, '\t', config.kwargs[arg]) log('------------------') self.setup(config) raw_metrics = ParserMetric() attack_metrics = ParserMetric() agg = Aggregator() for sid, (words, tags, chars, arcs, rels) in enumerate(self.loader): # if sid > 100: # continue words_text = self.vocab.id2word(words[0]) tags_text = self.vocab.id2tag(tags[0]) log('****** {}: \n{}\n{}'.format(sid, " ".join(words_text), " ".join(tags_text))) result = self.hack(instance=(words, tags, chars, arcs, rels), sentence=self.corpus[sid]) if result is None: continue else: raw_metrics += result['raw_metric'] attack_metrics += result['attack_metric'] agg.aggregate( ("iters", result['iters']), ("time", result['time']), ("fail", abs(result['attack_metric'].uas - result['raw_metric'].uas) < 1e-4), ('best_iter', result['best_iter']), ("changed", result['num_changed'])) # WARNING: SOME SENTENCE NOT SHOWN! if result: log('Show result from iter {}:'.format(result['best_iter'])) log(result['logtable']) log('Aggregated result: {} --> {}, ' 'iters(avg) {:.1f}, time(avg) {:.1f}s, ' 'fail rate {:.2f}, best_iter(avg) {:.1f}, best_iter(std) {:.1f}, ' 'changed(avg) {:.1f}'.format(raw_metrics, attack_metrics, agg.mean('iters'), agg.mean('time'), agg.mean('fail'), agg.mean('best_iter'), agg.std('best_iter'), agg.mean('changed'))) log()
def single_hack(self, instance, src_span, tgt_span, iter_id, raw_words, raw_metric, raw_arcs, forbidden_idxs__: list, change_positions__: set, max_change_num, iter_change_num, verbose=False): # yapf: enable words, tags, chars, arcs, rels = instance sent_len = words.size(1) # Backward loss embed_grad = self.backward_loss(instance=instance, mask_idxs=ex_span_idx(tgt_span, sent_len), verbose=True) grad_norm = embed_grad.norm(dim=2) if self.config.hks_word_random == "on": grad_norm = torch.rand(grad_norm.size(), device=grad_norm.device) position_mask = [False for _ in range(words.size(1))] # Mask some positions for i in range(sent_len): if rels[0][i].item( ) == self.vocab.rel_dict['punct'] or not src_span[0] <= i <= src_span[1]: position_mask[i] = True # Check if the number of changed words exceeds the max value if len(change_positions__) >= max_change_num: for i in range(sent_len): if i not in change_positions__: position_mask[i] = True for i in range(sent_len): if position_mask[i]: grad_norm[0][i] = -(grad_norm[0][i] + 1000) # print(grad_norm) # Select a word and forbid itself word_sids = [] # type: list[torch.Tensor] word_vids = [] # type: list[torch.Tensor] new_word_vids = [] # type: list[torch.Tensor] # _, topk_idxs = grad_norm[0].topk(min(max_change_num, len(src_idxs))) # for ele in topk_idxs: # word_sids.append(ele) _, topk_idxs = grad_norm[0].sort(descending=True) selected_words = elder_select(ordered_idxs=cast_list(topk_idxs), num_to_select=iter_change_num, selected=change_positions__, max_num=max_change_num) # The position mask will ensure that at least one word is legal, # but the second one may not be allowed selected_words = [ele for ele in selected_words if position_mask[ele] is False] word_sids = torch.tensor(selected_words) for word_sid in word_sids: word_vid = words[0][word_sid] emb_to_rpl = self.parser.embed.weight[word_vid] if self.config.hks_step_size > 0: word_grad = embed_grad[0][word_sid] delta = word_grad / \ torch.norm(word_grad) * self.config.hks_step_size changed = emb_to_rpl - delta tag_type = self.config.hks_constraint if tag_type == 'any': must_tag = None elif tag_type == 'same': must_tag = self.vocab.tags[tags[0][word_sid].item()] elif re.match("[njvri]+", tag_type): must_tag = HACK_TAGS[tag_type] else: raise Exception forbidden_idxs__.append(word_vid) change_positions__.add(word_sid.item()) new_word_vid, repl_info = self.find_replacement( changed, must_tag, dist_measure=self.config.hks_dist_measure, forbidden_idxs__=forbidden_idxs__, repl_method='tagdict', words=words, word_sid=word_sid) else: new_word_vid = random.randint(0, self.vocab.n_words) while new_word_vid in [self.vocab.pad_index, self.vocab.word_dict["<root>"]]: new_word_vid = random.randint(0, self.vocab.n_words - 1) new_word_vid = torch.tensor(new_word_vid, device=words.device) repl_info = {} change_positions__.add(word_sid.item()) word_vids.append(word_vid) new_word_vids.append(new_word_vid) # log(delta @ (self.parser.embed.weight[new_word_vid] - self.parser.embed.weight[word_vid])) # log() new_words = words.clone() for i in range(len(word_vids)): if new_word_vids[i] is not None: new_words[0][word_sids[i]] = new_word_vids[i] """ Evaluating the result """ # print('START EVALUATING') # print([self.vocab.words[ele] for ele in self.forbidden_idxs__]) metric = self.task.partial_evaluate(instance=(new_words, tags, None, arcs, rels), mask_idxs=ex_span_idx(tgt_span, sent_len), mst=self.config.hks_mst == 'on') att_tags, att_arcs, att_rels = self.task.predict([(new_words, tags, None)], mst=self.config.hks_mst == 'on') # if verbose: # print('$$$$$$$$$$$$$$$$$$$$$$$$$$$') # print('Iter {}'.format(iter_id)) # print(tabulate(_gen_log_table(), floatfmt=('.6f'))) # print('^^^^^^^^^^^^^^^^^^^^^^^^^^^') rpl_detail = "" for i in range(len(word_sids)): if new_word_vids[i] is not None: rpl_detail += "{}:{}->{} ".format( self.vocab.words[raw_words[0][word_sids[i]].item()], self.vocab.words[word_vids[i].item()], self.vocab.words[new_word_vids[i].item()]) log( "iter {}, uas {:.4f}, ".format(iter_id, metric.uas), "mind {:6.3f}, avgd {:6.3f}, ".format(repl_info['mind'], repl_info['avgd']) if 'mind' in repl_info else '', rpl_detail) if metric.uas >= raw_metric.uas - .00001: info = 'Nothing' else: info = tabulate(self._gen_log_table(raw_words, new_words, tags, arcs, rels, raw_arcs, att_arcs, src_span, tgt_span), floatfmt='.6f') return { 'code': 200, 'words': new_words, 'attack_metric': metric, 'logtable': info, }
def hack(self, instance): words, tags, chars, arcs, rels = instance _, raw_metric = self.task.evaluate([(words, tags, chars, arcs, rels)], mst=self.config.hkw_mst == 'on') _, raw_arcs, _ = self.task.predict([(words, tags, None)], mst=self.config.hkw_mst == 'on') # Setup some states before attacking a sentence # WARNING: Operations on variables n__global_ramambc__" passed to a function # are in-placed! Never try to save an internal state of the variable. forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index] change_positions__ = set() orphans__ = set() if self.config.hkw_max_change > 0.9999: max_change_num = int(self.config.hkw_max_change) else: max_change_num = max( 1, int(self.config.hkw_max_change * words.size(1))) iter_change_num = min(max_change_num, self.config.hkw_iter_change) var_words = words.clone() raw_words = words.clone() # HIGHLIGHT: ITERATION t0 = time.time() picker = CherryPicker(lower_is_better=True, compare_fn=lambda m1, m2: m1.uas - m2.uas) # iter 0 -> raw picker.add(raw_metric, { "num_changed": 0, "logtable": 'No modification' }) for iter_id in range(1, self.config.hkw_steps): result = self.single_hack(var_words, tags, arcs, rels, raw_words=raw_words, raw_metric=raw_metric, raw_arcs=raw_arcs, verbose=False, max_change_num=max_change_num, iter_change_num=iter_change_num, iter_id=iter_id, forbidden_idxs__=forbidden_idxs__, change_positions__=change_positions__, orphans__=orphans__) # Fail if result['code'] == 404: log('Stop in step {}, info: {}'.format(iter_id, result['info'])) break # Success if result['code'] == 200: picker.add( result['attack_metric'], { # "words": result['new_words'], "logtable": result['logtable'], "num_changed": len(change_positions__) }) if result[ 'attack_metric'].uas < raw_metric.uas - self.config.hkw_eps: log('Succeed in step {}'.format(iter_id)) break var_words = result['words'] # forbidden_idxs__ = result['forbidden_idxs__'] # change_positions__ = result['change_positions__'] t1 = time.time() best_iter, best_attack_metric, best_info = picker.select_best_point() return { "raw_metric": raw_metric, "attack_metric": best_attack_metric, "iters": iter_id, "best_iter": best_iter, "num_changed": best_info['num_changed'], "time": t1 - t0, "logtable": best_info['logtable'] }
def select_by_jacobian(self, instance, sentence, mode): sent_len = instance[0].size(1) spans = filter_spans(gen_spans(sentence), self.config.hks_min_span_len, self.config.hks_max_span_len, True) if mode == 'jacobian1': table = [] pairs = [] for tgt_span in spans: embed_grad = self.backward_loss(instance, ex_span_idx(tgt_span, sent_len)) grad_norm = embed_grad.norm(dim=-1) # 1 x (sen_len - 1) row = [tgt_span] for src_span in spans: if check_gap(src_span, tgt_span, self.config.hks_span_gap): # <root> not included norm = grad_norm[0][src_span[0] - 1:src_span[1]].sum().item() pairs.append((norm, (tgt_span, src_span))) row.append("{:4.2f}".format(norm)) else: row.append("-") table.append(row) log(tabulate(table, headers=["t↓"] + spans)) if len(pairs) == 0: return None spairs = sorted(pairs, key=lambda x: x[0], reverse=True) return list(zip(*spairs))[1] elif mode == 'jacobian2': # log('valid', spans) # margins = self.compute_margin(instance) # log(margins) norms = [] # WARNING: This is rather slow! idxs_to_backward = [] for span in spans: idxs_to_backward.extend(list(range(span[0], span[1] + 1))) for i in range(1, sent_len): if i in idxs_to_backward: embed_grad = self.backward_loss(instance, [_ for _ in range(1, sent_len) if _ != i]) grad_norm = embed_grad.norm(dim=-1)[:, 1:] norms.append(grad_norm) else: norms.append(torch.zeros_like(instance[0]).float()[:, 1:]) norms = torch.cat(norms) # size: sent_len x sent_len # log("t↓", flt2str(list(range(sent_len)), cat=" ", fmt=":3")) # for i, norm in enumerate(norms): # log("{:2}".format(i + 1), # flt2str(cast_list(norm), cat=" ", fmt=":3.1f"), # flt2str(sum(norm), fmt=":3.1f")) table = [] for t in spans: row = [t] for s in spans: if check_gap(s, t, self.config.hks_span_gap): row.append("{:4.2f}".format(norms[t[0] - 1:t[1], s[0] - 1:s[1]].sum().item())) else: row.append('-') table.append(row) log(tabulate(table, headers=["t↓"] + spans)) pairs = [] for t in spans: for s in spans: if check_gap(s, t, self.config.hks_span_gap): pairs.append((norms[t[0] - 1:t[1], s[0] - 1:s[1]].sum().item(), (t, s))) if len(pairs) == 0: return None spairs = sorted(pairs, key=lambda x: x[0], reverse=True) return list(zip(*spairs))[1]
def __call__(self, config): self.init_logger(config) self.setup(config) if self.config.hk_use_worker == 'on': start_sid, end_sid = locate_chunk(len(self.loader), self.config.hk_num_worker, self.config.hk_worker_id) log('Run code on a chunk [{}, {})'.format(start_sid, end_sid)) raw_metrics = ParserMetric() attack_metrics = ParserMetric() agg = Aggregator() for sid, (words, tags, chars, arcs, rels) in enumerate(self.loader): # if sid in [0, 1, 2, 3, 4]: # continue # if sid < 1434: # continue if self.config.hk_use_worker == 'on': if sid < start_sid or sid >= end_sid: continue if self.config.hk_training_set == 'on' and words.size(1) > 50: log('Skip sentence {} whose length is {}(>50).'.format( sid, words.size(1))) continue if words.size(1) < 5: log('Skip sentence {} whose length is {}(<5).'.format( sid, words.size(1))) continue words_text = self.vocab.id2word(words[0]) tags_text = self.vocab.id2tag(tags[0]) log('****** {}: \n{}\n{}'.format(sid, " ".join(words_text), " ".join(tags_text))) # hack it! result = self.hack(instance=(words, tags, chars, arcs, rels)) # aggregate information raw_metrics += result['raw_metric'] attack_metrics += result['attack_metric'] agg.aggregate( ("iters", result['iters']), ("time", result['time']), ("fail", abs(result['attack_metric'].uas - result['raw_metric'].uas) < 1e-4), ('best_iter', result['best_iter']), ("changed", result['num_changed'])) # log some information log('Show result from iter {}, changed num {}:'.format( result['best_iter'], result['num_changed'])) log(result['logtable']) log('Aggregated result: {} --> {}, ' 'iters(avg) {:.1f}, time(avg) {:.1f}s, ' 'fail rate {:.2f}, best_iter(avg) {:.1f}, best_iter(std) {:.1f}, ' 'changed(avg) {:.1f}'.format(raw_metrics, attack_metrics, agg.mean('iters'), agg.mean('time'), agg.mean('fail'), agg.mean('best_iter'), agg.std('best_iter'), agg.mean('changed'))) log()
def hack(self, instance, sentence): words, tags, chars, arcs, rels = instance sent_len = words.size(1) spans = gen_spans(sentence) valid_spans = list(filter(lambda ele: 5 <= ele[1] - ele[0] <= 8, spans)) if len(valid_spans) == 0: log("Attack error, no valid spans".format()) return None chosen_span = random.choice(valid_spans) # chosen_span = (18, 24) if chosen_span[1] - chosen_span[0] + 1 > 0.5 * sent_len: return None # sent_print(sentence, 'tablev') # print('spans', spans) # print('valid', valid_spans) log( 'chosen span: ', chosen_span, ' '.join( self.vocab.id2word(words[0])[chosen_span[0]:chosen_span[1] + 1])) raw_words = words.clone() var_words = words.clone() words_text = self.vocab.id2word(words[0]) tags_text = self.vocab.id2tag(tags[0]) # eval_idxs = [eval_idx for eval_idx in range(sent_len) # if not chosen_span[0] <= eval_idx <= chosen_span[1]] mask_idxs = list(range(chosen_span[0], chosen_span[1] + 1)) raw_metric = self.task.partial_evaluate(instance=(raw_words, tags, None, arcs, rels), mask_idxs=mask_idxs) _, raw_arcs, _ = self.task.predict([(raw_words, tags, None)]) forbidden_idxs__ = [self.vocab.unk_index, self.vocab.pad_index] change_positions__ = set() if isinstance(self.config.hko_max_change, int): max_change_num = self.config.hko_max_change elif isinstance(self.config.hk_max_change, float): max_change_num = int(self.config.hk_max_change * words.size(1)) else: raise Exception("hk_max_change must be a float or an int") picker = CherryPicker(lower_is_better=True) t0 = time.time() for iter_id in range(self.config.hko_steps): result = self.single_hack(instance=(var_words, tags, None, arcs, rels), raw_words=raw_words, raw_metric=raw_metric, raw_arcs=raw_arcs, mask_idxs=mask_idxs, iter_id=iter_id, forbidden_idxs__=forbidden_idxs__, change_positions__=change_positions__, max_change_num=max_change_num) if result['code'] == 200: var_words = result['words'] picker.add( result['attack_metric'], { 'logtable': result['logtable'], "num_changed": len(change_positions__) }) elif result['code'] == 404: print('FAILED') break t1 = time.time() best_iter, best_attack_metric, best_info = picker.select_best_point() return { "raw_metric": raw_metric, "attack_metric": best_attack_metric, "iters": iter_id, "best_iter": best_iter, "num_changed": best_info['num_changed'], "time": t1 - t0, "logtable": best_info['logtable'] }