def infer(args): infer_dir = os.path.join(args.train_dir, 'infer') if not os.path.isdir(infer_dir): os.makedirs(infer_dir) # Subgraph that generates latent vectors samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n') samp_z = tf.random_uniform([samp_z_n, _D_Z], -1.0, 1.0, dtype=tf.float32, name='samp_z') # Input zo z = tf.placeholder(tf.float32, [None, _D_Z + _D_Y], name='z') flat_pad = tf.placeholder(tf.int32, [], name='flat_pad') # Execute generator with tf.variable_scope('G'): G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_z = tf.identity(G_z, name='G_z') # Flatten batch nch = int(G_z.get_shape()[-1]) G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]]) G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat') # Encode to int16 def float_to_int16(x, name=None): x_int16 = x * 32767. x_int16 = tf.clip_by_value(x_int16, -32767., 32767.) x_int16 = tf.cast(x_int16, tf.int16, name=name) return x_int16 G_z_int16 = float_to_int16(G_z, name='G_z_int16') G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16') # Create saver G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G') global_step = tf.train.get_or_create_global_step() saver = tf.train.Saver(G_vars + [global_step]) # Export graph tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt') # Export MetaGraph infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta') tf.train.export_meta_graph(filename=infer_metagraph_fp, clear_devices=True, saver_def=saver.as_saver_def()) # Reset graph (in case training afterwards) tf.reset_default_graph()
def train(fps, args): with tf.name_scope('loader'): x = loader.decode_extract_and_batch( fps, batch_size=args.train_batch_size, slice_len=args.data_slice_len, decode_fs=args.data_sample_rate, decode_num_channels=args.data_num_channels, decode_fast_wav=args.data_fast_wav, decode_parallel_calls=4, slice_randomize_offset=False if args.data_first_slice else True, slice_first_only=args.data_first_slice, slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio, slice_pad_end=True if args.data_first_slice else args.data_pad_end, repeat=True, shuffle=True, shuffle_buffer_size=4096, prefetch_size=args.train_batch_size * 4, prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0] # Make z vector z = tf.random_uniform([args.train_batch_size, args.wavegan_latent_dim], -1., 1., dtype=tf.float32) # Make generator with tf.variable_scope('G'): G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) # Summarize tf.summary.audio('x', x, args.data_sample_rate) tf.summary.audio('G_z', G_z, args.data_sample_rate) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) # Create loss D_clip_weights = None if args.wavegan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real)) D_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake)) D_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real)) D_loss /= 2. elif args.wavegan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.)**2) D_loss = tf.reduce_mean((D_x - 1.)**2) D_loss += tf.reduce_mean(D_G_z**2) D_loss /= 2. elif args.wavegan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) D_clip_weights = tf.group(*clip_ops) elif args.wavegan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() # Load adversarial input fs, audio = wavread(args.adv_input) assert fs == args.data_sample_rate assert audio.dtype == np.float32 assert len(audio.shape) == 1 # Synthesis if audio.shape[0] < args.data_slice_len: audio = np.pad(audio, (0, args.data_slice_len - audio.shape[0]), 'constant') adv_input = tf.constant( audio[:args.data_slice_len], dtype=np.float32 ) + args.adv_magnitude * tf.reshape(G_z, G_z.get_shape().as_list()[:-1]) # Calculate MFCCs spectrograms = tf.abs( tf.signal.stft(adv_input, frame_length=320, frame_step=160)) linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix( 40, spectrograms.shape[-1].value, fs, 20, 4000) mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1) mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate( linear_to_mel_weight_matrix.shape[-1:])) log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6) mfccs = tf.expand_dims( tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms)[ -1, :99, :40], -1) # Load a model for speech command classification with tf.gfile.FastGFile(args.adv_model, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.variable_scope('Speech'): adv_logits, = tf.import_graph_def(graph_def, input_map={'Mfcc:0': mfccs}, return_elements=['add_2:0']) # Load labels for speech command classification adv_labels = [line.rstrip() for line in tf.gfile.GFile(args.adv_label)] adv_index = adv_labels.index(args.adv_target) # Make adversarial loss # Came from: https://github.com/carlini/nn_robust_attacks/blob/master/l2_attack.py adv_targets = tf.one_hot( tf.constant([adv_index] * args.train_batch_size, dtype=tf.int32), len(adv_labels)) adv_target_logit = tf.reduce_sum(adv_targets * adv_logits, 1) adv_others_logit = tf.reduce_max( (1 - adv_targets) * adv_logits - (adv_targets * 10000), 1) adv_loss = tf.reduce_mean( tf.maximum(0.0, adv_others_logit - adv_target_logit + args.adv_confidence)) # Summarize audios tf.summary.audio('adv_input', adv_input, fs, max_outputs=args.adv_max_outputs) tf.summary.scalar('adv_loss', adv_loss) tf.summary.histogram('adv_classes', tf.argmax(adv_logits, axis=1)) # Create (recommended) optimizer if args.wavegan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif args.wavegan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) elif args.wavegan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) elif args.wavegan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize( G_loss + args.adv_lambda * adv_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) config = tf.ConfigProto() config.gpu_options.allow_growth = True # Run training with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, config=config, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: print('-' * 80) print( 'Training has started. Please use \'tensorboard --logdir={}\' to monitor.' .format(args.train_dir)) while True: # Train discriminator for i in xrange(args.wavegan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) # Train generator sess.run(G_train_op)
def train(fps, args): with tf.name_scope('loader'): x = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window) # Make z vector if args.use_sequence: z = tf.random_uniform([args.train_batch_size, 16, args.d_z], -1., 1., dtype=tf.float32) else: z = tf.random_uniform([args.train_batch_size, args.d_z], -1., 1., dtype=tf.float32)#tf.random_normal([args.train_batch_size, _D_Z]) # Make generator with tf.variable_scope('G'): gru_layer = tf.keras.layers.CuDNNGRU(args.d_z, return_sequences=True) G_z, gru = WaveGANGenerator(z, gru_layer=gru_layer, train=True, return_gru=True, reuse=False, use_sequence=args.use_sequence, **args.wavegan_g_kwargs) print('G_z.shape:',G_z.get_shape().as_list()) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') G_var_names = [g_var.name for g_var in G_vars] # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) extra_secs = 1 if not args.use_sequence: z_feed_long = z else: added_noise = tf.random_uniform([args.train_batch_size, 16*extra_secs, args.d_z], -1., 1., dtype=tf.float32) z_feed_long = tf.concat([z, added_noise], axis=1) with tf.variable_scope('G', reuse=True): #gru_layer.reset_states() G_z_long, gru_long = WaveGANGenerator(z_feed_long, gru_layer=gru_layer, train=False, length=16*extra_secs, return_gru=True, reuse=True, use_sequence=args.use_sequence, **args.wavegan_g_kwargs) print('G_z_long.shape:',G_z_long.get_shape().as_list()) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt', reuse=True): G_z_long = tf.layers.conv1d(G_z_long, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') # Summarize tf.summary.audio('x', x, _FS) tf.summary.audio('G_z', G_z, _FS) tf.summary.audio('G_z_long', G_z_long, _FS) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') print('D_vars:', D_vars) # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) # Create loss D_clip_weights = None if args.wavegan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=D_G_z, labels=real )) D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=D_G_z, labels=fake )) D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=D_x, labels=real )) D_loss /= 2. elif args.wavegan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.) ** 2) D_loss = tf.reduce_mean((D_x - 1.) ** 2) D_loss += tf.reduce_mean(D_G_z ** 2) D_loss /= 2. elif args.wavegan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]) ) ) D_clip_weights = tf.group(*clip_ops) elif args.wavegan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z)# - D_x)#-tf.reduce_mean(D_G_z) + tf.reduce_mean(D_x) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)# - tf.reduce_mean() alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): # #stft = tf.log1p(tf.abs(tf.contrib.signal.stft(interpolates[:,:,0], 512,128,fft_length=512)[:,:,:,tf.newaxis])) #D_interp = WaveGANDiscriminator(interpolates, x_cqt=stft, **args.wavegan_d_kwargs) #D_interp = tf.reduce_sum(tf.log1p(tf.abs(tf.contrib.signal.stft(interpolates[:,:,0], 2048,512,fft_length=2048)[:,:,:,tf.newaxis]))) D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] print('gradients:', gradients) slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.wavegan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer( learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer( learning_rate=2e-4, beta1=0.5) elif args.wavegan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer( learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer( learning_rate=1e-4) elif args.wavegan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer( learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer( learning_rate=5e-5) elif args.wavegan_loss == 'wgan-gp': my_learning_rate = tf.train.exponential_decay(1e-4, tf.get_collection(tf.GraphKeys.GLOBAL_STEP), decay_steps=100000, decay_rate=0.5) G_opt = tf.train.AdamOptimizer( learning_rate=my_learning_rate, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer( learning_rate=my_learning_rate, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize(G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) saver = tf.train.Saver(max_to_keep=10) #tf_max, tf_min = tf.reduce_max(x[:,:,0], axis=-1), tf.reduce_min(x[:,:,0], axis=-1) global_step = tf.get_collection(tf.GraphKeys.GLOBAL_STEP) # Run training with tf.train.MonitoredTrainingSession( scaffold=tf.train.Scaffold(saver=saver), checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: #saver.restore(sess, tf.train.latest_checkpoint(args.train_dir)) iterator_count = 0 while True: # Train discriminator for i in xrange(args.wavegan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) # Train generator #_, g_losses, d_losses, gru_, gru_long_ = sess.run([G_train_op, G_loss, D_loss, gru, gru_long]) _, g_losses, d_losses, global_step_ = sess.run([G_train_op, G_loss, D_loss, global_step]) print('i:', global_step_[0], 'G_loss:', g_losses, 'D_loss:', d_losses) if iterator_count == 0: G_var_dict = {} G_vars_np = sess.run(G_vars) for g_var_name, g_var in zip(G_var_names, G_vars_np): G_var_dict[g_var_name] = g_var with open('saved_G_vars_iteration-{}.pkl'.format(global_step_[0]), 'wb') as f: pickle.dump(G_var_dict, f) #print('maxs:', maxs) #print('mins:', mins) #print(gru_[0]) #print(gru_long_[0]) iterator_count += 1