def test_collator(self): template = '[T] [T] {arbitrary} [T] {fields} [P]' tokenizer = AutoTokenizer.from_pretrained('bert-base-cased') config = AutoConfig.from_pretrained('bert-base-cased') utils.add_task_specific_tokens(tokenizer) templatizer = utils.TriggerTemplatizer( template, config, tokenizer, add_special_tokens=False ) collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) instances = [ {'arbitrary': 'a', 'fields': 'the', 'label': 'hot'}, {'arbitrary': 'a a', 'fields': 'the the', 'label': 'cold'} ] templatized_instances = [templatizer(x) for x in instances] loader = DataLoader( templatized_instances, batch_size=2, shuffle=False, collate_fn=collator ) model_inputs, labels = next(iter(loader)) # Check results match our expectations expected_labels = torch.tensor([ tokenizer.encode('hot', add_special_tokens=False, add_prefix_space=True), tokenizer.encode('cold', add_special_tokens=False, add_prefix_space=True), ]) assert torch.equal(expected_labels, labels) expected_trigger_mask = torch.tensor([ [True, True, False, True, False, False, False, False], [True, True, False, False, True, False, False, False], ]) assert torch.equal(expected_trigger_mask, model_inputs['trigger_mask']) expected_predict_mask = torch.tensor([ [False, False, False, False, False, True, False, False], [False, False, False, False, False, False, False, True], ]) assert torch.equal(expected_predict_mask, model_inputs['predict_mask'])
def run_model(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info('Loading model, tokenizer, etc.') config, model, tokenizer = load_pretrained(args.model_name) model.to(device) embeddings = get_embeddings(model, config) embedding_gradient = GradientStorage(embeddings) predictor = PredictWrapper(model) if args.label_map is not None: label_map = json.loads(args.label_map) logger.info(f"Label map: {label_map}") else: label_map = None templatizer = utils.TriggerTemplatizer( args.template, config, tokenizer, label_map=label_map, label_field=args.label_field, tokenize_labels=args.tokenize_labels, add_special_tokens=False, use_ctx=args.use_ctx ) # Obtain the initial trigger tokens and label mapping if args.initial_trigger: trigger_ids = tokenizer.convert_tokens_to_ids(args.initial_trigger) logger.debug(f'Initial trigger: {args.initial_trigger}') logger.debug(f'Trigger ids: {trigger_ids}') assert len(trigger_ids) == templatizer.num_trigger_tokens else: trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0) best_trigger_ids = trigger_ids.clone() # NOTE: Accuracy can only be computed if a fixed pool of labels is given, which currently # requires the label map to be specified. Since producing a label map may be cumbersome (e.g., # for link prediction tasks), we just use (negative) loss as the evaluation metric in these cases. if label_map: evaluation_fn = AccuracyFn(tokenizer, label_map, device) else: evaluation_fn = lambda x, y: -get_loss(x, y) logger.info('Loading datasets') collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) if args.perturbed: train_dataset = utils.load_augmented_trigger_dataset(args.train, templatizer, limit=args.limit) else: train_dataset = utils.load_trigger_dataset(args.train, templatizer, use_ctx=args.use_ctx, limit=args.limit) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) if args.perturbed: dev_dataset = utils.load_augmented_trigger_dataset(args.dev, templatizer) else: dev_dataset = utils.load_trigger_dataset(args.dev, templatizer, use_ctx=args.use_ctx) dev_loader = DataLoader(dev_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator) # To "filter" unwanted trigger tokens, we subtract a huge number from their logits. filter = torch.zeros(tokenizer.vocab_size, dtype=torch.float32, device=device) if args.filter: logger.info('Filtering label tokens.') if label_map: for label_tokens in label_map.values(): label_ids = utils.encode_label(tokenizer, label_tokens).unsqueeze(0) filter[label_ids] = -1e32 else: for _, label_ids in train_dataset: filter[label_ids] = -1e32 logger.info('Filtering special tokens and capitalized words.') for word, idx in tokenizer.get_vocab().items(): if len(word) == 1 or idx >= tokenizer.vocab_size: continue # Filter special tokens. if idx in tokenizer.all_special_ids: logger.debug('Filtered: %s', word) filter[idx] = -1e32 # Filter capitalized words (lazy way to remove proper nouns). if isupper(idx, tokenizer): logger.debug('Filtered: %s', word) filter[idx] = -1e32 logger.info('Evaluating') numerator = 0 denominator = 0 for model_inputs, labels in tqdm(dev_loader): model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) with torch.no_grad(): predict_logits = predictor(model_inputs, trigger_ids) numerator += evaluation_fn(predict_logits, labels).sum().item() denominator += labels.size(0) dev_metric = numerator / (denominator + 1e-13) logger.info(f'Dev metric: {dev_metric}') best_dev_metric = -float('inf') # Measure elapsed time of trigger search start = time.time() for i in range(args.iters): logger.info(f'Iteration: {i}') logger.info('Accumulating Gradient') model.zero_grad() pbar = tqdm(range(args.accumulation_steps)) train_iter = iter(train_loader) averaged_grad = None # Accumulate for step in pbar: # Shuttle inputs to GPU try: model_inputs, labels = next(train_iter) except: logger.warning( 'Insufficient data for number of accumulation steps. ' 'Effective batch size will be smaller than specified.' ) break model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) predict_logits = predictor(model_inputs, trigger_ids) loss = get_loss(predict_logits, labels).mean() loss.backward() grad = embedding_gradient.get() bsz, _, emb_dim = grad.size() selection_mask = model_inputs['trigger_mask'].unsqueeze(-1) grad = torch.masked_select(grad, selection_mask) grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim) if averaged_grad is None: averaged_grad = grad.sum(dim=0) / args.accumulation_steps else: averaged_grad += grad.sum(dim=0) / args.accumulation_steps logger.info('Evaluating Candidates') pbar = tqdm(range(args.accumulation_steps)) train_iter = iter(train_loader) token_to_flip = random.randrange(templatizer.num_trigger_tokens) candidates = hotflip_attack(averaged_grad[token_to_flip], embeddings.weight, increase_loss=False, num_candidates=args.num_cand, filter=filter) current_score = 0 candidate_scores = torch.zeros(args.num_cand, device=device) denom = 0 for step in pbar: try: model_inputs, labels = next(train_iter) except: logger.warning( 'Insufficient data for number of accumulation steps. ' 'Effective batch size will be smaller than specified.' ) break model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) with torch.no_grad(): predict_logits = predictor(model_inputs, trigger_ids) eval_metric = evaluation_fn(predict_logits, labels) # Update current score current_score += eval_metric.sum() denom += labels.size(0) # NOTE: Instead of iterating over tokens to flip we randomly change just one each # time so the gradients don't get stale. for i, candidate in enumerate(candidates): # if candidate.item() in filter_candidates: # candidate_scores[i] = -1e32 # continue temp_trigger = trigger_ids.clone() temp_trigger[:, token_to_flip] = candidate with torch.no_grad(): predict_logits = predictor(model_inputs, temp_trigger) eval_metric = evaluation_fn(predict_logits, labels) candidate_scores[i] += eval_metric.sum() # TODO: Something cleaner. LAMA templates can't have mask tokens, so if # there are still mask tokens in the trigger then set the current score # to -inf. if args.print_lama: if trigger_ids.eq(tokenizer.mask_token_id).any(): current_score = float('-inf') if (candidate_scores > current_score).any(): logger.info('Better trigger detected.') best_candidate_score = candidate_scores.max() best_candidate_idx = candidate_scores.argmax() trigger_ids[:, token_to_flip] = candidates[best_candidate_idx] logger.info(f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}') else: logger.info('No improvement detected. Skipping evaluation.') continue logger.info('Evaluating') numerator = 0 denominator = 0 for model_inputs, labels in tqdm(dev_loader): model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) with torch.no_grad(): predict_logits = predictor(model_inputs, trigger_ids) numerator += evaluation_fn(predict_logits, labels).sum().item() denominator += labels.size(0) dev_metric = numerator / (denominator + 1e-13) logger.info(f'Trigger tokens: {tokenizer.convert_ids_to_tokens(trigger_ids.squeeze(0))}') logger.info(f'Dev metric: {dev_metric}') # TODO: Something cleaner. LAMA templates can't have mask tokens, so if # there are still mask tokens in the trigger then set the current score # to -inf. if args.print_lama: if best_trigger_ids.eq(tokenizer.mask_token_id).any(): best_dev_metric = float('-inf') if dev_metric > best_dev_metric: logger.info('Best performance so far') best_trigger_ids = trigger_ids.clone() best_dev_metric = dev_metric best_trigger_tokens = tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0)) logger.info(f'Best tokens: {best_trigger_tokens}') logger.info(f'Best dev metric: {best_dev_metric}') if args.print_lama: # Templatize with [X] and [Y] if args.use_ctx: model_inputs, label_ids = templatizer({ 'sub_label': '[X]', 'obj_label': tokenizer.lama_y, 'context': '' }) else: model_inputs, label_ids = templatizer({ 'sub_label': '[X]', 'obj_label': tokenizer.lama_y, }) lama_template = model_inputs['input_ids'] # Instantiate trigger tokens lama_template.masked_scatter_( mask=model_inputs['trigger_mask'], source=best_trigger_ids.cpu()) # Instantiate label token lama_template.masked_scatter_( mask=model_inputs['predict_mask'], source=label_ids) # Print LAMA JSON template relation = args.train.parent.stem # The following block of code is a bit hacky but whatever, it gets the job done if args.use_ctx: template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[SEP] ', '').replace('</s> ', '').replace('[ X ]', '[X]') else: template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[ X ]', '[X]') out = { 'relation': args.train.parent.stem, 'template': template } print(json.dumps(out))
def run_autoprompt(args, dataset, cache_test): if cache_test.is_test: raise CacheMiss() ct.set_seed(args.seed) global_data = GlobalData.from_pretrained(args.model_name) templatizer = utils.TriggerTemplatizer( args.template, global_data.config, global_data.tokenizer, label_field=args.label_field, label_map=dataset.label_map, tokenize_labels=args.tokenize_labels, add_special_tokens=True, ) evaluation_fn = ct.AccuracyFn(global_data.tokenizer, dataset.label_map, global_data.device, tokenize_labels=args.tokenize_labels) # Do not allow for initial trigger specification. trigger_ids = [global_data.tokenizer.mask_token_id ] * templatizer.num_trigger_tokens trigger_ids = torch.tensor(trigger_ids, device=global_data.device).unsqueeze(0) best_trigger_ids = trigger_ids.clone() # Load datasets logger.info('Loading datasets') collator = utils.Collator(pad_token_id=global_data.tokenizer.pad_token_id) try: train_dataset = load_trigger_dataset(dataset.train, templatizer) except KeyError as e: raise RuntimeError( 'A field in your template is not present in the uploaded dataset. ' f'Check that there is a column with the name: {e}') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) progress = st.progress(0.0) trigger_placeholder = st.empty() best_dev_metric = -float('inf') for i in range(args.iters): logger.info(f'Iteration: {i}') progress.progress(float(i) / args.iters) current_trigger = ','.join( global_data.tokenizer.convert_ids_to_tokens( best_trigger_ids.squeeze(0))) trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}') global_data.model.zero_grad() train_iter = iter(train_loader) averaged_grad = None # Compute gradient of loss for step in range(args.accumulation_steps): try: model_inputs, labels = next(train_iter) except: logger.warning( 'Insufficient data for number of accumulation steps. ' 'Effective batch size will be smaller than specified.') break model_inputs = { k: v.to(global_data.device) for k, v in model_inputs.items() } labels = labels.to(global_data.device) predict_logits = global_data.predictor(model_inputs, trigger_ids) loss = ct.get_loss(predict_logits, labels).mean() loss.backward() grad = global_data.embedding_gradient.get() bsz, _, emb_dim = grad.size() selection_mask = model_inputs['trigger_mask'].unsqueeze(-1) grad = torch.masked_select(grad, selection_mask) grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim) if averaged_grad is None: averaged_grad = grad.sum(dim=0) / args.accumulation_steps else: averaged_grad += grad.sum(dim=0) / args.accumulation_steps logger.info('Evaluating Candidates') pbar = tqdm(range(args.accumulation_steps)) train_iter = iter(train_loader) token_to_flip = i % templatizer.num_trigger_tokens candidates = ct.hotflip_attack(averaged_grad[token_to_flip], global_data.embeddings.weight, increase_loss=False, num_candidates=args.num_cand) current_score = 0 candidate_scores = torch.zeros(args.num_cand, device=global_data.device) denom = 0 for step in pbar: try: model_inputs, labels = next(train_iter) except: logger.warning( 'Insufficient data for number of accumulation steps. ' 'Effective batch size will be smaller than specified.') break model_inputs = { k: v.to(global_data.device) for k, v in model_inputs.items() } labels = labels.to(global_data.device) with torch.no_grad(): predict_logits = global_data.predictor(model_inputs, trigger_ids) eval_metric = evaluation_fn(predict_logits, labels) # Update current score current_score += eval_metric.sum() denom += labels.size(0) # NOTE: Instead of iterating over tokens to flip we randomly change just one each # time so the gradients don't get stale. for i, candidate in enumerate(candidates): # if candidate.item() in filter_candidates: # candidate_scores[i] = -1e32 # continue temp_trigger = trigger_ids.clone() temp_trigger[:, token_to_flip] = candidate with torch.no_grad(): predict_logits = global_data.predictor( model_inputs, temp_trigger) eval_metric = evaluation_fn(predict_logits, labels) candidate_scores[i] += eval_metric.sum() if (candidate_scores >= current_score).any(): logger.info('Better trigger detected.') best_candidate_score = candidate_scores.max() best_candidate_idx = candidate_scores.argmax() trigger_ids[:, token_to_flip] = candidates[best_candidate_idx] logger.info( f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}' ) # Skip eval best_trigger_ids = trigger_ids.clone() progress.progress(1.0) current_trigger = ','.join( global_data.tokenizer.convert_ids_to_tokens( best_trigger_ids.squeeze(0))) trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}') best_trigger_tokens = global_data.tokenizer.convert_ids_to_tokens( best_trigger_ids.squeeze(0)) train_output = predict_test(map(lambda x: x['sentence'], dataset.train), dataset.label_map, templatizer, best_trigger_ids, global_data.tokenizer, global_data.predictor, args) # Streamlit does not like accessing widgets across functions, which is # problematic for this "live updating" widget which we want to still # display even if the train output is cached. To get around this, we're # going to delete the widget and replace it with a very similar looking # widget outside the function...no one will ever notice ;) trigger_placeholder.empty() return (best_trigger_tokens, current_score / denom, dataset.label_map, templatizer, best_trigger_ids, global_data.tokenizer, global_data.predictor, args, train_output)
def main(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels) tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config) model.to(device) collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) train_dataset, label_map = utils.load_classification_dataset( args.train, tokenizer, args.field_a, args.field_b, args.label_field, limit=args.limit) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) dev_dataset, _ = utils.load_classification_dataset(args.dev, tokenizer, args.field_a, args.field_b, args.label_field, label_map) dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) test_dataset, _ = utils.load_classification_dataset( args.test, tokenizer, args.field_a, args.field_b, args.label_field, label_map) test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) optimizer = torch.optim.Adam(model.classifier.parameters(), lr=args.lr, weight_decay=1e-6) # if not args.ckpt_dir.exists(): # logger.info(f'Making checkpoint directory: {args.ckpt_dir}') # args.ckpt_dir.mkdir(parents=True) # elif not args.force_overwrite: # raise RuntimeError('Checkpoint directory already exists.') best_accuracy = 0 for epoch in range(args.epochs): logger.info('Training...') model.train() avg_loss = utils.ExponentialMovingAverage() pbar = tqdm(train_loader) for model_inputs, labels in pbar: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) optimizer.zero_grad() logits, *_ = model(**model_inputs) loss = F.cross_entropy(logits, labels.squeeze(-1)) loss.backward() optimizer.step() avg_loss.update(loss.item()) pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}') logger.info('Evaluating...') model.eval() correct = 0 total = 0 for model_inputs, labels in dev_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) logits, *_ = model(**model_inputs) _, preds = logits.max(dim=-1) correct += (preds == labels.squeeze(-1)).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}') if accuracy > best_accuracy: logger.info('Best performance so far.') # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME) # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME) # tokenizer.save_pretrained(args.ckpt_dir) best_accuracy = accuracy logger.info('Testing...') model.eval() correct = 0 total = 0 for model_inputs, labels in test_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) logits, *_ = model(**model_inputs) _, preds = logits.max(dim=-1) correct += (preds == labels.squeeze(-1)).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}')
def main(args): logger.info("Dataset: %s" % str(args.train).split("/")[3]) set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') config = AutoConfig.from_pretrained(args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelWithLMHead.from_pretrained(args.model_name, config=config) if args.model_name == "bert-base-cased": model.embeds = model.bert.embeddings.word_embeddings eos_idx = 102 if not args.finetune: for param in model.bert.parameters(): param.requires_grad = False elif args.model_name == "roberta-base": model.embeds = model.roberta.embeddings.word_embeddings eos_idx = tokenizer.eos_token_id if not args.finetune: for param in model.roberta.parameters(): param.requires_grad = False if not args.finetune: for param in model.parameters(): param.requires_grad = False model.relation_embeds = torch.nn.Parameter( torch.rand(args.trigger_length, model.embeds.weight.shape[1], requires_grad=True)) model.to(device) collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) train_dataset = utils.load_continuous_trigger_dataset(args.train, tokenizer, args.field_a, args.field_b, args.label_field, limit=args.limit) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) dev_dataset = utils.load_continuous_trigger_dataset( args.dev, tokenizer, args.field_a, args.field_b, args.label_field) dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) test_dataset = utils.load_continuous_trigger_dataset( args.test, tokenizer, args.field_a, args.field_b, args.label_field) test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6) best_accuracy = 0 for epoch in range(args.epochs): logger.info('Training...') model.train() avg_loss = utils.ExponentialMovingAverage() pbar = tqdm(train_loader) for model_inputs, labels in pbar: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} mask_token_idxs = (model_inputs["input_ids"] == eos_idx).nonzero()[:, 1] + args.trigger_length model_inputs = generate_inputs_embeds(model_inputs, model, tokenizer, eos_idx) labels = labels.to(device)[:, 1] optimizer.zero_grad() logits, *_ = model(**model_inputs) mask_logits = logits[ torch.arange(0, logits.shape[0], dtype=torch.long), mask_token_idxs] loss = F.cross_entropy(mask_logits, labels) loss.backward() optimizer.step() avg_loss.update(loss.item()) pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}') logger.info('Evaluating...') model.eval() correct = 0 total = 0 for model_inputs, labels in dev_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} mask_token_idxs = (model_inputs["input_ids"] == eos_idx).nonzero()[:, 1] + args.trigger_length model_inputs = generate_inputs_embeds(model_inputs, model, tokenizer, eos_idx) labels = labels.to(device)[:, 1] logits, *_ = model(**model_inputs) mask_logits = logits[ torch.arange(0, logits.shape[0], dtype=torch.long), mask_token_idxs] preds = torch.topk(mask_logits, 1, dim=1).indices[:, 0] correct += (preds == labels).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}') if accuracy > best_accuracy: logger.info('Best performance so far.') # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME) # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME) # tokenizer.save_pretrained(args.ckpt_dir) best_accuracy = accuracy logger.info('Testing...') model.eval() correct = 0 total = 0 # TO DO: currently testing on last model, not best validation model for model_inputs, labels in test_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} mask_token_idxs = (model_inputs["input_ids"] == eos_idx).nonzero()[:, 1] + args.trigger_length model_inputs = generate_inputs_embeds(model_inputs, model, tokenizer, eos_idx) labels = labels.to(device)[:, 1] logits, *_ = model(**model_inputs) mask_logits = logits[ torch.arange(0, logits.shape[0], dtype=torch.long), mask_token_idxs] preds = torch.topk(mask_logits, 1, dim=1).indices[:, 0] correct += (preds == labels).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}')
def run_autoprompt(args, dataset): ct.set_seed(args.seed) global_data = GlobalData.from_pretrained(args.model_name) templatizer = utils.TriggerTemplatizer( args.template, global_data.config, global_data.tokenizer, label_field=args.label_field, label_map=dataset.label_map, tokenize_labels=args.tokenize_labels, add_special_tokens=False, ) evaluation_fn = ct.AccuracyFn(global_data.tokenizer, dataset.label_map, global_data.device) # Do not allow for initial trigger specification. trigger_ids = [global_data.tokenizer.mask_token_id ] * templatizer.num_trigger_tokens trigger_ids = torch.tensor(trigger_ids, device=global_data.device).unsqueeze(0) best_trigger_ids = trigger_ids.clone() # Load datasets logger.info('Loading datasets') collator = utils.Collator(pad_token_id=global_data.tokenizer.pad_token_id) train_dataset = load_trigger_dataset(dataset.train, templatizer) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) dev_dataset = load_trigger_dataset(dataset.dev, templatizer) dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator) progress = st.progress(0.0) trigger_placeholder = st.empty() best_dev_metric = -float('inf') for i in range(args.iters): logger.info(f'Iteration: {i}') progress.progress(float(i) / args.iters) current_trigger = ','.join( global_data.tokenizer.convert_ids_to_tokens( best_trigger_ids.squeeze(0))) trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}') global_data.model.zero_grad() train_iter = iter(train_loader) averaged_grad = None # Compute gradient of loss for step in range(args.accumulation_steps): try: model_inputs, labels = next(train_iter) except: logger.warning( 'Insufficient data for number of accumulation steps. ' 'Effective batch size will be smaller than specified.') break model_inputs = { k: v.to(global_data.device) for k, v in model_inputs.items() } labels = labels.to(global_data.device) predict_logits = global_data.predictor(model_inputs, trigger_ids) loss = ct.get_loss(predict_logits, labels).mean() loss.backward() grad = global_data.embedding_gradient.get() bsz, _, emb_dim = grad.size() selection_mask = model_inputs['trigger_mask'].unsqueeze(-1) grad = torch.masked_select(grad, selection_mask) grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim) if averaged_grad is None: averaged_grad = grad.sum(dim=0) / args.accumulation_steps else: averaged_grad += grad.sum(dim=0) / args.accumulation_steps logger.info('Evaluating Candidates') pbar = tqdm(range(args.accumulation_steps)) train_iter = iter(train_loader) token_to_flip = i % templatizer.num_trigger_tokens candidates = ct.hotflip_attack(averaged_grad[token_to_flip], global_data.embeddings.weight, increase_loss=False, num_candidates=args.num_cand) current_score = 0 candidate_scores = torch.zeros(args.num_cand, device=global_data.device) denom = 0 for step in pbar: try: model_inputs, labels = next(train_iter) except: logger.warning( 'Insufficient data for number of accumulation steps. ' 'Effective batch size will be smaller than specified.') break model_inputs = { k: v.to(global_data.device) for k, v in model_inputs.items() } labels = labels.to(global_data.device) with torch.no_grad(): predict_logits = global_data.predictor(model_inputs, trigger_ids) eval_metric = evaluation_fn(predict_logits, labels) # Update current score current_score += eval_metric.sum() denom += labels.size(0) # NOTE: Instead of iterating over tokens to flip we randomly change just one each # time so the gradients don't get stale. for i, candidate in enumerate(candidates): # if candidate.item() in filter_candidates: # candidate_scores[i] = -1e32 # continue temp_trigger = trigger_ids.clone() temp_trigger[:, token_to_flip] = candidate with torch.no_grad(): predict_logits = global_data.predictor( model_inputs, temp_trigger) eval_metric = evaluation_fn(predict_logits, labels) candidate_scores[i] += eval_metric.sum() if (candidate_scores > current_score).any(): logger.info('Better trigger detected.') best_candidate_score = candidate_scores.max() best_candidate_idx = candidate_scores.argmax() trigger_ids[:, token_to_flip] = candidates[best_candidate_idx] logger.info( f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}' ) logger.info('Evaluating') numerator = 0 denominator = 0 for model_inputs, labels in tqdm(dev_loader): model_inputs = { k: v.to(global_data.device) for k, v in model_inputs.items() } labels = labels.to(global_data.device) with torch.no_grad(): predict_logits = global_data.predictor(model_inputs, trigger_ids) numerator += evaluation_fn(predict_logits, labels).sum().item() denominator += labels.size(0) dev_metric = numerator / (denominator + 1e-13) if dev_metric > best_dev_metric: logger.info('Best performance so far') best_trigger_ids = trigger_ids.clone() best_dev_metric = dev_metric progress.progress(1.0) current_trigger = ','.join( global_data.tokenizer.convert_ids_to_tokens( best_trigger_ids.squeeze(0))) trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}') best_trigger_tokens = global_data.tokenizer.convert_ids_to_tokens( best_trigger_ids.squeeze(0)) dev_output = predict_test(map(lambda x: x['sentence'], dataset.dev), dataset.label_map, templatizer, best_trigger_ids, global_data.tokenizer, global_data.predictor, args) st.dataframe(pd.DataFrame(dev_output).style.highlight_min(axis=1)) return best_trigger_tokens, best_dev_metric, dataset.label_map, templatizer, best_trigger_ids, global_data.tokenizer, global_data.predictor, args
def main(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels) tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config) model.to(device) collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) train_dataset, label_map = utils.load_classification_dataset( args.train, tokenizer, args.field_a, args.field_b, args.label_field, limit=args.limit ) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) dev_dataset, _ = utils.load_classification_dataset( args.dev, tokenizer, args.field_a, args.field_b, args.label_field, label_map ) dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) test_dataset, _ = utils.load_classification_dataset( args.test, tokenizer, args.field_a, args.field_b, args.label_field, label_map ) test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) if args.bias_correction: betas = (0.9, 0.999) else: betas = (0.0, 0.000) optimizer = AdamW( model.parameters(), lr=args.lr, weight_decay=1e-2, betas=betas ) # Use suggested learning rate scheduler num_training_steps = len(train_dataset) * args.epochs // args.bsz num_warmup_steps = num_training_steps // 10 scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps) if not args.ckpt_dir.exists(): logger.info(f'Making checkpoint directory: {args.ckpt_dir}') args.ckpt_dir.mkdir(parents=True) elif not args.force_overwrite: raise RuntimeError('Checkpoint directory already exists.') try: best_accuracy = 0 for epoch in range(args.epochs): logger.info('Training...') model.train() avg_loss = utils.ExponentialMovingAverage() pbar = tqdm(train_loader) for model_inputs, labels in pbar: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) optimizer.zero_grad() logits, *_ = model(**model_inputs) loss = F.cross_entropy(logits, labels.squeeze(-1)) loss.backward() optimizer.step() scheduler.step() avg_loss.update(loss.item()) pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}, ' f'lr: {optimizer.param_groups[0]["lr"]: .3e}') logger.info('Evaluating...') model.eval() correct = 0 total = 0 with torch.no_grad(): for model_inputs, labels in dev_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) logits, *_ = model(**model_inputs) _, preds = logits.max(dim=-1) correct += (preds == labels.squeeze(-1)).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}') if accuracy > best_accuracy: logger.info('Best performance so far.') model.save_pretrained(args.ckpt_dir) tokenizer.save_pretrained(args.ckpt_dir) best_accuracy = accuracy except KeyboardInterrupt: logger.info('Interrupted...') logger.info('Testing...') model.eval() correct = 0 total = 0 with torch.no_grad(): for model_inputs, labels in test_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) logits, *_ = model(**model_inputs) _, preds = logits.max(dim=-1) correct += (preds == labels.squeeze(-1)).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}')
def main(args): ct.set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info('Loading model, tokenizer, etc.') config, model, tokenizer = load_pretrained(args.model_name) model.to(device) final_embeddings = get_final_embeddings(model) embedding_storage = utils.OutputStorage(final_embeddings) word_embeddings = get_word_embeddings(model) label_map = json.loads(args.label_map) reverse_label_map = {y: x for x, y in label_map.items()} templatizer = utils.TriggerTemplatizer( args.template, tokenizer, label_map=label_map, label_field=args.label_field, add_special_tokens=False ) # The weights of this projection will help identify the best label words. projection = torch.nn.Linear(config.hidden_size, len(label_map)) projection.to(device) # Obtain the initial trigger tokens and label mapping if args.initial_trigger: trigger_ids = tokenizer.encode( args.initial_trigger, add_special_tokens=False, add_prefix_space=True ) assert len(trigger_ids) == templatizer.num_trigger_tokens else: trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0) logger.info('Loading datasets') collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) train_dataset = utils.load_trigger_dataset(args.train, templatizer) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) optimizer = torch.optim.Adam(projection.parameters(), lr=args.lr) scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) scores = F.softmax(scores, dim=0) for i, row in enumerate(scores): _, top = row.topk(args.k) decoded = tokenizer.convert_ids_to_tokens(top) logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}") logger.info('Training') for i in range(args.iters): pbar = tqdm(train_loader) for model_inputs, labels in pbar: optimizer.zero_grad() model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) trigger_mask = model_inputs.pop('trigger_mask') predict_mask = model_inputs.pop('predict_mask') model_inputs = ct.replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask) with torch.no_grad(): model(**model_inputs) embeddings = embedding_storage.get() predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1) logits = projection(predict_embeddings) loss = F.cross_entropy(logits, labels.squeeze(-1)) loss.backward() optimizer.step() pbar.set_description(f'loss: {loss : 0.4f}') scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) scores = F.softmax(scores, dim=0) for i, row in enumerate(scores): _, top = row.topk(args.k) decoded = tokenizer.convert_ids_to_tokens(top) logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")