Skip to content

zjms/SentenceClassifcation

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 

Repository files navigation

SentenceClassification Model implemented with PyTorch

项目简介

本项目复现了用于文本分类的经典模型,目前有TextCNN和BERT,后续会添加LSTM、BiLSTM、FastText等经典模型。

运行代码

TextCNN为例,介绍代码使用流程, BertForSequenceClassification同理

下载代码

git clone git@github.com:unikcc/SentenceClassfication.git

进入主项目目录

cd SentenceClassfication

解压数据集

unzip data.zip

下载预训练词向量

打开链接https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit 下载完成后,解压gz文件,并放在主项目目录的data/embeddings下;或者其他你想要的位置,在每个项目的config.yaml文件中,embedding_path变量的值可以修改为相应的位置。

进入TextCNN目录

cd TextCNN

安装依赖

pip install -r requirements.txt

预处理

python preprocess.py

默认为MR数据集,运行下列命令可以处理SST2数据集 python preprocess.py --dataset SST2

运行训练+测试

python main.py 运行完毕之后,即可得到测试集的效果。

默认为MR数据集,运行下列命令可以处理SST2数据集 python main.py --dataset SST2

修改embededding模式

config.yaml文件中,修改train_moderandom, static, fine-tuned即可实现随机初始化、固定词向量和可训练词向量三种模式下的模型训练;

其他

config.yaml中可以修改相应配置,实现不同数据集的预测,目前支持SST2MR,其中MR是十折交叉验证;

复现结果

TextCNN

  • 原始论文

复现结果(括号内为复现结果)

模型选项 SST-2 MR
CNN-rand 82.7 (80.78) 76.1 (77.10)
CNN-static 86.8 (85.83) 81.0 (80.49)
CNN-fine-tuned 87.2 (84.68) 81.5 (79.88)
Bert-base-cased 93.5 (90.57)

About

Sentence Classification Model Implemented with PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.8%
  • Shell 0.2%