Пример #1
0
def infer(args):
    zgen_n = tf.placeholder(tf.int32, [], name='samp_z_n')
    zgen = tf.random.normal([zgen_n, Z_DIM], dtype=tf.float32, name='samp_z')

    z = tf.placeholder(tf.float32, [None, Z_DIM], name='z')
    with tf.variable_scope('G'):
        G = MelspecGANGenerator()
        G_z = G(z, training=False)
    G_z = feats_denorm(G_z)
    G_z = tf.identity(G_z, name='G_z')
    G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
    step = tf.train.get_or_create_global_step()
    saver = tf.train.Saver(var_list=G_vars + [step])

    tf.train.write_graph(tf.get_default_graph(), args.train_dir, 'infer.pbtxt')

    tf.train.export_meta_graph(filename=os.path.join(args.train_dir,
                                                     'infer.meta'),
                               clear_devices=True,
                               saver_def=saver.as_saver_def())

    tf.reset_default_graph()
Пример #2
0
def incept(args):
    incept_dir = os.path.join(args.train_dir, 'incept')
    if not os.path.isdir(incept_dir):
        os.makedirs(incept_dir)

    # Create GAN graph
    z = tf.placeholder(tf.float32, [None, Z_DIM])
    with tf.variable_scope('G'):
        G = MelspecGANGenerator()
        G_z = G(z, training=False)
    G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
    step = tf.train.get_or_create_global_step()
    gan_saver = tf.train.Saver(var_list=G_vars + [step], max_to_keep=1)

    # Load or generate latents
    z_fp = os.path.join(incept_dir, 'z.pkl')
    if os.path.exists(z_fp):
        with open(z_fp, 'rb') as f:
            _zs = pickle.load(f)
    else:
        zs = tf.random.normal([args.incept_n, Z_DIM], dtype=tf.float32)
        with tf.Session() as sess:
            _zs = sess.run(zs)
        with open(z_fp, 'wb') as f:
            pickle.dump(_zs, f)

    # Load classifier graph
    incept_graph = tf.Graph()
    with incept_graph.as_default():
        incept_saver = tf.train.import_meta_graph(args.incept_metagraph_fp)
    incept_x = incept_graph.get_tensor_by_name('x:0')
    incept_preds = incept_graph.get_tensor_by_name('scores:0')
    incept_sess = tf.Session(graph=incept_graph)
    incept_saver.restore(incept_sess, args.incept_ckpt_fp)

    # Create summaries
    summary_graph = tf.Graph()
    with summary_graph.as_default():
        incept_mean = tf.placeholder(tf.float32, [])
        incept_std = tf.placeholder(tf.float32, [])
        summaries = [
            tf.summary.scalar('incept_mean', incept_mean),
            tf.summary.scalar('incept_std', incept_std)
        ]
        summaries = tf.summary.merge(summaries)
    summary_writer = tf.summary.FileWriter(incept_dir)

    # Loop, waiting for checkpoints
    ckpt_fp = None
    _best_score = 0.
    while True:
        latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
        if latest_ckpt_fp != ckpt_fp:
            print('Incept: {}'.format(latest_ckpt_fp))

            sess = tf.Session()

            gan_saver.restore(sess, latest_ckpt_fp)

            _step = sess.run(step)

            _G_z_feats = []
            for i in range(0, args.incept_n, 100):
                _G_z_feats.append(sess.run(G_z, {z: _zs[i:i + 100]}))
            _G_z_feats = np.concatenate(_G_z_feats, axis=0)
            _G_zs = []
            for i, _G_z in enumerate(_G_z_feats):
                _G_z = feats_denorm(_G_z).astype(np.float64)
                _audio = r9y9_melspec_to_waveform(_G_z,
                                                  fs=args.data_sample_rate,
                                                  waveform_len=16384)
                if i == 0:
                    out_fp = os.path.join(incept_dir,
                                          '{}.wav'.format(str(_step).zfill(9)))
                    save_as_wav(out_fp, args.data_sample_rate, _audio)
                _G_zs.append(_audio[:, 0, 0])

            _preds = []
            for i in range(0, args.incept_n, 100):
                _preds.append(
                    incept_sess.run(incept_preds,
                                    {incept_x: _G_zs[i:i + 100]}))
            _preds = np.concatenate(_preds, axis=0)

            # Split into k groups
            _incept_scores = []
            split_size = args.incept_n // args.incept_k
            for i in range(args.incept_k):
                _split = _preds[i * split_size:(i + 1) * split_size]
                _kl = _split * (np.log(_split) -
                                np.log(np.expand_dims(np.mean(_split, 0), 0)))
                _kl = np.mean(np.sum(_kl, 1))
                _incept_scores.append(np.exp(_kl))

            _incept_mean, _incept_std = np.mean(_incept_scores), np.std(
                _incept_scores)

            # Summarize
            with tf.Session(graph=summary_graph) as summary_sess:
                _summaries = summary_sess.run(summaries, {
                    incept_mean: _incept_mean,
                    incept_std: _incept_std
                })
            summary_writer.add_summary(_summaries, _step)

            # Save
            if _incept_mean > _best_score:
                gan_saver.save(sess, os.path.join(incept_dir, 'best_score'),
                               _step)
                _best_score = _incept_mean

            sess.close()

            print('Done')

            ckpt_fp = latest_ckpt_fp

        time.sleep(1)

    incept_sess.close()
Пример #3
0
def train(fps, args):
    # Load data
    with tf.name_scope('loader'):
        x, x_audio = decode_extract_and_batch(
            fps=fps,
            batch_size=TRAIN_BATCH_SIZE,
            slice_len=64,
            audio_fs=args.data_sample_rate,
            audio_mono=True,
            audio_normalize=args.data_normalize,
            decode_fastwav=args.data_fastwav,
            decode_parallel_calls=8,
            extract_type='melspec',
            extract_nfft=1024,
            extract_nhop=256,
            extract_parallel_calls=8,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=512,
            slice_first_only=args.data_slice_first_only,
            slice_randomize_offset=args.data_slice_randomize_offset,
            slice_overlap_ratio=args.data_slice_overlap_ratio,
            slice_pad_end=args.data_slice_pad_end,
            prefetch_size=TRAIN_BATCH_SIZE * 8,
            prefetch_gpu_num=args.data_prefetch_gpu_num)
        x = feats_norm(x)

    # Data summaries
    tf.summary.audio('x_audio', x_audio[:, :, 0], args.data_sample_rate)
    tf.summary.image('x', feats_to_uint8_img(feats_denorm(x)))
    tf.summary.audio(
        'x_inv_audio',
        feats_to_approx_audio(feats_denorm(x),
                              args.data_sample_rate,
                              16384,
                              n=3)[:, :, 0], args.data_sample_rate)

    # Make z vector
    z = tf.random.normal([TRAIN_BATCH_SIZE, Z_DIM], dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G = MelspecGANGenerator()
        G_z = G(z, training=True)
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Summarize G_z
    tf.summary.image('G_z', feats_to_uint8_img(feats_denorm(G_z)))
    tf.summary.audio(
        'G_z_inv_audio',
        feats_to_approx_audio(feats_denorm(G_z),
                              args.data_sample_rate,
                              16384,
                              n=3)[:, :, 0], args.data_sample_rate)

    # Make real discriminator
    D = MelspecGANDiscriminator()
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = D(x, training=True)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = D(G_z, training=True)

    # Create loss
    num_disc_updates_per_genr = 1
    if TRAIN_LOSS == 'dcgan':
        fake = tf.zeros([TRAIN_BATCH_SIZE], dtype=tf.float32)
        real = tf.ones([TRAIN_BATCH_SIZE], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif TRAIN_LOSS == 'wgangp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[TRAIN_BATCH_SIZE, 1, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = D(interpolates, training=True)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty

        num_disc_updates_per_genr = 5
    else:
        raise ValueError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create opt
    if TRAIN_LOSS == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif TRAIN_LOSS == 'wgangp':
        # TODO: some igul code uses beta1=0.
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise ValueError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Train
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_ckpt_every_nsecs,
            save_summaries_secs=args.train_summary_every_nsecs) as sess:
        while not sess.should_stop():
            for i in range(num_disc_updates_per_genr):
                sess.run(D_train_op)

            sess.run(G_train_op)
Пример #4
0
import numpy as np
import tensorflow as tf

from conv2d import MelspecGANGenerator, MelspecGANDiscriminator

from util import feats_denorm

Z_DIM = 100
OUT_FP = 'infer.meta'

zgen_n = tf.placeholder(tf.int32, [], name='samp_z_n')
zgen = tf.random.normal([zgen_n, Z_DIM], dtype=tf.float32, name='samp_z')

z = tf.placeholder(tf.float32, [None, Z_DIM], name='z')
with tf.variable_scope('G'):
    G = MelspecGANGenerator()
    G_z = G(z, training=False)
G_z = feats_denorm(G_z)
G_z = tf.identity(G_z, name='G_z')
print(G_z)
G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
print(G_vars)
step = tf.train.get_or_create_global_step()
saver = tf.train.Saver(var_list=G_vars + [step])

tf.train.write_graph(tf.get_default_graph(), './', 'infer.pbtxt')

tf.train.export_meta_graph(filename=OUT_FP,
                           clear_devices=True,
                           saver_def=saver.as_saver_def())