Пример #1
0
def train(fps, args):
    # Initialize model
    if args.model_type == "regular":
        model = Advoc(Modes.TRAIN)
    elif args.model_type == "small":
        model = AdvocSmall(Modes.TRAIN)
    else:
        raise NotImplementedError()

    model, summary = override_model_attrs(model, args.model_overrides)
    model.audio_fs = args.data_sample_rate

    print('-' * 80)
    print(summary)
    print('-' * 80)

    # Load data
    with tf.name_scope('loader'):
        x_magspec, x_wav = decode_extract_and_batch(
            fps,
            batch_size=model.train_batch_size,
            slice_len=model.subseq_len,
            audio_fs=model.audio_fs,
            audio_mono=True,
            audio_normalize=args.data_normalize,
            decode_fastwav=args.data_fastwav,
            decode_parallel_calls=4,
            extract_type='magspec',
            extract_parallel_calls=8,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=512,
            slice_first_only=args.data_slice_first_only,
            slice_randomize_offset=args.data_slice_randomize_offset,
            slice_overlap_ratio=args.data_slice_overlap_ratio,
            slice_pad_end=args.data_slice_pad_end,
            prefetch_size=model.train_batch_size * 8,
            prefetch_gpu_num=0)

    # Create model
    spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs)

    x_melspec = spectral.mag_to_mel_linear_spec(x_magspec)
    x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec,
                                                         transform='inverse')

    model(x_inverted_magspec, x_magspec, x_wav, x_melspec)

    #Train
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_ckpt_every_nsecs,
            save_summaries_secs=args.train_summary_every_nsecs) as sess:

        _step = 0
        while not sess.should_stop() and _step < args.max_steps:
            _step = model.train_loop(sess)

    print("Done!")
Пример #2
0
def train(fps, args):
    # Load data
    with tf.name_scope('loader'):
        x, x_audio = decode_extract_and_batch(
            fps=fps,
            batch_size=TRAIN_BATCH_SIZE,
            slice_len=64,
            audio_fs=args.data_sample_rate,
            audio_mono=True,
            audio_normalize=args.data_normalize,
            decode_fastwav=args.data_fastwav,
            decode_parallel_calls=8,
            extract_type='melspec',
            extract_nfft=1024,
            extract_nhop=256,
            extract_parallel_calls=8,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=512,
            slice_first_only=args.data_slice_first_only,
            slice_randomize_offset=args.data_slice_randomize_offset,
            slice_overlap_ratio=args.data_slice_overlap_ratio,
            slice_pad_end=args.data_slice_pad_end,
            prefetch_size=TRAIN_BATCH_SIZE * 8,
            prefetch_gpu_num=args.data_prefetch_gpu_num)
        x = feats_norm(x)

    # Data summaries
    tf.summary.audio('x_audio', x_audio[:, :, 0], args.data_sample_rate)
    tf.summary.image('x', feats_to_uint8_img(feats_denorm(x)))
    tf.summary.audio(
        'x_inv_audio',
        feats_to_approx_audio(feats_denorm(x),
                              args.data_sample_rate,
                              16384,
                              n=3)[:, :, 0], args.data_sample_rate)

    # Make z vector
    z = tf.random.normal([TRAIN_BATCH_SIZE, Z_DIM], dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G = MelspecGANGenerator()
        G_z = G(z, training=True)
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Summarize G_z
    tf.summary.image('G_z', feats_to_uint8_img(feats_denorm(G_z)))
    tf.summary.audio(
        'G_z_inv_audio',
        feats_to_approx_audio(feats_denorm(G_z),
                              args.data_sample_rate,
                              16384,
                              n=3)[:, :, 0], args.data_sample_rate)

    # Make real discriminator
    D = MelspecGANDiscriminator()
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = D(x, training=True)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = D(G_z, training=True)

    # Create loss
    num_disc_updates_per_genr = 1
    if TRAIN_LOSS == 'dcgan':
        fake = tf.zeros([TRAIN_BATCH_SIZE], dtype=tf.float32)
        real = tf.ones([TRAIN_BATCH_SIZE], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif TRAIN_LOSS == 'wgangp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[TRAIN_BATCH_SIZE, 1, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = D(interpolates, training=True)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty

        num_disc_updates_per_genr = 5
    else:
        raise ValueError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create opt
    if TRAIN_LOSS == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif TRAIN_LOSS == 'wgangp':
        # TODO: some igul code uses beta1=0.
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise ValueError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Train
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_ckpt_every_nsecs,
            save_summaries_secs=args.train_summary_every_nsecs) as sess:
        while not sess.should_stop():
            for i in range(num_disc_updates_per_genr):
                sess.run(D_train_op)

            sess.run(G_train_op)
Пример #3
0
def eval(fps, args):
    if args.eval_dataset_name is not None:
        eval_dir = os.path.join(args.train_dir,
                                'eval_{}'.format(args.eval_dataset_name))
    else:
        eval_dir = os.path.join(args.train_dir, 'eval_valid')
    if not os.path.isdir(eval_dir):
        os.makedirs(eval_dir)

    if args.model_type == "regular":
        model = Advoc(Modes.EVAL)
    elif args.model_type == "small":
        model = AdvocSmall(Modes.EVAL)
    else:
        raise NotImplementedError()

    model, summary = override_model_attrs(model, args.model_overrides)
    model.audio_fs = args.data_sample_rate

    print('-' * 80)
    print(summary)
    print('-' * 80)

    with tf.name_scope('loader'):
        x_magspec, x_wav = decode_extract_and_batch(
            fps,
            batch_size=model.eval_batch_size,
            slice_len=model.subseq_len,
            audio_fs=model.audio_fs,
            audio_mono=True,
            audio_normalize=args.data_normalize,
            decode_fastwav=args.data_fastwav,
            decode_parallel_calls=4,
            extract_type='magspec',
            extract_parallel_calls=8,
            repeat=False,
            shuffle=False,
            shuffle_buffer_size=None,
            slice_first_only=args.data_slice_first_only,
            slice_randomize_offset=False,
            slice_overlap_ratio=0.,
            slice_pad_end=True,
            prefetch_size=None,
            prefetch_gpu_num=None)

    spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs)
    x_melspec = spectral.mag_to_mel_linear_spec(x_magspec)
    x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec,
                                                         transform='inverse')

    with tf.variable_scope("generator") as vs:
        if model.generator_type == "pix2pix":
            gen_magspec = model.build_generator(x_inverted_magspec)
        elif model.generator_type == "linear":
            gen_magspec = model.build_linear_generator(x_inverted_magspec)
        elif model.generator_type == "linear+pix2pix":
            _temp_spec = model.build_linear_generator(x_melspec)
            gen_magspec = model.build_linear_generator(_temp_spec)
        elif model.generator_type == "interp+pix2pix":
            _temp_spec = tf.image.resize_images(x_melspec,
                                                [model.subseq_len, 513])
            gen_magspec = model.build_linear_generator(_temp_spec)
        else:
            raise NotImplementedError()

        G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                   scope=vs.name)

    gen_loss_L1 = tf.reduce_mean(tf.abs(x_magspec - gen_magspec))
    gan_step = tf.train.get_or_create_global_step()
    gan_saver = tf.train.Saver(var_list=G_vars + [gan_step], max_to_keep=1)

    all_gen_loss_L1 = tf.placeholder(tf.float32, [None])

    summaries = [
        tf.summary.scalar('gen_loss_L1', tf.reduce_mean(all_gen_loss_L1)),
    ]
    summaries = tf.summary.merge(summaries)

    # Create summary writer
    summary_writer = tf.summary.FileWriter(eval_dir)
    ckpt_fp = None
    _best_gen_loss_l1 = np.inf

    while True:
        latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
        if latest_ckpt_fp != ckpt_fp:
            ckpt_fp = latest_ckpt_fp
            print('Evaluating {}'.format(ckpt_fp))

            with tf.Session() as sess:
                gan_saver.restore(sess, latest_ckpt_fp)
                _step = sess.run(gan_step)
                _all_gen_loss_L1 = []

                while True:
                    try:
                        _gen_loss_L1, _gen_magspec, _x_magspec = sess.run(
                            [gen_loss_L1, gen_magspec, x_magspec])
                    except tf.errors.OutOfRangeError:
                        break

                    _all_gen_loss_L1.append(_gen_loss_L1)

                _all_gen_loss_L1 = np.array(_all_gen_loss_L1)

                _summaries = sess.run(summaries, {
                    all_gen_loss_L1: _all_gen_loss_L1,
                })
                summary_writer.add_summary(_summaries, _step)
                _gen_loss_L1_np = np.mean(_all_gen_loss_L1)

                if _gen_loss_L1_np < _best_gen_loss_l1:
                    gan_saver.save(sess,
                                   os.path.join(eval_dir, 'best_gen_loss_l1'),
                                   _step)
                    print("Saved best gen loss l1!")
            print('Done!')
        time.sleep(1)
Пример #4
0
def infer(fps, args):
    if args.infer_dataset_name is not None:
        infer_dir = os.path.join(args.train_dir,
                                 'infer_{}'.format(args.infer_dataset_name))
    else:
        infer_dir = os.path.join(args.train_dir, 'infer_valid')
    if not os.path.isdir(infer_dir):
        os.makedirs(infer_dir)

    if args.model_type == "regular":
        model = Advoc(Modes.INFER)
    elif args.model_type == "small":
        model = AdvocSmall(Modes.INFER)
    else:
        raise NotImplementedError()

    model, summary = override_model_attrs(model, args.model_overrides)
    model.audio_fs = args.data_sample_rate

    print('-' * 80)
    print(summary)
    print('-' * 80)

    with tf.name_scope('loader'):
        x_magspec, x_wav = decode_extract_and_batch(
            fps,
            batch_size=args.infer_batch_size,
            slice_len=model.subseq_len,
            audio_fs=model.audio_fs,
            audio_mono=True,
            audio_normalize=args.data_normalize,
            decode_fastwav=args.data_fastwav,
            decode_parallel_calls=4,
            extract_type='magspec',
            extract_parallel_calls=8,
            repeat=False,
            shuffle=False,
            shuffle_buffer_size=None,
            slice_randomize_offset=False,
            slice_overlap_ratio=0.,
            slice_pad_end=True,
            prefetch_size=None,
            prefetch_gpu_num=None)

    spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs)
    x_melspec = spectral.mag_to_mel_linear_spec(x_magspec)
    x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec,
                                                         transform='inverse')

    with tf.variable_scope("generator") as vs:
        if model.generator_type == "pix2pix":
            gen_magspec = model.build_generator(x_inverted_magspec)
        elif model.generator_type == "linear":
            gen_magspec = model.build_linear_generator(x_inverted_magspec)
        elif model.generator_type == "linear+pix2pix":
            _temp_spec = model.build_linear_generator(x_melspec)
            gen_magspec = model.build_linear_generator(_temp_spec)
        elif model.generator_type == "interp+pix2pix":
            _temp_spec = tf.image.resize_images(x_melspec,
                                                [model.subseq_len, 513])
            gen_magspec = model.build_linear_generator(_temp_spec)
        else:
            raise NotImplementedError()
        G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                   scope=vs.name)

    step = tf.train.get_or_create_global_step()
    gan_saver = tf.train.Saver(var_list=G_vars + [step], max_to_keep=1)

    input_audio = tf.py_func(spectral.audio_from_mag_spec,
                             [x_inverted_magspec[0]],
                             tf.float32,
                             stateful=False)
    target_audio = tf.py_func(spectral.audio_from_mag_spec, [x_magspec[0]],
                              tf.float32,
                              stateful=False)
    gen_audio = tf.py_func(spectral.audio_from_mag_spec, [gen_magspec[0]],
                           tf.float32,
                           stateful=False)

    # dont know why i rehspae them this way. just following past convention.
    input_audio = tf.reshape(input_audio, [1, -1, 1, 1])
    target_audio = tf.reshape(target_audio, [1, -1, 1, 1])
    gen_audio = tf.reshape(gen_audio, [1, -1, 1, 1])

    summaries = [
        tf.summary.audio('infer_x_wav', x_wav[:, :, 0, :], model.audio_fs),
        tf.summary.audio('infer_gen_audio', gen_audio[:, :, 0, :],
                         model.audio_fs),
        tf.summary.audio('target_audio', target_audio[:, :, 0, :],
                         model.audio_fs),
        tf.summary.audio('infer_input_audio', input_audio[:, :, 0, :],
                         model.audio_fs)
    ]

    summaries = tf.summary.merge(summaries)
    # Create saver and summary writer
    summary_writer = tf.summary.FileWriter(infer_dir)

    if args.infer_ckpt_path is not None:
        # Infering From a particular Checkpoint
        ckpt_fp = args.infer_ckpt_path
        print('Infereing From {}'.format(ckpt_fp))

        with tf.Session() as sess:
            gan_saver.restore(sess, ckpt_fp)
            _step = sess.run(step)
            # Just one batch at a time
            while True:
                try:
                    _summaries, mel_np, est_np, act_np, gen_np = sess.run([
                        summaries, x_melspec, x_inverted_magspec, x_magspec,
                        gen_magspec
                    ])
                    summary_writer.add_summary(_summaries, _step)

                except tf.errors.OutOfRangeError:
                    break
            print('Done!')

    else:
        # Continuous Inference
        ckpt_fp = None
        while True:
            with tf.Session() as sess:
                latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
                if latest_ckpt_fp != ckpt_fp:
                    ckpt_fp = latest_ckpt_fp
                    print('Infereing From {}'.format(ckpt_fp))
                    gan_saver.restore(sess, ckpt_fp)
                    _step = sess.run(step)

                    while True:
                        try:
                            _summaries, mel_np, est_np, act_np, gen_np = sess.run(
                                [
                                    summaries, x_melspec, x_inverted_magspec,
                                    x_magspec, gen_magspec
                                ])
                            summary_writer.add_summary(_summaries, _step)

                        except tf.errors.OutOfRangeError:
                            break
                    print("Done!")
                time.sleep(1)

    raise NotImplementedError()