def train_model(model, train_loader, dev_loader, optimizer, criterion, num_classes, target_classes, label_encoder, device): # create to Meter's classes to track the performance of the model during training and evaluating train_meter = Meter(target_classes) dev_meter = Meter(target_classes) best_f1 = -1 # epoch loop for epoch in range(args.epochs): train_tqdm = tqdm(train_loader) dev_tqdm = tqdm(dev_loader) model.train() # train loop for i, (train_x, train_y, mask, crf_mask) in enumerate(train_tqdm): # get the logits and update the gradients optimizer.zero_grad() logits = model.forward(train_x, mask) if args.no_crf: loss = criterion(logits.reshape(-1, num_classes).to(device), train_y.reshape(-1).to(device)) else: loss = - criterion(logits.to(device), train_y, reduction="token_mean", mask=crf_mask) loss.backward() optimizer.step() # get the current metrics (average over all the train) loss, _, _, micro_f1, _, _, macro_f1 = train_meter.update_params(loss.item(), logits, train_y) # print the metrics train_tqdm.set_description("Epoch: {}/{}, Train Loss: {:.4f}, Train Micro F1: {:.4f}, Train Macro F1: {:.4f}". format(epoch + 1, args.epochs, loss, micro_f1, macro_f1)) train_tqdm.refresh() # reset the metrics to 0 train_meter.reset() model.eval() # evaluation loop -> mostly same as the training loop, but without updating the parameters for i, (dev_x, dev_y, mask, crf_mask) in enumerate(dev_tqdm): logits = model.forward(dev_x, mask) if args.no_crf: loss = criterion(logits.reshape(-1, num_classes).to(device), dev_y.reshape(-1).to(device)) else: loss = - criterion(logits.to(device), dev_y, reduction="token_mean", mask=crf_mask) loss, _, _, micro_f1, _, _, macro_f1 = dev_meter.update_params(loss.item(), logits, dev_y) dev_tqdm.set_description("Dev Loss: {:.4f}, Dev Micro F1: {:.4f}, Dev Macro F1: {:.4f}". format(loss, micro_f1, macro_f1)) dev_tqdm.refresh() dev_meter.reset() # if the current macro F1 score is the best one -> save the model if macro_f1 > best_f1: if not os.path.exists(args.save_path): os.makedirs(args.save_path) print("Macro F1 score improved from {:.4f} -> {:.4f}. Saving model...".format(best_f1, macro_f1)) best_f1 = macro_f1 torch.save(model, os.path.join(args.save_path, "model.pt")) with open(os.path.join(args.save_path, "label_encoder.pk"), "wb") as file: pickle.dump(label_encoder, file)
def train_model(model, train_loader, dev_loader, optimizer, criterion, num_classes, target_classes, it, label_encoder, device): # create to Meter's classes to track the performance of the model during training and evaluating train_meter = Meter(target_classes) dev_meter = Meter(target_classes) best_f1 = 0 loss, macro_f1 = 0, 0 total_steps = len(train_loader) * args.epochs scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, # Default value in run_glue.py num_training_steps=total_steps) curr_patience = 0 # epoch loop for epoch in range(args.epochs): train_tqdm = tqdm(train_loader, leave=False) model.train() # train loop for i, (train_x, train_y, mask) in enumerate(train_tqdm): train_tqdm.set_description( " Training - Epoch: {}/{}, Loss: {:.4f}, F1: {:.4f}, Best F1: {:.4f}" .format(epoch + 1, args.epochs, loss, macro_f1, best_f1)) train_tqdm.refresh() # get the logits and update the gradients optimizer.zero_grad() logits = model.forward(train_x, mask) loss = criterion( logits.reshape(-1, num_classes).to(device), train_y.reshape(-1).to(device)) loss.backward() optimizer.step() if args.fine_tune: scheduler.step() # get the current metrics (average over all the train) loss, _, _, _, _, _, macro_f1 = train_meter.update_params( loss.item(), logits, train_y) # reset the metrics to 0 train_meter.reset() dev_tqdm = tqdm(dev_loader, leave=False) model.eval() loss, macro_f1 = 0, 0 # evaluation loop -> mostly same as the training loop, but without updating the parameters for i, (dev_x, dev_y, mask) in enumerate(dev_tqdm): dev_tqdm.set_description( " Evaluating - Epoch: {}/{}, Loss: {:.4f}, F1: {:.4f}, Best F1: {:.4f}" .format(epoch + 1, args.epochs, loss, macro_f1, best_f1)) dev_tqdm.refresh() logits = model.forward(dev_x, mask) loss = criterion( logits.reshape(-1, num_classes).to(device), dev_y.reshape(-1).to(device)) loss, _, _, micro_f1, _, _, macro_f1 = dev_meter.update_params( loss.item(), logits, dev_y) dev_meter.reset() # if the current macro F1 score is the best one -> save the model if macro_f1 > best_f1: curr_patience = 0 best_f1 = macro_f1 torch.save( model, os.path.join(args.save_path, "model_{}.pt".format(it + 1))) with open(os.path.join(args.save_path, "label_encoder.pk"), "wb") as file: pickle.dump(label_encoder, file) else: curr_patience += 1 if curr_patience > args.patience: break return best_f1