def run_training_one_step( self, bound, dataset_type, data_dimension, dataset_filename, dir_prefix, resampling_type, model, batch_size=2, num_samples=3, create_dataset_and_model_fn=(runners.create_dataset_and_model)): config = self.default_config() config.model = model config.resampling_type = resampling_type config.relaxed_resampling_temperature = 0.5 config.bound = bound config.split = "train" config.dataset_type = dataset_type config.dataset_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "test_data", dataset_filename) config.max_steps = 1 config.batch_size = batch_size config.num_samples = num_samples config.latent_size = 4 config.data_dimension = data_dimension config.logdir = os.path.join( tf.test.get_temp_dir(), "%s-%s-%s-%s" % (dir_prefix, bound, dataset_type, model)) runners.run_train( config, create_dataset_and_model_fn=create_dataset_and_model_fn) return config
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.model in ["vrnn", "srnn"]: if FLAGS.data_dimension is None: if FLAGS.dataset_type == "pianoroll": FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION elif FLAGS.dataset_type == "speech": FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION if FLAGS.mode == "train": runners.run_train(FLAGS) elif FLAGS.mode == "eval": runners.run_eval(FLAGS) elif FLAGS.mode == "sample": runners.run_sample(FLAGS) elif FLAGS.model == "ghmm": if FLAGS.mode == "train": ghmm_runners.run_train(FLAGS) elif FLAGS.mode == "eval": ghmm_runners.run_eval(FLAGS)