def smatch_eval(pred, gold, use_fast=False) -> Union[SmatchScores, F1_]: script = get_resource(_FAST_SMATCH_SCRIPT if use_fast else _SMATCH_SCRIPT) home = os.path.dirname(script) pred = os.path.realpath(pred) gold = os.path.realpath(gold) with pushd(home): flash('Running evaluation script [blink][yellow]...[/yellow][/blink]') cmd = f'bash {script} {pred} {gold}' text = run_cmd(cmd) flash('') return format_fast_scores(text) if use_fast else format_official_scores( text)
def load_transform(self, save_dir) -> Transform: """ Try to load transform only. This method might fail due to the fact it avoids building the model. If it do fail, then you have to use `load` which might be too heavy but that's the best we can do. :param save_dir: The path to load. """ save_dir = get_resource(save_dir) self.load_config(save_dir) self.load_vocabs(save_dir) self.transform.build_config() self.transform.lock_vocabs() return self.transform
def make_gold_conll(ontonotes_path, language): ensure_python_points_to_python2() ontonotes_path = os.path.abspath(get_resource(ontonotes_path)) to_conll = get_resource( 'https://gist.githubusercontent.com/hankcs/46b9137016c769e4b6137104daf43a92/raw/66369de6c24b5ec47696ae307591f0d72c6f3f02/ontonotes_to_conll.sh' ) to_conll = os.path.abspath(to_conll) # shutil.rmtree(os.path.join(ontonotes_path, 'conll-2012'), ignore_errors=True) with pushd(ontonotes_path): try: flash( f'Converting [blue]{language}[/blue] to CoNLL format, ' f'this might take half an hour [blink][yellow]...[/yellow][/blink]' ) run_cmd(f'bash {to_conll} {ontonotes_path} {language}') flash('') except RuntimeError as e: flash( f'[red]Failed[/red] to convert {language} of {ontonotes_path} to CoNLL. See exceptions for detail' ) raise e
def load_weights(self, save_dir, filename='model.pt', **kwargs): """Load weights from a directory. Args: save_dir: The directory to load weights from. filename: A file name for weights. **kwargs: Not used. """ save_dir = get_resource(save_dir) filename = os.path.join(save_dir, filename) # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]') self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False)
def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None, callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs): input_path = get_resource(input_path) file_prefix, ext = os.path.splitext(input_path) name = os.path.basename(file_prefix) if not name: name = 'evaluate' if save_dir and not logger: logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN, mode='w') tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size) samples = self.num_samples_in(tst_data) num_batches = math.ceil(samples / batch_size) if warm_up: for x, y in tst_data: self.model.predict_on_batch(x) break if output: assert save_dir, 'Must pass save_dir in order to output' if isinstance(output, bool): output = os.path.join(save_dir, name) + '.predict' + ext elif isinstance(output, str): output = output else: raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output))) timer = Timer() eval_outputs = self.evaluate_dataset(tst_data, callbacks, output, num_batches, **kwargs) loss, score, output = eval_outputs[0], eval_outputs[1], eval_outputs[2] delta_time = timer.stop() speed = samples / delta_time.delta_seconds if logger: f1: IOBES_F1_TF = None for metric in self.model.metrics: if isinstance(metric, IOBES_F1_TF): f1 = metric break extra_report = '' if f1: overall, by_type, extra_report = f1.state.result(full=True, verbose=False) extra_report = ' \n' + extra_report logger.info('Evaluation results for {} - ' 'loss: {:.4f} - {} - speed: {:.2f} sample/sec{}' .format(name + ext, loss, format_scores(score) if isinstance(score, dict) else format_metrics(self.model.metrics), speed, extra_report)) if output: logger.info('Saving output to {}'.format(output)) with open(output, 'w', encoding='utf-8') as out: self.evaluate_output(tst_data, out, num_batches, self.model.metrics) return loss, score, speed
def __init__(self, filepath: str, src, dst=None, **kwargs) -> None: if not dst: dst = src + '_fasttext' self.filepath = filepath flash( f'Loading fasttext model {filepath} [blink][yellow]...[/yellow][/blink]' ) filepath = get_resource(filepath) with stdout_redirected(to=os.devnull, stdout=sys.stderr): self._model = fasttext.load_model(filepath) flash('') output_dim = self._model['king'].size super().__init__(output_dim, src, dst)
def make_ontonotes_language_jsonlines(conll12_ontonotes_path, output_path=None, language='english'): conll12_ontonotes_path = get_resource(conll12_ontonotes_path) if output_path is None: output_path = os.path.dirname(conll12_ontonotes_path) for split in ['train', 'development', 'test']: pattern = f'{conll12_ontonotes_path}/data/{split}/data/{language}/annotations/*/*/*/*gold_conll' files = sorted(glob.glob(pattern, recursive=True)) assert files, f'No gold_conll files found in {pattern}' version = os.path.basename(files[0]).split('.')[-1].split('_')[0] if version.startswith('v'): assert all([version in os.path.basename(f) for f in files]) else: version = 'v5' lang_dir = f'{output_path}/{language}' if split == 'conll-2012-test': split = 'test' full_file = f'{lang_dir}/{split}.{language}.{version}_gold_conll' os.makedirs(lang_dir, exist_ok=True) print(f'Merging {len(files)} files to {full_file}') merge_files(files, full_file) v5_json_file = full_file.replace(f'.{version}_gold_conll', f'.{version}.jsonlines') print(f'Converting CoNLL file {full_file} to json file {v5_json_file}') labels, stats = convert_to_jsonlines(full_file, v5_json_file, language) print('Labels:') pprint(labels) print('Statistics:') pprint(stats) conll12_json_file = f'{lang_dir}/{split}.{language}.conll12.jsonlines' print( f'Applying CoNLL 12 official splits on {v5_json_file} to {conll12_json_file}' ) id_file = get_resource( f'https://od.hankcs.com/research/emnlp2021/conll.cemantix.org.zip#2012/download/ids/' f'{language}/coref/{split}.id') filter_data(v5_json_file, conll12_json_file, id_file)
def __init__(self, mapper: Union[str, dict], src: str, dst: str = None) -> None: super().__init__(src, dst) self.mapper = mapper if isinstance(mapper, str): mapper = get_resource(mapper) if isinstance(mapper, str): self._table = load_json(mapper) elif isinstance(mapper, dict): self._table = mapper else: raise ValueError(f'Unrecognized mapper type {mapper}')
def read_conll(filepath: Union[str, TimingFileIterator], underline_to_none=False, enhanced_collapse_empty_nodes=False): sent = [] if isinstance(filepath, str): filepath: str = get_resource(filepath) if filepath.endswith( '.conllu') and enhanced_collapse_empty_nodes is None: enhanced_collapse_empty_nodes = True src = open(filepath, encoding='utf-8') else: src = filepath for idx, line in enumerate(src): if line.startswith('#'): continue line = line.strip() cells = line.split('\t') if line and cells: if enhanced_collapse_empty_nodes and '.' in cells[0]: cells[0] = float(cells[0]) cells[6] = None else: if '-' in cells[0] or '.' in cells[0]: # sent[-1][1] += cells[1] continue cells[0] = int(cells[0]) if cells[6] != '_': try: cells[6] = int(cells[6]) except ValueError: cells[6] = 0 logger.exception( f'Wrong CoNLL format {filepath}:{idx + 1}\n{line}') if underline_to_none: for i, x in enumerate(cells): if x == '_': cells[i] = None sent.append(cells) else: if enhanced_collapse_empty_nodes: sent = collapse_enhanced_empty_nodes(sent) yield sent sent = [] if sent: if enhanced_collapse_empty_nodes: sent = collapse_enhanced_empty_nodes(sent) yield sent src.close()
def load_from_meta_file(save_dir, meta_filename='meta.json', transform_only=False, load_kwargs=None, **kwargs) -> Component: identifier = save_dir load_path = save_dir save_dir = get_resource(save_dir) metapath = os.path.join(save_dir, meta_filename) if not os.path.isfile(metapath): tips = '' if save_dir.isupper(): from difflib import SequenceMatcher similar_keys = sorted(pretrained.ALL.keys(), key=lambda k: SequenceMatcher( None, save_dir, metapath).ratio(), reverse=True)[:5] tips = f'Check its spelling based on the available keys:\n' + \ f'{sorted(pretrained.ALL.keys())}\n' + \ f'Tips: it might be one of {similar_keys}' raise FileNotFoundError( f'The identifier {save_dir} resolves to a non-exist meta file {metapath}. {tips}' ) meta: dict = load_json(metapath) cls = meta.get('class_path', None) assert cls, f'{meta_filename} doesn\'t contain class_path field' try: obj: Component = object_from_class_path(cls, **kwargs) if hasattr(obj, 'load') and os.path.isfile( os.path.join(save_dir, 'config.json')): if transform_only: # noinspection PyUnresolvedReferences obj.load_transform(save_dir) else: if load_kwargs is None: load_kwargs = {} obj.load(save_dir, **load_kwargs) obj.meta['load_path'] = load_path return obj except Exception as e: eprint(f'Failed to load {identifier}. See stack trace below') traceback.print_exc() model_version = meta.get("hanlp_version", "unknown") cur_version = version.__version__ if model_version != cur_version: eprint( f'{identifier} was created with hanlp-{model_version}, while you are running {cur_version}. ' f'Try to upgrade hanlp with\n' f'pip install --upgrade hanlp') exit(1)
def ctb_pos_to_text_format(path, delimiter='_'): """ Convert ctb pos tagging corpus from tsv format to text format, where each word is followed by its pos tag. Args: path: File to be converted. delimiter: Delimiter between word and tag. """ path = get_resource(path) name, ext = os.path.splitext(path) with open(f'{name}.txt', 'w', encoding='utf-8') as out: for sent in read_tsv_as_sents(path): out.write(' '.join([delimiter.join(x) for x in sent])) out.write('\n')
def load_config(self, save_dir, filename='config.json', **kwargs): """Load config from a directory. Args: save_dir: The directory to load config. filename: A file name for config. **kwargs: K-V pairs to override config. """ save_dir = get_resource(save_dir) self.config.load_json(os.path.join(save_dir, filename)) self.config.update(kwargs) # overwrite config loaded from disk for k, v in self.config.items(): if isinstance(v, dict) and 'classpath' in v: self.config[k] = Configurable.from_config(v) self.on_config_ready(**self.config)
def _list_dir(path, home): prefix = home.lstrip('_').replace('_HOME', '') path = get_resource(path) with open('ud27.py', 'a') as out: for f in sorted(glob.glob(path + '/ud-treebanks-v2.7/UD_*')): basename = os.path.basename(f) name = basename[len('UD_'):] name = name.upper().replace('-', '_') for split in 'train', 'dev', 'test': sp = glob.glob(f + f'/*{split}.conllu') if not sp: continue sp = os.path.basename(sp[0]) out.write(f'{prefix}_{name}_{split.upper()} = {home} + "{basename}/{sp}"\n') out.write(f'"{prefix} {split} set of {name}."\n')
def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs): self.meta['load_path'] = save_dir save_dir = get_resource(save_dir) self.load_config(save_dir) self.load_vocabs(save_dir) self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True)) self.load_weights(save_dir) self.load_meta(save_dir)
def convert_to_stanford_dependency_330(src, dst): cprint( f'Converting {os.path.basename(src)} to {os.path.basename(dst)} using Stanford Parser Version 3.3.0. ' f'It might take a while [blink][yellow]...[/yellow][/blink]') sp_home = 'https://nlp.stanford.edu/software/stanford-parser-full-2013-11-12.zip' sp_home = get_resource(sp_home) # jar_path = get_resource(f'{sp_home}#stanford-parser.jar') code, out, err = get_exitcode_stdout_stderr( f'java -cp {sp_home}/* edu.stanford.nlp.trees.international.pennchinese.ChineseGrammaticalStructure ' f'-basic -keepPunct -conllx ' f'-treeFile {src}') with open(dst, 'w') as f: f.write(out) if code: raise RuntimeError( f'Conversion failed with code {code} for {src}. The err message is:\n {err}\n' f'Do you have java installed? Do you have enough memory?')
def _generate_chars_tags(filepath, delimiter, max_seq_len): filepath = get_resource(filepath) with open(filepath, encoding='utf8') as src: for text in src: chars, tags = bmes_of(text, True) if max_seq_len and delimiter and len(chars) > max_seq_len: short_chars, short_tags = [], [] for idx, (char, tag) in enumerate(zip(chars, tags)): short_chars.append(char) short_tags.append(tag) if len(short_chars) >= max_seq_len and delimiter(char): yield short_chars, short_tags short_chars, short_tags = [], [] if short_chars: yield short_chars, short_tags else: yield chars, tags
def export_model_for_serving(self, export_dir=None, version=1, overwrite=False, show_hint=False): assert self.model, 'You have to fit or load a model before exporting it' if not export_dir: assert 'load_path' in self.meta, 'When not specifying save_dir, load_path has to present' export_dir = get_resource(self.meta['load_path']) model_path = os.path.join(export_dir, str(version)) if os.path.isdir(model_path) and not overwrite: logger.info(f'{model_path} exists, skip since overwrite = {overwrite}') return export_dir logger.info(f'Exporting to {export_dir} ...') tf.saved_model.save(self.model, model_path) logger.info(f'Successfully exported model to {export_dir}') if show_hint: logger.info(f'You can serve it through \n' f'tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} ' f'--model_base_path={export_dir} --rest_api_port=8888') return export_dir
def __init__(self, filepath: str = None, vocab: Vocab = None, expand_vocab=True, lowercase=True, input_dim=None, output_dim=None, unk=None, normalize=False, embeddings_initializer='VarianceScaling', embeddings_regularizer=None, activity_regularizer=None, embeddings_constraint=None, mask_zero=True, input_length=None, name=None, **kwargs): filepath = get_resource(filepath) word2vec, _output_dim = load_word2vec(filepath) if output_dim: assert output_dim == _output_dim, f'output_dim = {output_dim} does not match {filepath}' output_dim = _output_dim # if the `unk` token exists in the pretrained, # then replace it with a self-defined one, usually the one in word vocab if unk and unk in word2vec: word2vec[vocab.safe_unk_token] = word2vec.pop(unk) if vocab is None: vocab = Vocab() vocab.update(word2vec.keys()) if expand_vocab and vocab.mutable: for word in word2vec: vocab.get_idx(word.lower() if lowercase else word) if input_dim: assert input_dim == len(vocab), f'input_dim = {input_dim} does not match {filepath}' input_dim = len(vocab) # init matrix self._embeddings_initializer = embeddings_initializer embeddings_initializer = tf.keras.initializers.get(embeddings_initializer) with tf.device('cpu:0'): pret_embs = embeddings_initializer(shape=[input_dim, output_dim]).numpy() # insert to pret_embs for word, idx in vocab.token_to_idx.items(): vec = word2vec.get(word, None) # Retry lower case if vec is None and lowercase: vec = word2vec.get(word.lower(), None) if vec is not None: pret_embs[idx] = vec if normalize: pret_embs /= np.std(pret_embs) if not name: name = os.path.splitext(os.path.basename(filepath))[0] super().__init__(input_dim, output_dim, tf.keras.initializers.Constant(pret_embs), embeddings_regularizer, activity_regularizer, embeddings_constraint, mask_zero, input_length, name=name, **kwargs) self.filepath = filepath self.expand_vocab = expand_vocab self.lowercase = lowercase
def load_dataset(ds_name: str, save_dir=None, batch_size=128, splits=['train', 'valid', 'test']): if isinstance(splits, str): splits = [splits] helper = load_tokenizer( 'zh' ) #the tokenizer is used ***exlusively*** to access a class function for loading a dataset task = find_task_ds(ds_name) ds = {} lib = 'hanlp.datasets.' + task if task == 'classification': lib = importlib.import_module(lib + '.sentiment') elif task == 'cws': if ds_name == 'CTB6_CWS': lib = importlib.import_module(lib + '.ctb') elif 'MSR' in ds_name: lib = importlib.import_module(lib + '.sighan2005.msr') else: lib = importlib.import_module(lib + '.sighan2005.pku') elif task == 'pos': lib = importlib.import_module(lib + '.ctb') elif task == 'ner': if 'CONLL03' in ds_name: lib = importlib.import_module(lib + '.conll03') else: lib = importlib.import_module(lib + '.msra') elif task == 'dep': if 'CTB' in ds_name: lib = importlib.import_module(lib[:-3] + 'parsing.ctb') else: lib = importlib.import_module(lib[:-3] + 'parsing.semeval2016') for split in splits: url = getattr(lib, ds_name + '_' + split.upper()) input_path = get_resource(url) if split == 'train' and 'SIGHAN2005' in ds_name: from hanlp.datasets.cws.sighan2005 import make make(url) ds[split] = helper.transform.file_to_dataset(input_path, batch_size=batch_size) return ds
def load_domains(ctb_home): """ Load file ids from a Chinese treebank grouped by domains. Args: ctb_home: Root path to CTB. Returns: A dict of sets, each represents a domain. """ ctb_home = get_resource(ctb_home) ctb_root = join(ctb_home, 'bracketed') chtbs = _list_treebank_root(ctb_root) domains = defaultdict(set) for each in chtbs: name, domain = each.split('.') _, fid = name.split('_') domains[domain].add(fid) return domains
def load_file(self, filepath): """Load a ``.tsv`` file. A ``.tsv`` file for tagging is defined as a tab separated text file, where non-empty lines have two columns for token and tag respectively, empty lines mark the end of sentences. Args: filepath: Path to a ``.tsv`` tagging file. .. highlight:: bash .. code-block:: bash $ head eng.train.tsv -DOCSTART- O EU S-ORG rejects O German S-MISC call O to O boycott O British S-MISC lamb O """ filepath = get_resource(filepath) # idx = 0 for words, tags in generate_words_tags_from_tsv(filepath, lower=False): # idx += 1 # if idx % 1000 == 0: # print(f'\rRead instances {idx // 1000}k', end='') if self.max_seq_len: start = 0 for short_sents in split_long_sentence_into( words, self.max_seq_len, self.sent_delimiter, char_level=self.char_level, hard_constraint=self.hard_constraint): end = start + len(short_sents) yield {'token': short_sents, 'tag': tags[start:end]} start = end else: yield {'token': words, 'tag': tags}
def read_conll(filepath): sent = [] filepath = get_resource(filepath) with open(filepath, encoding='utf-8') as src: for line in src: if line.startswith('#'): continue cells = line.strip().split() if cells: cells[0] = int(cells[0]) cells[6] = int(cells[6]) for i, x in enumerate(cells): if x == '_': cells[i] = None sent.append(cells) else: yield sent sent = [] if sent: yield sent
def _list_dir(path, home): prefix = home.lstrip('_').replace('_HOME', '') from hanlp.utils.io_util import get_resource import glob import os path = get_resource(path) with open('ud23.py', 'a') as out: for f in sorted(glob.glob(path + '/UD_*')): basename = os.path.basename(f) name = basename[len('UD_'):] name = name.upper().replace('-', '_') for split in 'train', 'dev', 'test': sp = glob.glob(f + f'/*{split}.conllu') if not sp: continue sp = os.path.basename(sp[0]) out.write( f'{prefix}_{name}_{split.upper()} = {home} + "#{basename}/{sp}"\n' )
def make(train): root = get_resource(SIGHAN2005) train = os.path.join(root, train.split('#')[-1]) if not os.path.isfile(train): full = train.replace('_90.txt', '.utf8') logger.info( f'Splitting {full} into training set and valid set with 9:1 proportion' ) valid = train.replace('90.txt', '10.txt') split_file(full, train=0.9, dev=0.1, test=0, names={ 'train': train, 'dev': valid }) assert os.path.isfile(train), f'Failed to make {train}' assert os.path.isfile(valid), f'Failed to make {valid}' logger.info(f'Successfully made {train} {valid}')
def __init__(self, filepath: str, padding=PAD, name=None, **kwargs): import fasttext self.padding = padding.encode('utf-8') self.filepath = filepath filepath = get_resource(filepath) assert os.path.isfile(filepath), f'Resolved path {filepath} is not a file' logger.debug('Loading fasttext model from [{}].'.format(filepath)) # fasttext print a blank line here with stdout_redirected(to=os.devnull, stdout=sys.stderr): self.model = fasttext.load_model(filepath) kwargs.pop('input_dim', None) kwargs.pop('output_dim', None) kwargs.pop('mask_zero', None) if not name: name = os.path.splitext(os.path.basename(filepath))[0] super().__init__(input_dim=len(self.model.words), output_dim=self.model['king'].size, mask_zero=padding is not None, trainable=False, dtype=tf.string, name=name, **kwargs) embed_fn = np.frompyfunc(self.embed, 1, 1) # vf = np.vectorize(self.embed, otypes=[np.ndarray]) self._embed_np = embed_fn
def load_word2vec(path, delimiter=' ', cache=True) -> Tuple[Dict[str, np.ndarray], int]: realpath = get_resource(path) binpath = replace_ext(realpath, '.pkl') if cache: try: flash( 'Loading word2vec from cache [blink][yellow]...[/yellow][/blink]' ) word2vec, dim = load_pickle(binpath) flash('') return word2vec, dim except IOError: pass dim = None word2vec = dict() f = TimingFileIterator(realpath) for idx, line in enumerate(f): f.log( 'Loading word2vec from text file [blink][yellow]...[/yellow][/blink]' ) line = line.rstrip().split(delimiter) if len(line) > 2: if dim is None: dim = len(line) else: if len(line) != dim: logger.warning('{}#{} length mismatches with {}'.format( path, idx + 1, dim)) continue word, vec = line[0], line[1:] word2vec[word] = np.array(vec, dtype=np.float32) dim -= 1 if cache: flash('Caching word2vec [blink][yellow]...[/yellow][/blink]') save_pickle((word2vec, dim), binpath) flash('') return word2vec, dim
def __init__(self, data: str, batch_size, seq_len, tokenizer='char', eos='\n', strip=True, vocab=None, cache=False, transform: Union[Callable, List] = None) -> None: self.cache = cache self.eos = eos self.strip = strip super().__init__(transform) if isinstance(tokenizer, str): available_tokenizers = { 'char': ToChar('text', 'token'), 'whitespace': WhitespaceTokenizer('text', 'token') } assert tokenizer in available_tokenizers, f'{tokenizer} not supported, available options: {available_tokenizers.keys()} ' self.append_transform(available_tokenizers[tokenizer]) if vocab is None: vocab = Vocab() self.training = True else: self.training = vocab.mutable self.append_transform(AppendEOS('token', eos=eos)) self.append_transform(FieldToIndex('token', vocab)) self.batch_size = batch_size data = get_resource(data) self.data = data self.num_tokens = None self.load_file(data) self._fp = None if isinstance(seq_len, int): self.seq_len = lambda: seq_len else: self.seq_len = seq_len
def load_data(self, data, generate_idx=False): """A intermediate step between constructor and calling the actual file loading method. Args: data: If data is a file, this method calls :meth:`~hanlp.common.dataset.TransformableDataset.load_file` to load it. generate_idx: Create a :const:`~hanlp_common.constants.IDX` field for each sample to store its order in dataset. Useful for prediction when samples are re-ordered by a sampler. Returns: Loaded samples. """ if self.should_load_file(data): if isinstance(data, str): data = get_resource(data) data = list(self.load_file(data)) if generate_idx: for i, each in enumerate(data): each[IDX] = i # elif isinstance(data, list): # data = self.load_list(data) return data
def load_word2vec_as_vocab_tensor(path, delimiter=' ', cache=True) -> Tuple[Dict[str, int], torch.Tensor]: realpath = get_resource(path) vocab_path = replace_ext(realpath, '.vocab') matrix_path = replace_ext(realpath, '.pt') if cache: try: flash('Loading vocab and matrix from cache [blink][yellow]...[/yellow][/blink]') vocab = load_pickle(vocab_path) matrix = torch.load(matrix_path, map_location='cpu') flash('') return vocab, matrix except IOError: pass word2vec, dim = load_word2vec(path, delimiter, cache) vocab = dict((k, i) for i, k in enumerate(word2vec.keys())) matrix = torch.Tensor(list(word2vec.values())) if cache: flash('Caching vocab and matrix [blink][yellow]...[/yellow][/blink]') save_pickle(vocab, vocab_path) torch.save(matrix, matrix_path) flash('') return vocab, matrix
def official_conll_05_evaluate(pred_path, gold_path): script_root = get_resource( 'http://www.lsi.upc.edu/~srlconll/srlconll-1.1.tgz') lib_path = f'{script_root}/lib' if lib_path not in os.environ.get("PERL5LIB", ""): os.environ['PERL5LIB'] = f'{lib_path}:{os.environ.get("PERL5LIB", "")}' bin_path = f'{script_root}/bin' if bin_path not in os.environ.get('PATH', ''): os.environ['PATH'] = f'{bin_path}:{os.environ.get("PATH", "")}' eval_info_gold_pred = run_cmd( f'perl {script_root}/bin/srl-eval.pl {gold_path} {pred_path}') eval_info_pred_gold = run_cmd( f'perl {script_root}/bin/srl-eval.pl {pred_path} {gold_path}') conll_recall = float( eval_info_gold_pred.strip().split("\n")[6].strip().split()[5]) / 100 conll_precision = float( eval_info_pred_gold.strip().split("\n")[6].strip().split()[5]) / 100 if conll_recall + conll_precision > 0: conll_f1 = 2 * conll_recall * conll_precision / (conll_recall + conll_precision) else: conll_f1 = 0 return conll_precision, conll_recall, conll_f1