Exemple #1
0
def train():
    config = TextCNNConfig()
    logger = get_logger(config.log_path, "train_textcnn")
    model = TextCNN(config)
    train_loader = DataLoader(CnnDataSet(config.base_config.train_data_path), batch_size=config.batch_size, shuffle=True)
    dev_loader = DataLoader(CnnDataSet(config.base_config.dev_data_path), batch_size=config.batch_size, shuffle=False)
    model.train()
    model.to(config.device)

    optimizer = Adam(model.parameters(), lr=config.learning_rate)
    best_acc = 0.

    for epoch in range(config.num_epochs):
        for i, (texts, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            texts = texts.to(config.device)
            labels = labels.to(config.device)
            logits = model(texts)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                labels = labels.data.cpu().numpy()
                preds = torch.argmax(logits, dim=1)
                preds = preds.data.cpu().numpy()
                acc = np.sum(preds == labels) * 1. / len(preds)
                logger.info("TRAIN: epoch: {} step: {} acc: {} loss: {} ".format(epoch + 1, i, acc, loss.item()))

        acc, table = dev(model, dev_loader, config)

        logger.info("DEV: acc: {} ".format(acc))
        logger.info("DEV classification report: \n{}".format(table))

        if acc > best_acc:
            torch.save(model.state_dict(), config.model_path)
            best_acc = acc

    test_loader = DataLoader(CnnDataSet(config.base_config.test_data_path), batch_size=config.batch_size, shuffle=False)
    best_model = TextCNN(config)
    best_model.load_state_dict(torch.load(config.model_path))
    acc, table = dev(best_model, test_loader, config)

    logger.info("TEST acc: {}".format(acc))
    logger.info("TEST classification report:\n{}".format(table))
Exemple #2
0
def train():
    config = KDConfig()

    logger = get_logger(config.log_path, "train_KD")

    device = config.device

    # 加载bert模型,作为teacher
    logger.info("load bert .....")
    bert = Bert(config.bert_config)
    bert.load_state_dict(torch.load(config.bert_config.model_path))
    bert.to(device)
    bert.eval()

    # 冻结bert参数
    for name, p in bert.named_parameters():
        p.requires_grad = False

    # 加载textcnn模型,作为student
    textcnn = TextCNN(config.textcnn_config)
    textcnn.to(device)
    textcnn.train()

    # 加载数据集
    logger.info("load train/dev data .....")
    train_loader = DataLoader(KDdataset(config.base_config.train_data_path),
                              batch_size=config.batch_size,
                              shuffle=True)
    dev_loader = DataLoader(KDdataset(config.base_config.dev_data_path),
                            batch_size=config.batch_size,
                            shuffle=False)

    optimizer = Adam(textcnn.parameters(), lr=config.lr)

    # 开始训练
    logger.info("start training .....")
    best_acc = 0.
    for epoch in range(config.epochs):
        for i, batch in enumerate(train_loader):
            cnn_ids, labels, input_ids, token_type_ids, attention_mask = batch[0].to(device), batch[1].to(device), \
                                                                         batch[2].to(device), batch[3].to(device), \
                                                                         batch[4].to(device)
            optimizer.zero_grad()
            students_output = textcnn(cnn_ids)
            teacher_output = bert(input_ids, token_type_ids, attention_mask)
            loss = loss_fn_kd(students_output, labels, teacher_output,
                              config.T, config.alpha)
            loss.backward()
            optimizer.step()

            # 打印信息
            if i % 100 == 0:
                labels = labels.data.cpu().numpy()
                preds = torch.argmax(students_output, dim=1)
                preds = preds.data.cpu().numpy()
                acc = np.sum(preds == labels) * 1. / len(preds)
                logger.info(
                    "TRAIN: epoch: {} step: {} acc: {} loss: {} ".format(
                        epoch + 1, i, acc, loss.item()))

        acc, table = dev(textcnn, dev_loader, config)

        logger.info("DEV: acc: {} ".format(acc))
        logger.info("DEV classification report: \n{}".format(table))

        if acc > best_acc:
            torch.save(textcnn.state_dict(), config.model_path)
            best_acc = acc

    logger.info("start testing ......")
    test_loader = DataLoader(KDdataset(config.base_config.test_data_path),
                             batch_size=config.batch_size,
                             shuffle=False)
    best_model = TextCNN(config.textcnn_config)
    best_model.load_state_dict(torch.load(config.model_path))
    acc, table = dev(best_model, test_loader, config)

    logger.info("TEST acc: {}".format(acc))
    logger.info("TEST classification report:\n{}".format(table))
Exemple #3
0
from data_preparation import THCNewsDataSet, batch_iter
import torch
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
from config import Config
from textcnn import TextCNN
import torch.nn as nn
import torch.nn.functional as F

device = torch.device(
    "cuda:0") if torch.cuda.is_available() else torch.device("cpu")


model = TextCNN()

model = model.to(device)

opt = optim.Adam(model.parameters())

criterion = nn.CrossEntropyLoss()


def save_model(model, model_name="best_model_sofa.pkl", model_save_dir="./trained_models/"):
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    path = os.path.join(model_save_dir, model_name)

    torch.save(model.state_dict(), path)
    print("saved model state dict at :"+path)