def infer(args): infer_dir = os.path.join(args.train_dir, 'infer') if not os.path.isdir(infer_dir): os.makedirs(infer_dir) # Subgraph that generates latent vectors samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n') samp_z = tf.random_uniform([samp_z_n, args.specgan_latent_dim], -1.0, 1.0, dtype=tf.float32, name='samp_z') # Input zo z = tf.placeholder(tf.float32, [None, args.specgan_latent_dim], name='z') ngl = tf.placeholder(tf.int32, [], name='ngl') flat_pad = tf.placeholder(tf.int32, [], name='flat_pad') # Execute generator with tf.variable_scope('G'): G_z_norm = SpecGANGenerator(z, train=False, **args.specgan_g_kwargs) G_z_norm = tf.identity(G_z_norm, name='G_z_norm') G_z = f_to_t(G_z_norm, args.data_moments_mean, args.data_moments_std, ngl) G_z = tf.identity(G_z, name='G_z') G_z_norm_uint8 = f_to_img(G_z_norm) G_z_norm_uint8 = tf.identity(G_z_norm_uint8, name='G_z_norm_uint8') # Flatten batch nch = int(G_z.get_shape()[-1]) G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]]) G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat') # Encode to int16 def float_to_int16(x, name=None): x_int16 = x * 32767. x_int16 = tf.clip_by_value(x_int16, -32767., 32767.) x_int16 = tf.cast(x_int16, tf.int16, name=name) return x_int16 G_z_int16 = float_to_int16(G_z, name='G_z_int16') G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16') # Create saver G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G') global_step = tf.train.get_or_create_global_step() saver = tf.train.Saver(G_vars + [global_step]) # Export graph tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt') # Export MetaGraph infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta') tf.train.export_meta_graph(filename=infer_metagraph_fp, clear_devices=True, saver_def=saver.as_saver_def()) # Reset graph (in case training afterwards) tf.reset_default_graph()
def train(fps, args): global train_dataset_size with tf.name_scope('loader'): # This was actually not necessarily good. However, we can keep it as a point for 115 tfrecords # train_fps, _ = loader.split_files_test_val(fps, train_data_percentage, 0) # fps = train_fps # fps = fps[:gan_train_data_size] logging.info("Full training datasize = " + str(find_data_size(fps, None))) length = len(fps) fps = fps[:(int(train_data_percentage / 100.0 * length))] logging.info("GAN training datasize (before exclude) = " + str(find_data_size(fps, None))) if args.exclude_class is None: pass elif args.exclude_class != -1: train_dataset_size = find_data_size(fps, args.exclude_class) logging.info("GAN training datasize (after exclude) = " + str(train_dataset_size)) elif args.exclude_class == -1: fps, _ = loader.split_files_test_val(fps, 0.9, 0) train_dataset_size = find_data_size(fps, args.exclude_class) logging.info( "GAN training datasize (after exclude - random sampling) = " + str(train_dataset_size)) else: # LOL :P raise ValueError( "args.exclude_class should be either [0, num_class), None, or -1 for random sampling 90%" ) training_iterator = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, repeat=True, initializable=True, labels=True, exclude_class=args.exclude_class) x_wav, _ = training_iterator.get_next() # Important: ignore the labels print("x_wav.shape = %s" % str(x_wav.shape)) x = t_to_f(x_wav, args.data_moments_mean, args.data_moments_std) print("x.shape = %s" % str(x.shape)) logging.info("train_dataset_size = " + str(train_dataset_size)) # Make z vector z = tf.random_uniform([args.train_batch_size, _D_Z], -1., 1., dtype=tf.float32) # Make generator with tf.variable_scope('G'): G_z = SpecGANGenerator(z, train=True, **args.specgan_g_kwargs) G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary logging.info('-' * 80) logging.info('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n logging.info('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) logging.info('Total params: {} ({:.2f} MB)'.format( nparams, (float(nparams) * 4) / (1024 * 1024))) # Summarize x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std, args.specgan_ngl) print("x_gl.shape = %s" % str(x_gl.shape)) G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std, args.specgan_ngl) tf.summary.audio('x_wav', x_wav, _FS) tf.summary.audio('x', x_gl, _FS) tf.summary.audio('G_z', G_z_gl, _FS) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) tf.summary.image('x', f_to_img(x)) tf.summary.image('G_z', f_to_img(G_z)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = SpecGANDiscriminator(x, **args.specgan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary logging.info('-' * 80) logging.info('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n logging.info('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) logging.info('Total params: {} ({:.2f} MB)'.format( nparams, (float(nparams) * 4) / (1024 * 1024))) logging.info('-' * 80) # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = SpecGANDiscriminator(G_z, **args.specgan_d_kwargs) # Create loss D_clip_weights = None if args.specgan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real)) D_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake)) D_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real)) D_loss /= 2. elif args.specgan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.)**2) D_loss = tf.reduce_mean((D_x - 1.)**2) D_loss += tf.reduce_mean(D_G_z**2) D_loss /= 2. elif args.specgan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) D_clip_weights = tf.group(*clip_ops) elif args.specgan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = SpecGANDiscriminator(interpolates, **args.specgan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.specgan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif args.specgan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) elif args.specgan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) elif args.specgan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize( G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) # Run training current_step = -1 scaffold = tf.train.Scaffold(local_init_op=tf.group( tf.local_variables_initializer(), training_iterator.initializer), saver=tf.train.Saver(max_to_keep=3)) gpu_options = tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction=0.5) with tf.train.MonitoredTrainingSession( hooks=[SaveAtEnd(os.path.join(args.train_dir, 'model'))], config=tf.ConfigProto(gpu_options=gpu_options), scaffold=scaffold, checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs, ) as sess: # sess.run(training_iterator.initializer) while True: global_step = sess.run(tf.train.get_or_create_global_step()) logging.info("Global step: " + str(global_step)) if args.stop_at_global_step != 0 and global_step >= args.stop_at_global_step: logging.info( "Stopping because args.stop_at_global_step is set to " + str(args.stop_at_global_step)) break # last_saver.save(sess, os.path.join(args.train_dir, 'model'), global_step=global_step) # Train discriminator # for i in range(args.specgan_disc_nupdates): # try: # sess.run(D_train_op) # current_step += 1 # # Stop training after x% of training data seen # if current_step * args.train_batch_size > math.ceil(train_dataset_size * train_data_percentage / 100.0): # logging.info("Stopping at batch: " + str(current_step)) # current_step = -1 # sess.run(training_iterator.initializer) # # except tf.errors.OutOfRangeError: # # End of training dataset # if train_data_percentage != 100: # logging.info("ERROR: end of dataset for only part of data! Achieved end of training dataset with train_data_percentage = " + str(train_data_percentage)) # else: # current_step = -1 # sess.run(training_iterator.initializer) # Train discriminator try: for i in range(args.specgan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) except tf.errors.OutOfRangeError: sess.run(training_iterator.initializer) # Train generator sess.run(G_train_op)
def train(fps, args): with tf.name_scope('loader'): x_wav = loader.decode_extract_and_batch( fps, batch_size=args.train_batch_size, slice_len=_SLICE_LEN, decode_fs=args.data_sample_rate, decode_num_channels=1, decode_fast_wav=args.data_fast_wav, decode_parallel_calls=4, slice_randomize_offset=False if args.data_first_slice else True, slice_first_only=args.data_first_slice, slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio, slice_pad_end=True if args.data_first_slice else args.data_pad_end, repeat=True, shuffle=True, shuffle_buffer_size=4096, prefetch_size=args.train_batch_size * 4, prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0] x = t_to_f(x_wav, args.data_moments_mean, args.data_moments_std) # Make z vector z = tf.random_uniform([args.train_batch_size, args.specgan_latent_dim], -1., 1., dtype=tf.float32) # Make generator with tf.variable_scope('G'): G_z = SpecGANGenerator(z, train=True, **args.specgan_g_kwargs) G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) # Summarize x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std, args.specgan_ngl) G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std, args.specgan_ngl) tf.summary.audio('x_wav', x_wav, args.data_sample_rate) tf.summary.audio('x', x_gl, args.data_sample_rate) tf.summary.audio('G_z', G_z_gl, args.data_sample_rate) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) tf.summary.image('x', f_to_img(x)) tf.summary.image('G_z', f_to_img(G_z)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = SpecGANDiscriminator(x, **args.specgan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = SpecGANDiscriminator(G_z, **args.specgan_d_kwargs) # Create loss D_clip_weights = None if args.specgan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real)) D_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake)) D_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real)) D_loss /= 2. elif args.specgan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.)**2) D_loss = tf.reduce_mean((D_x - 1.)**2) D_loss += tf.reduce_mean(D_G_z**2) D_loss /= 2. elif args.specgan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) D_clip_weights = tf.group(*clip_ops) elif args.specgan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = SpecGANDiscriminator(interpolates, **args.specgan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.specgan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif args.specgan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) elif args.specgan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) elif args.specgan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize( G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) # Run training with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: print('-' * 80) print( 'Training has started. Please use \'tensorboard --logdir={}\' to monitor.' .format(args.train_dir)) while True: # Train discriminator for i in xrange(args.specgan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) # Train generator sess.run(G_train_op)
def train(fps, args): with tf.name_scope('loader'): right_x_wav = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window) right_x = t_to_f(right_x_wav, args.data_moments_mean, args.data_moments_std) wrong_x_wav = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window) wrong_x = t_to_f(wrong_x_wav, args.data_moments_mean, args.data_moments_std) # Make z vector z = tf.random_uniform([args.train_batch_size, _D_Z], -1., 1., dtype=tf.float32) # static_condition means pitch right_static_condition = tf.random_uniform([args.train_batch_size, _STATIC_PITCH_DIM], -1., 1., dtype=tf.float32) wrong_static_condition = tf.random_uniform([args.train_batch_size, _STATIC_PITCH_DIM], -1., 1., dtype=tf.float32) # Make generator with tf.variable_scope('G'): # encode the spectrum into a vector En_right_x = SpecGANEncoder(right_x) En_wrong_x = SpecGANEncoder(wrong_x) Condition_z = tf.concat([En_right_x, z, static_condition], 1) G_z, G_z_static = SpecGANGenerator(Condition_z, train=True, **args.specgan_g_kwargs) G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) # Summarize x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std, args.specgan_ngl) G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std, args.specgan_ngl) tf.summary.audio('x_wav', x_wav, _FS) tf.summary.audio('x', x_gl, _FS) tf.summary.audio('G_z', G_z_gl, _FS) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) tf.summary.image('x', f_to_img(x)) tf.summary.image('G_z', f_to_img(G_z)) # Real input to discriminator dynamic_x = tf.random_uniform([args.train_batch_size, 128, 128, 1], -1., 1., dtype=tf.float32) static_x = tf.random_uniform([args.train_batch_size, _STATIC_TRACT_DIM], -1., 1., dtype=tf.float32) # Make real-right discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): real_logits = SpecGANDiscriminator(dynamic_x, static_x, En_right_x, right_static_condition, **args.specgan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make real-wrong discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): wrong_logits = SpecGANDiscriminator(dynamic_x, static_x, En_wrong_x, wrong_static_condition, **args.specgan_d_kwargs) # Make fake-right discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): fake_logits = SpecGANDiscriminator(G_z, G_z_static, En_right_x, right_static_condition, **args.specgan_d_kwargs) # Create loss D_clip_weights = None if args.specgan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=fake_logits, labels=real )) real_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=real_logits, labels=real )) wrong_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=wrong_logits, labels=fake )) fake_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=fake_logits, labels=fake )) D_loss = real_D_loss + (wrong_D_loss + fake_D_loss)/2. elif args.specgan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.) ** 2) D_loss = tf.reduce_mean((D_x - 1.) ** 2) D_loss += tf.reduce_mean(D_G_z ** 2) D_loss /= 2. elif args.specgan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]) ) ) D_clip_weights = tf.group(*clip_ops) elif args.specgan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = SpecGANDiscriminator(interpolates, **args.specgan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.specgan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer( learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer( learning_rate=2e-4, beta1=0.5) elif args.specgan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer( learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer( learning_rate=1e-4) elif args.specgan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer( learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer( learning_rate=5e-5) elif args.specgan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer( learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer( learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize(G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) # Run training with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: while True: # Train discriminator for i in xrange(args.specgan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) # Train generator sess.run(G_train_op)