def train(): config = TextCNNConfig() logger = get_logger(config.log_path, "train_textcnn") model = TextCNN(config) train_loader = DataLoader(CnnDataSet(config.base_config.train_data_path), batch_size=config.batch_size, shuffle=True) dev_loader = DataLoader(CnnDataSet(config.base_config.dev_data_path), batch_size=config.batch_size, shuffle=False) model.train() model.to(config.device) optimizer = Adam(model.parameters(), lr=config.learning_rate) best_acc = 0. for epoch in range(config.num_epochs): for i, (texts, labels) in enumerate(train_loader): optimizer.zero_grad() texts = texts.to(config.device) labels = labels.to(config.device) logits = model(texts) loss = F.cross_entropy(logits, labels) loss.backward() optimizer.step() if i % 100 == 0: labels = labels.data.cpu().numpy() preds = torch.argmax(logits, dim=1) preds = preds.data.cpu().numpy() acc = np.sum(preds == labels) * 1. / len(preds) logger.info("TRAIN: epoch: {} step: {} acc: {} loss: {} ".format(epoch + 1, i, acc, loss.item())) acc, table = dev(model, dev_loader, config) logger.info("DEV: acc: {} ".format(acc)) logger.info("DEV classification report: \n{}".format(table)) if acc > best_acc: torch.save(model.state_dict(), config.model_path) best_acc = acc test_loader = DataLoader(CnnDataSet(config.base_config.test_data_path), batch_size=config.batch_size, shuffle=False) best_model = TextCNN(config) best_model.load_state_dict(torch.load(config.model_path)) acc, table = dev(best_model, test_loader, config) logger.info("TEST acc: {}".format(acc)) logger.info("TEST classification report:\n{}".format(table))
def train(): config = KDConfig() logger = get_logger(config.log_path, "train_KD") device = config.device # 加载bert模型,作为teacher logger.info("load bert .....") bert = Bert(config.bert_config) bert.load_state_dict(torch.load(config.bert_config.model_path)) bert.to(device) bert.eval() # 冻结bert参数 for name, p in bert.named_parameters(): p.requires_grad = False # 加载textcnn模型,作为student textcnn = TextCNN(config.textcnn_config) textcnn.to(device) textcnn.train() # 加载数据集 logger.info("load train/dev data .....") train_loader = DataLoader(KDdataset(config.base_config.train_data_path), batch_size=config.batch_size, shuffle=True) dev_loader = DataLoader(KDdataset(config.base_config.dev_data_path), batch_size=config.batch_size, shuffle=False) optimizer = Adam(textcnn.parameters(), lr=config.lr) # 开始训练 logger.info("start training .....") best_acc = 0. for epoch in range(config.epochs): for i, batch in enumerate(train_loader): cnn_ids, labels, input_ids, token_type_ids, attention_mask = batch[0].to(device), batch[1].to(device), \ batch[2].to(device), batch[3].to(device), \ batch[4].to(device) optimizer.zero_grad() students_output = textcnn(cnn_ids) teacher_output = bert(input_ids, token_type_ids, attention_mask) loss = loss_fn_kd(students_output, labels, teacher_output, config.T, config.alpha) loss.backward() optimizer.step() # 打印信息 if i % 100 == 0: labels = labels.data.cpu().numpy() preds = torch.argmax(students_output, dim=1) preds = preds.data.cpu().numpy() acc = np.sum(preds == labels) * 1. / len(preds) logger.info( "TRAIN: epoch: {} step: {} acc: {} loss: {} ".format( epoch + 1, i, acc, loss.item())) acc, table = dev(textcnn, dev_loader, config) logger.info("DEV: acc: {} ".format(acc)) logger.info("DEV classification report: \n{}".format(table)) if acc > best_acc: torch.save(textcnn.state_dict(), config.model_path) best_acc = acc logger.info("start testing ......") test_loader = DataLoader(KDdataset(config.base_config.test_data_path), batch_size=config.batch_size, shuffle=False) best_model = TextCNN(config.textcnn_config) best_model.load_state_dict(torch.load(config.model_path)) acc, table = dev(best_model, test_loader, config) logger.info("TEST acc: {}".format(acc)) logger.info("TEST classification report:\n{}".format(table))
from data_preparation import THCNewsDataSet, batch_iter import torch import torch.optim as optim from torch.utils.data.dataloader import DataLoader from config import Config from textcnn import TextCNN import torch.nn as nn import torch.nn.functional as F device = torch.device( "cuda:0") if torch.cuda.is_available() else torch.device("cpu") model = TextCNN() model = model.to(device) opt = optim.Adam(model.parameters()) criterion = nn.CrossEntropyLoss() def save_model(model, model_name="best_model_sofa.pkl", model_save_dir="./trained_models/"): if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) path = os.path.join(model_save_dir, model_name) torch.save(model.state_dict(), path) print("saved model state dict at :"+path)