def load_pretrained_uncond(data_type): #################################################### # default parameters, do not change hidden_enc_dim = 256 hidden_dec_dim = 256 n_layers = 1 num_gaussian = 20 dropout_p = 0.2 batch_size = 50 latent_dim = 64 rnn_dir = 2 # 1 for unidirection, 2 for bi-direction bi_mode = 2 # bidirectional mode:- 1 for addition 2 for concatenation cond_gen = False if not cond_gen: rnn_dir = 1 bi_mode = 1 #################################################### encoder = encoder_skrnn(input_size = 5, hidden_size = hidden_enc_dim, hidden_dec_size=hidden_dec_dim,\ dropout_p = dropout_p,n_layers = n_layers, batch_size = batch_size, latent_dim = latent_dim,\ device = device, cond_gen= cond_gen, bi_mode= bi_mode, rnn_dir = rnn_dir).to(device) decoder = decoder_skrnn(input_size = 5, hidden_size = hidden_dec_dim, num_gaussian = num_gaussian,\ dropout_p = dropout_p, n_layers = n_layers, batch_size = batch_size,\ latent_dim = latent_dim, device = device, cond_gen= cond_gen).to(device) if data_type == 'cat': encoder.load_state_dict( torch.load('saved_model/UncondEnc_cat.pt', map_location='cuda:0')['model']) decoder.load_state_dict( torch.load('saved_model/UncondDec_cat.pt', map_location='cuda:0')['model']) data_enc, data_dec, max_seq_len = get_data(data_type='cat') else: encoder.load_state_dict( torch.load('saved_model/UncondEnc_kanji.pt')['model']) decoder.load_state_dict( torch.load('saved_model/UncondDec_kanji.pt')['model']) data_enc, data_dec, max_seq_len = get_data(data_type='kanji') return encoder, decoder, hidden_enc_dim, latent_dim, max_seq_len, cond_gen, bi_mode, device
print_every = batch_size * 200 # print loss after this much iteration, change the multiplier aacording to dataset plot_every = 1 # plot the strokes using current trained model rnn_dir = 2 # 1 for unidirection, 2 for bi-direction bi_mode = 2 # bidirectional mode:- 1 for addition 2 for concatenation cond_gen = False # use either unconditional or conditional generation data_type = 'cat' # 'cat' and 'kanji' if not cond_gen: weight_kl = 0.0 rnn_dir = 1 bi_mode = 1 encoder = encoder_skrnn(input_size = 5, hidden_size = hidden_enc_dim, hidden_dec_size=hidden_dec_dim,\ dropout_p = dropout_p,n_layers = n_layers, batch_size = batch_size, latent_dim = latent_dim,\ device = device, cond_gen= cond_gen, bi_mode= bi_mode, rnn_dir = rnn_dir).to(device) decoder = decoder_skrnn(input_size = 5, hidden_size = hidden_dec_dim, num_gaussian = num_gaussian,\ dropout_p = dropout_p, n_layers = n_layers, batch_size = batch_size,\ latent_dim = latent_dim, device = device, cond_gen= cond_gen).to(device) encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate) data_enc, data_dec, max_seq_len = get_data(data_type=data_type) num_mini_batch = len(data_dec) - (len(data_dec) % batch_size) for epoch in range(epochs): start = time.time()