def main(hps): #pre_train if args.pre_train: word2idx, idx2word, verb2idx, idx2verb = make_vocab(hps) # word2idx, idx2word, verb2idx, idx2verb = load_vocab(hps) # mapping, vectors = load_glove(hps) # weights_matrix = make_pre_trained_word_embedding(mapping, vectors, word2idx.keys(), hps) hps = hps._replace(vocab_size=len(word2idx)) hps = hps._replace(verb_vocab_size=len(verb2idx)) hps = hps._replace(pre_train=True) print('parameters:') print(hps) train_loader, valid_loader, char_weights, action_weights = load_train_data(word2idx, verb2idx, hps) model = Model(char_weights, action_weights, hps) # model.load_state_dict(torch.load(hps.test_path + 'models/pre_train.model')) model = model.cuda() optimizer = optim.Adam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay) print('pre_training', flush=True) pre_train(model, optimizer, train_loader, valid_loader, idx2word, hps) #train elif args.train: # word2idx, idx2word = make_vocab(hps) word2idx, idx2word = load_vocab(hps) hps = hps._replace(vocab_size=len(word2idx)) print('parameters:') print(hps) train_loader, valid_loader, char_weights, action_weights = load_train_data(word2idx, hps) model = Model(char_weights, action_weights, hps) # model.load_state_dict(torch.load(hps.test_path + 'models/best.model')) if args.reload: model.load_state_dict(torch.load(hps.test_path + hps.save_path.format(args.reload_epoch))) model.cuda() # model = nn.DataParallel(model,device_ids=[0]) optimizer = optim.Adam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay) print('training', flush=True) train(model, optimizer, train_loader, valid_loader, idx2word, hps) #test if args.test: print('testing', flush=True) word2idx, idx2word = load_vocab(hps) hps = hps._replace(vocab_size=len(word2idx)) hps = hps._replace(test=True) model = Model([0] * hps.max_num_char, [0] * hps.vocab_size, hps) model.load_state_dict(torch.load(hps.test_path + hps.save_path)) model.cuda() test_loader, anony2names = load_test_data(word2idx, hps) test(model, test_loader, idx2word, anony2names, hps)
def test_make_vocab(self): tokens = [ 'Rock', 'n', 'Roll', 'is', 'a', 'risk', '.', 'You', 'rick', 'being', 'ridiculed', '.' ] token_to_index, index_to_token = make_vocab(tokens, 1, 10) self.assertEqual(token_to_index['<pad>'], 0) self.assertEqual(token_to_index['<unk>'], 1) self.assertEqual(token_to_index['<s>'], 2) self.assertEqual(token_to_index['</s>'], 3) self.assertEqual(len(token_to_index), 10) self.assertEqual(len(index_to_token), 10) self.assertEqual(index_to_token[0], '<pad>') self.assertEqual(index_to_token[1], '<unk>') self.assertEqual(index_to_token[2], '<s>') self.assertEqual(index_to_token[3], '</s>')
def main(args): if not exists(join(__SAVE_PATH, args.dir)): os.makedirs(join(__SAVE_PATH, args.dir)) os.makedirs(join(__SAVE_PATH, '{}/ckpt'.format(args.dir))) word2id, id2word = make_vocab(args.vsize) with open(join(__SAVE_PATH, join(args.dir, 'vocab.pkl')), 'wb') as f: pkl.dump((word2id, id2word), f, pkl.HIGHEST_PROTOCOL) word2id = defaultdict(lambda: UNK, word2id) train_loader = get_coco_train_loader(word2id, args.max_len, args.batch_size, cuda=args.cuda) val_loader = get_coco_val_loader(word2id, args.max_len, args.batch_size, cuda=args.cuda) model = AttnImCap(len(id2word), args.emb_dim, args.n_cell, args.n_layer) if args.emb: emb, oovs = load_embedding_from_bin(args.emb, id2word) model.set_embedding(emb, oovs=oovs) if args.cuda: model.cuda() if args.opt == 'adam': opt_cls = optim.Adam else: raise ValueError() opt_kwargs = {'lr': args.lr} # TODO optimizer = opt_cls(model.parameters(), **opt_kwargs) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=0, factor=0.5, verbose=True) meta = vars(args) with open(join(__SAVE_PATH, '{}/meta.json'.format(args.dir)), 'w') as f: json.dump(meta, f) configure(join(__SAVE_PATH, args.dir)) step = 0 running = None best_val = None patience = 0 for img, input_, target in train_loader: loss, grad_norm = train_step(model, img, input_, target, optimizer, args.clip_grad) step += 1 running = 0.99 * running + 0.01 * loss if running else loss log_value('loss', loss, step) log_value('grad', grad_norm, step) print('step: {}, running loss: {:.4f}\r'.format(step, running), end='') sys.stdout.flush() if step % args.ckpt_freq == 0: print('\nstart validation...') val_loss = validate(model, val_loader) log_value('val_loss', val_loss, step) save_ckpt(model, val_loss, step, args.dir) scheduler.step(val_loss) if best_val is None or val_loss < best_val: best_val = val_loss patience = 0 else: print('val loss does not decrease') patience += 1 if patience > args.patience: break print('training finished, run test set') test_loader = get_coco_test_iter(args.max_len, args.batch_size, cuda=args.cuda) model.load_state_dict(torch.load(get_best_ckpt(args.dir))) result = test(model, test_loader, id2word, args.max_len) with open(join(__SAVE_PATH, '{}/result.json'.format(args.dir)), 'w') as f: json.dump(result, f)
def train(opt): ''' data could be loaded to a dictionary with "train"/"val"/"test" pointers (Need to improve the below part) ''' print_every = opt.print_every showatt_every = opt.print_every plot_every = opt.print_every full_model_name = opt.model_name \ + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \ + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \ + '_dp' + str(opt.dp) \ + '_gc' + str(opt.gcth) \ + '_wtinit' + str(opt.wtinit_meth) \ + '_lw' + str(int(opt.load_wts)) \ + '_ef' + str(int(opt.embedding_flag)) \ + '_rf' + str(int(opt.residual_flag)) print(full_model_name) logging.basicConfig(filename=opt.log_folder + full_model_name + '.log', filemode='w', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') r = opt.r fid = open(opt.feats_dir + 'train_list.txt') train_list = fid.read().splitlines() fid.close() fid = open(opt.feats_dir + 'val_list.txt') val_list = fid.read().splitlines() fid.close() all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list + val_list) vocab = make_vocab(all_prompts) print(vocab) # Load training data train_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list) phn2id, id2phn = phn2id2phn(vocab) file_list = train_prompts.keys() print(len(file_list), len(train_list)) # Load stats of mfcc mo1 = np.load(opt.stats_dir + 'mo.npy') so1 = np.load(opt.stats_dir + 'so.npy') mo1 = mo1.astype('float32') so1 = so1.astype('float32') nml_vec1 = np.arange(0, mo1.shape[1]) # Load stats of spectrum mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy') so2 = np.load(opt.pfnet_stats_dir + 'so.npy') mo2 = mo2.astype('float32') so2 = so2.astype('float32') nml_vec2 = np.arange(0, mo2.shape[1]) # Load validation data val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', val_list) # Initialize model vocab_size = len(vocab) op_dim = 60 encoder = encoders.EncoderCBL(vocab_size, opt.hs2, opt.hs1) if opt.residual_flag: if opt.r == 2: decoder = decoders.AttnDecoderLSTM3L_R2_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 3: decoder = decoders.AttnDecoderLSTM3L_R3_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 4: decoder = decoders.AttnDecoderLSTM3L_R4_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 5: decoder = decoders.AttnDecoderLSTM3L_R5_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) else: decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp) op_dim1 = 513 pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1) encoder = encoder.cuda() if use_cuda else encoder decoder = decoder.cuda() if use_cuda else decoder pfnet = pfnet.cuda() if use_cuda else pfnet criterion = torch.nn.L1Loss(size_average=False) if opt.load_wts: load_model_name_pfx = '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_r' + str( opt.r ) + '_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_lw0_ef1_rf1_' load_model_name_sfx = '.pth' # load model enc_state_dict = torch.load(load_model_name_pfx + 'enc' + load_model_name_sfx, map_location=lambda storage, loc: storage) encoder.load_state_dict(enc_state_dict) dec_state_dict = torch.load(load_model_name_pfx + 'dec' + load_model_name_sfx, map_location=lambda storage, loc: storage) decoder.load_state_dict(dec_state_dict) pfnet_state_dict = torch.load( load_model_name_pfx + 'pfnet' + load_model_name_sfx, map_location=lambda storage, loc: storage) pfnet.load_state_dict(pfnet_state_dict) encoder_optimizer = optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) decoder_optimizer = optim.Adam(decoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) pfnet_optimizer = optim.Adam(pfnet.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) start = time.time() print_loss_total = 0 # Reset every print_every best_val_loss = sys.maxsize for iter in range(1, opt.niter + 1): if iter == 3: opt.lr = opt.lr / 10 encoder_optimizer = optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) decoder_optimizer = optim.Adam(decoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) pfnet_optimizer = optim.Adam(pfnet.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) for j, k in enumerate(train_prompts): [input_variable, input_length] = get_x(train_prompts, k, phn2id, use_cuda) train_targets, train_seq_len = load_targets( opt.feats_dir + '/fb/', [k], '.npy', dtype, mo1, so1, nml_vec1) [target_variable, target_variable2, target_length] = get_y(train_seq_len, 0, train_targets, use_cuda, r) loss = 0 encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() encoder_h0 = encoder.initHidden() encoder_c0 = encoder.initHidden() encoder_outputs = Variable( torch.zeros(input_length, encoder.hidden_size2)) encoder_outputs = encoder_outputs.cuda( ) if use_cuda else encoder_outputs encoder_output, (encoder_hn, encoder_cn) = encoder(input_variable, (encoder_h0, encoder_c0)) encoder_outputs = encoder_output.squeeze(1) decoder_input = Variable(torch.zeros(1, op_dim)) # all - zero frame decoder_input = decoder_input.cuda() if use_cuda else decoder_input decoder_h1 = decoder.initHidden() decoder_c1 = decoder.initHidden() decoder_h2 = decoder.initHidden() decoder_c2 = decoder.initHidden() decoder_h3 = decoder.initHidden() decoder_c3 = decoder.initHidden() decoder_output_half = Variable( torch.zeros(target_length, r * op_dim)).cuda() if use_cuda else Variable( torch.zeros(target_length, r * op_dim)) decoder_output_full = Variable( torch.zeros(r * target_length, op_dim)).cuda() if use_cuda else Variable( torch.zeros(r * target_length, op_dim)) # Teacher forcing: Feed the target as the next input for di in range(target_length): decoder_output1, decoder_output2, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, decoder_attention = decoder( decoder_input, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, encoder_outputs) loss += criterion(decoder_output1, target_variable[di]) decoder_input = target_variable2[di].unsqueeze( 0) # Teacher forcing decoder_output_half[di] = decoder_output1 loss.backward(retain_graph=True) encoder_optimizer.step() decoder_optimizer.step() encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() # Start Post-Filtering Net pfnet_optimizer.zero_grad() for ix in range(r): decoder_output_full[ ix::r, :] = decoder_output_half[:, ix * op_dim:(ix + 1) * op_dim] s1 = r * target_length train_targets_pfnet, train_seq_len_pfnet = load_targets( opt.feats_dir + '/sp/', [k], '.npy', dtype, mo2, so2, nml_vec2) targets_pfnet = Variable(train_targets_pfnet).cuda( ) if use_cuda else Variable(train_targets_pfnet) s2 = targets_pfnet.size()[0] if (s2 % r) > 0: targets_pfnet = targets_pfnet[:-(s2 % r), :] pfnet_h0 = pfnet.initHidden() pfnet_c0 = pfnet.initHidden() pfnet_outputs = Variable( torch.zeros(targets_pfnet.size()[0], pfnet.output_size)) pfnet_outputs = pfnet_outputs.cuda() if use_cuda else pfnet_outputs pfnet_output = pfnet(decoder_output_full, (pfnet_h0, pfnet_c0)) pfnet_outputs = pfnet_output loss_pfnet = criterion(pfnet_outputs, targets_pfnet) loss_pfnet.backward() pfnet_optimizer.step() loss_total = loss + loss_pfnet print_loss_total += (loss_total.data[0] / r * target_length) if (j + 1) % print_every == 0: print_loss_avg = print_loss_total / print_every print_loss_total = 0 print('%s (%d %d%%) %.4f' % (timeSince( start, (iter * len(train_prompts) - len(train_prompts) + j) / ((opt.niter + 1) * len(train_prompts))), iter, iter / opt.niter * 100, print_loss_avg)) tf = True # teacher forcing avg_total_val_loss_tf, avg_dec_val_loss_tf, decoder_attentions_tf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%d %0.4f %0.4f' % (iter, avg_total_val_loss_tf, avg_dec_val_loss_tf)) tf = False # always sampling avg_total_val_loss_as, avg_dec_val_loss_as, decoder_attentions_pf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%d %0.4f %0.4f' % (iter, avg_total_val_loss_as, avg_dec_val_loss_as)) logging.debug( 'Epoch: ' + str(iter) + ' Update: ' + str(iter * len(train_prompts) - len(train_prompts) + j) + ' Avg Total Val Loss TF: ' + str(avg_total_val_loss_tf) + ' Avg Total Val Loss AS: ' + str(avg_total_val_loss_as) + ' Avg Dec Val Loss TF: ' + str(avg_dec_val_loss_tf) + ' Avg Dec Val Loss AS: ' + str(avg_dec_val_loss_as)) if avg_total_val_loss_tf < best_val_loss: best_val_loss = avg_total_val_loss_tf torch.save( encoder.state_dict(), '%s/%s_enc.pth' % (opt.model_folder, full_model_name)) torch.save( decoder.state_dict(), '%s/%s_dec.pth' % (opt.model_folder, full_model_name)) torch.save( pfnet.state_dict(), '%s/%s_pfnet.pth' % (opt.model_folder, full_model_name)) encoder.train() decoder.train() pfnet.train() # if (j+1) % showatt_every == 0: # plt.figure(1, figsize=(12, 12)) # plt.imshow(decoder_attentions_tf.numpy()) # plt.colorbar() # pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight') # plt.close() # plt.figure(1, figsize=(12, 12)) # plt.imshow(decoder_attentions_pf.numpy()) # plt.colorbar() # pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight') # plt.close() # if (j+1) % plot_every == 0: # plot_loss_avg = plot_loss_total / plot_every # plot_losses.append(plot_loss_avg) # plot_loss_total = 0 gc.collect()
def test(opt): full_model_name = opt.model_name \ + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \ + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \ + '_dp' + str(opt.dp) \ + '_gc' + str(opt.gcth) \ + '_wtinit' + str(opt.wtinit_meth) \ + '_lw' + str(int(opt.load_wts)) \ + '_ef' + str(int(opt.embedding_flag)) \ + '_rf' + str(int(opt.residual_flag)) print(full_model_name) opt.full_model_name = full_model_name try: os.makedirs(opt.synth_folder + opt.full_model_name) os.makedirs(opt.plot_folder + opt.full_model_name) except OSError: pass fid = open(opt.feats_dir + 'train_list.txt') train_list = fid.read().splitlines() fid.close() fid = open(opt.feats_dir + 'val_list.txt') val_list = fid.read().splitlines() fid.close() all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list + val_list) vocab = make_vocab(all_prompts) phn2id, id2phn = phn2id2phn(vocab) print(vocab) fid = open(opt.feats_dir + 'test_list.txt') val_list = fid.read().splitlines() val_list = val_list[:10] fid.close() # Load stats of mfcc mo1 = np.load(opt.stats_dir + 'mo.npy') so1 = np.load(opt.stats_dir + 'so.npy') mo1 = mo1.astype('float32') so1 = so1.astype('float32') nml_vec1 = np.arange(0, mo1.shape[1]) # Load stats of spectrum mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy') so2 = np.load(opt.pfnet_stats_dir + 'so.npy') mo2 = mo2.astype('float32') so2 = so2.astype('float32') nml_vec2 = np.arange(0, mo2.shape[1]) # Load validation data val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', val_list) # Initialize model vocab_size = len(vocab) op_dim = 60 encoder = encoders.EncoderCBL(vocab_size, opt.hs2, opt.hs1) if opt.residual_flag: if opt.r == 2: decoder = decoders.AttnDecoderLSTM3L_R2_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 3: decoder = decoders.AttnDecoderLSTM3L_R3_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 4: decoder = decoders.AttnDecoderLSTM3L_R4_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 5: decoder = decoders.AttnDecoderLSTM3L_R5_Rescon( op_dim, opt.hs2, op_dim, 1, opt.dp) else: decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp) op_dim1 = 513 pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1) encoder = encoder.cuda() if use_cuda else encoder decoder = decoder.cuda() if use_cuda else decoder pfnet = pfnet.cuda() if use_cuda else pfnet criterion = torch.nn.L1Loss(size_average=False) load_model_name_pfx = '../../wt/' + opt.full_model_name + '_' load_model_name_sfx = '.pth' # load model enc_state_dict = torch.load(load_model_name_pfx + 'enc' + load_model_name_sfx, map_location=lambda storage, loc: storage) encoder.load_state_dict(enc_state_dict) dec_state_dict = torch.load(load_model_name_pfx + 'dec' + load_model_name_sfx, map_location=lambda storage, loc: storage) decoder.load_state_dict(dec_state_dict) pfnet_state_dict = torch.load(load_model_name_pfx + 'pfnet' + load_model_name_sfx, map_location=lambda storage, loc: storage) pfnet.load_state_dict(pfnet_state_dict) tf = True # teacher forcing avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%0.4f %0.4f' % (avg_val_loss_tf1, avg_val_loss_tf2)) tf = False # professor forcing avg_val_loss_pf1, avg_val_loss_pf2, decoder_attentions_pf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%0.4f %0.4f' % (avg_val_loss_pf1, avg_val_loss_pf2))
def test(opt): print_every = opt.print_every showatt_every = opt.print_every plot_every = opt.print_every full_model_name = opt.model_name \ + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \ + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \ + '_dp' + str(opt.dp) \ + '_gc' + str(opt.gcth) \ + '_wtinit' + str(opt.wtinit_meth) \ + '_lw' + str(int(opt.load_wts)) \ + '_ef' + str(int(opt.embedding_flag)) \ + '_rf' + str(int(opt.residual_flag)) print(full_model_name) opt.full_model_name = full_model_name r = opt.r fid = open(opt.feats_dir + 'train_list.txt') train_list = fid.read().splitlines() fid.close() fid = open(opt.feats_dir + 'val_list.txt') val_list = fid.read().splitlines() val_list = val_list fid.close() all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list + val_list) vocab = make_vocab(all_prompts) print(vocab) fid = open(opt.feats_dir + 'test_list.txt') val_list = fid.read().splitlines() val_list = val_list[:10] fid.close() # Load training data train_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list) #vocab = make_vocab(train_prompts) #print(vocab) phn2id, id2phn = phn2id2phn(vocab) file_list = train_prompts.keys() print(len(file_list), len(train_list)) #save_stats_suffstats(opt.feats_dir + '/fb/', file_list, '.npy', dtype, opt.stats_dir) #save_stats_suffstats(opt.feats_dir + '/sp/', file_list, '.npy', dtype, opt.pfnet_stats_dir) # save_stats(opt.feats_dir + '/fb/', file_list, '.npy', dtype, opt.stats_dir) # save_stats(opt.feats_dir + phase + '/log_mag_spec/', # file_list, opt.pfnet_audio_feats_ext, dtype, opt.pfnet_stats_dir) # exit() # Load stats of mfcc mo1 = np.load(opt.stats_dir + 'mo.npy') so1 = np.load(opt.stats_dir + 'so.npy') mo1 = mo1.astype('float32') so1 = so1.astype('float32') nml_vec1 = np.arange(0, mo1.shape[1]) # Load stats of spectrum mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy') so2 = np.load(opt.pfnet_stats_dir + 'so.npy') mo2 = mo2.astype('float32') so2 = so2.astype('float32') nml_vec2 = np.arange(0, mo2.shape[1]) #train_targets, train_seq_len = load_targets(opt.feats_dir + phase # + '/audio_feats/', file_list, # opt.audio_feats_ext, # dtype, mo1, so1, nml_vec1) # Load validation data val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', val_list) #file_list = val_prompts.keys() #val_targets, val_seq_len = load_targets(opt.feats_dir + '/fb/', val_list, '.npy', dtype, mo1, so1, nml_vec1) #print(val_seq_len) #print(val_targets.shape) #print(val_list) # Initialize model vocab_size = len(vocab) op_dim = 60 encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1) if opt.residual_flag: if opt.r == 2: decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 3: decoder = decoders.AttnDecoderLSTM3L_R3_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 4: decoder = decoders.AttnDecoderLSTM3L_R4_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) if opt.r == 5: decoder = decoders.AttnDecoderLSTM3L_R5_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) else: decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp) op_dim1 = 513 pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1) encoder = encoder.cuda() if use_cuda else encoder decoder = decoder.cuda() if use_cuda else decoder pfnet = pfnet.cuda() if use_cuda else pfnet criterion = torch.nn.L1Loss(size_average=False) if opt.load_wts: load_model_name_pfx = '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_r3_lr3e-05_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_lw1_ef0_rf1_' load_model_name_sfx = '_epoch_2999_5.pth' # load model enc_state_dict = torch.load(load_model_name_pfx + 'enc' + load_model_name_sfx, map_location=lambda storage, loc: storage) encoder.load_state_dict(enc_state_dict) dec_state_dict = torch.load(load_model_name_pfx + 'dec' + load_model_name_sfx, map_location=lambda storage, loc: storage) decoder.load_state_dict(dec_state_dict) pfnet_state_dict = torch.load(load_model_name_pfx + 'pfnet' + load_model_name_sfx, map_location=lambda storage, loc: storage) pfnet.load_state_dict(pfnet_state_dict) tf = True # teacher forcing avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2) print('%0.4f %0.4f' % (avg_val_loss_tf1, avg_val_loss_tf2))
def train(opt): ''' data could be loaded to a dictionary with "train"/"val"/"test" pointers (Need to improve the below part) ''' print_every = opt.print_every showatt_every = opt.print_every plot_every = opt.print_every full_model_name = opt.model_name \ + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) \ + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \ + '_dp' + str(opt.dp) \ + '_gc' + str(opt.gcth) \ + '_wtinit' + str(opt.wtinit_meth) \ + '_ef' + str(int(opt.embedding_flag)) \ + '_rf' + str(int(opt.residual_flag)) print(full_model_name) logging.basicConfig(filename=opt.log_folder + full_model_name + '.log', filemode='w', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') # Load training data phase = 'train' train_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data') vocab = make_vocab(train_prompts) phn2id, id2phn = phn2id2phn(vocab) file_list = train_prompts.keys() save_stats(opt.feats_dir + phase + '/audio_feats/', file_list, opt.audio_feats_ext, dtype, '../stats/') train_targets, train_seq_len = load_targets( opt.feats_dir + phase + '/audio_feats/', file_list, opt.audio_feats_ext, dtype, opt.stats_dir) # Load validation data phase = 'val' val_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data') file_list = val_prompts.keys() val_targets, val_seq_len = load_targets( opt.feats_dir + phase + '/audio_feats/', file_list, opt.audio_feats_ext, dtype, opt.stats_dir) # Initialize model vocab_size = len(vocab) op_dim = 60 encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1) if opt.residual_flag: decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) else: decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp) encoder = encoder.cuda() if use_cuda else encoder decoder = decoder.cuda() if use_cuda else decoder criterion = torch.nn.L1Loss(size_average=False) encoder_optimizer = optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) decoder_optimizer = optim.Adam(decoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) start = time.time() plot_losses = [] print_loss_total = 0 # Reset every print_every plot_loss_total = 0 # Reset every plot_every best_val_loss = 1000000 for iter in range(1, opt.niter + 1): for j, k in enumerate(train_prompts): [input_variable, input_length] = get_x_1hot(train_prompts, k, phn2id, vocab_size, use_cuda) [target_variable, target_variable2, target_length] = get_y(train_seq_len, j, train_targets, use_cuda) loss = 0 encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() encoder_h0 = encoder.initHidden() encoder_c0 = encoder.initHidden() encoder_outputs = Variable( torch.zeros(input_length, encoder.hidden_size)) encoder_outputs = encoder_outputs.cuda( ) if use_cuda else encoder_outputs encoder_output, (encoder_hn, encoder_cn) = encoder(input_variable, (encoder_h0, encoder_c0)) encoder_outputs = encoder_output.squeeze(1) decoder_input = Variable(torch.zeros(1, op_dim)) # all - zero frame decoder_input = decoder_input.cuda() if use_cuda else decoder_input decoder_h1 = decoder.initHidden() decoder_c1 = decoder.initHidden() decoder_h2 = decoder.initHidden() decoder_c2 = decoder.initHidden() decoder_h3 = decoder.initHidden() decoder_c3 = decoder.initHidden() # Teacher forcing: Feed the target as the next input for di in range(target_length): decoder_output1, decoder_output2, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, decoder_attention = decoder( decoder_input, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, encoder_outputs) loss += criterion(decoder_output1, target_variable[di]) decoder_input = target_variable2[di].unsqueeze( 0) # Teacher forcing loss.backward() #torch.nn.utils.clip_grad_norm(encoder.parameters(), 1) #torch.nn.utils.clip_grad_norm(decoder.parameters(), 1) encoder_optimizer.step() decoder_optimizer.step() print_loss_total += (loss.data[0] / target_length) plot_loss_total += (loss.data[0] / target_length) if (j + 1) % print_every == 0: print_loss_avg = print_loss_total / print_every print_loss_total = 0 print('%s (%d %d%%) %.4f' % (timeSince( start, (iter * len(train_prompts) - len(train_prompts) + j) / ((opt.niter + 1) * len(train_prompts))), iter, iter / opt.niter * 100, print_loss_avg)) tf = True # teacher forcing avg_val_loss_tf, decoder_attentions_tf = evaluate( encoder.eval(), decoder.eval(), val_prompts, val_targets, val_seq_len, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf) print('%d %0.4f' % (iter, avg_val_loss_tf)) tf = False # professor forcing avg_val_loss_pf, decoder_attentions_pf = evaluate( encoder.eval(), decoder.eval(), val_prompts, val_targets, val_seq_len, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf) print('%d %0.4f' % (iter, avg_val_loss_pf)) logging.debug('Epoch: ' + str(iter) + ' Update: ' + str(iter * len(train_prompts) - len(train_prompts) + j) + ' Avg Val Loss TF: ' + str(avg_val_loss_tf) + ' Avg Val Loss PF: ' + str(avg_val_loss_pf)) if avg_val_loss_tf < best_val_loss: best_val_loss = avg_val_loss_tf torch.save( encoder.state_dict(), '%s/%s_enc_epoch_%d_%d.pth' % (opt.model_folder, full_model_name, j, iter)) torch.save( decoder.state_dict(), '%s/%s_dec_epoch_%d_%d.pth' % (opt.model_folder, full_model_name, j, iter)) encoder.train() decoder.train() if (j + 1) % showatt_every == 0: plt.figure(1, figsize=(12, 12)) plt.imshow(decoder_attentions_tf.numpy()) plt.colorbar() pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight') plt.close() plt.figure(1, figsize=(12, 12)) plt.imshow(decoder_attentions_pf.numpy()) plt.colorbar() pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight') plt.close() if (j + 1) % plot_every == 0: plot_loss_avg = plot_loss_total / plot_every plot_losses.append(plot_loss_avg) plot_loss_total = 0 showPlot(plot_losses)
def test(opt): full_model_name = opt.model_name \ + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1)\ + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \ + '_dp' + str(opt.dp) \ + '_gc' + str(opt.gcth) \ + '_wtinit' + str(opt.wtinit_meth) \ + '_ef' + str(int(opt.embedding_flag)) \ + '_rf' + str(int(opt.residual_flag)) print(full_model_name) opt.full_model_name = full_model_name # Load training data phase = 'train' train_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data') vocab = make_vocab(train_prompts) phn2id, id2phn = phn2id2phn(vocab) # Load validation data phase = 'test' val_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data') file_list = val_prompts.keys() file_list = file_list[:5] val_targets, val_seq_len = load_targets( opt.feats_dir + phase + '/audio_feats/', file_list, opt.audio_feats_ext, dtype, opt.stats_dir) # Initialize model vocab_size = len(vocab) op_dim = 60 encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1) if opt.residual_flag: decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp) else: decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp) op_dim1 = 513 pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1) encoder = encoder.cuda() if use_cuda else encoder decoder = decoder.cuda() if use_cuda else decoder pfnet = pfnet.cuda() if use_cuda else pfnet criterion = torch.nn.L1Loss(size_average=False) # load model enc_state_dict = torch.load( '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_enc_epoch_999_18.pth', map_location=lambda storage, loc: storage) encoder.load_state_dict(enc_state_dict) dec_state_dict = torch.load( '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_dec_epoch_999_18.pth', map_location=lambda storage, loc: storage) decoder.load_state_dict(dec_state_dict) pfnet_state_dict = torch.load( '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_pfnet_epoch_999_18.pth', map_location=lambda storage, loc: storage) pfnet.load_state_dict(pfnet_state_dict) #tf = True # teacher forcing #avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, val_targets, val_seq_len, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt) #print('%d %0.4f %0.4f' % (iter, avg_val_loss_tf1, avg_val_loss_tf2)) tf = False # professor forcing avg_val_loss_pf1, avg_val_loss_pf2, decoder_attentions_pf = evaluate( encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, val_targets, val_seq_len, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt) print('%d %0.4f %0.4f' % (iter, avg_val_loss_pf1, avg_val_loss_pf2))