def infer(args): infer_dir = os.path.join(args.train_dir, 'infer') if not os.path.isdir(infer_dir): os.makedirs(infer_dir) # Subgraph that generates latent vectors samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n') samp_z = tf.random_uniform([samp_z_n, _D_Z], -1.0, 1.0, dtype=tf.float32, name='samp_z') # Input zo z = tf.placeholder(tf.float32, [None, _D_Z + _D_Y], name='z') flat_pad = tf.placeholder(tf.int32, [], name='flat_pad') # Execute generator with tf.variable_scope('G'): G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_z = tf.identity(G_z, name='G_z') # Flatten batch nch = int(G_z.get_shape()[-1]) G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]]) G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat') # Encode to int16 def float_to_int16(x, name=None): x_int16 = x * 32767. x_int16 = tf.clip_by_value(x_int16, -32767., 32767.) x_int16 = tf.cast(x_int16, tf.int16, name=name) return x_int16 G_z_int16 = float_to_int16(G_z, name='G_z_int16') G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16') # Create saver G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G') global_step = tf.train.get_or_create_global_step() saver = tf.train.Saver(G_vars + [global_step]) # Export graph tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt') # Export MetaGraph infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta') tf.train.export_meta_graph(filename=infer_metagraph_fp, clear_devices=True, saver_def=saver.as_saver_def()) # Reset graph (in case training afterwards) tf.reset_default_graph()
def load_wavegan(self, slice_len=16384, model_size=32): path_to_model = os.path.join(self.wavegan_path) self.generator = WaveGANGenerator(slice_len=slice_len, model_size=model_size, use_batch_norm=False, num_channels=1) checkpoint = torch.load(path_to_model, map_location=self.device) self.generator.load_state_dict(checkpoint['generator'])
def train(fps, args): with tf.name_scope('loader'): x, y = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, labels=True) # Make inputs y_fill = tf.expand_dims(y, axis=2) z = tf.random_uniform([args.train_batch_size, _D_Z], -1., 1., dtype=tf.float32) # concatenate x = tf.concat([x, y_fill], 1) z = tf.concat([z, y], 1) # Make generator with tf.variable_scope('G'): G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) # Summarize tf.summary.audio('x', x, _FS) tf.summary.audio('G_z', G_z, _FS) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) # Create loss D_clip_weights = None if args.wavegan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real)) D_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake)) D_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real)) D_loss /= 2. elif args.wavegan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.)**2) D_loss = tf.reduce_mean((D_x - 1.)**2) D_loss += tf.reduce_mean(D_G_z**2) D_loss /= 2. elif args.wavegan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) D_clip_weights = tf.group(*clip_ops) elif args.wavegan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.wavegan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif args.wavegan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) elif args.wavegan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) elif args.wavegan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize( G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) # Run training with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: while True: # Train discriminator for i in xrange(args.wavegan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) # Train generator sess.run(G_train_op)
def initialize(self, opt): BaseModel.initialize(self, opt) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = WaveGANGenerator( model_size=opt.model_size, ngpus=opt.ngpus, latent_dim=opt.latent_dim, alpha=opt.alpha, post_proc_filt_len=opt.post_proc_filt_len) self.netG_B = WaveGANGenerator( model_size=opt.model_size, ngpus=opt.ngpus, latent_dim=opt.latent_dim, alpha=opt.alpha, post_proc_filt_len=opt.post_proc_filt_len) if self.isTrain: use_sigmoid = opt.gan_loss != 'lsgan' self.netD_A = WaveGANDiscriminator(model_size=opt.model_size, ngpus=opt.ngpus, shift_factor=opt.shift_factor, alpha=opt.alpha, batch_shuffle=opt.batch_shuffle) self.netD_B = WaveGANDiscriminator(model_size=opt.model_size, ngpus=opt.ngpus, shift_factor=opt.shift_factor, alpha=opt.alpha, batch_shuffle=opt.batch_shuffle) if self.isTrain: self.fake_A_pool = AudioPool(opt.pool_size) self.fake_B_pool = AudioPool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss(loss_type=opt.gan_loss, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(get_scheduler(optimizer, opt)) LOGGER.info('---------- Networks initialized -------------') print_network(self.netG_A) print_network(self.netG_B) if self.isTrain: print_network(self.netD_A) print_network(self.netD_B) LOGGER.info('-----------------------------------------------')
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = WaveGANGenerator( model_size=opt.model_size, ngpus=opt.ngpus, latent_dim=opt.latent_dim, alpha=opt.alpha, post_proc_filt_len=opt.post_proc_filt_len) self.netG_B = WaveGANGenerator( model_size=opt.model_size, ngpus=opt.ngpus, latent_dim=opt.latent_dim, alpha=opt.alpha, post_proc_filt_len=opt.post_proc_filt_len) if self.isTrain: use_sigmoid = opt.gan_loss != 'lsgan' self.netD_A = WaveGANDiscriminator(model_size=opt.model_size, ngpus=opt.ngpus, shift_factor=opt.shift_factor, alpha=opt.alpha, batch_shuffle=opt.batch_shuffle) self.netD_B = WaveGANDiscriminator(model_size=opt.model_size, ngpus=opt.ngpus, shift_factor=opt.shift_factor, alpha=opt.alpha, batch_shuffle=opt.batch_shuffle) if self.isTrain: self.fake_A_pool = AudioPool(opt.pool_size) self.fake_B_pool = AudioPool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss(loss_type=opt.gan_loss, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(get_scheduler(optimizer, opt)) LOGGER.info('---------- Networks initialized -------------') print_network(self.netG_A) print_network(self.netG_B) if self.isTrain: print_network(self.netD_A) print_network(self.netD_B) LOGGER.info('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.audio_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): real_A = Variable(self.input_A, volatile=True) fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get audio paths def get_audio_paths(self): return self.audio_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_wp(self, netD, real, fake): # Gradient penalty loss for WGAN-WP loss_D_wp = calc_gradient_penalty(netD, real, fake, self.opt.batchSize, self.opt.lambda_wp, use_cuda=len(self.gpu_ids) > 0) loss_D_wp.backward() return loss_D_wp def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.data[0] if self.opt.gan_loss == 'wgan-wp': loss_D_wp_A = self.backward_D_wp(self.netD_A, self.real_B, fake_B) self.loss_D_wp_A = loss_D_wp_A.data[0] else: self.loss_D_wp_A = 0 def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.data[0] if self.opt.gan_loss == 'wgan-wp': loss_D_wp_B = self.backward_D_wp(self.netD_B, self.real_A, fake_A) self.loss_D_wp_B = loss_D_wp_B.data[0] else: self.loss_D_wp_B = 0 def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B if self.opt.gan_loss == 'wgan-wp': ret_errors['WP_A'] = self.loss_D_wp_A ret_errors['WP_B'] = self.loss_D_wp_B return ret_errors def get_current_audibles(self): real_A = tensor2audio(self.input_A) fake_B = tensor2audio(self.fake_B) rec_A = tensor2audio(self.rec_A) real_B = tensor2audio(self.input_B) fake_A = tensor2audio(self.fake_A) rec_B = tensor2audio(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = tensor2audio(self.idt_A) ret_visuals['idt_B'] = tensor2audio(self.idt_B) return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
def train(fps, args): with tf.name_scope('loader'): x = loader.decode_extract_and_batch( fps, batch_size=args.train_batch_size, slice_len=args.data_slice_len, decode_fs=args.data_sample_rate, decode_num_channels=args.data_num_channels, decode_fast_wav=args.data_fast_wav, decode_parallel_calls=4, slice_randomize_offset=False if args.data_first_slice else True, slice_first_only=args.data_first_slice, slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio, slice_pad_end=True if args.data_first_slice else args.data_pad_end, repeat=True, shuffle=True, shuffle_buffer_size=4096, prefetch_size=args.train_batch_size * 4, prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0] # Make z vector z = tf.random_uniform([args.train_batch_size, args.wavegan_latent_dim], -1., 1., dtype=tf.float32) # Make generator with tf.variable_scope('G'): G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) # Summarize tf.summary.audio('x', x, args.data_sample_rate) tf.summary.audio('G_z', G_z, args.data_sample_rate) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) # Create loss D_clip_weights = None if args.wavegan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real)) D_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake)) D_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real)) D_loss /= 2. elif args.wavegan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.)**2) D_loss = tf.reduce_mean((D_x - 1.)**2) D_loss += tf.reduce_mean(D_G_z**2) D_loss /= 2. elif args.wavegan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) D_clip_weights = tf.group(*clip_ops) elif args.wavegan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() # Load adversarial input fs, audio = wavread(args.adv_input) assert fs == args.data_sample_rate assert audio.dtype == np.float32 assert len(audio.shape) == 1 # Synthesis if audio.shape[0] < args.data_slice_len: audio = np.pad(audio, (0, args.data_slice_len - audio.shape[0]), 'constant') adv_input = tf.constant( audio[:args.data_slice_len], dtype=np.float32 ) + args.adv_magnitude * tf.reshape(G_z, G_z.get_shape().as_list()[:-1]) # Calculate MFCCs spectrograms = tf.abs( tf.signal.stft(adv_input, frame_length=320, frame_step=160)) linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix( 40, spectrograms.shape[-1].value, fs, 20, 4000) mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1) mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate( linear_to_mel_weight_matrix.shape[-1:])) log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6) mfccs = tf.expand_dims( tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms)[ -1, :99, :40], -1) # Load a model for speech command classification with tf.gfile.FastGFile(args.adv_model, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.variable_scope('Speech'): adv_logits, = tf.import_graph_def(graph_def, input_map={'Mfcc:0': mfccs}, return_elements=['add_2:0']) # Load labels for speech command classification adv_labels = [line.rstrip() for line in tf.gfile.GFile(args.adv_label)] adv_index = adv_labels.index(args.adv_target) # Make adversarial loss # Came from: https://github.com/carlini/nn_robust_attacks/blob/master/l2_attack.py adv_targets = tf.one_hot( tf.constant([adv_index] * args.train_batch_size, dtype=tf.int32), len(adv_labels)) adv_target_logit = tf.reduce_sum(adv_targets * adv_logits, 1) adv_others_logit = tf.reduce_max( (1 - adv_targets) * adv_logits - (adv_targets * 10000), 1) adv_loss = tf.reduce_mean( tf.maximum(0.0, adv_others_logit - adv_target_logit + args.adv_confidence)) # Summarize audios tf.summary.audio('adv_input', adv_input, fs, max_outputs=args.adv_max_outputs) tf.summary.scalar('adv_loss', adv_loss) tf.summary.histogram('adv_classes', tf.argmax(adv_logits, axis=1)) # Create (recommended) optimizer if args.wavegan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif args.wavegan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) elif args.wavegan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) elif args.wavegan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize( G_loss + args.adv_lambda * adv_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) config = tf.ConfigProto() config.gpu_options.allow_growth = True # Run training with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, config=config, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: print('-' * 80) print( 'Training has started. Please use \'tensorboard --logdir={}\' to monitor.' .format(args.train_dir)) while True: # Train discriminator for i in xrange(args.wavegan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) # Train generator sess.run(G_train_op)
print('Saving configurations...') config_path = os.path.join(model_dir, 'config.json') with open(config_path, 'w') as f: json.dump(args, f) # Try on some training data print('Loading audio data...') audio_filepaths = get_all_audio_filepaths(args['audio_dir']) train_gen, valid_data, test_data \ = create_data_split(audio_filepaths, args['valid_ratio'], args['test_ratio'], batch_size, batch_size, batch_size) print('Creating models...') model_gen = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim, post_proc_filt_len=args['post_proc_filt_len'], upsample=True) model_dis = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus, alpha=args['alpha'], shift_factor=args['shift_factor'], batch_shuffle=args['batch_shuffle']) print('Starting training...') model_gen, model_dis, history, final_discr_metrics, samples = train_wgan( model_gen=model_gen, model_dis=model_dis, train_gen=train_gen, valid_data=valid_data, test_data=test_data,
def train(fps, args): with tf.name_scope('loader'): x = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window) # Make z vector if args.use_sequence: z = tf.random_uniform([args.train_batch_size, 16, args.d_z], -1., 1., dtype=tf.float32) else: z = tf.random_uniform([args.train_batch_size, args.d_z], -1., 1., dtype=tf.float32)#tf.random_normal([args.train_batch_size, _D_Z]) # Make generator with tf.variable_scope('G'): gru_layer = tf.keras.layers.CuDNNGRU(args.d_z, return_sequences=True) G_z, gru = WaveGANGenerator(z, gru_layer=gru_layer, train=True, return_gru=True, reuse=False, use_sequence=args.use_sequence, **args.wavegan_g_kwargs) print('G_z.shape:',G_z.get_shape().as_list()) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') G_var_names = [g_var.name for g_var in G_vars] # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) extra_secs = 1 if not args.use_sequence: z_feed_long = z else: added_noise = tf.random_uniform([args.train_batch_size, 16*extra_secs, args.d_z], -1., 1., dtype=tf.float32) z_feed_long = tf.concat([z, added_noise], axis=1) with tf.variable_scope('G', reuse=True): #gru_layer.reset_states() G_z_long, gru_long = WaveGANGenerator(z_feed_long, gru_layer=gru_layer, train=False, length=16*extra_secs, return_gru=True, reuse=True, use_sequence=args.use_sequence, **args.wavegan_g_kwargs) print('G_z_long.shape:',G_z_long.get_shape().as_list()) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt', reuse=True): G_z_long = tf.layers.conv1d(G_z_long, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') # Summarize tf.summary.audio('x', x, _FS) tf.summary.audio('G_z', G_z, _FS) tf.summary.audio('G_z_long', G_z_long, _FS) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') print('D_vars:', D_vars) # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) # Create loss D_clip_weights = None if args.wavegan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=D_G_z, labels=real )) D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=D_G_z, labels=fake )) D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=D_x, labels=real )) D_loss /= 2. elif args.wavegan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.) ** 2) D_loss = tf.reduce_mean((D_x - 1.) ** 2) D_loss += tf.reduce_mean(D_G_z ** 2) D_loss /= 2. elif args.wavegan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]) ) ) D_clip_weights = tf.group(*clip_ops) elif args.wavegan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z)# - D_x)#-tf.reduce_mean(D_G_z) + tf.reduce_mean(D_x) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)# - tf.reduce_mean() alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): # #stft = tf.log1p(tf.abs(tf.contrib.signal.stft(interpolates[:,:,0], 512,128,fft_length=512)[:,:,:,tf.newaxis])) #D_interp = WaveGANDiscriminator(interpolates, x_cqt=stft, **args.wavegan_d_kwargs) #D_interp = tf.reduce_sum(tf.log1p(tf.abs(tf.contrib.signal.stft(interpolates[:,:,0], 2048,512,fft_length=2048)[:,:,:,tf.newaxis]))) D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] print('gradients:', gradients) slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.wavegan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer( learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer( learning_rate=2e-4, beta1=0.5) elif args.wavegan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer( learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer( learning_rate=1e-4) elif args.wavegan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer( learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer( learning_rate=5e-5) elif args.wavegan_loss == 'wgan-gp': my_learning_rate = tf.train.exponential_decay(1e-4, tf.get_collection(tf.GraphKeys.GLOBAL_STEP), decay_steps=100000, decay_rate=0.5) G_opt = tf.train.AdamOptimizer( learning_rate=my_learning_rate, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer( learning_rate=my_learning_rate, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize(G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) saver = tf.train.Saver(max_to_keep=10) #tf_max, tf_min = tf.reduce_max(x[:,:,0], axis=-1), tf.reduce_min(x[:,:,0], axis=-1) global_step = tf.get_collection(tf.GraphKeys.GLOBAL_STEP) # Run training with tf.train.MonitoredTrainingSession( scaffold=tf.train.Scaffold(saver=saver), checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: #saver.restore(sess, tf.train.latest_checkpoint(args.train_dir)) iterator_count = 0 while True: # Train discriminator for i in xrange(args.wavegan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) # Train generator #_, g_losses, d_losses, gru_, gru_long_ = sess.run([G_train_op, G_loss, D_loss, gru, gru_long]) _, g_losses, d_losses, global_step_ = sess.run([G_train_op, G_loss, D_loss, global_step]) print('i:', global_step_[0], 'G_loss:', g_losses, 'D_loss:', d_losses) if iterator_count == 0: G_var_dict = {} G_vars_np = sess.run(G_vars) for g_var_name, g_var in zip(G_var_names, G_vars_np): G_var_dict[g_var_name] = g_var with open('saved_G_vars_iteration-{}.pkl'.format(global_step_[0]), 'wb') as f: pickle.dump(G_var_dict, f) #print('maxs:', maxs) #print('mins:', mins) #print(gru_[0]) #print(gru_long_[0]) iterator_count += 1
def train(fps, args): with tf.name_scope('loader'): x = loader.decode_extract_and_batch( fps, batch_size=args.train_batch_size, slice_len=args.data_slice_len, decode_fs=args.data_sample_rate, decode_num_channels=args.data_num_channels, decode_fast_wav=args.data_fast_wav, decode_parallel_calls=4, slice_randomize_offset=False if args.data_first_slice else True, slice_first_only=args.data_first_slice, slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio, slice_pad_end=True if args.data_first_slice else args.data_pad_end, repeat=True, shuffle=True, shuffle_buffer_size=4096, prefetch_size=args.train_batch_size * 4, prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0] # Make z vector z = tf.random_uniform([args.train_batch_size, args.wavegan_latent_dim], -1., 1., dtype=tf.float32) # Make generator with tf.variable_scope('G'): G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) # Summarize tf.summary.audio('x', x, args.data_sample_rate) tf.summary.audio('G_z', G_z, args.data_sample_rate) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) # Create loss D_clip_weights = None if args.wavegan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real)) D_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake)) D_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real)) D_loss /= 2. elif args.wavegan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.)**2) D_loss = tf.reduce_mean((D_x - 1.)**2) D_loss += tf.reduce_mean(D_G_z**2) D_loss /= 2. elif args.wavegan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) D_clip_weights = tf.group(*clip_ops) elif args.wavegan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.wavegan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif args.wavegan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) elif args.wavegan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) elif args.wavegan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize( G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) # Run training with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: print('-' * 80) print( 'Training has started. Please use \'tensorboard --logdir={}\' to monitor.' .format(args.train_dir)) counter = 0 while True: # Train discriminator for i in xrange(args.wavegan_disc_nupdates): sess.run(D_train_op) print("Ran D_train_op") # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) print("Ran D_clip_weights") # Train generator sess.run(G_train_op) print("Ran G_train_op") counter += 1 print("Iteration: " + str(counter)) print("DONE TRAINING")
def train(fps, args): # global train_dataset_size # global train_data_percentage train_data_percentage = args.train_data_percentage with tf.name_scope('loader'): # This was actually not necessarily good. However, we can keep it as a point for 115 tfrecords # train_fps, _ = loader.split_files_test_val(fps, train_data_percentage, 0) # fps = train_fps # fps = fps[:gan_train_data_size] logging.info("Full training datasize = " + str(find_data_size(fps, None))) length = len(fps) fps = fps[:(int(train_data_percentage / 100.0 * length))] logging.info("GAN training datasize (before exclude) = " + str(find_data_size(fps, None))) if args.exclude_class is None: pass elif args.exclude_class != -1: train_dataset_size = find_data_size(fps, args.exclude_class) logging.info("GAN training datasize (after exclude) = " + str(train_dataset_size)) elif args.exclude_class == -1: fps, _ = loader.split_files_test_val(fps, 0.9, 0) train_dataset_size = find_data_size(fps, args.exclude_class) logging.info( "GAN training datasize (after exclude - random sampling) = " + str(train_dataset_size)) else: # LOL :P raise ValueError( "args.exclude_class should be either [0, num_class), None, or -1 for random sampling 90%" ) training_iterator = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, repeat=True, initializable=True, labels=True, exclude_class=args.exclude_class) x, _ = training_iterator.get_next() # Important: ignore the labels print("x_wav.shape = %s" % str(x.shape)) logging.info("train_dataset_size = " + str(train_dataset_size)) # Make z vector z = tf.random_uniform([args.train_batch_size, _D_Z], -1., 1., dtype=tf.float32) # Make generator with tf.variable_scope('G'): G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary print '-' * 80 print 'Generator vars' nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print '{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name) print 'Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)) # Summarize tf.summary.audio('x', x, _FS) tf.summary.audio('G_z', G_z, _FS) G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary print '-' * 80 print 'Discriminator vars' nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print '{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name) print 'Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)) print '-' * 80 # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) # Create loss D_clip_weights = None if args.wavegan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size], dtype=tf.float32) real = tf.ones([args.train_batch_size], dtype=tf.float32) G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real)) D_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake)) D_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real)) D_loss /= 2. elif args.wavegan_loss == 'lsgan': G_loss = tf.reduce_mean((D_G_z - 1.)**2) D_loss = tf.reduce_mean((D_x - 1.)**2) D_loss += tf.reduce_mean(D_G_z**2) D_loss /= 2. elif args.wavegan_loss == 'wgan': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) D_clip_weights = tf.group(*clip_ops) elif args.wavegan_loss == 'wgan-gp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) D_loss += LAMBDA * gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.wavegan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif args.wavegan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) elif args.wavegan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) elif args.wavegan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise NotImplementedError() # Create training ops G_train_op = G_opt.minimize( G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) # Run training current_step = -1 scaffold = tf.train.Scaffold(local_init_op=tf.group( tf.local_variables_initializer(), training_iterator.initializer), saver=tf.train.Saver(max_to_keep=5)) with tf.train.MonitoredTrainingSession( hooks=[SaveAtEnd(os.path.join(args.train_dir, 'model'))], scaffold=scaffold, checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: while True: global_step = sess.run(tf.train.get_or_create_global_step()) logging.info("Global step: " + str(global_step)) if args.stop_at_global_step != 0 and global_step >= args.stop_at_global_step: logging.info( "Stopping because args.stop_at_global_step is set to " + str(args.stop_at_global_step)) break # Train discriminator # for i in range(args.wavegan_disc_nupdates): # try: # sess.run(D_train_op) # current_step += 1 # # # Stop training after x% of training data seen # if current_step * args.train_batch_size > math.ceil(train_dataset_size * train_data_percentage / 100.0): # print("Stopping at batch: " + str(current_step)) # current_step = -1 # sess.run(training_iterator.initializer) # # except tf.errors.OutOfRangeError: # # End of training dataset # if train_data_percentage != 100: # print( # "ERROR: end of dataset for only part of data! Achieved end of training dataset with train_data_percentage = " + str( # train_data_percentage)) # else: # current_step = -1 # sess.run(training_iterator.initializer) try: for i in range(args.wavegan_disc_nupdates): sess.run(D_train_op) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights) except tf.errors.OutOfRangeError: sess.run(training_iterator.initializer) # Train generator sess.run(G_train_op)
args['model_dir'] = model_dir # save samples for every N epochs. epochs_per_sample = args['epochs_per_sample'] # gradient penalty regularization factor. lmbda = args['lmbda'] # Dir audio_dir = args['audio_dir'] output_dir = args['output_dir'] # =============Network=============== netG = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim, upsample=True) netD = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus) if cuda: netG = torch.nn.DataParallel(netG).cuda() netD = torch.nn.DataParallel(netD).cuda() # "Two time-scale update rule"(TTUR) to update netD 4x faster than netG. optimizerG = optim.Adam(netG.parameters(), lr=args['learning_rate'], betas=(args['beta1'], args['beta2'])) optimizerD = optim.Adam(netD.parameters(), lr=args['learning_rate'], betas=(args['beta1'], args['beta2'])) # Sample noise used for generated output. sample_noise = torch.randn(args['sample_size'], latent_dim) if cuda: sample_noise = sample_noise.cuda() sample_noise_Var = autograd.Variable(sample_noise, requires_grad=False)
class Word2Wave(nn.Module): def __init__(self, args): super(Word2Wave, self).__init__() self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.coala_model_name = args.coala_model_name self.wavegan_path = args.wavegan_path self.load_wavegan() self.load_coala() self.init_latents() def load_wavegan(self, slice_len=16384, model_size=32): path_to_model = os.path.join(self.wavegan_path) self.generator = WaveGANGenerator(slice_len=slice_len, model_size=model_size, use_batch_norm=False, num_channels=1) checkpoint = torch.load(path_to_model, map_location=self.device) self.generator.load_state_dict(checkpoint['generator']) def load_coala(self): coala_path = os.path.join("coala/models", self.coala_model_name) tag_encoder_url = "https://github.com/xavierfav/coala/blob/master/saved_models/{}/tag_encoder_epoch_200.pt".format( self.coala_model_name) audio_encoder_url = "https://github.com/xavierfav/coala/blob/master/saved_models/{}/audio_encoder_epoch_200.pt".format( self.coala_model_name) tag_encoder_path = os.path.join(coala_path, os.path.basename(tag_encoder_url)) audio_encoder_path = os.path.join(coala_path, os.path.basename(audio_encoder_url)) # TODO below does not work due to corrupted download - download manually instead if not os.path.exists(coala_path): os.mkdir(coala_path) logging.info("Downloading COALA model weights from {}".format( audio_encoder_url)) urlretrieve(tag_encoder_url, tag_encoder_path) urlretrieve(audio_encoder_url, audio_encoder_path) self.tag_encoder = TagEncoder() self.tag_encoder.load_state_dict(torch.load(tag_encoder_path)) self.tag_encoder.eval() self.audio_encoder = AudioEncoder() self.audio_encoder.load_state_dict(torch.load(audio_encoder_path)) self.audio_encoder.eval() id2tag = json.load(open('coala/id2token_top_1000.json', 'rb')) self.tag2id = {tag: id for id, tag in id2tag.items()} def init_latents(self, size=1, latent_dim=100): noise = torch.FloatTensor(size, latent_dim) noise.data.normal_() self.latents = torch.nn.Parameter(noise) def tokenize_text(self, text_prompt): words_not_in_dict = [ word for word in text_prompt.split(" ") if word not in self.tag2id.keys() ] words_in_dict = [ word for word in text_prompt.split(" ") if word in self.tag2id.keys() ] tokenized_text = [int(self.tag2id[word]) for word in words_in_dict] return tokenized_text, words_in_dict, words_not_in_dict def encode_text(self, text_prompt): word_ids, _, _ = self.tokenize_text(text_prompt) sentence_embedding = torch.zeros(1152).to(self.device) tag_vector = torch.zeros(len(word_ids), 1000).to(self.device) for index, word in enumerate(word_ids): tag_vector[index, word] = 1 embedding, embedding_d = self.tag_encoder(tag_vector) sentence_embedding = embedding_d.mean(dim=0) return sentence_embedding def encode_audio(self, audio): x = preprocess_audio(audio).to(self.device) scaler = pickle.load(open('coala/scaler_top_1000.pkl', 'rb')) x *= torch.tensor(scaler.scale_).to(self.device) x += torch.tensor(scaler.min_).to(self.device) x = torch.clamp(x, scaler.feature_range[0], scaler.feature_range[1]) embedding, embedding_d = self.audio_encoder( x.unsqueeze(0).unsqueeze(0)) return embedding_d def latent_space_interpolation(self, latents=None, n_samples=1): if latents is None: z_test = sample_noise(2) else: z_test = latents interpolates = [] for alpha in np.linspace(0, 1, n_samples): interpolate_vec = alpha * z_test[0] + ((1 - alpha) * z_test[1]) interpolates.append(interpolate_vec) interpolates = torch.stack(interpolates) generated_audio = self.generator(interpolates) return generated_audio def synthesise_audio(self, noise): generated_audio = self.generator(noise).view(-1) return generated_audio def coala_loss(self, audio, text): text_embedding = self.encode_text(text) audio_embedding = self.encode_audio(audio) text_embedding = text_embedding / text_embedding.norm() audio_embedding = audio_embedding / audio_embedding.norm() cos_dist = (1 - audio_embedding @ text_embedding.t()) / 2 return cos_dist def forward(self, text): audio = self.generator(self.latents).view(-1) loss = self.coala_loss(audio, text) return audio, loss
def train(fps, args): with tf.name_scope('loader'): x, cond_text, _ = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, conditionals=True, name='batch') wrong_audio = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, conditionals=False, name='wrong_batch') # wrong_cond_text, wrong_cond_text_embed = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, wavs=False, conditionals=True, name='batch') # Make z vector z = tf.random_normal([args.train_batch_size, _D_Z]) embed = hub.Module('https://tfhub.dev/google/elmo/2', trainable=False, name='embed') cond_text_embed = embed(cond_text) # Add conditioning input to the model args.wavegan_g_kwargs['context_embedding'] = cond_text_embed args.wavegan_d_kwargs['context_embedding'] = args.wavegan_g_kwargs[ 'context_embedding'] lod = tf.placeholder(tf.float32, shape=[]) with tf.variable_scope('G'): # Make generator G_z, c_kl_loss = WaveGANGenerator(z, lod, train=True, **args.wavegan_g_kwargs) if args.wavegan_genr_pp: with tf.variable_scope('pp_filt'): G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') # Summarize G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) x_rms_lod_4 = tf.sqrt( tf.reduce_mean(tf.square(avg_downsample(x)[:, :, 0]), axis=1)) x_rms_lod_3 = tf.sqrt( tf.reduce_mean(tf.square(avg_downsample(avg_downsample(x))[:, :, 0]), axis=1)) x_rms_lod_2 = tf.sqrt( tf.reduce_mean(tf.square( avg_downsample(avg_downsample(avg_downsample(x)))[:, :, 0]), axis=1)) x_rms_lod_1 = tf.sqrt( tf.reduce_mean(tf.square( avg_downsample(avg_downsample(avg_downsample( avg_downsample(x))))[:, :, 0]), axis=1)) x_rms_lod_0 = tf.sqrt( tf.reduce_mean(tf.square( avg_downsample( avg_downsample( avg_downsample(avg_downsample(avg_downsample(x)))))[:, :, 0]), axis=1)) tf.summary.histogram('x_rms_batch', x_rms) tf.summary.histogram('G_z_rms_batch', G_z_rms) tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) tf.summary.scalar('x_rms_lod_4', tf.reduce_mean(x_rms_lod_4)) tf.summary.scalar('x_rms_lod_3', tf.reduce_mean(x_rms_lod_3)) tf.summary.scalar('x_rms_lod_2', tf.reduce_mean(x_rms_lod_2)) tf.summary.scalar('x_rms_lod_1', tf.reduce_mean(x_rms_lod_1)) tf.summary.scalar('x_rms_lod_0', tf.reduce_mean(x_rms_lod_0)) tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) tf.summary.audio('x', x, _FS, max_outputs=10) tf.summary.audio('G_z', G_z, _FS, max_outputs=10) tf.summary.text('Conditioning Text', cond_text[:10]) # with tf.variable_scope('G'): # # Make history buffer # history_buffer = HistoryBuffer(_WINDOW_LEN, args.train_batch_size * 100, args.train_batch_size) # # Select half of batch from history buffer # g_from_history, r_from_history, embeds_from_history = history_buffer.get_from_history_buffer() # new_fake_batch = tf.concat([G_z[:tf.shape(G_z)[0] - tf.shape(g_from_history)[0]], g_from_history], 0) # Use tf.shape to handle case when g_from_history is empty # new_cond_embeds = tf.concat([cond_text_embed[:tf.shape(cond_text_embed)[0] - tf.shape(embeds_from_history)[0]], embeds_from_history], 0) # new_real_batch = tf.concat([x[:tf.shape(x)[0] - tf.shape(r_from_history)[0]], r_from_history], 0) # with tf.control_dependencies([new_fake_batch, new_real_batch, new_cond_embeds]): # with tf.control_dependencies([history_buffer.add_to_history_buffer(G_z, x, cond_text_embed)]): # G_z = tf.identity(new_fake_batch) # x = tf.identity(new_real_batch) # args.wavegan_g_kwargs['context_embedding'] = tf.identity(new_cond_embeds) # args.wavegan_d_kwargs['context_embedding'] = args.wavegan_g_kwargs['context_embedding'] # G_z.set_shape([args.train_batch_size, _WINDOW_LEN, 1]) # x.set_shape([args.train_batch_size, _WINDOW_LEN, 1]) G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Print G summary print('-' * 80) print('Generator vars') nparams = 0 for v in G_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) # Summarize # tf.summary.scalar('history_buffer_size', history_buffer.current_size) # tf.summary.scalar('g_from_history_size', tf.shape(g_from_history)[0]) # tf.summary.scalar('r_from_history_size', tf.shape(r_from_history)[0]) # tf.summary.scalar('embeds_from_history_size', tf.shape(embeds_from_history)[0]) # tf.summary.audio('G_z_history', g_from_history, _FS, max_outputs=10) # tf.summary.audio('x_history', r_from_history, _FS, max_outputs=10) tf.summary.audio('wrong_audio', wrong_audio, _FS, max_outputs=10) tf.summary.scalar('Conditional Resample - KL-Loss', c_kl_loss) # tf.summary.scalar('embed_error_cosine', tf.reduce_sum(tf.multiply(cond_text_embed, expected_embed)) / (tf.norm(cond_text_embed) * tf.norm(expected_embed))) # tf.summary.scalar('embed_error_cosine_wrong', tf.reduce_sum(tf.multiply(wrong_cond_text_embed, expected_embed)) / (tf.norm(wrong_cond_text_embed) * tf.norm(expected_embed))) # Make real discriminator with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = WaveGANDiscriminator(x, lod, **args.wavegan_d_kwargs) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Print D summary print('-' * 80) print('Discriminator vars') nparams = 0 for v in D_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # Make fake / wrong discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = WaveGANDiscriminator(G_z, lod, **args.wavegan_d_kwargs) with tf.name_scope('D_w'), tf.variable_scope('D', reuse=True): D_w = WaveGANDiscriminator(wrong_audio, lod, **args.wavegan_d_kwargs) # Create loss D_clip_weights = None if args.wavegan_loss == 'dcgan': fake = tf.zeros([args.train_batch_size, 1], dtype=tf.float32) real = tf.ones([args.train_batch_size, 1], dtype=tf.float32) # Conditional G Loss G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[0], labels=real)) G_loss += c_kl_loss # Unconditional G Loss if args.use_extra_uncond_loss: G_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[1], labels=real)) G_loss /= 2 # Conditional D Losses D_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[0], labels=fake)) D_loss_wrong = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_w[0], labels=fake)) D_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x[0], labels=real)) # Unconditional D Losses if args.use_extra_uncond_loss: D_loss_fake_uncond = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[1], labels=fake)) D_loss_wrong_uncond = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_w[1], labels=real)) D_loss_real_uncond = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x[1], labels=real)) D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \ + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond D_loss /= 2 else: D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) # Warmup Conditional Loss # D_warmup_loss = D_loss_real + D_loss_wrong elif args.wavegan_loss == 'lsgan': # Conditional G Loss G_loss = tf.reduce_mean((D_G_z[0] - 1.)**2) G_loss += c_kl_loss # Unconditional G Loss if args.use_extra_uncond_loss: G_loss += tf.reduce_mean((D_G_z[1] - 1.)**2) G_loss /= 2 # Conditional D Loss D_loss_real = tf.reduce_mean((D_x[0] - 1.)**2) D_loss_wrong = tf.reduce_mean(D_w[0]**2) D_loss_fake = tf.reduce_mean(D_G_z[0]**2) # Unconditional D Loss if args.use_extra_uncond_loss: D_loss_real_uncond = tf.reduce_mean((D_x[1] - 1.)**2) D_loss_wrong_uncond = tf.reduce_mean((D_w[1] - 1.)**2) D_loss_fake_uncond = tf.reduce_mean(D_G_z[1]**2) D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \ + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond D_loss /= 2 else: D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) # Warmup Conditional Loss # D_warmup_loss = D_loss_real + D_loss_wrong elif args.wavegan_loss == 'wgan': # Conditional G Loss G_loss = -tf.reduce_mean(D_G_z[0]) G_loss += c_kl_loss # Unconditional G Loss if args.use_extra_uncond_loss: G_loss += -tf.reduce_mean(D_G_z[1]) G_loss /= 2 # Conditional D Loss D_loss_real = -tf.reduce_mean(D_x[0]) D_loss_wrong = tf.reduce_mean(D_w[0]) D_loss_fake = tf.reduce_mean(D_G_z[0]) # Unconditional D Loss if args.use_extra_uncond_loss: D_loss_real_uncond = -tf.reduce_mean(D_x[1]) D_loss_wrong_uncond = -tf.reduce_mean(D_w[1]) D_loss_fake_uncond = tf.reduce_mean(D_G_z[1]) D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \ + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond D_loss /= 2 else: D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) # Warmup Conditional Loss # D_warmup_loss = D_loss_real + D_loss_wrong with tf.name_scope('D_clip_weights'): clip_ops = [] for var in D_vars: clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) D_clip_weights = tf.group(*clip_ops) elif args.wavegan_loss == 'wgan-gp': # Conditional G Loss G_loss = -tf.reduce_mean(D_G_z[0]) G_loss += c_kl_loss # Unconditional G Loss if args.use_extra_uncond_loss: G_loss += -tf.reduce_mean(D_G_z[1]) G_loss /= 2 # Conditional D Loss D_loss_real = -tf.reduce_mean(D_x[0]) D_loss_wrong = tf.reduce_mean(D_w[0]) D_loss_fake = tf.reduce_mean(D_G_z[0]) # Unconditional D Loss if args.use_extra_uncond_loss: D_loss_real_uncond = -tf.reduce_mean(D_x[1]) D_loss_wrong_uncond = -tf.reduce_mean(D_w[1]) D_loss_fake_uncond = tf.reduce_mean(D_G_z[1]) D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \ + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond D_loss /= 2 else: D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) # Warmup Conditional Loss # D_warmup_loss = D_loss_real + D_loss_wrong # Conditional Gradient Penalty alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) real = x fake = tf.concat([ G_z[:args.train_batch_size // 2], wrong_audio[:args.train_batch_size // 2] ], 0) differences = fake - real interpolates = real + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = WaveGANDiscriminator( interpolates, lod, **args.wavegan_d_kwargs)[0] # Only want conditional output gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) cond_gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) # Unconditional Gradient Penalty alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) real = tf.concat([ x[:args.train_batch_size // 2], wrong_audio[:args.train_batch_size // 2] ], 0) fake = G_z differences = fake - real interpolates = real + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = WaveGANDiscriminator( interpolates, lod, **args.wavegan_d_kwargs)[1] # Only want unconditional output gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) uncond_gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) # Warmup Gradient Penalty # alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) # real = x # fake = wrong_audio # differences = fake - real # interpolates = real + (alpha * differences) # with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): # D_interp = WaveGANDiscriminator(interpolates, lod, **args.wavegan_d_kwargs)[0] # Only want conditional output # gradients = tf.gradients(D_interp, [interpolates])[0] # slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) # warmup_gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.) gradient_penalty = (cond_gradient_penalty + uncond_gradient_penalty) / 2 LAMBDA = 10 D_loss += LAMBDA * gradient_penalty # D_warmup_loss += LAMBDA * warmup_gradient_penalty else: raise NotImplementedError() tf.summary.scalar('G_loss', G_loss) if (args.wavegan_loss == 'wgan-gp'): tf.summary.scalar('Gradient Penalty', LAMBDA * gradient_penalty) if (args.wavegan_loss == 'wgan' or args.wavegan_loss == 'wgan-gp'): if args.use_extra_uncond_loss: tf.summary.scalar('Critic Score - Real Data - Condition Match', -D_loss_real) tf.summary.scalar('Critic Score - Fake Data - Condition Match', D_loss_fake) tf.summary.scalar('Critic Score - Wrong Data - Condition Match', D_loss_wrong) tf.summary.scalar('Critic Score - Real Data', -D_loss_real_uncond) tf.summary.scalar('Critic Score - Wrong Data', -D_loss_wrong_uncond) tf.summary.scalar('Critic Score - Fake Data', D_loss_fake_uncond) tf.summary.scalar('Wasserstein Distance - No Regularization Term', -((D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \ + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond) / 2)) tf.summary.scalar('Wasserstein Distance - Real-Wrong Only', -(D_loss_real + D_loss_wrong)) tf.summary.scalar('Wasserstein Distance - Real-Fake Only', -((D_loss_real + D_loss_fake \ + D_loss_real_uncond + D_loss_fake_uncond) / 2)) else: tf.summary.scalar('Critic Score - Real Data', -D_loss_real) tf.summary.scalar('Critic Score - Wrong Data', D_loss_wrong) tf.summary.scalar('Critic Score - Fake Data', D_loss_fake) tf.summary.scalar( 'Wasserstein Distance - No Regularization Term', -(D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake))) tf.summary.scalar('Wasserstein Distance - With Regularization Term', -D_loss) else: if args.use_extra_uncond_loss: tf.summary.scalar('D_acc_uncond', 0.5 * ((0.5 * (tf.reduce_mean(tf.sigmoid(D_x[1])) + tf.reduce_mean(tf.sigmoid(D_w[1])))) \ + tf.reduce_mean(1 - tf.sigmoid(D_G_z[1])))) tf.summary.scalar('D_acc', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \ + 0.5 * (tf.reduce_mean(1 - tf.sigmoid(D_w[0])) + tf.reduce_mean(1 - tf.sigmoid(D_G_z[0]))))) tf.summary.scalar('D_acc_real_wrong_only', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \ + tf.reduce_mean(1 - tf.sigmoid(D_w[0])))) tf.summary.scalar('D_loss_cond_real', D_loss_real) tf.summary.scalar('D_loss_uncond_real', D_loss_real_uncond) tf.summary.scalar('D_loss_cond_wrong', D_loss_wrong) tf.summary.scalar('D_loss_uncond_wrong', D_loss_wrong_uncond) tf.summary.scalar('D_loss_cond_fake', D_loss_fake) tf.summary.scalar('D_loss_uncond_fake', D_loss_fake_uncond) tf.summary.scalar('D_loss_unregularized', (D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \ + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond) / 2) else: tf.summary.scalar('D_acc', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \ + 0.5 * (tf.reduce_mean(1 - tf.sigmoid(D_w[0])) + tf.reduce_mean(1 - tf.sigmoid(D_G_z[0]))))) tf.summary.scalar('D_loss_real', D_loss_real) tf.summary.scalar('D_loss_wrong', D_loss_wrong) tf.summary.scalar('D_loss_fake', D_loss_fake) tf.summary.scalar('D_loss_unregularized', D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)) tf.summary.scalar('D_loss', D_loss) # Create (recommended) optimizer if args.wavegan_loss == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif args.wavegan_loss == 'lsgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4) elif args.wavegan_loss == 'wgan': G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) elif args.wavegan_loss == 'wgan-gp': G_opt = tf.train.AdamOptimizer(learning_rate=4e-4, beta1=0.0, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=4e-4, beta1=0.0, beta2=0.9) else: raise NotImplementedError() # Optimizer internal state reset ops reset_G_opt_op = tf.variables_initializer(G_opt.variables()) reset_D_opt_op = tf.variables_initializer(D_opt.variables()) # Create training ops G_train_op = G_opt.minimize( G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) def smoothstep(x, mi, mx): return mi + (mx - mi) * (lambda t: np.where( t < 0, 0, np.where(t <= 1, 3 * t**2 - 2 * t**3, 1)))(x) def np_lerp_clip(t, a, b): return a + (b - a) * np.clip(t, 0.0, 1.0) def get_lod_at_step(step): return np.piecewise(float(step), [ step < 10000, 10000 <= step < 20000, 20000 <= step < 30000, 30000 <= step < 40000, 40000 <= step < 50000, 50000 <= step < 60000, 60000 <= step < 70000, 70000 <= step < 80000, 80000 <= step < 90000, 90000 <= step < 100000 ], [ 0, lambda x: np_lerp_clip((x - 10000) / 10000, 0, 1), 1, lambda x: np_lerp_clip( (x - 30000) / 10000, 1, 2), 2, lambda x: np_lerp_clip( (x - 50000) / 10000, 2, 3), 3, lambda x: np_lerp_clip( (x - 70000) / 10000, 3, 4), 4, lambda x: np_lerp_clip( (x - 90000) / 10000, 4, 5), 5 ]) def my_filter_callable(datum, tensor): if (not isinstance(tensor, debug_data.InconvertibleTensorProto)) and ( tensor.dtype == np.float32 or tensor.dtype == np.float64): return np.any([ np.any(np.greater_equal(tensor, 50.0)), np.any(np.less_equal(tensor, -50.0)) ]) else: return False # Create a LocalCLIDebugHook and use it as a monitor # debug_hook = tf_debug.LocalCLIDebugHook(dump_root='C:/d/t/') # debug_hook.add_tensor_filter('large_values', my_filter_callable) # hooks = [debug_hook] # Run training with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: # Get the summary writer for writing extra summary statistics summary_writer = SummaryWriterCache.get(args.train_dir) cur_lod = 0 while True: # Calculate Maximum LOD to train step = sess.run(tf.train.get_or_create_global_step(), feed_dict={lod: cur_lod}) cur_lod = get_lod_at_step(step) prev_lod = get_lod_at_step(step - 1) # Reset optimizer internal state when new layers are introduced if np.floor(cur_lod) != np.floor(prev_lod) or np.ceil( cur_lod) != np.ceil(prev_lod): print( "Resetting optimizers' internal states at step {}".format( step)) sess.run([reset_G_opt_op, reset_D_opt_op], feed_dict={lod: cur_lod}) # Output current LOD and 'steps at currrent LOD' to tensorboard step = float( sess.run(tf.train.get_or_create_global_step(), feed_dict={lod: cur_lod})) lod_summary = tf.Summary(value=[ tf.Summary.Value(tag="current_lod", simple_value=float(cur_lod)), ]) summary_writer.add_summary(lod_summary, step) # Train discriminator for i in xrange(args.wavegan_disc_nupdates): sess.run(D_train_op, feed_dict={lod: cur_lod}) # Enforce Lipschitz constraint for WGAN if D_clip_weights is not None: sess.run(D_clip_weights, feed_dict={lod: cur_lod}) # Train generator sess.run(G_train_op, feed_dict={lod: cur_lod})
os.makedirs(model_dir) LOGGER.info('Saving configurations...') config_path = os.path.join(model_dir, 'config.json') with open(config_path, 'w') as f: json.dump(args, f) # Try on some training data LOGGER.info('Loading audio data...') audio_filepaths = get_all_audio_filepaths(args['audio_dir']) train_gen, valid_data, test_data \ = create_data_split(audio_filepaths, args['valid_ratio'], args['test_ratio'], batch_size, batch_size, batch_size) LOGGER.info('Creating models...') model_gen = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim, post_proc_filt_len=args['post_proc_filt_len'], upsample=True, num_class=10) model_dis = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus, alpha=args['alpha'], shift_factor=args['shift_factor'], batch_shuffle=args['batch_shuffle'], num_class=10) LOGGER.info('Starting training...') model_gen, model_dis, history, final_discr_metrics, samples = train_wgan( model_gen=model_gen, model_dis=model_dis, train_gen=train_gen, valid_data=valid_data, test_data=test_data, num_epochs=args['num_epochs'], batches_per_epoch=args['batches_per_epoch'], batch_size=batch_size, output_dir=model_dir,
os.makedirs(model_dir) LOGGER.info('Saving configurations...') config_path = os.path.join(model_dir, 'config.json') with open(config_path, 'w') as f: json.dump(args, f) # Try on some training data LOGGER.info('Loading audio data...') audio_filepaths = get_all_audio_filepaths(args['audio_dir']) train_gen, valid_data, test_data \ = create_data_split(audio_filepaths, args['valid_ratio'], args['test_ratio'], batch_size, batch_size, batch_size) LOGGER.info('Creating models...') model_gen = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim, post_proc_filt_len=args['post_proc_filt_len']) model_dis = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus, alpha=args['alpha']) LOGGER.info('Starting training...') model_gen, model_dis, history, final_discr_metrics, samples = train_wgan( model_gen=model_gen, model_dis=model_dis, train_gen=train_gen, valid_data=valid_data, test_data=test_data, num_epochs=args['num_epochs'], batches_per_epoch=args['batches_per_epoch'], batch_size=batch_size, output_dir=model_dir, lmbda=args['lmbda'],