def eval(fps, args):
  eval_dir = os.path.join(args.train_dir, 'eval_valid')
  if not os.path.isdir(eval_dir):
    os.makedirs(eval_dir)

  model = WaveAE(Modes.EVAL)
  model, summary = override_model_attrs(model, args.model_overrides)
  print('-' * 80)
  print(summary)
  print('-' * 80)

  # Load data
  with tf.name_scope('loader'):
    clean, x = waveform_decoder(
        fps=fps,
        batch_size=model.eval_batch_size,
        subseq_len=model.subseq_len,
        audio_fs=model.audio_fs,
        audio_mono=True,
        audio_normalize=True,
        decode_fastwav=args.data_fastwav,
        decode_parallel_calls=1,
        repeat=False,
        shuffle=False,
        shuffle_buffer_size=None,
        subseq_randomize_offset=False,
        subseq_overlap_ratio=0.,
        subseq_pad_end=True,
        prefetch_size=None,
        gpu_num=None)

  model.build_denoiser(clean, x)

  saver = tf.train.Saver(var_list=model.restore_vars, max_to_keep=1)
  summary_writer = tf.summary.FileWriter(eval_dir)

  ckpt_fp = None
  while True:
    latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
    if latest_ckpt_fp != ckpt_fp:
      ckpt_fp = latest_ckpt_fp
      print('Evaluating {}'.format(ckpt_fp))
      with tf.Session() as sess:
        model.eval_ckpt(ckpt_fp, sess, summary_writer, saver, eval_dir)
      print('Done!')
    time.sleep(1)
def train(fps, args):
  # Initialize model
  if args.ae_model == "supervised":
    print ("supervised Model")
    model = supervisedWaveAE(Modes.TRAIN)
  else:
    model = WaveAE(Modes.TRAIN)

  model, summary = override_model_attrs(model, args.model_overrides)
  print('-' * 80)
  print(summary)
  print('-' * 80)

  # Load data
  with tf.name_scope('loader'):
    clean, x = waveform_decoder(
        fps=fps,
        batch_size=model.train_batch_size,
        subseq_len=model.subseq_len,
        audio_fs=model.audio_fs,
        audio_mono=True,
        audio_normalize=True,
        decode_fastwav=args.data_fastwav,
        decode_parallel_calls=4,
        repeat=True,
        shuffle=True,
        shuffle_buffer_size=4096,
        subseq_randomize_offset=args.data_randomize_offset,
        subseq_overlap_ratio=args.data_overlap_ratio,
        subseq_pad_end=True,
        prefetch_size=64,
        gpu_num=0)

  # Create model
  model(clean, x)

  # Train
  # model_dir_path = "/data2/paarth/TrainDir/WaveAE/WaveAEsc09_l1batchnormFalse/eval_sc09_valid"
  # ckpt = 253802
  with tf.train.MonitoredTrainingSession(
      checkpoint_dir=args.train_dir,
      save_checkpoint_secs=args.train_ckpt_every_nsecs,
      save_summaries_secs=args.train_summary_every_nsecs) as sess:
    while not sess.should_stop():
      model.train_loop(sess)
def infer(fps, args):
  infer_dir = os.path.join(args.train_dir, 'infer_valid')
  if not os.path.isdir(infer_dir):
    os.makedirs(infer_dir)
  model = WaveAE(Modes.INFER)
  model, summary = override_model_attrs(model, args.model_overrides)
  print('-' * 80)
  print(summary)
  print('-' * 80)

  with tf.name_scope('loader'):
    clean, x = waveform_decoder(
        fps=fps,
        batch_size=model.eval_batch_size,
        subseq_len=model.subseq_len,
        audio_fs=model.audio_fs,
        audio_mono=True,
        audio_normalize=True,
        decode_fastwav=args.data_fastwav,
        decode_parallel_calls=1,
        repeat=False,
        shuffle=False,
        shuffle_buffer_size=None,
        subseq_randomize_offset=False,
        subseq_overlap_ratio=0.,
        subseq_pad_end=True,
        prefetch_size=None,
        gpu_num=None)

  model.build_inference(clean, x)
  saver = tf.train.Saver(var_list=model.G_vars, max_to_keep=1)
  summary_writer = tf.summary.FileWriter(infer_dir)

  ckpt_fp = args.infer_ckpt_fp
  with tf.Session() as sess:
    model.infer(ckpt_fp, sess, summary_writer, saver, infer_dir)
  print ("Done")
Example #4
0
import tensorflow as tf
from AudioModel.util import override_model_attrs
import EncDecModel
import supervisedModel
from AudioModel.model import Model, Modes

ckpt = 16384
a = tf.placeholder('float32', (32, 16384, 1, 1))
b = tf.placeholder('float32', (32, 16384, 1, 1))
model = supervisedModel.WaveAE(mode=Modes.TRAIN)
oberrides = "objective=l1,batchnorm=False,train_batch_size=32,alpha=100.0,enc_length=16,stride=4,kernel_len=25,subseq_len=16384"
model, summary = override_model_attrs(model, oberrides)

model(a, b)

# with tf.variable_scope('AE'):
#   with tf.variable_scope('E'):
#     enc = waveAE.WaveEncoderFactor256(batchnorm=False)
#     E_x = enc(a, training=False)

# e_flat = tf.reshape(E_x, [32, -1])
# rd_tensor = tf.layers.dense(e_flat, 5)
# model_dir_path = "/data2/paarth/TrainDir/WaveAE/WaveAEsc09_l1batchnormFalse/eval_sc09_valid"
# e_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AE/E')
# print (e_vars)
# saver = tf.train.Saver(var_list=e_vars)
# sess = tf.InteractiveSession()

# saver.restore(sess, '{}/best_valid_l2-{}'.format(model_dir_path, ckpt))