def get_averaged_grad( model: GPT2LMHeadModel, trigger_tokens: torch.Tensor, target_tokens: torch.Tensor, targets_embeddings: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Return the gradient of the trigger tokens wrt to the loss averaged across all the targets. """ num_targets = target_tokens.shape[0] trigger_length = trigger_tokens.shape[0] targets_padding_mask = target_tokens.eq(-1) model_input_embeddings = model.get_input_embeddings() trigger_embeddings = ( model_input_embeddings(trigger_tokens).detach().requires_grad_(True) ) if targets_embeddings is None: target_inputs = target_tokens.clone() target_inputs[targets_padding_mask] = 1 targets_embeddings = model_input_embeddings(target_inputs) lm_input = torch.cat( [ trigger_embeddings.unsqueeze(0).expand( num_targets, *trigger_embeddings.shape ), targets_embeddings, ], dim=1, ) model.zero_grad() attention_mask = torch.cat( [ torch.ones( (num_targets, trigger_length), device=target_tokens.device, dtype=torch.bool, ), targets_padding_mask.logical_not(), ], dim=1, ) lm_output = model(inputs_embeds=lm_input, attention_mask=attention_mask) logits = lm_output[0] target_logits = logits[:, trigger_tokens.shape[0] - 1 : -1, :].reshape( num_targets * target_tokens.shape[1], -1 ) loss = torch.nn.functional.cross_entropy( target_logits, target_tokens.view(-1), ignore_index=-100 ) loss.backward() embeddings_average_grad = trigger_embeddings.grad.detach() model.zero_grad() return embeddings_average_grad
def train(tokenizer: Tokenizer, model: GPT2LMHeadModel, args: TrainingArguments, writer: SummaryWriter, logger, test_dataset: MyDataset): train_dataset = MyDataset(get_corpus(args.corpus_path), tokenizer, args.block_size) optimizer = AdamW(model.parameters(), lr=args.learning_rate) # num_training_steps = len(train_dataset) // args.train_batch_size * args.num_train_epochs # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, # num_training_steps=num_training_steps) i = 0 try: os.mkdir(args.output_dir) os.mkdir(args.output_dir + "/best") except FileExistsError: pass prev_loss = eval(tokenizer, model, test_dataset, args) logger.info(f"eval loss: {prev_loss}") writer.add_scalar('Loss/eval', prev_loss, i) train_loss = 0 no_save_counter = 0 for _ in range(args.num_train_epochs): iterator = build_data_iterator(tokenizer, train_dataset, args.train_batch_size, args.block_size, random_sampler=True) for ids, attention_mask in tqdm(iterator, desc='train'): i += 1 ids = ids.to(args.device) loss = model(ids, attention_mask=attention_mask.to(args.device), labels=ids)[0] loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() # scheduler.step() model.zero_grad() writer.add_scalar('Loss/train', loss.item(), i) train_loss += loss.item() if i % args.save_steps == 0: model.save_pretrained(args.output_dir) if args.evaluate_during_training and i % args.logging_steps == 0: logger.info(f"epoch: {i / len(iterator)}") logger.info(f"train loss: {train_loss / args.logging_steps}") train_loss = 0 # lr = scheduler.get_last_lr()[0] # logger.info(f"lr: {lr}") eval_loss = eval(tokenizer, model, test_dataset, args) logger.info(f"eval loss: {eval_loss}") writer.add_scalar('Loss/eval', eval_loss, i) # writer.add_scalar('LR', lr, i) if prev_loss > eval_loss: prev_loss = eval_loss model.save_pretrained(args.output_dir + "/best") no_save_counter = 0 else: no_save_counter += 1 logger.info( f"модель не улучшалась {no_save_counter} раз подряд. best_eval: {prev_loss}" ) eval_loss = eval(tokenizer, model, test_dataset, args) logger.info(f"eval loss: {eval_loss}") model.save_pretrained(args.output_dir)