def eval(fps, args): eval_dir = os.path.join(args.train_dir, 'eval_valid') if not os.path.isdir(eval_dir): os.makedirs(eval_dir) model = WaveAE(Modes.EVAL) model, summary = override_model_attrs(model, args.model_overrides) print('-' * 80) print(summary) print('-' * 80) # Load data with tf.name_scope('loader'): clean, x = waveform_decoder( fps=fps, batch_size=model.eval_batch_size, subseq_len=model.subseq_len, audio_fs=model.audio_fs, audio_mono=True, audio_normalize=True, decode_fastwav=args.data_fastwav, decode_parallel_calls=1, repeat=False, shuffle=False, shuffle_buffer_size=None, subseq_randomize_offset=False, subseq_overlap_ratio=0., subseq_pad_end=True, prefetch_size=None, gpu_num=None) model.build_denoiser(clean, x) saver = tf.train.Saver(var_list=model.restore_vars, max_to_keep=1) summary_writer = tf.summary.FileWriter(eval_dir) ckpt_fp = None while True: latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) if latest_ckpt_fp != ckpt_fp: ckpt_fp = latest_ckpt_fp print('Evaluating {}'.format(ckpt_fp)) with tf.Session() as sess: model.eval_ckpt(ckpt_fp, sess, summary_writer, saver, eval_dir) print('Done!') time.sleep(1)
def train(fps, args): # Initialize model if args.ae_model == "supervised": print ("supervised Model") model = supervisedWaveAE(Modes.TRAIN) else: model = WaveAE(Modes.TRAIN) model, summary = override_model_attrs(model, args.model_overrides) print('-' * 80) print(summary) print('-' * 80) # Load data with tf.name_scope('loader'): clean, x = waveform_decoder( fps=fps, batch_size=model.train_batch_size, subseq_len=model.subseq_len, audio_fs=model.audio_fs, audio_mono=True, audio_normalize=True, decode_fastwav=args.data_fastwav, decode_parallel_calls=4, repeat=True, shuffle=True, shuffle_buffer_size=4096, subseq_randomize_offset=args.data_randomize_offset, subseq_overlap_ratio=args.data_overlap_ratio, subseq_pad_end=True, prefetch_size=64, gpu_num=0) # Create model model(clean, x) # Train # model_dir_path = "/data2/paarth/TrainDir/WaveAE/WaveAEsc09_l1batchnormFalse/eval_sc09_valid" # ckpt = 253802 with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_ckpt_every_nsecs, save_summaries_secs=args.train_summary_every_nsecs) as sess: while not sess.should_stop(): model.train_loop(sess)
def infer(fps, args): infer_dir = os.path.join(args.train_dir, 'infer_valid') if not os.path.isdir(infer_dir): os.makedirs(infer_dir) model = WaveAE(Modes.INFER) model, summary = override_model_attrs(model, args.model_overrides) print('-' * 80) print(summary) print('-' * 80) with tf.name_scope('loader'): clean, x = waveform_decoder( fps=fps, batch_size=model.eval_batch_size, subseq_len=model.subseq_len, audio_fs=model.audio_fs, audio_mono=True, audio_normalize=True, decode_fastwav=args.data_fastwav, decode_parallel_calls=1, repeat=False, shuffle=False, shuffle_buffer_size=None, subseq_randomize_offset=False, subseq_overlap_ratio=0., subseq_pad_end=True, prefetch_size=None, gpu_num=None) model.build_inference(clean, x) saver = tf.train.Saver(var_list=model.G_vars, max_to_keep=1) summary_writer = tf.summary.FileWriter(infer_dir) ckpt_fp = args.infer_ckpt_fp with tf.Session() as sess: model.infer(ckpt_fp, sess, summary_writer, saver, infer_dir) print ("Done")
import tensorflow as tf from AudioModel.util import override_model_attrs import EncDecModel import supervisedModel from AudioModel.model import Model, Modes ckpt = 16384 a = tf.placeholder('float32', (32, 16384, 1, 1)) b = tf.placeholder('float32', (32, 16384, 1, 1)) model = supervisedModel.WaveAE(mode=Modes.TRAIN) oberrides = "objective=l1,batchnorm=False,train_batch_size=32,alpha=100.0,enc_length=16,stride=4,kernel_len=25,subseq_len=16384" model, summary = override_model_attrs(model, oberrides) model(a, b) # with tf.variable_scope('AE'): # with tf.variable_scope('E'): # enc = waveAE.WaveEncoderFactor256(batchnorm=False) # E_x = enc(a, training=False) # e_flat = tf.reshape(E_x, [32, -1]) # rd_tensor = tf.layers.dense(e_flat, 5) # model_dir_path = "/data2/paarth/TrainDir/WaveAE/WaveAEsc09_l1batchnormFalse/eval_sc09_valid" # e_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AE/E') # print (e_vars) # saver = tf.train.Saver(var_list=e_vars) # sess = tf.InteractiveSession() # saver.restore(sess, '{}/best_valid_l2-{}'.format(model_dir_path, ckpt))