if args.snapshot is None: if args.CNN: print("loading CNN model.....") model = model_CNN.CNN_Text(args) elif args.DEEP_CNN: print("loading DEEP_CNN model......") model = model_DeepCNN.DEEP_CNN(args) elif args.LSTM: print("loading LSTM model......") model = model_LSTM.LSTM(args) elif args.GRU: print("loading GRU model......") model = model_GRU.GRU(args) elif args.BiLSTM: print("loading BiLSTM model......") model = model_BiLSTM.BiLSTM(args) elif args.BiLSTM_1: print("loading BiLSTM_1 model......") model = model_BiLSTM_1.BiLSTM_1(args) elif args.CNN_LSTM: print("loading CNN_LSTM model......") model = model_CNN_LSTM.CNN_LSTM(args) elif args.CLSTM: print("loading CLSTM model......") model = model_CLSTM.CLSTM(args) elif args.CBiLSTM: print("loading CBiLSTM model......") model = model_CBiLSTM.CBiLSTM(args) elif args.CGRU: print("loading CGRU model......") model = model_CGRU.CGRU(args)
train_topic_var = Variable(torch.LongTensor(topic_index)) train_text_var = Variable(torch.LongTensor(text_index)) train_label_var = Variable(torch.LongTensor(label_index)) # dev_topic_var = Variable(torch.LongTensor(dev_topic_index)) # dev_text_var = Variable(torch.LongTensor(dev_text_index)) # dev_label_var = Variable(torch.LongTensor(dev_label_index)) test_topic_var = Variable(torch.LongTensor(test_topic_index)) test_text_var = Variable(torch.LongTensor(test_text_index)) test_label_var = Variable(torch.LongTensor(test_label_index)) # # print("ssss",test_topic_var) # print(test_text_var) # print(test_label_var) # dev_iter = dataProcessing.create_batches(dev_topic_var, dev_text_var, dev_label_var, params.batch_size) # print(dev_iter) test_iter = dataProcessing.create_batches(test_topic_var, test_text_var, test_label_var, params.batch_size) # print(test_iter) # print("train_var",train_topic_var) #2414x6 # print("train_text_var",train_text_var) #2414x35 # print("train_label_var",train_label_var) #2414x1 if params.use_lstm is True: model = model_BiLSTM.BiLSTM(params) if params.cuda_use is True: model = model.cuda() train2.train(train_topic_var, train_text_var, train_label_var, model, label2id, id2label, params, test_iter)