def run_model(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info('Loading model, tokenizer, etc.') config, model, tokenizer = load_pretrained(args.model_name) model.to(device) embeddings = get_embeddings(model, config) embedding_gradient = GradientStorage(embeddings) predictor = PredictWrapper(model) if args.label_map is not None: label_map = json.loads(args.label_map) logger.info(f"Label map: {label_map}") else: label_map = None templatizer = utils.TriggerTemplatizer( args.template, config, tokenizer, label_map=label_map, label_field=args.label_field, tokenize_labels=args.tokenize_labels, add_special_tokens=False, use_ctx=args.use_ctx ) # Obtain the initial trigger tokens and label mapping if args.initial_trigger: trigger_ids = tokenizer.convert_tokens_to_ids(args.initial_trigger) logger.debug(f'Initial trigger: {args.initial_trigger}') logger.debug(f'Trigger ids: {trigger_ids}') assert len(trigger_ids) == templatizer.num_trigger_tokens else: trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0) best_trigger_ids = trigger_ids.clone() # NOTE: Accuracy can only be computed if a fixed pool of labels is given, which currently # requires the label map to be specified. Since producing a label map may be cumbersome (e.g., # for link prediction tasks), we just use (negative) loss as the evaluation metric in these cases. if label_map: evaluation_fn = AccuracyFn(tokenizer, label_map, device) else: evaluation_fn = lambda x, y: -get_loss(x, y) logger.info('Loading datasets') collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) if args.perturbed: train_dataset = utils.load_augmented_trigger_dataset(args.train, templatizer, limit=args.limit) else: train_dataset = utils.load_trigger_dataset(args.train, templatizer, use_ctx=args.use_ctx, limit=args.limit) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) if args.perturbed: dev_dataset = utils.load_augmented_trigger_dataset(args.dev, templatizer) else: dev_dataset = utils.load_trigger_dataset(args.dev, templatizer, use_ctx=args.use_ctx) dev_loader = DataLoader(dev_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator) # To "filter" unwanted trigger tokens, we subtract a huge number from their logits. filter = torch.zeros(tokenizer.vocab_size, dtype=torch.float32, device=device) if args.filter: logger.info('Filtering label tokens.') if label_map: for label_tokens in label_map.values(): label_ids = utils.encode_label(tokenizer, label_tokens).unsqueeze(0) filter[label_ids] = -1e32 else: for _, label_ids in train_dataset: filter[label_ids] = -1e32 logger.info('Filtering special tokens and capitalized words.') for word, idx in tokenizer.get_vocab().items(): if len(word) == 1 or idx >= tokenizer.vocab_size: continue # Filter special tokens. if idx in tokenizer.all_special_ids: logger.debug('Filtered: %s', word) filter[idx] = -1e32 # Filter capitalized words (lazy way to remove proper nouns). if isupper(idx, tokenizer): logger.debug('Filtered: %s', word) filter[idx] = -1e32 logger.info('Evaluating') numerator = 0 denominator = 0 for model_inputs, labels in tqdm(dev_loader): model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) with torch.no_grad(): predict_logits = predictor(model_inputs, trigger_ids) numerator += evaluation_fn(predict_logits, labels).sum().item() denominator += labels.size(0) dev_metric = numerator / (denominator + 1e-13) logger.info(f'Dev metric: {dev_metric}') best_dev_metric = -float('inf') # Measure elapsed time of trigger search start = time.time() for i in range(args.iters): logger.info(f'Iteration: {i}') logger.info('Accumulating Gradient') model.zero_grad() pbar = tqdm(range(args.accumulation_steps)) train_iter = iter(train_loader) averaged_grad = None # Accumulate for step in pbar: # Shuttle inputs to GPU try: model_inputs, labels = next(train_iter) except: logger.warning( 'Insufficient data for number of accumulation steps. ' 'Effective batch size will be smaller than specified.' ) break model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) predict_logits = predictor(model_inputs, trigger_ids) loss = get_loss(predict_logits, labels).mean() loss.backward() grad = embedding_gradient.get() bsz, _, emb_dim = grad.size() selection_mask = model_inputs['trigger_mask'].unsqueeze(-1) grad = torch.masked_select(grad, selection_mask) grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim) if averaged_grad is None: averaged_grad = grad.sum(dim=0) / args.accumulation_steps else: averaged_grad += grad.sum(dim=0) / args.accumulation_steps logger.info('Evaluating Candidates') pbar = tqdm(range(args.accumulation_steps)) train_iter = iter(train_loader) token_to_flip = random.randrange(templatizer.num_trigger_tokens) candidates = hotflip_attack(averaged_grad[token_to_flip], embeddings.weight, increase_loss=False, num_candidates=args.num_cand, filter=filter) current_score = 0 candidate_scores = torch.zeros(args.num_cand, device=device) denom = 0 for step in pbar: try: model_inputs, labels = next(train_iter) except: logger.warning( 'Insufficient data for number of accumulation steps. ' 'Effective batch size will be smaller than specified.' ) break model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) with torch.no_grad(): predict_logits = predictor(model_inputs, trigger_ids) eval_metric = evaluation_fn(predict_logits, labels) # Update current score current_score += eval_metric.sum() denom += labels.size(0) # NOTE: Instead of iterating over tokens to flip we randomly change just one each # time so the gradients don't get stale. for i, candidate in enumerate(candidates): # if candidate.item() in filter_candidates: # candidate_scores[i] = -1e32 # continue temp_trigger = trigger_ids.clone() temp_trigger[:, token_to_flip] = candidate with torch.no_grad(): predict_logits = predictor(model_inputs, temp_trigger) eval_metric = evaluation_fn(predict_logits, labels) candidate_scores[i] += eval_metric.sum() # TODO: Something cleaner. LAMA templates can't have mask tokens, so if # there are still mask tokens in the trigger then set the current score # to -inf. if args.print_lama: if trigger_ids.eq(tokenizer.mask_token_id).any(): current_score = float('-inf') if (candidate_scores > current_score).any(): logger.info('Better trigger detected.') best_candidate_score = candidate_scores.max() best_candidate_idx = candidate_scores.argmax() trigger_ids[:, token_to_flip] = candidates[best_candidate_idx] logger.info(f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}') else: logger.info('No improvement detected. Skipping evaluation.') continue logger.info('Evaluating') numerator = 0 denominator = 0 for model_inputs, labels in tqdm(dev_loader): model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) with torch.no_grad(): predict_logits = predictor(model_inputs, trigger_ids) numerator += evaluation_fn(predict_logits, labels).sum().item() denominator += labels.size(0) dev_metric = numerator / (denominator + 1e-13) logger.info(f'Trigger tokens: {tokenizer.convert_ids_to_tokens(trigger_ids.squeeze(0))}') logger.info(f'Dev metric: {dev_metric}') # TODO: Something cleaner. LAMA templates can't have mask tokens, so if # there are still mask tokens in the trigger then set the current score # to -inf. if args.print_lama: if best_trigger_ids.eq(tokenizer.mask_token_id).any(): best_dev_metric = float('-inf') if dev_metric > best_dev_metric: logger.info('Best performance so far') best_trigger_ids = trigger_ids.clone() best_dev_metric = dev_metric best_trigger_tokens = tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0)) logger.info(f'Best tokens: {best_trigger_tokens}') logger.info(f'Best dev metric: {best_dev_metric}') if args.print_lama: # Templatize with [X] and [Y] if args.use_ctx: model_inputs, label_ids = templatizer({ 'sub_label': '[X]', 'obj_label': tokenizer.lama_y, 'context': '' }) else: model_inputs, label_ids = templatizer({ 'sub_label': '[X]', 'obj_label': tokenizer.lama_y, }) lama_template = model_inputs['input_ids'] # Instantiate trigger tokens lama_template.masked_scatter_( mask=model_inputs['trigger_mask'], source=best_trigger_ids.cpu()) # Instantiate label token lama_template.masked_scatter_( mask=model_inputs['predict_mask'], source=label_ids) # Print LAMA JSON template relation = args.train.parent.stem # The following block of code is a bit hacky but whatever, it gets the job done if args.use_ctx: template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[SEP] ', '').replace('</s> ', '').replace('[ X ]', '[X]') else: template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[ X ]', '[X]') out = { 'relation': args.train.parent.stem, 'template': template } print(json.dumps(out))
def main(args): ct.set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info('Loading model, tokenizer, etc.') config, model, tokenizer = load_pretrained(args.model_name) model.to(device) final_embeddings = get_final_embeddings(model) embedding_storage = utils.OutputStorage(final_embeddings) word_embeddings = get_word_embeddings(model) label_map = json.loads(args.label_map) reverse_label_map = {y: x for x, y in label_map.items()} templatizer = utils.TriggerTemplatizer( args.template, tokenizer, label_map=label_map, label_field=args.label_field, add_special_tokens=False ) # The weights of this projection will help identify the best label words. projection = torch.nn.Linear(config.hidden_size, len(label_map)) projection.to(device) # Obtain the initial trigger tokens and label mapping if args.initial_trigger: trigger_ids = tokenizer.encode( args.initial_trigger, add_special_tokens=False, add_prefix_space=True ) assert len(trigger_ids) == templatizer.num_trigger_tokens else: trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0) logger.info('Loading datasets') collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) train_dataset = utils.load_trigger_dataset(args.train, templatizer) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) optimizer = torch.optim.Adam(projection.parameters(), lr=args.lr) scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) scores = F.softmax(scores, dim=0) for i, row in enumerate(scores): _, top = row.topk(args.k) decoded = tokenizer.convert_ids_to_tokens(top) logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}") logger.info('Training') for i in range(args.iters): pbar = tqdm(train_loader) for model_inputs, labels in pbar: optimizer.zero_grad() model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) trigger_mask = model_inputs.pop('trigger_mask') predict_mask = model_inputs.pop('predict_mask') model_inputs = ct.replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask) with torch.no_grad(): model(**model_inputs) embeddings = embedding_storage.get() predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1) logits = projection(predict_embeddings) loss = F.cross_entropy(logits, labels.squeeze(-1)) loss.backward() optimizer.step() pbar.set_description(f'loss: {loss : 0.4f}') scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) scores = F.softmax(scores, dim=0) for i, row in enumerate(scores): _, top = row.topk(args.k) decoded = tokenizer.convert_ids_to_tokens(top) logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")