def infer(): # Config Loader test_args = ConfigSection() ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) index2label = load_pickle(pickle_path, "label2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model model = AdvSeqLabel(test_args) try: ModelLoader.load_pytorch(model, "./save/trained_model.pkl") print('model loaded!') except Exception as e: print('cannot load model!') raise # Data Loader infer_data = SeqLabelDataSet(load_func=BaseLoader.load_lines) infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) print('data loaded') # Inference interface infer = SeqLabelInfer(pickle_path) results = infer.predict(model, infer_data) print(results) print("Inference finished!")
def predict(): # Config Loader test_args = ConfigSection() ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) index2label = load_pickle(pickle_path, "label2id.pkl") test_args["num_classes"] = len(index2label) # load dev data dev_data = load_pickle(pickle_path, "data_dev.pkl") # Define the same model model = AdvSeqLabel(test_args) # Dump trained parameters into the model ModelLoader.load_pytorch(model, "./save/trained_model.pkl") print("model loaded!") # Tester test_args["evaluator"] = SeqLabelEvaluator() tester = SeqLabelTester(**test_args.data) # Start testing tester.test(model, dev_data)
import torch from fastNLP import Trainer from fastNLP import Tester from fastNLP import CrossEntropyLoss from fastNLP import Adam from fastNLP import AccuracyMetric from fastNLP.io.config_io import ConfigSection, ConfigLoader import fastNLP.core.utils as util from model import myESIM args = ConfigSection() ConfigLoader().load_config("../data/config.json", {"train": args}) # 加载训练、验证数据集和词向量 print("\t* Loading train data...") train_data = util.load_pickle(os.path.normpath(args["data_dir"]), args["train_file"]) print("\t* Loading dev data...") dev_data = util.load_pickle(os.path.normpath(args["data_dir"]), args["dev_file"]) print("\t* Loading word embeddings...") embeddings = util.load_pickle(os.path.normpath(args["data_dir"]), args["embeddings_file"]) embeddings = torch.Tensor(embeddings) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = myESIM(embeddings.shape[0], embeddings.shape[1], 300, embeddings=embeddings, dropout=0.5, num_classes=3, device=device).to(device)