示例#1
0
def main(unused_argv):
    tf.logging.set_verbosity(FLAGS.log)

    if not FLAGS.run_dir:
        tf.logging.fatal('--run_dir required')
        return
    if not FLAGS.sequence_example_file:
        tf.logging.fatal('--sequence_example_file required')
        return

    sequence_example_file_paths = tf.gfile.Glob(
        os.path.expanduser(FLAGS.sequence_example_file))
    run_dir = os.path.expanduser(FLAGS.run_dir)

    config = melody_vrae_config_flags.config_from_flags()

    mode = 'eval' if FLAGS.eval else 'train'
    graph = events_vrae_graph.build_graph(mode, config,
                                          sequence_example_file_paths)

    train_dir = os.path.join(run_dir, 'train')
    if not os.path.exists(train_dir):
        tf.gfile.MakeDirs(train_dir)
    tf.logging.info('Train dir: %s', train_dir)

    if FLAGS.eval:
        eval_dir = os.path.join(run_dir, 'eval')
        if not os.path.exists(eval_dir):
            tf.gfile.MakeDirs(eval_dir)
        tf.logging.info('Eval dir: %s', eval_dir)
        events_vrae_train.run_eval(graph, train_dir, eval_dir,
                                   FLAGS.num_training_steps,
                                   FLAGS.summary_frequency)

    else:
        events_vrae_train.run_training(graph, train_dir,
                                       FLAGS.num_training_steps,
                                       FLAGS.summary_frequency)
示例#2
0
 def _build_graph_for_generation(self):
     return events_vrae_graph.build_graph('generate', self._config)
示例#3
0
 def testBuildGenerateGraph(self):
     g = events_vrae_graph.build_graph('generate', self.config)
     self.assertTrue(isinstance(g, tf.Graph))
示例#4
0
 def testBuildGraphWithAttention(self):
     self.config.hparams.attn_length = 10
     g = events_vrae_graph.build_graph('train',
                                       self.config,
                                       sequence_example_file_paths=['test'])
     self.assertTrue(isinstance(g, tf.Graph))
示例#5
0
 def testBuildEvalGraph(self):
     g = events_vrae_graph.build_graph('eval',
                                       self.config,
                                       sequence_example_file_paths=['test'])
     self.assertTrue(isinstance(g, tf.Graph))