def default_dqn_hparams(): """Generates the default hparams for RLTuner DQN model.""" return tf_lib.HParams(random_action_probability=0.1, store_every_nth=1, train_every_nth=5, minibatch_size=32, discount_rate=0.95, max_experience=100000, target_network_update_rate=0.01)
def default_hparams(): return tf_lib.HParams(batch_size=128, rnn_layer_sizes=[128, 128], dropout_keep_prob=0.5, skip_first_n_losses=0, clip_norm=5, initial_learning_rate=0.01, decay_steps=1000, decay_rate=0.85)
def setUp(self): self.encoder_decoder = melodies_lib.OneHotEncoderDecoder(0, 12, 0) self.hparams = tf_lib.HParams(batch_size=128, rnn_layer_sizes=[128, 128], dropout_keep_prob=0.5, skip_first_n_losses=0, clip_norm=5, initial_learning_rate=0.01, decay_steps=1000, decay_rate=0.85)
def main(_): hparams = (rl_tuner_ops.basic_rnn_hparams() if FLAGS.note_rnn_type == 'basic_rnn' else rl_tuner_ops.default_hparams()) dqn_hparams = tf_lib.HParams(random_action_probability=0.1, store_every_nth=1, train_every_nth=5, minibatch_size=32, discount_rate=0.5, max_experience=100000, target_network_update_rate=0.01) output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm) output_ckpt = FLAGS.algorithm + '.ckpt' backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir, FLAGS.note_rnn_checkpoint_name) rlt = rl_tuner.RLTuner( output_dir, midi_primer=FLAGS.midi_primer, dqn_hparams=dqn_hparams, reward_scaler=FLAGS.reward_scaler, save_name=output_ckpt, output_every_nth=FLAGS.output_every_nth, note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir, note_rnn_checkpoint_file=backup_checkpoint_file, note_rnn_type=FLAGS.note_rnn_type, note_rnn_hparams=hparams, num_notes_in_melody=FLAGS.num_notes_in_melody, exploration_mode=FLAGS.exploration_mode, algorithm=FLAGS.algorithm) tf.logging.info('Saving images and melodies to: %s', rlt.output_dir) tf.logging.info('Training...') rlt.train(num_steps=FLAGS.training_steps, exploration_period=FLAGS.exploration_steps) tf.logging.info( 'Finished training. Saving output figures and composition.') rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps') rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm, prob_image_name=FLAGS.algorithm + '.png') rlt.save_model_and_figs(FLAGS.algorithm) tf.logging.info('Calculating music theory metric stats for 1000 ' 'compositions.') rlt.evaluate_music_theory_metrics(num_compositions=1000)
def basic_rnn_hparams(): """Generates the hparams used to train a basic_rnn. These are the hparams used in the .mag file found at https://github.com/tensorflow/magenta/tree/master/magenta/models/ melody_rnn#pre-trained Returns: Hyperparameters of the downloadable basic_rnn pre-trained model. """ # TODO(natashajaques): ability to restore basic_rnn from any .mag file. return tf_lib.HParams(batch_size=128, rnn_layer_sizes=[512, 512], one_hot_length=NUM_CLASSES)
def default_hparams(): """Generates the hparams used to train note rnn used in paper.""" return tf_lib.HParams(use_dynamic_rnn=True, batch_size=BATCH_SIZE, lr=0.0002, l2_reg=2.5e-5, clip_norm=5, initial_learning_rate=0.5, decay_steps=1000, decay_rate=0.85, rnn_layer_sizes=[100], skip_first_n_losses=32, one_hot_length=NUM_CLASSES, exponentially_decay_learning_rate=True)
def basic_rnn_hparams(): """Generates the hparams used to train a basic_rnn. These are the hparams used in the .mag file found at https://github.com/tensorflow/magenta/tree/master/magenta/models/ melody_rnn#pre-trained Returns: Hyperparameters of the downloadable basic_rnn pre-trained model. """ # TODO(natashajaques): ability to restore basic_rnn from any .mag # file. return tf_lib.HParams(batch_size=128, dropout_keep_prob=0.5, clip_norm=5, initial_learning_rate=0.01, decay_steps=1000, decay_rate=0.85, rnn_layer_sizes=[512, 512], skip_first_n_losses=0, one_hot_length=NUM_CLASSES, exponentially_decay_learning_rate=True)