def run_model(mode, path, in_file, o_file): global feature, encoder, indp, crf, mldecoder, rltrain, f_opt, e_opt, i_opt, c_opt, m_opt, r_opt cfg = Configuration() #General mode has two values: 'train' or 'test' cfg.mode = mode #Set Random Seeds random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) if hasCuda: torch.cuda.manual_seed_all(cfg.seed) #Load Embeddings load_embeddings(cfg) #Only for testing if mode == 'test': cfg.test_raw = in_file #Construct models feature = Feature(cfg) if cfg.model_type == 'AC-RNN': f_opt = optim.SGD(ifilter(lambda p: p.requires_grad, feature.parameters()), lr=cfg.actor_step_size) else: f_opt = optim.Adam(ifilter(lambda p: p.requires_grad, feature.parameters()), lr=cfg.learning_rate) if hasCuda: feature.cuda() encoder = Encoder(cfg) if cfg.model_type == 'AC-RNN': e_opt = optim.SGD(ifilter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.actor_step_size) else: e_opt = optim.Adam(ifilter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.learning_rate) if hasCuda: encoder.cuda() if cfg.model_type == 'INDP': indp = INDP(cfg) i_opt = optim.Adam(ifilter(lambda p: p.requires_grad, indp.parameters()), lr=cfg.learning_rate) if hasCuda: indp.cuda() elif cfg.model_type == 'CRF': crf = CRF(cfg) c_opt = optim.Adam(ifilter(lambda p: p.requires_grad, crf.parameters()), lr=cfg.learning_rate) if hasCuda: crf.cuda() elif cfg.model_type == 'TF-RNN': mldecoder = MLDecoder(cfg) m_opt = optim.Adam(ifilter(lambda p: p.requires_grad, mldecoder.parameters()), lr=cfg.learning_rate) if hasCuda: mldecoder.cuda() cfg.mldecoder_type = 'TF' elif cfg.model_type == 'SS-RNN': mldecoder = MLDecoder(cfg) m_opt = optim.Adam(ifilter(lambda p: p.requires_grad, mldecoder.parameters()), lr=cfg.learning_rate) if hasCuda: mldecoder.cuda() cfg.mldecoder_type = 'SS' elif cfg.model_type == 'AC-RNN': mldecoder = MLDecoder(cfg) m_opt = optim.SGD(ifilter(lambda p: p.requires_grad, mldecoder.parameters()), lr=cfg.actor_step_size) if hasCuda: mldecoder.cuda() cfg.mldecoder_type = 'TF' rltrain = RLTrain(cfg) r_opt = optim.Adam(ifilter(lambda p: p.requires_grad, rltrain.parameters()), lr=cfg.learning_rate, weight_decay=0.001) if hasCuda: rltrain.cuda() cfg.rltrain_type = 'AC' #For RL, the network should be pre-trained with teacher forced ML decoder. feature.load_state_dict(torch.load(path + 'TF-RNN' + '_feature')) encoder.load_state_dict(torch.load(path + 'TF-RNN' + '_encoder')) mldecoder.load_state_dict(torch.load(path + 'TF-RNN' + '_predictor')) if mode == 'train': o_file = './temp.predicted_' + cfg.model_type best_val_cost = float('inf') best_val_epoch = 0 first_start = time.time() epoch = 0 while (epoch < cfg.max_epochs): print print 'Model:{} | Epoch:{}'.format(cfg.model_type, epoch) if cfg.model_type == 'SS-RNN': #Specify the decaying schedule for sampling probability. #inverse sigmoid schedule: cfg.sampling_p = float( cfg.k) / float(cfg.k + np.exp(float(epoch) / cfg.k)) start = time.time() run_epoch(cfg) print '\nValidation:' predict(cfg, o_file) val_cost = 100 - evaluate(cfg, cfg.dev_ref, o_file) print 'Validation score:{}'.format(100 - val_cost) if val_cost < best_val_cost: best_val_cost = val_cost best_val_epoch = epoch torch.save(feature.state_dict(), path + cfg.model_type + '_feature') torch.save(encoder.state_dict(), path + cfg.model_type + '_encoder') if cfg.model_type == 'INDP': torch.save(indp.state_dict(), path + cfg.model_type + '_predictor') elif cfg.model_type == 'CRF': torch.save(crf.state_dict(), path + cfg.model_type + '_predictor') elif cfg.model_type == 'TF-RNN' or cfg.model_type == 'SS-RNN': torch.save(mldecoder.state_dict(), path + cfg.model_type + '_predictor') elif cfg.model_type == 'AC-RNN': torch.save(mldecoder.state_dict(), path + cfg.model_type + '_predictor') torch.save(rltrain.state_dict(), path + cfg.model_type + '_critic') #For early stopping if epoch - best_val_epoch > cfg.early_stopping: break ### print 'Epoch training time:{} seconds'.format(time.time() - start) epoch += 1 print 'Total training time:{} seconds'.format(time.time() - first_start) elif mode == 'test': cfg.batch_size = 256 feature.load_state_dict(torch.load(path + cfg.model_type + '_feature')) encoder.load_state_dict(torch.load(path + cfg.model_type + '_encoder')) if cfg.model_type == 'INDP': indp.load_state_dict( torch.load(path + cfg.model_type + '_predictor')) elif cfg.model_type == 'CRF': crf.load_state_dict( torch.load(path + cfg.model_type + '_predictor')) elif cfg.model_type == 'TF-RNN' or cfg.model_type == 'SS-RNN': mldecoder.load_state_dict( torch.load(path + cfg.model_type + '_predictor')) elif cfg.model_type == 'AC-RNN': mldecoder.load_state_dict( torch.load(path + cfg.model_type + '_predictor')) rltrain.load_state_dict( torch.load(path + cfg.model_type + '_critic')) print print 'Model:{} Predicting'.format(cfg.model_type) start = time.time() predict(cfg, o_file) print 'Total prediction time:{} seconds'.format(time.time() - start) return
def train(args): # initalize dataset with Timed('Loading dataset'): ds = tiny_words(max_text_length=hp.max_text_length, max_audio_length=hp.max_audio_length, max_dataset_size=args.data_size) # initialize model with Timed('Initializing model.'): encoder = Encoder(ds.lang.num_chars, hp.embedding_dim, hp.encoder_bank_k, hp.encoder_bank_ck, hp.encoder_proj_dims, hp.encoder_highway_layers, hp.encoder_highway_units, hp.encoder_gru_units, dropout=hp.dropout, use_cuda=hp.use_cuda) decoder = AttnDecoder(hp.max_text_length, hp.attn_gru_hidden_size, hp.n_mels, hp.rf, hp.decoder_gru_hidden_size, hp.decoder_gru_layers, dropout=hp.dropout, use_cuda=hp.use_cuda) postnet = PostNet(hp.n_mels, 1 + hp.n_fft // 2, hp.post_bank_k, hp.post_bank_ck, hp.post_proj_dims, hp.post_highway_layers, hp.post_highway_units, hp.post_gru_units, use_cuda=hp.use_cuda) if args.multi_gpus: all_devices = list(range(torch.cuda.device_count())) encoder = nn.DataParallel(encoder, device_ids=all_devices) decoder = nn.DataParallel(decoder, device_ids=all_devices) postnet = nn.DataParallel(postnet, device_ids=all_devices) if hp.use_cuda: encoder.cuda() decoder.cuda() postnet.cuda() # initialize optimizers and criterion all_paramters = (list(encoder.parameters()) + list(decoder.parameters()) + list(postnet.parameters())) optimizer = optim.Adam(all_paramters, lr=hp.lr) criterion = nn.L1Loss() # configuring traingin print_every = 100 save_every = 1000 # Keep track of time elapsed and running averages start = time.time() print_loss_total = 0 # Reset every print_every for epoch in range(1, hp.n_epochs + 1): # get training data for this cycle mels, mags, indexed_texts = ds.next_batch(hp.batch_size) mels_v = Variable(torch.from_numpy(mels).float()) mags_v = Variable(torch.from_numpy(mags).float()) texts_v = Variable(torch.from_numpy(indexed_texts)) if hp.use_cuda: mels_v = mels_v.cuda() mags_v = mags_v.cuda() texts_v = texts_v.cuda() loss = train_batch(mels_v, mags_v, texts_v, encoder, decoder, postnet, optimizer, criterion, multi_gpus=args.multi_gpus) # Keep track of loss print_loss_total += loss if epoch == 0: continue if epoch % print_every == 0: print_loss_avg = print_loss_total / print_every print_loss_total = 0 print_summary = '%s (%d %d%%) %.4f' % \ (time_since(start, epoch / hp.n_epochs), epoch, epoch / hp.n_epochs * 100, print_loss_avg) print(print_summary) if epoch % save_every == 0: save_checkpoint({ 'epoch': epoch + 1, 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'postnet': postnet.state_dict(), 'optimizer': optimizer.state_dict(), })
def inference(checkpoint_file, text): ds = tiny_words(max_text_length=hp.max_text_length, max_audio_length=hp.max_audio_length, max_dataset_size=args.data_size) print(ds.texts) # prepare input indexes = indexes_from_text(ds.lang, text) indexes.append(EOT_token) padded_indexes = pad_indexes(indexes, hp.max_text_length, PAD_token) texts_v = Variable(torch.from_numpy(padded_indexes)) texts_v = texts_v.unsqueeze(0) if hp.use_cuda: texts_v = texts_v.cuda() encoder = Encoder(ds.lang.num_chars, hp.embedding_dim, hp.encoder_bank_k, hp.encoder_bank_ck, hp.encoder_proj_dims, hp.encoder_highway_layers, hp.encoder_highway_units, hp.encoder_gru_units, dropout=hp.dropout, use_cuda=hp.use_cuda) decoder = AttnDecoder(hp.max_text_length, hp.attn_gru_hidden_size, hp.n_mels, hp.rf, hp.decoder_gru_hidden_size, hp.decoder_gru_layers, dropout=hp.dropout, use_cuda=hp.use_cuda) postnet = PostNet(hp.n_mels, 1 + hp.n_fft // 2, hp.post_bank_k, hp.post_bank_ck, hp.post_proj_dims, hp.post_highway_layers, hp.post_highway_units, hp.post_gru_units, use_cuda=hp.use_cuda) encoder.eval() decoder.eval() postnet.eval() if hp.use_cuda: encoder.cuda() decoder.cuda() postnet.cuda() # load model checkpoint = torch.load(checkpoint_file) encoder.load_state_dict(checkpoint['encoder']) decoder.load_state_dict(checkpoint['decoder']) postnet.load_state_dict(checkpoint['postnet']) encoder_out = encoder(texts_v) # Prepare input and output variables GO_frame = np.zeros((1, hp.n_mels)) decoder_in = Variable(torch.from_numpy(GO_frame).float()) if hp.use_cuda: decoder_in = decoder_in.cuda() h, hs = decoder.init_hiddens(1) decoder_outs = [] for t in range(int(hp.max_audio_length / hp.rf)): decoder_out, h, hs, _ = decoder(decoder_in, h, hs, encoder_out) decoder_outs.append(decoder_out) # use predict decoder_in = decoder_out[:, -1, :].contiguous() # (batch_size, T, n_mels) decoder_outs = torch.cat(decoder_outs, 1) # postnet post_out = postnet(decoder_outs) s = post_out[0].cpu().data.numpy() print("Recontructing wav...") s = np.where(s < 0, 0, s) wav = spectrogram2wav(s**hp.power) # wav = griffinlim(s**hp.power) write("demo.wav", hp.sr, wav)