def main(unused_argv): tf.logging.set_verbosity(FLAGS.log) if not FLAGS.run_dir: tf.logging.fatal('--run_dir required') return if not FLAGS.sequence_example_file: tf.logging.fatal('--sequence_example_file required') return sequence_example_file = os.path.expanduser(FLAGS.sequence_example_file) run_dir = os.path.expanduser(FLAGS.run_dir) config = melody_rnn_config.config_from_flags() mode = 'eval' if FLAGS.eval else 'train' graph = melody_rnn_graph.build_graph(mode, config, sequence_example_file) train_dir = os.path.join(run_dir, 'train') if not os.path.exists(train_dir): tf.gfile.MakeDirs(train_dir) tf.logging.info('Train dir: %s', train_dir) if FLAGS.eval: eval_dir = os.path.join(run_dir, 'eval') if not os.path.exists(eval_dir): tf.gfile.MakeDirs(eval_dir) tf.logging.info('Eval dir: %s', eval_dir) run_eval(graph, train_dir, eval_dir, FLAGS.num_training_steps, FLAGS.summary_frequency) else: run_training(graph, train_dir, FLAGS.num_training_steps, FLAGS.summary_frequency)
def run_from_flags(): tf.logging.set_verbosity(FLAGS.log) config = melody_rnn_config.config_from_flags() pipeline_instance = get_pipeline(config.encoder_decoder) FLAGS.input = os.path.expanduser(FLAGS.input) FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir) pipeline.run_pipeline_serial( pipeline_instance, pipeline.tf_record_iterator(FLAGS.input, pipeline_instance.input_type), FLAGS.output_dir)
def main(unused_argv): """Saves bundle or runs generator based on flags.""" generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( melody_rnn_config.config_from_flags(), FLAGS.steps_per_quarter, get_checkpoint(), get_bundle()) if FLAGS.save_generator_bundle: bundle_filename = os.path.expanduser(FLAGS.bundle_file) tf.logging.info('Saving generator bundle to %s', bundle_filename) generator.create_bundle_file(bundle_filename) else: run_with_flags(generator)