def run(melody_encoder_decoder): """Creates training and eval data with the given MelodyEncoderDecoder. Args: melody_encoder_decoder: A melodies_lib.MelodyEncoderDecoder. """ tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.input: tf.logging.fatal('--input required') return if not FLAGS.train_output: tf.logging.fatal('--train_output required') return FLAGS.input = os.path.expanduser(FLAGS.input) FLAGS.train_output = os.path.expanduser(FLAGS.train_output) if FLAGS.eval_output: FLAGS.eval_output = os.path.expanduser(FLAGS.eval_output) if not os.path.exists(os.path.dirname(FLAGS.train_output)): os.makedirs(os.path.dirname(FLAGS.train_output)) if FLAGS.eval_output: if not os.path.exists(os.path.dirname(FLAGS.eval_output)): os.makedirs(os.path.dirname(FLAGS.eval_output)) sequence_to_melodies.run_conversion(melody_encoder_decoder, FLAGS.input, FLAGS.train_output, FLAGS.eval_output, FLAGS.eval_ratio)
def testRunConversionNoEval(self): sequence_to_melodies.run_conversion( encoder=sequence_to_melodies.basic_one_hot_encoder, sequences_file=self.sequences_file, train_output=self.train_output) self.assertTrue(os.path.isfile(self.train_output)) reader = tf.python_io.tf_record_iterator(self.train_output) self.assertEqual(62, len(list(reader))) self.assertFalse(os.path.isfile(self.eval_output))
def main(unused_argv): root = logging.getLogger() root.setLevel(logging.INFO) ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.INFO) root.addHandler(ch) sequence_to_melodies.run_conversion( encoder=sequence_to_melodies.basic_one_hot_encoder, sequences_file=FLAGS.input, train_output=FLAGS.train_output, eval_output=FLAGS.eval_output, eval_ratio=FLAGS.eval_ratio)
def testRunConversion(self): sequence_to_melodies.run_conversion( encoder=sequence_to_melodies.basic_one_hot_encoder, sequences_file=self.sequences_file, train_output=self.train_output, eval_output=self.eval_output, eval_ratio=0.25) self.assertTrue(os.path.isfile(self.train_output)) reader = tf.python_io.tf_record_iterator(self.train_output) num_train_samples = len(list(reader)) self.assertTrue(os.path.isfile(self.eval_output)) reader = tf.python_io.tf_record_iterator(self.eval_output) num_eval_samples = len(list(reader)) self.assertTrue(num_train_samples > 0) self.assertTrue(num_eval_samples > 0) self.assertTrue(num_train_samples > num_eval_samples)