def get_fold_pianorolls(fold, hparams): dataset = lib_data.get_dataset(FLAGS.data_dir, hparams, fold) pianorolls = dataset.get_pianorolls() tf.logging.info('Retrieving pianorolls from %s set of %s dataset.', fold, hparams.dataset) print_statistics(pianorolls) if FLAGS.fold_index is not None: pianorolls = [pianorolls[int(FLAGS.fold_index)]] return pianorolls
def _run(self, pianorolls, masks): if not np.all(masks): raise NotImplementedError() print("Loading validation pieces from %s..." % self.wmodel.hparams.dataset) dataset = lib_data.get_dataset(self.data_dir, self.wmodel.hparams, "valid") bach_pianorolls = dataset.get_pianorolls() shape = pianorolls.shape pianorolls = np.array( [pianoroll[:shape[1]] for pianoroll in bach_pianorolls])[:shape[0]] self.logger.log(pianorolls=pianorolls, masks=masks, predictions=pianorolls) return pianorolls
def main(unused_argv): """Builds the graph and then runs training and validation.""" print('TensorFlow version:', tf.__version__) tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.data_dir is None: tf.logging.fatal('No input directory was provided.') print(FLAGS.maskout_method, 'separate', FLAGS.separate_instruments) hparams = _hparams_from_flags() # Get data. print('dataset:', FLAGS.dataset, FLAGS.data_dir) print('current dir:', os.path.curdir) train_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'train') valid_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'valid') print('# of train_data:', train_data.num_examples) print('# of valid_data:', valid_data.num_examples) if train_data.num_examples < hparams.batch_size: print('reducing batch_size to %i' % train_data.num_examples) hparams.batch_size = train_data.num_examples train_data.update_hparams(hparams) # Save hparam configs. logdir = os.path.join(FLAGS.logdir, hparams.log_subdir_str) tf.gfile.MakeDirs(logdir) config_fpath = os.path.join(logdir, 'config') tf.logging.info('Writing to %s', config_fpath) with tf.gfile.Open(config_fpath, 'w') as p: hparams.dump(p) # Build the graph and subsequently running it for train and validation. with tf.Graph().as_default(): no_op = tf.no_op() # Build placeholders and training graph, and validation graph with reuse. m = lib_graph.build_graph(is_training=True, hparams=hparams) tf.get_variable_scope().reuse_variables() mvalid = lib_graph.build_graph(is_training=False, hparams=hparams) tracker = Tracker( label='validation loss', patience=FLAGS.patience, decay_op=m.decay_op, save_path=os.path.join(FLAGS.logdir, hparams.log_subdir_str, 'best_model.ckpt')) # Graph will be finalized after instantiating supervisor. sv = tf.train.Supervisor( logdir=logdir, saver=tf.train.Supervisor.USE_DEFAULT if FLAGS.log_progress else None, summary_op=None, save_model_secs=FLAGS.save_model_secs) with sv.PrepareSession() as sess: epoch_count = 0 while epoch_count < FLAGS.num_epochs or not FLAGS.num_epochs: if sv.should_stop(): break # Run training. run_epoch(sv, sess, m, train_data, hparams, m.train_op, 'train', epoch_count) # Run validation. if epoch_count % hparams.eval_freq == 0: estimate_popstats(sv, sess, m, train_data, hparams) loss = run_epoch(sv, sess, mvalid, valid_data, hparams, no_op, 'valid', epoch_count) tracker(loss, sess) if tracker.should_stop(): break epoch_count += 1 print('best', tracker.label, tracker.best) print('Done.') return tracker.best