def create_dataset_input_fn(self, mode): """Create the dataset input_fn used for train and eval. We simply wrap the existing tfdata implementation such that we have a consistent input generator interface. Args: mode: (ModeKeys) Specifies if this is training, evaluation or prediction. Returns: A valid input_fn for the estimator api. """ self._assert_specs_initialized() logging.info('Creating InputGenerator %s with file patterns:\n%s', self._label, self._file_patterns) input_fn = tfdata.get_input_fn( file_patterns=self._file_patterns or self._dataset_map, batch_size=self.batch_size, feature_spec=self._feature_spec, label_spec=self._label_spec, mode=mode, preprocess_fn=self._preprocess_fn) return input_fn
def create_dataset(self, mode, params=None): """Create the actual input_fn. This is potentially wrapped in create_dataset_input_fn. Args: mode: (ModeKeys) Specifies if this is training, evaluation or prediction. params: Not used for this implementation but expected by callers. An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic Python types. There are reserved keys for TPUEstimator, including 'batch_size'. Returns: A valid input_fn for the estimator api. """ input_fn = tfdata.get_input_fn(file_patterns=self._file_patterns or self._dataset_map, batch_size=self.batch_size, feature_spec=self._feature_spec, label_spec=self._label_spec, mode=mode, preprocess_fn=self._preprocess_fn) return input_fn(params)
def create_dataset_input_fn(self, mode): """Create the dataset input_fn used for train and eval. We simply wrap the existing tfdata implementation such that we have a consistent input generator interface. Args: mode: (ModeKeys) Specifies if this is training, evaluation or prediction. Returns: A valid input_fn for the estimator api. """ self._assert_specs_initialized() logging.info('Creating InputGenerator %s with file patterns:\n%s', self._label, self._file_patterns) input_fn = tfdata.get_input_fn(file_patterns=self._file_patterns or self._dataset_map, batch_size=self.batch_size, feature_spec=self._feature_spec, label_spec=self._label_spec, mode=mode, preprocess_fn=self._preprocess_fn) if self._guzzler_server_address: def dataguzzler_dataset_fn(params=None): def compressed_dataset(params): dataset = input_fn(params=params) compress_fn = tfdata.create_compress_fn( feature_spec=self._out_feature_spec, label_spec=self._out_label_spec, quality=self._guzzler_compression_quality) dataset = dataset.map(compress_fn, num_parallel_calls=2) return dataset if self._guzzler_use_compression: tf.logging.info('Use compressed dataset.') dataset = guzzler_dataset.DataGuzzlerDataset( dataset_fn=lambda: compressed_dataset(params=params), guzzler_server_address=self._guzzler_server_address, guzzler_timeout_ms=self._guzzler_timeout_ms, guzzler_graph_key='') else: tf.logging.info('Use uncompressed dataset.') dataset = guzzler_dataset.DataGuzzlerDataset( dataset_fn=lambda: input_fn(params=params), guzzler_server_address=self._guzzler_server_address, guzzler_timeout_ms=self._guzzler_timeout_ms, guzzler_graph_key='') dataset.PublishGraphToModelDir(self._guzzler_output_dir) if self._guzzler_use_compression: tf.logging.info('Use decompression.') decompress_fn = tfdata.create_decompress_fn( feature_spec=self._out_feature_spec, label_spec=self._out_label_spec) dataset = dataset.map( decompress_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) else: tf.logging.info('Use no decompression.') return dataset.prefetch(2) return dataguzzler_dataset_fn return input_fn