def main(args): model_path = args.model_path hparams.set_hparam('batch_size', 1) hparams.add_hparam('is_training', False) check_vocab(args) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) datasets = load_dataset(args, src_placeholder) iterator = iterator_utils.get_inference_iterator(hparams, datasets) src_vocab, tgt_vocab, _, tgt_reverse_vocab, src_vocab_size, tgt_vocab_size = datasets hparams.add_hparam('vocab_size_source', src_vocab_size) hparams.add_hparam('vocab_size_target', tgt_vocab_size) sess, model = load_model(hparams, tf.contrib.learn.ModeKeys.INFER, iterator, src_vocab, tgt_vocab, tgt_reverse_vocab) ckpt = tf.train.latest_checkpoint(args.model_path) saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) if ckpt: saver.restore(sess, ckpt) else: raise Exception("can not found checkpoint file") src_vocab_file = os.path.join(model_path, 'vocab.src') src_reverse_vocab = build_reverse_vocab_table(src_vocab_file, hparams) sess.run(tf.tables_initializer()) index = 1 inputs = np.array(get_data(args), dtype=np.str) with sess: logger.info("starting inference...") sess.run(iterator.initializer, feed_dict={src_placeholder: inputs}) eos = hparams.eos.encode() pad = hparams.pad.encode() while True: try: predictions, confidence, source = model.inference(sess) source_sent = src_reverse_vocab.lookup(tf.constant(list(source[0]), tf.int64)) source_sent = sess.run(source_sent) print(index, text_utils.format_bpe_text(source_sent, [eos, pad])) if hparams.beam_width == 1: print(bytes2sent(list(predictions[0]), [eos, pad])) else: print(bytes2sent(list(predictions[0][:, 0]), [eos, pad])) if confidence is not None: print(confidence[0]) print() if index > args.max_data_size: break index += 1 except tf.errors.OutOfRangeError: logger.info('Done inference') break
def train(device, model, train_data_loader, test_data_loader, optimizer, checkpoint_dir=None, checkpoint_interval=None, nepochs=None): global global_step, global_epoch resumed_step = global_step while global_epoch < nepochs: print('Starting Epoch: {}'.format(global_epoch)) running_sync_loss, running_l1_loss = 0., 0. prog_bar = tqdm(enumerate(train_data_loader)) for step, (x, indiv_mels, mel, gt) in prog_bar: model.train() optimizer.zero_grad() # Move data to CUDA device x = x.to(device) mel = mel.to(device) indiv_mels = indiv_mels.to(device) gt = gt.to(device) g = model(indiv_mels, x) if hparams.syncnet_wt > 0.: sync_loss = get_sync_loss(mel, g) else: sync_loss = 0. l1loss = recon_loss(g, gt) loss = hparams.syncnet_wt * sync_loss + ( 1 - hparams.syncnet_wt) * l1loss loss.backward() optimizer.step() if global_step % checkpoint_interval == 0: save_sample_images(x, g, gt, global_step, checkpoint_dir) global_step += 1 cur_session_steps = global_step - resumed_step running_l1_loss += l1loss.item() if hparams.syncnet_wt > 0.: running_sync_loss += sync_loss.item() else: running_sync_loss += 0. if global_step == 1 or global_step % checkpoint_interval == 0: save_checkpoint(model, optimizer, global_step, checkpoint_dir, global_epoch) if global_step == 1 or global_step % hparams.eval_interval == 0: with torch.no_grad(): average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir) if average_sync_loss < .75: hparams.set_hparam( 'syncnet_wt', 0.01 ) # without image GAN a lesser weight is sufficient prog_bar.set_description('L1: {}, Sync Loss: {}'.format( running_l1_loss / (step + 1), running_sync_loss / (step + 1))) global_epoch += 1
def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer, checkpoint_dir=None, checkpoint_interval=None, nepochs=None): global global_step, global_epoch resumed_step = global_step print('global_epoch: ', global_epoch) while global_epoch < nepochs: print('Starting Epoch: {}'.format(global_epoch)) running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0. running_disc_real_loss, running_disc_fake_loss = 0., 0. #running_disc_real_acc, running_disc_fake_acc = 0., 0. prog_bar = tqdm(enumerate(train_data_loader)) for step, (x, indiv_mels, mel, gt) in prog_bar: disc.train() model.train() x = x.to(device) mel = mel.to(device) indiv_mels = indiv_mels.to(device) gt = gt.to(device) ### Train generator now. Remove ALL grads. optimizer.zero_grad() disc_optimizer.zero_grad() g = model(indiv_mels, x) if hparams.syncnet_wt > 0.: sync_loss = get_sync_loss(mel, g) else: sync_loss = 0. if hparams.disc_wt > 0.: perceptual_loss = disc.perceptual_forward(g) else: perceptual_loss = 0. l1loss = recon_loss(g, gt) loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss loss.backward() optimizer.step() ### Remove all gradients before Training disc disc_optimizer.zero_grad() pred = disc(gt) disc_real_loss = F.binary_cross_entropy( pred, torch.ones((len(pred), 1)).to(device)) disc_real_loss.backward() ''' pred_label = pred.detach() pred_label[pred<=0.5], pred_label[pred>0.5] = 0, 1 disc_real_acc = torch.sum(pred_label==1) / len(pred) * 100 ''' pred = disc(g.detach()) disc_fake_loss = F.binary_cross_entropy( pred, torch.zeros((len(pred), 1)).to(device)) disc_fake_loss.backward() ''' pred_label = pred.detach() pred_label[pred<=0.5], pred_label[pred>0.5] = 0, 1 disc_fake_acc = torch.sum(pred_label==0) / len(pred) * 100 ''' disc_optimizer.step() running_disc_real_loss += disc_real_loss.item() running_disc_fake_loss += disc_fake_loss.item() #running_disc_real_acc += disc_real_acc.item() #running_disc_fake_acc += disc_fake_acc.item() if global_step % checkpoint_interval == 0: save_sample_images(x, g, gt, global_step, checkpoint_dir) # Logs global_step += 1 cur_session_steps = global_step - resumed_step running_l1_loss += l1loss.item() if hparams.syncnet_wt > 0.: running_sync_loss += sync_loss.item() else: running_sync_loss += 0. if hparams.disc_wt > 0.: running_perceptual_loss += perceptual_loss.item() else: running_perceptual_loss += 0. if global_step == 1 or global_step % checkpoint_interval == 0: save_checkpoint(model, optimizer, global_step, checkpoint_dir, global_epoch) save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_') if global_step % hparams.eval_interval == 0: with torch.no_grad(): average_sync_loss = eval_model(test_data_loader, device, model, disc) print('Average_sync_loss: ', average_sync_loss) print('hparams.disc_wt: ', hparams.disc_wt) print('hparams.syncnet_wt: ', hparams.syncnet_wt) if average_sync_loss < .75: hparams.set_hparam('syncnet_wt', 0.03) prog_bar.set_description( '[Train] Epoch {} - L1: {}, Sync: {}, Percep: {} | Loss Fake: {}, Real: {}' .format(global_epoch, running_l1_loss / (step + 1), running_sync_loss / (step + 1), running_perceptual_loss / (step + 1), running_disc_fake_loss / (step + 1), running_disc_real_loss / (step + 1))) global_epoch += 1
print("Training postnet model") else: assert False, "must be specified wrong args" # Load preset if specified if preset is not None: with open(preset) as f: hparams.parse_json(f.read()) # Override hyper parameters hparams.parse(args["--hparams"]) # Preventing Windows specific error such as MemoryError # Also reduces the occurrence of THAllocator.c 0x05 error in Widows build of PyTorch if platform.system() == "Windows": print("Windows Detected - num_workers set to 1") hparams.set_hparam('num_workers', 1) assert hparams.name == "deepvoice3" print(hparams_debug_string()) _frontend = getattr(frontend, hparams.frontend) os.makedirs(checkpoint_dir, exist_ok=True) # Input dataset definitions X = FileSourceDataset(TextDataSource(data_root, speaker_id)) Mel = FileSourceDataset(MelSpecDataSource(data_root, speaker_id)) Y = FileSourceDataset(LinearSpecDataSource(data_root, speaker_id)) # Prepare sampler frame_lengths = Mel.file_data_source.frame_lengths
def train(args, device, model_Wav2Lip, train_data_loader_Left, test_data_loader_Left, optimizer, checkpoint_dir=None, checkpoint_interval=None, nepochs=None): n_img, loader = prepare_dataloader(args, 'train') # CycleGAN val_n_img, val_loader = prepare_dataloader(args, 'val') # CycleGAN model_cycle = MonoDepthArchitecture(args) # Modified model_cycle.set_data_loader(loader) # Modified global global_step, global_epoch # Wav2Lip resumed_step = global_step # Wav2Lip if not args.resume: # CycleGAN best_val_loss = float('Inf') # CycleGAN validate_cycle(-1) # # CycleGAN pre_validation_update(model_cycle.losses[-1]['val']) # Modified else: best_val_loss = min([ model_cycle.losses[epoch]['val']['G'] for epoch in model_cycle.losses.keys() ]) # Modified running_val_loss = 0.0 while global_epoch < nepochs: # Cycle GAN c_time = time.time() model_cycle.to_train() # Modified model_cycle.set_new_loss_item(global_epoch) # Modified model_cycle.run_epoch(global_epoch, n_img) # Modified validate_cycle(global_epoch) # M print_epoch_update(global_epoch, time.time() - c_time, model_cycle.losses) # Modified # Make a checkpoint running_val_loss = model_cycle.losses[global_epoch]['val'][ 'G'] # Modified is_best = running_val_loss < best_val_loss if is_best: best_val_loss = running_val_loss print('Starting Epoch: {}'.format(global_epoch)) running_sync_loss, running_l1_loss = 0., 0. prog_bar = tqdm(enumerate(train_data_loader_Left)) for step, (x, indiv_mels, mel, gt_left) in prog_bar: # M model_Wav2Lip.train() optimizer.zero_grad() # Move data to CUDA device x = x.to(device) mel = mel.to(device) indiv_mels = indiv_mels.to(device) gt_left = gt_left.to(device) # Modified g_left = model_Wav2Lip(indiv_mels, x) # Modified if hparams.syncnet_wt > 0.: sync_loss = get_sync_loss(mel, g_left) else: sync_loss = 0. l1loss = recon_loss(g_left, gt_left) + best_val_loss * (recon_loss( g_left, gt_left)) # Modified loss = hparams.syncnet_wt * sync_loss + ( 1 - hparams.syncnet_wt) * l1loss loss.backward() optimizer.step() if global_step % checkpoint_interval == 0: save_sample_images(x, g_left, gt_left, global_step, checkpoint_dir) global_step += 1 cur_session_steps = global_step - resumed_step running_l1_loss += l1loss.item() if hparams.syncnet_wt > 0.: running_sync_loss += sync_loss.item() else: running_sync_loss += 0. if global_step == 1 or global_step % checkpoint_interval == 0: save_checkpoint(model_Wav2Lip, optimizer, global_step, checkpoint_dir, global_epoch) if global_step == 1 or global_step % hparams.eval_interval == 0: with torch.no_grad(): average_sync_loss = eval_model(test_data_loader_Left, global_step, device, model_Wav2Lip, checkpoint_dir, best_val_loss) if average_sync_loss < .75: hparams.set_hparam( 'syncnet_wt', 0.01 ) # without image GAN a lesser weight is sufficient prog_bar.set_description('L1: {}, Sync Loss: {}'.format( running_l1_loss / (step + 1), running_sync_loss / (step + 1))) model_cycle.save_checkpoint(global_epoch, is_best, best_val_loss) global_epoch += 1 print('Finished Training. Best validation loss:\t{:.3f}'.format( best_val_loss)) model_cycle.save_networks('final') if running_val_loss != best_val_loss: model_cycle.save_best_networks() model_cycle.save_losses()
def main(args, max_data_size=0, shuffle=True, display=False): hparams.set_hparam('batch_size', 10) hparams.add_hparam('is_training', False) check_vocab(args) datasets, src_data_size = load_dataset(args) iterator = iterator_utils.get_eval_iterator(hparams, datasets, hparams.eos, shuffle=shuffle) src_vocab, tgt_vocab, src_dataset, tgt_dataset, tgt_reverse_vocab, src_vocab_size, tgt_vocab_size = datasets hparams.add_hparam('vocab_size_source', src_vocab_size) hparams.add_hparam('vocab_size_target', tgt_vocab_size) sess, model = load_model(hparams, tf.contrib.learn.ModeKeys.EVAL, iterator, src_vocab, tgt_vocab, tgt_reverse_vocab) if args.restore_step: checkpoint_path = os.path.join(args.model_path, 'nmt.ckpt') ckpt = '%s-%d' % (checkpoint_path, args.restore_step) else: ckpt = tf.train.latest_checkpoint(args.model_path) saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) if ckpt: saver.restore(sess, ckpt) else: raise Exception("can not found checkpoint file") src_vocab_file = os.path.join(args.model_path, 'vocab.src') src_reverse_vocab = build_reverse_vocab_table(src_vocab_file, hparams) sess.run(tf.tables_initializer()) step_count = 1 with sess: logger.info("starting evaluating...") sess.run(iterator.initializer) eos = hparams.eos.encode() references = [] translations = [] start_time = time.time() while True: try: if (max_data_size > 0) and (step_count * hparams.batch_size > max_data_size): break if step_count % 10 == 0: t = time.time() - start_time logger.info('step={0} total={1} time={2:.3f}'.format(step_count, step_count * hparams.batch_size, t)) start_time = time.time() predictions, source, target, source_text, confidence = model.eval(sess) reference = bpe2sent(target, eos) if hparams.beam_width == 1: translation = bytes2sent(list(predictions), eos) else: translation = bytes2sent(list(predictions[:, 0]), eos) for s, r, t in zip(source, reference, translation): if display: source_sent = src_reverse_vocab.lookup(tf.constant(list(s), tf.int64)) source_sent = sess.run(source_sent) source_sent = text_utils.format_bpe_text(source_sent, eos) print('{}\n{}\n{}\n'.format(source_sent, r, t)) references.append(r) translations.append(t) if step_count % 100 == 0: bleu_score = moses_multi_bleu(references, translations, args.model_path) logger.info('bleu score = {0:.3f}'.format(bleu_score)) step_count += 1 except tf.errors.OutOfRangeError: logger.info('Done eval data') break logger.info('compute bleu score...') # bleu_score = compute_bleu_score(references, translations) bleu_score = moses_multi_bleu(references, translations, args.model_path) logger.info('bleu score = {0:.3f}'.format(bleu_score))
def train( device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer, checkpoint_dir=None, checkpoint_interval=None, nepochs=None, ): global global_step, global_epoch resumed_step = global_step while global_epoch < nepochs: print("Starting Epoch: {}".format(global_epoch)) running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = ( 0.0, 0.0, 0.0, 0.0, ) running_disc_real_loss, running_disc_fake_loss = 0.0, 0.0 prog_bar = tqdm(enumerate(train_data_loader)) for step, (x, indiv_mels, mel, gt) in prog_bar: disc.train() model.train() x = x.to(device) mel = mel.to(device) indiv_mels = indiv_mels.to(device) gt = gt.to(device) ### Train generator now. Remove ALL grads. optimizer.zero_grad() disc_optimizer.zero_grad() g = model(indiv_mels, x) if hparams.syncnet_wt > 0.0: sync_loss = get_sync_loss(mel, g) else: sync_loss = 0.0 if hparams.disc_wt > 0.0: perceptual_loss = disc.perceptual_forward(g) else: perceptual_loss = 0.0 l1loss = recon_loss(g, gt) loss = (hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + (1.0 - hparams.syncnet_wt - hparams.disc_wt) * l1loss) loss.backward() optimizer.step() ### Remove all gradients before Training disc disc_optimizer.zero_grad() pred = disc(gt) disc_real_loss = F.binary_cross_entropy( pred, torch.ones((len(pred), 1)).to(device)) disc_real_loss.backward() pred = disc(g.detach()) disc_fake_loss = F.binary_cross_entropy( pred, torch.zeros((len(pred), 1)).to(device)) disc_fake_loss.backward() disc_optimizer.step() running_disc_real_loss += disc_real_loss.item() running_disc_fake_loss += disc_fake_loss.item() if global_step % checkpoint_interval == 0: save_sample_images(x, g, gt, global_step, checkpoint_dir) # Logs global_step += 1 cur_session_steps = global_step - resumed_step running_l1_loss += l1loss.item() if hparams.syncnet_wt > 0.0: running_sync_loss += sync_loss.item() else: running_sync_loss += 0.0 if hparams.disc_wt > 0.0: running_perceptual_loss += perceptual_loss.item() else: running_perceptual_loss += 0.0 if global_step == 1 or global_step % checkpoint_interval == 0: save_checkpoint(model, optimizer, global_step, checkpoint_dir, global_epoch) save_checkpoint( disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix="disc_", ) if global_step % hparams.eval_interval == 0: with torch.no_grad(): average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc) if average_sync_loss < 0.75: hparams.set_hparam("syncnet_wt", 0.03) prog_bar.set_description( "L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}".format( running_l1_loss / (step + 1), running_sync_loss / (step + 1), running_perceptual_loss / (step + 1), running_disc_fake_loss / (step + 1), running_disc_real_loss / (step + 1), )) global_epoch += 1