コード例 #1
0
 def setUp(self):
     self.config = melody_rnn_model.MelodyRnnConfig(
         None,
         magenta.music.OneHotEventSequenceEncoderDecoder(
             magenta.music.MelodyOneHotEncoding(0, 127)),
         contrib_training.HParams(),
         min_note=0,
         max_note=127,
         transpose_to_key=0)
コード例 #2
0
def config_from_flags():
    """Parses flags and returns the appropriate MelodyRnnConfig.

  If `--config` is supplied, returns the matching default MelodyRnnConfig after
  updating the hyperparameters based on `--hparams`.

  If `--melody_encoder_decoder` is supplied, returns a new MelodyRnnConfig using
  the matching EventSequenceEncoderDecoder, generator details supplied by
  `--generator_id` and `--generator_description`, and hyperparameters based on
  `--hparams`.

  Returns:
    The appropriate MelodyRnnConfig based on the supplied flags.

  Raises:
     MelodyRnnConfigFlagsException: When not exactly one of `--config` or
         `melody_encoder_decoder` is supplied.
  """
    if (FLAGS.melody_encoder_decoder, FLAGS.config).count(None) != 1:
        raise MelodyRnnConfigFlagsException(
            'Exactly one of `--config` or `--melody_encoder_decoder` must be '
            'supplied.')

    if FLAGS.melody_encoder_decoder is not None:
        if FLAGS.melody_encoder_decoder not in melody_encoder_decoders:
            raise MelodyRnnConfigFlagsException(
                '`--melody_encoder_decoder` must be one of %s. Got %s.' %
                (melody_encoder_decoders.keys(), FLAGS.melody_encoder_decoder))
        if FLAGS.generator_id is not None:
            generator_details = magenta.protobuf.generator_pb2.GeneratorDetails(
                id=FLAGS.generator_id)
            if FLAGS.generator_description is not None:
                generator_details.description = FLAGS.generator_description
        else:
            generator_details = None
        encoder_decoder = melody_encoder_decoders[
            FLAGS.melody_encoder_decoder](melody_rnn_model.DEFAULT_MIN_NOTE,
                                          melody_rnn_model.DEFAULT_MAX_NOTE)
        hparams = magenta.common.HParams()
        hparams.parse(FLAGS.hparams)
        return melody_rnn_model.MelodyRnnConfig(generator_details,
                                                encoder_decoder, hparams)
    else:
        if FLAGS.config not in melody_rnn_model.default_configs:
            raise MelodyRnnConfigFlagsException(
                '`--config` must be one of %s. Got %s.' %
                (melody_rnn_model.default_configs.keys(), FLAGS.config))
        config = melody_rnn_model.default_configs[FLAGS.config]
        config.hparams.parse(FLAGS.hparams)
        if FLAGS.generator_id is not None:
            config.details.id = FLAGS.generator_id
        if FLAGS.generator_description is not None:
            config.details.description = FLAGS.generator_description
        if FLAGS.learn_initial_state is not None:
            config.learn_initial_state = True
        return config
コード例 #3
0
 def setUp(self):
     self.config = melody_rnn_model.MelodyRnnConfig(
         None,
         magenta.music.OneHotEventSequenceEncoderDecoder(
             magenta.music.MelodyOneHotEncoding(0, 12)),
         magenta.common.HParams(batch_size=128,
                                rnn_layer_sizes=[128, 128],
                                dropout_keep_prob=0.5,
                                skip_first_n_losses=0,
                                clip_norm=5,
                                initial_learning_rate=0.01,
                                decay_steps=1000,
                                decay_rate=0.85),
         min_note=0,
         max_note=12,
         transpose_to_key=0)