def test(opt): full_model_name = opt.model_name \ + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \ + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \ + '_dp' + str(opt.dp) \ + '_gc' + str(opt.gcth) \ + '_wtinit' + str(opt.wtinit_meth) \ + '_lw' + str(int(opt.load_wts)) \ + '_ef' + str(int(opt.embedding_flag)) \ + '_rf' + str(int(opt.residual_flag)) print(full_model_name) opt.full_model_name = full_model_name try: os.makedirs(opt.synth_folder + opt.full_model_name) os.makedirs(opt.plot_folder + opt.full_model_name) except OSError: pass fid = open(opt.feats_dir + 'train_list.txt') train_list = fid.read().splitlines() fid.close() fid = open(opt.feats_dir + 'val_list.txt') val_list = fid.read().splitlines() fid.close() all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list + val_list) vocab = make_vocab(all_prompts) phn2id, id2phn = phn2id2phn(vocab) print(vocab) fid = open(opt.feats_dir + 'test_list.txt') val_list = fid.read().splitlines() val_list = val_list[:10] fid.close() # Load stats of mfcc mo1 = np.load(opt.stats_dir + 'mo.npy') so1 = np.load(opt.stats_dir + 'so.npy') mo1 = mo1.astype('float32') so1 = so1.astype('float32') nml_vec1 = np.arange(0, mo1.shape[1]) # Load stats of spectrum mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy') so2 = np.load(opt.pfnet_stats_dir + 'so.npy') mo2 = mo2.astype('float32') so2 = so2.astype('float32') nml_vec2 = np.arange(0, mo2.shape[1]) # Load validation data val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', val_list) # Initialize model vocab_size = len(vocab) op_dim = 60 encoder = encoders.EncoderCBL(vocab_size, opt.hs2, opt.hs1) if opt.residual_flag: if opt.r == 2: decoder = decoders.AttnDecoderLSTM3L_R2_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 3: decoder = decoders.AttnDecoderLSTM3L_R3_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 4: decoder = decoders.AttnDecoderLSTM3L_R4_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 5: decoder = decoders.AttnDecoderLSTM3L_R5_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) else: decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp) op_dim1 = 513 pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1) encoder = encoder.cuda() if use_cuda else encoder decoder = decoder.cuda() if use_cuda else decoder pfnet = pfnet.cuda() if use_cuda else pfnet criterion = torch.nn.L1Loss(size_average=False) load_model_name_pfx = '../../wt/' + opt.full_model_name + '_' load_model_name_sfx = '.pth' # load model enc_state_dict = torch.load(load_model_name_pfx + 'enc' + load_model_name_sfx, map_location=lambda storage, loc: storage) encoder.load_state_dict(enc_state_dict) dec_state_dict = torch.load(load_model_name_pfx + 'dec' + load_model_name_sfx, map_location=lambda storage, loc: storage) decoder.load_state_dict(dec_state_dict) pfnet_state_dict = torch.load(load_model_name_pfx + 'pfnet' + load_model_name_sfx, map_location=lambda storage, loc: storage) pfnet.load_state_dict(pfnet_state_dict) tf = True # teacher forcing avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%0.4f %0.4f' % (avg_val_loss_tf1, avg_val_loss_tf2)) tf = False # professor forcing avg_val_loss_pf1, avg_val_loss_pf2, decoder_attentions_pf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%0.4f %0.4f' % (avg_val_loss_pf1, avg_val_loss_pf2))
def train(opt): ''' data could be loaded to a dictionary with "train"/"val"/"test" pointers (Need to improve the below part) ''' print_every = opt.print_every showatt_every = opt.print_every plot_every = opt.print_every full_model_name = opt.model_name \ + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \ + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \ + '_dp' + str(opt.dp) \ + '_gc' + str(opt.gcth) \ + '_wtinit' + str(opt.wtinit_meth) \ + '_lw' + str(int(opt.load_wts)) \ + '_ef' + str(int(opt.embedding_flag)) \ + '_rf' + str(int(opt.residual_flag)) print(full_model_name) logging.basicConfig(filename=opt.log_folder + full_model_name + '.log', filemode='w', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') r = opt.r fid = open(opt.feats_dir + 'train_list.txt') train_list = fid.read().splitlines() fid.close() fid = open(opt.feats_dir + 'val_list.txt') val_list = fid.read().splitlines() fid.close() all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list + val_list) vocab = make_vocab(all_prompts) print(vocab) # Load training data train_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list) phn2id, id2phn = phn2id2phn(vocab) file_list = train_prompts.keys() print(len(file_list), len(train_list)) # Load stats of mfcc mo1 = np.load(opt.stats_dir + 'mo.npy') so1 = np.load(opt.stats_dir + 'so.npy') mo1 = mo1.astype('float32') so1 = so1.astype('float32') nml_vec1 = np.arange(0, mo1.shape[1]) # Load stats of spectrum mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy') so2 = np.load(opt.pfnet_stats_dir + 'so.npy') mo2 = mo2.astype('float32') so2 = so2.astype('float32') nml_vec2 = np.arange(0, mo2.shape[1]) # Load validation data val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', val_list) # Initialize model vocab_size = len(vocab) op_dim = 60 encoder = encoders.EncoderCBL(vocab_size, opt.hs2, opt.hs1) if opt.residual_flag: if opt.r == 2: decoder = decoders.AttnDecoderLSTM3L_R2_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 3: decoder = decoders.AttnDecoderLSTM3L_R3_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 4: decoder = decoders.AttnDecoderLSTM3L_R4_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 5: decoder = decoders.AttnDecoderLSTM3L_R5_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) else: decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp) op_dim1 = 513 pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1) encoder = encoder.cuda() if use_cuda else encoder decoder = decoder.cuda() if use_cuda else decoder pfnet = pfnet.cuda() if use_cuda else pfnet criterion = torch.nn.L1Loss(size_average=False) if opt.load_wts: load_model_name_pfx = '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_r' + str( opt.r ) + '_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_lw0_ef1_rf1_' load_model_name_sfx = '.pth' # load model enc_state_dict = torch.load(load_model_name_pfx + 'enc' + load_model_name_sfx, map_location=lambda storage, loc: storage) encoder.load_state_dict(enc_state_dict) dec_state_dict = torch.load(load_model_name_pfx + 'dec' + load_model_name_sfx, map_location=lambda storage, loc: storage) decoder.load_state_dict(dec_state_dict) pfnet_state_dict = torch.load( load_model_name_pfx + 'pfnet' + load_model_name_sfx, map_location=lambda storage, loc: storage) pfnet.load_state_dict(pfnet_state_dict) encoder_optimizer = optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) decoder_optimizer = optim.Adam(decoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) pfnet_optimizer = optim.Adam(pfnet.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) start = time.time() print_loss_total = 0 # Reset every print_every best_val_loss = sys.maxsize for iter in range(1, opt.niter + 1): if iter == 3: opt.lr = opt.lr / 10 encoder_optimizer = optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) decoder_optimizer = optim.Adam(decoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) pfnet_optimizer = optim.Adam(pfnet.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) for j, k in enumerate(train_prompts): [input_variable, input_length] = get_x(train_prompts, k, phn2id, use_cuda) train_targets, train_seq_len = load_targets( opt.feats_dir + '/fb/', [k], '.npy', dtype, mo1, so1, nml_vec1) [target_variable, target_variable2, target_length] = get_y(train_seq_len, 0, train_targets, use_cuda, r) loss = 0 encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() encoder_h0 = encoder.initHidden() encoder_c0 = encoder.initHidden() encoder_outputs = Variable( torch.zeros(input_length, encoder.hidden_size2)) encoder_outputs = encoder_outputs.cuda( ) if use_cuda else encoder_outputs encoder_output, (encoder_hn, encoder_cn) = encoder(input_variable, (encoder_h0, encoder_c0)) encoder_outputs = encoder_output.squeeze(1) decoder_input = Variable(torch.zeros(1, op_dim)) # all - zero frame decoder_input = decoder_input.cuda() if use_cuda else decoder_input decoder_h1 = decoder.initHidden() decoder_c1 = decoder.initHidden() decoder_h2 = decoder.initHidden() decoder_c2 = decoder.initHidden() decoder_h3 = decoder.initHidden() decoder_c3 = decoder.initHidden() decoder_output_half = Variable( torch.zeros(target_length, r * op_dim)).cuda() if use_cuda else Variable( torch.zeros(target_length, r * op_dim)) decoder_output_full = Variable( torch.zeros(r * target_length, op_dim)).cuda() if use_cuda else Variable( torch.zeros(r * target_length, op_dim)) # Teacher forcing: Feed the target as the next input for di in range(target_length): decoder_output1, decoder_output2, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, decoder_attention = decoder( decoder_input, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, encoder_outputs) loss += criterion(decoder_output1, target_variable[di]) decoder_input = target_variable2[di].unsqueeze( 0) # Teacher forcing decoder_output_half[di] = decoder_output1 loss.backward(retain_graph=True) encoder_optimizer.step() decoder_optimizer.step() encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() # Start Post-Filtering Net pfnet_optimizer.zero_grad() for ix in range(r): decoder_output_full[ ix::r, :] = decoder_output_half[:, ix * op_dim:(ix + 1) * op_dim] s1 = r * target_length train_targets_pfnet, train_seq_len_pfnet = load_targets( opt.feats_dir + '/sp/', [k], '.npy', dtype, mo2, so2, nml_vec2) targets_pfnet = Variable(train_targets_pfnet).cuda( ) if use_cuda else Variable(train_targets_pfnet) s2 = targets_pfnet.size()[0] if (s2 % r) > 0: targets_pfnet = targets_pfnet[:-(s2 % r), :] pfnet_h0 = pfnet.initHidden() pfnet_c0 = pfnet.initHidden() pfnet_outputs = Variable( torch.zeros(targets_pfnet.size()[0], pfnet.output_size)) pfnet_outputs = pfnet_outputs.cuda() if use_cuda else pfnet_outputs pfnet_output = pfnet(decoder_output_full, (pfnet_h0, pfnet_c0)) pfnet_outputs = pfnet_output loss_pfnet = criterion(pfnet_outputs, targets_pfnet) loss_pfnet.backward() pfnet_optimizer.step() loss_total = loss + loss_pfnet print_loss_total += (loss_total.data[0] / r * target_length) if (j + 1) % print_every == 0: print_loss_avg = print_loss_total / print_every print_loss_total = 0 print('%s (%d %d%%) %.4f' % (timeSince( start, (iter * len(train_prompts) - len(train_prompts) + j) / ((opt.niter + 1) * len(train_prompts))), iter, iter / opt.niter * 100, print_loss_avg)) tf = True # teacher forcing avg_total_val_loss_tf, avg_dec_val_loss_tf, decoder_attentions_tf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%d %0.4f %0.4f' % (iter, avg_total_val_loss_tf, avg_dec_val_loss_tf)) tf = False # always sampling avg_total_val_loss_as, avg_dec_val_loss_as, decoder_attentions_pf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%d %0.4f %0.4f' % (iter, avg_total_val_loss_as, avg_dec_val_loss_as)) logging.debug( 'Epoch: ' + str(iter) + ' Update: ' + str(iter * len(train_prompts) - len(train_prompts) + j) + ' Avg Total Val Loss TF: ' + str(avg_total_val_loss_tf) + ' Avg Total Val Loss AS: ' + str(avg_total_val_loss_as) + ' Avg Dec Val Loss TF: ' + str(avg_dec_val_loss_tf) + ' Avg Dec Val Loss AS: ' + str(avg_dec_val_loss_as)) if avg_total_val_loss_tf < best_val_loss: best_val_loss = avg_total_val_loss_tf torch.save( encoder.state_dict(), '%s/%s_enc.pth' % (opt.model_folder, full_model_name)) torch.save( decoder.state_dict(), '%s/%s_dec.pth' % (opt.model_folder, full_model_name)) torch.save( pfnet.state_dict(), '%s/%s_pfnet.pth' % (opt.model_folder, full_model_name)) encoder.train() decoder.train() pfnet.train() # if (j+1) % showatt_every == 0: # plt.figure(1, figsize=(12, 12)) # plt.imshow(decoder_attentions_tf.numpy()) # plt.colorbar() # pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight') # plt.close() # plt.figure(1, figsize=(12, 12)) # plt.imshow(decoder_attentions_pf.numpy()) # plt.colorbar() # pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight') # plt.close() # if (j+1) % plot_every == 0: # plot_loss_avg = plot_loss_total / plot_every # plot_losses.append(plot_loss_avg) # plot_loss_total = 0 gc.collect()
def test(opt): print_every = opt.print_every showatt_every = opt.print_every plot_every = opt.print_every full_model_name = opt.model_name \ + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \ + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \ + '_dp' + str(opt.dp) \ + '_gc' + str(opt.gcth) \ + '_wtinit' + str(opt.wtinit_meth) \ + '_lw' + str(int(opt.load_wts)) \ + '_ef' + str(int(opt.embedding_flag)) \ + '_rf' + str(int(opt.residual_flag)) print(full_model_name) opt.full_model_name = full_model_name r = opt.r fid = open(opt.feats_dir + 'train_list.txt') train_list = fid.read().splitlines() fid.close() fid = open(opt.feats_dir + 'val_list.txt') val_list = fid.read().splitlines() val_list = val_list fid.close() all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list + val_list) vocab = make_vocab(all_prompts) print(vocab) fid = open(opt.feats_dir + 'test_list.txt') val_list = fid.read().splitlines() val_list = val_list[:10] fid.close() # Load training data train_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list) #vocab = make_vocab(train_prompts) #print(vocab) phn2id, id2phn = phn2id2phn(vocab) file_list = train_prompts.keys() print(len(file_list), len(train_list)) #save_stats_suffstats(opt.feats_dir + '/fb/', file_list, '.npy', dtype, opt.stats_dir) #save_stats_suffstats(opt.feats_dir + '/sp/', file_list, '.npy', dtype, opt.pfnet_stats_dir) # save_stats(opt.feats_dir + '/fb/', file_list, '.npy', dtype, opt.stats_dir) # save_stats(opt.feats_dir + phase + '/log_mag_spec/', # file_list, opt.pfnet_audio_feats_ext, dtype, opt.pfnet_stats_dir) # exit() # Load stats of mfcc mo1 = np.load(opt.stats_dir + 'mo.npy') so1 = np.load(opt.stats_dir + 'so.npy') mo1 = mo1.astype('float32') so1 = so1.astype('float32') nml_vec1 = np.arange(0, mo1.shape[1]) # Load stats of spectrum mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy') so2 = np.load(opt.pfnet_stats_dir + 'so.npy') mo2 = mo2.astype('float32') so2 = so2.astype('float32') nml_vec2 = np.arange(0, mo2.shape[1]) #train_targets, train_seq_len = load_targets(opt.feats_dir + phase # + '/audio_feats/', file_list, # opt.audio_feats_ext, # dtype, mo1, so1, nml_vec1) # Load validation data val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', val_list) #file_list = val_prompts.keys() #val_targets, val_seq_len = load_targets(opt.feats_dir + '/fb/', val_list, '.npy', dtype, mo1, so1, nml_vec1) #print(val_seq_len) #print(val_targets.shape) #print(val_list) # Initialize model vocab_size = len(vocab) op_dim = 60 encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1) if opt.residual_flag: if opt.r == 2: decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 3: decoder = decoders.AttnDecoderLSTM3L_R3_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 4: decoder = decoders.AttnDecoderLSTM3L_R4_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 5: decoder = decoders.AttnDecoderLSTM3L_R5_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) else: decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp) op_dim1 = 513 pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1) encoder = encoder.cuda() if use_cuda else encoder decoder = decoder.cuda() if use_cuda else decoder pfnet = pfnet.cuda() if use_cuda else pfnet criterion = torch.nn.L1Loss(size_average=False) if opt.load_wts: load_model_name_pfx = '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_r3_lr3e-05_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_lw1_ef0_rf1_' load_model_name_sfx = '_epoch_2999_5.pth' # load model enc_state_dict = torch.load(load_model_name_pfx + 'enc' + load_model_name_sfx, map_location=lambda storage, loc: storage) encoder.load_state_dict(enc_state_dict) dec_state_dict = torch.load(load_model_name_pfx + 'dec' + load_model_name_sfx, map_location=lambda storage, loc: storage) decoder.load_state_dict(dec_state_dict) pfnet_state_dict = torch.load(load_model_name_pfx + 'pfnet' + load_model_name_sfx, map_location=lambda storage, loc: storage) pfnet.load_state_dict(pfnet_state_dict) tf = True # teacher forcing avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%0.4f %0.4f' % (avg_val_loss_tf1, avg_val_loss_tf2))