コード例 #1
0
def main(unused_argv):
    tf.logging.set_verbosity(FLAGS.log)

    if not FLAGS.run_dir:
        tf.logging.fatal('--run_dir required')
        return
    if not FLAGS.sequence_example_file:
        tf.logging.fatal('--sequence_example_file required')
        return

    sequence_example_file = os.path.expanduser(FLAGS.sequence_example_file)
    run_dir = os.path.expanduser(FLAGS.run_dir)

    config = melody_rnn_config.config_from_flags()

    mode = 'eval' if FLAGS.eval else 'train'
    graph = melody_rnn_graph.build_graph(mode, config, sequence_example_file)

    train_dir = os.path.join(run_dir, 'train')
    if not os.path.exists(train_dir):
        tf.gfile.MakeDirs(train_dir)
    tf.logging.info('Train dir: %s', train_dir)

    if FLAGS.eval:
        eval_dir = os.path.join(run_dir, 'eval')
        if not os.path.exists(eval_dir):
            tf.gfile.MakeDirs(eval_dir)
        tf.logging.info('Eval dir: %s', eval_dir)
        run_eval(graph, train_dir, eval_dir, FLAGS.num_training_steps,
                 FLAGS.summary_frequency)

    else:
        run_training(graph, train_dir, FLAGS.num_training_steps,
                     FLAGS.summary_frequency)
コード例 #2
0
def run_from_flags():
  tf.logging.set_verbosity(FLAGS.log)

  config = melody_rnn_config.config_from_flags()
  pipeline_instance = get_pipeline(config.encoder_decoder)
  FLAGS.input = os.path.expanduser(FLAGS.input)
  FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
  pipeline.run_pipeline_serial(
      pipeline_instance,
      pipeline.tf_record_iterator(FLAGS.input, pipeline_instance.input_type),
      FLAGS.output_dir)
コード例 #3
0
def main(unused_argv):
    """Saves bundle or runs generator based on flags."""
    generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
        melody_rnn_config.config_from_flags(), FLAGS.steps_per_quarter,
        get_checkpoint(), get_bundle())

    if FLAGS.save_generator_bundle:
        bundle_filename = os.path.expanduser(FLAGS.bundle_file)
        tf.logging.info('Saving generator bundle to %s', bundle_filename)
        generator.create_bundle_file(bundle_filename)
    else:
        run_with_flags(generator)