def train(): config = BertConfig() logger = get_logger(config.log_path) model = Bert(config) device = config.device train_dataset = BertDataSet(config.base_config.train_data_path) dev_dataset = BertDataSet(config.base_config.dev_data_path) train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) dev_dataloader = DataLoader(dev_dataset, batch_size=config.batch_size, shuffle=False) optimizer = AdamW(model.parameters(), lr=config.lr) criterion = nn.CrossEntropyLoss() model.to(device) model.train() best_acc = 0. for epoch in range(config.epochs): for i, batch in enumerate(train_dataloader): optimizer.zero_grad() input_ids, token_type_ids, attention_mask, labels = batch[0].to(device), batch[1].to(device), batch[ 2].to( device), batch[3].to(device) logits = model(input_ids, token_type_ids, attention_mask) loss = criterion(logits, labels) loss.backward() optimizer.step() if i % 100 == 0: preds = torch.argmax(logits, dim=1) acc = torch.sum(preds == labels)*1. / len(labels) logger.info("TRAIN: epoch: {} step: {} acc: {}, loss: {}".format(epoch, i, acc, loss.item())) acc, cls_report = dev(model, dev_dataloader, config) logger.info("DEV: epoch: {} acc: {}".format(epoch, acc)) logger.info("DEV classification report:\n{}".format(cls_report)) if acc > best_acc: torch.save(model.state_dict(), config.model_path) best_acc = acc test_dataset = BertDataSet(config.base_config.test_data_path) test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) best_model = Bert(config) best_model.load_state_dict(torch.load(config.model_path)) acc, cls_report = dev(best_model, test_dataloader, config) logger.info("TEST: ACC:{}".format(acc)) logger.info("TEST classification report:\n{}".format(cls_report))
def load_model(self, path): device = torch.device('cpu') V = len(self.vocab.char2id) d_model = 64 d_ff = 256 h = 4 n_encoders = 4 self_attn = MultiHeadedAttention(h=h, d_model=d_model, d_k=d_model // h, d_v=d_model // h, dropout=0.1) feed_forward = FullyConnectedFeedForward(d_model=d_model, d_ff=d_ff) position = PositionalEncoding(d_model, dropout=0.1) embedding = nn.Sequential(Embeddings(d_model=d_model, vocab=V), position) encoder = Encoder(self_attn=self_attn, feed_forward=feed_forward, size=d_model, dropout=0.1) generator = Generator(d_model=d_model, vocab_size=V) model = Bert(encoder=encoder, embedding=embedding, generator=generator, n_layers=n_encoders) model = model.to(device) model_save_path = path checkpoint = torch.load(model_save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) return model
def train(): config = KDConfig() logger = get_logger(config.log_path, "train_KD") device = config.device # 加载bert模型,作为teacher logger.info("load bert .....") bert = Bert(config.bert_config) bert.load_state_dict(torch.load(config.bert_config.model_path)) bert.to(device) bert.eval() # 冻结bert参数 for name, p in bert.named_parameters(): p.requires_grad = False # 加载textcnn模型,作为student textcnn = TextCNN(config.textcnn_config) textcnn.to(device) textcnn.train() # 加载数据集 logger.info("load train/dev data .....") train_loader = DataLoader(KDdataset(config.base_config.train_data_path), batch_size=config.batch_size, shuffle=True) dev_loader = DataLoader(KDdataset(config.base_config.dev_data_path), batch_size=config.batch_size, shuffle=False) optimizer = Adam(textcnn.parameters(), lr=config.lr) # 开始训练 logger.info("start training .....") best_acc = 0. for epoch in range(config.epochs): for i, batch in enumerate(train_loader): cnn_ids, labels, input_ids, token_type_ids, attention_mask = batch[0].to(device), batch[1].to(device), \ batch[2].to(device), batch[3].to(device), \ batch[4].to(device) optimizer.zero_grad() students_output = textcnn(cnn_ids) teacher_output = bert(input_ids, token_type_ids, attention_mask) loss = loss_fn_kd(students_output, labels, teacher_output, config.T, config.alpha) loss.backward() optimizer.step() # 打印信息 if i % 100 == 0: labels = labels.data.cpu().numpy() preds = torch.argmax(students_output, dim=1) preds = preds.data.cpu().numpy() acc = np.sum(preds == labels) * 1. / len(preds) logger.info( "TRAIN: epoch: {} step: {} acc: {} loss: {} ".format( epoch + 1, i, acc, loss.item())) acc, table = dev(textcnn, dev_loader, config) logger.info("DEV: acc: {} ".format(acc)) logger.info("DEV classification report: \n{}".format(table)) if acc > best_acc: torch.save(textcnn.state_dict(), config.model_path) best_acc = acc logger.info("start testing ......") test_loader = DataLoader(KDdataset(config.base_config.test_data_path), batch_size=config.batch_size, shuffle=False) best_model = TextCNN(config.textcnn_config) best_model.load_state_dict(torch.load(config.model_path)) acc, table = dev(best_model, test_loader, config) logger.info("TEST acc: {}".format(acc)) logger.info("TEST classification report:\n{}".format(table))