Пример #1
0
class RnnModel:
    def __init__(self):
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.model = TextRNN()
        self.model.load_state_dict(torch.load('model_params.pkl'))

    def predict(self, message):
        # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
        content = unicode(message)
        data = [self.word_to_id[x] for x in content if x in self.word_to_id]
        data = kr.preprocessing.sequence.pad_sequences([data], 600)
        data = torch.LongTensor(data)
        y_pred_cls = self.model(data)
        class_index = torch.argmax(y_pred_cls[0]).item()
        return self.categories[class_index]
Пример #2
0
def train():
    x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,
                                    600)  #获取训练数据每个字的id和对应标签的oe-hot形式
    x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, 600)
    #使用LSTM或者CNN
    model = TextRNN()
    # model = TextCNN()
    #选择损失函数
    Loss = nn.MultiLabelSoftMarginLoss()
    # Loss = nn.BCELoss()
    # Loss = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    best_val_acc = 0
    for epoch in range(100):
        i = 0
        print('epoch:{}'.format(epoch))
        batch_train = batch_iter(x_train, y_train, 64)
        for x_batch, y_batch in batch_train:
            i += 1
            # print(i)
            x = np.array(x_batch)
            y = np.array(y_batch)
            x = torch.LongTensor(x)
            y = torch.Tensor(y)
            # y = torch.LongTensor(y)
            # x = Variable(x)
            # y = Variable(y)
            out = model(x)
            loss = Loss(out, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 对模型进行验证
            if i % 90 == 0:
                los, accracy = evaluate(model, Loss, optimizer, x_val, y_val)
                print('loss:{},accracy:{}'.format(los, accracy))
                if accracy > best_val_acc:
                    torch.save(model.state_dict(), 'model_params.pkl')
                    best_val_acc = accracy
Пример #3
0
 def __init__(self):
     self.categories, self.cat_to_id = read_category()
     self.words, self.word_to_id = read_vocab(vocab_dir)
     self.model = TextRNN()
     self.model.load_state_dict(torch.load('model_params.pkl'))