Пример #1
0
def train_model(model,
                train_loader, dev_loader,
                optimizer, criterion,
                num_classes, target_classes,
                label_encoder,
                device):

    # create to Meter's classes to track the performance of the model during training and evaluating
    train_meter = Meter(target_classes)
    dev_meter = Meter(target_classes)

    best_f1 = -1

    # epoch loop
    for epoch in range(args.epochs):
        train_tqdm = tqdm(train_loader)
        dev_tqdm = tqdm(dev_loader)

        model.train()

        # train loop
        for i, (train_x, train_y, mask, crf_mask) in enumerate(train_tqdm):
            # get the logits and update the gradients
            optimizer.zero_grad()

            logits = model.forward(train_x, mask)

            if args.no_crf:
                loss = criterion(logits.reshape(-1, num_classes).to(device), train_y.reshape(-1).to(device))
            else:
                loss = - criterion(logits.to(device), train_y, reduction="token_mean", mask=crf_mask)

            loss.backward()
            optimizer.step()

            # get the current metrics (average over all the train)
            loss, _, _, micro_f1, _, _, macro_f1 = train_meter.update_params(loss.item(), logits, train_y)

            # print the metrics
            train_tqdm.set_description("Epoch: {}/{}, Train Loss: {:.4f}, Train Micro F1: {:.4f}, Train Macro F1: {:.4f}".
                                       format(epoch + 1, args.epochs, loss, micro_f1, macro_f1))
            train_tqdm.refresh()

        # reset the metrics to 0
        train_meter.reset()

        model.eval()

        # evaluation loop -> mostly same as the training loop, but without updating the parameters
        for i, (dev_x, dev_y, mask, crf_mask) in enumerate(dev_tqdm):
            logits = model.forward(dev_x, mask)

            if args.no_crf:
                loss = criterion(logits.reshape(-1, num_classes).to(device), dev_y.reshape(-1).to(device))
            else:
                loss = - criterion(logits.to(device), dev_y, reduction="token_mean", mask=crf_mask)

            loss, _, _, micro_f1, _, _, macro_f1 = dev_meter.update_params(loss.item(), logits, dev_y)

            dev_tqdm.set_description("Dev Loss: {:.4f}, Dev Micro F1: {:.4f}, Dev Macro F1: {:.4f}".
                                     format(loss, micro_f1, macro_f1))
            dev_tqdm.refresh()

        dev_meter.reset()

        # if the current macro F1 score is the best one -> save the model
        if macro_f1 > best_f1:
            if not os.path.exists(args.save_path):
                os.makedirs(args.save_path)

            print("Macro F1 score improved from {:.4f} -> {:.4f}. Saving model...".format(best_f1, macro_f1))

            best_f1 = macro_f1
            torch.save(model, os.path.join(args.save_path, "model.pt"))
            with open(os.path.join(args.save_path, "label_encoder.pk"), "wb") as file:
                pickle.dump(label_encoder, file)
Пример #2
0
def train_model(model, train_loader, dev_loader, optimizer, criterion,
                num_classes, target_classes, it, label_encoder, device):

    # create to Meter's classes to track the performance of the model during training and evaluating
    train_meter = Meter(target_classes)
    dev_meter = Meter(target_classes)

    best_f1 = 0
    loss, macro_f1 = 0, 0

    total_steps = len(train_loader) * args.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,  # Default value in run_glue.py
        num_training_steps=total_steps)

    curr_patience = 0

    # epoch loop
    for epoch in range(args.epochs):
        train_tqdm = tqdm(train_loader, leave=False)

        model.train()

        # train loop
        for i, (train_x, train_y, mask) in enumerate(train_tqdm):
            train_tqdm.set_description(
                "    Training - Epoch: {}/{}, Loss: {:.4f}, F1: {:.4f}, Best F1: {:.4f}"
                .format(epoch + 1, args.epochs, loss, macro_f1, best_f1))
            train_tqdm.refresh()

            # get the logits and update the gradients
            optimizer.zero_grad()

            logits = model.forward(train_x, mask)

            loss = criterion(
                logits.reshape(-1, num_classes).to(device),
                train_y.reshape(-1).to(device))
            loss.backward()
            optimizer.step()

            if args.fine_tune:
                scheduler.step()

            # get the current metrics (average over all the train)
            loss, _, _, _, _, _, macro_f1 = train_meter.update_params(
                loss.item(), logits, train_y)

        # reset the metrics to 0
        train_meter.reset()

        dev_tqdm = tqdm(dev_loader, leave=False)
        model.eval()
        loss, macro_f1 = 0, 0

        # evaluation loop -> mostly same as the training loop, but without updating the parameters
        for i, (dev_x, dev_y, mask) in enumerate(dev_tqdm):
            dev_tqdm.set_description(
                "    Evaluating - Epoch: {}/{}, Loss: {:.4f}, F1: {:.4f}, Best F1: {:.4f}"
                .format(epoch + 1, args.epochs, loss, macro_f1, best_f1))
            dev_tqdm.refresh()

            logits = model.forward(dev_x, mask)
            loss = criterion(
                logits.reshape(-1, num_classes).to(device),
                dev_y.reshape(-1).to(device))

            loss, _, _, micro_f1, _, _, macro_f1 = dev_meter.update_params(
                loss.item(), logits, dev_y)

        dev_meter.reset()

        # if the current macro F1 score is the best one -> save the model
        if macro_f1 > best_f1:
            curr_patience = 0
            best_f1 = macro_f1
            torch.save(
                model,
                os.path.join(args.save_path, "model_{}.pt".format(it + 1)))
            with open(os.path.join(args.save_path, "label_encoder.pk"),
                      "wb") as file:
                pickle.dump(label_encoder, file)
        else:
            curr_patience += 1

        if curr_patience > args.patience:
            break

    return best_f1