def main(_): hparams = rl_tuner_ops.default_hparams() if FLAGS.note_rnn_type == 'basic_rnn': hparams = rl_tuner_ops.basic_rnn_hparams() elif FLAGS.note_rnn_type == 'attention_rnn': hparams = rl_tuner_ops.attention_rnn_hparams() dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1, store_every_nth=1, train_every_nth=5, minibatch_size=32, discount_rate=0.5, max_experience=100000, target_network_update_rate=0.01) output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm) output_ckpt = FLAGS.algorithm + '.ckpt' backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir, FLAGS.note_rnn_checkpoint_name) rlt = rl_tuner.RLTuner(output_dir, midi_primer=FLAGS.midi_primer, dqn_hparams=dqn_hparams, reward_scaler=FLAGS.reward_scaler, save_name=output_ckpt, output_every_nth=FLAGS.output_every_nth, note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir, note_rnn_checkpoint_file=backup_checkpoint_file, note_rnn_type=FLAGS.note_rnn_type, note_rnn_hparams=hparams, num_notes_in_melody=FLAGS.num_notes_in_melody, exploration_mode=FLAGS.exploration_mode, algorithm=FLAGS.algorithm) tf.logging.info('Saving images and melodies to: %s', rlt.output_dir) tf.logging.info('Training...') rlt.train(num_steps=FLAGS.training_steps, exploration_period=FLAGS.exploration_steps) tf.logging.info('Finished training. Saving output figures and composition.') rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps') rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm, prob_image_name=FLAGS.algorithm + '.png') rlt.save_model_and_figs(FLAGS.algorithm) tf.logging.info('Calculating music theory metric stats for 1000 ' 'compositions.') statistics = rlt.evaluate_music_theory_metrics(num_compositions=1000) print("music theory evaluation statistic: {0}".format(statistics))
def main(_): if FLAGS.note_rnn_type == 'basic_rnn': hparams = rl_tuner_ops.basic_rnn_hparams() else: hparams = rl_tuner_ops.default_hparams() dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1, store_every_nth=1, train_every_nth=5, minibatch_size=32, discount_rate=0.5, max_experience=100000, target_network_update_rate=0.01) output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm) output_ckpt = FLAGS.algorithm + '.ckpt' backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir, FLAGS.note_rnn_checkpoint_name) rlt = rl_tuner.RLTuner(output_dir, midi_primer=FLAGS.midi_primer, dqn_hparams=dqn_hparams, reward_scaler=FLAGS.reward_scaler, save_name=output_ckpt, output_every_nth=FLAGS.output_every_nth, note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir, note_rnn_checkpoint_file=backup_checkpoint_file, note_rnn_type=FLAGS.note_rnn_type, note_rnn_hparams=hparams, num_notes_in_melody=FLAGS.num_notes_in_melody, exploration_mode=FLAGS.exploration_mode, algorithm=FLAGS.algorithm) tf.logging.info('Saving images and melodies to: %s', rlt.output_dir) tf.logging.info('Training...') rlt.train(num_steps=FLAGS.training_steps, exploration_period=FLAGS.exploration_steps) tf.logging.info('Finished training. Saving output figures and composition.') rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps') rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm, prob_image_name=FLAGS.algorithm + '.png') rlt.save_model_and_figs(FLAGS.algorithm) tf.logging.info('Calculating music theory metric stats for 1000 ' 'compositions.') rlt.evaluate_music_theory_metrics(num_compositions=1000)
def main(_): if FLAGS.note_rnn_type == 'basic_rnn': hparams = rl_tuner_ops.basic_rnn_hparams() else: hparams = rl_tuner_ops.default_hparams() dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1, store_every_nth=1, train_every_nth=5, minibatch_size=32, discount_rate=0.5, max_experience=100000, target_network_update_rate=0.01) num_compositions = 100000 if (FLAGS.running_mode == 'comparison'): tf.logging.info("Running for comparison") # num_compositions = 1000 defaultRlt = run_algorithm(FLAGS.algorithm, dqn_hparams, hparams, num_compositions) if (FLAGS.running_mode == 'comparison'): pureRlt = run_algorithm('pure_rl', dqn_hparams, hparams, num_compositions) plot_comparison(defaultRlt, pureRlt)