def testBuildCudnnGraph(self):
     self.config.hparams.use_cudnn = True
     with tf.Graph().as_default():
         events_rnn_graph.get_build_graph_fn(
             'train',
             self.config,
             sequence_example_file_paths=[self._sequence_file.name])()
Esempio n. 2
0
 def testBuildCudnnGraphWithResidualConnections(self):
   self.config.hparams.use_cudnn = True
   self.config.hparams.residual_connections = True
   with tf.Graph().as_default():
     events_rnn_graph.get_build_graph_fn(
         'train', self.config,
         sequence_example_file_paths=[self._sequence_file.name])()
 def testBuildGraphWithAttention(self):
     self.config.hparams.attn_length = 10
     with tf.Graph().as_default():
         events_rnn_graph.get_build_graph_fn(
             'train',
             self.config,
             sequence_example_file_paths=[self._sequence_file.name])()
Esempio n. 4
0
def main(unused_argv):
    tf.logging.set_verbosity(mt.FLAGS.log)
    data_dir = tgt.OUTPUT_DIR
    train_dir = os.path.join(tgt.MODEL_DIR, "logdir/train")
    if not os.path.exists(train_dir):
        tf.gfile.MakeDirs(train_dir)

    config = mt.melody_rnn_config_flags.config_from_flags()

    if not mt.FLAGS.eval:
        train_file = tf.gfile.Glob(
            os.path.join(data_dir, "training_melodies.tfrecord"))
        tf.logging.info("Train dir: %s", train_dir)
        with tf.gfile.Open(os.path.join(train_dir, "hparams"), mode="w") as f:
            f.write("\t".join([mt.FLAGS.config, mt.FLAGS.hparams]))

        graph = events_rnn_graph.get_build_graph_fn("train", config,
                                                    train_file)
        events_rnn_train.run_training(
            graph,
            train_dir,
            mt.FLAGS.num_training_steps,
            mt.FLAGS.summary_frequency,
            checkpoints_to_keep=mt.FLAGS.num_checkpoints)

    else:
        eval_file = tf.gfile.Glob(
            os.path.join(data_dir, "eval_melodies.tfrecord"))
        eval_dir = os.path.join(tgt.MODEL_DIR, "logdir/eval")
        if not os.path.exists(eval_dir):
            tf.gfile.MakeDirs(eval_dir)
        tf.logging.info("Eval dir: %s", eval_dir)

        examples = mt.FLAGS.num_eval_examples if mt.FLAGS.num_eval_examples else magenta.common.count_records(
            eval_file)

        if examples >= config.hparams.batch_size:
            num_batches = examples // config.hparams.batch_size
        else:
            config.hparams.batch_size = examples
            num_batches = 1

        graph = events_rnn_graph.get_build_graph_fn("eval", config, eval_file)
        events_rnn_train.run_eval(graph, train_dir, eval_dir, num_batches)
Esempio n. 5
0
def main(unused_argv):
    tf.logging.set_verbosity(FLAGS.log)

    if not FLAGS.run_dir:
        tf.logging.fatal('--run_dir required')
        return
    if not FLAGS.sequence_example_file:
        tf.logging.fatal('--sequence_example_file required')
        return

    sequence_example_file_paths = tf.gfile.Glob(
        os.path.expanduser(FLAGS.sequence_example_file))
    run_dir = os.path.expanduser(FLAGS.run_dir)

    config = drums_rnn_config_flags.config_from_flags()

    mode = 'eval' if FLAGS.eval else 'train'
    build_graph_fn = events_rnn_graph.get_build_graph_fn(
        mode, config, sequence_example_file_paths)

    train_dir = os.path.join(run_dir, 'train')
    if not os.path.exists(train_dir):
        tf.gfile.MakeDirs(train_dir)
    tf.logging.info('Train dir: %s', train_dir)

    if FLAGS.eval:
        eval_dir = os.path.join(run_dir, 'eval')
        if not os.path.exists(eval_dir):
            tf.gfile.MakeDirs(eval_dir)
        tf.logging.info('Eval dir: %s', eval_dir)
        num_batches = (
            (FLAGS.num_eval_examples
             or magenta.common.count_records(sequence_example_file_paths)) //
            config.hparams.batch_size)
        events_rnn_train.run_eval(build_graph_fn, train_dir, eval_dir,
                                  num_batches)

    else:
        events_rnn_train.run_training(
            build_graph_fn,
            train_dir,
            FLAGS.num_training_steps,
            FLAGS.summary_frequency,
            checkpoints_to_keep=FLAGS.num_checkpoints)
Esempio n. 6
0
def main(unused_argv):
  tf.logging.set_verbosity(FLAGS.log)

  if not FLAGS.run_dir:
    tf.logging.fatal('--run_dir required')
    return
  if not FLAGS.sequence_example_file:
    tf.logging.fatal('--sequence_example_file required')
    return

  sequence_example_file_paths = tf.gfile.Glob(
      os.path.expanduser(FLAGS.sequence_example_file))
  run_dir = os.path.expanduser(FLAGS.run_dir)

  config = melody_rnn_config_flags.config_from_flags()

  mode = 'eval' if FLAGS.eval else 'train'
  build_graph_fn = events_rnn_graph.get_build_graph_fn(
      mode, config, sequence_example_file_paths)

  train_dir = os.path.join(run_dir, 'train')
  if not os.path.exists(train_dir):
    tf.gfile.MakeDirs(train_dir)
  tf.logging.info('Train dir: %s', train_dir)

  if FLAGS.eval:
    eval_dir = os.path.join(run_dir, 'eval')
    if not os.path.exists(eval_dir):
      tf.gfile.MakeDirs(eval_dir)
    tf.logging.info('Eval dir: %s', eval_dir)
    num_batches = (
        (FLAGS.num_eval_examples if FLAGS.num_eval_examples else
         magenta.common.count_records(sequence_example_file_paths)) //
        config.hparams.batch_size)
    events_rnn_train.run_eval(build_graph_fn, train_dir, eval_dir, num_batches)

  else:
    events_rnn_train.run_training(build_graph_fn, train_dir,
                                  FLAGS.num_training_steps,
                                  FLAGS.summary_frequency,
                                  checkpoints_to_keep=FLAGS.num_checkpoints)
 def testBuildEvalGraph(self):
   with tf.Graph().as_default():
     events_rnn_graph.get_build_graph_fn(
         'eval', self.config,
         sequence_example_file_paths=[self._sequence_file.name])()
Esempio n. 8
0
 def _build_graph_for_generation(self):
     events_rnn_graph.get_build_graph_fn('generate', self._config)()
 def testBuildCudnnGenerateGraphWithResidualConnections(self):
     self.config.hparams.use_cudnn = True
     self.config.hparams.residual_connections = True
     with tf.Graph().as_default():
         events_rnn_graph.get_build_graph_fn('generate', self.config)()
 def testBuildCudnnGenerateGraph(self):
     self.config.hparams.use_cudnn = True
     with tf.Graph().as_default():
         events_rnn_graph.get_build_graph_fn('generate', self.config)()
 def testBuildEvalGraph(self):
     with tf.Graph().as_default():
         events_rnn_graph.get_build_graph_fn(
             'eval',
             self.config,
             sequence_example_file_paths=[self._sequence_file.name])()
Esempio n. 12
0
# run_dir = os.path.join(work_dir, 'logdir')
run_dir = 'logdir'
print(run_dir)

tf.logging.set_verbosity('INFO')

sequence_example_file = 'data' + os.path.sep + 'training_melodies.tfrecord'
sequence_example_file_paths = tf.io.gfile.glob(
    os.path.join(work_dir, sequence_example_file))

config = melody_rnn_model.default_configs['attention_rnn']
config.hparams.batch_size = 64
config.hparams.rnn_layer_sizes = [64, 64]

mode = 'train'
build_graph_fn = events_rnn_graph.get_build_graph_fn(
    mode, config, sequence_example_file_paths)

train_dir = os.path.join(run_dir, 'train')
eval_dir = os.path.join(run_dir, 'eval')

print('Train directory: %s', train_dir)
print('Evaluate directory: %s', eval_dir)

print(config.hparams.batch_size)
print(magenta.common.count_records(sequence_example_file_paths))

# num_batches = magenta.common.count_records(sequence_example_file_paths) // config.hparams.batch_size

events_rnn_train.run_training(build_graph_fn, train_dir, 20000, 10, 3)
Esempio n. 13
0
 def testBuildCudnnGenerateGraphWithResidualConnections(self):
   self.config.hparams.use_cudnn = True
   self.config.hparams.residual_connections = True
   with tf.Graph().as_default():
     events_rnn_graph.get_build_graph_fn('generate', self.config)()
Esempio n. 14
0
 def testBuildCudnnGenerateGraph(self):
   self.config.hparams.use_cudnn = True
   with tf.Graph().as_default():
     events_rnn_graph.get_build_graph_fn('generate', self.config)()
 def testBuildGenerateGraph(self):
   with tf.Graph().as_default():
     events_rnn_graph.get_build_graph_fn('generate', self.config)()
 def testBuildGenerateGraph(self):
     with tf.Graph().as_default():
         events_rnn_graph.get_build_graph_fn('generate', self.config)()
 def testBuildGraphWithAttention(self):
   self.config.hparams.attn_length = 10
   with tf.Graph().as_default():
     events_rnn_graph.get_build_graph_fn(
         'train', self.config,
         sequence_example_file_paths=[self._sequence_file.name])()
Esempio n. 18
0
 def _build_graph_for_generation(self):
   events_rnn_graph.get_build_graph_fn('generate', self._config)()