Esempio n. 1
0
if args.snapshot is None:
    if args.CNN:
        print("loading CNN model.....")
        model = model_CNN.CNN_Text(args)
    elif args.DEEP_CNN:
        print("loading DEEP_CNN model......")
        model = model_DeepCNN.DEEP_CNN(args)
    elif args.LSTM:
        print("loading LSTM model......")
        model = model_LSTM.LSTM(args)
    elif args.GRU:
        print("loading GRU model......")
        model = model_GRU.GRU(args)
    elif args.BiLSTM:
        print("loading BiLSTM model......")
        model = model_BiLSTM.BiLSTM(args)
    elif args.BiLSTM_1:
        print("loading BiLSTM_1 model......")
        model = model_BiLSTM_1.BiLSTM_1(args)
    elif args.CNN_LSTM:
        print("loading CNN_LSTM model......")
        model = model_CNN_LSTM.CNN_LSTM(args)
    elif args.CLSTM:
        print("loading CLSTM model......")
        model = model_CLSTM.CLSTM(args)
    elif args.CBiLSTM:
        print("loading CBiLSTM model......")
        model = model_CBiLSTM.CBiLSTM(args)
    elif args.CGRU:
        print("loading CGRU model......")
        model = model_CGRU.CGRU(args)
    train_topic_var = Variable(torch.LongTensor(topic_index))
    train_text_var = Variable(torch.LongTensor(text_index))
    train_label_var = Variable(torch.LongTensor(label_index))

    # dev_topic_var = Variable(torch.LongTensor(dev_topic_index))
    # dev_text_var = Variable(torch.LongTensor(dev_text_index))
    # dev_label_var = Variable(torch.LongTensor(dev_label_index))

    test_topic_var = Variable(torch.LongTensor(test_topic_index))
    test_text_var = Variable(torch.LongTensor(test_text_index))
    test_label_var = Variable(torch.LongTensor(test_label_index))
    # # print("ssss",test_topic_var)
    # print(test_text_var)
    # print(test_label_var)

    # dev_iter = dataProcessing.create_batches(dev_topic_var, dev_text_var, dev_label_var, params.batch_size)
    # print(dev_iter)
    test_iter = dataProcessing.create_batches(test_topic_var, test_text_var,
                                              test_label_var,
                                              params.batch_size)
    # print(test_iter)
    # print("train_var",train_topic_var)          #2414x6
    # print("train_text_var",train_text_var)      #2414x35
    # print("train_label_var",train_label_var)    #2414x1
    if params.use_lstm is True:
        model = model_BiLSTM.BiLSTM(params)
        if params.cuda_use is True:
            model = model.cuda()
        train2.train(train_topic_var, train_text_var, train_label_var, model,
                     label2id, id2label, params, test_iter)