def evaluate(model: BertPreTrainedModel, iterator: DataLoader) -> float: model.eval() total = [] for batch in tqdm(list(iterator), desc='eval'): with torch.no_grad(): loss = model(**batch)[0] total += [loss.item()] model.train() return sum(total) / len(total)
def train_epoch(model: BertPreTrainedModel, optimizer: torch.optim.Optimizer, iterator: DataLoader, args: TrainingArguments, num_epoch=0): model.train() train_loss = 0 for step, batch in enumerate(tqdm(iterator, desc="train")): loss = model(**batch)[0] loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() model.zero_grad() train_loss += loss.item() if args.writer: args.writer.add_scalar('Loss/train', loss.item(), num_epoch * len(iterator) + step) if step > 0 and step % args.save_steps == 0: model.save_pretrained(args.output_dir) logger.info(f"epoch: {num_epoch + step / len(iterator)}") logger.info(f"train loss: {train_loss / args.save_steps}") train_loss = 0 model.save_pretrained(args.output_dir)