示例#1
0
def train():
    config = TextCNNConfig()
    logger = get_logger(config.log_path, "train_textcnn")
    model = TextCNN(config)
    train_loader = DataLoader(CnnDataSet(config.base_config.train_data_path), batch_size=config.batch_size, shuffle=True)
    dev_loader = DataLoader(CnnDataSet(config.base_config.dev_data_path), batch_size=config.batch_size, shuffle=False)
    model.train()
    model.to(config.device)

    optimizer = Adam(model.parameters(), lr=config.learning_rate)
    best_acc = 0.

    for epoch in range(config.num_epochs):
        for i, (texts, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            texts = texts.to(config.device)
            labels = labels.to(config.device)
            logits = model(texts)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                labels = labels.data.cpu().numpy()
                preds = torch.argmax(logits, dim=1)
                preds = preds.data.cpu().numpy()
                acc = np.sum(preds == labels) * 1. / len(preds)
                logger.info("TRAIN: epoch: {} step: {} acc: {} loss: {} ".format(epoch + 1, i, acc, loss.item()))

        acc, table = dev(model, dev_loader, config)

        logger.info("DEV: acc: {} ".format(acc))
        logger.info("DEV classification report: \n{}".format(table))

        if acc > best_acc:
            torch.save(model.state_dict(), config.model_path)
            best_acc = acc

    test_loader = DataLoader(CnnDataSet(config.base_config.test_data_path), batch_size=config.batch_size, shuffle=False)
    best_model = TextCNN(config)
    best_model.load_state_dict(torch.load(config.model_path))
    acc, table = dev(best_model, test_loader, config)

    logger.info("TEST acc: {}".format(acc))
    logger.info("TEST classification report:\n{}".format(table))
示例#2
0
def train():
    config = KDConfig()

    logger = get_logger(config.log_path, "train_KD")

    device = config.device

    # 加载bert模型,作为teacher
    logger.info("load bert .....")
    bert = Bert(config.bert_config)
    bert.load_state_dict(torch.load(config.bert_config.model_path))
    bert.to(device)
    bert.eval()

    # 冻结bert参数
    for name, p in bert.named_parameters():
        p.requires_grad = False

    # 加载textcnn模型,作为student
    textcnn = TextCNN(config.textcnn_config)
    textcnn.to(device)
    textcnn.train()

    # 加载数据集
    logger.info("load train/dev data .....")
    train_loader = DataLoader(KDdataset(config.base_config.train_data_path),
                              batch_size=config.batch_size,
                              shuffle=True)
    dev_loader = DataLoader(KDdataset(config.base_config.dev_data_path),
                            batch_size=config.batch_size,
                            shuffle=False)

    optimizer = Adam(textcnn.parameters(), lr=config.lr)

    # 开始训练
    logger.info("start training .....")
    best_acc = 0.
    for epoch in range(config.epochs):
        for i, batch in enumerate(train_loader):
            cnn_ids, labels, input_ids, token_type_ids, attention_mask = batch[0].to(device), batch[1].to(device), \
                                                                         batch[2].to(device), batch[3].to(device), \
                                                                         batch[4].to(device)
            optimizer.zero_grad()
            students_output = textcnn(cnn_ids)
            teacher_output = bert(input_ids, token_type_ids, attention_mask)
            loss = loss_fn_kd(students_output, labels, teacher_output,
                              config.T, config.alpha)
            loss.backward()
            optimizer.step()

            # 打印信息
            if i % 100 == 0:
                labels = labels.data.cpu().numpy()
                preds = torch.argmax(students_output, dim=1)
                preds = preds.data.cpu().numpy()
                acc = np.sum(preds == labels) * 1. / len(preds)
                logger.info(
                    "TRAIN: epoch: {} step: {} acc: {} loss: {} ".format(
                        epoch + 1, i, acc, loss.item()))

        acc, table = dev(textcnn, dev_loader, config)

        logger.info("DEV: acc: {} ".format(acc))
        logger.info("DEV classification report: \n{}".format(table))

        if acc > best_acc:
            torch.save(textcnn.state_dict(), config.model_path)
            best_acc = acc

    logger.info("start testing ......")
    test_loader = DataLoader(KDdataset(config.base_config.test_data_path),
                             batch_size=config.batch_size,
                             shuffle=False)
    best_model = TextCNN(config.textcnn_config)
    best_model.load_state_dict(torch.load(config.model_path))
    acc, table = dev(best_model, test_loader, config)

    logger.info("TEST acc: {}".format(acc))
    logger.info("TEST classification report:\n{}".format(table))
示例#3
0
def main():
    for i in range(10):
        # 加载配置文件
        config = Config()
        if torch.cuda.is_available():
            torch.cuda.set_device(0)
        # 加载数据集
        early_stopping = EarlyStopping(patience=10, verbose=True, cv_index=i)
        kwargs = {'num_workers': 2, 'pin_memory': True}
        dataset_train = MR_dataset(config=config,
                                   state="train",
                                   k=i,
                                   embedding_state=True)
        train_data_batch = DataLoader(dataset_train,
                                      batch_size=config.batch_size,
                                      shuffle=False,
                                      drop_last=False,
                                      **kwargs)
        dataset_valid = MR_dataset(config=config,
                                   state="valid",
                                   k=i,
                                   embedding_state=False)
        valid_data_batch = DataLoader(dataset_valid,
                                      batch_size=config.batch_size,
                                      shuffle=False,
                                      drop_last=False,
                                      **kwargs)
        dataset_test = MR_dataset(config=config,
                                  state="test",
                                  k=i,
                                  embedding_state=False)
        test_data_batch = DataLoader(dataset_test,
                                     batch_size=config.batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     **kwargs)
        print(len(dataset_train), len(dataset_valid), len(dataset_test))

        if config.use_pretrained_embed:
            config.embedding_pretrained = torch.from_numpy(
                dataset_train.weight).float().cuda()
            print("load pretrained models.")
        else:
            config.embedding_pretrained = None

        config.vocab_size = dataset_train.vocab_size

        model = TextCNN(config)
        print(model)

        if config.use_cuda and torch.cuda.is_available():
            # print("load data to CUDA")
            model.cuda()
            # config.embedding_pretrained.cuda()

        criterion = nn.CrossEntropyLoss()  # 定义为交叉熵损失函数
        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
        count = 0
        loss_sum = 0.0
        for epoch in range(config.epoch):
            # 开始训练
            model.train()
            for data, label in train_data_batch:
                if config.use_cuda and torch.cuda.is_available():
                    data = data.to(torch.int64).cuda()
                    label = label.cuda()
                else:
                    data.to(torch.int64)
                # data = torch.autograd.Variable(data).long().cuda()
                # label = torch.autograd.Variable(label).squeeze()
                out = model(data)
                l2_loss = config.l2_weight * torch.sum(
                    torch.pow(list(model.parameters())[1], 2))
                loss = criterion(out, autograd.Variable(
                    label.long())) + l2_loss
                loss_sum += loss.data.item()
                count += 1
                if count % 100 == 0:
                    print("epoch", epoch, end='  ')
                    print("The loss is: %.5f" % (loss_sum / 100))
                    loss_sum = 0
                    count = 0
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            # 一轮训练结束,在验证集测试
            valid_loss, valid_acc = get_test_result(model, valid_data_batch,
                                                    dataset_valid, config,
                                                    criterion)
            early_stopping(valid_loss, model, config)
            print("The valid acc is: %.5f" % valid_acc)
            if early_stopping.early_stop:
                print("Early stopping")
                break
        # 1 fold训练结果
        model.load_state_dict(
            torch.load(
                os.path.abspath(
                    os.path.join(config.checkpoint_path,
                                 'checkpoint%d.pt' % i))))
        test_loss, test_acc = get_test_result(model, test_data_batch,
                                              dataset_test, config, criterion)
        print("The test acc is: %.5f" % test_acc)