コード例 #1
0
ファイル: transform.py プロジェクト: lei1993/HanLP
 def __call__(self, sample: dict) -> dict:
     src = sample[self.src]
     if isinstance(src, str):
         dst = not ispunct(src)
     else:
         dst = [not ispunct(x) for x in src]
     sample[self.dst] = dst
     return sample
コード例 #2
0
ファイル: eos.py プロジェクト: lei1993/HanLP
    def load_file(self, filepath: str):
        """Load eos corpus.

        Args:
            filepath: Path to the corpus.

        .. highlight:: bash
        .. code-block:: bash

            $ head -n 2 ctb8.txt
            中国经济简讯
            新华社北京十月二十九日电中国经济简讯

        """
        f = TimingFileIterator(filepath)
        sents = []
        eos_offsets = []
        offset = 0
        for line in f:
            if not line.strip():
                continue
            line = line.rstrip('\n')
            eos_offsets.append(offset + len(line.rstrip()) - 1)
            offset += len(line)
            if self.append_after_sentence:
                line += self.append_after_sentence
                offset += len(self.append_after_sentence)
            f.log(line)
            sents.append(line)
        f.erase()
        corpus = list(itertools.chain.from_iterable(sents))

        if self.eos_chars:
            if not isinstance(self.eos_chars, set):
                self.eos_chars = set(self.eos_chars)
        else:
            eos_chars = Counter()
            for i in eos_offsets:
                eos_chars[corpus[i]] += 1
            self.eos_chars = set(k for (k, v) in eos_chars.most_common()
                                 if v >= self.eos_char_min_freq and (
                                     not self.eos_char_is_punct or ispunct(k)))
            cprint(f'eos_chars = [yellow]{self.eos_chars}[/yellow]')

        eos_index = 0
        eos_offsets = [i for i in eos_offsets if corpus[i] in self.eos_chars]
        window_size = self.window_size
        for i, c in enumerate(corpus):
            if c in self.eos_chars:
                window = corpus[i - window_size:i + window_size + 1]
                label_id = 1. if eos_offsets[eos_index] == i else 0.
                if label_id > 0:
                    eos_index += 1
                yield {'char': window, 'label_id': label_id}
        assert eos_index == len(
            eos_offsets), f'{eos_index} != {len(eos_offsets)}'
コード例 #3
0
ファイル: chunking_dataset.py プロジェクト: lei1993/HanLP
 def __init__(self,
              data: Union[str, List],
              transform: Union[Callable, List] = None,
              cache=None,
              generate_idx=None,
              max_seq_len=None,
              sent_delimiter=None) -> None:
     if not sent_delimiter:
         sent_delimiter = lambda x: ispunct(x)
     elif isinstance(sent_delimiter, str):
         sent_delimiter = set(list(sent_delimiter))
         sent_delimiter = lambda x: x in sent_delimiter
     self.sent_delimiter = sent_delimiter
     self.max_seq_len = max_seq_len
     super().__init__(data, transform, cache, generate_idx)
コード例 #4
0
ファイル: conll.py プロジェクト: zuoqy/HanLP
 def lock_vocabs(self):
     super().lock_vocabs()
     self.puncts = tf.constant(
         [i for s, i in self.form_vocab.token_to_idx.items() if ispunct(s)],
         dtype=tf.int64)
コード例 #5
0
def evalb(gold_trees, predicted_trees, ref_gold_path=None, evalb_dir=None):
    if not evalb_dir:
        evalb_dir = get_evalb_dir()
    assert os.path.exists(evalb_dir)
    evalb_program_path = os.path.join(evalb_dir, "evalb")
    evalb_spmrl_program_path = os.path.join(evalb_dir, "evalb_spmrl")
    assert os.path.exists(evalb_program_path) or os.path.exists(
        evalb_spmrl_program_path)

    if os.path.exists(evalb_program_path):
        # evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm")
        evalb_param_path = os.path.join(evalb_dir, "nk.prm")
    else:
        evalb_program_path = evalb_spmrl_program_path
        evalb_param_path = os.path.join(evalb_dir, "spmrl.prm")

    assert os.path.exists(evalb_program_path)
    assert os.path.exists(evalb_param_path)

    assert len(gold_trees) == len(predicted_trees)
    for gold_tree, predicted_tree in zip(gold_trees, predicted_trees):
        assert isinstance(gold_tree, trees.TreebankNode)
        assert isinstance(predicted_tree, trees.TreebankNode)
        gold_leaves = list(gold_tree.leaves())
        predicted_leaves = list(predicted_tree.leaves())
        assert len(gold_leaves) == len(predicted_leaves)
        for gold_leaf, predicted_leaf in zip(gold_leaves, predicted_leaves):
            if gold_leaf.word != predicted_leaf.word:
                # Maybe -LRB- => (
                if ispunct(predicted_leaf.word):
                    gold_leaf.word = predicted_leaf.word
                else:
                    print(
                        f'Predicted word {predicted_leaf.word} does not match gold word {gold_leaf.word}'
                    )
        # assert all(
        #     gold_leaf.word == predicted_leaf.word
        #     for gold_leaf, predicted_leaf in zip(gold_leaves, predicted_leaves))

    temp_dir = tempfile.TemporaryDirectory(prefix="evalb-")
    gold_path = os.path.join(temp_dir.name, "gold.txt")
    predicted_path = os.path.join(temp_dir.name, "predicted.txt")
    output_path = os.path.join(temp_dir.name, "output.txt")

    # DELETE
    # predicted_path = 'tmp_predictions.txt'
    # output_path = 'tmp_output.txt'
    # gold_path = 'tmp_gold.txt'

    with open(gold_path, "w") as outfile:
        if ref_gold_path is None:
            for tree in gold_trees:
                outfile.write("{}\n".format(tree.linearize()))
        else:
            with open(ref_gold_path) as goldfile:
                outfile.write(goldfile.read())

    with open(predicted_path, "w") as outfile:
        for tree in predicted_trees:
            outfile.write("{}\n".format(tree.linearize()))

    command = "{} -p {} {} {} > {}".format(
        evalb_program_path,
        evalb_param_path,
        gold_path,
        predicted_path,
        output_path,
    )
    # print(command)
    subprocess.run(command, shell=True)

    fscore = FScore(math.nan, math.nan, math.nan)
    with open(output_path) as infile:
        for line in infile:
            match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
            if match:
                fscore.recall = float(match.group(1)) / 100
            match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
            if match:
                fscore.precision = float(match.group(1)) / 100
            match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
            if match:
                fscore.fscore = float(match.group(1)) / 100
                break

    success = (not math.isnan(fscore.fscore) or fscore.recall == 0.0
               or fscore.precision == 0.0)

    if success:
        temp_dir.cleanup()
    else:
        # print("Error reading EVALB results.")
        # print("Gold path: {}".format(gold_path))
        # print("Predicted path: {}".format(predicted_path))
        # print("Output path: {}".format(output_path))
        pass

    return fscore