def save_checkpoint(self): logdir = tempfile.mkdtemp() save_path = os.path.join(logdir, 'model.ckpt') hparams = lib_hparams.Hyperparameters(**{}) tf.gfile.MakeDirs(logdir) config_fpath = os.path.join(logdir, 'config') with tf.gfile.Open(config_fpath, 'w') as p: hparams.dump(p) with tf.Graph().as_default(): lib_graph.build_graph(is_training=True, hparams=hparams) sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.save(sess, save_path) return logdir
def predict(self, pianorolls, masks): """Evalutes the model once and returns predictions.""" direct_inputs = dict(pianorolls=pianorolls, masks=masks, lengths=tf.to_float([tf.shape(pianorolls)[1]])) model = lib_graph.build_graph(is_training=False, hparams=self.hparams, direct_inputs=direct_inputs, use_placeholders=False) self.logits = model.logits return self.logits
def predict(self, pianorolls, masks): """Evalutes the model once and returns predictions.""" direct_inputs = dict( pianorolls=pianorolls, masks=masks, lengths=tf.to_float([tf.shape(pianorolls)[1]])) model = lib_graph.build_graph( is_training=False, hparams=self.hparams, direct_inputs=direct_inputs, use_placeholders=False) self.logits = model.logits return self.logits
def main(unused_argv): """Builds the graph and then runs training and validation.""" print('TensorFlow version:', tf.__version__) tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.data_dir is None: tf.logging.fatal('No input directory was provided.') print(FLAGS.maskout_method, 'separate', FLAGS.separate_instruments) hparams = _hparams_from_flags() # Get data. print('dataset:', FLAGS.dataset, FLAGS.data_dir) print('current dir:', os.path.curdir) train_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'train') valid_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'valid') print('# of train_data:', train_data.num_examples) print('# of valid_data:', valid_data.num_examples) if train_data.num_examples < hparams.batch_size: print('reducing batch_size to %i' % train_data.num_examples) hparams.batch_size = train_data.num_examples train_data.update_hparams(hparams) # Save hparam configs. logdir = os.path.join(FLAGS.logdir, hparams.log_subdir_str) tf.gfile.MakeDirs(logdir) config_fpath = os.path.join(logdir, 'config') tf.logging.info('Writing to %s', config_fpath) with tf.gfile.Open(config_fpath, 'w') as p: hparams.dump(p) # Build the graph and subsequently running it for train and validation. with tf.Graph().as_default(): no_op = tf.no_op() # Build placeholders and training graph, and validation graph with reuse. m = lib_graph.build_graph(is_training=True, hparams=hparams) tf.get_variable_scope().reuse_variables() mvalid = lib_graph.build_graph(is_training=False, hparams=hparams) tracker = Tracker( label='validation loss', patience=FLAGS.patience, decay_op=m.decay_op, save_path=os.path.join(FLAGS.logdir, hparams.log_subdir_str, 'best_model.ckpt')) # Graph will be finalized after instantiating supervisor. sv = tf.train.Supervisor( logdir=logdir, saver=tf.train.Supervisor.USE_DEFAULT if FLAGS.log_progress else None, summary_op=None, save_model_secs=FLAGS.save_model_secs) with sv.PrepareSession() as sess: epoch_count = 0 while epoch_count < FLAGS.num_epochs or not FLAGS.num_epochs: if sv.should_stop(): break # Run training. run_epoch(sv, sess, m, train_data, hparams, m.train_op, 'train', epoch_count) # Run validation. if epoch_count % hparams.eval_freq == 0: estimate_popstats(sv, sess, m, train_data, hparams) loss = run_epoch(sv, sess, mvalid, valid_data, hparams, no_op, 'valid', epoch_count) tracker(loss, sess) if tracker.should_stop(): break epoch_count += 1 print('best', tracker.label, tracker.best) print('Done.') return tracker.best