def train():
    x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,
                                    600)  #获取训练数据每个字的id和对应标签的oe-hot形式
    x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, 600)
    #使用LSTM或者CNN
    model = TextRNN()
    model.train()
    # model = TextCNN()
    #选择损失函数
    Loss = nn.MultiLabelSoftMarginLoss()
    # Loss = nn.BCELoss()
    # Loss = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    best_val_acc = 0
    for epoch in range(100):
        i = 0
        print('epoch:{}'.format(epoch))
        batch_train = batch_iter(x_train, y_train, 64)
        for x_batch, y_batch in batch_train:
            i += 1
            # print(i)
            x = np.array(x_batch)
            y = np.array(y_batch)
            x = torch.LongTensor(x)
            y = torch.Tensor(y)
            # y = torch.LongTensor(y)
            # x = Variable(x)
            # y = Variable(y)
            out = model(x)
            loss = Loss(out, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 对模型进行验证
            if i % 90 == 0:
                los, accracy = evaluate(model, Loss, x_val,
                                        y_val)  # 此处不需要优化器参数
                print('loss:{},accracy:{}'.format(los, accracy))
                if accracy > best_val_acc:
                    torch.save(model.state_dict(), 'model_params.pkl')
                    best_val_acc = accracy
 def __init__(self):
     self.categories, self.cat_to_id = read_category()
     self.words, self.word_to_id = read_vocab(vocab_dir)
     self.model = TextRNN()
     self.model.load_state_dict(torch.load('model_params.pkl'))