コード例 #1
0
ファイル: smatch_eval.py プロジェクト: lei1993/HanLP
def smatch_eval(pred, gold, use_fast=False) -> Union[SmatchScores, F1_]:
    script = get_resource(_FAST_SMATCH_SCRIPT if use_fast else _SMATCH_SCRIPT)
    home = os.path.dirname(script)
    pred = os.path.realpath(pred)
    gold = os.path.realpath(gold)
    with pushd(home):
        flash('Running evaluation script [blink][yellow]...[/yellow][/blink]')
        cmd = f'bash {script} {pred} {gold}'
        text = run_cmd(cmd)
        flash('')
    return format_fast_scores(text) if use_fast else format_official_scores(
        text)
コード例 #2
0
ファイル: component.py プロジェクト: zgying/HanLP
 def load_transform(self, save_dir) -> Transform:
     """
     Try to load transform only. This method might fail due to the fact it avoids building the model.
     If it do fail, then you have to use `load` which might be too heavy but that's the best we can do.
     :param save_dir: The path to load.
     """
     save_dir = get_resource(save_dir)
     self.load_config(save_dir)
     self.load_vocabs(save_dir)
     self.transform.build_config()
     self.transform.lock_vocabs()
     return self.transform
コード例 #3
0
def make_gold_conll(ontonotes_path, language):
    ensure_python_points_to_python2()
    ontonotes_path = os.path.abspath(get_resource(ontonotes_path))
    to_conll = get_resource(
        'https://gist.githubusercontent.com/hankcs/46b9137016c769e4b6137104daf43a92/raw/66369de6c24b5ec47696ae307591f0d72c6f3f02/ontonotes_to_conll.sh'
    )
    to_conll = os.path.abspath(to_conll)
    # shutil.rmtree(os.path.join(ontonotes_path, 'conll-2012'), ignore_errors=True)
    with pushd(ontonotes_path):
        try:
            flash(
                f'Converting [blue]{language}[/blue] to CoNLL format, '
                f'this might take half an hour [blink][yellow]...[/yellow][/blink]'
            )
            run_cmd(f'bash {to_conll} {ontonotes_path} {language}')
            flash('')
        except RuntimeError as e:
            flash(
                f'[red]Failed[/red] to convert {language} of {ontonotes_path} to CoNLL. See exceptions for detail'
            )
            raise e
コード例 #4
0
    def load_weights(self, save_dir, filename='model.pt', **kwargs):
        """Load weights from a directory.

        Args:
            save_dir: The directory to load weights from.
            filename: A file name for weights.
            **kwargs: Not used.
        """
        save_dir = get_resource(save_dir)
        filename = os.path.join(save_dir, filename)
        # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]')
        self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False)
コード例 #5
0
    def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None,
                 callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs):
        input_path = get_resource(input_path)
        file_prefix, ext = os.path.splitext(input_path)
        name = os.path.basename(file_prefix)
        if not name:
            name = 'evaluate'
        if save_dir and not logger:
            logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN,
                                 mode='w')
        tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size)
        samples = self.num_samples_in(tst_data)
        num_batches = math.ceil(samples / batch_size)
        if warm_up:
            for x, y in tst_data:
                self.model.predict_on_batch(x)
                break
        if output:
            assert save_dir, 'Must pass save_dir in order to output'
            if isinstance(output, bool):
                output = os.path.join(save_dir, name) + '.predict' + ext
            elif isinstance(output, str):
                output = output
            else:
                raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output)))
        timer = Timer()
        eval_outputs = self.evaluate_dataset(tst_data, callbacks, output, num_batches, **kwargs)
        loss, score, output = eval_outputs[0], eval_outputs[1], eval_outputs[2]
        delta_time = timer.stop()
        speed = samples / delta_time.delta_seconds

        if logger:
            f1: IOBES_F1_TF = None
            for metric in self.model.metrics:
                if isinstance(metric, IOBES_F1_TF):
                    f1 = metric
                    break
            extra_report = ''
            if f1:
                overall, by_type, extra_report = f1.state.result(full=True, verbose=False)
                extra_report = ' \n' + extra_report
            logger.info('Evaluation results for {} - '
                        'loss: {:.4f} - {} - speed: {:.2f} sample/sec{}'
                        .format(name + ext, loss,
                                format_scores(score) if isinstance(score, dict) else format_metrics(self.model.metrics),
                                speed, extra_report))
        if output:
            logger.info('Saving output to {}'.format(output))
            with open(output, 'w', encoding='utf-8') as out:
                self.evaluate_output(tst_data, out, num_batches, self.model.metrics)

        return loss, score, speed
コード例 #6
0
ファイル: fast_text.py プロジェクト: lei1993/HanLP
 def __init__(self, filepath: str, src, dst=None, **kwargs) -> None:
     if not dst:
         dst = src + '_fasttext'
     self.filepath = filepath
     flash(
         f'Loading fasttext model {filepath} [blink][yellow]...[/yellow][/blink]'
     )
     filepath = get_resource(filepath)
     with stdout_redirected(to=os.devnull, stdout=sys.stderr):
         self._model = fasttext.load_model(filepath)
     flash('')
     output_dim = self._model['king'].size
     super().__init__(output_dim, src, dst)
コード例 #7
0
def make_ontonotes_language_jsonlines(conll12_ontonotes_path,
                                      output_path=None,
                                      language='english'):
    conll12_ontonotes_path = get_resource(conll12_ontonotes_path)
    if output_path is None:
        output_path = os.path.dirname(conll12_ontonotes_path)
    for split in ['train', 'development', 'test']:
        pattern = f'{conll12_ontonotes_path}/data/{split}/data/{language}/annotations/*/*/*/*gold_conll'
        files = sorted(glob.glob(pattern, recursive=True))
        assert files, f'No gold_conll files found in {pattern}'
        version = os.path.basename(files[0]).split('.')[-1].split('_')[0]
        if version.startswith('v'):
            assert all([version in os.path.basename(f) for f in files])
        else:
            version = 'v5'
        lang_dir = f'{output_path}/{language}'
        if split == 'conll-2012-test':
            split = 'test'
        full_file = f'{lang_dir}/{split}.{language}.{version}_gold_conll'
        os.makedirs(lang_dir, exist_ok=True)
        print(f'Merging {len(files)} files to {full_file}')
        merge_files(files, full_file)
        v5_json_file = full_file.replace(f'.{version}_gold_conll',
                                         f'.{version}.jsonlines')
        print(f'Converting CoNLL file {full_file} to json file {v5_json_file}')
        labels, stats = convert_to_jsonlines(full_file, v5_json_file, language)
        print('Labels:')
        pprint(labels)
        print('Statistics:')
        pprint(stats)
        conll12_json_file = f'{lang_dir}/{split}.{language}.conll12.jsonlines'
        print(
            f'Applying CoNLL 12 official splits on {v5_json_file} to {conll12_json_file}'
        )
        id_file = get_resource(
            f'https://od.hankcs.com/research/emnlp2021/conll.cemantix.org.zip#2012/download/ids/'
            f'{language}/coref/{split}.id')
        filter_data(v5_json_file, conll12_json_file, id_file)
コード例 #8
0
ファイル: transform.py プロジェクト: zmjm4/HanLP
 def __init__(self,
              mapper: Union[str, dict],
              src: str,
              dst: str = None) -> None:
     super().__init__(src, dst)
     self.mapper = mapper
     if isinstance(mapper, str):
         mapper = get_resource(mapper)
     if isinstance(mapper, str):
         self._table = load_json(mapper)
     elif isinstance(mapper, dict):
         self._table = mapper
     else:
         raise ValueError(f'Unrecognized mapper type {mapper}')
コード例 #9
0
ファイル: conll.py プロジェクト: lei1993/HanLP
def read_conll(filepath: Union[str, TimingFileIterator],
               underline_to_none=False,
               enhanced_collapse_empty_nodes=False):
    sent = []
    if isinstance(filepath, str):
        filepath: str = get_resource(filepath)
        if filepath.endswith(
                '.conllu') and enhanced_collapse_empty_nodes is None:
            enhanced_collapse_empty_nodes = True
        src = open(filepath, encoding='utf-8')
    else:
        src = filepath
    for idx, line in enumerate(src):
        if line.startswith('#'):
            continue
        line = line.strip()
        cells = line.split('\t')
        if line and cells:
            if enhanced_collapse_empty_nodes and '.' in cells[0]:
                cells[0] = float(cells[0])
                cells[6] = None
            else:
                if '-' in cells[0] or '.' in cells[0]:
                    # sent[-1][1] += cells[1]
                    continue
                cells[0] = int(cells[0])
                if cells[6] != '_':
                    try:
                        cells[6] = int(cells[6])
                    except ValueError:
                        cells[6] = 0
                        logger.exception(
                            f'Wrong CoNLL format {filepath}:{idx + 1}\n{line}')
            if underline_to_none:
                for i, x in enumerate(cells):
                    if x == '_':
                        cells[i] = None
            sent.append(cells)
        else:
            if enhanced_collapse_empty_nodes:
                sent = collapse_enhanced_empty_nodes(sent)
            yield sent
            sent = []

    if sent:
        if enhanced_collapse_empty_nodes:
            sent = collapse_enhanced_empty_nodes(sent)
        yield sent

    src.close()
コード例 #10
0
ファイル: component_util.py プロジェクト: zuoqy/HanLP
def load_from_meta_file(save_dir,
                        meta_filename='meta.json',
                        transform_only=False,
                        load_kwargs=None,
                        **kwargs) -> Component:
    identifier = save_dir
    load_path = save_dir
    save_dir = get_resource(save_dir)
    metapath = os.path.join(save_dir, meta_filename)
    if not os.path.isfile(metapath):
        tips = ''
        if save_dir.isupper():
            from difflib import SequenceMatcher
            similar_keys = sorted(pretrained.ALL.keys(),
                                  key=lambda k: SequenceMatcher(
                                      None, save_dir, metapath).ratio(),
                                  reverse=True)[:5]
            tips = f'Check its spelling based on the available keys:\n' + \
                   f'{sorted(pretrained.ALL.keys())}\n' + \
                   f'Tips: it might be one of {similar_keys}'
        raise FileNotFoundError(
            f'The identifier {save_dir} resolves to a non-exist meta file {metapath}. {tips}'
        )
    meta: dict = load_json(metapath)
    cls = meta.get('class_path', None)
    assert cls, f'{meta_filename} doesn\'t contain class_path field'
    try:
        obj: Component = object_from_class_path(cls, **kwargs)
        if hasattr(obj, 'load') and os.path.isfile(
                os.path.join(save_dir, 'config.json')):
            if transform_only:
                # noinspection PyUnresolvedReferences
                obj.load_transform(save_dir)
            else:
                if load_kwargs is None:
                    load_kwargs = {}
                obj.load(save_dir, **load_kwargs)
            obj.meta['load_path'] = load_path
        return obj
    except Exception as e:
        eprint(f'Failed to load {identifier}. See stack trace below')
        traceback.print_exc()
        model_version = meta.get("hanlp_version", "unknown")
        cur_version = version.__version__
        if model_version != cur_version:
            eprint(
                f'{identifier} was created with hanlp-{model_version}, while you are running {cur_version}. '
                f'Try to upgrade hanlp with\n'
                f'pip install --upgrade hanlp')
        exit(1)
コード例 #11
0
def ctb_pos_to_text_format(path, delimiter='_'):
    """
    Convert ctb pos tagging corpus from tsv format to text format, where each word is followed by
    its pos tag.
    Args:
        path: File to be converted.
        delimiter: Delimiter between word and tag.
    """
    path = get_resource(path)
    name, ext = os.path.splitext(path)
    with open(f'{name}.txt', 'w', encoding='utf-8') as out:
        for sent in read_tsv_as_sents(path):
            out.write(' '.join([delimiter.join(x) for x in sent]))
            out.write('\n')
コード例 #12
0
ファイル: torch_component.py プロジェクト: zhoumo99133/HanLP
    def load_config(self, save_dir, filename='config.json', **kwargs):
        """Load config from a directory.

        Args:
            save_dir: The directory to load config.
            filename: A file name for config.
            **kwargs: K-V pairs to override config.
        """
        save_dir = get_resource(save_dir)
        self.config.load_json(os.path.join(save_dir, filename))
        self.config.update(kwargs)  # overwrite config loaded from disk
        for k, v in self.config.items():
            if isinstance(v, dict) and 'classpath' in v:
                self.config[k] = Configurable.from_config(v)
        self.on_config_ready(**self.config)
コード例 #13
0
def _list_dir(path, home):
    prefix = home.lstrip('_').replace('_HOME', '')

    path = get_resource(path)
    with open('ud27.py', 'a') as out:
        for f in sorted(glob.glob(path + '/ud-treebanks-v2.7/UD_*')):
            basename = os.path.basename(f)
            name = basename[len('UD_'):]
            name = name.upper().replace('-', '_')
            for split in 'train', 'dev', 'test':
                sp = glob.glob(f + f'/*{split}.conllu')
                if not sp:
                    continue
                sp = os.path.basename(sp[0])
                out.write(f'{prefix}_{name}_{split.upper()} = {home} + "{basename}/{sp}"\n')
                out.write(f'"{prefix} {split} set of {name}."\n')
コード例 #14
0
ファイル: component.py プロジェクト: zhouxinfei/HanLP
 def load(self,
          save_dir: str,
          logger=hanlp.utils.log_util.logger,
          **kwargs):
     self.meta['load_path'] = save_dir
     save_dir = get_resource(save_dir)
     self.load_config(save_dir)
     self.load_vocabs(save_dir)
     self.build(**merge_dict(self.config,
                             training=False,
                             logger=logger,
                             **kwargs,
                             overwrite=True,
                             inplace=True))
     self.load_weights(save_dir)
     self.load_meta(save_dir)
コード例 #15
0
def convert_to_stanford_dependency_330(src, dst):
    cprint(
        f'Converting {os.path.basename(src)} to {os.path.basename(dst)} using Stanford Parser Version 3.3.0. '
        f'It might take a while [blink][yellow]...[/yellow][/blink]')
    sp_home = 'https://nlp.stanford.edu/software/stanford-parser-full-2013-11-12.zip'
    sp_home = get_resource(sp_home)
    # jar_path = get_resource(f'{sp_home}#stanford-parser.jar')
    code, out, err = get_exitcode_stdout_stderr(
        f'java -cp {sp_home}/* edu.stanford.nlp.trees.international.pennchinese.ChineseGrammaticalStructure '
        f'-basic -keepPunct -conllx '
        f'-treeFile {src}')
    with open(dst, 'w') as f:
        f.write(out)
    if code:
        raise RuntimeError(
            f'Conversion failed with code {code} for {src}. The err message is:\n {err}\n'
            f'Do you have java installed? Do you have enough memory?')
コード例 #16
0
ファイル: chunking_dataset.py プロジェクト: lei1993/HanLP
 def _generate_chars_tags(filepath, delimiter, max_seq_len):
     filepath = get_resource(filepath)
     with open(filepath, encoding='utf8') as src:
         for text in src:
             chars, tags = bmes_of(text, True)
             if max_seq_len and delimiter and len(chars) > max_seq_len:
                 short_chars, short_tags = [], []
                 for idx, (char, tag) in enumerate(zip(chars, tags)):
                     short_chars.append(char)
                     short_tags.append(tag)
                     if len(short_chars) >= max_seq_len and delimiter(char):
                         yield short_chars, short_tags
                         short_chars, short_tags = [], []
                 if short_chars:
                     yield short_chars, short_tags
             else:
                 yield chars, tags
コード例 #17
0
 def export_model_for_serving(self, export_dir=None, version=1, overwrite=False, show_hint=False):
     assert self.model, 'You have to fit or load a model before exporting it'
     if not export_dir:
         assert 'load_path' in self.meta, 'When not specifying save_dir, load_path has to present'
         export_dir = get_resource(self.meta['load_path'])
     model_path = os.path.join(export_dir, str(version))
     if os.path.isdir(model_path) and not overwrite:
         logger.info(f'{model_path} exists, skip since overwrite = {overwrite}')
         return export_dir
     logger.info(f'Exporting to {export_dir} ...')
     tf.saved_model.save(self.model, model_path)
     logger.info(f'Successfully exported model to {export_dir}')
     if show_hint:
         logger.info(f'You can serve it through \n'
                     f'tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} '
                     f'--model_base_path={export_dir} --rest_api_port=8888')
     return export_dir
コード例 #18
0
 def __init__(self, filepath: str = None, vocab: Vocab = None, expand_vocab=True, lowercase=True,
              input_dim=None, output_dim=None, unk=None, normalize=False,
              embeddings_initializer='VarianceScaling',
              embeddings_regularizer=None,
              activity_regularizer=None, embeddings_constraint=None, mask_zero=True, input_length=None,
              name=None, **kwargs):
     filepath = get_resource(filepath)
     word2vec, _output_dim = load_word2vec(filepath)
     if output_dim:
         assert output_dim == _output_dim, f'output_dim = {output_dim} does not match {filepath}'
     output_dim = _output_dim
     # if the `unk` token exists in the pretrained,
     # then replace it with a self-defined one, usually the one in word vocab
     if unk and unk in word2vec:
         word2vec[vocab.safe_unk_token] = word2vec.pop(unk)
     if vocab is None:
         vocab = Vocab()
         vocab.update(word2vec.keys())
     if expand_vocab and vocab.mutable:
         for word in word2vec:
             vocab.get_idx(word.lower() if lowercase else word)
     if input_dim:
         assert input_dim == len(vocab), f'input_dim = {input_dim} does not match {filepath}'
     input_dim = len(vocab)
     # init matrix
     self._embeddings_initializer = embeddings_initializer
     embeddings_initializer = tf.keras.initializers.get(embeddings_initializer)
     with tf.device('cpu:0'):
         pret_embs = embeddings_initializer(shape=[input_dim, output_dim]).numpy()
     # insert to pret_embs
     for word, idx in vocab.token_to_idx.items():
         vec = word2vec.get(word, None)
         # Retry lower case
         if vec is None and lowercase:
             vec = word2vec.get(word.lower(), None)
         if vec is not None:
             pret_embs[idx] = vec
     if normalize:
         pret_embs /= np.std(pret_embs)
     if not name:
         name = os.path.splitext(os.path.basename(filepath))[0]
     super().__init__(input_dim, output_dim, tf.keras.initializers.Constant(pret_embs), embeddings_regularizer,
                      activity_regularizer, embeddings_constraint, mask_zero, input_length, name=name, **kwargs)
     self.filepath = filepath
     self.expand_vocab = expand_vocab
     self.lowercase = lowercase
コード例 #19
0
ファイル: hanlp_wrapper.py プロジェクト: ragatti/sotaai
def load_dataset(ds_name: str,
                 save_dir=None,
                 batch_size=128,
                 splits=['train', 'valid', 'test']):
    if isinstance(splits, str):
        splits = [splits]
    helper = load_tokenizer(
        'zh'
    )  #the tokenizer is used ***exlusively*** to access a class function for loading a dataset
    task = find_task_ds(ds_name)
    ds = {}
    lib = 'hanlp.datasets.' + task
    if task == 'classification':
        lib = importlib.import_module(lib + '.sentiment')
    elif task == 'cws':
        if ds_name == 'CTB6_CWS':
            lib = importlib.import_module(lib + '.ctb')
        elif 'MSR' in ds_name:
            lib = importlib.import_module(lib + '.sighan2005.msr')
        else:
            lib = importlib.import_module(lib + '.sighan2005.pku')
    elif task == 'pos':
        lib = importlib.import_module(lib + '.ctb')
    elif task == 'ner':
        if 'CONLL03' in ds_name:
            lib = importlib.import_module(lib + '.conll03')
        else:
            lib = importlib.import_module(lib + '.msra')
    elif task == 'dep':
        if 'CTB' in ds_name:
            lib = importlib.import_module(lib[:-3] + 'parsing.ctb')
        else:
            lib = importlib.import_module(lib[:-3] + 'parsing.semeval2016')

    for split in splits:
        url = getattr(lib, ds_name + '_' + split.upper())
        input_path = get_resource(url)
        if split == 'train' and 'SIGHAN2005' in ds_name:
            from hanlp.datasets.cws.sighan2005 import make
            make(url)

        ds[split] = helper.transform.file_to_dataset(input_path,
                                                     batch_size=batch_size)
    return ds
コード例 #20
0
def load_domains(ctb_home):
    """
    Load file ids from a Chinese treebank grouped by domains.

    Args:
        ctb_home: Root path to CTB.

    Returns:
        A dict of sets, each represents a domain.
    """
    ctb_home = get_resource(ctb_home)
    ctb_root = join(ctb_home, 'bracketed')
    chtbs = _list_treebank_root(ctb_root)
    domains = defaultdict(set)
    for each in chtbs:
        name, domain = each.split('.')
        _, fid = name.split('_')
        domains[domain].add(fid)
    return domains
コード例 #21
0
ファイル: tsv.py プロジェクト: lei1993/HanLP
    def load_file(self, filepath):
        """Load a ``.tsv`` file. A ``.tsv`` file for tagging is defined as a tab separated text file, where non-empty
        lines have two columns for token and tag respectively, empty lines mark the end of sentences.

        Args:
            filepath: Path to a ``.tsv`` tagging file.

        .. highlight:: bash
        .. code-block:: bash

            $ head eng.train.tsv
            -DOCSTART-      O

            EU      S-ORG
            rejects O
            German  S-MISC
            call    O
            to      O
            boycott O
            British S-MISC
            lamb    O

        """
        filepath = get_resource(filepath)
        # idx = 0
        for words, tags in generate_words_tags_from_tsv(filepath, lower=False):
            # idx += 1
            # if idx % 1000 == 0:
            #     print(f'\rRead instances {idx // 1000}k', end='')
            if self.max_seq_len:
                start = 0
                for short_sents in split_long_sentence_into(
                        words,
                        self.max_seq_len,
                        self.sent_delimiter,
                        char_level=self.char_level,
                        hard_constraint=self.hard_constraint):
                    end = start + len(short_sents)
                    yield {'token': short_sents, 'tag': tags[start:end]}
                    start = end
            else:
                yield {'token': words, 'tag': tags}
コード例 #22
0
ファイル: conll.py プロジェクト: zuoqy/HanLP
def read_conll(filepath):
    sent = []
    filepath = get_resource(filepath)
    with open(filepath, encoding='utf-8') as src:
        for line in src:
            if line.startswith('#'):
                continue
            cells = line.strip().split()
            if cells:
                cells[0] = int(cells[0])
                cells[6] = int(cells[6])
                for i, x in enumerate(cells):
                    if x == '_':
                        cells[i] = None
                sent.append(cells)
            else:
                yield sent
                sent = []
    if sent:
        yield sent
コード例 #23
0
ファイル: ud23.py プロジェクト: lei1993/HanLP
def _list_dir(path, home):
    prefix = home.lstrip('_').replace('_HOME', '')

    from hanlp.utils.io_util import get_resource
    import glob
    import os
    path = get_resource(path)
    with open('ud23.py', 'a') as out:
        for f in sorted(glob.glob(path + '/UD_*')):
            basename = os.path.basename(f)
            name = basename[len('UD_'):]
            name = name.upper().replace('-', '_')
            for split in 'train', 'dev', 'test':
                sp = glob.glob(f + f'/*{split}.conllu')
                if not sp:
                    continue
                sp = os.path.basename(sp[0])
                out.write(
                    f'{prefix}_{name}_{split.upper()} = {home} + "#{basename}/{sp}"\n'
                )
コード例 #24
0
ファイル: __init__.py プロジェクト: lei1993/HanLP
def make(train):
    root = get_resource(SIGHAN2005)
    train = os.path.join(root, train.split('#')[-1])
    if not os.path.isfile(train):
        full = train.replace('_90.txt', '.utf8')
        logger.info(
            f'Splitting {full} into training set and valid set with 9:1 proportion'
        )
        valid = train.replace('90.txt', '10.txt')
        split_file(full,
                   train=0.9,
                   dev=0.1,
                   test=0,
                   names={
                       'train': train,
                       'dev': valid
                   })
        assert os.path.isfile(train), f'Failed to make {train}'
        assert os.path.isfile(valid), f'Failed to make {valid}'
        logger.info(f'Successfully made {train} {valid}')
コード例 #25
0
ファイル: fast_text_tf.py プロジェクト: lei1993/HanLP
 def __init__(self, filepath: str, padding=PAD, name=None, **kwargs):
     import fasttext
     self.padding = padding.encode('utf-8')
     self.filepath = filepath
     filepath = get_resource(filepath)
     assert os.path.isfile(filepath), f'Resolved path {filepath} is not a file'
     logger.debug('Loading fasttext model from [{}].'.format(filepath))
     # fasttext print a blank line here
     with stdout_redirected(to=os.devnull, stdout=sys.stderr):
         self.model = fasttext.load_model(filepath)
     kwargs.pop('input_dim', None)
     kwargs.pop('output_dim', None)
     kwargs.pop('mask_zero', None)
     if not name:
         name = os.path.splitext(os.path.basename(filepath))[0]
     super().__init__(input_dim=len(self.model.words), output_dim=self.model['king'].size,
                      mask_zero=padding is not None, trainable=False, dtype=tf.string, name=name, **kwargs)
     embed_fn = np.frompyfunc(self.embed, 1, 1)
     # vf = np.vectorize(self.embed, otypes=[np.ndarray])
     self._embed_np = embed_fn
コード例 #26
0
ファイル: torch_util.py プロジェクト: wainshine/HanLP
def load_word2vec(path,
                  delimiter=' ',
                  cache=True) -> Tuple[Dict[str, np.ndarray], int]:
    realpath = get_resource(path)
    binpath = replace_ext(realpath, '.pkl')
    if cache:
        try:
            flash(
                'Loading word2vec from cache [blink][yellow]...[/yellow][/blink]'
            )
            word2vec, dim = load_pickle(binpath)
            flash('')
            return word2vec, dim
        except IOError:
            pass

    dim = None
    word2vec = dict()
    f = TimingFileIterator(realpath)
    for idx, line in enumerate(f):
        f.log(
            'Loading word2vec from text file [blink][yellow]...[/yellow][/blink]'
        )
        line = line.rstrip().split(delimiter)
        if len(line) > 2:
            if dim is None:
                dim = len(line)
            else:
                if len(line) != dim:
                    logger.warning('{}#{} length mismatches with {}'.format(
                        path, idx + 1, dim))
                    continue
            word, vec = line[0], line[1:]
            word2vec[word] = np.array(vec, dtype=np.float32)
    dim -= 1
    if cache:
        flash('Caching word2vec [blink][yellow]...[/yellow][/blink]')
        save_pickle((word2vec, dim), binpath)
        flash('')
    return word2vec, dim
コード例 #27
0
ファイル: lm_dataset.py プロジェクト: cfy42584125/HanLP-1
    def __init__(self,
                 data: str,
                 batch_size,
                 seq_len,
                 tokenizer='char',
                 eos='\n',
                 strip=True,
                 vocab=None,
                 cache=False,
                 transform: Union[Callable, List] = None) -> None:
        self.cache = cache
        self.eos = eos
        self.strip = strip
        super().__init__(transform)
        if isinstance(tokenizer, str):
            available_tokenizers = {
                'char': ToChar('text', 'token'),
                'whitespace': WhitespaceTokenizer('text', 'token')
            }
            assert tokenizer in available_tokenizers, f'{tokenizer} not supported, available options: {available_tokenizers.keys()} '
            self.append_transform(available_tokenizers[tokenizer])

        if vocab is None:
            vocab = Vocab()
            self.training = True
        else:
            self.training = vocab.mutable
        self.append_transform(AppendEOS('token', eos=eos))
        self.append_transform(FieldToIndex('token', vocab))
        self.batch_size = batch_size
        data = get_resource(data)
        self.data = data
        self.num_tokens = None
        self.load_file(data)
        self._fp = None
        if isinstance(seq_len, int):
            self.seq_len = lambda: seq_len
        else:
            self.seq_len = seq_len
コード例 #28
0
ファイル: dataset.py プロジェクト: zhoumo99133/HanLP
    def load_data(self, data, generate_idx=False):
        """A intermediate step between constructor and calling the actual file loading method.

        Args:
            data: If data is a file, this method calls :meth:`~hanlp.common.dataset.TransformableDataset.load_file`
                to load it.
            generate_idx: Create a :const:`~hanlp_common.constants.IDX` field for each sample to store its order in dataset. Useful for prediction when
                samples are re-ordered by a sampler.

        Returns: Loaded samples.

        """
        if self.should_load_file(data):
            if isinstance(data, str):
                data = get_resource(data)
            data = list(self.load_file(data))
        if generate_idx:
            for i, each in enumerate(data):
                each[IDX] = i
        # elif isinstance(data, list):
        #     data = self.load_list(data)
        return data
コード例 #29
0
def load_word2vec_as_vocab_tensor(path, delimiter=' ', cache=True) -> Tuple[Dict[str, int], torch.Tensor]:
    realpath = get_resource(path)
    vocab_path = replace_ext(realpath, '.vocab')
    matrix_path = replace_ext(realpath, '.pt')
    if cache:
        try:
            flash('Loading vocab and matrix from cache [blink][yellow]...[/yellow][/blink]')
            vocab = load_pickle(vocab_path)
            matrix = torch.load(matrix_path, map_location='cpu')
            flash('')
            return vocab, matrix
        except IOError:
            pass

    word2vec, dim = load_word2vec(path, delimiter, cache)
    vocab = dict((k, i) for i, k in enumerate(word2vec.keys()))
    matrix = torch.Tensor(list(word2vec.values()))
    if cache:
        flash('Caching vocab and matrix [blink][yellow]...[/yellow][/blink]')
        save_pickle(vocab, vocab_path)
        torch.save(matrix, matrix_path)
        flash('')
    return vocab, matrix
コード例 #30
0
def official_conll_05_evaluate(pred_path, gold_path):
    script_root = get_resource(
        'http://www.lsi.upc.edu/~srlconll/srlconll-1.1.tgz')
    lib_path = f'{script_root}/lib'
    if lib_path not in os.environ.get("PERL5LIB", ""):
        os.environ['PERL5LIB'] = f'{lib_path}:{os.environ.get("PERL5LIB", "")}'
    bin_path = f'{script_root}/bin'
    if bin_path not in os.environ.get('PATH', ''):
        os.environ['PATH'] = f'{bin_path}:{os.environ.get("PATH", "")}'
    eval_info_gold_pred = run_cmd(
        f'perl {script_root}/bin/srl-eval.pl {gold_path} {pred_path}')
    eval_info_pred_gold = run_cmd(
        f'perl {script_root}/bin/srl-eval.pl {pred_path} {gold_path}')
    conll_recall = float(
        eval_info_gold_pred.strip().split("\n")[6].strip().split()[5]) / 100
    conll_precision = float(
        eval_info_pred_gold.strip().split("\n")[6].strip().split()[5]) / 100
    if conll_recall + conll_precision > 0:
        conll_f1 = 2 * conll_recall * conll_precision / (conll_recall +
                                                         conll_precision)
    else:
        conll_f1 = 0
    return conll_precision, conll_recall, conll_f1