def train(): x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, 600) #获取训练数据每个字的id和对应标签的oe-hot形式 x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, 600) #使用LSTM或者CNN model = TextRNN() model.train() # model = TextCNN() #选择损失函数 Loss = nn.MultiLabelSoftMarginLoss() # Loss = nn.BCELoss() # Loss = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) best_val_acc = 0 for epoch in range(100): i = 0 print('epoch:{}'.format(epoch)) batch_train = batch_iter(x_train, y_train, 64) for x_batch, y_batch in batch_train: i += 1 # print(i) x = np.array(x_batch) y = np.array(y_batch) x = torch.LongTensor(x) y = torch.Tensor(y) # y = torch.LongTensor(y) # x = Variable(x) # y = Variable(y) out = model(x) loss = Loss(out, y) optimizer.zero_grad() loss.backward() optimizer.step() # 对模型进行验证 if i % 90 == 0: los, accracy = evaluate(model, Loss, x_val, y_val) # 此处不需要优化器参数 print('loss:{},accracy:{}'.format(los, accracy)) if accracy > best_val_acc: torch.save(model.state_dict(), 'model_params.pkl') best_val_acc = accracy
def __init__(self): self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_dir) self.model = TextRNN() self.model.load_state_dict(torch.load('model_params.pkl'))