def __call__(self, sample: dict) -> dict: src = sample[self.src] if isinstance(src, str): dst = not ispunct(src) else: dst = [not ispunct(x) for x in src] sample[self.dst] = dst return sample
def load_file(self, filepath: str): """Load eos corpus. Args: filepath: Path to the corpus. .. highlight:: bash .. code-block:: bash $ head -n 2 ctb8.txt 中国经济简讯 新华社北京十月二十九日电中国经济简讯 """ f = TimingFileIterator(filepath) sents = [] eos_offsets = [] offset = 0 for line in f: if not line.strip(): continue line = line.rstrip('\n') eos_offsets.append(offset + len(line.rstrip()) - 1) offset += len(line) if self.append_after_sentence: line += self.append_after_sentence offset += len(self.append_after_sentence) f.log(line) sents.append(line) f.erase() corpus = list(itertools.chain.from_iterable(sents)) if self.eos_chars: if not isinstance(self.eos_chars, set): self.eos_chars = set(self.eos_chars) else: eos_chars = Counter() for i in eos_offsets: eos_chars[corpus[i]] += 1 self.eos_chars = set(k for (k, v) in eos_chars.most_common() if v >= self.eos_char_min_freq and ( not self.eos_char_is_punct or ispunct(k))) cprint(f'eos_chars = [yellow]{self.eos_chars}[/yellow]') eos_index = 0 eos_offsets = [i for i in eos_offsets if corpus[i] in self.eos_chars] window_size = self.window_size for i, c in enumerate(corpus): if c in self.eos_chars: window = corpus[i - window_size:i + window_size + 1] label_id = 1. if eos_offsets[eos_index] == i else 0. if label_id > 0: eos_index += 1 yield {'char': window, 'label_id': label_id} assert eos_index == len( eos_offsets), f'{eos_index} != {len(eos_offsets)}'
def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None, generate_idx=None, max_seq_len=None, sent_delimiter=None) -> None: if not sent_delimiter: sent_delimiter = lambda x: ispunct(x) elif isinstance(sent_delimiter, str): sent_delimiter = set(list(sent_delimiter)) sent_delimiter = lambda x: x in sent_delimiter self.sent_delimiter = sent_delimiter self.max_seq_len = max_seq_len super().__init__(data, transform, cache, generate_idx)
def lock_vocabs(self): super().lock_vocabs() self.puncts = tf.constant( [i for s, i in self.form_vocab.token_to_idx.items() if ispunct(s)], dtype=tf.int64)
def evalb(gold_trees, predicted_trees, ref_gold_path=None, evalb_dir=None): if not evalb_dir: evalb_dir = get_evalb_dir() assert os.path.exists(evalb_dir) evalb_program_path = os.path.join(evalb_dir, "evalb") evalb_spmrl_program_path = os.path.join(evalb_dir, "evalb_spmrl") assert os.path.exists(evalb_program_path) or os.path.exists( evalb_spmrl_program_path) if os.path.exists(evalb_program_path): # evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm") evalb_param_path = os.path.join(evalb_dir, "nk.prm") else: evalb_program_path = evalb_spmrl_program_path evalb_param_path = os.path.join(evalb_dir, "spmrl.prm") assert os.path.exists(evalb_program_path) assert os.path.exists(evalb_param_path) assert len(gold_trees) == len(predicted_trees) for gold_tree, predicted_tree in zip(gold_trees, predicted_trees): assert isinstance(gold_tree, trees.TreebankNode) assert isinstance(predicted_tree, trees.TreebankNode) gold_leaves = list(gold_tree.leaves()) predicted_leaves = list(predicted_tree.leaves()) assert len(gold_leaves) == len(predicted_leaves) for gold_leaf, predicted_leaf in zip(gold_leaves, predicted_leaves): if gold_leaf.word != predicted_leaf.word: # Maybe -LRB- => ( if ispunct(predicted_leaf.word): gold_leaf.word = predicted_leaf.word else: print( f'Predicted word {predicted_leaf.word} does not match gold word {gold_leaf.word}' ) # assert all( # gold_leaf.word == predicted_leaf.word # for gold_leaf, predicted_leaf in zip(gold_leaves, predicted_leaves)) temp_dir = tempfile.TemporaryDirectory(prefix="evalb-") gold_path = os.path.join(temp_dir.name, "gold.txt") predicted_path = os.path.join(temp_dir.name, "predicted.txt") output_path = os.path.join(temp_dir.name, "output.txt") # DELETE # predicted_path = 'tmp_predictions.txt' # output_path = 'tmp_output.txt' # gold_path = 'tmp_gold.txt' with open(gold_path, "w") as outfile: if ref_gold_path is None: for tree in gold_trees: outfile.write("{}\n".format(tree.linearize())) else: with open(ref_gold_path) as goldfile: outfile.write(goldfile.read()) with open(predicted_path, "w") as outfile: for tree in predicted_trees: outfile.write("{}\n".format(tree.linearize())) command = "{} -p {} {} {} > {}".format( evalb_program_path, evalb_param_path, gold_path, predicted_path, output_path, ) # print(command) subprocess.run(command, shell=True) fscore = FScore(math.nan, math.nan, math.nan) with open(output_path) as infile: for line in infile: match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line) if match: fscore.recall = float(match.group(1)) / 100 match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line) if match: fscore.precision = float(match.group(1)) / 100 match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line) if match: fscore.fscore = float(match.group(1)) / 100 break success = (not math.isnan(fscore.fscore) or fscore.recall == 0.0 or fscore.precision == 0.0) if success: temp_dir.cleanup() else: # print("Error reading EVALB results.") # print("Gold path: {}".format(gold_path)) # print("Predicted path: {}".format(predicted_path)) # print("Output path: {}".format(output_path)) pass return fscore