示例#1
0
    def __init__(self,
                 graph,
                 scope,
                 checkpoint_dir,
                 checkpoint_file=None,
                 midi_primer=None,
                 training_file_list=None,
                 hparams=None,
                 note_rnn_type='default',
                 checkpoint_scope='rnn_model'):
        """Initialize by building the graph and loading a previous checkpoint.

    Args:
      graph: A tensorflow graph where the MelodyRNN's graph will be added.
      scope: The tensorflow scope where this network will be saved.
      checkpoint_dir: Path to the directory where the checkpoint file is saved.
      checkpoint_file: Path to a checkpoint file to be used if none can be
        found in the checkpoint_dir
      midi_primer: Path to a single midi file that can be used to prime the
        model.
      training_file_list: List of paths to tfrecord files containing melody
        training data.
      hparams: A tf_lib.HParams object. Must match the hparams used to create
        the checkpoint file.
      note_rnn_type: If 'default', will use the basic LSTM described in the
        research paper. If 'basic_rnn', will assume the checkpoint is from a
        Magenta basic_rnn model.
      checkpoint_scope: The scope in lstm which the model was originally defined
        when it was first trained.
    """
        self.graph = graph
        self.session = None
        self.saver = None
        self.scope = scope
        self.batch_size = 1
        self.midi_primer = midi_primer
        self.checkpoint_scope = checkpoint_scope
        self.note_rnn_type = note_rnn_type
        self.training_file_list = training_file_list
        self.checkpoint_dir = checkpoint_dir
        self.checkpoint_file = checkpoint_file

        if hparams is not None:
            tf.logging.info('Using custom hparams')
            self.hparams = hparams
        else:
            tf.logging.info('Empty hparams string. Using defaults')
            self.hparams = rl_tuner_ops.default_hparams()

        self.build_graph()
        self.state_value = self.get_zero_state()

        if midi_primer is not None:
            self.load_primer()

        self.variable_names = rl_tuner_ops.get_variable_names(
            self.graph, self.scope)

        self.transpose_amount = 0
示例#2
0
def main(_):


  hparams = rl_tuner_ops.default_hparams()
  if FLAGS.note_rnn_type == 'basic_rnn':
      hparams = rl_tuner_ops.basic_rnn_hparams()
  elif FLAGS.note_rnn_type == 'attention_rnn':
      hparams = rl_tuner_ops.attention_rnn_hparams()

  dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1,
                                            store_every_nth=1,
                                            train_every_nth=5,
                                            minibatch_size=32,
                                            discount_rate=0.5,
                                            max_experience=100000,
                                            target_network_update_rate=0.01)

  output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm)
  output_ckpt = FLAGS.algorithm + '.ckpt'
  backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir,
                                        FLAGS.note_rnn_checkpoint_name)

  rlt = rl_tuner.RLTuner(output_dir,
                         midi_primer=FLAGS.midi_primer,
                         dqn_hparams=dqn_hparams,
                         reward_scaler=FLAGS.reward_scaler,
                         save_name=output_ckpt,
                         output_every_nth=FLAGS.output_every_nth,
                         note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir,
                         note_rnn_checkpoint_file=backup_checkpoint_file,
                         note_rnn_type=FLAGS.note_rnn_type,
                         note_rnn_hparams=hparams,
                         num_notes_in_melody=FLAGS.num_notes_in_melody,
                         exploration_mode=FLAGS.exploration_mode,
                         algorithm=FLAGS.algorithm)

  tf.logging.info('Saving images and melodies to: %s', rlt.output_dir)

  tf.logging.info('Training...')
  rlt.train(num_steps=FLAGS.training_steps,
            exploration_period=FLAGS.exploration_steps)

  tf.logging.info('Finished training. Saving output figures and composition.')
  rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps')

  rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm,
                              prob_image_name=FLAGS.algorithm + '.png')

  rlt.save_model_and_figs(FLAGS.algorithm)

  tf.logging.info('Calculating music theory metric stats for 1000 '
                  'compositions.')
  statistics = rlt.evaluate_music_theory_metrics(num_compositions=1000)
  print("music theory evaluation statistic: {0}".format(statistics))
示例#3
0
  def __init__(self, graph, scope, checkpoint_dir, checkpoint_file=None,
               midi_primer=None, training_file_list=None, hparams=None,
               note_rnn_type='default', checkpoint_scope='rnn_model'):
    """Initialize by building the graph and loading a previous checkpoint.

    Args:
      graph: A tensorflow graph where the MelodyRNN's graph will be added.
      scope: The tensorflow scope where this network will be saved.
      checkpoint_dir: Path to the directory where the checkpoint file is saved.
      checkpoint_file: Path to a checkpoint file to be used if none can be
        found in the checkpoint_dir
      midi_primer: Path to a single midi file that can be used to prime the
        model.
      training_file_list: List of paths to tfrecord files containing melody
        training data.
      hparams: A tf_lib.HParams object. Must match the hparams used to create
        the checkpoint file.
      note_rnn_type: If 'default', will use the basic LSTM described in the
        research paper. If 'basic_rnn', will assume the checkpoint is from a
        Magenta basic_rnn model.
      checkpoint_scope: The scope in lstm which the model was originally defined
        when it was first trained.
    """
    self.graph = graph
    self.session = None
    self.scope = scope
    self.batch_size = 1
    self.midi_primer = midi_primer
    self.checkpoint_scope = checkpoint_scope
    self.note_rnn_type = note_rnn_type
    self.training_file_list = training_file_list
    self.checkpoint_dir = checkpoint_dir
    self.checkpoint_file = checkpoint_file

    if hparams is not None:
      tf.logging.info('Using custom hparams')
      self.hparams = hparams
    else:
      tf.logging.info('Empty hparams string. Using defaults')
      self.hparams = rl_tuner_ops.default_hparams()

    self.build_graph()
    self.state_value = self.get_zero_state()

    if midi_primer is not None:
      self.load_primer()

    self.variable_names = rl_tuner_ops.get_variable_names(self.graph,
                                                          self.scope)

    self.transpose_amount = 0
示例#4
0
def main(_):
  if FLAGS.note_rnn_type == 'basic_rnn':
    hparams = rl_tuner_ops.basic_rnn_hparams()
  else:
    hparams = rl_tuner_ops.default_hparams()

  dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1,
                                            store_every_nth=1,
                                            train_every_nth=5,
                                            minibatch_size=32,
                                            discount_rate=0.5,
                                            max_experience=100000,
                                            target_network_update_rate=0.01)

  output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm)
  output_ckpt = FLAGS.algorithm + '.ckpt'
  backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir,
                                        FLAGS.note_rnn_checkpoint_name)

  rlt = rl_tuner.RLTuner(output_dir,
                         midi_primer=FLAGS.midi_primer,
                         dqn_hparams=dqn_hparams,
                         reward_scaler=FLAGS.reward_scaler,
                         save_name=output_ckpt,
                         output_every_nth=FLAGS.output_every_nth,
                         note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir,
                         note_rnn_checkpoint_file=backup_checkpoint_file,
                         note_rnn_type=FLAGS.note_rnn_type,
                         note_rnn_hparams=hparams,
                         num_notes_in_melody=FLAGS.num_notes_in_melody,
                         exploration_mode=FLAGS.exploration_mode,
                         algorithm=FLAGS.algorithm)

  tf.logging.info('Saving images and melodies to: %s', rlt.output_dir)

  tf.logging.info('Training...')
  rlt.train(num_steps=FLAGS.training_steps,
            exploration_period=FLAGS.exploration_steps)

  tf.logging.info('Finished training. Saving output figures and composition.')
  rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps')

  rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm,
                              prob_image_name=FLAGS.algorithm + '.png')

  rlt.save_model_and_figs(FLAGS.algorithm)

  tf.logging.info('Calculating music theory metric stats for 1000 '
                  'compositions.')
  rlt.evaluate_music_theory_metrics(num_compositions=1000)
示例#5
0
def main(_):
  if FLAGS.note_rnn_type == 'basic_rnn':
    hparams = rl_tuner_ops.basic_rnn_hparams()
  else:
    hparams = rl_tuner_ops.default_hparams()

  dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1,
                                            store_every_nth=1,
                                            train_every_nth=5,
                                            minibatch_size=32,
                                            discount_rate=0.5,
                                            max_experience=100000,
                                            target_network_update_rate=0.01)

  num_compositions = 100000
  if (FLAGS.running_mode == 'comparison'):
    tf.logging.info("Running for comparison")
    # num_compositions = 1000

  defaultRlt = run_algorithm(FLAGS.algorithm, dqn_hparams, hparams,                                        num_compositions)
  if (FLAGS.running_mode == 'comparison'):
    pureRlt = run_algorithm('pure_rl', dqn_hparams, hparams, num_compositions)
    plot_comparison(defaultRlt, pureRlt)