コード例 #1
0
ファイル: predict.py プロジェクト: wyb330/nmt
def main(args):
    model_path = args.model_path
    hparams.set_hparam('batch_size', 1)
    hparams.add_hparam('is_training', False)
    check_vocab(args)
    src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    datasets = load_dataset(args, src_placeholder)
    iterator = iterator_utils.get_inference_iterator(hparams, datasets)
    src_vocab, tgt_vocab, _, tgt_reverse_vocab, src_vocab_size, tgt_vocab_size = datasets
    hparams.add_hparam('vocab_size_source', src_vocab_size)
    hparams.add_hparam('vocab_size_target', tgt_vocab_size)

    sess, model = load_model(hparams, tf.contrib.learn.ModeKeys.INFER, iterator, src_vocab, tgt_vocab, tgt_reverse_vocab)

    ckpt = tf.train.latest_checkpoint(args.model_path)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
    if ckpt:
        saver.restore(sess, ckpt)
    else:
        raise Exception("can not found checkpoint file")

    src_vocab_file = os.path.join(model_path, 'vocab.src')
    src_reverse_vocab = build_reverse_vocab_table(src_vocab_file, hparams)
    sess.run(tf.tables_initializer())

    index = 1
    inputs = np.array(get_data(args), dtype=np.str)
    with sess:
        logger.info("starting inference...")
        sess.run(iterator.initializer, feed_dict={src_placeholder: inputs})
        eos = hparams.eos.encode()
        pad = hparams.pad.encode()
        while True:
            try:
                predictions, confidence, source = model.inference(sess)
                source_sent = src_reverse_vocab.lookup(tf.constant(list(source[0]), tf.int64))
                source_sent = sess.run(source_sent)
                print(index, text_utils.format_bpe_text(source_sent, [eos, pad]))
                if hparams.beam_width == 1:
                    print(bytes2sent(list(predictions[0]), [eos, pad]))
                else:
                    print(bytes2sent(list(predictions[0][:, 0]), [eos, pad]))
                if confidence is not None:
                    print(confidence[0])
                print()
                if index > args.max_data_size:
                    break
                index += 1
            except tf.errors.OutOfRangeError:
                logger.info('Done inference')
                break
コード例 #2
0
def train(device,
          model,
          train_data_loader,
          test_data_loader,
          optimizer,
          checkpoint_dir=None,
          checkpoint_interval=None,
          nepochs=None):

    global global_step, global_epoch
    resumed_step = global_step

    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss = 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            model.train()
            optimizer.zero_grad()

            # Move data to CUDA device
            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            g = model(indiv_mels, x)

            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            l1loss = recon_loss(g, gt)

            loss = hparams.syncnet_wt * sync_loss + (
                1 - hparams.syncnet_wt) * l1loss
            loss.backward()
            optimizer.step()

            if global_step % checkpoint_interval == 0:
                save_sample_images(x, g, gt, global_step, checkpoint_dir)

            global_step += 1
            cur_session_steps = global_step - resumed_step

            running_l1_loss += l1loss.item()
            if hparams.syncnet_wt > 0.:
                running_sync_loss += sync_loss.item()
            else:
                running_sync_loss += 0.

            if global_step == 1 or global_step % checkpoint_interval == 0:
                save_checkpoint(model, optimizer, global_step, checkpoint_dir,
                                global_epoch)

            if global_step == 1 or global_step % hparams.eval_interval == 0:
                with torch.no_grad():
                    average_sync_loss = eval_model(test_data_loader,
                                                   global_step, device, model,
                                                   checkpoint_dir)

                    if average_sync_loss < .75:
                        hparams.set_hparam(
                            'syncnet_wt', 0.01
                        )  # without image GAN a lesser weight is sufficient

            prog_bar.set_description('L1: {}, Sync Loss: {}'.format(
                running_l1_loss / (step + 1), running_sync_loss / (step + 1)))

        global_epoch += 1
コード例 #3
0
def train(device,
          model,
          disc,
          train_data_loader,
          test_data_loader,
          optimizer,
          disc_optimizer,
          checkpoint_dir=None,
          checkpoint_interval=None,
          nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step

    print('global_epoch: ', global_epoch)

    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
        running_disc_real_loss, running_disc_fake_loss = 0., 0.
        #running_disc_real_acc, running_disc_fake_acc = 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            disc.train()
            model.train()

            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            ### Train generator now. Remove ALL grads.
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            g = model(indiv_mels, x)

            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            if hparams.disc_wt > 0.:
                perceptual_loss = disc.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            l1loss = recon_loss(g, gt)

            loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
                                    (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss

            loss.backward()
            optimizer.step()

            ### Remove all gradients before Training disc
            disc_optimizer.zero_grad()

            pred = disc(gt)
            disc_real_loss = F.binary_cross_entropy(
                pred,
                torch.ones((len(pred), 1)).to(device))
            disc_real_loss.backward()
            '''
            pred_label = pred.detach()
            pred_label[pred<=0.5], pred_label[pred>0.5] = 0, 1
            disc_real_acc = torch.sum(pred_label==1) / len(pred) * 100
            '''

            pred = disc(g.detach())
            disc_fake_loss = F.binary_cross_entropy(
                pred,
                torch.zeros((len(pred), 1)).to(device))
            disc_fake_loss.backward()
            '''
            pred_label = pred.detach()
            pred_label[pred<=0.5], pred_label[pred>0.5] = 0, 1
            disc_fake_acc = torch.sum(pred_label==0) / len(pred) * 100
            '''
            disc_optimizer.step()

            running_disc_real_loss += disc_real_loss.item()
            running_disc_fake_loss += disc_fake_loss.item()
            #running_disc_real_acc += disc_real_acc.item()
            #running_disc_fake_acc += disc_fake_acc.item()

            if global_step % checkpoint_interval == 0:
                save_sample_images(x, g, gt, global_step, checkpoint_dir)

            # Logs
            global_step += 1
            cur_session_steps = global_step - resumed_step

            running_l1_loss += l1loss.item()
            if hparams.syncnet_wt > 0.:
                running_sync_loss += sync_loss.item()
            else:
                running_sync_loss += 0.

            if hparams.disc_wt > 0.:
                running_perceptual_loss += perceptual_loss.item()
            else:
                running_perceptual_loss += 0.

            if global_step == 1 or global_step % checkpoint_interval == 0:
                save_checkpoint(model, optimizer, global_step, checkpoint_dir,
                                global_epoch)
                save_checkpoint(disc,
                                disc_optimizer,
                                global_step,
                                checkpoint_dir,
                                global_epoch,
                                prefix='disc_')

            if global_step % hparams.eval_interval == 0:
                with torch.no_grad():
                    average_sync_loss = eval_model(test_data_loader, device,
                                                   model, disc)
                    print('Average_sync_loss: ', average_sync_loss)
                    print('hparams.disc_wt: ', hparams.disc_wt)
                    print('hparams.syncnet_wt: ', hparams.syncnet_wt)
                    if average_sync_loss < .75:
                        hparams.set_hparam('syncnet_wt', 0.03)

            prog_bar.set_description(
                '[Train] Epoch {} - L1: {}, Sync: {}, Percep: {} | Loss Fake: {}, Real: {}'
                .format(global_epoch, running_l1_loss / (step + 1),
                        running_sync_loss / (step + 1),
                        running_perceptual_loss / (step + 1),
                        running_disc_fake_loss / (step + 1),
                        running_disc_real_loss / (step + 1)))

        global_epoch += 1
コード例 #4
0
        print("Training postnet model")
    else:
        assert False, "must be specified wrong args"

    # Load preset if specified
    if preset is not None:
        with open(preset) as f:
            hparams.parse_json(f.read())
    # Override hyper parameters
    hparams.parse(args["--hparams"])

    # Preventing Windows specific error such as MemoryError
    # Also reduces the occurrence of THAllocator.c 0x05 error in Widows build of PyTorch
    if platform.system() == "Windows":
        print("Windows Detected - num_workers set to 1")
        hparams.set_hparam('num_workers', 1)

    assert hparams.name == "deepvoice3"
    print(hparams_debug_string())

    _frontend = getattr(frontend, hparams.frontend)

    os.makedirs(checkpoint_dir, exist_ok=True)

    # Input dataset definitions
    X = FileSourceDataset(TextDataSource(data_root, speaker_id))
    Mel = FileSourceDataset(MelSpecDataSource(data_root, speaker_id))
    Y = FileSourceDataset(LinearSpecDataSource(data_root, speaker_id))

    # Prepare sampler
    frame_lengths = Mel.file_data_source.frame_lengths
コード例 #5
0
def train(args,
          device,
          model_Wav2Lip,
          train_data_loader_Left,
          test_data_loader_Left,
          optimizer,
          checkpoint_dir=None,
          checkpoint_interval=None,
          nepochs=None):
    n_img, loader = prepare_dataloader(args, 'train')  # CycleGAN
    val_n_img, val_loader = prepare_dataloader(args, 'val')  # CycleGAN
    model_cycle = MonoDepthArchitecture(args)  # Modified
    model_cycle.set_data_loader(loader)  # Modified

    global global_step, global_epoch  # Wav2Lip
    resumed_step = global_step  # Wav2Lip

    if not args.resume:  # CycleGAN
        best_val_loss = float('Inf')  # CycleGAN
        validate_cycle(-1)  # # CycleGAN
        pre_validation_update(model_cycle.losses[-1]['val'])  # Modified
    else:
        best_val_loss = min([
            model_cycle.losses[epoch]['val']['G']
            for epoch in model_cycle.losses.keys()
        ])  # Modified

    running_val_loss = 0.0

    while global_epoch < nepochs:
        # Cycle GAN
        c_time = time.time()
        model_cycle.to_train()  # Modified
        model_cycle.set_new_loss_item(global_epoch)  # Modified

        model_cycle.run_epoch(global_epoch, n_img)  # Modified
        validate_cycle(global_epoch)  # M
        print_epoch_update(global_epoch,
                           time.time() - c_time,
                           model_cycle.losses)  # Modified

        # Make a checkpoint
        running_val_loss = model_cycle.losses[global_epoch]['val'][
            'G']  # Modified
        is_best = running_val_loss < best_val_loss

        if is_best:
            best_val_loss = running_val_loss

        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss = 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader_Left))

        for step, (x, indiv_mels, mel, gt_left) in prog_bar:  # M
            model_Wav2Lip.train()
            optimizer.zero_grad()

            # Move data to CUDA device
            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt_left = gt_left.to(device)  # Modified

            g_left = model_Wav2Lip(indiv_mels, x)  # Modified

            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g_left)
            else:
                sync_loss = 0.

            l1loss = recon_loss(g_left, gt_left) + best_val_loss * (recon_loss(
                g_left, gt_left))  # Modified

            loss = hparams.syncnet_wt * sync_loss + (
                1 - hparams.syncnet_wt) * l1loss
            loss.backward()
            optimizer.step()

            if global_step % checkpoint_interval == 0:
                save_sample_images(x, g_left, gt_left, global_step,
                                   checkpoint_dir)

            global_step += 1
            cur_session_steps = global_step - resumed_step

            running_l1_loss += l1loss.item()
            if hparams.syncnet_wt > 0.:
                running_sync_loss += sync_loss.item()
            else:
                running_sync_loss += 0.

            if global_step == 1 or global_step % checkpoint_interval == 0:
                save_checkpoint(model_Wav2Lip, optimizer, global_step,
                                checkpoint_dir, global_epoch)

            if global_step == 1 or global_step % hparams.eval_interval == 0:
                with torch.no_grad():
                    average_sync_loss = eval_model(test_data_loader_Left,
                                                   global_step, device,
                                                   model_Wav2Lip,
                                                   checkpoint_dir,
                                                   best_val_loss)

                    if average_sync_loss < .75:
                        hparams.set_hparam(
                            'syncnet_wt', 0.01
                        )  # without image GAN a lesser weight is sufficient

            prog_bar.set_description('L1: {}, Sync Loss: {}'.format(
                running_l1_loss / (step + 1), running_sync_loss / (step + 1)))

        model_cycle.save_checkpoint(global_epoch, is_best, best_val_loss)

        global_epoch += 1

    print('Finished Training. Best validation loss:\t{:.3f}'.format(
        best_val_loss))
    model_cycle.save_networks('final')

    if running_val_loss != best_val_loss:
        model_cycle.save_best_networks()

    model_cycle.save_losses()
コード例 #6
0
ファイル: eval.py プロジェクト: wyb330/nmt
def main(args, max_data_size=0, shuffle=True, display=False):
    hparams.set_hparam('batch_size', 10)
    hparams.add_hparam('is_training', False)
    check_vocab(args)
    datasets, src_data_size = load_dataset(args)
    iterator = iterator_utils.get_eval_iterator(hparams, datasets, hparams.eos, shuffle=shuffle)
    src_vocab, tgt_vocab, src_dataset, tgt_dataset, tgt_reverse_vocab, src_vocab_size, tgt_vocab_size = datasets
    hparams.add_hparam('vocab_size_source', src_vocab_size)
    hparams.add_hparam('vocab_size_target', tgt_vocab_size)

    sess, model = load_model(hparams, tf.contrib.learn.ModeKeys.EVAL, iterator, src_vocab, tgt_vocab, tgt_reverse_vocab)

    if args.restore_step:
        checkpoint_path = os.path.join(args.model_path, 'nmt.ckpt')
        ckpt = '%s-%d' % (checkpoint_path, args.restore_step)
    else:
        ckpt = tf.train.latest_checkpoint(args.model_path)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
    if ckpt:
        saver.restore(sess, ckpt)
    else:
        raise Exception("can not found checkpoint file")

    src_vocab_file = os.path.join(args.model_path, 'vocab.src')
    src_reverse_vocab = build_reverse_vocab_table(src_vocab_file, hparams)
    sess.run(tf.tables_initializer())

    step_count = 1
    with sess:
        logger.info("starting evaluating...")
        sess.run(iterator.initializer)
        eos = hparams.eos.encode()
        references = []
        translations = []
        start_time = time.time()
        while True:
            try:
                if (max_data_size > 0) and (step_count * hparams.batch_size > max_data_size):
                    break
                if step_count % 10 == 0:
                    t = time.time() - start_time
                    logger.info('step={0} total={1} time={2:.3f}'.format(step_count, step_count * hparams.batch_size, t))
                    start_time = time.time()
                predictions, source, target, source_text, confidence = model.eval(sess)
                reference = bpe2sent(target, eos)
                if hparams.beam_width == 1:
                    translation = bytes2sent(list(predictions), eos)
                else:
                    translation = bytes2sent(list(predictions[:, 0]), eos)

                for s, r, t in zip(source, reference, translation):
                    if display:
                        source_sent = src_reverse_vocab.lookup(tf.constant(list(s), tf.int64))
                        source_sent = sess.run(source_sent)
                        source_sent = text_utils.format_bpe_text(source_sent, eos)
                        print('{}\n{}\n{}\n'.format(source_sent, r, t))
                    references.append(r)
                    translations.append(t)

                if step_count % 100 == 0:
                    bleu_score = moses_multi_bleu(references, translations, args.model_path)
                    logger.info('bleu score = {0:.3f}'.format(bleu_score))

                step_count += 1
            except tf.errors.OutOfRangeError:
                logger.info('Done eval data')
                break

        logger.info('compute bleu score...')
        # bleu_score = compute_bleu_score(references, translations)
        bleu_score = moses_multi_bleu(references, translations, args.model_path)
        logger.info('bleu score = {0:.3f}'.format(bleu_score))
コード例 #7
0
ファイル: hq_wav2lip_train.py プロジェクト: Honga1/Wav2Lip
def train(
    device,
    model,
    disc,
    train_data_loader,
    test_data_loader,
    optimizer,
    disc_optimizer,
    checkpoint_dir=None,
    checkpoint_interval=None,
    nepochs=None,
):
    global global_step, global_epoch
    resumed_step = global_step

    while global_epoch < nepochs:
        print("Starting Epoch: {}".format(global_epoch))
        running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = (
            0.0,
            0.0,
            0.0,
            0.0,
        )
        running_disc_real_loss, running_disc_fake_loss = 0.0, 0.0
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            disc.train()
            model.train()

            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            ### Train generator now. Remove ALL grads.
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            g = model(indiv_mels, x)

            if hparams.syncnet_wt > 0.0:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.0

            if hparams.disc_wt > 0.0:
                perceptual_loss = disc.perceptual_forward(g)
            else:
                perceptual_loss = 0.0

            l1loss = recon_loss(g, gt)

            loss = (hparams.syncnet_wt * sync_loss +
                    hparams.disc_wt * perceptual_loss +
                    (1.0 - hparams.syncnet_wt - hparams.disc_wt) * l1loss)

            loss.backward()
            optimizer.step()

            ### Remove all gradients before Training disc
            disc_optimizer.zero_grad()

            pred = disc(gt)
            disc_real_loss = F.binary_cross_entropy(
                pred,
                torch.ones((len(pred), 1)).to(device))
            disc_real_loss.backward()

            pred = disc(g.detach())
            disc_fake_loss = F.binary_cross_entropy(
                pred,
                torch.zeros((len(pred), 1)).to(device))
            disc_fake_loss.backward()

            disc_optimizer.step()

            running_disc_real_loss += disc_real_loss.item()
            running_disc_fake_loss += disc_fake_loss.item()

            if global_step % checkpoint_interval == 0:
                save_sample_images(x, g, gt, global_step, checkpoint_dir)

            # Logs
            global_step += 1
            cur_session_steps = global_step - resumed_step

            running_l1_loss += l1loss.item()
            if hparams.syncnet_wt > 0.0:
                running_sync_loss += sync_loss.item()
            else:
                running_sync_loss += 0.0

            if hparams.disc_wt > 0.0:
                running_perceptual_loss += perceptual_loss.item()
            else:
                running_perceptual_loss += 0.0

            if global_step == 1 or global_step % checkpoint_interval == 0:
                save_checkpoint(model, optimizer, global_step, checkpoint_dir,
                                global_epoch)
                save_checkpoint(
                    disc,
                    disc_optimizer,
                    global_step,
                    checkpoint_dir,
                    global_epoch,
                    prefix="disc_",
                )

            if global_step % hparams.eval_interval == 0:
                with torch.no_grad():
                    average_sync_loss = eval_model(test_data_loader,
                                                   global_step, device, model,
                                                   disc)

                    if average_sync_loss < 0.75:
                        hparams.set_hparam("syncnet_wt", 0.03)

            prog_bar.set_description(
                "L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}".format(
                    running_l1_loss / (step + 1),
                    running_sync_loss / (step + 1),
                    running_perceptual_loss / (step + 1),
                    running_disc_fake_loss / (step + 1),
                    running_disc_real_loss / (step + 1),
                ))

        global_epoch += 1