def save_checkpoint(self):
        logdir = tempfile.mkdtemp()
        save_path = os.path.join(logdir, 'model.ckpt')

        hparams = lib_hparams.Hyperparameters(**{})

        tf.gfile.MakeDirs(logdir)
        config_fpath = os.path.join(logdir, 'config')
        with tf.gfile.Open(config_fpath, 'w') as p:
            hparams.dump(p)

        with tf.Graph().as_default():
            lib_graph.build_graph(is_training=True, hparams=hparams)
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver()
            saver.save(sess, save_path)

        return logdir
  def save_checkpoint(self):
    logdir = tempfile.mkdtemp()
    save_path = os.path.join(logdir, 'model.ckpt')

    hparams = lib_hparams.Hyperparameters(**{})

    tf.gfile.MakeDirs(logdir)
    config_fpath = os.path.join(logdir, 'config')
    with tf.gfile.Open(config_fpath, 'w') as p:
      hparams.dump(p)

    with tf.Graph().as_default():
      lib_graph.build_graph(is_training=True, hparams=hparams)
      sess = tf.Session()
      sess.run(tf.global_variables_initializer())

      saver = tf.train.Saver()
      saver.save(sess, save_path)

    return logdir
Example #3
0
    def predict(self, pianorolls, masks):
        """Evalutes the model once and returns predictions."""
        direct_inputs = dict(pianorolls=pianorolls,
                             masks=masks,
                             lengths=tf.to_float([tf.shape(pianorolls)[1]]))

        model = lib_graph.build_graph(is_training=False,
                                      hparams=self.hparams,
                                      direct_inputs=direct_inputs,
                                      use_placeholders=False)
        self.logits = model.logits
        return self.logits
  def predict(self, pianorolls, masks):
    """Evalutes the model once and returns predictions."""
    direct_inputs = dict(
        pianorolls=pianorolls, masks=masks,
        lengths=tf.to_float([tf.shape(pianorolls)[1]]))

    model = lib_graph.build_graph(
        is_training=False,
        hparams=self.hparams,
        direct_inputs=direct_inputs,
        use_placeholders=False)
    self.logits = model.logits
    return self.logits
def main(unused_argv):
  """Builds the graph and then runs training and validation."""
  print('TensorFlow version:', tf.__version__)

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.data_dir is None:
    tf.logging.fatal('No input directory was provided.')

  print(FLAGS.maskout_method, 'separate', FLAGS.separate_instruments)

  hparams = _hparams_from_flags()

  # Get data.
  print('dataset:', FLAGS.dataset, FLAGS.data_dir)
  print('current dir:', os.path.curdir)
  train_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'train')
  valid_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'valid')
  print('# of train_data:', train_data.num_examples)
  print('# of valid_data:', valid_data.num_examples)
  if train_data.num_examples < hparams.batch_size:
    print('reducing batch_size to %i' % train_data.num_examples)
    hparams.batch_size = train_data.num_examples

  train_data.update_hparams(hparams)

  # Save hparam configs.
  logdir = os.path.join(FLAGS.logdir, hparams.log_subdir_str)
  tf.gfile.MakeDirs(logdir)
  config_fpath = os.path.join(logdir, 'config')
  tf.logging.info('Writing to %s', config_fpath)
  with tf.gfile.Open(config_fpath, 'w') as p:
    hparams.dump(p)

  # Build the graph and subsequently running it for train and validation.
  with tf.Graph().as_default():
    no_op = tf.no_op()

    # Build placeholders and training graph, and validation graph with reuse.
    m = lib_graph.build_graph(is_training=True, hparams=hparams)
    tf.get_variable_scope().reuse_variables()
    mvalid = lib_graph.build_graph(is_training=False, hparams=hparams)

    tracker = Tracker(
        label='validation loss',
        patience=FLAGS.patience,
        decay_op=m.decay_op,
        save_path=os.path.join(FLAGS.logdir, hparams.log_subdir_str,
                               'best_model.ckpt'))

    # Graph will be finalized after instantiating supervisor.
    sv = tf.train.Supervisor(
        logdir=logdir,
        saver=tf.train.Supervisor.USE_DEFAULT if FLAGS.log_progress else None,
        summary_op=None,
        save_model_secs=FLAGS.save_model_secs)
    with sv.PrepareSession() as sess:
      epoch_count = 0
      while epoch_count < FLAGS.num_epochs or not FLAGS.num_epochs:
        if sv.should_stop():
          break

        # Run training.
        run_epoch(sv, sess, m, train_data, hparams, m.train_op, 'train',
                  epoch_count)

        # Run validation.
        if epoch_count % hparams.eval_freq == 0:
          estimate_popstats(sv, sess, m, train_data, hparams)
          loss = run_epoch(sv, sess, mvalid, valid_data, hparams, no_op,
                           'valid', epoch_count)
          tracker(loss, sess)
          if tracker.should_stop():
            break

        epoch_count += 1

    print('best', tracker.label, tracker.best)
    print('Done.')
    return tracker.best
Example #6
0
def main(unused_argv):
  """Builds the graph and then runs training and validation."""
  print('TensorFlow version:', tf.__version__)

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.data_dir is None:
    tf.logging.fatal('No input directory was provided.')

  print(FLAGS.maskout_method, 'separate', FLAGS.separate_instruments)

  hparams = _hparams_from_flags()

  # Get data.
  print('dataset:', FLAGS.dataset, FLAGS.data_dir)
  print('current dir:', os.path.curdir)
  train_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'train')
  valid_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'valid')
  print('# of train_data:', train_data.num_examples)
  print('# of valid_data:', valid_data.num_examples)
  if train_data.num_examples < hparams.batch_size:
    print('reducing batch_size to %i' % train_data.num_examples)
    hparams.batch_size = train_data.num_examples

  train_data.update_hparams(hparams)

  # Save hparam configs.
  logdir = os.path.join(FLAGS.logdir, hparams.log_subdir_str)
  tf.gfile.MakeDirs(logdir)
  config_fpath = os.path.join(logdir, 'config')
  tf.logging.info('Writing to %s', config_fpath)
  with tf.gfile.Open(config_fpath, 'w') as p:
    hparams.dump(p)

  # Build the graph and subsequently running it for train and validation.
  with tf.Graph().as_default():
    no_op = tf.no_op()

    # Build placeholders and training graph, and validation graph with reuse.
    m = lib_graph.build_graph(is_training=True, hparams=hparams)
    tf.get_variable_scope().reuse_variables()
    mvalid = lib_graph.build_graph(is_training=False, hparams=hparams)

    tracker = Tracker(
        label='validation loss',
        patience=FLAGS.patience,
        decay_op=m.decay_op,
        save_path=os.path.join(FLAGS.logdir, hparams.log_subdir_str,
                               'best_model.ckpt'))

    # Graph will be finalized after instantiating supervisor.
    sv = tf.train.Supervisor(
        logdir=logdir,
        saver=tf.train.Supervisor.USE_DEFAULT if FLAGS.log_progress else None,
        summary_op=None,
        save_model_secs=FLAGS.save_model_secs)
    with sv.PrepareSession() as sess:
      epoch_count = 0
      while epoch_count < FLAGS.num_epochs or not FLAGS.num_epochs:
        if sv.should_stop():
          break

        # Run training.
        run_epoch(sv, sess, m, train_data, hparams, m.train_op, 'train',
                  epoch_count)

        # Run validation.
        if epoch_count % hparams.eval_freq == 0:
          estimate_popstats(sv, sess, m, train_data, hparams)
          loss = run_epoch(sv, sess, mvalid, valid_data, hparams, no_op,
                           'valid', epoch_count)
          tracker(loss, sess)
          if tracker.should_stop():
            break

        epoch_count += 1

    print('best', tracker.label, tracker.best)
    print('Done.')
    return tracker.best