def call(self, y_true=None, y_pred=None, arguments=None): ''' compute metric ''' label_path_file = arguments["label_vocab_path"] return "\n" + seq_classification_report( ids_to_sentences(y_true, label_path_file), ids_to_sentences(y_pred, label_path_file), digits=4)
def call(self, predictions, log_verbose=False): ''' main func entrypoint''' preds = predictions["preds"] output_index = predictions["output_index"] if output_index is None: res_file = self.config["solver"]["postproc"].get("res_file", "") label_path_file = self.config["data"]["task"]["label_vocab"] else: res_file = self.config["solver"]["postproc"][output_index].get( "res_file", "") label_path_file = self.config["data"]["task"]["label_vocab"][ output_index] if res_file == "": logging.info( "Infer res not saved. You can check 'res_file' in your config." ) return res_dir = os.path.dirname(res_file) if not os.path.exists(res_dir): os.makedirs(res_dir) logging.info("Save inference result to: {}".format(res_file)) preds = ids_to_sentences(preds, label_path_file) with open(res_file, "w", encoding="utf-8") as in_f: for i, pre in enumerate(preds): entities = get_entities(pre) # [('PER', 0, 1), ('LOC', 3, 3)] if not entities: in_f.write("Null") else: new_line = "\t".join( [" ".join(map(str, entity)) for entity in entities]) in_f.write(new_line) in_f.write("\n")
def test_ids_to_sentences(self): ''' test ids_to_sentences function ''' config = utils.load_config(self.config_file) ids = [[2, 3, 1]] vocab_file_path = config["data"]["task"]["label_vocab"] sents = ids_to_sentences(ids, vocab_file_path) self.assertAllEqual(sents, [["I-PER", "B-LOC", "B-PER"]])
def call(self, predictions, log_verbose=False): ''' main func entrypoint''' preds = predictions["preds"] res_file = self.config["solver"]["postproc"].get("res_file", "") if res_file == "": logging.info("Infer res not saved. You can check 'res_file' in your config.") return res_dir = os.path.dirname(res_file) if not os.path.exists(res_dir): os.makedirs(res_dir) logging.info("Save inference result to: {}".format(res_file)) self.task_config = self.config['data']['task'] self.label_vocab_file_paths = self.task_config['label_vocab'] if not isinstance(self.label_vocab_file_paths, list): self.label_vocab_file_paths = [self.label_vocab_file_paths] self.use_label_vocab = self.task_config['use_label_vocab'] if self.use_label_vocab: label_path_file = self.label_vocab_file_paths[0] else: label_path_file = self.task_config["text_vocab"] preds = ids_to_sentences(preds, label_path_file) with open(res_file, "w", encoding="utf-8") as in_f: for i, pre in enumerate(preds): while len(pre) > 1 and pre[-1] in ['<unk>', '<pad>', '<eos>']: pre.pop() pred_abs = ' '.join(pre) in_f.write(pred_abs) in_f.write("\n")
def call(self, predictions, log_verbose=False): ''' main func entrypoint''' preds = predictions["preds"] paths = self.config["data"]["infer"]["paths"] max_seq_len = self.config["data"]["task"]["max_seq_len"] text = [] counter = 0 for path in paths: with open(path, 'r', encoding='utf8') as file_input: for line in file_input.readlines(): line = list(line.strip()) if line: if len(line) >= max_seq_len: line = line[:max_seq_len] else: line.extend(["unk"]*(max_seq_len-len(line))) text.append("".join(line)) counter += 1 logging.info("Load {} lines from {}.".format(str(counter), path)) res_file = self.config["solver"]["postproc"].get("res_file", "") if res_file == "": logging.info("Infer res not saved. You can check 'res_file' in your config.") return res_dir = os.path.dirname(res_file) if not os.path.exists(res_dir): os.makedirs(res_dir) logging.info("Save inference result to: {}".format(res_file)) label_path_file = self.config["data"]["task"]["label_vocab"] preds = ids_to_sentences(preds, label_path_file) assert len(text) == len(preds) with open(res_file, "w", encoding="utf-8") as in_f: for i, pre in enumerate(preds): entity_dict = {} entities = get_entities(pre) # [('PER', 0, 1), ('LOC', 3, 3)] for entity_tuple in entities: entity = "".join([text[i][j] for j in range(entity_tuple[1], entity_tuple[2] + 1)]) if entity_tuple[0] in entity_dict: entity_dict[entity_tuple[0]].append(entity) else: entity_dict[entity_tuple[0]] = [entity] in_f.write(str(entity_dict)) in_f.write("\n")
def call(self, y_true=None, y_pred=None, arguments=None): ''' compute metric ''' label_path_file = self.config["data"]["task"]["label_vocab"] return seq_classification_report(ids_to_sentences(y_true, label_path_file), ids_to_sentences(y_pred, label_path_file))