def eval_ghmm_notraining(self, bound, proposal_type, expected_bound_avg):
        config = self.default_config()
        config.proposal_type = proposal_type
        config.bound = bound
        config.logdir = os.path.join(
            tf.test.get_temp_dir(), "test-ghmm-%s-%s" % (proposal_type, bound))

        ghmm_runners.run_eval(config)

        data = np.load(os.path.join(config.logdir, "out.npz")).item()
        self.assertAlmostEqual(expected_bound_avg, data["mean"], places=3)
Ejemplo n.º 2
0
 def train_ghmm_for_one_step_and_eval(self, bound, proposal_type, expected_bound_avg):
   config = self.default_config()
   config.proposal_type = proposal_type
   config.bound = bound
   config.max_steps = 1
   config.logdir = os.path.join(
       tf.test.get_temp_dir(), "test-ghmm-training-%s-%s" % (proposal_type, bound))
   ghmm_runners.run_train(config)
   ghmm_runners.run_eval(config)
   data = np.load(os.path.join(config.logdir, "out.npz")).item()
   self.assertAlmostEqual(expected_bound_avg, data["mean"], places=2)
Ejemplo n.º 3
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.model in ["vrnn", "srnn"]:
        if FLAGS.data_dimension is None:
            if FLAGS.dataset_type == "pianoroll":
                FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
            elif FLAGS.dataset_type == "speech":
                FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
        if FLAGS.mode == "train":
            runners.run_train(FLAGS)
        elif FLAGS.mode == "eval":
            runners.run_eval(FLAGS)
        elif FLAGS.mode == "sample":
            runners.run_sample(FLAGS)
    elif FLAGS.model == "ghmm":
        if FLAGS.mode == "train":
            ghmm_runners.run_train(FLAGS)
        elif FLAGS.mode == "eval":
            ghmm_runners.run_eval(FLAGS)
Ejemplo n.º 4
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.model in ["vrnn", "srnn"]:
    if FLAGS.data_dimension is None:
      if FLAGS.dataset_type == "pianoroll":
        FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
      elif FLAGS.dataset_type == "speech":
        FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
    if FLAGS.mode == "train":
      runners.run_train(FLAGS)
    elif FLAGS.mode == "eval":
      runners.run_eval(FLAGS)
    elif FLAGS.mode == "sample":
      runners.run_sample(FLAGS)
  elif FLAGS.model == "ghmm":
    if FLAGS.mode == "train":
      ghmm_runners.run_train(FLAGS)
    elif FLAGS.mode == "eval":
      ghmm_runners.run_eval(FLAGS)