コード例 #1
0
ファイル: eval_skrnn.py プロジェクト: mlelarge/DL-Seq2Seq
def load_pretrained_uncond(data_type):
    ####################################################
    # default parameters, do not change
    hidden_enc_dim = 256
    hidden_dec_dim = 256
    n_layers = 1
    num_gaussian = 20
    dropout_p = 0.2
    batch_size = 50
    latent_dim = 64

    rnn_dir = 2  # 1 for unidirection,  2 for bi-direction
    bi_mode = 2  # bidirectional mode:- 1 for addition 2 for concatenation
    cond_gen = False

    if not cond_gen:
        rnn_dir = 1
        bi_mode = 1

    ####################################################

    encoder = encoder_skrnn(input_size = 5, hidden_size = hidden_enc_dim, hidden_dec_size=hidden_dec_dim,\
                        dropout_p = dropout_p,n_layers = n_layers, batch_size = batch_size, latent_dim = latent_dim,\
                        device = device, cond_gen= cond_gen, bi_mode= bi_mode, rnn_dir = rnn_dir).to(device)

    decoder = decoder_skrnn(input_size = 5, hidden_size = hidden_dec_dim, num_gaussian = num_gaussian,\
                            dropout_p = dropout_p, n_layers = n_layers, batch_size = batch_size,\
                            latent_dim = latent_dim, device = device, cond_gen= cond_gen).to(device)

    if data_type == 'cat':
        encoder.load_state_dict(
            torch.load('saved_model/UncondEnc_cat.pt',
                       map_location='cuda:0')['model'])
        decoder.load_state_dict(
            torch.load('saved_model/UncondDec_cat.pt',
                       map_location='cuda:0')['model'])
        data_enc, data_dec, max_seq_len = get_data(data_type='cat')
    else:
        encoder.load_state_dict(
            torch.load('saved_model/UncondEnc_kanji.pt')['model'])
        decoder.load_state_dict(
            torch.load('saved_model/UncondDec_kanji.pt')['model'])
        data_enc, data_dec, max_seq_len = get_data(data_type='kanji')
    return encoder, decoder, hidden_enc_dim, latent_dim, max_seq_len, cond_gen, bi_mode, device
コード例 #2
0
print_every = batch_size * 200  # print loss after this much iteration, change the multiplier aacording to dataset
plot_every = 1  # plot the strokes using current trained model

rnn_dir = 2  # 1 for unidirection,  2 for bi-direction
bi_mode = 2  # bidirectional mode:- 1 for addition 2 for concatenation
cond_gen = False  # use either unconditional or conditional generation
data_type = 'cat'  # 'cat' and 'kanji'

if not cond_gen:
    weight_kl = 0.0
    rnn_dir = 1
    bi_mode = 1

encoder = encoder_skrnn(input_size = 5, hidden_size = hidden_enc_dim, hidden_dec_size=hidden_dec_dim,\
                    dropout_p = dropout_p,n_layers = n_layers, batch_size = batch_size, latent_dim = latent_dim,\
                    device = device, cond_gen= cond_gen, bi_mode= bi_mode, rnn_dir = rnn_dir).to(device)

decoder = decoder_skrnn(input_size = 5, hidden_size = hidden_dec_dim, num_gaussian = num_gaussian,\
                        dropout_p = dropout_p, n_layers = n_layers, batch_size = batch_size,\
                        latent_dim = latent_dim, device = device, cond_gen= cond_gen).to(device)

encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

data_enc, data_dec, max_seq_len = get_data(data_type=data_type)

num_mini_batch = len(data_dec) - (len(data_dec) % batch_size)

for epoch in range(epochs):
    start = time.time()