def train(): config = TextCNNConfig() logger = get_logger(config.log_path, "train_textcnn") model = TextCNN(config) train_loader = DataLoader(CnnDataSet(config.base_config.train_data_path), batch_size=config.batch_size, shuffle=True) dev_loader = DataLoader(CnnDataSet(config.base_config.dev_data_path), batch_size=config.batch_size, shuffle=False) model.train() model.to(config.device) optimizer = Adam(model.parameters(), lr=config.learning_rate) best_acc = 0. for epoch in range(config.num_epochs): for i, (texts, labels) in enumerate(train_loader): optimizer.zero_grad() texts = texts.to(config.device) labels = labels.to(config.device) logits = model(texts) loss = F.cross_entropy(logits, labels) loss.backward() optimizer.step() if i % 100 == 0: labels = labels.data.cpu().numpy() preds = torch.argmax(logits, dim=1) preds = preds.data.cpu().numpy() acc = np.sum(preds == labels) * 1. / len(preds) logger.info("TRAIN: epoch: {} step: {} acc: {} loss: {} ".format(epoch + 1, i, acc, loss.item())) acc, table = dev(model, dev_loader, config) logger.info("DEV: acc: {} ".format(acc)) logger.info("DEV classification report: \n{}".format(table)) if acc > best_acc: torch.save(model.state_dict(), config.model_path) best_acc = acc test_loader = DataLoader(CnnDataSet(config.base_config.test_data_path), batch_size=config.batch_size, shuffle=False) best_model = TextCNN(config) best_model.load_state_dict(torch.load(config.model_path)) acc, table = dev(best_model, test_loader, config) logger.info("TEST acc: {}".format(acc)) logger.info("TEST classification report:\n{}".format(table))
def train(): config = KDConfig() logger = get_logger(config.log_path, "train_KD") device = config.device # 加载bert模型,作为teacher logger.info("load bert .....") bert = Bert(config.bert_config) bert.load_state_dict(torch.load(config.bert_config.model_path)) bert.to(device) bert.eval() # 冻结bert参数 for name, p in bert.named_parameters(): p.requires_grad = False # 加载textcnn模型,作为student textcnn = TextCNN(config.textcnn_config) textcnn.to(device) textcnn.train() # 加载数据集 logger.info("load train/dev data .....") train_loader = DataLoader(KDdataset(config.base_config.train_data_path), batch_size=config.batch_size, shuffle=True) dev_loader = DataLoader(KDdataset(config.base_config.dev_data_path), batch_size=config.batch_size, shuffle=False) optimizer = Adam(textcnn.parameters(), lr=config.lr) # 开始训练 logger.info("start training .....") best_acc = 0. for epoch in range(config.epochs): for i, batch in enumerate(train_loader): cnn_ids, labels, input_ids, token_type_ids, attention_mask = batch[0].to(device), batch[1].to(device), \ batch[2].to(device), batch[3].to(device), \ batch[4].to(device) optimizer.zero_grad() students_output = textcnn(cnn_ids) teacher_output = bert(input_ids, token_type_ids, attention_mask) loss = loss_fn_kd(students_output, labels, teacher_output, config.T, config.alpha) loss.backward() optimizer.step() # 打印信息 if i % 100 == 0: labels = labels.data.cpu().numpy() preds = torch.argmax(students_output, dim=1) preds = preds.data.cpu().numpy() acc = np.sum(preds == labels) * 1. / len(preds) logger.info( "TRAIN: epoch: {} step: {} acc: {} loss: {} ".format( epoch + 1, i, acc, loss.item())) acc, table = dev(textcnn, dev_loader, config) logger.info("DEV: acc: {} ".format(acc)) logger.info("DEV classification report: \n{}".format(table)) if acc > best_acc: torch.save(textcnn.state_dict(), config.model_path) best_acc = acc logger.info("start testing ......") test_loader = DataLoader(KDdataset(config.base_config.test_data_path), batch_size=config.batch_size, shuffle=False) best_model = TextCNN(config.textcnn_config) best_model.load_state_dict(torch.load(config.model_path)) acc, table = dev(best_model, test_loader, config) logger.info("TEST acc: {}".format(acc)) logger.info("TEST classification report:\n{}".format(table))
def main(): for i in range(10): # 加载配置文件 config = Config() if torch.cuda.is_available(): torch.cuda.set_device(0) # 加载数据集 early_stopping = EarlyStopping(patience=10, verbose=True, cv_index=i) kwargs = {'num_workers': 2, 'pin_memory': True} dataset_train = MR_dataset(config=config, state="train", k=i, embedding_state=True) train_data_batch = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=False, drop_last=False, **kwargs) dataset_valid = MR_dataset(config=config, state="valid", k=i, embedding_state=False) valid_data_batch = DataLoader(dataset_valid, batch_size=config.batch_size, shuffle=False, drop_last=False, **kwargs) dataset_test = MR_dataset(config=config, state="test", k=i, embedding_state=False) test_data_batch = DataLoader(dataset_test, batch_size=config.batch_size, shuffle=False, drop_last=False, **kwargs) print(len(dataset_train), len(dataset_valid), len(dataset_test)) if config.use_pretrained_embed: config.embedding_pretrained = torch.from_numpy( dataset_train.weight).float().cuda() print("load pretrained models.") else: config.embedding_pretrained = None config.vocab_size = dataset_train.vocab_size model = TextCNN(config) print(model) if config.use_cuda and torch.cuda.is_available(): # print("load data to CUDA") model.cuda() # config.embedding_pretrained.cuda() criterion = nn.CrossEntropyLoss() # 定义为交叉熵损失函数 optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) count = 0 loss_sum = 0.0 for epoch in range(config.epoch): # 开始训练 model.train() for data, label in train_data_batch: if config.use_cuda and torch.cuda.is_available(): data = data.to(torch.int64).cuda() label = label.cuda() else: data.to(torch.int64) # data = torch.autograd.Variable(data).long().cuda() # label = torch.autograd.Variable(label).squeeze() out = model(data) l2_loss = config.l2_weight * torch.sum( torch.pow(list(model.parameters())[1], 2)) loss = criterion(out, autograd.Variable( label.long())) + l2_loss loss_sum += loss.data.item() count += 1 if count % 100 == 0: print("epoch", epoch, end=' ') print("The loss is: %.5f" % (loss_sum / 100)) loss_sum = 0 count = 0 optimizer.zero_grad() loss.backward() optimizer.step() # 一轮训练结束,在验证集测试 valid_loss, valid_acc = get_test_result(model, valid_data_batch, dataset_valid, config, criterion) early_stopping(valid_loss, model, config) print("The valid acc is: %.5f" % valid_acc) if early_stopping.early_stop: print("Early stopping") break # 1 fold训练结果 model.load_state_dict( torch.load( os.path.abspath( os.path.join(config.checkpoint_path, 'checkpoint%d.pt' % i)))) test_loss, test_acc = get_test_result(model, test_data_batch, dataset_test, config, criterion) print("The test acc is: %.5f" % test_acc)