def create_dataset_and_model(config, split, shuffle, repeat): """Creates the dataset and model for a given config. Args: config: A configuration object with config values accessible as properties. Most likely a FLAGS object. This function expects the properties batch_size, dataset_path, dataset_type, and latent_size to be defined. split: The dataset split to load. shuffle: If true, shuffle the dataset randomly. repeat: If true, repeat the dataset endlessly. Returns: inputs: A batch of input sequences represented as a dense Tensor of shape [time, batch_size, data_dimension]. targets: A batch of target sequences represented as a dense Tensor of shape [time, batch_size, data_dimension]. lens: An int Tensor of shape [batch_size] representing the lengths of each sequence in the batch. model: A vrnn.VRNNCell model object. """ if config.dataset_type == "pianoroll": inputs, targets, lengths, mean = datasets.create_pianoroll_dataset( config.dataset_path, split, config.batch_size, shuffle=shuffle, repeat=repeat) # Convert the mean of the training set to logit space so it can be used to # initialize the bias of the generative distribution. generative_bias_init = -tf.log( 1. / tf.clip_by_value(mean, 0.0001, 0.9999) - 1) generative_distribution_class = vrnn.ConditionalBernoulliDistribution elif config.dataset_type == "speech": inputs, targets, lengths = datasets.create_speech_dataset( config.dataset_path, config.batch_size, samples_per_timestep=config.data_dimension, prefetch_buffer_size=1, shuffle=False, repeat=False) generative_bias_init = None generative_distribution_class = vrnn.ConditionalNormalDistribution model = vrnn.create_vrnn(inputs.get_shape().as_list()[2], config.latent_size, generative_distribution_class, generative_bias_init=generative_bias_init, raw_sigma_bias=0.5) return inputs, targets, lengths, model
def create_dataset_and_model(config, split, shuffle, repeat): inputs, targets, lengths, mmsis, mean = datasets.create_AIS_dataset(config.dataset_path, config.split, config.batch_size, config.data_dim, config.lat_bins, config.lon_bins, config.sog_bins, config.cog_bins, shuffle=shuffle, repeat=repeat) # Convert the mean of the training set to logit space so it can be used to # initialize the bias of the generative distribution. generative_bias_init = -tf.log(1. / tf.clip_by_value(mean, 0.0001, 0.9999) - 1) generative_distribution_class = vrnn.ConditionalBernoulliDistribution model = vrnn.create_vrnn(inputs.get_shape().as_list()[2], config.latent_size, generative_distribution_class, generative_bias_init=generative_bias_init, raw_sigma_bias=0.5) return inputs, targets, mmsis, lengths, model