Exemplo n.º 1
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if not tf.gfile.Exists(FLAGS.logdir):
    tf.gfile.MakeDirs(FLAGS.logdir)

  with tf.Graph().as_default():

    # If ps_tasks is 0, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    model = utils.get_module("baseline.models.%s" % FLAGS.model)
    hparams = model.get_hparams(FLAGS.config)

    # Run the Reader on the CPU
    cpu_device = ("/job:worker/cpu:0" if FLAGS.ps_tasks else
                  "/job:localhost/replica:0/task:0/cpu:0")

    with tf.device(cpu_device):
      with tf.name_scope("Reader"):
        batch = reader.NSynthDataset(
            FLAGS.train_path, is_training=True).get_baseline_batch(hparams)

    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks)):
      train_op = model.train_op(batch, hparams, FLAGS.config)

      # Run training
      slim.learning.train(
          train_op=train_op,
          logdir=FLAGS.logdir,
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          number_of_steps=hparams.max_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
Exemplo n.º 2
0
def train_op(batch, hparams, config_name):
  """Define a training op, including summaries and optimization.

  Args:
    batch: Dictionary produced by NSynthDataset.
    hparams: Hyperparameters dictionary.
    config_name: Name of config module.

  Returns:
    train_op: A complete iteration of training with summaries.
  """
  config = utils.get_module("baseline.models.ae_configs.%s" % config_name)

  if hparams.raw_audio:
    x = batch["audio"]
    # Add height and channel dims
    x = tf.expand_dims(tf.expand_dims(x, 1), -1)
  else:
    x = batch["spectrogram"]

  # Define the model
  with tf.name_scope("Model"):
    z = config.encode(x, hparams)
    xhat = config.decode(z, batch, hparams)

  # For interpolation
  tf.add_to_collection("x", x)
  tf.add_to_collection("pitch", batch["pitch"])
  tf.add_to_collection("z", z)
  tf.add_to_collection("xhat", xhat)

  # Compute losses
  total_loss = compute_mse_loss(x, xhat, hparams)

  # Apply optimizer
  with tf.name_scope("Optimizer"):
    global_step = tf.get_variable(
        "global_step", [],
        tf.int64,
        initializer=tf.constant_initializer(0),
        trainable=False)
    optimizer = tf.train.AdamOptimizer(hparams.learning_rate, hparams.adam_beta)
    train_step = slim.learning.create_train_op(total_loss,
                                               optimizer,
                                               global_step=global_step)

  return train_step
Exemplo n.º 3
0
def get_hparams(config_name):
  """Set hyperparameters.

  Args:
    config_name: Name of config module to use.

  Returns:
    A HParams object (magenta) with defaults.
  """
  hparams = HParams(
      # Optimization
      batch_size=16,
      learning_rate=1e-4,
      adam_beta=0.5,
      max_steps=6000 * 50000,
      samples_per_second=16000,
      num_samples=64000,
      # Preprocessing
      n_fft=1024,
      hop_length=256,
      mask=True,
      log_mag=True,
      use_cqt=False,
      re_im=False,
      dphase=True,
      mag_only=False,
      pad=True,
      mu_law_num=0,
      raw_audio=False,
      # Graph
      num_latent=64,  # dimension of z.
      cost_phase_mask=False,
      phase_loss_coeff=1.0,
      fw_loss_coeff=1.0,  # Frequency weighted cost
      fw_loss_cutoff=1000,
  )
  # Set values from a dictionary in the config
  config = utils.get_module("baseline.models.ae_configs.%s" % config_name)
  if hasattr(config, "config_hparams"):
    config_hparams = config.config_hparams
    hparams.update(config_hparams)
  return hparams
Exemplo n.º 4
0
def main(unused_argv=None):
  tf.logging.set_verbosity(FLAGS.log)

  if FLAGS.config is None:
    raise RuntimeError("No config name specified.")

  config = utils.get_module("wavenet." + FLAGS.config).Config(
      FLAGS.train_path)

  logdir = FLAGS.logdir
  tf.logging.info("Saving to %s" % logdir)

  with tf.Graph().as_default():
    total_batch_size = FLAGS.total_batch_size
    assert total_batch_size % FLAGS.worker_replicas == 0
    worker_batch_size = total_batch_size / FLAGS.worker_replicas

    # Run the Reader on the CPU
    cpu_device = "/job:localhost/replica:0/task:0/cpu:0"
    if FLAGS.ps_tasks:
      cpu_device = "/job:worker/cpu:0"

    with tf.device(cpu_device):
      inputs_dict = config.get_batch(worker_batch_size)

    with tf.device(
        tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks,
                                       merge_devices=True)):
      global_step = tf.get_variable(
          "global_step", [],
          tf.int32,
          initializer=tf.constant_initializer(0),
          trainable=False)

      # pylint: disable=cell-var-from-loop
      lr = tf.constant(config.learning_rate_schedule[0])
      for key, value in config.learning_rate_schedule.iteritems():
        lr = tf.cond(
            tf.less(global_step, key), lambda: lr, lambda: tf.constant(value))
      # pylint: enable=cell-var-from-loop
      tf.summary.scalar("learning_rate", lr)

      # build the model graph
      outputs_dict = config.build(inputs_dict, is_training=True)
      loss = outputs_dict["loss"]
      tf.summary.scalar("train_loss", loss)

      worker_replicas = FLAGS.worker_replicas
      ema = tf.train.ExponentialMovingAverage(
          decay=0.9999, num_updates=global_step)
      opt = tf.train.SyncReplicasOptimizer(
          tf.train.AdamOptimizer(lr, epsilon=1e-8),
          worker_replicas,
          total_num_replicas=worker_replicas,
          variable_averages=ema,
          variables_to_average=tf.trainable_variables())

      train_op = opt.minimize(
          loss,
          global_step=global_step,
          name="train",
          colocate_gradients_with_ops=True)

      session_config = tf.ConfigProto(allow_soft_placement=True)

      is_chief = (FLAGS.task == 0)
      local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op

      slim.learning.train(
          train_op=train_op,
          logdir=logdir,
          is_chief=is_chief,
          master=FLAGS.master,
          number_of_steps=config.num_iters,
          global_step=global_step,
          log_every_n_steps=250,
          local_init_op=local_init_op,
          save_interval_secs=300,
          sync_optimizer=opt,
          session_config=session_config,)
Exemplo n.º 5
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if FLAGS.checkpoint_path:
    checkpoint_path = FLAGS.checkpoint_path
  else:
    expdir = FLAGS.expdir
    tf.logging.info("Will load latest checkpoint from %s.", expdir)
    while not tf.gfile.Exists(expdir):
      tf.logging.fatal("\tExperiment save dir '%s' does not exist!", expdir)
      sys.exit(1)

    try:
      checkpoint_path = tf.train.latest_checkpoint(expdir)
    except tf.errors.NotFoundError:
      tf.logging.fatal("There was a problem determining the latest checkpoint.")
      sys.exit(1)

  if not tf.train.checkpoint_exists(checkpoint_path):
    tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path)
    sys.exit(1)

  savedir = FLAGS.savedir
  if not tf.gfile.Exists(savedir):
    tf.gfile.MakeDirs(savedir)

  # Make the graph
  with tf.Graph().as_default():
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
      model = utils.get_module("baseline.models.%s" % FLAGS.model)
      hparams = model.get_hparams(FLAGS.config)

      # Load the trained model with is_training=False
      with tf.name_scope("Reader"):
        batch = reader.NSynthDataset(
            FLAGS.tfrecord_path,
            is_training=False).get_baseline_batch(hparams)

      _ = model.train_op(batch, hparams, FLAGS.config)
      z = tf.get_collection("z")[0]

      init_op = tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer())
      sess.run(init_op)

      # Add ops to save and restore all the variables.
      # Restore variables from disk.
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)
      tf.logging.info("Model restored.")

      # Start up some threads
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)
      i = 0
      z_val = []
      try:
        while True:
          if coord.should_stop():
            break
          res_val = sess.run([z])
          z_val.append(res_val[0])
          tf.logging.info("Iter: %d" % i)
          tf.logging.info("Z:{}".format(res_val[0].shape))
          i += 1
          if i + 1 % 1 == 0:
            save_arrays(savedir, hparams, z_val)
      # Report all exceptions to the coordinator, pylint: disable=broad-except
      except Exception as e:
        coord.request_stop(e)
      # pylint: enable=broad-except
      finally:
        save_arrays(savedir, hparams, z_val)
        # Terminate as usual.  It is innocuous to request stop twice.
        coord.request_stop()
        coord.join(threads)
Exemplo n.º 6
0
def main(unused_argv=None):
  tf.logging.set_verbosity(FLAGS.log)

  if FLAGS.config is None:
    raise RuntimeError("No config name specified.")

  config = utils.get_module("wavenet." + FLAGS.config).Config()

  if FLAGS.checkpoint_path:
    ckpt_path = FLAGS.ckpt_path
  else:
    expdir = FLAGS.expdir
    tf.logging.info("Will load latest checkpoint from %s.", expdir)
    while not tf.gfile.Exists(expdir):
      tf.logging.fatal("\tExperiment save dir '%s' does not exist!", expdir)
      sys.exit(1)

    try:
      ckpt_path = tf.train.latest_checkpoint(expdir)
    except tf.errors.NotFoundError:
      tf.logging.fatal("There was a problem determining the latest checkpoint.")
      sys.exit(1)

  if not tf.train.checkpoint_exists(ckpt_path):
    tf.logging.fatal("Invalid checkpoint path: %s", ckpt_path)
    sys.exit(1)

  tf.logging.info("Will restore from checkpoint: %s", ckpt_path)

  wavdir = FLAGS.wavdir
  tf.logging.info("Will load Wavs from %s." % wavdir)

  savedir = FLAGS.savedir
  tf.logging.info("Will save embeddings to %s." % savedir)
  if not tf.gfile.Exists(savedir):
    tf.logging.info("Creating save directory...")
    tf.gfile.MakeDirs(savedir)

  tf.logging.info("Building graph")
  with tf.Graph().as_default(), tf.device("/gpu:0"):
    sample_length = FLAGS.sample_length
    batch_size = FLAGS.batch_size
    wav_placeholder = tf.placeholder(
        tf.float32, shape=[batch_size, sample_length])
    graph = config.build({"wav": wav_placeholder}, is_training=False)
    graph_encoding = graph["encoding"]

    ema = tf.train.ExponentialMovingAverage(decay=0.9999)
    variables_to_restore = ema.variables_to_restore()

    # Create a saver, which is used to restore the parameters from checkpoints
    saver = tf.train.Saver(variables_to_restore)

    session_config = tf.ConfigProto(allow_soft_placement=True)
    # Set the opt_level to prevent py_funcs from being executed multiple times.
    session_config.graph_options.optimizer_options.opt_level = 2
    sess = tf.Session("", config=session_config)

    tf.logging.info("\tRestoring from checkpoint.")
    saver.restore(sess, ckpt_path)

    def is_wav(f):
      return f.lower().endswith(".wav")

    wavfiles = sorted([
        os.path.join(wavdir, fname) for fname in tf.gfile.ListDirectory(wavdir)
        if is_wav(fname)
    ])

    for start_file in xrange(0, len(wavfiles), batch_size):
      batch_number = (start_file / batch_size) + 1
      tf.logging.info("On file number %s (batch %d).", start_file, batch_number)
      end_file = start_file + batch_size
      files = wavfiles[start_file:end_file]

      # Ensure that files has batch_size elements.
      batch_filler = batch_size - len(files)
      files.extend(batch_filler * [files[-1]])

      wavdata = np.array([utils.load_wav(f)[:sample_length] for f in files])

      try:
        encoding = sess.run(
            graph_encoding, feed_dict={wav_placeholder: wavdata})
        for num, (wavfile, enc) in enumerate(zip(wavfiles, encoding)):
          filename = "%s_embeddings.npy" % wavfile.split("/")[-1].strip(".wav")
          with tf.gfile.Open(os.path.join(savedir, filename), "w") as f:
            np.save(f, enc)

          if num + batch_filler + 1 == batch_size:
            break
      except Exception, e:
        tf.logging.info("Unexpected error happened: %s.", e)
        raise
Exemplo n.º 7
0
def eval_op(batch, hparams, config_name):
  """Define a evaluation op.

  Args:
    batch: Batch produced by NSynthReader.
    hparams: Hyperparameters.
    config_name: Name of config module.

  Returns:
    eval_op: A complete evaluation op with summaries.
  """
  phase = False if hparams.mag_only or hparams.raw_audio else True

  config = utils.get_module("baseline.models.ae_configs.%s" % config_name)
  if hparams.raw_audio:
    x = batch["audio"]
    # Add height and channel dims
    x = tf.expand_dims(tf.expand_dims(x, 1), -1)
  else:
    x = batch["spectrogram"]

  # Define the model
  with tf.name_scope("Model"):
    z = config.encode(x, hparams, is_training=False)
    xhat = config.decode(z, batch, hparams, is_training=False)

  # For interpolation
  tf.add_to_collection("x", x)
  tf.add_to_collection("pitch", batch["pitch"])
  tf.add_to_collection("z", z)
  tf.add_to_collection("xhat", xhat)

  total_loss = compute_mse_loss(x, xhat, hparams)

  # Define the metrics:
  names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
      "Loss": slim.metrics.mean(total_loss),
  })

  # Define the summaries
  for name, value in names_to_values.iteritems():
    slim.summaries.add_scalar_summary(value, name, print_summary=True)

  # Interpolate
  with tf.name_scope("Interpolation"):
    xhat = config.decode(z, batch, hparams, reuse=True, is_training=False)

    # Linear interpolation
    z_shift_one_example = tf.concat([z[1:], z[:1]], 0)
    z_linear_half = (z + z_shift_one_example) / 2.0
    xhat_linear_half = config.decode(z_linear_half, batch, hparams, reuse=True,
                                     is_training=False)

    # Pitch shift

    pitch_plus_2 = tf.clip_by_value(batch["pitch"] + 2, 0, 127)
    pitch_minus_2 = tf.clip_by_value(batch["pitch"] - 2, 0, 127)

    batch["pitch"] = pitch_minus_2
    xhat_pitch_minus_2 = config.decode(z, batch, hparams,
                                       reuse=True, is_training=False)
    batch["pitch"] = pitch_plus_2
    xhat_pitch_plus_2 = config.decode(z, batch, hparams,
                                      reuse=True, is_training=False)

  utils.specgram_summaries(x, "Training Examples", hparams, phase=phase)
  utils.specgram_summaries(xhat, "Reconstructions", hparams, phase=phase)
  utils.specgram_summaries(
      x - xhat, "Difference", hparams, audio=False, phase=phase)
  utils.specgram_summaries(
      xhat_linear_half, "Linear Interp. 0.5", hparams, phase=phase)
  utils.specgram_summaries(xhat_pitch_plus_2, "Pitch +2", hparams, phase=phase)
  utils.specgram_summaries(xhat_pitch_minus_2, "Pitch -2", hparams, phase=phase)

  return names_to_updates.values()
Exemplo n.º 8
0
Arquivo: ae.py Projeto: yynst2/magenta
def eval_op(batch, hparams, config_name):
    """Define a evaluation op.

  Args:
    batch: Batch produced by NSynthReader.
    hparams: Hyperparameters.
    config_name: Name of config module.

  Returns:
    eval_op: A complete evaluation op with summaries.
  """
    phase = not (hparams.mag_only or hparams.raw_audio)

    config = utils.get_module("baseline.models.ae_configs.%s" % config_name)
    if hparams.raw_audio:
        x = batch["audio"]
        # Add height and channel dims
        x = tf.expand_dims(tf.expand_dims(x, 1), -1)
    else:
        x = batch["spectrogram"]

    # Define the model
    with tf.name_scope("Model"):
        z = config.encode(x, hparams, is_training=False)
        xhat = config.decode(z, batch, hparams, is_training=False)

    # For interpolation
    tf.add_to_collection("x", x)
    tf.add_to_collection("pitch", batch["pitch"])
    tf.add_to_collection("z", z)
    tf.add_to_collection("xhat", xhat)

    total_loss = compute_mse_loss(x, xhat, hparams)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        "Loss":
        slim.metrics.mean(total_loss),
    })

    # Define the summaries
    for name, value in names_to_values.iteritems():
        slim.summaries.add_scalar_summary(value, name, print_summary=True)

    # Interpolate
    with tf.name_scope("Interpolation"):
        xhat = config.decode(z, batch, hparams, reuse=True, is_training=False)

        # Linear interpolation
        z_shift_one_example = tf.concat([z[1:], z[:1]], 0)
        z_linear_half = (z + z_shift_one_example) / 2.0
        xhat_linear_half = config.decode(z_linear_half,
                                         batch,
                                         hparams,
                                         reuse=True,
                                         is_training=False)

        # Pitch shift

        pitch_plus_2 = tf.clip_by_value(batch["pitch"] + 2, 0, 127)
        pitch_minus_2 = tf.clip_by_value(batch["pitch"] - 2, 0, 127)

        batch["pitch"] = pitch_minus_2
        xhat_pitch_minus_2 = config.decode(z,
                                           batch,
                                           hparams,
                                           reuse=True,
                                           is_training=False)
        batch["pitch"] = pitch_plus_2
        xhat_pitch_plus_2 = config.decode(z,
                                          batch,
                                          hparams,
                                          reuse=True,
                                          is_training=False)

    utils.specgram_summaries(x, "Training Examples", hparams, phase=phase)
    utils.specgram_summaries(xhat, "Reconstructions", hparams, phase=phase)
    utils.specgram_summaries(x - xhat,
                             "Difference",
                             hparams,
                             audio=False,
                             phase=phase)
    utils.specgram_summaries(xhat_linear_half,
                             "Linear Interp. 0.5",
                             hparams,
                             phase=phase)
    utils.specgram_summaries(xhat_pitch_plus_2,
                             "Pitch +2",
                             hparams,
                             phase=phase)
    utils.specgram_summaries(xhat_pitch_minus_2,
                             "Pitch -2",
                             hparams,
                             phase=phase)

    return names_to_updates.values()
Exemplo n.º 9
0
def main(unused_argv=None):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.config is None:
        raise RuntimeError("No config name specified.")

    config = utils.get_module("wavenet." + FLAGS.config).Config(
        FLAGS.train_path)

    logdir = FLAGS.logdir
    tf.logging.info("Saving to %s" % logdir)

    with tf.Graph().as_default():
        total_batch_size = FLAGS.total_batch_size
        assert total_batch_size % FLAGS.worker_replicas == 0
        worker_batch_size = total_batch_size / FLAGS.worker_replicas

        # Run the Reader on the CPU
        cpu_device = "/job:localhost/replica:0/task:0/cpu:0"
        if FLAGS.ps_tasks:
            cpu_device = "/job:worker/cpu:0"

        with tf.device(cpu_device):
            inputs_dict = config.get_batch(worker_batch_size)

        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks,
                                               merge_devices=True)):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

            # pylint: disable=cell-var-from-loop
            lr = tf.constant(config.learning_rate_schedule[0])
            for key, value in config.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            # pylint: enable=cell-var-from-loop
            tf.summary.scalar("learning_rate", lr)

            # build the model graph
            outputs_dict = config.build(inputs_dict, is_training=True)
            loss = outputs_dict["loss"]
            tf.summary.scalar("train_loss", loss)

            worker_replicas = FLAGS.worker_replicas
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            opt = tf.train.SyncReplicasOptimizer(
                tf.train.AdamOptimizer(lr, epsilon=1e-8),
                worker_replicas,
                total_num_replicas=worker_replicas,
                variable_averages=ema,
                variables_to_average=tf.trainable_variables())

            train_op = opt.minimize(loss,
                                    global_step=global_step,
                                    name="train",
                                    colocate_gradients_with_ops=True)

            session_config = tf.ConfigProto(allow_soft_placement=True)

            is_chief = (FLAGS.task == 0)
            local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op

            slim.learning.train(
                train_op=train_op,
                logdir=logdir,
                is_chief=is_chief,
                master=FLAGS.master,
                number_of_steps=config.num_iters,
                global_step=global_step,
                log_every_n_steps=250,
                local_init_op=local_init_op,
                save_interval_secs=300,
                sync_optimizer=opt,
                session_config=session_config,
            )
Exemplo n.º 10
0
def main(unused_argv):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.checkpoint_path:
        checkpoint_path = FLAGS.checkpoint_path
    else:
        expdir = FLAGS.expdir
        tf.logging.info("Will load latest checkpoint from %s.", expdir)
        while not tf.gfile.Exists(expdir):
            tf.logging.fatal("\tExperiment save dir '%s' does not exist!",
                             expdir)
            sys.exit(1)

        try:
            checkpoint_path = tf.train.latest_checkpoint(expdir)
        except tf.errors.NotFoundError:
            tf.logging.fatal(
                "There was a problem determining the latest checkpoint.")
            sys.exit(1)

    if not tf.train.checkpoint_exists(checkpoint_path):
        tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path)
        sys.exit(1)

    savedir = FLAGS.savedir
    if not tf.gfile.Exists(savedir):
        tf.gfile.MakeDirs(savedir)

    # Make the graph
    with tf.Graph().as_default():
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            model = utils.get_module("baseline.models.%s" % FLAGS.model)
            hparams = model.get_hparams(FLAGS.config)

            # Load the trained model with is_training=False
            with tf.name_scope("Reader"):
                batch = reader.NSynthDataset(
                    FLAGS.tfrecord_path,
                    is_training=False).get_baseline_batch(hparams)

            _ = model.train_op(batch, hparams, FLAGS.config)
            z = tf.get_collection("z")[0]

            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)

            # Add ops to save and restore all the variables.
            # Restore variables from disk.
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)
            tf.logging.info("Model restored.")

            # Start up some threads
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            i = 0
            z_val = []
            try:
                while True:
                    if coord.should_stop():
                        break
                    res_val = sess.run([z])
                    z_val.append(res_val[0])
                    tf.logging.info("Iter: %d" % i)
                    tf.logging.info("Z:{}".format(res_val[0].shape))
                    i += 1
                    if i + 1 % 1 == 0:
                        save_arrays(savedir, hparams, z_val)
            # Report all exceptions to the coordinator, pylint: disable=broad-except
            except Exception as e:
                coord.request_stop(e)
            # pylint: enable=broad-except
            finally:
                save_arrays(savedir, hparams, z_val)
                # Terminate as usual.  It is innocuous to request stop twice.
                coord.request_stop()
                coord.join(threads)
Exemplo n.º 11
0
def main(unused_argv=None):
  tf.logging.set_verbosity(FLAGS.log)

  if FLAGS.config is None:
    raise RuntimeError("No config name specified.")

  logdir = FLAGS.logdir
  tf.logging.info("Saving to %s" % logdir)
  train_files = glob.glob(FLAGS.train_path + "/*")
  assert len(train_files) == FLAGS.gpu

  with tf.Graph().as_default():
    total_batch_size = FLAGS.total_batch_size
    assert total_batch_size % FLAGS.gpu == 0
    worker_batch_size = total_batch_size / FLAGS.gpu
    config = utils.get_module("ours." + FLAGS.config).Config(worker_batch_size)

    # Run the Reader on the CPU
    cpu_device = "/job:localhost/replica:0/task:0/cpu:0"
    if FLAGS.ps_tasks:
      cpu_device = "/job:worker/cpu:0"

    with tf.variable_scope('ours_model_var_scope') as var_scope:
      with tf.device(cpu_device):
        global_step = tf.get_variable(
            "global_step", [],
            tf.int32,
            initializer=tf.constant_initializer(0),
            trainable=False)

        # pylint: disable=cell-var-from-loop
        lr = tf.constant(config.learning_rate_schedule[0])
        for key, value in config.learning_rate_schedule.iteritems():
          lr = tf.cond(
              tf.less(global_step, key), lambda: lr, lambda: tf.constant(value))
        # pylint: enable=cell-var-from-loop

        losses = []
        for i in range(FLAGS.gpu):
          inputs_dict = config.get_batch(train_files[i])
          with tf.device('/gpu:%d' % i):
            with tf.name_scope('GPU_NAME_SCOPE_%d' % i):
              # build the model graph
              encode_dict = config.encode(inputs_dict["wav"])
              decode_dict = config.decode(encode_dict["encoding"])
              loss_dict = config.loss(encode_dict["x_quantized"], decode_dict["logits"])
              loss = loss_dict["loss"]
              losses.append(loss)
              var_scope.reuse_variables()

        avg_loss = tf.reduce_mean(losses, 0)

        worker_replicas = FLAGS.worker_replicas
        ema = tf.train.ExponentialMovingAverage(
            decay=0.9999, num_updates=global_step)

    # with tf.variable_scope('ours_model_var_scope') as var_scope ENDS HERE

    opt = tf.train.SyncReplicasOptimizer(
        tf.train.AdamOptimizer(lr, epsilon=1e-8),
        worker_replicas,
        total_num_replicas=worker_replicas,
        variable_averages=ema,
        variables_to_average=tf.trainable_variables())

    train_op = slim.learning.create_train_op(avg_loss, opt,
        global_step=global_step, colocate_gradients_with_ops=True)

    session_config = tf.ConfigProto(allow_soft_placement=True)

    is_chief = (FLAGS.task == 0)
    local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op

    slim.learning.train(
        train_op=train_op,
        logdir=logdir,
        is_chief=is_chief,
        master=FLAGS.master,
        number_of_steps=FLAGS.num_iters,
        global_step=global_step,
        log_every_n_steps=FLAGS.log_period,
        local_init_op=local_init_op,
        save_interval_secs=FLAGS.ckpt_period,
        sync_optimizer=opt,
        session_config=session_config,)
Exemplo n.º 12
0
def main(unused_argv=None):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.config is None:
        raise RuntimeError("No config name specified.")

    config = utils.get_module("ours." + FLAGS.config).Config(FLAGS.batch_size)

    if FLAGS.checkpoint_path:
        checkpoint_path = FLAGS.checkpoint_path
    else:
        raise RuntimeError("Must specify checkpoint path")

    if not tf.train.checkpoint_exists(checkpoint_path):
        tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path)
        sys.exit(1)

    tf.logging.info("Will restore from checkpoint: %s", checkpoint_path)

    wavdir = FLAGS.wavdir
    tf.logging.info("Will load Wavs from %s." % wavdir)

    savedir = FLAGS.savedir
    tf.logging.info("Will save embeddings to %s." % savedir)
    if not tf.gfile.Exists(savedir):
        tf.logging.info("Creating save directory...")
        tf.gfile.MakeDirs(savedir)

    tf.logging.info("Building graph")
    with tf.Graph().as_default(), tf.device("/gpu:0"):
        with tf.variable_scope('ours_model_var_scope') as var_scope:
            sample_length = FLAGS.sample_length
            batch_size = FLAGS.batch_size
            wav_placeholder = tf.placeholder(tf.float32,
                                             shape=[batch_size, sample_length])
            encode_op = config.encode(wav_placeholder)["encoding"]

        ema = tf.train.ExponentialMovingAverage(decay=0.9999)
        variables_to_restore = ema.variables_to_restore()

        # Create a saver, which is used to restore the parameters from checkpoints
        saver = tf.train.Saver(variables_to_restore)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        # Set the opt_level to prevent py_funcs from being executed multiple times.
        session_config.graph_options.optimizer_options.opt_level = 2
        sess = tf.Session("", config=session_config)

        tf.logging.info("\tRestoring from checkpoint.")
        saver.restore(sess, checkpoint_path)

        def is_wav(f):
            return f.lower().endswith(".wav")

        wavfiles = sorted([
            os.path.join(wavdir, fname)
            for fname in tf.gfile.ListDirectory(wavdir) if is_wav(fname)
        ])

        for start_file in xrange(0, len(wavfiles), batch_size):
            batch_number = (start_file / batch_size) + 1
            tf.logging.info("On file number %s (batch %d).", start_file,
                            batch_number)
            end_file = start_file + batch_size
            files = wavfiles[start_file:end_file]

            # Ensure that files has batch_size elements.
            batch_filler = batch_size - len(files)
            files.extend(batch_filler * [files[-1]])

            wavdata = np.array(
                [utils.load_wav(f)[:sample_length] for f in files])

            try:
                encoding = sess.run(encode_op,
                                    feed_dict={wav_placeholder: wavdata})
                for num, (wavfile, enc) in enumerate(zip(wavfiles, encoding)):
                    filename = "%s_embeddings.npy" % wavfile.split(
                        "/")[-1].strip(".wav")
                    with tf.gfile.Open(os.path.join(savedir, filename),
                                       "w") as f:
                        np.save(f, enc)

                    if num + batch_filler + 1 == batch_size:
                        break
            except Exception, e:
                tf.logging.info("Unexpected error happened: %s.", e)
                raise
Exemplo n.º 13
0
def main(unused_argv=None):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.config is None:
        raise RuntimeError("No config name specified.")

    if FLAGS.vae:
        config = utils.get_module("wavenet." + FLAGS.config).VAEConfig(
            FLAGS.eval_path,
            sample_length=FLAGS.sample_length,
            problem=FLAGS.problem,
            small=FLAGS.small,
            asymmetric=FLAGS.asymmetric,
            aux=FLAGS.aux_coefficient,
            dropout=FLAGS.input_dropout)
    else:
        config = utils.get_module("wavenet." + FLAGS.config).Config(
            FLAGS.eval_path,
            sample_length=FLAGS.sample_length,
            problem=FLAGS.problem,
            small=FLAGS.small,
            asymmetric=FLAGS.asymmetric)

    logdir = FLAGS.logdir
    tf.logging.info("Saving to %s" % logdir)

    with tf.Graph().as_default():
        total_batch_size = FLAGS.total_batch_size
        assert total_batch_size % FLAGS.worker_replicas == 0
        worker_batch_size = total_batch_size / FLAGS.worker_replicas

        # Run the Reader on the CPU
        cpu_device = "/job:localhost/replica:0/task:0/cpu:0"
        if FLAGS.ps_tasks:
            cpu_device = "/job:worker/cpu:0"

        with tf.device(cpu_device):
            inputs_dict = config.get_batch(worker_batch_size,
                                           is_training=False)

        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks,
                                               merge_devices=True)):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

            # build the model graph
            outputs_dict = config.build(inputs_dict, is_training=False)

            if FLAGS.vae:
                if FLAGS.kl_annealing:
                    dist = tfp.distributions.Normal(
                        loc=FLAGS.annealing_loc, scale=FLAGS.annealing_scale)
                    annealing_rate = dist.cdf(tf.to_float(
                        global_step))  # how to adjust the annealing
                else:
                    annealing_rate = 0.
                kl = outputs_dict["loss"]["kl"]
                rec = outputs_dict["loss"]["rec"]
                aux = outputs_dict["loss"]["aux"]
                tf.summary.scalar("kl", kl)
                tf.summary.scalar("rec", rec)
                tf.summary.scalar("annealing_rate", annealing_rate)
                if FLAGS.kl_threshold is not None:
                    kl = tf.maximum(
                        tf.cast(FLAGS.kl_threshold, dtype=kl.dtype), kl)
                if FLAGS.aux_coefficient > 0:
                    tf.summary.scalar("aux", aux)
                loss = rec + annealing_rate * kl + tf.cast(
                    FLAGS.aux_coefficient, dtype=tf.float32) * aux
            else:
                loss = outputs_dict["loss"]

            tf.summary.scalar("train_loss", loss)

            labels = inputs_dict["parameters"]
            x_in = inputs_dict["wav"]
            batch_size, _ = x_in.get_shape().as_list()
            predictions = outputs_dict["predictions"]
            _, pred_dim = predictions.get_shape().as_list()
            predictions = tf.reshape(predictions, [batch_size, -1, pred_dim])
            encodings = outputs_dict["encoding"]

            session_config = tf.ConfigProto(allow_soft_placement=True)

            # Define the metrics:
            if FLAGS.vae:
                names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
                    {
                        'eval/kl':
                        slim.metrics.streaming_mean(kl),
                        'eval/rec':
                        slim.metrics.streaming_mean(rec),
                        'eval/loss':
                        slim.metrics.streaming_mean(loss),
                        'eval/predictions':
                        slim.metrics.streaming_concat(predictions),
                        'eval/labels':
                        slim.metrics.streaming_concat(labels),
                        'eval/encodings':
                        slim.metrics.streaming_concat(encodings),
                        'eval/audio':
                        slim.metrics.streaming_concat(x_in)
                    })
            else:
                names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
                    {
                        'eval/loss':
                        slim.metrics.streaming_mean(loss),
                        'eval/predictions':
                        slim.metrics.streaming_concat(predictions),
                        'eval/labels':
                        slim.metrics.streaming_concat(labels),
                        'eval/encodings':
                        slim.metrics.streaming_concat(encodings),
                        'eval/audio':
                        slim.metrics.streaming_concat(x_in)
                    })

            print('Running evaluation Loop...')
            if FLAGS.checkpoint_path is not None:
                checkpoint_path = FLAGS.checkpoint_path
            else:
                checkpoint_path = tf.train.latest_checkpoint(
                    FLAGS.checkpoint_dir)
            metric_values = slim.evaluation.evaluate_once(
                num_evals=FLAGS.num_evals,
                master=FLAGS.master,
                checkpoint_path=checkpoint_path,
                logdir=FLAGS.logdir,
                eval_op=names_to_updates.values(),
                final_op=names_to_values.values(),
                session_config=session_config)

            names_to_values = dict(zip(names_to_values.keys(), metric_values))

            losses = {}
            data_name = FLAGS.eval_path.split('/')[-1].split('.')[0]
            outpath = os.path.join(FLAGS.logdir, data_name)
            for k, v in names_to_values.items():
                name = k.split('/')[-1]
                if name in ['predictions', 'encodings', 'labels', 'audio']:
                    out = outpath + '-{}'.format(name)
                    if name == 'predictions':
                        v = np.argmax(v, axis=-1)
                        v = utils.inv_mu_law_numpy(v - 128)
                    np.save(out, v)
                else:
                    losses[name] = v

            out_loss = outpath + '-losses.pickle'
            with open(out_loss, 'w') as w:
                pickle.dump(losses, w)
Exemplo n.º 14
0
def main(unused_argv=None):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.config is None:
        raise RuntimeError("No config name specified.")

    if FLAGS.vae:
        config = utils.get_module("wavenet." + FLAGS.config).VAEConfig(
            FLAGS.train_path,
            sample_length=FLAGS.sample_length,
            problem=FLAGS.problem,
            small=FLAGS.small,
            asymmetric=FLAGS.asymmetric,
            num_iters=FLAGS.num_iters,
            aux=FLAGS.aux_coefficient,
            dropout=FLAGS.input_dropout)
    else:
        config = utils.get_module("wavenet." + FLAGS.config).Config(
            FLAGS.train_path,
            sample_length=FLAGS.sample_length,
            problem=FLAGS.problem,
            small=FLAGS.small,
            asymmetric=FLAGS.asymmetric,
            num_iters=FLAGS.num_iters)

    logdir = FLAGS.logdir
    tf.logging.info("Saving to %s" % logdir)

    with tf.Graph().as_default():
        total_batch_size = FLAGS.total_batch_size
        assert total_batch_size % FLAGS.worker_replicas == 0
        worker_batch_size = total_batch_size / FLAGS.worker_replicas

        # Run the Reader on the CPU
        cpu_device = "/job:localhost/replica:0/task:0/cpu:0"
        if FLAGS.ps_tasks:
            cpu_device = "/job:worker/cpu:0"

        with tf.device(cpu_device):
            inputs_dict = config.get_batch(worker_batch_size)

        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks,
                                               merge_devices=True)):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

            # pylint: disable=cell-var-from-loop
            lr = tf.constant(config.learning_rate_schedule[0])
            for key, value in config.learning_rate_schedule.iteritems():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            # pylint: enable=cell-var-from-loop
            tf.summary.scalar("learning_rate", lr)

            # build the model graph
            outputs_dict = config.build(inputs_dict, is_training=True)

            if FLAGS.vae:
                if FLAGS.kl_annealing:
                    dist = tfp.distributions.Normal(
                        loc=FLAGS.annealing_loc, scale=FLAGS.annealing_scale)
                    annealing_rate = dist.cdf(tf.to_float(
                        global_step))  # how to adjust the annealing
                else:
                    annealing_rate = 0.
                kl = outputs_dict["loss"]["kl"]
                rec = outputs_dict["loss"]["rec"]
                aux = outputs_dict["loss"]["aux"]
                tf.summary.scalar("kl", kl)
                tf.summary.scalar("rec", rec)
                tf.summary.scalar("annealing_rate", annealing_rate)
                if FLAGS.kl_threshold is not None:
                    kl = tf.maximum(
                        tf.cast(FLAGS.kl_threshold, dtype=kl.dtype), kl)
                if FLAGS.aux_coefficient > 0:
                    tf.summary.scalar("aux", aux)
                loss = rec + annealing_rate * kl + tf.cast(
                    FLAGS.aux_coefficient, dtype=tf.float32) * aux
            else:
                loss = outputs_dict["loss"]

            tf.summary.scalar("train_loss", loss)

            worker_replicas = FLAGS.worker_replicas
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            opt = tf.train.SyncReplicasOptimizer(
                tf.train.AdamOptimizer(lr, epsilon=1e-8),
                worker_replicas,
                total_num_replicas=worker_replicas,
                variable_averages=ema,
                variables_to_average=tf.trainable_variables())

            train_op = slim.learning.create_train_op(
                total_loss=loss,
                optimizer=opt,
                global_step=global_step,
                colocate_gradients_with_ops=True)

            # train_op = opt.minimize(
            #     loss,
            #     global_step=global_step,
            #     name="train",
            #     colocate_gradients_with_ops=True)

            session_config = tf.ConfigProto(allow_soft_placement=True)

            is_chief = (FLAGS.task == 0)
            local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op

            slim.learning.train(
                train_op=train_op,
                logdir=logdir,
                is_chief=is_chief,
                master=FLAGS.master,
                number_of_steps=config.num_iters,
                global_step=global_step,
                log_every_n_steps=250,
                local_init_op=local_init_op,
                save_interval_secs=300,
                sync_optimizer=opt,
                session_config=session_config,
            )