Ejemplo n.º 1
0
def main():
    BATCH_SIZE = 128
    X_indices, Y_indices, X_char2idx, Y_char2idx, X_idx2char, Y_idx2char = preprocess_data(
    )
    X_train = X_indices[BATCH_SIZE:]
    Y_train = Y_indices[BATCH_SIZE:]
    X_test = X_indices[:BATCH_SIZE]
    Y_test = Y_indices[:BATCH_SIZE]

    model = Seq2Seq(
        rnn_size=50,
        n_layers=1,
        X_word2idx=X_char2idx,
        encoder_embedding_dim=128,
        Y_word2idx=Y_char2idx,
        decoder_embedding_dim=128,
    )
    model.fit(X_train,
              Y_train,
              val_data=(X_test, Y_test),
              batch_size=BATCH_SIZE,
              n_epoch=10,
              display_step=32)

    # model.infer('今朝有酒今朝醉', X_idx2char, Y_idx2char)
    model.infer('你好', X_idx2char, Y_idx2char)
    # model.infer('雪消狮子瘦', X_idx2char, Y_idx2char)
    model.infer('晚上吃什么', X_idx2char, Y_idx2char)
    # model.infer('生员里长,打里长不打生员。', X_idx2char, Y_idx2char)
    model.infer('我没什么意见', X_idx2char, Y_idx2char)
Ejemplo n.º 2
0
def main(args):
    batch_size = args.batch_size
    print args

    # if args.predict:
    #    assert args.model_dir, 'model directory must be specified when predicting'

    x_train, y_train, x_valid, y_valid, x_word2idx, y_word2idx, x_idx2word, y_idx2word = \
        preprocess_data(args.input, args.validation, args.vocabulary)

    model = Seq2Seq(rnn_size=args.hidden_size,
                    n_layers=args.n_layers,
                    x_word2idx=x_word2idx,
                    encoder_embedding_dim=args.hidden_size,
                    y_word2idx=y_word2idx,
                    decoder_embedding_dim=args.hidden_size,
                    model_path=args.model_path)

    if not args.predict:
        model.build_graph()
        print 'Training ...'
        model.fit(x_train,
                  y_train,
                  val_data=(x_valid, y_valid),
                  batch_size=batch_size,
                  n_epoch=args.n_epoch)
    else:
        print 'Loading pre-trained model ...'
        model.restore_graph()

    print 'Translating ...'
    model.infer_sentence(u'我 的 青蛙 叫 呱呱 !', x_idx2word, y_idx2word)
    model.infer_sentence(u'我 非常 期待 它 带 礼物 回来 !', x_idx2word, y_idx2word)
Ejemplo n.º 3
0
def main():

    config = model_config.Config()

    BATCH_SIZE = config.batch_size

    (X_indices, Y_indices), (X_char2idx, Y_char2idx), (X_idx2char, Y_idx2char) = \
        data_utils.load_preprocess(config.source_path, config.target_path)

    X_train = X_indices[BATCH_SIZE:]
    Y_train = Y_indices[BATCH_SIZE:]
    X_test = X_indices[:BATCH_SIZE]
    Y_test = Y_indices[:BATCH_SIZE]

    model = Seq2Seq(
        rnn_size=config.num_units,
        n_layers=config.num_layers,
        X_word2idx=X_char2idx,
        encoder_embedding_dim=config.encoding_embedding_size,
        Y_word2idx=Y_char2idx,
        decoder_embedding_dim=config.decoding_embedding_size,
    )

    model.fit(X_train,
              Y_train,
              val_data=(X_test, Y_test),
              batch_size=BATCH_SIZE,
              n_epoch=config.num_epochs,
              display_step=config.display_step)

    model.save_model(config.save_path)
Ejemplo n.º 4
0
def main():
    X_indices, Y_indices, X_char2idx, Y_char2idx, X_idx2char, Y_idx2char = preprocess_data()

    model = Seq2Seq(
        rnn_size = 50,
        n_layers = 2,
        X_word2idx = X_char2idx,
        Y_word2idx = Y_char2idx)
    model.fit(X_indices, Y_indices)

    model.infer('common', X_idx2char, Y_idx2char)
    model.infer('apple', X_idx2char, Y_idx2char)
    model.infer('zhedong', X_idx2char, Y_idx2char)
Ejemplo n.º 5
0
def main():
    _, (X_char2idx, Y_char2idx), (X_idx2char, Y_idx2char) = \
        data_utils.load_preprocess(config.source_path, config.target_path)

    model = Seq2Seq(rnn_size=config.num_units,
                    n_layers=config.num_layers,
                    X_word2idx=X_char2idx,
                    encoder_embedding_dim=config.encoding_embedding_size,
                    Y_word2idx=Y_char2idx,
                    decoder_embedding_dim=config.decoding_embedding_size,
                    load_path=config.save_path)

    model.infer('今朝有酒今朝醉', X_idx2char, Y_idx2char)
    # model.infer('你好', X_idx2char, Y_idx2char)
    model.infer('雪消狮子瘦', X_idx2char, Y_idx2char)
    # model.infer('晚上吃什么', X_idx2char, Y_idx2char)
    model.infer('生员里长,打里长不打生员。', X_idx2char, Y_idx2char)
Ejemplo n.º 6
0
def main():
    BATCH_SIZE = 128
    X_indices, Y_indices, X_char2idx, Y_char2idx, X_idx2char, Y_idx2char = preprocess_data(
    )
    X_train = X_indices[BATCH_SIZE:]
    Y_train = Y_indices[BATCH_SIZE:]
    X_test = X_indices[:BATCH_SIZE]
    Y_test = Y_indices[:BATCH_SIZE]

    model = Seq2Seq(
        rnn_size=50,
        n_layers=2,
        X_word2idx=X_char2idx,
        encoder_embedding_dim=15,
        Y_word2idx=Y_char2idx,
        decoder_embedding_dim=15,
        batch_size=BATCH_SIZE,
    )
    model.fit(X_train, Y_train, val_data=(X_test, Y_test))
    model.infer('common', X_idx2char, Y_idx2char)
    model.infer('apple', X_idx2char, Y_idx2char)
    model.infer('zhedong', X_idx2char, Y_idx2char)
Ejemplo n.º 7
0
def main():
    batch_size = 128
    X_indices, Y_indices, X_char2idx, Y_char2idx, X_idx2char, Y_idx2char = preprocess_data(
    )
    X_train = X_indices[batch_size:]
    Y_train = Y_indices[batch_size:]
    X_test = X_indices[:batch_size]
    Y_test = Y_indices[:batch_size]

    model = Seq2Seq(
        rnn_size=50,
        n_layers=2,
        x_word2idx=X_char2idx,
        encoder_embedding_dim=15,
        y_word2idx=Y_char2idx,
        decoder_embedding_dim=15,
    )
    model.fit(X_train,
              Y_train,
              val_data=(X_test, Y_test),
              batch_size=batch_size)
    model.infer('common', X_idx2char, Y_idx2char)
    model.infer('apple', X_idx2char, Y_idx2char)
    model.infer('zhedong', X_idx2char, Y_idx2char)
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

from seq2seq_attn import Seq2Seq
from utils import translate_sentence  #,calculate_bleu
from data import getData

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SRC, TRG, train_data, valid_data, test_data = getData(False)

src_vocab_size = len(SRC)
trg_vocab_size = len(TRG)
SRC_PAD_IDX = SRC.stoi[SRC.pad_token]
TRG_EOS_TOKEN = SRC.stoi[SRC.eos_token]

model = Seq2Seq(SRC_PAD_IDX, src_vocab_size, trg_vocab_size, device,
                TRG_EOS_TOKEN).to(device)

model.load_state_dict(torch.load('tut4-model.pt'))

src = "ein pferd geht unter einer brücke neben einem boot ."
translation, attention = translate_sentence(model, src, SRC, TRG, device)
print(src)
print(translation)


#exit()
def display_attention(sentence, translation, attention):

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)