def create_dataset_input_fn(self, mode):
    """Create the dataset input_fn used for train and eval.

    We simply wrap the existing tfdata implementation such that we have a
    consistent input generator interface.

    Args:
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.

    Returns:
      A valid input_fn for the estimator api.
    """
    self._assert_specs_initialized()
    logging.info('Creating InputGenerator %s with file patterns:\n%s',
                 self._label, self._file_patterns)
    input_fn = tfdata.get_input_fn(
        file_patterns=self._file_patterns or self._dataset_map,
        batch_size=self.batch_size,
        feature_spec=self._feature_spec,
        label_spec=self._label_spec,
        mode=mode,
        preprocess_fn=self._preprocess_fn)


    return input_fn
예제 #2
0
    def create_dataset(self, mode, params=None):
        """Create the actual input_fn.

    This is potentially wrapped in create_dataset_input_fn.

    Args:
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      params: Not used for this implementation but expected by callers. An
        optional dict of hyper parameters that will be passed into input_fn and
        model_fn. Keys are names of parameters, values are basic Python types.
        There are reserved keys for TPUEstimator, including 'batch_size'.

    Returns:
      A valid input_fn for the estimator api.
    """
        input_fn = tfdata.get_input_fn(file_patterns=self._file_patterns
                                       or self._dataset_map,
                                       batch_size=self.batch_size,
                                       feature_spec=self._feature_spec,
                                       label_spec=self._label_spec,
                                       mode=mode,
                                       preprocess_fn=self._preprocess_fn)
        return input_fn(params)
예제 #3
0
    def create_dataset_input_fn(self, mode):
        """Create the dataset input_fn used for train and eval.

    We simply wrap the existing tfdata implementation such that we have a
    consistent input generator interface.

    Args:
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.

    Returns:
      A valid input_fn for the estimator api.
    """
        self._assert_specs_initialized()
        logging.info('Creating InputGenerator %s with file patterns:\n%s',
                     self._label, self._file_patterns)
        input_fn = tfdata.get_input_fn(file_patterns=self._file_patterns
                                       or self._dataset_map,
                                       batch_size=self.batch_size,
                                       feature_spec=self._feature_spec,
                                       label_spec=self._label_spec,
                                       mode=mode,
                                       preprocess_fn=self._preprocess_fn)

        if self._guzzler_server_address:

            def dataguzzler_dataset_fn(params=None):
                def compressed_dataset(params):
                    dataset = input_fn(params=params)
                    compress_fn = tfdata.create_compress_fn(
                        feature_spec=self._out_feature_spec,
                        label_spec=self._out_label_spec,
                        quality=self._guzzler_compression_quality)
                    dataset = dataset.map(compress_fn, num_parallel_calls=2)
                    return dataset

                if self._guzzler_use_compression:
                    tf.logging.info('Use compressed dataset.')
                    dataset = guzzler_dataset.DataGuzzlerDataset(
                        dataset_fn=lambda: compressed_dataset(params=params),
                        guzzler_server_address=self._guzzler_server_address,
                        guzzler_timeout_ms=self._guzzler_timeout_ms,
                        guzzler_graph_key='')
                else:
                    tf.logging.info('Use uncompressed dataset.')
                    dataset = guzzler_dataset.DataGuzzlerDataset(
                        dataset_fn=lambda: input_fn(params=params),
                        guzzler_server_address=self._guzzler_server_address,
                        guzzler_timeout_ms=self._guzzler_timeout_ms,
                        guzzler_graph_key='')
                dataset.PublishGraphToModelDir(self._guzzler_output_dir)
                if self._guzzler_use_compression:
                    tf.logging.info('Use decompression.')
                    decompress_fn = tfdata.create_decompress_fn(
                        feature_spec=self._out_feature_spec,
                        label_spec=self._out_label_spec)
                    dataset = dataset.map(
                        decompress_fn,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
                else:
                    tf.logging.info('Use no decompression.')
                return dataset.prefetch(2)

            return dataguzzler_dataset_fn

        return input_fn