コード例 #1
0
    def get_or_create(config,
                      src_dict_path=None,
                      tgt_dict_path=None,
                      weights_path=None,
                      embedding_file=None,
                      optimizer=Adam(),
                      encoding="utf-8"):
        if DLSegmenter.__singleton is None:
            if type(config) == str:
                with open(config, encoding=encoding) as file:
                    config = dict(json.load(file))
            elif type(config) == dict:
                config = config
            else:
                raise ValueError("Unexpect config type!")

            if src_dict_path is not None:
                src_tokenizer = load_dictionary(src_dict_path, encoding)
                config['src_tokenizer'] = src_tokenizer
                if embedding_file is not None:
                    emb_matrix = create_embedding_matrix(
                        get_embedding_index(embedding_file),
                        src_tokenizer.word_index,
                        min(config['vocab_size'] + 1,
                            config['max_num_words']), config['embed_dim'])
                    config['emb_matrix'] = emb_matrix
            if tgt_dict_path is not None:
                config['tgt_tokenizer'] = load_dictionary(
                    tgt_dict_path, encoding)

            config['weights_path'] = weights_path
            config['optimizer'] = optimizer
            DLSegmenter.__singleton = DLSegmenter(**config)
        return DLSegmenter.__singleton
コード例 #2
0
    def get_or_create(config,
                      src_dict_path=None,
                      tgt_dict_path=None,
                      weights_path=None,
                      embedding_file=None,
                      optimizer=Adam(),
                      rule_fn=None,
                      encoding="utf-8"):
        DLSegmenter.__lock.acquire()
        try:
            if DLSegmenter.__singleton is None:
                if type(config) == str:
                    with open(config, encoding=encoding) as file:
                        config = dict(json.load(file))
                elif type(config) == dict:
                    config = config
                else:
                    raise ValueError("Unexpect config type!")

                if src_dict_path is not None:
                    src_tokenizer = load_dictionary(src_dict_path, encoding)
                    config['src_tokenizer'] = src_tokenizer
                    if embedding_file is not None:
                        embedding_index = get_embedding_index(embedding_file)
                        print('字词向量大小:{}'.format(len(embedding_index)))

                        emb_matrix = create_embedding_matrix(
                            embedding_index, src_tokenizer.word_index,
                            min(config['vocab_size'] + 1,
                                config['max_num_words']), config['embed_dim'])
                        config['emb_matrix'] = emb_matrix
                if tgt_dict_path is not None:
                    config['tgt_tokenizer'] = load_dictionary(
                        tgt_dict_path, encoding)

                config['rule_fn'] = rule_fn
                config['weights_path'] = weights_path
                config['optimizer'] = optimizer
                DLSegmenter.__singleton = DLSegmenter(**config)
        except Exception:
            traceback.print_exc()
        finally:
            DLSegmenter.__lock.release()
        return DLSegmenter.__singleton
コード例 #3
0
 def __init__(self,
              src_dict_path,
              tgt_dict_path,
              batch_size=64,
              max_len=999,
              fix_len=True,
              word_delimiter=' ',
              sent_delimiter='\t',
              encoding="utf-8",
              sparse_target=False):
     self.src_tokenizer = load_dictionary(src_dict_path, encoding)
     self.tgt_tokenizer = load_dictionary(tgt_dict_path, encoding)
     self.batch_size = batch_size
     self.max_len = max_len
     self.fix_len = fix_len
     self.word_delimiter = word_delimiter
     self.sent_delimiter = sent_delimiter
     self.src_vocab_size = self.src_tokenizer.num_words
     self.tgt_vocab_size = self.tgt_tokenizer.num_words
     self.sparse_target = sparse_target