Ejemplo n.º 1
0
def main(_):


  hparams = rl_tuner_ops.default_hparams()
  if FLAGS.note_rnn_type == 'basic_rnn':
      hparams = rl_tuner_ops.basic_rnn_hparams()
  elif FLAGS.note_rnn_type == 'attention_rnn':
      hparams = rl_tuner_ops.attention_rnn_hparams()

  dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1,
                                            store_every_nth=1,
                                            train_every_nth=5,
                                            minibatch_size=32,
                                            discount_rate=0.5,
                                            max_experience=100000,
                                            target_network_update_rate=0.01)

  output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm)
  output_ckpt = FLAGS.algorithm + '.ckpt'
  backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir,
                                        FLAGS.note_rnn_checkpoint_name)

  rlt = rl_tuner.RLTuner(output_dir,
                         midi_primer=FLAGS.midi_primer,
                         dqn_hparams=dqn_hparams,
                         reward_scaler=FLAGS.reward_scaler,
                         save_name=output_ckpt,
                         output_every_nth=FLAGS.output_every_nth,
                         note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir,
                         note_rnn_checkpoint_file=backup_checkpoint_file,
                         note_rnn_type=FLAGS.note_rnn_type,
                         note_rnn_hparams=hparams,
                         num_notes_in_melody=FLAGS.num_notes_in_melody,
                         exploration_mode=FLAGS.exploration_mode,
                         algorithm=FLAGS.algorithm)

  tf.logging.info('Saving images and melodies to: %s', rlt.output_dir)

  tf.logging.info('Training...')
  rlt.train(num_steps=FLAGS.training_steps,
            exploration_period=FLAGS.exploration_steps)

  tf.logging.info('Finished training. Saving output figures and composition.')
  rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps')

  rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm,
                              prob_image_name=FLAGS.algorithm + '.png')

  rlt.save_model_and_figs(FLAGS.algorithm)

  tf.logging.info('Calculating music theory metric stats for 1000 '
                  'compositions.')
  statistics = rlt.evaluate_music_theory_metrics(num_compositions=1000)
  print("music theory evaluation statistic: {0}".format(statistics))
Ejemplo n.º 2
0
def main(_):
  if FLAGS.note_rnn_type == 'basic_rnn':
    hparams = rl_tuner_ops.basic_rnn_hparams()
  else:
    hparams = rl_tuner_ops.default_hparams()

  dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1,
                                            store_every_nth=1,
                                            train_every_nth=5,
                                            minibatch_size=32,
                                            discount_rate=0.5,
                                            max_experience=100000,
                                            target_network_update_rate=0.01)

  output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm)
  output_ckpt = FLAGS.algorithm + '.ckpt'
  backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir,
                                        FLAGS.note_rnn_checkpoint_name)

  rlt = rl_tuner.RLTuner(output_dir,
                         midi_primer=FLAGS.midi_primer,
                         dqn_hparams=dqn_hparams,
                         reward_scaler=FLAGS.reward_scaler,
                         save_name=output_ckpt,
                         output_every_nth=FLAGS.output_every_nth,
                         note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir,
                         note_rnn_checkpoint_file=backup_checkpoint_file,
                         note_rnn_type=FLAGS.note_rnn_type,
                         note_rnn_hparams=hparams,
                         num_notes_in_melody=FLAGS.num_notes_in_melody,
                         exploration_mode=FLAGS.exploration_mode,
                         algorithm=FLAGS.algorithm)

  tf.logging.info('Saving images and melodies to: %s', rlt.output_dir)

  tf.logging.info('Training...')
  rlt.train(num_steps=FLAGS.training_steps,
            exploration_period=FLAGS.exploration_steps)

  tf.logging.info('Finished training. Saving output figures and composition.')
  rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps')

  rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm,
                              prob_image_name=FLAGS.algorithm + '.png')

  rlt.save_model_and_figs(FLAGS.algorithm)

  tf.logging.info('Calculating music theory metric stats for 1000 '
                  'compositions.')
  rlt.evaluate_music_theory_metrics(num_compositions=1000)
Ejemplo n.º 3
0
def main(_):
  if FLAGS.note_rnn_type == 'basic_rnn':
    hparams = rl_tuner_ops.basic_rnn_hparams()
  else:
    hparams = rl_tuner_ops.default_hparams()

  dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1,
                                            store_every_nth=1,
                                            train_every_nth=5,
                                            minibatch_size=32,
                                            discount_rate=0.5,
                                            max_experience=100000,
                                            target_network_update_rate=0.01)

  num_compositions = 100000
  if (FLAGS.running_mode == 'comparison'):
    tf.logging.info("Running for comparison")
    # num_compositions = 1000

  defaultRlt = run_algorithm(FLAGS.algorithm, dqn_hparams, hparams,                                        num_compositions)
  if (FLAGS.running_mode == 'comparison'):
    pureRlt = run_algorithm('pure_rl', dqn_hparams, hparams, num_compositions)
    plot_comparison(defaultRlt, pureRlt)