Пример #1
0
def main(unused_argv=None):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.config is None:
        raise RuntimeError("No config name specified.")

    config = utils.get_module("ours." + FLAGS.config).Config(FLAGS.batch_size)

    if FLAGS.checkpoint_path:
        checkpoint_path = FLAGS.checkpoint_path
    else:
        expdir = FLAGS.expdir
        tf.logging.info("Will load latest checkpoint from %s.", expdir)
        while not tf.gfile.Exists(expdir):
            tf.logging.fatal("\tExperiment save dir '%s' does not exist!",
                             expdir)
            sys.exit(1)

        try:
            checkpoint_path = tf.train.latest_checkpoint(expdir)
        except tf.errors.NotFoundError:
            tf.logging.fatal(
                "There was a problem determining the latest checkpoint.")
            sys.exit(1)

    if not tf.train.checkpoint_exists(checkpoint_path):
        tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path)
        sys.exit(1)

    tf.logging.info("Will restore from checkpoint: %s", checkpoint_path)

    wavdir = FLAGS.wavdir
    tf.logging.info("Will load Wavs from %s." % wavdir)

    ######################
    # restore the model  #
    ######################
    tf.logging.info("Building graph")
    with tf.Graph().as_default(), tf.device("/gpu:0"):
        with tf.variable_scope('ours_model_var_scope') as var_scope:
            sample_length = FLAGS.sample_length
            batch_size = FLAGS.batch_size
            wav_placeholder = tf.placeholder(tf.float32,
                                             shape=[batch_size, sample_length])
            wav_names = tf.placeholder(tf.string, shape=[batch_size])
            encode_op = config.encode(wav_placeholder)["encoding"]
            decode_op = config.decode(encode_op)["logits"]  # predictions"]
            sample = sampled(decode_op)
            reshaped_sample = tf.reshape(sample, [batch_size, sample_length])
            generate_wav = generate(reshaped_sample)

        ema = tf.train.ExponentialMovingAverage(decay=0.9999)
        variables_to_restore = ema.variables_to_restore()

        # Create a saver, which is used to restore the parameters from checkpoints
        saver = tf.train.Saver(variables_to_restore)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        # Set the opt_level to prevent py_funcs from being executed multiple times.
        session_config.graph_options.optimizer_options.opt_level = 2
        sess = tf.Session("", config=session_config)

        tf.logging.info("\tRestoring from checkpoint.")
        saver.restore(sess, checkpoint_path)

        def is_wav(f):
            return f.lower().endswith(".wav")

        wavfiles = sorted([
            os.path.join(wavdir, fname)
            for fname in tf.gfile.ListDirectory(wavdir) if is_wav(fname)
        ])

        def get_fnames(files):
            fnames_list = []
            for f in files:
                fnames_list.append(ntpath.basename(f))
            return fnames_list

        for start_file in xrange(0, len(wavfiles), batch_size):
            batch_number = (start_file / batch_size) + 1
            tf.logging.info("On batch %d.", batch_number)
            end_file = start_file + batch_size
            files = wavfiles[start_file:end_file]
            wavfile_names = get_fnames(files)

            # Ensure that files has batch_size elements.
            batch_filler = batch_size - len(files)
            files.extend(batch_filler * [files[-1]])

            wavdata = np.array(
                [utils.load_wav(f)[:sample_length] for f in files])

            try:
                res = sess.run(generate_wav,
                               feed_dict={
                                   wav_placeholder: wavdata,
                                   wav_names: wavfile_names
                               })

            except Exception, e:
                tf.logging.info("Unexpected error happened: %s.", e)
                raise

            for decoded_wav, filename in zip(res, wavfile_names):
                write_wav(decoded_wav, FLAGS.sample_rate, FLAGS.wav_savedir,
                          filename)
Пример #2
0
def main(unused_argv=None):
  tf.logging.set_verbosity(FLAGS.log)

  if FLAGS.config is None:
    raise RuntimeError("No config name specified.")

  config = utils.get_module("wavenet." + FLAGS.config).Config()

  if FLAGS.checkpoint_path:
    ckpt_path = FLAGS.ckpt_path
  else:
    expdir = FLAGS.expdir
    tf.logging.info("Will load latest checkpoint from %s.", expdir)
    while not tf.gfile.Exists(expdir):
      tf.logging.fatal("\tExperiment save dir '%s' does not exist!", expdir)
      sys.exit(1)

    try:
      ckpt_path = tf.train.latest_checkpoint(expdir)
    except tf.errors.NotFoundError:
      tf.logging.fatal("There was a problem determining the latest checkpoint.")
      sys.exit(1)

  if not tf.train.checkpoint_exists(ckpt_path):
    tf.logging.fatal("Invalid checkpoint path: %s", ckpt_path)
    sys.exit(1)

  tf.logging.info("Will restore from checkpoint: %s", ckpt_path)

  wavdir = FLAGS.wavdir
  tf.logging.info("Will load Wavs from %s." % wavdir)

  savedir = FLAGS.savedir
  tf.logging.info("Will save embeddings to %s." % savedir)
  if not tf.gfile.Exists(savedir):
    tf.logging.info("Creating save directory...")
    tf.gfile.MakeDirs(savedir)

  tf.logging.info("Building graph")
  with tf.Graph().as_default(), tf.device("/gpu:0"):
    sample_length = FLAGS.sample_length
    batch_size = FLAGS.batch_size
    wav_placeholder = tf.placeholder(
        tf.float32, shape=[batch_size, sample_length])
    graph = config.build({"wav": wav_placeholder}, is_training=False)
    graph_encoding = graph["encoding"]

    ema = tf.train.ExponentialMovingAverage(decay=0.9999)
    variables_to_restore = ema.variables_to_restore()

    # Create a saver, which is used to restore the parameters from checkpoints
    saver = tf.train.Saver(variables_to_restore)

    session_config = tf.ConfigProto(allow_soft_placement=True)
    # Set the opt_level to prevent py_funcs from being executed multiple times.
    session_config.graph_options.optimizer_options.opt_level = 2
    sess = tf.Session("", config=session_config)

    tf.logging.info("\tRestoring from checkpoint.")
    saver.restore(sess, ckpt_path)

    def is_wav(f):
      return f.lower().endswith(".wav")

    wavfiles = sorted([
        os.path.join(wavdir, fname) for fname in tf.gfile.ListDirectory(wavdir)
        if is_wav(fname)
    ])

    for start_file in xrange(0, len(wavfiles), batch_size):
      batch_number = (start_file / batch_size) + 1
      tf.logging.info("On file number %s (batch %d).", start_file, batch_number)
      end_file = start_file + batch_size
      files = wavfiles[start_file:end_file]

      # Ensure that files has batch_size elements.
      batch_filler = batch_size - len(files)
      files.extend(batch_filler * [files[-1]])

      wavdata = np.array([utils.load_wav(f)[:sample_length] for f in files])

      try:
        encoding = sess.run(
            graph_encoding, feed_dict={wav_placeholder: wavdata})
        for num, (wavfile, enc) in enumerate(zip(wavfiles, encoding)):
          filename = "%s_embeddings.npy" % wavfile.split("/")[-1].strip(".wav")
          with tf.gfile.Open(os.path.join(savedir, filename), "w") as f:
            np.save(f, enc)

          if num + batch_filler + 1 == batch_size:
            break
      except Exception, e:
        tf.logging.info("Unexpected error happened: %s.", e)
        raise
Пример #3
0
def main(unused_argv=None):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.config is None:
        raise RuntimeError("No config name specified.")

    config = utils.get_module("wavenet." + FLAGS.config).Config()

    if FLAGS.checkpoint_path:
        checkpoint_path = FLAGS.checkpoint_path
    else:
        expdir = FLAGS.expdir
        tf.logging.info("Will load latest checkpoint from %s.", expdir)
        while not tf.gfile.Exists(expdir):
            tf.logging.fatal("\tExperiment save dir '%s' does not exist!",
                             expdir)
            sys.exit(1)

        try:
            checkpoint_path = tf.train.latest_checkpoint(expdir)
        except tf.errors.NotFoundError:
            tf.logging.fatal(
                "There was a problem determining the latest checkpoint.")
            sys.exit(1)

    if not tf.train.checkpoint_exists(checkpoint_path):
        tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path)
        sys.exit(1)

    tf.logging.info("Will restore from checkpoint: %s", checkpoint_path)

    wavdir = FLAGS.wavdir
    tf.logging.info("Will load Wavs from %s." % wavdir)

    savedir = FLAGS.savedir
    tf.logging.info("Will save embeddings to %s." % savedir)
    if not tf.gfile.Exists(savedir):
        tf.logging.info("Creating save directory...")
        tf.gfile.MakeDirs(savedir)

    tf.logging.info("Building graph")
    with tf.Graph().as_default(), tf.device("/gpu:0"):
        sample_length = FLAGS.sample_length
        batch_size = FLAGS.batch_size
        wav_placeholder = tf.placeholder(tf.float32,
                                         shape=[batch_size, sample_length])
        graph = config.build({"wav": wav_placeholder}, is_training=False)
        graph_encoding = graph["encoding"]

        ema = tf.train.ExponentialMovingAverage(decay=0.9999)
        variables_to_restore = ema.variables_to_restore()

        # Create a saver, which is used to restore the parameters from checkpoints
        saver = tf.train.Saver(variables_to_restore)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        # Set the opt_level to prevent py_funcs from being executed multiple times.
        session_config.graph_options.optimizer_options.opt_level = 2
        sess = tf.Session("", config=session_config)

        tf.logging.info("\tRestoring from checkpoint.")
        saver.restore(sess, checkpoint_path)

        def is_wav(f):
            return f.lower().endswith(".wav")

        wavfiles = sorted([
            os.path.join(wavdir, fname)
            for fname in tf.gfile.ListDirectory(wavdir) if is_wav(fname)
        ])

        for start_file in xrange(0, len(wavfiles), batch_size):
            batch_number = (start_file / batch_size) + 1
            tf.logging.info("On file number %s (batch %d).", start_file,
                            batch_number)
            end_file = start_file + batch_size
            wavefiles_batch = wavfiles[start_file:end_file]

            # Ensure that files has batch_size elements.
            batch_filler = batch_size - len(wavefiles_batch)
            wavefiles_batch.extend(batch_filler * [wavefiles_batch[-1]])

            wavdata = np.array(
                [utils.load_wav(f)[:sample_length] for f in wavefiles_batch])

            try:
                encoding = sess.run(graph_encoding,
                                    feed_dict={wav_placeholder: wavdata})
                for num, (wavfile,
                          enc) in enumerate(zip(wavefiles_batch, encoding)):
                    filename = "%s_embeddings.npy" % wavfile.split(
                        "/")[-1].strip(".wav")
                    with tf.gfile.Open(os.path.join(savedir, filename),
                                       "w") as f:
                        np.save(f, enc)

                    if num + batch_filler + 1 == batch_size:
                        break
            except Exception, e:
                tf.logging.info("Unexpected error happened: %s.", e)
                raise
Пример #4
0
  def eval(self):
    FLAGS = self.FLAGS
    sample_length = FLAGS.sample_length
    batch_size = FLAGS.total_batch_size

    if FLAGS.ckpt_id is not None: #checkpoint_path:
      checkpoint_path = os.path.join(FLAGS.train_path, "model.ckpt-%d" % FLAGS.ckpt_id)
    else:
      tf.logging.info("Will load latest checkpoint from %s.", FLAGS.train_path)
      while not tf.gfile.Exists(FLAGS.train_path):
        tf.logging.fatal("\tTrained model save dir '%s' does not exist!", FLAGS.train_path)
        sys.exit(1)

      try:
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_path)
      except tf.errors.NotFoundError:
        tf.logging.fatal("There was a problem determining the latest checkpoint.")
        sys.exit(1)

    if not tf.train.checkpoint_exists(checkpoint_path):
      tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path)
      sys.exit(1)

    tf.logging.info("Will restore from checkpoint: %s", checkpoint_path)

    wavdir = FLAGS.eval_wav_path
    tf.logging.info("Will load Wavs from %s." % wavdir)


    with tf.Graph().as_default() as graph:
      # build model
      sample_length = FLAGS.sample_length
      wav_placeholder = tf.placeholder(
          tf.float32, shape=[batch_size, sample_length])

      model = self.model.build_eval_model(wav_placeholder)

      with tf.Session(config=self.sess_config) as sess:
        # load trained model
        if checkpoint_path is None:
          raise RuntimeError("No checkpoint is given")
        else:
          variables_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
          restorer = tf.train.Saver(variables_to_restore)
          restorer.restore(sess, checkpoint_path)
          tf.logging.info("Complete restoring parameters from %s" % checkpoint_path)
        # input wavs
        def is_wav(f):
          return f.lower().endswith(".wav")

        wavfiles = sorted([
          os.path.join(wavdir, fname) for fname in tf.gfile.ListDirectory(wavdir)
          if is_wav(fname)
        ])

        def get_fnames(files):
          fnames_list = []
          for f in files:
            fnames_list.append(ntpath.basename(f))
          return fnames_list

        tf.logging.info("wavfiles %d", len(wavfiles))

        for start_file in xrange(0, len(wavfiles), batch_size):
          batch_number = (start_file / batch_size) + 1
          tf.logging.info("On batch %d.", batch_number)
          end_file = start_file + batch_size
          files = wavfiles[start_file:end_file]
          wavfile_names = get_fnames(files)

          # Ensure that files has batch_size elements.
          batch_filler = batch_size - len(files)
          files.extend(batch_filler * [files[-1]])

          wavdatas = np.array([utils.load_wav(f)[:sample_length] for f in files])

          # transfer music
          decoded_wavs = sess.run(model['decoding'],
                              feed_dict={wav_placeholder: wavdatas})
          transferred_wav = utils.inv_mu_law(decoded_wavs - 128)

          def write_wav(waveform, sample_rate, pathname, wavfile_name):
            filename = "%s_decode.wav" % wavfile_name.strip(".wav")
            pathname += "/"+filename
            y = np.array(waveform)
            librosa.output.write_wav(pathname, y, sample_rate)
            print('Updated wav file at {}'.format(pathname))

          tf.logging.info("wavdatas %d", len(wavdatas))
          tf.logging.info("wavfile_names %d", len(wavfile_names))
          tf.logging.info("transferred_wav %s", str(transferred_wav.shape.as_list()))

          for wav_file, filename in zip(transferred_wav.eval(), wavfile_names):
            write_wav(wav_file, FLAGS.sample_rate, FLAGS.transferred_save_path, filename)

    return