def load_model(self, path):
        V = len(self.vocab.char2id)
        d_model = 256
        d_ff = 1024
        h = 4
        n_encoders = 4

        self_attn = MultiHeadedAttention(h=h,
                                         d_model=d_model,
                                         d_k=d_model // h,
                                         d_v=d_model // h,
                                         dropout=0.1)
        feed_forward = FullyConnectedFeedForward(d_model=d_model, d_ff=d_ff)
        position = PositionalEncoding(d_model, dropout=0.1)
        embedding = nn.Sequential(Embeddings(d_model=d_model, vocab=V),
                                  position)

        encoder = Encoder(self_attn=self_attn,
                          feed_forward=feed_forward,
                          size=d_model,
                          dropout=0.1)
        generator = Generator3(d_model=d_model, vocab_size=V)
        model = Bert(encoder=encoder,
                     embedding=embedding,
                     generator=generator,
                     n_layers=n_encoders)
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        return model
示例#2
0
def train():
    config = KDConfig()

    logger = get_logger(config.log_path, "train_KD")

    device = config.device

    # 加载bert模型,作为teacher
    logger.info("load bert .....")
    bert = Bert(config.bert_config)
    bert.load_state_dict(torch.load(config.bert_config.model_path))
    bert.to(device)
    bert.eval()

    # 冻结bert参数
    for name, p in bert.named_parameters():
        p.requires_grad = False

    # 加载textcnn模型,作为student
    textcnn = TextCNN(config.textcnn_config)
    textcnn.to(device)
    textcnn.train()

    # 加载数据集
    logger.info("load train/dev data .....")
    train_loader = DataLoader(KDdataset(config.base_config.train_data_path),
                              batch_size=config.batch_size,
                              shuffle=True)
    dev_loader = DataLoader(KDdataset(config.base_config.dev_data_path),
                            batch_size=config.batch_size,
                            shuffle=False)

    optimizer = Adam(textcnn.parameters(), lr=config.lr)

    # 开始训练
    logger.info("start training .....")
    best_acc = 0.
    for epoch in range(config.epochs):
        for i, batch in enumerate(train_loader):
            cnn_ids, labels, input_ids, token_type_ids, attention_mask = batch[0].to(device), batch[1].to(device), \
                                                                         batch[2].to(device), batch[3].to(device), \
                                                                         batch[4].to(device)
            optimizer.zero_grad()
            students_output = textcnn(cnn_ids)
            teacher_output = bert(input_ids, token_type_ids, attention_mask)
            loss = loss_fn_kd(students_output, labels, teacher_output,
                              config.T, config.alpha)
            loss.backward()
            optimizer.step()

            # 打印信息
            if i % 100 == 0:
                labels = labels.data.cpu().numpy()
                preds = torch.argmax(students_output, dim=1)
                preds = preds.data.cpu().numpy()
                acc = np.sum(preds == labels) * 1. / len(preds)
                logger.info(
                    "TRAIN: epoch: {} step: {} acc: {} loss: {} ".format(
                        epoch + 1, i, acc, loss.item()))

        acc, table = dev(textcnn, dev_loader, config)

        logger.info("DEV: acc: {} ".format(acc))
        logger.info("DEV classification report: \n{}".format(table))

        if acc > best_acc:
            torch.save(textcnn.state_dict(), config.model_path)
            best_acc = acc

    logger.info("start testing ......")
    test_loader = DataLoader(KDdataset(config.base_config.test_data_path),
                             batch_size=config.batch_size,
                             shuffle=False)
    best_model = TextCNN(config.textcnn_config)
    best_model.load_state_dict(torch.load(config.model_path))
    acc, table = dev(best_model, test_loader, config)

    logger.info("TEST acc: {}".format(acc))
    logger.info("TEST classification report:\n{}".format(table))