def infer(): # Config Loader test_args = ConfigSection() ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) index2label = load_pickle(pickle_path, "label2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model model = AdvSeqLabel(test_args) try: ModelLoader.load_pytorch(model, "./save/trained_model.pkl") print('model loaded!') except Exception as e: print('cannot load model!') raise # Data Loader infer_data = SeqLabelDataSet(load_func=BaseLoader.load_lines) infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) print('data loaded') # Inference interface infer = SeqLabelInfer(pickle_path) results = infer.predict(model, infer_data) print(results) print("Inference finished!")
def infer(): # Load infer configuration, the same as test test_args = ConfigSection() ConfigLoader("config.cfg").load_config(config_dir, {"POS_infer": test_args}) # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) index2label = load_pickle(pickle_path, "id2class.pkl") test_args["num_classes"] = len(index2label) # Define the same model model = SeqLabeling(test_args) # Dump trained parameters into the model ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) print("model loaded!") # Data Loader raw_data_loader = BaseLoader(data_infer_path) infer_data = raw_data_loader.load_lines() # Inference interface infer = SeqLabelInfer(pickle_path) results = infer.predict(model, infer_data) for res in results: print(res) print("Inference finished!")
def infer(): # Load infer configuration, the same as test test_args = ConfigSection() ConfigLoader().load_config(config_path, {"POS_infer": test_args}) # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) index2label = load_pickle(pickle_path, "label2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model model = SeqLabeling(test_args) # Dump trained parameters into the model ModelLoader.load_pytorch(model, "./save/saved_model.pkl") print("model loaded!") # Load infer data infer_data = SeqLabelDataSet(load_func=BaseLoader.load) infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) # inference infer = SeqLabelInfer(pickle_path) results = infer.predict(model, infer_data) print(results)
def _create_inference(self, model_dir): if self.infer_type == "seq_label": return SeqLabelInfer(model_dir) elif self.infer_type == "text_class": return ClassificationInfer(model_dir) else: raise ValueError("fail to create inference instance")
def _create_inference(self, model_dir): """Specify which task to perform. :param model_dir: :return: """ if self.infer_type == "seq_label": return SeqLabelInfer(model_dir) elif self.infer_type == "text_class": return ClassificationInfer(model_dir) else: raise ValueError("fail to create inference instance")
def infer(): # Load infer configuration, the same as test test_args = ConfigSection() ConfigLoader().load_config(config_dir, {"POS_infer": test_args}) # fetch dictionary size and number of labels from pickle files word_vocab = load_pickle(pickle_path, "word2id.pkl") label_vocab = load_pickle(pickle_path, "label2id.pkl") test_args["vocab_size"] = len(word_vocab) test_args["num_classes"] = len(label_vocab) print("vocabularies loaded") # Define the same model model = SeqLabeling(test_args) print("model defined") # Dump trained parameters into the model ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) print("model loaded!") # Data Loader infer_data = SeqLabelDataSet(load_func=BaseLoader.load) infer_data.load(data_infer_path, vocabs={ "word_vocab": word_vocab, "label_vocab": label_vocab }, infer=True) print("data set prepared") # Inference interface infer = SeqLabelInfer(pickle_path) results = infer.predict(model, infer_data) for res in results: print(res) print("Inference finished!")