def main(unused_argv):
    tf.logging.set_verbosity(FLAGS.log)

    if not FLAGS.run_dir:
        tf.logging.fatal('--run_dir required')
        return
    if not FLAGS.sequence_example_file:
        tf.logging.fatal('--sequence_example_file required')
        return

    sequence_example_file_paths = tf.gfile.Glob(
        os.path.expanduser(FLAGS.sequence_example_file))
    run_dir = os.path.expanduser(FLAGS.run_dir)

    config = pianoroll_rnn_nade_model.default_configs[FLAGS.config]
    config.hparams.parse(FLAGS.hparams)

    mode = 'eval' if FLAGS.eval else 'train'
    build_graph_fn = pianoroll_rnn_nade_graph.get_build_graph_fn(
        mode, config, sequence_example_file_paths)

    train_dir = os.path.join(run_dir, 'train')
    tf.gfile.MakeDirs(train_dir)
    tf.logging.info('Train dir: %s', train_dir)

    if FLAGS.eval:
        eval_dir = os.path.join(run_dir, 'eval')
        tf.gfile.MakeDirs(eval_dir)
        tf.logging.info('Eval dir: %s', eval_dir)
        num_batches = (
            (FLAGS.num_eval_examples
             or magenta.common.count_records(sequence_example_file_paths)) //
            config.hparams.batch_size)
        events_rnn_train.run_eval(build_graph_fn, train_dir, eval_dir,
                                  num_batches)

    else:
        events_rnn_train.run_training(
            build_graph_fn,
            train_dir,
            FLAGS.num_training_steps,
            FLAGS.summary_frequency,
            checkpoints_to_keep=FLAGS.num_checkpoints)
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if not FLAGS.run_dir:
    tf.logging.fatal('--run_dir required')
    return
  if not FLAGS.sequence_example_file:
    tf.logging.fatal('--sequence_example_file required')
    return

  sequence_example_file_paths = tf.gfile.Glob(
      os.path.expanduser(FLAGS.sequence_example_file))
  run_dir = os.path.expanduser(FLAGS.run_dir)

  config = pianoroll_rnn_nade_model.default_configs[FLAGS.config]
  config.hparams.parse(FLAGS.hparams)

  mode = 'eval' if FLAGS.eval else 'train'
  build_graph_fn = pianoroll_rnn_nade_graph.get_build_graph_fn(
      mode, config, sequence_example_file_paths)

  train_dir = os.path.join(run_dir, 'train')
  tf.gfile.MakeDirs(train_dir)
  tf.logging.info('Train dir: %s', train_dir)

  if FLAGS.eval:
    eval_dir = os.path.join(run_dir, 'eval')
    tf.gfile.MakeDirs(eval_dir)
    tf.logging.info('Eval dir: %s', eval_dir)
    num_batches = (
        (FLAGS.num_eval_examples or
         magenta.common.count_records(sequence_example_file_paths)) //
        config.hparams.batch_size)
    events_rnn_train.run_eval(build_graph_fn, train_dir, eval_dir, num_batches)

  else:
    events_rnn_train.run_training(build_graph_fn, train_dir,
                                  FLAGS.num_training_steps,
                                  FLAGS.summary_frequency,
                                  checkpoints_to_keep=FLAGS.num_checkpoints)
 def _build_graph_for_generation(self):
   return pianoroll_rnn_nade_graph.get_build_graph_fn(
       'generate', self._config)()
Beispiel #4
0
 def _build_graph_for_generation(self):
     return pianoroll_rnn_nade_graph.get_build_graph_fn(
         'generate', self._config)()