Beispiel #1
0
def main(beam_size):
    num_emotions = 9
    word2id, id2word = build_dict()
    word2id, id2word = add_emo_token(word2id, id2word, num_emotions)
    pad_len = 30
    batch_size = 2000
    emb_dim = 300
    dim = 600
    vocab_size = len(word2id)
    from sklearn.model_selection import ShuffleSplit

    model = Seq2SeqAttentionSharedEmbedding(
        emb_dim=emb_dim,
        vocab_size=vocab_size,
        src_hidden_dim=dim,
        trg_hidden_dim=dim,
        ctx_hidden_dim=dim,
        attention_mode='dot',
        batch_size=batch_size,
        bidirectional=False,
        pad_token_src=word2id['<pad>'],
        pad_token_trg=word2id['<pad>'],
        nlayers=2,
        nlayers_trg=2,
        dropout=0.,
    )

    model.cuda()
    model_path = 'checkpoint/new_simple_start_epoch_22.model'

    model.load_state_dict(torch.load(
        model_path
    ))

    df = pd.read_csv('data_6_remove_dup_test.csv')
    X, y, tag = df['source'], df['target'], df['tag']
    __emo = 0
    test_set = EmotionDataLoaderStart(X, y, __emo, pad_len, word2id)
    test_loader = DataLoader(test_set, batch_size=batch_size)

    for __emo in range(9):
        decoder = BeamSearchDecoder(model, test_loader, pad_len, beam_size, word2id, id2word)
        decoder.translate(__emo)
Beispiel #2
0
    training_set = EmotionDataLoaderFoo(X_train, y_train, tag_train, pad_len, word2id)
    train_loader = DataLoader(training_set, batch_size=batch_size)

    test_set = EmotionDataLoaderFoo(X_dev, y_dev, tag_dev, pad_len, word2id)
    test_loader = DataLoader(test_set, batch_size=batch_size)
    # loader = iter(train_loader)
    # next(loader)

    model = Seq2SeqAttentionSharedEmbedding(
        emb_dim=emb_dim,
        vocab_size=vocab_size,
        src_hidden_dim=dim,
        trg_hidden_dim=dim,
        ctx_hidden_dim=dim,
        attention_mode='dot',
        batch_size=batch_size,
        bidirectional=False,
        pad_token_src=word2id['<pad>'],
        pad_token_trg=word2id['<pad>'],
        nlayers=2,
        nlayers_trg=2,
        dropout=0.,
    )
    model.load_word_embedding(id2word)
    model.cuda()

    # model_path = 'checkpoint/new_simple_foo_epoch_9.model'
    # model.load_state_dict(torch.load(
    #     model_path
    # ))