def test_train(self): """ 模型训练. """ print('{} test_train {}'.format('-' * 15, '-' * 15)) # 数据集加载 test_data_pickle = './data/aclImdb/test_data.pkl' test_data = load_serialize_obj(test_data_pickle) test_data = test_data[:100] # 数据量比较大,cpu电脑跑不动,取一部分进行训练 test_data_tokenized = get_tokenized_imdb(imdb_data=test_data) test_data_vocab = get_tokenized_vocab(test_data_tokenized) vocab_size = len(test_data_vocab) print('vocab len:{}'.format(vocab_size)) # vocab len:45098 test_iter = get_imdb_data_iter(test_data, test_data_vocab, batch_size=8, shuffle=True) print('test_iter len:{}'.format(len(test_iter))) # test_iter len:3218 # 构造模型 net = BiLSTM(vocab_size=vocab_size, labels_size=2) print('参数量:{}'.format(get_parameter_number( net))) # total:436.002 Thousand, trainable:436.002 Thousand print(net) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, net.parameters())) loss_func = nn.CrossEntropyLoss() # 训练 train_net(net, train_iter=test_iter, dev_iter=test_iter, max_epoch=5, optimizer=optimizer, loss_func=loss_func)
def test_train_use_pretrained_embedding(self): """ 模型训练,使用预训练embed. """ print('{} test_train_use_pretrained_embedding {}'.format( '-' * 15, '-' * 15)) # 数据集加载 test_data_pickle = './data/aclImdb/test_data.pkl' test_data = load_serialize_obj(test_data_pickle) test_data = test_data[:1000] # 数据量比较大,cpu电脑跑不动,取一部分进行训练 test_data_tokenized = get_tokenized_imdb(imdb_data=test_data) test_data_vocab = get_tokenized_vocab(test_data_tokenized) vocab_size = len(test_data_vocab) print('vocab len:{}'.format(vocab_size)) # vocab len:4345 test_iter = get_imdb_data_iter(test_data, test_data_vocab, batch_size=8, shuffle=True) print('test_iter len:{}'.format(len(test_iter))) # test_iter len:125 # 构造模型 net = TextCNN(vocab_size=vocab_size, labels_size=2) print('参数量:{}'.format(get_parameter_number( net))) # total:263.152 Thousand, trainable:263.152 Thousand # 使用预训练embed初始化 glove_embedding = torchtext.vocab.GloVe(name='6B', dim=50, cache='./data/torchtext') print("glove_embedding 一共包含%d个词。" % len(glove_embedding.stoi)) # 一共包含400000个词。 words = test_data_vocab.itos embed = load_pretrained_embedding( words=words, pretrained_vocab=glove_embedding) # There are 1004 oov words. net.embedding.weight.data.copy_(embed) net.embedding.weight.requires_grad = False # 直接加载预训练好的, 所以不需要更新它 print('参数量:{}'.format(get_parameter_number( net))) # total:263.152 Thousand, trainable:45.902 Thousand print(net) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, net.parameters())) loss_func = nn.CrossEntropyLoss() # 训练 train_net(net, train_iter=test_iter, dev_iter=test_iter, max_epoch=2, optimizer=optimizer, loss_func=loss_func) init_file_path('./data/save/text_classify/textcnn/') # 保存模型 torch.save(net, f='./data/save/text_classify/textcnn/model.pkl') # 保存vocabulary save_vocab_words( test_data_vocab, file_name='./data/save/text_classify/textcnn/vocab_words.txt') save_serialize_obj( test_data_vocab, filename='./data/save/text_classify/textcnn/vocab.pkl')
def test_tokenized_vocab(self): """ 数据集词典构造. """ print('{} test_tokenized_vocab {}'.format('-' * 15, '-' * 15)) test_data_tokenized_pickle = './data/aclImdb/test_data_tokenized.pkl' test_data_tokenized = load_serialize_obj(test_data_tokenized_pickle) test_data_vocab = get_tokenized_vocab(test_data_tokenized) print('vocab len:{}'.format(len(test_data_vocab))) # vocab len:45098 print('overcome id:{}'.format( test_data_vocab.stoi.get('overcome', None))) # overcome id:3753 print('father id:{}'.format(test_data_vocab.stoi.get( 'father', None))) # father id:475
def test_get_imdb_data_iter(self): """ 预处理数据集,构造DataLoader. """ print('{} test_get_imdb_data_iter {}'.format('-' * 15, '-' * 15)) test_data_pickle = './data/aclImdb/test_data.pkl' test_data = load_serialize_obj(test_data_pickle) test_data_tokenized = get_tokenized_imdb(imdb_data=test_data) test_data_vocab = get_tokenized_vocab(test_data_tokenized) print('vocab len:{}'.format(len(test_data_vocab))) # vocab len:45098 test_iter = get_imdb_data_iter(test_data, test_data_vocab, batch_size=8, shuffle=True) print('test_iter len:{}'.format(len(test_iter))) # test_iter len:3218 for X, y in test_iter: print('X', X.shape, 'y', y.shape) # X torch.Size([8, 500]) y torch.Size([8]) break