def setUp(self): self.config = melody_rnn_model.MelodyRnnConfig( None, magenta.music.OneHotEventSequenceEncoderDecoder( magenta.music.MelodyOneHotEncoding(0, 127)), contrib_training.HParams(), min_note=0, max_note=127, transpose_to_key=0)
def config_from_flags(): """Parses flags and returns the appropriate MelodyRnnConfig. If `--config` is supplied, returns the matching default MelodyRnnConfig after updating the hyperparameters based on `--hparams`. If `--melody_encoder_decoder` is supplied, returns a new MelodyRnnConfig using the matching EventSequenceEncoderDecoder, generator details supplied by `--generator_id` and `--generator_description`, and hyperparameters based on `--hparams`. Returns: The appropriate MelodyRnnConfig based on the supplied flags. Raises: MelodyRnnConfigFlagsException: When not exactly one of `--config` or `melody_encoder_decoder` is supplied. """ if (FLAGS.melody_encoder_decoder, FLAGS.config).count(None) != 1: raise MelodyRnnConfigFlagsException( 'Exactly one of `--config` or `--melody_encoder_decoder` must be ' 'supplied.') if FLAGS.melody_encoder_decoder is not None: if FLAGS.melody_encoder_decoder not in melody_encoder_decoders: raise MelodyRnnConfigFlagsException( '`--melody_encoder_decoder` must be one of %s. Got %s.' % (melody_encoder_decoders.keys(), FLAGS.melody_encoder_decoder)) if FLAGS.generator_id is not None: generator_details = magenta.protobuf.generator_pb2.GeneratorDetails( id=FLAGS.generator_id) if FLAGS.generator_description is not None: generator_details.description = FLAGS.generator_description else: generator_details = None encoder_decoder = melody_encoder_decoders[ FLAGS.melody_encoder_decoder](melody_rnn_model.DEFAULT_MIN_NOTE, melody_rnn_model.DEFAULT_MAX_NOTE) hparams = magenta.common.HParams() hparams.parse(FLAGS.hparams) return melody_rnn_model.MelodyRnnConfig(generator_details, encoder_decoder, hparams) else: if FLAGS.config not in melody_rnn_model.default_configs: raise MelodyRnnConfigFlagsException( '`--config` must be one of %s. Got %s.' % (melody_rnn_model.default_configs.keys(), FLAGS.config)) config = melody_rnn_model.default_configs[FLAGS.config] config.hparams.parse(FLAGS.hparams) if FLAGS.generator_id is not None: config.details.id = FLAGS.generator_id if FLAGS.generator_description is not None: config.details.description = FLAGS.generator_description if FLAGS.learn_initial_state is not None: config.learn_initial_state = True return config
def setUp(self): self.config = melody_rnn_model.MelodyRnnConfig( None, magenta.music.OneHotEventSequenceEncoderDecoder( magenta.music.MelodyOneHotEncoding(0, 12)), magenta.common.HParams(batch_size=128, rnn_layer_sizes=[128, 128], dropout_keep_prob=0.5, skip_first_n_losses=0, clip_norm=5, initial_learning_rate=0.01, decay_steps=1000, decay_rate=0.85), min_note=0, max_note=12, transpose_to_key=0)