예제 #1
0
 def run_training_one_step(
     self,
     bound,
     dataset_type,
     data_dimension,
     dataset_filename,
     dir_prefix,
     resampling_type,
     model,
     batch_size=2,
     num_samples=3,
     create_dataset_and_model_fn=(runners.create_dataset_and_model)):
     config = self.default_config()
     config.model = model
     config.resampling_type = resampling_type
     config.relaxed_resampling_temperature = 0.5
     config.bound = bound
     config.split = "train"
     config.dataset_type = dataset_type
     config.dataset_path = os.path.join(
         os.path.dirname(os.path.realpath(__file__)), "test_data",
         dataset_filename)
     config.max_steps = 1
     config.batch_size = batch_size
     config.num_samples = num_samples
     config.latent_size = 4
     config.data_dimension = data_dimension
     config.logdir = os.path.join(
         tf.test.get_temp_dir(),
         "%s-%s-%s-%s" % (dir_prefix, bound, dataset_type, model))
     runners.run_train(
         config, create_dataset_and_model_fn=create_dataset_and_model_fn)
     return config
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.model in ["vrnn", "srnn"]:
        if FLAGS.data_dimension is None:
            if FLAGS.dataset_type == "pianoroll":
                FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
            elif FLAGS.dataset_type == "speech":
                FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
        if FLAGS.mode == "train":
            runners.run_train(FLAGS)
        elif FLAGS.mode == "eval":
            runners.run_eval(FLAGS)
        elif FLAGS.mode == "sample":
            runners.run_sample(FLAGS)
    elif FLAGS.model == "ghmm":
        if FLAGS.mode == "train":
            ghmm_runners.run_train(FLAGS)
        elif FLAGS.mode == "eval":
            ghmm_runners.run_eval(FLAGS)
예제 #3
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.model in ["vrnn", "srnn"]:
    if FLAGS.data_dimension is None:
      if FLAGS.dataset_type == "pianoroll":
        FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
      elif FLAGS.dataset_type == "speech":
        FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
    if FLAGS.mode == "train":
      runners.run_train(FLAGS)
    elif FLAGS.mode == "eval":
      runners.run_eval(FLAGS)
    elif FLAGS.mode == "sample":
      runners.run_sample(FLAGS)
  elif FLAGS.model == "ghmm":
    if FLAGS.mode == "train":
      ghmm_runners.run_train(FLAGS)
    elif FLAGS.mode == "eval":
      ghmm_runners.run_eval(FLAGS)