Example #1
0
def get_fold_pianorolls(fold, hparams):
  dataset = lib_data.get_dataset(FLAGS.data_dir, hparams, fold)
  pianorolls = dataset.get_pianorolls()
  tf.logging.info('Retrieving pianorolls from %s set of %s dataset.',
                  fold, hparams.dataset)
  print_statistics(pianorolls)
  if FLAGS.fold_index is not None:
    pianorolls = [pianorolls[int(FLAGS.fold_index)]]
  return pianorolls
Example #2
0
 def _run(self, pianorolls, masks):
   if not np.all(masks):
     raise NotImplementedError()
   print("Loading validation pieces from %s..." % self.wmodel.hparams.dataset)
   dataset = lib_data.get_dataset(self.data_dir, self.wmodel.hparams, "valid")
   bach_pianorolls = dataset.get_pianorolls()
   shape = pianorolls.shape
   pianorolls = np.array(
       [pianoroll[:shape[1]] for pianoroll in bach_pianorolls])[:shape[0]]
   self.logger.log(pianorolls=pianorolls, masks=masks, predictions=pianorolls)
   return pianorolls
def main(unused_argv):
  """Builds the graph and then runs training and validation."""
  print('TensorFlow version:', tf.__version__)

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.data_dir is None:
    tf.logging.fatal('No input directory was provided.')

  print(FLAGS.maskout_method, 'separate', FLAGS.separate_instruments)

  hparams = _hparams_from_flags()

  # Get data.
  print('dataset:', FLAGS.dataset, FLAGS.data_dir)
  print('current dir:', os.path.curdir)
  train_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'train')
  valid_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'valid')
  print('# of train_data:', train_data.num_examples)
  print('# of valid_data:', valid_data.num_examples)
  if train_data.num_examples < hparams.batch_size:
    print('reducing batch_size to %i' % train_data.num_examples)
    hparams.batch_size = train_data.num_examples

  train_data.update_hparams(hparams)

  # Save hparam configs.
  logdir = os.path.join(FLAGS.logdir, hparams.log_subdir_str)
  tf.gfile.MakeDirs(logdir)
  config_fpath = os.path.join(logdir, 'config')
  tf.logging.info('Writing to %s', config_fpath)
  with tf.gfile.Open(config_fpath, 'w') as p:
    hparams.dump(p)

  # Build the graph and subsequently running it for train and validation.
  with tf.Graph().as_default():
    no_op = tf.no_op()

    # Build placeholders and training graph, and validation graph with reuse.
    m = lib_graph.build_graph(is_training=True, hparams=hparams)
    tf.get_variable_scope().reuse_variables()
    mvalid = lib_graph.build_graph(is_training=False, hparams=hparams)

    tracker = Tracker(
        label='validation loss',
        patience=FLAGS.patience,
        decay_op=m.decay_op,
        save_path=os.path.join(FLAGS.logdir, hparams.log_subdir_str,
                               'best_model.ckpt'))

    # Graph will be finalized after instantiating supervisor.
    sv = tf.train.Supervisor(
        logdir=logdir,
        saver=tf.train.Supervisor.USE_DEFAULT if FLAGS.log_progress else None,
        summary_op=None,
        save_model_secs=FLAGS.save_model_secs)
    with sv.PrepareSession() as sess:
      epoch_count = 0
      while epoch_count < FLAGS.num_epochs or not FLAGS.num_epochs:
        if sv.should_stop():
          break

        # Run training.
        run_epoch(sv, sess, m, train_data, hparams, m.train_op, 'train',
                  epoch_count)

        # Run validation.
        if epoch_count % hparams.eval_freq == 0:
          estimate_popstats(sv, sess, m, train_data, hparams)
          loss = run_epoch(sv, sess, mvalid, valid_data, hparams, no_op,
                           'valid', epoch_count)
          tracker(loss, sess)
          if tracker.should_stop():
            break

        epoch_count += 1

    print('best', tracker.label, tracker.best)
    print('Done.')
    return tracker.best
def main(unused_argv):
  """Builds the graph and then runs training and validation."""
  print('TensorFlow version:', tf.__version__)

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.data_dir is None:
    tf.logging.fatal('No input directory was provided.')

  print(FLAGS.maskout_method, 'separate', FLAGS.separate_instruments)

  hparams = _hparams_from_flags()

  # Get data.
  print('dataset:', FLAGS.dataset, FLAGS.data_dir)
  print('current dir:', os.path.curdir)
  train_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'train')
  valid_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'valid')
  print('# of train_data:', train_data.num_examples)
  print('# of valid_data:', valid_data.num_examples)
  if train_data.num_examples < hparams.batch_size:
    print('reducing batch_size to %i' % train_data.num_examples)
    hparams.batch_size = train_data.num_examples

  train_data.update_hparams(hparams)

  # Save hparam configs.
  logdir = os.path.join(FLAGS.logdir, hparams.log_subdir_str)
  tf.gfile.MakeDirs(logdir)
  config_fpath = os.path.join(logdir, 'config')
  tf.logging.info('Writing to %s', config_fpath)
  with tf.gfile.Open(config_fpath, 'w') as p:
    hparams.dump(p)

  # Build the graph and subsequently running it for train and validation.
  with tf.Graph().as_default():
    no_op = tf.no_op()

    # Build placeholders and training graph, and validation graph with reuse.
    m = lib_graph.build_graph(is_training=True, hparams=hparams)
    tf.get_variable_scope().reuse_variables()
    mvalid = lib_graph.build_graph(is_training=False, hparams=hparams)

    tracker = Tracker(
        label='validation loss',
        patience=FLAGS.patience,
        decay_op=m.decay_op,
        save_path=os.path.join(FLAGS.logdir, hparams.log_subdir_str,
                               'best_model.ckpt'))

    # Graph will be finalized after instantiating supervisor.
    sv = tf.train.Supervisor(
        logdir=logdir,
        saver=tf.train.Supervisor.USE_DEFAULT if FLAGS.log_progress else None,
        summary_op=None,
        save_model_secs=FLAGS.save_model_secs)
    with sv.PrepareSession() as sess:
      epoch_count = 0
      while epoch_count < FLAGS.num_epochs or not FLAGS.num_epochs:
        if sv.should_stop():
          break

        # Run training.
        run_epoch(sv, sess, m, train_data, hparams, m.train_op, 'train',
                  epoch_count)

        # Run validation.
        if epoch_count % hparams.eval_freq == 0:
          estimate_popstats(sv, sess, m, train_data, hparams)
          loss = run_epoch(sv, sess, mvalid, valid_data, hparams, no_op,
                           'valid', epoch_count)
          tracker(loss, sess)
          if tracker.should_stop():
            break

        epoch_count += 1

    print('best', tracker.label, tracker.best)
    print('Done.')
    return tracker.best