Beispiel #1
0
  def call(self, y_true=None, y_pred=None, arguments=None):
    ''' compute metric '''

    label_path_file = arguments["label_vocab_path"]
    return "\n" + seq_classification_report(
        ids_to_sentences(y_true, label_path_file),
        ids_to_sentences(y_pred, label_path_file),
        digits=4)
Beispiel #2
0
    def call(self, predictions, log_verbose=False):
        ''' main func entrypoint'''
        preds = predictions["preds"]
        output_index = predictions["output_index"]
        if output_index is None:
            res_file = self.config["solver"]["postproc"].get("res_file", "")
            label_path_file = self.config["data"]["task"]["label_vocab"]
        else:
            res_file = self.config["solver"]["postproc"][output_index].get(
                "res_file", "")
            label_path_file = self.config["data"]["task"]["label_vocab"][
                output_index]

        if res_file == "":
            logging.info(
                "Infer res not saved. You can check 'res_file' in your config."
            )
            return
        res_dir = os.path.dirname(res_file)
        if not os.path.exists(res_dir):
            os.makedirs(res_dir)
        logging.info("Save inference result to: {}".format(res_file))

        preds = ids_to_sentences(preds, label_path_file)

        with open(res_file, "w", encoding="utf-8") as in_f:
            for i, pre in enumerate(preds):
                entities = get_entities(pre)  # [('PER', 0, 1), ('LOC', 3, 3)]
                if not entities:
                    in_f.write("Null")
                else:
                    new_line = "\t".join(
                        [" ".join(map(str, entity)) for entity in entities])
                    in_f.write(new_line)
                in_f.write("\n")
Beispiel #3
0
 def test_ids_to_sentences(self):
     ''' test ids_to_sentences function '''
     config = utils.load_config(self.config_file)
     ids = [[2, 3, 1]]
     vocab_file_path = config["data"]["task"]["label_vocab"]
     sents = ids_to_sentences(ids, vocab_file_path)
     self.assertAllEqual(sents, [["I-PER", "B-LOC", "B-PER"]])
Beispiel #4
0
  def call(self, predictions, log_verbose=False):
    ''' main func entrypoint'''
    preds = predictions["preds"]

    res_file = self.config["solver"]["postproc"].get("res_file", "")
    if res_file == "":
      logging.info("Infer res not saved. You can check 'res_file' in your config.")
      return
    res_dir = os.path.dirname(res_file)
    if not os.path.exists(res_dir):
      os.makedirs(res_dir)
    logging.info("Save inference result to: {}".format(res_file))
    self.task_config = self.config['data']['task']
    self.label_vocab_file_paths = self.task_config['label_vocab']
    if not isinstance(self.label_vocab_file_paths, list):
      self.label_vocab_file_paths = [self.label_vocab_file_paths]
    self.use_label_vocab = self.task_config['use_label_vocab']
    if self.use_label_vocab:
      label_path_file = self.label_vocab_file_paths[0]
    else:
      label_path_file = self.task_config["text_vocab"]
    preds = ids_to_sentences(preds, label_path_file)
    with open(res_file, "w", encoding="utf-8") as in_f:
      for i, pre in enumerate(preds):
        while len(pre) > 1 and pre[-1] in ['<unk>', '<pad>', '<eos>']:
          pre.pop()
        pred_abs = ' '.join(pre)
        in_f.write(pred_abs)
        in_f.write("\n")
  def call(self, predictions, log_verbose=False):
    ''' main func entrypoint'''
    preds = predictions["preds"]
    paths = self.config["data"]["infer"]["paths"]
    max_seq_len = self.config["data"]["task"]["max_seq_len"]

    text = []
    counter = 0
    for path in paths:
      with open(path, 'r', encoding='utf8') as file_input:
        for line in file_input.readlines():
          line = list(line.strip())
          if line:
            if len(line) >= max_seq_len:
              line = line[:max_seq_len]
            else:
              line.extend(["unk"]*(max_seq_len-len(line)))
            text.append("".join(line))
            counter += 1
      logging.info("Load {} lines from {}.".format(str(counter), path))

    res_file = self.config["solver"]["postproc"].get("res_file", "")
    if res_file == "":
      logging.info("Infer res not saved. You can check 'res_file' in your config.")
      return
    res_dir = os.path.dirname(res_file)
    if not os.path.exists(res_dir):
      os.makedirs(res_dir)
    logging.info("Save inference result to: {}".format(res_file))
    label_path_file = self.config["data"]["task"]["label_vocab"]
    preds = ids_to_sentences(preds, label_path_file)

    assert len(text) == len(preds)

    with open(res_file, "w", encoding="utf-8") as in_f:
      for i, pre in enumerate(preds):
        entity_dict = {}
        entities = get_entities(pre)  # [('PER', 0, 1), ('LOC', 3, 3)]

        for entity_tuple in entities:
          entity = "".join([text[i][j] for j in range(entity_tuple[1], entity_tuple[2] + 1)])
          if entity_tuple[0] in entity_dict:
            entity_dict[entity_tuple[0]].append(entity)
          else:
            entity_dict[entity_tuple[0]] = [entity]
        in_f.write(str(entity_dict))
        in_f.write("\n")
Beispiel #6
0
 def call(self, y_true=None, y_pred=None, arguments=None):
   ''' compute metric '''
   label_path_file = self.config["data"]["task"]["label_vocab"]
   return seq_classification_report(ids_to_sentences(y_true, label_path_file),
                                ids_to_sentences(y_pred, label_path_file))