def gen_natural_collision(ex, model, tokenizer, device, lm_model, eval_lm_model=None): src, segs, clss, src_sent_labels, src_txt, tgt_txt = ex word_embedding = model.bert.model.get_input_embeddings().weight.detach() collition_init = tokenizer.convert_tokens_to_ids([BOS_TOKEN]) start_idx = 1 num_beams = args.num_beams repetition_penalty = 5.0 curr_len = len(collition_init) # scores for each sentence in the beam beam_scores = torch.zeros((num_beams, ), dtype=torch.float, device=device) beam_scores[1:] = -1e9 output_so_far = torch.tensor([collition_init] * num_beams, device=device) past = None vocab_size = tokenizer.vocab_size topk = args.topk src_ids = torch.tensor(src, device=device) src_embeds = word_embedding[src_ids] sub_mask = get_sub_masks(tokenizer, device) filter_ids = [ tokenizer.vocab[w] for w in tokenizer.vocab if not w.isalnum() ] first_mask = torch.zeros_like(sub_mask) first_mask[filter_ids] = -1e9 input_mask = torch.zeros(vocab_size, device=device) src_tokens = [ w for w in tokenizer.convert_ids_to_tokens(src) if w.isalpha() and w not in STOPWORDS ] input_mask[tokenizer.convert_tokens_to_ids(src_tokens)] = -1e9 input_mask[tokenizer.convert_tokens_to_ids(['.', '@', '='])] = -1e9 unk_ids = tokenizer.encode('<unk>', add_special_tokens=False) input_mask[unk_ids] = -1e9 sep_tensor = torch.tensor([tokenizer.sep_token_id] * topk, device=device) cls_tensor = torch.tensor([tokenizer.cls_token_id] * topk, device=device) is_first = True batch_sep_emb = word_embedding[sep_tensor].unsqueeze(1) batch_cls_emb = word_embedding[cls_tensor].unsqueeze(1) label = int(len(clss) * args.insert_pos) labels = torch.tensor([label] * num_beams, device=device) loss_fn = torch.nn.CrossEntropyLoss() def classifier_loss(p, context, pre_emb, src_emb, type_token_ids, new_clss, mask): context = torch.nn.functional.one_hot(context, len(word_embedding)) one_hot = torch.cat([context.float(), p.unsqueeze(1)], 1) x = torch.einsum('blv,vh->blh', one_hot, word_embedding) # add embeddings for SEP x = torch.cat( [batch_cls_emb[:num_beams], x, batch_sep_emb[:num_beams]], 1) inputs_embeds = torch.cat([pre_emb, x, src_emb], 1) scores = model(None, type_token_ids, new_clss, None, mask, inputs_embeds, output_logits=True) loss = loss_fn(scores, labels) loss += torch.mean(torch.max(scores, 1)[0] - scores[:, label]) return loss best_collision = None best_score = -1e9 best_rank = -1 while curr_len < args.seq_len: seq_len = curr_len - start_idx + 1 batch_prefix_ids, batch_prefix_emb, batch_src_ids, batch_src_emb, mask_cls, batch_segs, batch_new_clss = \ get_input_constant(label, seq_len, src_ids, src_embeds, segs, clss, device) model_inputs = lm_model.prepare_inputs_for_generation(output_so_far, past=past) outputs = lm_model(**model_inputs) present = outputs[1] # (batch_size * num_beams, vocab_size) next_token_logits = outputs[0][:, -1, :] lm_scores = torch.log_softmax(next_token_logits, dim=-1) next_lm_scores = lm_scores + beam_scores[:, None].expand_as(lm_scores) if args.perturb_iter > 0: # perturb internal states of LM def target_model_wrapper(p): return classifier_loss(p, output_so_far.detach()[:, start_idx:], batch_prefix_emb[:num_beams], batch_src_emb[:num_beams], batch_segs[:num_beams], batch_new_clss[:num_beams], mask_cls) next_token_logits = perturb_logits( next_token_logits, args.lr, target_model_wrapper, num_iterations=args.perturb_iter, kl_scale=args.kl_scale, temperature=args.stemp, device=device, verbose=args.verbose, logit_mask=input_mask, ) if repetition_penalty > 1.0: lm_model.enforce_repetition_penalty_(next_token_logits, 1, num_beams, output_so_far, repetition_penalty) next_token_logits = next_token_logits / args.stemp # (batch_size * num_beams, vocab_size) _, topk_tokens = torch.topk(next_token_logits, topk) # get target model score here next_clf_scores = [] for i in range(num_beams): next_beam_scores = torch.zeros(tokenizer.vocab_size, device=device) - 1e9 if output_so_far.shape[1] > start_idx: curr_beam_topk = output_so_far[i, start_idx:].unsqueeze(0).expand( topk, output_so_far.shape[1] - start_idx) # (topk, curr_len + next_token + sep) curr_beam_topk = torch.cat([ cls_tensor.unsqueeze(1), curr_beam_topk, topk_tokens[i].unsqueeze(1), sep_tensor.unsqueeze(1) ], 1) else: curr_beam_topk = torch.cat([ cls_tensor.unsqueeze(1), topk_tokens[i].unsqueeze(1), sep_tensor.unsqueeze(1) ], 1) concat_input_ids = torch.cat( [batch_prefix_ids, curr_beam_topk, batch_src_ids], 1) scores = model(concat_input_ids, batch_segs, batch_new_clss, None, mask_cls, None) clf_scores = torch.log_softmax(scores, -1)[:, label].detach() next_beam_scores.scatter_(0, topk_tokens[i], clf_scores) next_clf_scores.append(next_beam_scores.unsqueeze(0)) next_clf_scores = torch.cat(next_clf_scores, 0) if is_first: next_clf_scores += beam_scores[:, None].expand_as(lm_scores) next_clf_scores += first_mask is_first = False next_scores = ( 1 - args.beta) * next_clf_scores + args.beta * next_lm_scores next_scores += input_mask # re-organize to group the beam together # (we are keeping top hypothesis accross beams) next_scores = next_scores.view(num_beams * vocab_size) next_lm_scores = next_lm_scores.view(num_beams * vocab_size) next_scores, next_tokens = torch.topk(next_scores, num_beams, largest=True, sorted=True) next_lm_scores = next_lm_scores[next_tokens] # next batch beam content next_sent_beam = [] for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( zip(next_tokens, next_lm_scores)): # get beam and token IDs beam_id = beam_token_id // vocab_size token_id = beam_token_id % vocab_size next_sent_beam.append((beam_token_score, token_id, beam_id)) next_batch_beam = next_sent_beam # sanity check / prepare next batch assert len(next_batch_beam) == num_beams beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) beam_tokens = output_so_far.new([x[1] for x in next_batch_beam]) beam_idx = output_so_far.new([x[2] for x in next_batch_beam]) # re-order batch output_so_far = output_so_far[beam_idx, :] output_so_far = torch.cat( [output_so_far, beam_tokens.unsqueeze(1)], dim=-1) # sanity check pad_output_so_far = torch.cat([ cls_tensor[:num_beams].unsqueeze(1), output_so_far[:, start_idx:], sep_tensor[:num_beams].unsqueeze(1) ], 1) concat_input_ids = torch.cat([ batch_prefix_ids[:num_beams], pad_output_so_far, batch_src_ids[:num_beams] ], 1) actual_scores = model.forward(concat_input_ids, batch_segs[:num_beams], batch_new_clss[:num_beams], None, mask_cls, None) top_scores, top_labels = torch.topk(actual_scores, actual_scores.shape[-1]) actual_clf_scores = actual_scores[:, label].detach() sorter = torch.argsort(actual_clf_scores, -1, descending=True) if args.verbose: decoded = [ f'{actual_clf_scores[i].item():.4f}, ' f'{tokenizer.decode(output_so_far[i, start_idx:].cpu().tolist())}' for i in sorter ] log(f'Margin={top_scores[:, 2].max().item()} | ' + ' | '.join(decoded)) # re-order internal states past = lm_model._reorder_cache(present, beam_idx) # update current length curr_len = curr_len + 1 if curr_len > args.min_len: valid_idx = sorter[0] valid = False for idx in sorter: valid, _ = valid_tokenization(output_so_far[idx, start_idx:], tokenizer) if valid: valid_idx = idx break curr_score = actual_clf_scores[valid_idx].item() curr_collision = tokenizer.decode( output_so_far[valid_idx, start_idx:].cpu().tolist()) curr_rank = ( top_labels[valid_idx] == label).nonzero().squeeze().item() if valid and curr_score > best_score: best_score = curr_score best_collision = curr_collision best_rank = curr_rank if args.verbose: lm_perp = eval_lm_model.perplexity(curr_collision) log(f'LM perp={lm_perp.item()}') return best_collision, best_score, best_rank
def gen_natural_collision(inputs_a, inputs_b, model, tokenizer, device, lm_model, margin=None, eval_lm_model=None): input_mask = torch.zeros(tokenizer.vocab_size, device=device) filters = find_filters(inputs_a, model, tokenizer, device, k=args.num_filters) best_ids = get_inputs_filter_ids(inputs_b, tokenizer) input_mask[best_ids] = -1e9 num_filters_ids = tokenizer.convert_tokens_to_ids(filters) input_mask[num_filters_ids] = -1e9 remove_tokens = add_single_plural(inputs_a, tokenizer) if args.verbose: log(','.join(remove_tokens)) remove_ids = tokenizer.convert_tokens_to_ids(remove_tokens) input_mask[remove_ids] = -1e9 input_mask[tokenizer.convert_tokens_to_ids(['.', '@', '='])] = -1e9 unk_ids = tokenizer.encode('<unk>', add_special_tokens=False) input_mask[unk_ids] = -1e9 filter_ids = [ tokenizer.vocab[w] for w in tokenizer.vocab if not w.isalnum() ] first_mask = torch.zeros_like(input_mask) first_mask[filter_ids] = -1e9 collition_init = tokenizer.convert_tokens_to_ids([BOS_TOKEN]) start_idx = 1 num_beams = args.num_beams repetition_penalty = 5.0 curr_len = len(collition_init) # scores for each sentence in the beam beam_scores = torch.zeros((num_beams, ), dtype=torch.float, device=device) beam_scores[1:] = -1e9 output_so_far = torch.tensor([collition_init] * num_beams, device=device) past = None vocab_size = tokenizer.vocab_size topk = args.topk input_ids = tokenizer.encode(inputs_a) input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) batch_input_ids = torch.cat([input_ids] * topk, 0) sep_tensor = torch.tensor([tokenizer.sep_token_id] * topk, device=device) is_first = True word_embedding = model.get_input_embeddings().weight.detach() batch_sep_embeds = word_embedding[sep_tensor].unsqueeze(1) batch_labels = torch.ones((num_beams, ), dtype=torch.long, device=device) def classifier_loss(p, context): context = torch.nn.functional.one_hot(context, len(word_embedding)) one_hot = torch.cat([context.float(), p.unsqueeze(1)], 1) x = torch.einsum('blv,vh->blh', one_hot, word_embedding) # add embeddings for SEP x = torch.cat([x, batch_sep_embeds[:num_beams]], 1) cls_loss = model(batch_input_ids[:num_beams], inputs_embeds=x, next_sentence_label=batch_labels)[0] return cls_loss best_score = -1e9 best_collision = None collision_cands = [] while (curr_len - start_idx) < args.seq_len: model_inputs = lm_model.prepare_inputs_for_generation(output_so_far, past=past) outputs = lm_model(**model_inputs) present = outputs[1] # (batch_size * num_beams, vocab_size) next_token_logits = outputs[0][:, -1, :] lm_scores = torch.log_softmax(next_token_logits, dim=-1) if args.perturb_iter > 0: # perturb internal states of LM def target_model_wrapper(p): return classifier_loss(p, output_so_far.detach()[:, start_idx:]) next_token_logits = perturb_logits( next_token_logits, args.lr, target_model_wrapper, num_iterations=args.perturb_iter, kl_scale=args.kl_scale, temperature=args.stemp, device=device, verbose=args.verbose, logit_mask=input_mask, ) if repetition_penalty > 1.0: lm_model.enforce_repetition_penalty_(next_token_logits, 1, num_beams, output_so_far, repetition_penalty) next_token_logits = next_token_logits / args.stemp # (batch_size * num_beams, vocab_size) next_lm_scores = lm_scores + beam_scores[:, None].expand_as(lm_scores) _, topk_tokens = torch.topk(next_token_logits, topk) # get target model score here next_clf_scores = [] for i in range(num_beams): next_beam_scores = torch.zeros(tokenizer.vocab_size, device=device) - 1e9 if output_so_far.shape[1] > start_idx: curr_beam_topk = output_so_far[i, start_idx:].unsqueeze(0).expand( topk, output_so_far.shape[1] - start_idx) # (topk, curr_len + next_token + sep) curr_beam_topk = torch.cat([ curr_beam_topk, topk_tokens[i].unsqueeze(1), sep_tensor.unsqueeze(1) ], 1) else: curr_beam_topk = torch.cat( [topk_tokens[i].unsqueeze(1), sep_tensor.unsqueeze(1)], 1) concat_input_ids = torch.cat([batch_input_ids, curr_beam_topk], 1) token_type_ids = torch.cat([ torch.zeros_like(batch_input_ids), torch.ones_like(curr_beam_topk), ], 1) clf_logits = model(input_ids=concat_input_ids, token_type_ids=token_type_ids)[0] clf_scores = torch.log_softmax(clf_logits, -1)[:, 1].detach() next_beam_scores.scatter_(0, topk_tokens[i], clf_scores.float()) next_clf_scores.append(next_beam_scores.unsqueeze(0)) next_clf_scores = torch.cat(next_clf_scores, 0) if is_first: next_clf_scores += beam_scores[:, None].expand_as(lm_scores) next_clf_scores += first_mask is_first = False next_scores = ( 1 - args.beta) * next_clf_scores + args.beta * next_lm_scores next_scores += input_mask # re-organize to group the beam together # (we are keeping top hypothesis accross beams) next_scores = next_scores.view(num_beams * vocab_size) next_lm_scores = next_lm_scores.view(num_beams * vocab_size) next_scores, next_tokens = torch.topk(next_scores, num_beams, largest=True, sorted=True) next_lm_scores = next_lm_scores[next_tokens] # next batch beam content next_sent_beam = [] for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( zip(next_tokens, next_lm_scores)): # get beam and token IDs beam_id = beam_token_id // vocab_size token_id = beam_token_id % vocab_size next_sent_beam.append((beam_token_score, token_id, beam_id)) next_batch_beam = next_sent_beam # sanity check / prepare next batch assert len(next_batch_beam) == num_beams beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) beam_tokens = output_so_far.new([x[1] for x in next_batch_beam]) beam_idx = output_so_far.new([x[2] for x in next_batch_beam]) # re-order batch output_so_far = output_so_far[beam_idx, :] output_so_far = torch.cat( [output_so_far, beam_tokens.unsqueeze(1)], dim=-1) # sanity check pad_output_so_far = torch.cat([ output_so_far[:, start_idx:], sep_tensor[:num_beams].unsqueeze(1) ], 1) concat_input_ids = torch.cat( [batch_input_ids[:num_beams], pad_output_so_far], 1) token_type_ids = torch.cat([ torch.zeros_like(batch_input_ids[:num_beams]), torch.ones_like(pad_output_so_far) ], 1) clf_logits = model(input_ids=concat_input_ids, token_type_ids=token_type_ids)[0] actual_clf_scores = clf_logits[:, 1] sorter = torch.argsort(actual_clf_scores, -1, descending=True) if args.verbose: decoded = [ f'{actual_clf_scores[i].item():.4f}, ' f'{tokenizer.decode(output_so_far[i, start_idx:].cpu().tolist())}' for i in sorter ] log(f'Margin={margin if margin else 0:.4f}, query={inputs_a} | ' + ' | '.join(decoded)) if curr_len > args.min_len: valid_idx = sorter[0] valid = False for idx in sorter: valid, _ = valid_tokenization(output_so_far[idx, start_idx:], tokenizer) if valid: valid_idx = idx break curr_score = actual_clf_scores[valid_idx].item() curr_collision = tokenizer.decode( output_so_far[valid_idx, start_idx:].cpu().tolist()) collision_cands.append((curr_score, curr_collision)) if valid and curr_score > best_score: best_score = curr_score best_collision = curr_collision if args.verbose: lm_perp = eval_lm_model.perplexity(curr_collision) log(f'LM perp={lm_perp.item()}') # re-order internal states past = lm_model._reorder_cache(present, beam_idx) # update current length curr_len = curr_len + 1 return best_collision, best_score, collision_cands
def gen_aggressive_collision(ex, model, tokenizer, device, lm_model=None): src, segs, clss, src_sent_labels, src_txt, tgt_txt = ex word_embedding = model.bert.model.get_input_embeddings().weight.detach() if lm_model is not None: lm_word_embedding = lm_model.get_input_embeddings().weight.detach() vocab_size = word_embedding.size(0) src_ids = torch.tensor(src, device=device) src_embeds = word_embedding[src_ids] sub_mask = get_sub_masks(tokenizer, device) input_mask = torch.zeros(vocab_size, device=device) src_tokens = [ w for w in tokenizer.convert_ids_to_tokens(src) if w.isalpha() and w not in STOPWORDS ] input_mask[tokenizer.convert_tokens_to_ids(src_tokens)] = -1e9 seq_len = args.seq_len stopwords_mask = create_constraints(seq_len, tokenizer, device) def relaxed_to_word_embs(x): # convert relaxed inputs to word embedding by softmax attention masked_x = x + input_mask + sub_mask if args.regularize: masked_x += stopwords_mask p = torch.softmax(masked_x / args.stemp, -1) x = torch.mm(p, word_embedding) # add embeddings for period and SEP x = torch.cat([ word_embedding[tokenizer.cls_token_id].unsqueeze(0), x, word_embedding[tokenizer.sep_token_id].unsqueeze(0) ]) return p, x.unsqueeze(0) def get_lm_loss(p): x = torch.mm(p.detach(), lm_word_embedding).unsqueeze(0) return lm_model(inputs_embeds=x, one_hot_labels=p.unsqueeze(0))[0] # some constants sep_tensor = torch.tensor([tokenizer.sep_token_id] * args.topk, device=device) batch_sep_emb = word_embedding[sep_tensor].unsqueeze(1) cls_tensor = torch.tensor([tokenizer.cls_token_id] * args.topk, device=device) batch_cls_emb = word_embedding[cls_tensor].unsqueeze(1) label = int(len(clss) * args.insert_pos) labels = torch.tensor([label], device=device) batch_prefix_ids, batch_prefix_emb, batch_src_ids, batch_src_emb, mask_cls, batch_segs, batch_new_clss = \ get_input_constant(label, seq_len, src_ids, src_embeds, segs, clss, device) prefix_embeds = batch_prefix_emb[0] src_embeds = batch_src_emb[0] type_token_ids = batch_segs[0] new_clss = batch_new_clss[0] loss_fn = torch.nn.CrossEntropyLoss() best_collision = None best_score = -1e9 best_rank = -1 prev_score = -1e9 var_size = (seq_len, vocab_size) z_i = torch.zeros(*var_size, requires_grad=True, device=device) for it in range(args.max_iter): optimizer = torch.optim.Adam([z_i], lr=args.lr) for j in range(args.perturb_iter): optimizer.zero_grad() # relaxation p_inputs, inputs_embeds = relaxed_to_word_embs(z_i) # forward to BERT with relaxed inputs inputs_embeds = torch.cat([ prefix_embeds.unsqueeze(0), inputs_embeds, src_embeds.unsqueeze(0) ], 1) scores = model(None, type_token_ids, new_clss, None, mask_cls, inputs_embeds, output_logits=True) loss = loss_fn(scores, labels) scores = scores.squeeze() loss += torch.max(scores) - scores[label] if args.beta > 0.: lm_loss = get_lm_loss(p_inputs) loss = args.beta * lm_loss + (1 - args.beta) * loss loss.backward() optimizer.step() if args.verbose and (j + 1) % 10 == 0: log(f'It{it}-{j + 1}, loss={loss.item()}') # detach to free GPU memory z_i = z_i.detach() _, topk_tokens = torch.topk(z_i, args.topk) probs_i = torch.softmax(z_i / args.stemp, -1).unsqueeze(0).expand( args.topk, seq_len, vocab_size) output_so_far = None # beam search left to right for t in range(seq_len): t_topk_tokens = topk_tokens[t] t_topk_onehot = torch.nn.functional.one_hot( t_topk_tokens, vocab_size).float() next_clf_scores = [] for j in range(args.num_beams): next_beam_scores = torch.zeros(tokenizer.vocab_size, device=device) - 1e9 if output_so_far is None: context = probs_i.clone() else: output_len = output_so_far.shape[1] beam_topk_output = output_so_far[j].unsqueeze(0).expand( args.topk, output_len) beam_topk_output = torch.nn.functional.one_hot( beam_topk_output, vocab_size) context = torch.cat([ beam_topk_output.float(), probs_i[:, output_len:].clone() ], 1) context[:, t] = t_topk_onehot context_emb = torch.einsum('blv,vh->blh', context, word_embedding) context_emb = torch.cat( [batch_cls_emb, context_emb, batch_sep_emb], 1) inputs_emb = torch.cat( [batch_prefix_emb, context_emb, batch_src_emb], 1) scores = model(None, batch_segs, batch_new_clss, None, mask_cls, inputs_emb, output_logits=True) clf_scores = scores[:, label].detach().float() next_beam_scores.scatter_(0, t_topk_tokens, clf_scores) next_clf_scores.append(next_beam_scores.unsqueeze(0)) next_clf_scores = torch.cat(next_clf_scores, 0) next_scores = next_clf_scores + input_mask + sub_mask if args.regularize: next_scores += stopwords_mask[t] if output_so_far is None: next_scores[1:] = -1e9 # re-organize to group the beam together # (we are keeping top hypothesis accross beams) next_scores = next_scores.view( 1, args.num_beams * vocab_size) # (batch_size, num_beams * vocab_size) next_scores, next_tokens = torch.topk(next_scores, args.num_beams, dim=1, largest=True, sorted=True) # next batch beam content next_sent_beam = [] for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( zip(next_tokens[0], next_scores[0])): # get beam and token IDs beam_id = beam_token_id // vocab_size token_id = beam_token_id % vocab_size next_sent_beam.append((beam_token_score, token_id, beam_id)) next_batch_beam = next_sent_beam # sanity check / prepare next batch assert len(next_batch_beam) == args.num_beams beam_tokens = torch.tensor([x[1] for x in next_batch_beam], device=device) beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=device) # re-order batch if output_so_far is None: output_so_far = beam_tokens.unsqueeze(1) else: output_so_far = output_so_far[beam_idx, :] output_so_far = torch.cat( [output_so_far, beam_tokens.unsqueeze(1)], dim=-1) pad_output_so_far = torch.cat([ cls_tensor[:args.num_beams].unsqueeze(1), output_so_far, sep_tensor[:args.num_beams].unsqueeze(1) ], 1) concat_input_ids = torch.cat([ batch_prefix_ids[:args.num_beams], pad_output_so_far, batch_src_ids[:args.num_beams] ], 1) actual_scores = model.forward(concat_input_ids, batch_segs[:args.num_beams], batch_new_clss[:args.num_beams], None, mask_cls, None).squeeze() actual_clf_scores = actual_scores[:, label].detach() top_scores, top_labels = torch.topk(actual_scores, actual_scores.shape[-1]) sorter = torch.argsort(actual_clf_scores, -1, descending=True) if args.verbose: decoded = [ f'{actual_clf_scores[i].item():.4f}, ' f'{tokenizer.decode(output_so_far[i].cpu().tolist())}' for i in sorter ] log(f'It={it}, margin={top_scores[:, 2].max().item()} | ' + ' | '.join(decoded)) valid_idx = sorter[0] valid = False for idx in sorter: valid, _ = valid_tokenization(output_so_far[idx], tokenizer) if valid: valid_idx = idx break # re-initialize z_i curr_best = output_so_far[valid_idx] next_z_i = torch.nn.functional.one_hot(curr_best, vocab_size).float() eps = 0.1 next_z_i = (next_z_i * (1 - eps)) + (1 - next_z_i) * eps / (vocab_size - 1) z_i = torch.nn.Parameter(torch.log(next_z_i), True) curr_score = actual_clf_scores[valid_idx].item() curr_collision = tokenizer.decode(curr_best.cpu().tolist()) curr_rank = (top_labels[valid_idx] == label).nonzero().squeeze().item() if valid and curr_score > best_score: best_score = curr_score best_collision = curr_collision best_rank = curr_rank if prev_score == curr_score: break prev_score = curr_score return best_collision, best_score, best_rank
def gen_aggressive_collision(inputs_a, inputs_b, model, tokenizer, device, margin=None, lm_model=None): word_embedding = model.get_input_embeddings().weight.detach() if lm_model is not None: lm_word_embedding = lm_model.get_input_embeddings().weight.detach() vocab_size = word_embedding.size(0) input_mask = torch.zeros(vocab_size, device=device) filters = find_filters(inputs_a, model, tokenizer, device, k=args.num_filters) best_ids = get_inputs_filter_ids(inputs_b, tokenizer) input_mask[best_ids] = -1e9 remove_tokens = add_single_plural(inputs_a, tokenizer) if args.verbose: log(','.join(remove_tokens)) remove_ids = tokenizer.convert_tokens_to_ids(remove_tokens) remove_ids.append(tokenizer.vocab['.']) input_mask[remove_ids] = -1e9 num_filters_ids = tokenizer.convert_tokens_to_ids(filters) input_mask[num_filters_ids] = -1e9 sub_mask = get_sub_masks(tokenizer, device) input_ids = tokenizer.encode(inputs_a) input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) # prevent output num_filters neighbor words seq_len = args.seq_len batch_input_ids = torch.cat([input_ids] * args.topk, 0) stopwords_mask = create_constraints(seq_len, tokenizer, device) def relaxed_to_word_embs(x): # convert relaxed inputs to word embedding by softmax attention masked_x = x + input_mask + sub_mask if args.regularize: masked_x += stopwords_mask p = torch.softmax(masked_x / args.stemp, -1) x = torch.mm(p, word_embedding) # add embeddings for period and SEP x = torch.cat([x, word_embedding[tokenizer.sep_token_id].unsqueeze(0)]) return p, x.unsqueeze(0) def get_lm_loss(p): x = torch.mm(p.detach(), lm_word_embedding).unsqueeze(0) return lm_model(inputs_embeds=x, one_hot_labels=p.unsqueeze(0))[0] # some constants sep_tensor = torch.tensor([tokenizer.sep_token_id] * args.topk, device=device) batch_sep_embeds = word_embedding[sep_tensor].unsqueeze(1) labels = torch.ones((1, ), dtype=torch.long, device=device) repetition_penalty = 1.0 best_collision = None best_score = -1e9 prev_score = -1e9 collision_cands = [] var_size = (seq_len, vocab_size) z_i = torch.zeros(*var_size, requires_grad=True, device=device) for it in range(args.max_iter): optimizer = torch.optim.Adam([z_i], lr=args.lr) for j in range(args.perturb_iter): optimizer.zero_grad() # relaxation p_inputs, inputs_embeds = relaxed_to_word_embs(z_i) # forward to BERT with relaxed inputs loss, cls_logits, _ = model(input_ids, inputs_embeds=inputs_embeds, next_sentence_label=labels) if margin is not None: loss += torch.sum(torch.relu(margin - cls_logits[:, 1])) if args.beta > 0.: lm_loss = get_lm_loss(p_inputs) loss = args.beta * lm_loss + (1 - args.beta) * loss loss.backward() optimizer.step() if args.verbose and (j + 1) % 10 == 0: log(f'It{it}-{j + 1}, loss={loss.item()}') # detach to free GPU memory z_i = z_i.detach() _, topk_tokens = torch.topk(z_i, args.topk) probs_i = torch.softmax(z_i / args.stemp, -1).unsqueeze(0).expand( args.topk, seq_len, vocab_size) output_so_far = None # beam search left to right for t in range(seq_len): t_topk_tokens = topk_tokens[t] t_topk_onehot = torch.nn.functional.one_hot( t_topk_tokens, vocab_size).float() next_clf_scores = [] for j in range(args.num_beams): next_beam_scores = torch.zeros(tokenizer.vocab_size, device=device) - 1e9 if output_so_far is None: context = probs_i.clone() else: output_len = output_so_far.shape[1] beam_topk_output = output_so_far[j].unsqueeze(0).expand( args.topk, output_len) beam_topk_output = torch.nn.functional.one_hot( beam_topk_output, vocab_size) context = torch.cat([ beam_topk_output.float(), probs_i[:, output_len:].clone() ], 1) context[:, t] = t_topk_onehot context_embeds = torch.einsum('blv,vh->blh', context, word_embedding) context_embeds = torch.cat([context_embeds, batch_sep_embeds], 1) clf_logits = model(input_ids=batch_input_ids, inputs_embeds=context_embeds)[0] clf_scores = clf_logits[:, 1].detach().float() next_beam_scores.scatter_(0, t_topk_tokens, clf_scores) next_clf_scores.append(next_beam_scores.unsqueeze(0)) next_clf_scores = torch.cat(next_clf_scores, 0) next_scores = next_clf_scores + input_mask + sub_mask if args.regularize: next_scores += stopwords_mask[t] if output_so_far is None: next_scores[1:] = -1e9 if output_so_far is not None and repetition_penalty > 1.0: lm_model.enforce_repetition_penalty_(next_scores, 1, args.num_beams, output_so_far, repetition_penalty) # re-organize to group the beam together # (we are keeping top hypothesis accross beams) next_scores = next_scores.view( 1, args.num_beams * vocab_size) # (batch_size, num_beams * vocab_size) next_scores, next_tokens = torch.topk(next_scores, args.num_beams, dim=1, largest=True, sorted=True) # next batch beam content next_sent_beam = [] for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( zip(next_tokens[0], next_scores[0])): # get beam and token IDs beam_id = beam_token_id // vocab_size token_id = beam_token_id % vocab_size next_sent_beam.append((beam_token_score, token_id, beam_id)) next_batch_beam = next_sent_beam # sanity check / prepare next batch assert len(next_batch_beam) == args.num_beams beam_tokens = torch.tensor([x[1] for x in next_batch_beam], device=device) beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=device) # re-order batch if output_so_far is None: output_so_far = beam_tokens.unsqueeze(1) else: output_so_far = output_so_far[beam_idx, :] output_so_far = torch.cat( [output_so_far, beam_tokens.unsqueeze(1)], dim=-1) pad_output_so_far = torch.cat( [output_so_far, sep_tensor[:args.num_beams].unsqueeze(1)], 1) concat_input_ids = torch.cat( [batch_input_ids[:args.num_beams], pad_output_so_far], 1) token_type_ids = torch.cat([ torch.zeros_like(batch_input_ids[:args.num_beams]), torch.ones_like(pad_output_so_far) ], 1) clf_logits = model(input_ids=concat_input_ids, token_type_ids=token_type_ids)[0] actual_clf_scores = clf_logits[:, 1] sorter = torch.argsort(actual_clf_scores, -1, descending=True) if args.verbose: decoded = [ f'{actual_clf_scores[i].item():.4f}, ' f'{tokenizer.decode(output_so_far[i].cpu().tolist())}' for i in sorter ] log(f'It={it}, margin={margin:.4f}, query={inputs_a} | ' + ' | '.join(decoded)) valid_idx = sorter[0] valid = False for idx in sorter: valid, _ = valid_tokenization(output_so_far[idx], tokenizer) if valid: valid_idx = idx break # re-initialize z_i curr_best = output_so_far[valid_idx] next_z_i = torch.nn.functional.one_hot(curr_best, vocab_size).float() eps = 0.1 next_z_i = (next_z_i * (1 - eps)) + (1 - next_z_i) * eps / (vocab_size - 1) z_i = torch.nn.Parameter(torch.log(next_z_i), True) curr_score = actual_clf_scores[valid_idx].item() if valid and curr_score > best_score: best_score = curr_score best_collision = tokenizer.decode(curr_best.cpu().tolist()) if curr_score <= prev_score: break prev_score = curr_score return best_collision, best_score, collision_cands