コード例 #1
0
    def __init__(self, lang_code):
        """
        Initialize deeptranslit

        Parameters:

        lang_code (str): Name or code of the language. (Currently supported: hindi/hi)

        """

        if lang_code in lang_code_mapping:
            lang_code = lang_code_mapping[lang_code]

        if lang_code not in model_links:
            print("DeepTranslit doesn't support '" + lang_code + "' yet.")
            print(
                "Please raise a issue at https://github.com/bedapudi6788/deeptranslit to add this language into future checklist."
            )
            return None

        # loading the model
        home = os.path.expanduser("~")
        lang_path = os.path.join(home, '.DeepTranslit_' + lang_code)
        checkpoint_path = os.path.join(lang_path, 'checkpoint')
        params_path = os.path.join(lang_path, 'params')

        if not os.path.exists(lang_path):
            os.mkdir(lang_path)

        if not os.path.exists(checkpoint_path):
            print('Downloading checkpoint',
                  model_links[lang_code]['checkpoint'], 'to', checkpoint_path)
            pydload.dload(url=model_links[lang_code]['checkpoint'],
                          save_to_path=checkpoint_path,
                          max_time=None)

        if not os.path.exists(params_path):
            print('Downloading model params', model_links[lang_code]['params'],
                  'to', params_path)
            pydload.dload(url=model_links[lang_code]['params'],
                          save_to_path=params_path,
                          max_time=None)

        self.model, self.params = build_model(params_path=params_path,
                                              enc_lstm_units=64,
                                              use_gru=True,
                                              display_summary=False)
        self.model.load_weights(checkpoint_path)
コード例 #2
0
    def __init__(self, lang_code, rank='auto'):
        """
        Initialize deeptranslit

        Parameters:

        lang_code (str): Name or code of the language. (Currently supported: hindi/hi)

        rank (str): Mode of ranking. In default mode ('auto') kenlm will be used if available. (simple|kenlm|auto are the supported options)

        """

        if lang_code in lang_code_mapping:
            lang_code = lang_code_mapping[lang_code]

        if lang_code not in model_links:
            print("DeepTranslit doesn't support '" + lang_code + "' yet.")
            print(
                "Please raise a issue at https://github.com/bedapudi6788/deeptranslit to add this language into future checklist."
            )
            return None

        # loading the model
        home = os.path.expanduser("~")
        lang_path = os.path.join(home, '.DeepTranslit_' + lang_code)
        checkpoint_path = os.path.join(lang_path, 'checkpoint')
        params_path = os.path.join(lang_path, 'params')
        words_path = os.path.join(lang_path, 'words')
        lm_path = os.path.join(lang_path, 'lm')

        if not os.path.exists(lang_path):
            os.mkdir(lang_path)

        if not os.path.exists(checkpoint_path):
            print('Downloading checkpoint',
                  model_links[lang_code]['checkpoint'], 'to', checkpoint_path)
            pydload.dload(url=model_links[lang_code]['checkpoint'],
                          save_to_path=checkpoint_path,
                          max_time=None)

        if not os.path.exists(params_path):
            print('Downloading model params', model_links[lang_code]['params'],
                  'to', params_path)
            pydload.dload(url=model_links[lang_code]['params'],
                          save_to_path=params_path,
                          max_time=None)

        if not os.path.exists(words_path):
            print('Downloading words', model_links[lang_code]['words'], 'to',
                  words_path)
            pydload.dload(url=model_links[lang_code]['words'],
                          save_to_path=words_path,
                          max_time=None)

        if not os.path.exists(lm_path):
            print('Downloading lm', model_links[lang_code]['lm'], 'to',
                  lm_path)
            pydload.dload(url=model_links[lang_code]['lm'],
                          save_to_path=lm_path,
                          max_time=None)

        DeepTranslit.model, DeepTranslit.params = build_model(
            params_path=params_path,
            enc_lstm_units=64,
            use_gru=True,
            display_summary=False)
        DeepTranslit.model.load_weights(checkpoint_path)

        DeepTranslit.words = pickle.load(open(words_path, 'rb'))

        if kenlm_available and rank in {'auto', 'kenlm'}:
            logging.warn('Loading KenLM.')
            DeepTranslit.lm = kenlm.Model(lm_path)
            DeepTranslit.rank = rank
コード例 #3
0
  for i in range(len(train_text[j])):
    #continue
    if train_text[j][i] not in chars:
      train_text[j]=train_text[j].replace(train_text[j][i], " ")
      #print(train_text[j])

train_text[3]

from txt2txt import build_params, build_model, convert_training_data
from keras.callbacks import ModelCheckpoint

input_data=train_text
output_data=train_text

build_params(input_data = input_data, output_data = output_data, params_path = 'params', max_lenghts=(10, 10))
    
model, params = build_model(params_path='params')

input_data, output_data = convert_training_data(input_data, output_data, params)
    
checkpoint = ModelCheckpoint('checkpoint', monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]

model.fit(input_data, output_data, validation_data=(input_data, output_data), batch_size=2, epochs=2, callbacks=callbacks_list)

from txt2txt import build_model, infer
model, params = build_model(params_path='params')
model.load_weights('checkpoint')
#infer('nmae', model, params)

infer('nmae', model, params)
コード例 #4
0
 def __init__(self, params_path, checkpoint_path):
     # loading the model
     DeepCorrect.deepcorrect_model = build_model(params_path)
     DeepCorrect.deepcorrect_model[0].load_weights(checkpoint_path)
コード例 #5
0
ファイル: correct.py プロジェクト: ntcp/deepcorrect
from txt2txt import build_model, infer

model, params = build_model(params_path='params', enc_lstm_units=256)
model.load_weights('checkpoint')

while 1:
    print('Enter input')
    print(infer(input(), model, params))