def train(fps, args): # Initialize model if args.model_type == "regular": model = Advoc(Modes.TRAIN) elif args.model_type == "small": model = AdvocSmall(Modes.TRAIN) else: raise NotImplementedError() model, summary = override_model_attrs(model, args.model_overrides) model.audio_fs = args.data_sample_rate print('-' * 80) print(summary) print('-' * 80) # Load data with tf.name_scope('loader'): x_magspec, x_wav = decode_extract_and_batch( fps, batch_size=model.train_batch_size, slice_len=model.subseq_len, audio_fs=model.audio_fs, audio_mono=True, audio_normalize=args.data_normalize, decode_fastwav=args.data_fastwav, decode_parallel_calls=4, extract_type='magspec', extract_parallel_calls=8, repeat=True, shuffle=True, shuffle_buffer_size=512, slice_first_only=args.data_slice_first_only, slice_randomize_offset=args.data_slice_randomize_offset, slice_overlap_ratio=args.data_slice_overlap_ratio, slice_pad_end=args.data_slice_pad_end, prefetch_size=model.train_batch_size * 8, prefetch_gpu_num=0) # Create model spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs) x_melspec = spectral.mag_to_mel_linear_spec(x_magspec) x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec, transform='inverse') model(x_inverted_magspec, x_magspec, x_wav, x_melspec) #Train with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_ckpt_every_nsecs, save_summaries_secs=args.train_summary_every_nsecs) as sess: _step = 0 while not sess.should_stop() and _step < args.max_steps: _step = model.train_loop(sess) print("Done!")
def train(fps, args): # Load data with tf.name_scope('loader'): x, x_audio = decode_extract_and_batch( fps=fps, batch_size=TRAIN_BATCH_SIZE, slice_len=64, audio_fs=args.data_sample_rate, audio_mono=True, audio_normalize=args.data_normalize, decode_fastwav=args.data_fastwav, decode_parallel_calls=8, extract_type='melspec', extract_nfft=1024, extract_nhop=256, extract_parallel_calls=8, repeat=True, shuffle=True, shuffle_buffer_size=512, slice_first_only=args.data_slice_first_only, slice_randomize_offset=args.data_slice_randomize_offset, slice_overlap_ratio=args.data_slice_overlap_ratio, slice_pad_end=args.data_slice_pad_end, prefetch_size=TRAIN_BATCH_SIZE * 8, prefetch_gpu_num=args.data_prefetch_gpu_num) x = feats_norm(x) # Data summaries tf.summary.audio('x_audio', x_audio[:, :, 0], args.data_sample_rate) tf.summary.image('x', feats_to_uint8_img(feats_denorm(x))) tf.summary.audio( 'x_inv_audio', feats_to_approx_audio(feats_denorm(x), args.data_sample_rate, 16384, n=3)[:, :, 0], args.data_sample_rate) # Make z vector z = tf.random.normal([TRAIN_BATCH_SIZE, Z_DIM], dtype=tf.float32) # Make generator with tf.variable_scope('G'): G = MelspecGANGenerator() G_z = G(z, training=True) G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') # Summarize G_z tf.summary.image('G_z', feats_to_uint8_img(feats_denorm(G_z))) tf.summary.audio( 'G_z_inv_audio', feats_to_approx_audio(feats_denorm(G_z), args.data_sample_rate, 16384, n=3)[:, :, 0], args.data_sample_rate) # Make real discriminator D = MelspecGANDiscriminator() with tf.name_scope('D_x'), tf.variable_scope('D'): D_x = D(x, training=True) D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') # Make fake discriminator with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): D_G_z = D(G_z, training=True) # Create loss num_disc_updates_per_genr = 1 if TRAIN_LOSS == 'dcgan': fake = tf.zeros([TRAIN_BATCH_SIZE], dtype=tf.float32) real = tf.ones([TRAIN_BATCH_SIZE], dtype=tf.float32) G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real)) D_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake)) D_loss += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real)) D_loss /= 2. elif TRAIN_LOSS == 'wgangp': G_loss = -tf.reduce_mean(D_G_z) D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) alpha = tf.random_uniform(shape=[TRAIN_BATCH_SIZE, 1, 1, 1], minval=0., maxval=1.) differences = G_z - x interpolates = x + (alpha * differences) with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): D_interp = D(interpolates, training=True) LAMBDA = 10 gradients = tf.gradients(D_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2.) D_loss += LAMBDA * gradient_penalty num_disc_updates_per_genr = 5 else: raise ValueError() tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('D_loss', D_loss) # Create opt if TRAIN_LOSS == 'dcgan': G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5) elif TRAIN_LOSS == 'wgangp': # TODO: some igul code uses beta1=0. G_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) D_opt = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) else: raise ValueError() # Create training ops G_train_op = G_opt.minimize( G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step()) D_train_op = D_opt.minimize(D_loss, var_list=D_vars) # Train with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_ckpt_every_nsecs, save_summaries_secs=args.train_summary_every_nsecs) as sess: while not sess.should_stop(): for i in range(num_disc_updates_per_genr): sess.run(D_train_op) sess.run(G_train_op)
def eval(fps, args): if args.eval_dataset_name is not None: eval_dir = os.path.join(args.train_dir, 'eval_{}'.format(args.eval_dataset_name)) else: eval_dir = os.path.join(args.train_dir, 'eval_valid') if not os.path.isdir(eval_dir): os.makedirs(eval_dir) if args.model_type == "regular": model = Advoc(Modes.EVAL) elif args.model_type == "small": model = AdvocSmall(Modes.EVAL) else: raise NotImplementedError() model, summary = override_model_attrs(model, args.model_overrides) model.audio_fs = args.data_sample_rate print('-' * 80) print(summary) print('-' * 80) with tf.name_scope('loader'): x_magspec, x_wav = decode_extract_and_batch( fps, batch_size=model.eval_batch_size, slice_len=model.subseq_len, audio_fs=model.audio_fs, audio_mono=True, audio_normalize=args.data_normalize, decode_fastwav=args.data_fastwav, decode_parallel_calls=4, extract_type='magspec', extract_parallel_calls=8, repeat=False, shuffle=False, shuffle_buffer_size=None, slice_first_only=args.data_slice_first_only, slice_randomize_offset=False, slice_overlap_ratio=0., slice_pad_end=True, prefetch_size=None, prefetch_gpu_num=None) spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs) x_melspec = spectral.mag_to_mel_linear_spec(x_magspec) x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec, transform='inverse') with tf.variable_scope("generator") as vs: if model.generator_type == "pix2pix": gen_magspec = model.build_generator(x_inverted_magspec) elif model.generator_type == "linear": gen_magspec = model.build_linear_generator(x_inverted_magspec) elif model.generator_type == "linear+pix2pix": _temp_spec = model.build_linear_generator(x_melspec) gen_magspec = model.build_linear_generator(_temp_spec) elif model.generator_type == "interp+pix2pix": _temp_spec = tf.image.resize_images(x_melspec, [model.subseq_len, 513]) gen_magspec = model.build_linear_generator(_temp_spec) else: raise NotImplementedError() G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name) gen_loss_L1 = tf.reduce_mean(tf.abs(x_magspec - gen_magspec)) gan_step = tf.train.get_or_create_global_step() gan_saver = tf.train.Saver(var_list=G_vars + [gan_step], max_to_keep=1) all_gen_loss_L1 = tf.placeholder(tf.float32, [None]) summaries = [ tf.summary.scalar('gen_loss_L1', tf.reduce_mean(all_gen_loss_L1)), ] summaries = tf.summary.merge(summaries) # Create summary writer summary_writer = tf.summary.FileWriter(eval_dir) ckpt_fp = None _best_gen_loss_l1 = np.inf while True: latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) if latest_ckpt_fp != ckpt_fp: ckpt_fp = latest_ckpt_fp print('Evaluating {}'.format(ckpt_fp)) with tf.Session() as sess: gan_saver.restore(sess, latest_ckpt_fp) _step = sess.run(gan_step) _all_gen_loss_L1 = [] while True: try: _gen_loss_L1, _gen_magspec, _x_magspec = sess.run( [gen_loss_L1, gen_magspec, x_magspec]) except tf.errors.OutOfRangeError: break _all_gen_loss_L1.append(_gen_loss_L1) _all_gen_loss_L1 = np.array(_all_gen_loss_L1) _summaries = sess.run(summaries, { all_gen_loss_L1: _all_gen_loss_L1, }) summary_writer.add_summary(_summaries, _step) _gen_loss_L1_np = np.mean(_all_gen_loss_L1) if _gen_loss_L1_np < _best_gen_loss_l1: gan_saver.save(sess, os.path.join(eval_dir, 'best_gen_loss_l1'), _step) print("Saved best gen loss l1!") print('Done!') time.sleep(1)
def infer(fps, args): if args.infer_dataset_name is not None: infer_dir = os.path.join(args.train_dir, 'infer_{}'.format(args.infer_dataset_name)) else: infer_dir = os.path.join(args.train_dir, 'infer_valid') if not os.path.isdir(infer_dir): os.makedirs(infer_dir) if args.model_type == "regular": model = Advoc(Modes.INFER) elif args.model_type == "small": model = AdvocSmall(Modes.INFER) else: raise NotImplementedError() model, summary = override_model_attrs(model, args.model_overrides) model.audio_fs = args.data_sample_rate print('-' * 80) print(summary) print('-' * 80) with tf.name_scope('loader'): x_magspec, x_wav = decode_extract_and_batch( fps, batch_size=args.infer_batch_size, slice_len=model.subseq_len, audio_fs=model.audio_fs, audio_mono=True, audio_normalize=args.data_normalize, decode_fastwav=args.data_fastwav, decode_parallel_calls=4, extract_type='magspec', extract_parallel_calls=8, repeat=False, shuffle=False, shuffle_buffer_size=None, slice_randomize_offset=False, slice_overlap_ratio=0., slice_pad_end=True, prefetch_size=None, prefetch_gpu_num=None) spectral = SpectralUtil(n_mels=model.n_mels, fs=model.audio_fs) x_melspec = spectral.mag_to_mel_linear_spec(x_magspec) x_inverted_magspec = spectral.mel_linear_to_mag_spec(x_melspec, transform='inverse') with tf.variable_scope("generator") as vs: if model.generator_type == "pix2pix": gen_magspec = model.build_generator(x_inverted_magspec) elif model.generator_type == "linear": gen_magspec = model.build_linear_generator(x_inverted_magspec) elif model.generator_type == "linear+pix2pix": _temp_spec = model.build_linear_generator(x_melspec) gen_magspec = model.build_linear_generator(_temp_spec) elif model.generator_type == "interp+pix2pix": _temp_spec = tf.image.resize_images(x_melspec, [model.subseq_len, 513]) gen_magspec = model.build_linear_generator(_temp_spec) else: raise NotImplementedError() G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name) step = tf.train.get_or_create_global_step() gan_saver = tf.train.Saver(var_list=G_vars + [step], max_to_keep=1) input_audio = tf.py_func(spectral.audio_from_mag_spec, [x_inverted_magspec[0]], tf.float32, stateful=False) target_audio = tf.py_func(spectral.audio_from_mag_spec, [x_magspec[0]], tf.float32, stateful=False) gen_audio = tf.py_func(spectral.audio_from_mag_spec, [gen_magspec[0]], tf.float32, stateful=False) # dont know why i rehspae them this way. just following past convention. input_audio = tf.reshape(input_audio, [1, -1, 1, 1]) target_audio = tf.reshape(target_audio, [1, -1, 1, 1]) gen_audio = tf.reshape(gen_audio, [1, -1, 1, 1]) summaries = [ tf.summary.audio('infer_x_wav', x_wav[:, :, 0, :], model.audio_fs), tf.summary.audio('infer_gen_audio', gen_audio[:, :, 0, :], model.audio_fs), tf.summary.audio('target_audio', target_audio[:, :, 0, :], model.audio_fs), tf.summary.audio('infer_input_audio', input_audio[:, :, 0, :], model.audio_fs) ] summaries = tf.summary.merge(summaries) # Create saver and summary writer summary_writer = tf.summary.FileWriter(infer_dir) if args.infer_ckpt_path is not None: # Infering From a particular Checkpoint ckpt_fp = args.infer_ckpt_path print('Infereing From {}'.format(ckpt_fp)) with tf.Session() as sess: gan_saver.restore(sess, ckpt_fp) _step = sess.run(step) # Just one batch at a time while True: try: _summaries, mel_np, est_np, act_np, gen_np = sess.run([ summaries, x_melspec, x_inverted_magspec, x_magspec, gen_magspec ]) summary_writer.add_summary(_summaries, _step) except tf.errors.OutOfRangeError: break print('Done!') else: # Continuous Inference ckpt_fp = None while True: with tf.Session() as sess: latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) if latest_ckpt_fp != ckpt_fp: ckpt_fp = latest_ckpt_fp print('Infereing From {}'.format(ckpt_fp)) gan_saver.restore(sess, ckpt_fp) _step = sess.run(step) while True: try: _summaries, mel_np, est_np, act_np, gen_np = sess.run( [ summaries, x_melspec, x_inverted_magspec, x_magspec, gen_magspec ]) summary_writer.add_summary(_summaries, _step) except tf.errors.OutOfRangeError: break print("Done!") time.sleep(1) raise NotImplementedError()