def main(unused_argv):
  """Saves bundle or runs generator based on flags."""
  tf.logging.set_verbosity(FLAGS.log)

  bundle = get_bundle()

  config_id = bundle.generator_details.id if bundle else FLAGS.config
  config = pianoroll_rnn_nade_model.default_configs[config_id]
  config.hparams.parse(FLAGS.hparams)
  # Having too large of a batch size will slow generation down unnecessarily.
  config.hparams.batch_size = min(
      config.hparams.batch_size, FLAGS.beam_size * FLAGS.branch_factor)

  generator = PianorollRnnNadeSequenceGenerator(
      model=pianoroll_rnn_nade_model.PianorollRnnNadeModel(config),
      details=config.details,
      steps_per_quarter=config.steps_per_quarter,
      checkpoint=get_checkpoint(),
      bundle=bundle)

  if FLAGS.save_generator_bundle:
    bundle_filename = os.path.expanduser(FLAGS.bundle_file)
    if FLAGS.bundle_description is None:
      tf.logging.warning('No bundle description provided.')
    tf.logging.info('Saving generator bundle to %s', bundle_filename)
    generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
  else:
    run_with_flags(generator)
 def create_sequence_generator(config, **kwargs):
     return PianorollRnnNadeSequenceGenerator(
         pianoroll_rnn_nade_model.PianorollRnnNadeModel(config),
         config.details,
         steps_per_quarter=config.steps_per_quarter,
         **kwargs)
示例#3
0
            pass
    message = {
        'status': 400,
        'message': "File not found",
    }
    resp = jsonify(message)
    resp.status_code = 400
    return resp


if __name__ == '__main__':
    bundle_file = 'pretrained/pianoroll_rnn_nade.mag'
    with tf.Session():
        tf.logging.set_verbosity(log)

        bundle = get_bundle()

        config_id = bundle.generator_details.id
        config = pianoroll_rnn_nade_model.default_configs[config_id]
        config.hparams.parse(hparams)
        config.hparams.batch_size = min(config.hparams.batch_size,
                                        beam_size * branch_factor)

        generator = PianorollRnnNadeSequenceGenerator(
            model=pianoroll_rnn_nade_model.PianorollRnnNadeModel(config),
            details=config.details,
            steps_per_quarter=config.steps_per_quarter,
            checkpoint=None,
            bundle=bundle)
        app.run(debug=True, host='0.0.0.0', use_reloader=False)