Exemplo n.º 1
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if not tf.gfile.Exists(FLAGS.logdir):
    tf.gfile.MakeDirs(FLAGS.logdir)

  with tf.Graph().as_default():

    # If ps_tasks is 0, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    model = utils.get_module("baseline.models.%s" % FLAGS.model)
    hparams = model.get_hparams(FLAGS.config)

    # Run the Reader on the CPU
    cpu_device = ("/job:worker/cpu:0" if FLAGS.ps_tasks else
                  "/job:localhost/replica:0/task:0/cpu:0")

    with tf.device(cpu_device):
      with tf.name_scope("Reader"):
        batch = reader.NSynthDataset(
            FLAGS.train_path, is_training=True).get_baseline_batch(hparams)

    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks)):
      train_op = model.train_op(batch, hparams, FLAGS.config)

      # Run training
      slim.learning.train(
          train_op=train_op,
          logdir=FLAGS.logdir,
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          number_of_steps=hparams.max_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
Exemplo n.º 2
0
    def get_batch(self, batch_size):
        """Summary

        Parameters
        ----------
        batch_size : TYPE
            Description

        Returns
        -------
        TYPE
            Description
        """
        assert self.train_path is not None
        data_train = reader.NSynthDataset(self.train_path, is_training=True)
        return data_train.get_wavenet_batch(batch_size, length=6144)
Exemplo n.º 3
0
 def get_batch(self, batch_size):
     assert self.train_path is not None
     data_train = reader.NSynthDataset(self.train_path, is_training=True)
     return data_train.get_wavenet_batch(batch_size, length=6144)
Exemplo n.º 4
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if FLAGS.checkpoint_path:
    checkpoint_path = FLAGS.checkpoint_path
  else:
    expdir = FLAGS.expdir
    tf.logging.info("Will load latest checkpoint from %s.", expdir)
    while not tf.gfile.Exists(expdir):
      tf.logging.fatal("\tExperiment save dir '%s' does not exist!", expdir)
      sys.exit(1)

    try:
      checkpoint_path = tf.train.latest_checkpoint(expdir)
    except tf.errors.NotFoundError:
      tf.logging.fatal("There was a problem determining the latest checkpoint.")
      sys.exit(1)

  if not tf.train.checkpoint_exists(checkpoint_path):
    tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path)
    sys.exit(1)

  savedir = FLAGS.savedir
  if not tf.gfile.Exists(savedir):
    tf.gfile.MakeDirs(savedir)

  # Make the graph
  with tf.Graph().as_default():
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
      model = utils.get_module("baseline.models.%s" % FLAGS.model)
      hparams = model.get_hparams(FLAGS.config)

      # Load the trained model with is_training=False
      with tf.name_scope("Reader"):
        batch = reader.NSynthDataset(
            FLAGS.tfrecord_path,
            is_training=False).get_baseline_batch(hparams)

      _ = model.train_op(batch, hparams, FLAGS.config)
      z = tf.get_collection("z")[0]

      init_op = tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer())
      sess.run(init_op)

      # Add ops to save and restore all the variables.
      # Restore variables from disk.
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)
      tf.logging.info("Model restored.")

      # Start up some threads
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)
      i = 0
      z_val = []
      try:
        while True:
          if coord.should_stop():
            break
          res_val = sess.run([z])
          z_val.append(res_val[0])
          tf.logging.info("Iter: %d" % i)
          tf.logging.info("Z:{}".format(res_val[0].shape))
          i += 1
          if i + 1 % 1 == 0:
            save_arrays(savedir, hparams, z_val)
      # Report all exceptions to the coordinator, pylint: disable=broad-except
      except Exception as e:
        coord.request_stop(e)
      # pylint: enable=broad-except
      finally:
        save_arrays(savedir, hparams, z_val)
        # Terminate as usual.  It is innocuous to request stop twice.
        coord.request_stop()
        coord.join(threads)