def __init__(self, graph, scope, checkpoint_dir, checkpoint_file=None, midi_primer=None, training_file_list=None, hparams=None, note_rnn_type='default', checkpoint_scope='rnn_model'): """Initialize by building the graph and loading a previous checkpoint. Args: graph: A tensorflow graph where the MelodyRNN's graph will be added. scope: The tensorflow scope where this network will be saved. checkpoint_dir: Path to the directory where the checkpoint file is saved. checkpoint_file: Path to a checkpoint file to be used if none can be found in the checkpoint_dir midi_primer: Path to a single midi file that can be used to prime the model. training_file_list: List of paths to tfrecord files containing melody training data. hparams: A tf_lib.HParams object. Must match the hparams used to create the checkpoint file. note_rnn_type: If 'default', will use the basic LSTM described in the research paper. If 'basic_rnn', will assume the checkpoint is from a Magenta basic_rnn model. checkpoint_scope: The scope in lstm which the model was originally defined when it was first trained. """ self.graph = graph self.session = None self.saver = None self.scope = scope self.batch_size = 1 self.midi_primer = midi_primer self.checkpoint_scope = checkpoint_scope self.note_rnn_type = note_rnn_type self.training_file_list = training_file_list self.checkpoint_dir = checkpoint_dir self.checkpoint_file = checkpoint_file if hparams is not None: tf.logging.info('Using custom hparams') self.hparams = hparams else: tf.logging.info('Empty hparams string. Using defaults') self.hparams = rl_tuner_ops.default_hparams() self.build_graph() self.state_value = self.get_zero_state() if midi_primer is not None: self.load_primer() self.variable_names = rl_tuner_ops.get_variable_names( self.graph, self.scope) self.transpose_amount = 0
def main(_): hparams = rl_tuner_ops.default_hparams() if FLAGS.note_rnn_type == 'basic_rnn': hparams = rl_tuner_ops.basic_rnn_hparams() elif FLAGS.note_rnn_type == 'attention_rnn': hparams = rl_tuner_ops.attention_rnn_hparams() dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1, store_every_nth=1, train_every_nth=5, minibatch_size=32, discount_rate=0.5, max_experience=100000, target_network_update_rate=0.01) output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm) output_ckpt = FLAGS.algorithm + '.ckpt' backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir, FLAGS.note_rnn_checkpoint_name) rlt = rl_tuner.RLTuner(output_dir, midi_primer=FLAGS.midi_primer, dqn_hparams=dqn_hparams, reward_scaler=FLAGS.reward_scaler, save_name=output_ckpt, output_every_nth=FLAGS.output_every_nth, note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir, note_rnn_checkpoint_file=backup_checkpoint_file, note_rnn_type=FLAGS.note_rnn_type, note_rnn_hparams=hparams, num_notes_in_melody=FLAGS.num_notes_in_melody, exploration_mode=FLAGS.exploration_mode, algorithm=FLAGS.algorithm) tf.logging.info('Saving images and melodies to: %s', rlt.output_dir) tf.logging.info('Training...') rlt.train(num_steps=FLAGS.training_steps, exploration_period=FLAGS.exploration_steps) tf.logging.info('Finished training. Saving output figures and composition.') rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps') rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm, prob_image_name=FLAGS.algorithm + '.png') rlt.save_model_and_figs(FLAGS.algorithm) tf.logging.info('Calculating music theory metric stats for 1000 ' 'compositions.') statistics = rlt.evaluate_music_theory_metrics(num_compositions=1000) print("music theory evaluation statistic: {0}".format(statistics))
def __init__(self, graph, scope, checkpoint_dir, checkpoint_file=None, midi_primer=None, training_file_list=None, hparams=None, note_rnn_type='default', checkpoint_scope='rnn_model'): """Initialize by building the graph and loading a previous checkpoint. Args: graph: A tensorflow graph where the MelodyRNN's graph will be added. scope: The tensorflow scope where this network will be saved. checkpoint_dir: Path to the directory where the checkpoint file is saved. checkpoint_file: Path to a checkpoint file to be used if none can be found in the checkpoint_dir midi_primer: Path to a single midi file that can be used to prime the model. training_file_list: List of paths to tfrecord files containing melody training data. hparams: A tf_lib.HParams object. Must match the hparams used to create the checkpoint file. note_rnn_type: If 'default', will use the basic LSTM described in the research paper. If 'basic_rnn', will assume the checkpoint is from a Magenta basic_rnn model. checkpoint_scope: The scope in lstm which the model was originally defined when it was first trained. """ self.graph = graph self.session = None self.scope = scope self.batch_size = 1 self.midi_primer = midi_primer self.checkpoint_scope = checkpoint_scope self.note_rnn_type = note_rnn_type self.training_file_list = training_file_list self.checkpoint_dir = checkpoint_dir self.checkpoint_file = checkpoint_file if hparams is not None: tf.logging.info('Using custom hparams') self.hparams = hparams else: tf.logging.info('Empty hparams string. Using defaults') self.hparams = rl_tuner_ops.default_hparams() self.build_graph() self.state_value = self.get_zero_state() if midi_primer is not None: self.load_primer() self.variable_names = rl_tuner_ops.get_variable_names(self.graph, self.scope) self.transpose_amount = 0
def main(_): if FLAGS.note_rnn_type == 'basic_rnn': hparams = rl_tuner_ops.basic_rnn_hparams() else: hparams = rl_tuner_ops.default_hparams() dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1, store_every_nth=1, train_every_nth=5, minibatch_size=32, discount_rate=0.5, max_experience=100000, target_network_update_rate=0.01) output_dir = os.path.join(FLAGS.output_dir, FLAGS.algorithm) output_ckpt = FLAGS.algorithm + '.ckpt' backup_checkpoint_file = os.path.join(FLAGS.note_rnn_checkpoint_dir, FLAGS.note_rnn_checkpoint_name) rlt = rl_tuner.RLTuner(output_dir, midi_primer=FLAGS.midi_primer, dqn_hparams=dqn_hparams, reward_scaler=FLAGS.reward_scaler, save_name=output_ckpt, output_every_nth=FLAGS.output_every_nth, note_rnn_checkpoint_dir=FLAGS.note_rnn_checkpoint_dir, note_rnn_checkpoint_file=backup_checkpoint_file, note_rnn_type=FLAGS.note_rnn_type, note_rnn_hparams=hparams, num_notes_in_melody=FLAGS.num_notes_in_melody, exploration_mode=FLAGS.exploration_mode, algorithm=FLAGS.algorithm) tf.logging.info('Saving images and melodies to: %s', rlt.output_dir) tf.logging.info('Training...') rlt.train(num_steps=FLAGS.training_steps, exploration_period=FLAGS.exploration_steps) tf.logging.info('Finished training. Saving output figures and composition.') rlt.plot_rewards(image_name='Rewards-' + FLAGS.algorithm + '.eps') rlt.generate_music_sequence(visualize_probs=True, title=FLAGS.algorithm, prob_image_name=FLAGS.algorithm + '.png') rlt.save_model_and_figs(FLAGS.algorithm) tf.logging.info('Calculating music theory metric stats for 1000 ' 'compositions.') rlt.evaluate_music_theory_metrics(num_compositions=1000)
def main(_): if FLAGS.note_rnn_type == 'basic_rnn': hparams = rl_tuner_ops.basic_rnn_hparams() else: hparams = rl_tuner_ops.default_hparams() dqn_hparams = tf.contrib.training.HParams(random_action_probability=0.1, store_every_nth=1, train_every_nth=5, minibatch_size=32, discount_rate=0.5, max_experience=100000, target_network_update_rate=0.01) num_compositions = 100000 if (FLAGS.running_mode == 'comparison'): tf.logging.info("Running for comparison") # num_compositions = 1000 defaultRlt = run_algorithm(FLAGS.algorithm, dqn_hparams, hparams, num_compositions) if (FLAGS.running_mode == 'comparison'): pureRlt = run_algorithm('pure_rl', dqn_hparams, hparams, num_compositions) plot_comparison(defaultRlt, pureRlt)