def get_averaged_grad(
    model: GPT2LMHeadModel,
    trigger_tokens: torch.Tensor,
    target_tokens: torch.Tensor,
    targets_embeddings: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Return the gradient of the trigger tokens wrt to the loss averaged across all the targets.
    """
    num_targets = target_tokens.shape[0]
    trigger_length = trigger_tokens.shape[0]
    targets_padding_mask = target_tokens.eq(-1)

    model_input_embeddings = model.get_input_embeddings()
    trigger_embeddings = (
        model_input_embeddings(trigger_tokens).detach().requires_grad_(True)
    )
    if targets_embeddings is None:
        target_inputs = target_tokens.clone()
        target_inputs[targets_padding_mask] = 1
        targets_embeddings = model_input_embeddings(target_inputs)
    lm_input = torch.cat(
        [
            trigger_embeddings.unsqueeze(0).expand(
                num_targets, *trigger_embeddings.shape
            ),
            targets_embeddings,
        ],
        dim=1,
    )
    model.zero_grad()
    attention_mask = torch.cat(
        [
            torch.ones(
                (num_targets, trigger_length),
                device=target_tokens.device,
                dtype=torch.bool,
            ),
            targets_padding_mask.logical_not(),
        ],
        dim=1,
    )
    lm_output = model(inputs_embeds=lm_input, attention_mask=attention_mask)
    logits = lm_output[0]
    target_logits = logits[:, trigger_tokens.shape[0] - 1 : -1, :].reshape(
        num_targets * target_tokens.shape[1], -1
    )
    loss = torch.nn.functional.cross_entropy(
        target_logits, target_tokens.view(-1), ignore_index=-100
    )
    loss.backward()
    embeddings_average_grad = trigger_embeddings.grad.detach()
    model.zero_grad()
    return embeddings_average_grad
Example #2
0
def train(tokenizer: Tokenizer, model: GPT2LMHeadModel,
          args: TrainingArguments, writer: SummaryWriter, logger,
          test_dataset: MyDataset):
    train_dataset = MyDataset(get_corpus(args.corpus_path), tokenizer,
                              args.block_size)
    optimizer = AdamW(model.parameters(), lr=args.learning_rate)
    # num_training_steps = len(train_dataset) // args.train_batch_size * args.num_train_epochs
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
    #                                             num_training_steps=num_training_steps)
    i = 0
    try:
        os.mkdir(args.output_dir)
        os.mkdir(args.output_dir + "/best")
    except FileExistsError:
        pass
    prev_loss = eval(tokenizer, model, test_dataset, args)
    logger.info(f"eval loss: {prev_loss}")
    writer.add_scalar('Loss/eval', prev_loss, i)
    train_loss = 0
    no_save_counter = 0
    for _ in range(args.num_train_epochs):
        iterator = build_data_iterator(tokenizer,
                                       train_dataset,
                                       args.train_batch_size,
                                       args.block_size,
                                       random_sampler=True)
        for ids, attention_mask in tqdm(iterator, desc='train'):
            i += 1
            ids = ids.to(args.device)
            loss = model(ids,
                         attention_mask=attention_mask.to(args.device),
                         labels=ids)[0]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)
            optimizer.step()
            # scheduler.step()
            model.zero_grad()
            writer.add_scalar('Loss/train', loss.item(), i)
            train_loss += loss.item()
            if i % args.save_steps == 0:
                model.save_pretrained(args.output_dir)
            if args.evaluate_during_training and i % args.logging_steps == 0:
                logger.info(f"epoch: {i / len(iterator)}")
                logger.info(f"train loss: {train_loss / args.logging_steps}")
                train_loss = 0
                # lr = scheduler.get_last_lr()[0]
                # logger.info(f"lr: {lr}")
                eval_loss = eval(tokenizer, model, test_dataset, args)
                logger.info(f"eval loss: {eval_loss}")
                writer.add_scalar('Loss/eval', eval_loss, i)
                # writer.add_scalar('LR', lr, i)
                if prev_loss > eval_loss:
                    prev_loss = eval_loss
                    model.save_pretrained(args.output_dir + "/best")
                    no_save_counter = 0
                else:
                    no_save_counter += 1
                    logger.info(
                        f"модель не улучшалась {no_save_counter} раз подряд. best_eval: {prev_loss}"
                    )
    eval_loss = eval(tokenizer, model, test_dataset, args)
    logger.info(f"eval loss: {eval_loss}")
    model.save_pretrained(args.output_dir)