def get_dataset(self, batch_size, epochs_between_evals): """Construct the dataset to be used for training and eval. For local training, data is provided through Dataset.from_generator. For remote training (TPUs) the data is first serialized to files and then sent to the TPU through a StreamingFilesDataset. Args: batch_size: The per-replica batch size of the dataset. epochs_between_evals: How many epochs worth of data to yield. (Generator mode only.) """ self.increment_request_epoch() if self._stream_files: if epochs_between_evals > 1: raise ValueError("epochs_between_evals > 1 not supported for file " "based dataset.") epoch_data_dir = self._result_queue.get(timeout=300) if not self._is_training: self._result_queue.put(epoch_data_dir) # Eval data is reused. file_pattern = os.path.join( epoch_data_dir, rconst.SHARD_TEMPLATE.format("*")) dataset = StreamingFilesDataset( files=file_pattern, worker_job=popen_helper.worker_job(), num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1, sloppy=not self._deterministic) map_fn = functools.partial( self.deserialize, batch_size=batch_size, is_training=self._is_training) dataset = dataset.map(map_fn, num_parallel_calls=16) else: types = {movielens.USER_COLUMN: rconst.USER_DTYPE, movielens.ITEM_COLUMN: rconst.ITEM_DTYPE} shapes = { movielens.USER_COLUMN: tf.TensorShape([batch_size, 1]), movielens.ITEM_COLUMN: tf.TensorShape([batch_size, 1]) } if self._is_training: types[rconst.VALID_POINT_MASK] = np.bool shapes[rconst.VALID_POINT_MASK] = tf.TensorShape([batch_size, 1]) types = (types, np.bool) shapes = (shapes, tf.TensorShape([batch_size, 1])) else: types[rconst.DUPLICATE_MASK] = np.bool shapes[rconst.DUPLICATE_MASK] = tf.TensorShape([batch_size, 1]) data_generator = functools.partial( self.data_generator, epochs_between_evals=epochs_between_evals) dataset = tf.data.Dataset.from_generator( generator=data_generator, output_types=types, output_shapes=shapes) return dataset.prefetch(16)
def get_dataset(self, batch_size, epochs_between_evals): """Construct the dataset to be used for training and eval. For local training, data is provided through Dataset.from_generator. For remote training (TPUs) the data is first serialized to files and then sent to the TPU through a StreamingFilesDataset. Args: batch_size: The per-replica batch size of the dataset. epochs_between_evals: How many epochs worth of data to yield. (Generator mode only.) """ self.increment_request_epoch() if self._stream_files: if epochs_between_evals > 1: raise ValueError("epochs_between_evals > 1 not supported for file " "based dataset.") epoch_data_dir = self._result_queue.get(timeout=300) if not self._is_training: self._result_queue.put(epoch_data_dir) # Eval data is reused. file_pattern = os.path.join( epoch_data_dir, rconst.SHARD_TEMPLATE.format("*")) # TODO(seemuch): remove this contrib import # pylint: disable=line-too-long from tensorflow.contrib.tpu.python.tpu.datasets import StreamingFilesDataset # pylint: enable=line-too-long dataset = StreamingFilesDataset( files=file_pattern, worker_job=popen_helper.worker_job(), num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1, sloppy=not self._deterministic) map_fn = functools.partial(self._deserialize, batch_size=batch_size) dataset = dataset.map(map_fn, num_parallel_calls=16) else: types = {movielens.USER_COLUMN: rconst.USER_DTYPE, movielens.ITEM_COLUMN: rconst.ITEM_DTYPE} shapes = {movielens.USER_COLUMN: tf.TensorShape([batch_size]), movielens.ITEM_COLUMN: tf.TensorShape([batch_size])} if self._is_training: types[rconst.VALID_POINT_MASK] = np.bool shapes[rconst.VALID_POINT_MASK] = tf.TensorShape([batch_size]) types = (types, np.bool) shapes = (shapes, tf.TensorShape([batch_size])) else: types[rconst.DUPLICATE_MASK] = np.bool shapes[rconst.DUPLICATE_MASK] = tf.TensorShape([batch_size]) data_generator = functools.partial( self.data_generator, epochs_between_evals=epochs_between_evals) dataset = tf.data.Dataset.from_generator( generator=data_generator, output_types=types, output_shapes=shapes) return dataset.prefetch(16)