Example #1
0
def test_inferring(infer_enc_model, infer_dec_model, plot=False):
    """ Inferring with trained model """
    rand_test_ids = np.random.randint(0, len(ts_en_text), size=10)
    for rid in rand_test_ids:
        test_en = ts_en_text[rid]
        # logger.info('\tRequest: {}'.format(test_en))
        print('Request: {}'.format(test_en))

        test_en_seq = sents2sequences(en_tokenizer, [test_en],
                                      pad_length=en_timesteps)
        test_fr, attn_weights = infer_nmt(encoder_model=infer_enc_model,
                                          decoder_model=infer_dec_model,
                                          test_en_seq=test_en_seq,
                                          en_vsize=en_vsize,
                                          fr_vsize=fr_vsize)
        print(Fore.GREEN + 'Response: {}'.format(test_fr) + Style.RESET_ALL)
        # print()

        if plot:
            """ Attention plotting """
            plot_attention_weights(test_en_seq,
                                   attn_weights,
                                   en_index2word,
                                   fr_index2word,
                                   base_dir=base_dir,
                                   filename='attention_{}.png'.format(rid))
Example #2
0
    """ Save model """
    if not os.path.exists(config.MODELS_DIR):
        os.mkdir(config.MODELS_DIR)
    full_model.save(os.path.join(config.MODELS_DIR, 'nmt_bidirectional.h5'))
    """ Index2word """
    en_index2word = dict(
        zip(en_tokenizer.word_index.values(), en_tokenizer.word_index.keys()))
    fr_index2word = dict(
        zip(fr_tokenizer.word_index.values(), fr_tokenizer.word_index.keys()))
    """ Inferring with trained model """

    np.random.seed(100)
    rand_test_ids = np.random.randint(0, len(ts_en_text), size=10)
    for rid in rand_test_ids:
        test_en = ts_en_text[rid]
        logger.info('\nTranslating: {}'.format(test_en))

        test_en_seq = sents2sequences(en_tokenizer, [test_en],
                                      pad_length=en_timesteps)
        test_fr, attn_weights = infer_nmt(encoder_model=infer_enc_model,
                                          decoder_model=infer_dec_model,
                                          test_en_seq=test_en_seq,
                                          en_vsize=en_vsize,
                                          fr_vsize=fr_vsize)
        logger.info('\tFrench: {}'.format(test_fr))
        """ Attention plotting """
        plot_attention_weights(test_en_seq,
                               attn_weights,
                               en_index2word,
                               fr_index2word,
                               filename='attention_{}.png'.format(rid))
Example #3
0
        en_timesteps=en_timesteps,
        fr_timesteps=fr_timesteps,
        en_vsize=en_vsize,
        fr_vsize=fr_vsize)

    n_epochs = 10 if not debug else 3
    train(full_model, en_seq, fr_seq, batch_size, n_epochs)
    """ Save model """
    if not os.path.exists(os.path.join('..', 'h5.models')):
        os.mkdir(os.path.join('..', 'h5.models'))
    full_model.save(os.path.join('..', 'h5.models', 'nmt.h5'))
    """ Index2word """
    en_index2word = dict(
        zip(en_tokenizer.word_index.values(), en_tokenizer.word_index.keys()))
    fr_index2word = dict(
        zip(fr_tokenizer.word_index.values(), fr_tokenizer.word_index.keys()))
    """ Inferring with trained model """
    test_en = ts_en_text[0]
    logger.info('Translating: {}'.format(test_en))

    test_en_seq = sents2sequences(en_tokenizer, [test_en],
                                  pad_length=en_timesteps)
    test_fr, attn_weights = infer_nmt(encoder_model=infer_enc_model,
                                      decoder_model=infer_dec_model,
                                      test_en_seq=test_en_seq,
                                      en_vsize=en_vsize,
                                      fr_vsize=fr_vsize)
    logger.info('\tFrench: {}'.format(test_fr))
    """ Attention plotting """
    plot_attention_weights(test_en_seq, attn_weights, en_index2word,
                           fr_index2word)
        fr_vsize=fr_vsize)

    n_epochs = 10 if not debug else 3
    train(full_model, en_seq, fr_seq, batch_size, n_epochs)
    """ Save model """
    if not os.path.exists(os.path.join('..', 'h5.models')):
        os.mkdir(os.path.join('..', 'h5.models'))
    full_model.save(os.path.join('..', 'h5.models', 'nmt.h5'))
    """ Index2word """
    en_index2word = dict(
        zip(en_tokenizer.word_index.values(), en_tokenizer.word_index.keys()))
    fr_index2word = dict(
        zip(fr_tokenizer.word_index.values(), fr_tokenizer.word_index.keys()))
    """ Inferring with trained model """
    test_en = ts_en_text[0]
    logger.info('Translating: {}'.format(test_en))

    test_en_seq = sents2sequences(en_tokenizer, [test_en],
                                  pad_length=en_timesteps)
    test_fr, attn_weights = infer_nmt(encoder_model=infer_enc_model,
                                      decoder_model=infer_dec_model,
                                      test_en_seq=test_en_seq,
                                      en_vsize=en_vsize,
                                      fr_vsize=fr_vsize)
    logger.info('\tFrench: {}'.format(test_fr))
    """ Attention plotting """
    plot_attention_weights(test_en_seq,
                           attn_weights,
                           en_index2word,
                           fr_index2word,
                           base_dir=base_dir)
    fr_index2word = dict(
        zip(fr_tokenizer.word_index.values(), fr_tokenizer.word_index.keys()))

    infered_fr_text = []
    actual_fr_text = []
    counter = 0
    max_save_plots = 20
    for i in range(min(len(ts_en_text), 1000)):
        """ Inferring with trained model """
        test_en = ts_en_text[i]

        test_en_seq = sents2sequences(en_tokenizer, [test_en],
                                      pad_length=en_timesteps)
        test_fr, attn_weights = infer_nmt(encoder_model=infer_enc_model,
                                          decoder_model=infer_dec_model,
                                          test_en_seq=test_en_seq,
                                          en_vsize=en_vsize,
                                          fr_vsize=fr_vsize)
        infered_fr_text.append(test_fr.split())
        actual_fr_text.append([ts_fr_text[i].split()])
        """ Attention plotting """
        if counter < max_save_plots:
            logger.info('Translating: {}'.format(test_en))
            logger.info('Persian: {}'.format(test_fr))
            plot_attention_weights(test_en_seq, attn_weights, en_index2word, \
                fr_index2word, base_dir=base_dir, filename='attention{}.png'.format(str(i)))
            counter += 1

    bleu = corpus_bleu(actual_fr_text, infered_fr_text)
    print("BLEU score is: {}".format(bleu))