def main(unused_argv): tf.logging.set_verbosity(FLAGS.log) if not tf.gfile.Exists(FLAGS.logdir): tf.gfile.MakeDirs(FLAGS.logdir) with tf.Graph().as_default(): # If ps_tasks is 0, the local device is used. When using multiple # (non-local) replicas, the ReplicaDeviceSetter distributes the variables # across the different devices. model = utils.get_module("baseline.models.%s" % FLAGS.model) hparams = model.get_hparams(FLAGS.config) # Run the Reader on the CPU cpu_device = ("/job:worker/cpu:0" if FLAGS.ps_tasks else "/job:localhost/replica:0/task:0/cpu:0") with tf.device(cpu_device): with tf.name_scope("Reader"): batch = reader.NSynthDataset( FLAGS.train_path, is_training=True).get_baseline_batch(hparams) with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks)): train_op = model.train_op(batch, hparams, FLAGS.config) # Run training slim.learning.train( train_op=train_op, logdir=FLAGS.logdir, master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=hparams.max_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def get_batch(self, batch_size): """Summary Parameters ---------- batch_size : TYPE Description Returns ------- TYPE Description """ assert self.train_path is not None data_train = reader.NSynthDataset(self.train_path, is_training=True) return data_train.get_wavenet_batch(batch_size, length=6144)
def get_batch(self, batch_size): assert self.train_path is not None data_train = reader.NSynthDataset(self.train_path, is_training=True) return data_train.get_wavenet_batch(batch_size, length=6144)
def main(unused_argv): tf.logging.set_verbosity(FLAGS.log) if FLAGS.checkpoint_path: checkpoint_path = FLAGS.checkpoint_path else: expdir = FLAGS.expdir tf.logging.info("Will load latest checkpoint from %s.", expdir) while not tf.gfile.Exists(expdir): tf.logging.fatal("\tExperiment save dir '%s' does not exist!", expdir) sys.exit(1) try: checkpoint_path = tf.train.latest_checkpoint(expdir) except tf.errors.NotFoundError: tf.logging.fatal("There was a problem determining the latest checkpoint.") sys.exit(1) if not tf.train.checkpoint_exists(checkpoint_path): tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path) sys.exit(1) savedir = FLAGS.savedir if not tf.gfile.Exists(savedir): tf.gfile.MakeDirs(savedir) # Make the graph with tf.Graph().as_default(): with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: model = utils.get_module("baseline.models.%s" % FLAGS.model) hparams = model.get_hparams(FLAGS.config) # Load the trained model with is_training=False with tf.name_scope("Reader"): batch = reader.NSynthDataset( FLAGS.tfrecord_path, is_training=False).get_baseline_batch(hparams) _ = model.train_op(batch, hparams, FLAGS.config) z = tf.get_collection("z")[0] init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # Add ops to save and restore all the variables. # Restore variables from disk. saver = tf.train.Saver() saver.restore(sess, checkpoint_path) tf.logging.info("Model restored.") # Start up some threads coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) i = 0 z_val = [] try: while True: if coord.should_stop(): break res_val = sess.run([z]) z_val.append(res_val[0]) tf.logging.info("Iter: %d" % i) tf.logging.info("Z:{}".format(res_val[0].shape)) i += 1 if i + 1 % 1 == 0: save_arrays(savedir, hparams, z_val) # Report all exceptions to the coordinator, pylint: disable=broad-except except Exception as e: coord.request_stop(e) # pylint: enable=broad-except finally: save_arrays(savedir, hparams, z_val) # Terminate as usual. It is innocuous to request stop twice. coord.request_stop() coord.join(threads)