def main(args): logger.info("Dataset: %s" % str(args.train).split("/")[3]) set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') config = AutoConfig.from_pretrained(args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelWithLMHead.from_pretrained(args.model_name, config=config) if args.model_name == "bert-base-cased": model.embeds = model.bert.embeddings.word_embeddings eos_idx = 102 if not args.finetune: for param in model.bert.parameters(): param.requires_grad = False elif args.model_name == "roberta-base": model.embeds = model.roberta.embeddings.word_embeddings eos_idx = tokenizer.eos_token_id if not args.finetune: for param in model.roberta.parameters(): param.requires_grad = False if not args.finetune: for param in model.parameters(): param.requires_grad = False model.relation_embeds = torch.nn.Parameter( torch.rand(args.trigger_length, model.embeds.weight.shape[1], requires_grad=True)) model.to(device) collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) train_dataset = utils.load_continuous_trigger_dataset(args.train, tokenizer, args.field_a, args.field_b, args.label_field, limit=args.limit) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) dev_dataset = utils.load_continuous_trigger_dataset( args.dev, tokenizer, args.field_a, args.field_b, args.label_field) dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) test_dataset = utils.load_continuous_trigger_dataset( args.test, tokenizer, args.field_a, args.field_b, args.label_field) test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6) best_accuracy = 0 for epoch in range(args.epochs): logger.info('Training...') model.train() avg_loss = utils.ExponentialMovingAverage() pbar = tqdm(train_loader) for model_inputs, labels in pbar: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} mask_token_idxs = (model_inputs["input_ids"] == eos_idx).nonzero()[:, 1] + args.trigger_length model_inputs = generate_inputs_embeds(model_inputs, model, tokenizer, eos_idx) labels = labels.to(device)[:, 1] optimizer.zero_grad() logits, *_ = model(**model_inputs) mask_logits = logits[ torch.arange(0, logits.shape[0], dtype=torch.long), mask_token_idxs] loss = F.cross_entropy(mask_logits, labels) loss.backward() optimizer.step() avg_loss.update(loss.item()) pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}') logger.info('Evaluating...') model.eval() correct = 0 total = 0 for model_inputs, labels in dev_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} mask_token_idxs = (model_inputs["input_ids"] == eos_idx).nonzero()[:, 1] + args.trigger_length model_inputs = generate_inputs_embeds(model_inputs, model, tokenizer, eos_idx) labels = labels.to(device)[:, 1] logits, *_ = model(**model_inputs) mask_logits = logits[ torch.arange(0, logits.shape[0], dtype=torch.long), mask_token_idxs] preds = torch.topk(mask_logits, 1, dim=1).indices[:, 0] correct += (preds == labels).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}') if accuracy > best_accuracy: logger.info('Best performance so far.') # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME) # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME) # tokenizer.save_pretrained(args.ckpt_dir) best_accuracy = accuracy logger.info('Testing...') model.eval() correct = 0 total = 0 # TO DO: currently testing on last model, not best validation model for model_inputs, labels in test_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} mask_token_idxs = (model_inputs["input_ids"] == eos_idx).nonzero()[:, 1] + args.trigger_length model_inputs = generate_inputs_embeds(model_inputs, model, tokenizer, eos_idx) labels = labels.to(device)[:, 1] logits, *_ = model(**model_inputs) mask_logits = logits[ torch.arange(0, logits.shape[0], dtype=torch.long), mask_token_idxs] preds = torch.topk(mask_logits, 1, dim=1).indices[:, 0] correct += (preds == labels).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}')
def main(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels) tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config) model.to(device) collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) train_dataset, label_map = utils.load_classification_dataset( args.train, tokenizer, args.field_a, args.field_b, args.label_field, limit=args.limit) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) dev_dataset, _ = utils.load_classification_dataset(args.dev, tokenizer, args.field_a, args.field_b, args.label_field, label_map) dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) test_dataset, _ = utils.load_classification_dataset( args.test, tokenizer, args.field_a, args.field_b, args.label_field, label_map) test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) optimizer = torch.optim.Adam(model.classifier.parameters(), lr=args.lr, weight_decay=1e-6) # if not args.ckpt_dir.exists(): # logger.info(f'Making checkpoint directory: {args.ckpt_dir}') # args.ckpt_dir.mkdir(parents=True) # elif not args.force_overwrite: # raise RuntimeError('Checkpoint directory already exists.') best_accuracy = 0 for epoch in range(args.epochs): logger.info('Training...') model.train() avg_loss = utils.ExponentialMovingAverage() pbar = tqdm(train_loader) for model_inputs, labels in pbar: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) optimizer.zero_grad() logits, *_ = model(**model_inputs) loss = F.cross_entropy(logits, labels.squeeze(-1)) loss.backward() optimizer.step() avg_loss.update(loss.item()) pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}') logger.info('Evaluating...') model.eval() correct = 0 total = 0 for model_inputs, labels in dev_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) logits, *_ = model(**model_inputs) _, preds = logits.max(dim=-1) correct += (preds == labels.squeeze(-1)).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}') if accuracy > best_accuracy: logger.info('Best performance so far.') # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME) # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME) # tokenizer.save_pretrained(args.ckpt_dir) best_accuracy = accuracy logger.info('Testing...') model.eval() correct = 0 total = 0 for model_inputs, labels in test_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) logits, *_ = model(**model_inputs) _, preds = logits.max(dim=-1) correct += (preds == labels.squeeze(-1)).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}')
def main(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels) tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config) model.to(device) collator = utils.Collator(pad_token_id=tokenizer.pad_token_id) train_dataset, label_map = utils.load_classification_dataset( args.train, tokenizer, args.field_a, args.field_b, args.label_field, limit=args.limit ) train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) dev_dataset, _ = utils.load_classification_dataset( args.dev, tokenizer, args.field_a, args.field_b, args.label_field, label_map ) dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) test_dataset, _ = utils.load_classification_dataset( args.test, tokenizer, args.field_a, args.field_b, args.label_field, label_map ) test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) if args.bias_correction: betas = (0.9, 0.999) else: betas = (0.0, 0.000) optimizer = AdamW( model.parameters(), lr=args.lr, weight_decay=1e-2, betas=betas ) # Use suggested learning rate scheduler num_training_steps = len(train_dataset) * args.epochs // args.bsz num_warmup_steps = num_training_steps // 10 scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps) if not args.ckpt_dir.exists(): logger.info(f'Making checkpoint directory: {args.ckpt_dir}') args.ckpt_dir.mkdir(parents=True) elif not args.force_overwrite: raise RuntimeError('Checkpoint directory already exists.') try: best_accuracy = 0 for epoch in range(args.epochs): logger.info('Training...') model.train() avg_loss = utils.ExponentialMovingAverage() pbar = tqdm(train_loader) for model_inputs, labels in pbar: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) optimizer.zero_grad() logits, *_ = model(**model_inputs) loss = F.cross_entropy(logits, labels.squeeze(-1)) loss.backward() optimizer.step() scheduler.step() avg_loss.update(loss.item()) pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}, ' f'lr: {optimizer.param_groups[0]["lr"]: .3e}') logger.info('Evaluating...') model.eval() correct = 0 total = 0 with torch.no_grad(): for model_inputs, labels in dev_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) logits, *_ = model(**model_inputs) _, preds = logits.max(dim=-1) correct += (preds == labels.squeeze(-1)).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}') if accuracy > best_accuracy: logger.info('Best performance so far.') model.save_pretrained(args.ckpt_dir) tokenizer.save_pretrained(args.ckpt_dir) best_accuracy = accuracy except KeyboardInterrupt: logger.info('Interrupted...') logger.info('Testing...') model.eval() correct = 0 total = 0 with torch.no_grad(): for model_inputs, labels in test_loader: model_inputs = {k: v.to(device) for k, v in model_inputs.items()} labels = labels.to(device) logits, *_ = model(**model_inputs) _, preds = logits.max(dim=-1) correct += (preds == labels.squeeze(-1)).sum().item() total += labels.size(0) accuracy = correct / (total + 1e-13) logger.info(f'Accuracy: {accuracy : 0.4f}')