Exemplo n.º 1
0
 def run_eval(self, model):
     config = self.run_training_one_step("fivo", "pianoroll", 88,
                                         "tiny_pianoroll.pkl",
                                         "test-eval-" + model,
                                         "multinomial", model)
     config.split = "train"
     runners.run_eval(config)
Exemplo n.º 2
0
 def test_eval_with_custom_fn(self):
     config = self.run_training_one_step(
         "fivo",
         "pianoroll",
         1,
         "tiny_pianoroll.pkl",
         "test-eval-custom-fn",
         "multinomial",
         "vrnn",
         batch_size=1,
         create_dataset_and_model_fn=self.dummmy_dataset_and_model_fn)
     config.split = "train"
     runners.run_eval(
         config,
         create_dataset_and_model_fn=self.dummmy_dataset_and_model_fn)
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.model in ["vrnn", "srnn"]:
        if FLAGS.data_dimension is None:
            if FLAGS.dataset_type == "pianoroll":
                FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
            elif FLAGS.dataset_type == "speech":
                FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
        if FLAGS.mode == "train":
            runners.run_train(FLAGS)
        elif FLAGS.mode == "eval":
            runners.run_eval(FLAGS)
        elif FLAGS.mode == "sample":
            runners.run_sample(FLAGS)
    elif FLAGS.model == "ghmm":
        if FLAGS.mode == "train":
            ghmm_runners.run_train(FLAGS)
        elif FLAGS.mode == "eval":
            ghmm_runners.run_eval(FLAGS)
Exemplo n.º 4
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.model in ["vrnn", "srnn"]:
    if FLAGS.data_dimension is None:
      if FLAGS.dataset_type == "pianoroll":
        FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
      elif FLAGS.dataset_type == "speech":
        FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
    if FLAGS.mode == "train":
      runners.run_train(FLAGS)
    elif FLAGS.mode == "eval":
      runners.run_eval(FLAGS)
    elif FLAGS.mode == "sample":
      runners.run_sample(FLAGS)
  elif FLAGS.model == "ghmm":
    if FLAGS.mode == "train":
      ghmm_runners.run_train(FLAGS)
    elif FLAGS.mode == "eval":
      ghmm_runners.run_eval(FLAGS)