Exemple #1
0
def generate_real(fps2,args):

  no_samples = len(fps2)
  no_set =int(no_samples/64)
  no_remain = no_samples%64

  for k in range (no_set+1):
    
    if(no_set == k):
      print("k is ",k)
      s_fps2 = fps2[no_samples-64:no_samples]
    else:
      s_fps2 = fps2[(64*k):(64*(k+1))]
    tf.reset_default_graph()
    saver = tf.train.import_meta_graph('infer.meta')
    graph = tf.get_default_graph()
    sess = tf.InteractiveSession()
    saver.restore(sess, 'model.ckpt')
  
    with tf.name_scope('samp_x_synthetic'):
      x_synthetic = loader.decode_extract_and_batch(
          s_fps2,
          batch_size=args.train_batch_size,
          slice_len=args.data2_slice_len,
          decode_fs=args.data2_sample_rate,
          decode_num_channels=args.data2_num_channels,
          decode_fast_wav=args.data2_fast_wav,
          decode_parallel_calls=4,
          slice_randomize_offset=False if args.data2_first_slice else True,
          slice_first_only=args.data2_first_slice,
          slice_overlap_ratio=0. if args.data2_first_slice else args.data2_overlap_ratio,
          slice_pad_end=True if args.data2_first_slice else args.data2_pad_end,
          repeat=True,
          shuffle=True,
          shuffle_buffer_size=4096,
          prefetch_size=args.train_batch_size * 4,
          prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0]	
  
  
    _x_synthetic = x_synthetic.eval(session=tf.Session())
    print("input ", len(_x_synthetic), len(_x_synthetic[1]))
      # _z = (np.random.rand(1000, 100) * 2.) - 1
    # Synthesize G(z)
    x_synthetic = graph.get_tensor_by_name('x_synthetic:0')
    x_real = graph.get_tensor_by_name('x_real:0')
    G_real = graph.get_tensor_by_name('G_real_x:0')
    _G_real = sess.run(G_real, {x_synthetic: _x_synthetic,x_real: _x_synthetic})
    print("G_S" , len(_G_real), len(_G_real[1]))
    for i in range (64):
      print("i ",i)
      wav=_G_real[i][0:16000]
      name = 'IRs_for_GAN/' + s_fps2[i]
      print("name ",name)
      librosa.output.write_wav(path=name,y=wav,sr=16000)
def moments(fps, args):
    with tf.name_scope('loader'):
        x_wav = loader.decode_extract_and_batch(
            fps,
            batch_size=1,
            slice_len=_SLICE_LEN,
            decode_fs=args.data_sample_rate,
            decode_num_channels=1,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data_first_slice else True,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0.
            if args.data_first_slice else args.data_overlap_ratio,
            slice_pad_end=True if args.data_first_slice else args.data_pad_end,
            repeat=False,
            shuffle=False,
            shuffle_buffer_size=0,
            prefetch_size=4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)[0, :, 0, 0]

    X = tf.contrib.signal.stft(x_wav, 256, 128, pad_end=True)
    X_mag = tf.abs(X)
    X_lmag = tf.log(X_mag + _LOG_EPS)

    _X_lmags = []
    with tf.Session() as sess:
        while True:
            try:
                _X_lmag = sess.run(X_lmag)
            except:
                break

            _X_lmags.append(_X_lmag)

    _X_lmags = np.concatenate(_X_lmags, axis=0)
    mean, std = np.mean(_X_lmags, axis=0), np.std(_X_lmags, axis=0)

    with open(args.data_moments_fp, 'wb') as f:
        pickle.dump((mean, std), f)
def train(fps, args):
    with tf.name_scope('loader'):
        x = loader.decode_extract_and_batch(
            fps,
            batch_size=args.train_batch_size,
            slice_len=args.data_slice_len,
            decode_fs=args.data_sample_rate,
            decode_num_channels=args.data_num_channels,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data_first_slice else True,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0.
            if args.data_first_slice else args.data_overlap_ratio,
            slice_pad_end=True if args.data_first_slice else args.data_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]

    # Make z vector
    def random_c():
        idxs = np.random.randint(args.num_categ, size=args.train_batch_size)
        c = np.zeros((args.train_batch_size, args.num_categ))
        c[np.arange(args.train_batch_size), idxs] = 1
        return c

    def random_z():
        rz = np.zeros([args.train_batch_size, args.wavegan_latent_dim])
        rz[:, :args.num_categ] = random_c()
        rz[:, args.num_categ:] = np.random.uniform(
            -1.,
            1.,
            size=(args.train_batch_size,
                  args.wavegan_latent_dim - args.num_categ))
        return rz

    z = tf.placeholder(tf.float32,
                       (args.train_batch_size, args.wavegan_latent_dim))

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    tf.summary.audio('x', x, args.data_sample_rate)
    tf.summary.audio('G_z', G_z, args.data_sample_rate)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

    # Make Q
    with tf.variable_scope('Q'):
        Q_G_z = WaveGANQ(G_z, **args.wavegan_q_kwargs)
    Q_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Q')

    # Print Q summary
    print('Q vars')
    nparams = 0
    for v in Q_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('-' * 80)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':

        def q_cost_tf(z, q):
            z_cat = z[:, :args.num_categ]
            q_cat = q[:, :args.num_categ]
            lcat = tf.nn.softmax_cross_entropy_with_logits(labels=z_cat,
                                                           logits=q_cat)
            return tf.reduce_mean(lcat)

        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)
        Q_loss = q_cost_tf(z, Q_G_z)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates,
                                            **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)
    tf.summary.scalar('Q_loss', Q_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        Q_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)
    Q_train_op = Q_opt.minimize(Q_loss, var_list=Q_vars + G_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        print('-' * 80)
        print(
            'Training has started. Please use \'tensorboard --logdir={}\' to monitor.'
            .format(args.train_dir))
        while True:
            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run([D_loss, D_train_op], feed_dict={z: random_z()})

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)

            # Train generator
            sess.run([G_loss, Q_loss, G_train_op, Q_train_op],
                     feed_dict={z: random_z()})
Exemple #4
0
def train(fps, args):
    with tf.name_scope('loader'):
        x = loader.decode_extract_and_batch(
            fps,
            batch_size=args.train_batch_size,
            slice_len=args.data_slice_len,
            decode_fs=args.data_sample_rate,
            decode_num_channels=args.data_num_channels,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data_first_slice else True,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0.
            if args.data_first_slice else args.data_overlap_ratio,
            slice_pad_end=True if args.data_first_slice else args.data_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, args.wavegan_latent_dim],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    tf.summary.audio('x', x, args.data_sample_rate)
    tf.summary.audio('G_z', G_z, args.data_sample_rate)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates,
                                            **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    # Load adversarial input
    fs, audio = wavread(args.adv_input)
    assert fs == args.data_sample_rate
    assert audio.dtype == np.float32
    assert len(audio.shape) == 1

    # Synthesis
    if audio.shape[0] < args.data_slice_len:
        audio = np.pad(audio, (0, args.data_slice_len - audio.shape[0]),
                       'constant')
    adv_input = tf.constant(
        audio[:args.data_slice_len], dtype=np.float32
    ) + args.adv_magnitude * tf.reshape(G_z,
                                        G_z.get_shape().as_list()[:-1])

    # Calculate MFCCs
    spectrograms = tf.abs(
        tf.signal.stft(adv_input, frame_length=320, frame_step=160))
    linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
        40, spectrograms.shape[-1].value, fs, 20, 4000)
    mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix,
                                    1)
    mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate(
        linear_to_mel_weight_matrix.shape[-1:]))
    log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6)
    mfccs = tf.expand_dims(
        tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms)[
            -1, :99, :40], -1)

    # Load a model for speech command classification
    with tf.gfile.FastGFile(args.adv_model, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        with tf.variable_scope('Speech'):
            adv_logits, = tf.import_graph_def(graph_def,
                                              input_map={'Mfcc:0': mfccs},
                                              return_elements=['add_2:0'])

    # Load labels for speech command classification
    adv_labels = [line.rstrip() for line in tf.gfile.GFile(args.adv_label)]
    adv_index = adv_labels.index(args.adv_target)

    # Make adversarial loss
    # Came from: https://github.com/carlini/nn_robust_attacks/blob/master/l2_attack.py
    adv_targets = tf.one_hot(
        tf.constant([adv_index] * args.train_batch_size, dtype=tf.int32),
        len(adv_labels))
    adv_target_logit = tf.reduce_sum(adv_targets * adv_logits, 1)
    adv_others_logit = tf.reduce_max(
        (1 - adv_targets) * adv_logits - (adv_targets * 10000), 1)

    adv_loss = tf.reduce_mean(
        tf.maximum(0.0,
                   adv_others_logit - adv_target_logit + args.adv_confidence))

    # Summarize audios
    tf.summary.audio('adv_input',
                     adv_input,
                     fs,
                     max_outputs=args.adv_max_outputs)
    tf.summary.scalar('adv_loss', adv_loss)
    tf.summary.histogram('adv_classes', tf.argmax(adv_logits, axis=1))

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss + args.adv_lambda * adv_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            config=config,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        print('-' * 80)
        print(
            'Training has started. Please use \'tensorboard --logdir={}\' to monitor.'
            .format(args.train_dir))
        while True:
            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)

            # Train generator
            sess.run(G_train_op)
def train(fps, args):
    with tf.name_scope('loader'):
        x_wav = loader.decode_extract_and_batch(
            fps,
            batch_size=args.train_batch_size,
            slice_len=_SLICE_LEN,
            decode_fs=args.data_sample_rate,
            decode_num_channels=1,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data_first_slice else True,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0.
            if args.data_first_slice else args.data_overlap_ratio,
            slice_pad_end=True if args.data_first_slice else args.data_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]

        x = t_to_f(x_wav, args.data_moments_mean, args.data_moments_std)

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, args.specgan_latent_dim],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = SpecGANGenerator(z, train=True, **args.specgan_g_kwargs)
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std,
                  args.specgan_ngl)
    G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std,
                    args.specgan_ngl)
    tf.summary.audio('x_wav', x_wav, args.data_sample_rate)
    tf.summary.audio('x', x_gl, args.data_sample_rate)
    tf.summary.audio('G_z', G_z_gl, args.data_sample_rate)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
    tf.summary.image('x', f_to_img(x))
    tf.summary.image('G_z', f_to_img(G_z))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = SpecGANDiscriminator(x, **args.specgan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = SpecGANDiscriminator(G_z, **args.specgan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.specgan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.specgan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.specgan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.specgan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = SpecGANDiscriminator(interpolates,
                                            **args.specgan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.specgan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.specgan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.specgan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.specgan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        print('-' * 80)
        print(
            'Training has started. Please use \'tensorboard --logdir={}\' to monitor.'
            .format(args.train_dir))
        while True:
            # Train discriminator
            for i in xrange(args.specgan_disc_nupdates):
                sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)

            # Train generator
            sess.run(G_train_op)
def train(args):
    from functools import reduce
    from model import WaveGANGenerator, WaveGANDiscriminator
    import glob
    import loader

    # Make train dir
    if not os.path.isdir(args.train_dir):
        os.makedirs(args.train_dir)

    fps = glob.glob(os.path.join(args.data_dir, '*'))

    if len(fps) == 0:
        raise Exception('Did not find any audio files in specified directory')
    print('Found {} audio files in specified directory'.format(len(fps)))

    with tf.name_scope('loader'):
        x = loader.decode_extract_and_batch(
            fps,
            batch_size=args.train_batch_size,
            slice_len=32768,
            decode_fs=args.data_sample_rate,
            decode_num_channels=args.data_num_channels,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0.,
            slice_pad_end=True,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)
        x = x[:, :, 0]

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, args.wavegan_latent_dim],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        # use first 512 point from real data as y
        y = tf.slice(x, [0, 0, 0], [-1, args.wavegan_smooth_len, -1])
        G_z = WaveGANGenerator(y,
                               z,
                               args.wavegan_kernel_len,
                               args.wavegan_smooth_len,
                               args.wavegan_dim,
                               args.wavegan_batchnorm,
                               train=True)
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')
    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, args.wavegan_kernel_len,
                                   args.wavegan_dim, args.wavegan_batchnorm,
                                   args.wavegan_disc_phaseshuffle)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        yG_z = tf.concat([y, G_z], 1)
        # print("yG_z shape:")
        # print(yG_z.get_shape())
        D_G_z = WaveGANDiscriminator(yG_z, args.wavegan_kernel_len,
                                     args.wavegan_dim, args.wavegan_batchnorm,
                                     args.wavegan_disc_phaseshuffle)

    # Create loss
    G_loss = -tf.reduce_mean(D_G_z)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

    alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                              minval=0.,
                              maxval=1.)
    differences = yG_z - x
    interpolates = x + (alpha * differences)
    with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
        D_interp = WaveGANDiscriminator(interpolates, args.wavegan_kernel_len,
                                        args.wavegan_dim,
                                        args.wavegan_batchnorm,
                                        args.wavegan_disc_phaseshuffle)
    LAMBDA = 10
    gradients = tf.gradients(D_interp, [interpolates])[0]
    slopes = tf.sqrt(
        tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
    gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
    D_loss += LAMBDA * gradient_penalty

    # Create (recommended) optimizer
    G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9)
    D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9)

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Summarize
    tf.summary.audio('x', x, args.data_sample_rate)
    tf.summary.audio('G_z', G_z, args.data_sample_rate)
    tf.summary.audio('yG_z', yG_z, args.data_sample_rate)

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        while True:
            # Train discriminator
            from six.moves import xrange

            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op)

            # Train generator
            sess.run(G_train_op)
            if args.verbose:
                eval_loss_D = D_loss.eval(session=sess)
                eval_loss_G = G_loss.eval(session=sess)
                print(str(eval_loss_D) + "," + str(eval_loss_G))
Exemple #7
0
def train(fps1, fps2, args):
    with tf.name_scope('loader'):
        x_real = loader.decode_extract_and_batch(
            fps1,
            batch_size=args.train_batch_size,
            slice_len=args.data1_slice_len,
            decode_fs=args.data1_sample_rate,
            decode_num_channels=args.data1_num_channels,
            decode_fast_wav=args.data1_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data1_first_slice else True,
            slice_first_only=args.data1_first_slice,
            slice_overlap_ratio=0.
            if args.data1_first_slice else args.data1_overlap_ratio,
            slice_pad_end=True
            if args.data1_first_slice else args.data1_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0]

        x_synthetic = loader.decode_extract_and_batch(
            fps2,
            batch_size=args.train_batch_size,
            slice_len=args.data2_slice_len,
            decode_fs=args.data2_sample_rate,
            decode_num_channels=args.data2_num_channels,
            decode_fast_wav=args.data2_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data2_first_slice else True,
            slice_first_only=args.data2_first_slice,
            slice_overlap_ratio=0.
            if args.data2_first_slice else args.data2_overlap_ratio,
            slice_pad_end=True
            if args.data2_first_slice else args.data2_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0]

    # print('length check', len(x_real))
    # Make z vector
    # z = tf.random_uniform([args.train_batch_size, args.TSRIRgan_latent_dim], -1., 1., dtype=tf.float32)

    # Make generator_synthetic
    with tf.variable_scope('G_synthetic'):
        G_synthetic = TSRIRGANGenerator_synthetic(x_real,
                                                  train=True,
                                                  **args.TSRIRgan_g_kwargs)
        if args.TSRIRgan_genr_pp:
            with tf.variable_scope('s_pp_filt'):
                G_synthetic = tf.layers.conv1d(G_synthetic,
                                               1,
                                               args.TSRIRgan_genr_pp_len,
                                               use_bias=False,
                                               padding='same')
    G_synthetic_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         scope='G_synthetic')

    # Print G_synthetic summary
    print('-' * 80)
    print('Generator_synthetic vars')
    nparams = 0
    for v in G_synthetic_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    tf.summary.audio('x_real', x_real, args.data1_sample_rate)
    tf.summary.audio('G_synthetic', G_synthetic, args.data1_sample_rate)
    G_synthetic_rms = tf.sqrt(
        tf.reduce_mean(tf.square(G_synthetic[:, :, 0]), axis=1))
    x_real_rms = tf.sqrt(tf.reduce_mean(tf.square(x_real[:, :, 0]), axis=1))
    tf.summary.histogram('x_real_rms_batch', x_real_rms)
    tf.summary.histogram('G_synthetic_rms_batch', G_synthetic_rms)
    tf.summary.scalar('x_real_rms', tf.reduce_mean(x_real_rms))
    tf.summary.scalar('G_synthetic_rms', tf.reduce_mean(G_synthetic_rms))

    # Make generator_real
    with tf.variable_scope('G_real'):
        G_real = TSRIRGANGenerator_real(x_synthetic,
                                        train=True,
                                        **args.TSRIRgan_g_kwargs)
        if args.TSRIRgan_genr_pp:
            with tf.variable_scope('r_pp_filt'):
                G_real = tf.layers.conv1d(G_real,
                                          1,
                                          args.TSRIRgan_genr_pp_len,
                                          use_bias=False,
                                          padding='same')
    G_real_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope='G_real')

    # Print G_real summary
    print('-' * 80)
    print('Generator_real vars')
    nparams = 0
    for v in G_real_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    tf.summary.audio('x_synthetic', x_synthetic, args.data1_sample_rate)
    tf.summary.audio('G_real', G_real, args.data1_sample_rate)
    G_real_rms = tf.sqrt(tf.reduce_mean(tf.square(G_real[:, :, 0]), axis=1))
    x_synthetic_rms = tf.sqrt(
        tf.reduce_mean(tf.square(x_synthetic[:, :, 0]), axis=1))
    tf.summary.histogram('x_synthetic_rms_batch', x_synthetic_rms)
    tf.summary.histogram('G_real_rms_batch', G_real_rms)
    tf.summary.scalar('x_synthetic_rms', tf.reduce_mean(x_synthetic_rms))
    tf.summary.scalar('G_real_rms', tf.reduce_mean(G_real_rms))

    #Generating Cycled Image
    with tf.variable_scope('G_synthetic', reuse=True):
        cycle_synthetic = TSRIRGANGenerator_synthetic(G_real,
                                                      train=True,
                                                      **args.TSRIRgan_g_kwargs)
        if args.TSRIRgan_genr_pp:
            with tf.variable_scope('s_pp_filt'):
                cycle_synthetic = tf.layers.conv1d(cycle_synthetic,
                                                   1,
                                                   args.TSRIRgan_genr_pp_len,
                                                   use_bias=False,
                                                   padding='same')
    G_synthetic_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         scope='G_synthetic')

    with tf.variable_scope('G_real', reuse=True):
        cycle_real = TSRIRGANGenerator_real(G_synthetic,
                                            train=True,
                                            **args.TSRIRgan_g_kwargs)
        if args.TSRIRgan_genr_pp:
            with tf.variable_scope('r_pp_filt'):
                cycle_real = tf.layers.conv1d(cycle_real,
                                              1,
                                              args.TSRIRgan_genr_pp_len,
                                              use_bias=False,
                                              padding='same')
    G_real_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope='G_real')

    #Generating Same Image
    with tf.variable_scope('G_synthetic', reuse=True):
        same_synthetic = TSRIRGANGenerator_synthetic(x_synthetic,
                                                     train=True,
                                                     **args.TSRIRgan_g_kwargs)
        if args.TSRIRgan_genr_pp:
            with tf.variable_scope('s_pp_filt'):
                same_synthetic = tf.layers.conv1d(same_synthetic,
                                                  1,
                                                  args.TSRIRgan_genr_pp_len,
                                                  use_bias=False,
                                                  padding='same')
    G_synthetic_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         scope='G_synthetic')

    with tf.variable_scope('G_real', reuse=True):
        same_real = TSRIRGANGenerator_real(x_real,
                                           train=True,
                                           **args.TSRIRgan_g_kwargs)
        if args.TSRIRgan_genr_pp:
            with tf.variable_scope('r_pp_filt'):
                same_real = tf.layers.conv1d(same_real,
                                             1,
                                             args.TSRIRgan_genr_pp_len,
                                             use_bias=False,
                                             padding='same')
    G_real_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope='G_real')

    #Synthetic
    # Make real discriminator
    with tf.name_scope('D_synthetic_x'), tf.variable_scope('D_synthetic'):
        D_synthetic_x = TSRIRGANDiscriminator_synthetic(
            x_synthetic, **args.TSRIRgan_d_kwargs)
    D_synthetic_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         scope='D_synthetic')

    # Print D summary
    print('-' * 80)
    print('Discriminator_synthetic vars')
    nparams = 0
    for v in D_synthetic_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_synthetic'), tf.variable_scope('D_synthetic',
                                                           reuse=True):
        D_G_synthetic = TSRIRGANDiscriminator_synthetic(
            G_synthetic, **args.TSRIRgan_d_kwargs)

    #Real
    # Make real discriminator
    with tf.name_scope('D_real_x'), tf.variable_scope('D_real'):
        D_real_x = TSRIRGANDiscriminator_real(x_real, **args.TSRIRgan_d_kwargs)
    D_real_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope='D_real')

    # Print D summary
    print('-' * 80)
    print('Discriminator_real vars')
    nparams = 0
    for v in D_real_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_real'), tf.variable_scope('D_real', reuse=True):
        D_G_real = TSRIRGANDiscriminator_real(G_real, **args.TSRIRgan_d_kwargs)


############stop here###########
# Create loss
    D_clip_weights = None
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    if args.TSRIRgan_loss == 'cycle-gan':
        #Real IR
        gen_real_loss = generator_loss(D_G_real)
        gen_synthetic_loss = generator_loss(D_G_synthetic)

        cycle_loss_real = calc_cycle_loss(x_real, cycle_real)
        cycle_loss_synthetic = calc_cycle_loss(x_synthetic, cycle_synthetic)

        total_cycle_loss = cycle_loss_real + cycle_loss_synthetic

        same_real_loss = identity_loss(x_real, same_real)
        same_synthetic_loss = identity_loss(x_synthetic, same_synthetic)

        # RT60_loss_real = RT_60_loss(x_real,G_real,sess)
        # RT60_loss_synthetic = RT_60_loss(x_synthetic,G_synthetic,sess)

        total_gen_real_loss = gen_real_loss + 25 * total_cycle_loss + 35 * same_real_loss  #+RT60_loss_real
        total_gen_synthetic_loss = gen_synthetic_loss + 25 * total_cycle_loss + 35 * same_synthetic_loss  # +RT60_loss_synthetic

        disc_synthetic_loss = discriminator_loss(D_synthetic_x, D_G_synthetic)
        disc_real_loss = discriminator_loss(D_real_x, D_G_real)

    else:
        raise NotImplementedError()

    # tf.summary.scalar('RT60_loss_real', RT60_loss_real)
    # tf.summary.scalar('RT60_loss_synthetic',RT60_loss_synthetic)
    tf.summary.scalar('G_real_loss', total_gen_real_loss)
    tf.summary.scalar('G_synthetic_loss', total_gen_synthetic_loss)
    tf.summary.scalar('D_real_loss', disc_real_loss)
    tf.summary.scalar('D_synthetic_loss', disc_synthetic_loss)

    tf.summary.scalar('Generator_real_loss', gen_real_loss)
    tf.summary.scalar('Generator_synthetic_loss', gen_synthetic_loss)
    tf.summary.scalar('Cycle_loss_real', 15 * cycle_loss_real)
    tf.summary.scalar('Cycle_loss_synthetic', 15 * cycle_loss_synthetic)
    tf.summary.scalar('Same_loss_real', 20 * same_real_loss)
    tf.summary.scalar('Same_loss_synthetic', 20 * same_synthetic_loss)

    # Create (recommended) optimizer
    if args.TSRIRgan_loss == 'cycle-gan':
        # G_real_opt = tf.train.AdamOptimizer(
        #     learning_rate=2e-4,
        #     beta1=0.5)
        # G_synthetic_opt = tf.train.AdamOptimizer(
        #     learning_rate=2e-4,
        #     beta1=0.5)
        # D_real_opt = tf.train.AdamOptimizer(
        #     learning_rate=2e-4,
        #     beta1=0.5)
        # D_synthetic_opt = tf.train.AdamOptimizer(
        #     learning_rate=2e-4,
        #     beta1=0.5)
        G_real_opt = tf.train.RMSPropOptimizer(learning_rate=3e-5)
        G_synthetic_opt = tf.train.RMSPropOptimizer(learning_rate=3e-5)
        D_real_opt = tf.train.RMSPropOptimizer(learning_rate=3e-5)
        D_synthetic_opt = tf.train.RMSPropOptimizer(learning_rate=3e-5)
    else:
        raise NotImplementedError()

    # Create training ops
    G_real_train_op = G_real_opt.minimize(
        total_gen_real_loss,
        var_list=G_real_vars,
        global_step=tf.train.get_or_create_global_step())
    G_synthetic_train_op = G_synthetic_opt.minimize(
        total_gen_synthetic_loss,
        var_list=G_synthetic_vars,
        global_step=tf.train.get_or_create_global_step())
    D_real_train_op = D_real_opt.minimize(disc_real_loss, var_list=D_real_vars)
    D_synthetic_train_op = D_synthetic_opt.minimize(disc_synthetic_loss,
                                                    var_list=D_synthetic_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        print('-' * 80)
        print(
            'Training has started. Please use \'tensorboard --logdir={}\' to monitor.'
            .format(args.train_dir))
        # RT60_loss_real = RT_60_loss(x_real,G_real,sess)
        # RT60_loss_synthetic = RT_60_loss(x_synthetic,G_synthetic,sess)
        while True:
            # Train discriminator
            for i in xrange(args.TSRIRgan_disc_nupdates):
                sess.run(D_real_train_op)
                sess.run(D_synthetic_train_op)

                # Enforce Lipschitz constraint for WGAN
                # if D_clip_weights is not None:
                #   sess.run(D_clip_weights)

            # Train generator
            sess.run(G_real_train_op)
            sess.run(G_synthetic_train_op)
Exemple #8
0
def preview(fps1, fps2, args):
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from scipy.io.wavfile import write as wavwrite
    from scipy.signal import freqz

    preview_dir = os.path.join(args.train_dir, 'preview')
    if not os.path.isdir(preview_dir):
        os.makedirs(preview_dir)

    ####################################################
    s_fps1 = fps1[0:args.preview_n]
    s_fps2 = fps2[0:args.preview_n]
    with tf.name_scope('samp_x_real'):
        x_real = loader.decode_extract_and_batch(
            s_fps1,
            batch_size=args.train_batch_size,
            slice_len=args.data1_slice_len,
            decode_fs=args.data1_sample_rate,
            decode_num_channels=args.data1_num_channels,
            decode_fast_wav=args.data1_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data1_first_slice else True,
            slice_first_only=args.data1_first_slice,
            slice_overlap_ratio=0.
            if args.data1_first_slice else args.data1_overlap_ratio,
            slice_pad_end=True
            if args.data1_first_slice else args.data1_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0]

    with tf.name_scope('samp_x_synthetic'):
        x_synthetic = loader.decode_extract_and_batch(
            s_fps2,
            batch_size=args.train_batch_size,
            slice_len=args.data2_slice_len,
            decode_fs=args.data2_sample_rate,
            decode_num_channels=args.data2_num_channels,
            decode_fast_wav=args.data2_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data2_first_slice else True,
            slice_first_only=args.data2_first_slice,
            slice_overlap_ratio=0.
            if args.data2_first_slice else args.data2_overlap_ratio,
            slice_pad_end=True
            if args.data2_first_slice else args.data2_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0]

    ####################################################
    x_synthetic = x_synthetic.eval(session=tf.Session())
    x_real = x_real.eval(session=tf.Session())

    # Load graph
    infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta')
    graph = tf.get_default_graph()
    saver = tf.train.import_meta_graph(infer_metagraph_fp)

    # Set up graph for generating preview images
    feeds = {}
    feeds[graph.get_tensor_by_name('x_synthetic:0')] = x_synthetic
    feeds[graph.get_tensor_by_name('synthetic_flat_pad:0')] = int(
        args.data1_sample_rate / 2)
    feeds[graph.get_tensor_by_name('x_synthetic_flat_pad:0')] = int(
        args.data1_sample_rate / 2)
    feeds[graph.get_tensor_by_name('x_real:0')] = x_real
    feeds[graph.get_tensor_by_name('real_flat_pad:0')] = int(
        args.data1_sample_rate / 2)
    feeds[graph.get_tensor_by_name('x_real_flat_pad:0')] = int(
        args.data1_sample_rate / 2)
    fetches = {}
    fetches['step'] = tf.train.get_or_create_global_step()
    fetches['G_synthetic_x'] = graph.get_tensor_by_name('G_synthetic_x:0')
    fetches['G_synthetic_x_flat_int16'] = graph.get_tensor_by_name(
        'G_synthetic_x_flat_int16:0')
    fetches['x_synthetic_flat_int16'] = graph.get_tensor_by_name(
        'x_synthetic_flat_int16:0')
    fetches['G_real_x'] = graph.get_tensor_by_name('G_real_x:0')
    fetches['G_real_x_flat_int16'] = graph.get_tensor_by_name(
        'G_real_x_flat_int16:0')
    fetches['x_real_flat_int16'] = graph.get_tensor_by_name(
        'x_real_flat_int16:0')
    if args.TSRIRgan_genr_pp:
        s_fetches['s_pp_filter'] = graph.get_tensor_by_name(
            'G_synthetic_x/s_pp_filt/conv1d/kernel:0')[:, 0, 0]
        s_fetches['r_pp_filter'] = graph.get_tensor_by_name(
            'G_real_x/r_pp_filt/conv1d/kernel:0')[:, 0, 0]

    # Summarize
    G_synthetic_x = graph.get_tensor_by_name('G_synthetic_x_flat:0')
    s_summaries = [
        tf.summary.audio('preview',
                         tf.expand_dims(G_synthetic_x, axis=0),
                         args.data1_sample_rate,
                         max_outputs=1)
    ]
    fetches['s_summaries'] = tf.summary.merge(s_summaries)
    s_summary_writer = tf.summary.FileWriter(preview_dir)

    G_real_x = graph.get_tensor_by_name('G_real_x_flat:0')
    r_summaries = [
        tf.summary.audio('preview',
                         tf.expand_dims(G_real_x, axis=0),
                         args.data1_sample_rate,
                         max_outputs=1)
    ]
    fetches['r_summaries'] = tf.summary.merge(r_summaries)
    r_summary_writer = tf.summary.FileWriter(preview_dir)

    # PP Summarize
    if args.TSRIRgan_genr_pp:
        s_pp_fp = tf.placeholder(tf.string, [])
        s_pp_bin = tf.read_file(s_pp_fp)
        s_pp_png = tf.image.decode_png(s_pp_bin)
        s_pp_summary = tf.summary.image('s_pp_filt',
                                        tf.expand_dims(s_pp_png, axis=0))

    if args.TSRIRgan_genr_pp:
        r_pp_fp = tf.placeholder(tf.string, [])
        r_pp_bin = tf.read_file(r_pp_fp)
        r_pp_png = tf.image.decode_png(r_pp_bin)
        r_pp_summary = tf.summary.image('r_pp_filt',
                                        tf.expand_dims(r_pp_png, axis=0))

    # Loop, waiting for checkpoints
    ckpt_fp = None
    while True:
        latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
        if latest_ckpt_fp != ckpt_fp:
            print('Preview: {}'.format(latest_ckpt_fp))

            with tf.Session() as sess:
                saver.restore(sess, latest_ckpt_fp)

                _fetches = sess.run(fetches, feeds)

                _step = _fetches['step']

            # with tf.Session() as sess:
            #   saver.restore(sess, latest_ckpt_fp)

            #   _r_fetches = sess.run(r_fetches, r_feeds)

            #   _r_step = _r_fetches['step']

            s_preview_fp = os.path.join(
                preview_dir,
                '{}.wav'.format(str(_step).zfill(8) + 'synthetic'))
            wavwrite(s_preview_fp, args.data1_sample_rate,
                     _fetches['G_synthetic_x_flat_int16'])
            s_original_fp = os.path.join(preview_dir,
                                         '{}.wav'.format('synthetic_original'))
            wavwrite(s_original_fp, args.data1_sample_rate,
                     _fetches['x_synthetic_flat_int16'])

            s_summary_writer.add_summary(_fetches['s_summaries'], _step)

            r_preview_fp = os.path.join(
                preview_dir, '{}.wav'.format(str(_step).zfill(8) + 'real'))
            wavwrite(r_preview_fp, args.data1_sample_rate,
                     _fetches['G_real_x_flat_int16'])
            r_original_fp = os.path.join(preview_dir,
                                         '{}.wav'.format('real_original'))
            wavwrite(r_original_fp, args.data1_sample_rate,
                     _fetches['x_real_flat_int16'])

            r_summary_writer.add_summary(_fetches['r_summaries'], _step)

            #I have to edit this
            # if args.TSRIRgan_genr_pp:
            #   s_w, s_h = freqz(_s_fetches['s_pp_filter'])

            #   fig = plt.figure()
            #   plt.title('Digital filter frequncy response')
            #   ax1 = fig.add_subplot(111)

            #   plt.plot(w, 20 * np.log10(abs(h)), 'b')
            #   plt.ylabel('Amplitude [dB]', color='b')
            #   plt.xlabel('Frequency [rad/sample]')

            #   ax2 = ax1.twinx()
            #   angles = np.unwrap(np.angle(h))
            #   plt.plot(w, angles, 'g')
            #   plt.ylabel('Angle (radians)', color='g')
            #   plt.grid()
            #   plt.axis('tight')

            #   _pp_fp = os.path.join(preview_dir, '{}_ppfilt.png'.format(str(_step).zfill(8)))
            #   plt.savefig(_pp_fp)

            #   with tf.Session() as sess:
            #     _summary = sess.run(pp_summary, {pp_fp: _pp_fp})
            #     summary_writer.add_summary(_summary, _step)

            print('Done')

            ckpt_fp = latest_ckpt_fp

        time.sleep(1)