Beispiel #1
0
  def __init__(self, input_dataset, features, num_parallel_calls):
    self._input_dataset = input_dataset
    if not structure.are_compatible(
        input_dataset.element_spec,
        tensor_spec.TensorSpec([None], dtypes.string)):
      raise TypeError("Input dataset should be a dataset of vectors of strings")
    self._num_parallel_calls = num_parallel_calls
    # pylint: disable=protected-access
    self._features = parsing_ops._prepend_none_dimension(features)
    # sparse_keys and dense_keys come back sorted here.
    (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
     dense_shapes) = parsing_ops._features_to_raw_params(
         self._features, [
             parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
             parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
         ])
    # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
    (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
     dense_shape_as_shape) = parsing_ops._process_raw_parameters(
         None, dense_defaults, sparse_keys, sparse_types, dense_keys,
         dense_types, dense_shapes)
    # pylint: enable=protected-access
    self._sparse_keys = sparse_keys
    self._sparse_types = sparse_types
    self._dense_keys = dense_keys
    self._dense_defaults = dense_defaults_vec
    self._dense_shapes = dense_shapes
    self._dense_types = dense_types
    input_dataset_shape = dataset_ops.get_legacy_output_shapes(
        self._input_dataset)
    dense_output_shapes = [input_dataset_shape.concatenate(shape)
                           for shape in dense_shape_as_shape]
    sparse_output_shapes = [input_dataset_shape.concatenate([None])
                            for _ in range(len(sparse_keys))]

    output_shapes = dict(
        zip(self._dense_keys + self._sparse_keys,
            dense_output_shapes + sparse_output_shapes))
    output_types = dict(
        zip(self._dense_keys + self._sparse_keys,
            self._dense_types + self._sparse_types))
    output_classes = dict(
        zip(self._dense_keys + self._sparse_keys,
            [ops.Tensor for _ in range(len(self._dense_defaults))] +
            [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
            ]))
    self._element_spec = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)

    if compat.forward_compatible(2019, 8, 3):
      variant_tensor = (
          gen_experimental_dataset_ops.parse_example_dataset(
              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
              self._num_parallel_calls,
              self._dense_defaults,
              self._sparse_keys,
              self._dense_keys,
              self._sparse_types,
              self._dense_shapes,
              **self._flat_structure))
    else:
      variant_tensor = (
          gen_experimental_dataset_ops.experimental_parse_example_dataset(
              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
              self._num_parallel_calls,
              self._dense_defaults,
              self._sparse_keys,
              self._dense_keys,
              self._sparse_types,
              self._dense_shapes,
              **self._flat_structure))
    super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
    def __init__(self, input_dataset, features, num_parallel_calls,
                 deterministic):
        self._input_dataset = input_dataset
        if not structure.are_compatible(
                input_dataset.element_spec,
                tensor_spec.TensorSpec([None], dtypes.string)):
            raise TypeError(
                "Input dataset should be a dataset of vectors of strings")
        self._num_parallel_calls = num_parallel_calls
        if deterministic is None:
            self._deterministic = "default"
        elif deterministic:
            self._deterministic = "true"
        else:
            self._deterministic = "false"
        # pylint: disable=protected-access
        self._features = parsing_ops._prepend_none_dimension(features)
        # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature
        params = parsing_ops._ParseOpParams.from_features(
            self._features, [
                parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
                parsing_ops.FixedLenFeature,
                parsing_ops.FixedLenSequenceFeature, parsing_ops.RaggedFeature
            ])
        # pylint: enable=protected-access
        self._sparse_keys = params.sparse_keys
        self._sparse_types = params.sparse_types
        self._ragged_keys = params.ragged_keys
        self._ragged_value_types = params.ragged_value_types
        self._ragged_split_types = params.ragged_split_types
        self._dense_keys = params.dense_keys
        self._dense_defaults = params.dense_defaults_vec
        self._dense_shapes = params.dense_shapes_as_proto
        self._dense_types = params.dense_types
        input_dataset_shape = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)

        self._element_spec = {}

        for (key, value_type) in zip(params.sparse_keys, params.sparse_types):
            self._element_spec[key] = sparse_tensor.SparseTensorSpec(
                input_dataset_shape.concatenate([None]), value_type)

        for (key, value_type, dense_shape) in zip(params.dense_keys,
                                                  params.dense_types,
                                                  params.dense_shapes):
            self._element_spec[key] = tensor_spec.TensorSpec(
                input_dataset_shape.concatenate(dense_shape), value_type)

        for (key, value_type, splits_type) in zip(params.ragged_keys,
                                                  params.ragged_value_types,
                                                  params.ragged_split_types):
            self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
                input_dataset_shape.concatenate([None]), value_type, 1,
                splits_type)

        if deterministic is not None or compat.forward_compatible(2020, 3, 6):
            variant_tensor = (
                gen_experimental_dataset_ops.parse_example_dataset_v2(
                    self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                    self._num_parallel_calls,
                    self._dense_defaults,
                    self._sparse_keys,
                    self._dense_keys,
                    self._sparse_types,
                    self._dense_shapes,
                    deterministic=self._deterministic,
                    ragged_keys=self._ragged_keys,
                    ragged_value_types=self._ragged_value_types,
                    ragged_split_types=self._ragged_split_types,
                    **self._flat_structure))
        else:
            variant_tensor = (
                gen_experimental_dataset_ops.parse_example_dataset(
                    self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                    self._num_parallel_calls,
                    self._dense_defaults,
                    self._sparse_keys,
                    self._dense_keys,
                    self._sparse_types,
                    self._dense_shapes,
                    ragged_keys=self._ragged_keys,
                    ragged_value_types=self._ragged_value_types,
                    ragged_split_types=self._ragged_split_types,
                    **self._flat_structure))
        super(_ParseExampleDataset, self).__init__(input_dataset,
                                                   variant_tensor)