Exemple #1
0
def main(unused_argv):
    """Saves bundle or runs generator based on flags."""
    tf.logging.set_verbosity(FLAGS.log)

    bundle = get_bundle()

    if bundle:
        config_id = bundle.generator_details.id
        config = improv_rnn_model.default_configs[config_id]
        config.hparams.parse(FLAGS.hparams)
    else:
        config = improv_rnn_config_flags.config_from_flags()
    # Having too large of a batch size will slow generation down unnecessarily.
    config.hparams.batch_size = min(config.hparams.batch_size,
                                    FLAGS.beam_size * FLAGS.branch_factor)

    generator = improv_rnn_sequence_generator.ImprovRnnSequenceGenerator(
        model=improv_rnn_model.ImprovRnnModel(config),
        details=config.details,
        steps_per_quarter=config.steps_per_quarter,
        checkpoint=get_checkpoint(),
        bundle=bundle)

    if FLAGS.save_generator_bundle:
        bundle_filename = os.path.expanduser(FLAGS.bundle_file)
        if FLAGS.bundle_description is None:
            tf.logging.warning('No bundle description provided.')
        tf.logging.info('Saving generator bundle to %s', bundle_filename)
        generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
    else:
        run_with_flags(generator)
Exemple #2
0
def get_generator_map():
  """Returns a map from the generator ID to its SequenceGenerator class.

  Binds the `config` argument so that the constructor matches the
  BaseSequenceGenerator class.

  Returns:
    Map from the generator ID to its SequenceGenerator class with a bound
    `config` argument.
  """
  return {key: partial(ImprovRnnSequenceGenerator,
                       improv_rnn_model.ImprovRnnModel(config), config.details)
          for (key, config) in improv_rnn_model.default_configs.items()}
def main(unused_argv):
  """Saves bundle or runs generator based on flags."""
  tf.logging.set_verbosity(FLAGS.log)

  config = improv_rnn_config_flags.config_from_flags()
  generator = improv_rnn_sequence_generator.ImprovRnnSequenceGenerator(
      model=improv_rnn_model.ImprovRnnModel(config),
      details=config.details,
      steps_per_quarter=config.steps_per_quarter,
      checkpoint=get_checkpoint(),
      bundle=get_bundle())

  if FLAGS.save_generator_bundle:
    bundle_filename = os.path.expanduser(FLAGS.bundle_file)
    if FLAGS.bundle_description is None:
      tf.logging.warning('No bundle description provided.')
    tf.logging.info('Saving generator bundle to %s', bundle_filename)
    generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
  else:
    run_with_flags(generator)
 def create_sequence_generator(config, **kwargs):
     return ImprovRnnSequenceGenerator(
         improv_rnn_model.ImprovRnnModel(config),
         config.details,
         steps_per_quarter=config.steps_per_quarter,
         **kwargs)
Exemple #5
0
 def create_sequence_generator(config, **kwargs):
   return ImprovRnnSequenceGenerator(
       improv_rnn_model.ImprovRnnModel(config), config.details, **kwargs)