예제 #1
0
def main(unused_argv):
    """Saves bundle or runs generator based on flags."""
    tf.logging.set_verbosity(FLAGS.log)

    bundle = get_bundle()

    if bundle:
        config_id = bundle.generator_details.id
        config = drums_rnn_model.default_configs[config_id]
        config.hparams.parse(FLAGS.hparams)
    else:
        config = drums_rnn_config_flags.config_from_flags()
    # Having too large of a batch size will slow generation down unnecessarily.
    config.hparams.batch_size = min(config.hparams.batch_size,
                                    FLAGS.beam_size * FLAGS.branch_factor)

    generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
        model=drums_rnn_model.DrumsRnnModel(config),
        details=config.details,
        steps_per_quarter=config.steps_per_quarter,
        checkpoint=get_checkpoint(),
        bundle=bundle)

    if FLAGS.save_generator_bundle:
        bundle_filename = os.path.expanduser(FLAGS.bundle_file)
        if FLAGS.bundle_description is None:
            tf.logging.warning('No bundle description provided.')
        tf.logging.info('Saving generator bundle to %s', bundle_filename)
        generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
    else:
        run_with_flags(generator)
예제 #2
0
def main(unused_argv):
  """Saves bundle or runs generator based on flags."""
  tf.logging.set_verbosity(FLAGS.log)

  bundle = get_bundle()

  if bundle:
    config_id = bundle.generator_details.id
    config = drums_rnn_model.default_configs[config_id]
    config.hparams.parse(FLAGS.hparams)
  else:
    config = drums_rnn_config_flags.config_from_flags()
  # Having too large of a batch size will slow generation down unnecessarily.
  config.hparams.batch_size = min(
      config.hparams.batch_size, FLAGS.beam_size * FLAGS.branch_factor)

  generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
      model=drums_rnn_model.DrumsRnnModel(config),
      details=config.details,
      steps_per_quarter=config.steps_per_quarter,
      checkpoint=get_checkpoint(),
      bundle=bundle)

  if FLAGS.save_generator_bundle:
    bundle_filename = os.path.expanduser(FLAGS.bundle_file)
    if FLAGS.bundle_description is None:
      tf.logging.warning('No bundle description provided.')
    tf.logging.info('Saving generator bundle to %s', bundle_filename)
    generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
  else:
    run_with_flags(generator)
예제 #3
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if not FLAGS.run_dir:
    tf.logging.fatal('--run_dir required')
    return
  if not FLAGS.sequence_example_file:
    tf.logging.fatal('--sequence_example_file required')
    return

  sequence_example_file = os.path.expanduser(FLAGS.sequence_example_file)
  run_dir = os.path.expanduser(FLAGS.run_dir)

  config = drums_rnn_config_flags.config_from_flags()

  mode = 'eval' if FLAGS.eval else 'train'
  graph = events_rnn_graph.build_graph(
      mode, config, sequence_example_file)

  train_dir = os.path.join(run_dir, 'train')
  if not os.path.exists(train_dir):
    tf.gfile.MakeDirs(train_dir)
  tf.logging.info('Train dir: %s', train_dir)

  if FLAGS.eval:
    eval_dir = os.path.join(run_dir, 'eval')
    if not os.path.exists(eval_dir):
      tf.gfile.MakeDirs(eval_dir)
    tf.logging.info('Eval dir: %s', eval_dir)
    events_rnn_train.run_eval(graph, train_dir, eval_dir,
                              FLAGS.num_training_steps, FLAGS.summary_frequency)

  else:
    events_rnn_train.run_training(graph, train_dir, FLAGS.num_training_steps,
                                  FLAGS.summary_frequency)
def main(unused_argv):
    tf.logging.set_verbosity(FLAGS.log)

    config = drums_rnn_config_flags.config_from_flags()
    pipeline_instance = get_pipeline(config, FLAGS.eval_ratio)

    FLAGS.input = os.path.expanduser(FLAGS.input)
    FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
    pipeline.run_pipeline_serial(
        pipeline_instance,
        pipeline.tf_record_iterator(FLAGS.input, pipeline_instance.input_type),
        FLAGS.output_dir)
예제 #5
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  config = drums_rnn_config_flags.config_from_flags()
  pipeline_instance = get_pipeline(
      config, FLAGS.eval_ratio)

  FLAGS.input = os.path.expanduser(FLAGS.input)
  FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
  pipeline.run_pipeline_serial(
      pipeline_instance,
      pipeline.tf_record_iterator(FLAGS.input, pipeline_instance.input_type),
      FLAGS.output_dir)
예제 #6
0
def main(unused_argv):
    tf.logging.set_verbosity(FLAGS.log)

    if not FLAGS.run_dir:
        tf.logging.fatal('--run_dir required')
        return
    if not FLAGS.sequence_example_file:
        tf.logging.fatal('--sequence_example_file required')
        return

    sequence_example_file_paths = tf.gfile.Glob(
        os.path.expanduser(FLAGS.sequence_example_file))
    run_dir = os.path.expanduser(FLAGS.run_dir)

    config = drums_rnn_config_flags.config_from_flags()

    mode = 'eval' if FLAGS.eval else 'train'
    build_graph_fn = events_rnn_graph.get_build_graph_fn(
        mode, config, sequence_example_file_paths)

    train_dir = os.path.join(run_dir, 'train')
    if not os.path.exists(train_dir):
        tf.gfile.MakeDirs(train_dir)
    tf.logging.info('Train dir: %s', train_dir)

    if FLAGS.eval:
        eval_dir = os.path.join(run_dir, 'eval')
        if not os.path.exists(eval_dir):
            tf.gfile.MakeDirs(eval_dir)
        tf.logging.info('Eval dir: %s', eval_dir)
        num_batches = (
            (FLAGS.num_eval_examples
             or magenta.common.count_records(sequence_example_file_paths)) //
            config.hparams.batch_size)
        events_rnn_train.run_eval(build_graph_fn, train_dir, eval_dir,
                                  num_batches)

    else:
        events_rnn_train.run_training(
            build_graph_fn,
            train_dir,
            FLAGS.num_training_steps,
            FLAGS.summary_frequency,
            checkpoints_to_keep=FLAGS.num_checkpoints)
예제 #7
0
def main(unused_argv):
  """Saves bundle or runs generator based on flags."""
  tf.logging.set_verbosity(FLAGS.log)

  config = drums_rnn_config_flags.config_from_flags()
  generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
      model=drums_rnn_model.DrumsRnnModel(config),
      details=config.details,
      steps_per_quarter=FLAGS.steps_per_quarter,
      checkpoint=get_checkpoint(),
      bundle=get_bundle())

  if FLAGS.save_generator_bundle:
    bundle_filename = os.path.expanduser(FLAGS.bundle_file)
    if FLAGS.bundle_description is None:
      tf.logging.warning('No bundle description provided.')
    tf.logging.info('Saving generator bundle to %s', bundle_filename)
    generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
  else:
    run_with_flags(generator)
def main(unused_argv):
    """Saves bundle or runs generator based on flags."""
    tf.logging.set_verbosity(FLAGS.log)

    config = drums_rnn_config_flags.config_from_flags()
    generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
        model=drums_rnn_model.DrumsRnnModel(config),
        details=config.details,
        steps_per_quarter=config.steps_per_quarter,
        checkpoint=get_checkpoint(),
        bundle=get_bundle())

    if FLAGS.save_generator_bundle:
        bundle_filename = os.path.expanduser(FLAGS.bundle_file)
        if FLAGS.bundle_description is None:
            tf.logging.warning('No bundle description provided.')
        tf.logging.info('Saving generator bundle to %s', bundle_filename)
        generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
    else:
        run_with_flags(generator)
예제 #9
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if not FLAGS.run_dir:
    tf.logging.fatal('--run_dir required')
    return
  if not FLAGS.sequence_example_file:
    tf.logging.fatal('--sequence_example_file required')
    return

  sequence_example_file_paths = tf.gfile.Glob(
      os.path.expanduser(FLAGS.sequence_example_file))
  run_dir = os.path.expanduser(FLAGS.run_dir)

  config = drums_rnn_config_flags.config_from_flags()

  mode = 'eval' if FLAGS.eval else 'train'
  build_graph_fn = events_rnn_graph.get_build_graph_fn(
      mode, config, sequence_example_file_paths)

  train_dir = os.path.join(run_dir, 'train')
  if not os.path.exists(train_dir):
    tf.gfile.MakeDirs(train_dir)
  tf.logging.info('Train dir: %s', train_dir)

  if FLAGS.eval:
    eval_dir = os.path.join(run_dir, 'eval')
    if not os.path.exists(eval_dir):
      tf.gfile.MakeDirs(eval_dir)
    tf.logging.info('Eval dir: %s', eval_dir)
    num_batches = (
        (FLAGS.num_eval_examples if FLAGS.num_eval_examples else
         magenta.common.count_records(sequence_example_file_paths)) //
        config.hparams.batch_size)
    events_rnn_train.run_eval(build_graph_fn, train_dir, eval_dir, num_batches)

  else:
    events_rnn_train.run_training(build_graph_fn, train_dir,
                                  FLAGS.num_training_steps,
                                  FLAGS.summary_frequency,
                                  checkpoints_to_keep=FLAGS.num_checkpoints)
예제 #10
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if not FLAGS.run_dir:
    tf.logging.fatal('--run_dir required')
    return
  if not FLAGS.sequence_example_file:
    tf.logging.fatal('--sequence_example_file required')
    return

  sequence_example_file_paths = tf.gfile.Glob(
      os.path.expanduser(FLAGS.sequence_example_file))
  run_dir = os.path.expanduser(FLAGS.run_dir)

  config = drums_rnn_config_flags.config_from_flags()

  mode = 'eval' if FLAGS.eval else 'train'
  graph = events_rnn_graph.build_graph(
      mode, config, sequence_example_file_paths)

  train_dir = os.path.join(run_dir, 'train')
  if not os.path.exists(train_dir):
    tf.gfile.MakeDirs(train_dir)
  tf.logging.info('Train dir: %s', train_dir)

  if FLAGS.eval:
    eval_dir = os.path.join(run_dir, 'eval')
    if not os.path.exists(eval_dir):
      tf.gfile.MakeDirs(eval_dir)
    tf.logging.info('Eval dir: %s', eval_dir)
    events_rnn_train.run_eval(graph, train_dir, eval_dir,
                              FLAGS.num_training_steps, FLAGS.summary_frequency)

  else:
    events_rnn_train.run_training(graph, train_dir, FLAGS.num_training_steps,
                                  FLAGS.summary_frequency)