예제 #1
0
    def serving_input_receiver_fn():
      """Create the ServingInputReceiver to export a saved model.

      Returns:
        An instance of ServingInputReceiver.
      """
      # We only assume one input, a string which containes the serialized proto.
      receiver_tensors = {
          'input_example_tensor':
              tf.placeholder(
                  dtype=tf.string, shape=[None], name='input_example_tensor')
      }
      feature_spec = self._get_input_features_for_receiver_fn()
      # We have to filter our specs since only required tensors are
      # used for inference time.
      flat_feature_spec = tensorspec_utils.flatten_spec_structure(feature_spec)
      required_feature_spec = (
          tensorspec_utils.filter_required_flat_tensor_spec(flat_feature_spec))

      tensor_dict, tensor_spec_dict = (
          tensorspec_utils.tensorspec_to_feature_dict(required_feature_spec))

      parse_tf_example_fn = tfdata.create_parse_tf_example_fn(
          tensor_dict=tensor_dict,
          tensor_spec_dict=tensor_spec_dict,
          feature_tspec=feature_spec)

      features = parse_tf_example_fn(receiver_tensors['input_example_tensor'])

      if (not self._export_raw_receivers and self._preprocess_fn is not None):
        features, _ = self._preprocess_fn(features=features, labels=None)

      return tf.estimator.export.ServingInputReceiver(features,
                                                      receiver_tensors)
예제 #2
0
        def serving_input_receiver_fn():
            """Create the ServingInputReceiver to export a saved model.

      Returns:
        An instance of ServingInputReceiver.
      """
            # We only assume one input, a string which containes the serialized proto.
            receiver_tensors = {
                'input_example_tensor':
                tf.placeholder(dtype=tf.string,
                               shape=[None],
                               name='input_example_tensor')
            }
            # We have to filter our specs since only required tensors are
            # used for inference time.
            flat_feature_spec = tensorspec_utils.flatten_spec_structure(
                self._feature_spec)

            # We need to freeze the conditioning and inference shapes.
            for key, value in flat_feature_spec.condition.items():
                ref_shape = value.shape.as_list()
                shape = [self._num_condition_samples_per_task] + ref_shape[1:]
                flat_feature_spec.condition[key] = (
                    tensorspec_utils.ExtendedTensorSpec.from_spec(value,
                                                                  shape=shape))

            for key, value in flat_feature_spec.inference.items():
                ref_shape = value.shape.as_list()
                shape = [self._num_inference_samples_per_task] + ref_shape[1:]
                flat_feature_spec.inference[key] = (
                    tensorspec_utils.ExtendedTensorSpec.from_spec(value,
                                                                  shape=shape))

            required_feature_spec = (
                tensorspec_utils.filter_required_flat_tensor_spec(
                    flat_feature_spec))

            tensor_dict, tensor_spec_dict = (
                tensorspec_utils.tensorspec_to_feature_dict(
                    required_feature_spec))

            parse_tf_example_fn = tfdata.create_parse_tf_example_fn(
                tensor_dict=tensor_dict,
                tensor_spec_dict=tensor_spec_dict,
                feature_tspec=self._feature_spec)

            features = parse_tf_example_fn(
                receiver_tensors['input_example_tensor'])

            if self._preprocess_fn is not None:
                features, _ = self._preprocess_fn(
                    features=features,
                    labels=None,
                    mode=tf.estimator.ModeKeys.PREDICT)

            return tf.estimator.export.ServingInputReceiver(
                features, receiver_tensors)
예제 #3
0
def serialized_to_parsed(dataset,
                         feature_tspec,
                         label_tspec,
                         num_parallel_calls=2):
    """Auto-generating TFExample parsing code from feature and label tensor specs.

  Supports both single-TFExample parsing (default) and batched parsing (e.g.
  when we are pulling batches from Replay Buffer).

  Args:
    dataset: tf.data.Dataset whose outputs are serialized tf.Examples.
    feature_tspec: Collection of TensorSpec designating how to extract features.
    label_tspec: Collection of TensorSpec designating how to extract labels.
    num_parallel_calls: (Optional.) A tf.int32 scalar tf.Tensor, representing
      the number elements to process in parallel. If not specified, elements
      will be processed sequentially.

  Returns:
    tf.data.Dataset whose output is single (features, labels) tuple.
  """
    tensor_dict = {}
    tensor_spec_dict = {}
    feature_dict, feature_tspec_dict = (
        tensorspec_utils.tensorspec_to_feature_dict(feature_tspec))
    tensor_dict.update(feature_dict)
    tensor_spec_dict.update(feature_tspec_dict)
    label_dict, label_tspec_dict = (
        tensorspec_utils.tensorspec_to_feature_dict(label_tspec))
    tensor_dict.update(label_dict)
    tensor_spec_dict.update(label_tspec_dict)

    parse_tf_example_fn = create_parse_tf_example_fn(
        tensor_dict=tensor_dict,
        tensor_spec_dict=tensor_spec_dict,
        feature_tspec=feature_tspec,
        label_tspec=label_tspec)
    dataset = dataset.map(map_func=parse_tf_example_fn,
                          num_parallel_calls=num_parallel_calls)
    return dataset
예제 #4
0
 def test_tensorspec_to_feature_dict(self):
     features, tensor_spec_dict = utils.tensorspec_to_feature_dict(
         mock_nested_subset_spec, decode_images=True)
     self.assertDictEqual(tensor_spec_dict, {
         'images': T1,
         'actions': T2,
     })
     self.assertDictEqual(
         features, {
             'images': tf.FixedLenFeature((), tf.string),
             'actions': tf.FixedLenFeature(T2.shape, T2.dtype),
         })
     features, tensor_spec_dict = utils.tensorspec_to_feature_dict(
         mock_nested_subset_spec, decode_images=False)
     self.assertDictEqual(tensor_spec_dict, {
         'images': T1,
         'actions': T2,
     })
     self.assertDictEqual(
         features, {
             'images': tf.FixedLenFeature(T1.shape, T1.dtype),
             'actions': tf.FixedLenFeature(T2.shape, T2.dtype),
         })
예제 #5
0
  def parse_tf_example_fn(*input_values):
    """Maps string tensors (serialized TFExamples) to parsed tensors.

    Args:
      *input_values: A (string tensor,) tuple if mapping from a RecordIODataset
        or TFRecordDataset, or a (key, string tensor) tuple if mapping from a
        SSTableDataset, or (Dict[dataset_key, values],) if mapping from multiple
        datasets.

    Returns:
      features: Collection of tensors conforming to feature_tspec.
      labels: Collection of tensors conforming to label_tspec.
    Raises:
        ValueError: If dtype other than uint8 or uint16 is supplied for image
        specs.
    """
    dict_extracted = _get_sstable_proto_dict(*input_values)

    def parse_wrapper(example, spec_dict):
      """Wrap tf.parse_example to support bfloat16 dtypes.

      This allows models which declare bfloat16 as inputs to not require an
      additional preprocessing step to cast all inputs from float32 to bfloat16.
      Consider this to be analogous to JPEG decoding in the data step.

      Args:
        example: TFExample
        spec_dict: Dictionary of feature name -> tf.FixedLenFeature

      Returns:
        Parsed feature map
      """

      def is_bfloat_feature(value):
        return value.dtype == tf.bfloat16

      def maybe_map_bfloat(value):
        """Maps bfloat16 to float32."""
        if is_bfloat_feature(value):
          if isinstance(value, tf.FixedLenFeature):
            return tf.FixedLenFeature(
                value.shape, tf.float32, default_value=value.default_value)
          elif isinstance(value, tf.VarLenFeature):
            return tf.VarLenFeature(
                value.shape, tf.float32, default_value=value.default_value)
          else:
            return tf.FixedLenSequenceFeature(
                value.shape, tf.float32, default_value=value.default_value)
        return value

      # Change bfloat features to float32 for parsing.
      new_spec_dict = {
          k: maybe_map_bfloat(v) for k, v in six.iteritems(spec_dict)
      }
      for k, v in six.iteritems(new_spec_dict):
        if v.dtype not in [tf.float32, tf.string, tf.int64]:
          raise ValueError('Feature specification with invalid data type for '
                           'tf.Example parsing: "%s": %s' % (k, v.dtype))

      # Separate new_spec_dict into Context and Sequence features. In the event
      # that there are no SequenceFeatures, the context_features dictionary
      # (containing FixedLenFeatures) is passed to tf.parse_examples.
      context_features, sequence_features = {}, {}
      for k, v in six.iteritems(new_spec_dict):
        v = maybe_map_bfloat(v)
        if isinstance(v, tf.FixedLenSequenceFeature):
          sequence_features[k] = v
        elif isinstance(v, tf.FixedLenFeature):
          context_features[k] = v
        elif isinstance(v, tf.VarLenFeature):
          context_features[k] = v
        else:
          raise ValueError(
              'Only FixedLenFeature and FixedLenSequenceFeature are currently '
              'supported.')

      # If there are any sequence features, we use parse_sequence_example.
      if sequence_features:
        # Filter out '_length' context features; don't parse them from records.
        for parse_name in sequence_features:
          # Sometimes, the '_length' context feature doesn't exist.
          if parse_name + '_length' in context_features:
            del context_features[parse_name + '_length']
        result, sequence_result, feature_lengths = tf.io.parse_sequence_example(
            example,
            context_features=context_features,
            sequence_features=sequence_features)
        result.update(sequence_result)
        # Augment the parsed tensors with feature length tensors.
        for parse_name, length_tensor in feature_lengths.items():
          result[parse_name + '_length'] = length_tensor
      else:
        result = tf.parse_example(example, context_features)
      to_convert = [
          k for k, v in six.iteritems(spec_dict) if is_bfloat_feature(v)
      ]

      for c in to_convert:
        result[c] = tf.cast(result[c], tf.bfloat16)

      return result

    prepend_keys = lambda d, pre: {pre + k: v for k, v in list(d.items())}
    # Parse each dataset's tensors. Parsed results from parse_wrapper get
    # dataset_key prepended to ensure uniqueness of keys among datasets.
    parsed_tensors = {}
    # {Prepended parsed key : TensorSpecs} for all datasets. Will contain
    # '_length' TensorSpecs that won't actually get parsed. We filter those out
    # before passing to the parse_sequence_example call.
    tensor_spec_dict = {}
    for dataset_key, example_proto in dict_extracted.items():
      # Parsed key to Feature Specs (retained only for this dataset).
      tensor_dict = {}
      sub_feature_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
          feature_tspec, dataset_key)
      feature_dict, feature_tspec_dict = (
          tensorspec_utils.tensorspec_to_feature_dict(
              sub_feature_tspec, decode_images=decode_images))
      tensor_dict.update(feature_dict)
      tensor_spec_dict.update(prepend_keys(feature_tspec_dict, dataset_key))
      if label_tspec is not None:
        sub_label_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
            label_tspec, dataset_key)
        label_dict, label_tspec_dict = (
            tensorspec_utils.tensorspec_to_feature_dict(
                sub_label_tspec, decode_images=decode_images))
        tensor_dict.update(label_dict)
        tensor_spec_dict.update(prepend_keys(label_tspec_dict, dataset_key))
      for key, parsed in parse_wrapper(example_proto, tensor_dict).items():
        parsed_tensors[dataset_key + key] = parsed

    # At this point, all tensors have been parsed into a single flat map.
    # Interpret encoded images.
    def decode_image(key, raw_bytes):
      """Decodes single or batches of JPEG- or PNG-encoded string tensors.

      Args:
        key: String key specified in feature map.
        raw_bytes: String tensor to decode as JPEG or PNG.

      Returns:
        Decoded image tensor with shape specified by tensor spec.
      Raises:
        ValueError: If dtype other than uint8 or uint16 is supplied for image
        specs.
      """
      img_batch_dims = tf.shape(raw_bytes)
      # The spatial + channel dimensions of a single image, assumed to be the
      # last 3 entries of the image feature's tensor spec.
      if len(tensor_spec_dict[key].shape) < 3:
        raise ValueError(
            'Shape of tensor spec for image feature "%s" must '
            'be 3 dimensional (h, w, c), but is %s' %
            (tensor_spec_dict[key].name, tensor_spec_dict[key].shape))
      single_img_dims = tensor_spec_dict[key].shape[-3:]
      num_channels = single_img_dims[2]
      if num_channels not in [1, 3]:
        raise ValueError(
            'Last dimension of shape of tensor spec for image '
            'feature "%s" must 1 or 3, but the shape is %s' %
            (tensor_spec_dict[key].name, tensor_spec_dict[key].shape))

      # Collapse (possibly multiple) batch dims to a single batch dim for
      # decoding purposes.
      raw_bytes = tf.reshape(raw_bytes, [-1])
      data_type = tensor_spec_dict[key].dtype
      if data_type not in SUPPORTED_PIXEL_ENCODINGS:
        raise ValueError('Decoding an image requires tensorspec.data_type '
                         'to be uint8 or uint16.')

      def _decode_images(image_bytes):
        """Decode single image."""
        def _zero_image():
          return tf.zeros(single_img_dims, dtype=data_type)

        def _tf_decode_image():
          return tf.image.decode_image(
              image_bytes, channels=num_channels, dtype=data_type)

        image = tf.cond(
            tf.equal(image_bytes, ''), _zero_image, _tf_decode_image)
        image.set_shape(single_img_dims)
        return image

      img = tf.map_fn(
          _decode_images, raw_bytes, dtype=data_type, back_prop=False)
      img.set_shape(raw_bytes.shape.concatenate(single_img_dims))

      # Expand the collapsed batch dim back to the original img_batch_dims.
      img = tf.reshape(img, tf.concat([img_batch_dims, single_img_dims], 0))

      return img

    # Convert all sparse tensors to dense tensors.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensor_spec.varlen_default_value is not None:
        if tensorspec_utils.is_encoded_image_spec(tensor_spec):
          default_value = ''
        else:
          default_value = tf.cast(
              tf.constant(tensor_spec.varlen_default_value),
              dtype=tensor_spec.dtype)
        parsed_tensors[key] = tf.sparse.to_dense(
            val, default_value=default_value)

    # Ensure that all images are properly decoded.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensorspec_utils.is_encoded_image_spec(tensor_spec) and decode_images:
        parsed_tensors[key] = decode_image(key, val)
        if tensor_spec.dtype not in SUPPORTED_PIXEL_ENCODINGS:
          raise ValueError('Encoded images with key {} must be '
                           'specified with uint8 or uint16 dtype.'.format(key))

    # Pad all varlen features to the corrensponding spec.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensor_spec.varlen_default_value is not None:
        parsed_tensors[key] = tensorspec_utils.pad_or_clip_tensor_to_spec_shape(
            val, tensor_spec)

    # Ensure that we have a consistent ordered mapping despite the underlying
    # spec structure.
    flat_feature_tspec = tensorspec_utils.TensorSpecStruct(
        sorted(tensorspec_utils.flatten_spec_structure(feature_tspec).items()))
    # Using the flat spec structure we allow to map the same parsed_tensor
    # to multiple features or labels. Note, the spec structure ensures that
    # the corresponding tensorspecs are iddentical in such cases.
    features = tensorspec_utils.TensorSpecStruct([
        (key, parsed_tensors[value.dataset_key + value.name])
        for key, value in flat_feature_tspec.items()
    ])

    features = tensorspec_utils.validate_and_pack(
        flat_feature_tspec, features, ignore_batch=True)
    if label_tspec is not None:
      # Ensure that we have a consistent ordered mapping despite the underlying
      # spec structure.
      flat_label_tspec = tensorspec_utils.TensorSpecStruct(
          sorted(tensorspec_utils.flatten_spec_structure(label_tspec).items()))
      labels = tensorspec_utils.TensorSpecStruct([
          (key, parsed_tensors[value.dataset_key + value.name])
          for key, value in flat_label_tspec.items()
      ])
      labels = tensorspec_utils.validate_and_pack(
          flat_label_tspec, labels, ignore_batch=True)
      return features, labels
    return features
예제 #6
0
    def parse_tf_example_fn(*input_values):
        """Maps string tensors (serialized TFExamples) to parsed tensors.

    Args:
      *input_values: A (string tensor,) tuple if mapping from a
        RecordIODataset or TFRecordDataset, or a (key, string tensor) tuple if
        mapping from a SSTableDataset, or (Dict[dataset_key, values],) if
        mapping from multiple datasets.

    Returns:
      features: Collection of tensors conforming to feature_tspec.
      labels: Collection of tensors conforming to label_tspec.
    Raises:
        ValueError: If dtype other than uint8 is supplied for image specs.
    """
        dict_extracted = {}
        if isinstance(input_values[0], dict):
            for key, serialized_proto in input_values[0].items():
                if isinstance(serialized_proto, tuple):
                    # Assume an SSTable key, value pair.
                    _, dict_extracted[key] = serialized_proto
                else:
                    dict_extracted[key] = serialized_proto
        else:
            if len(input_values) == 2:
                _, dict_extracted[''] = input_values
            else:
                dict_extracted[''], = input_values

        def parse_wrapper(example, spec_dict):
            """Wrap tf.parse_example to support bfloat16 dtypes.

      This allows models which declare bfloat16 as inputs to not require an
      additional preprocessing step to cast all inputs from float32 to bfloat16.
      Consider this to be analogous to JPEG decoding in the data step.

      Args:
        example: TFExample
        spec_dict: Dictionary of feature name -> tf.FixedLenFeature
      Returns:
        Parsed feature map
      """
            def is_bfloat_feature(value):
                return value.dtype == tf.bfloat16

            def maybe_map_bfloat(value):
                if is_bfloat_feature(value):
                    if isinstance(value, tf.FixedLenFeature):
                        return tf.FixedLenFeature(
                            value.shape,
                            tf.float32,
                            default_value=value.default_value)
                    else:
                        return tf.FixedLenSequenceFeature(
                            value.shape,
                            tf.float32,
                            default_value=value.default_value)
                return value

            # Change bfloat features to float32 for parsing.
            new_spec_dict = {
                k: maybe_map_bfloat(v)
                for k, v in six.iteritems(spec_dict)
            }
            for k, v in six.iteritems(new_spec_dict):
                if v.dtype not in [tf.float32, tf.string, tf.int64]:
                    raise ValueError(
                        'Feature specification with invalid data type for '
                        'tf.Example parsing: "%s": %s' % (k, v.dtype))

            # Separate new_spec_dict into Context and Sequence features. In the event
            # that there are no SequenceFeatures, the context_features dictionary
            # (containing FixedLenFeatures) is passed to tf.parse_examples.
            context_features, sequence_features = {}, {}
            for k, v in six.iteritems(new_spec_dict):
                v = maybe_map_bfloat(v)
                if isinstance(v, tf.FixedLenSequenceFeature):
                    sequence_features[k] = v
                elif isinstance(v, tf.FixedLenFeature):
                    context_features[k] = v
                else:
                    raise ValueError(
                        'Only FixedLenFeature and FixedLenSequenceFeature are currently '
                        'supported.')

            # If there are any sequence features, we use parse_sequence_example.
            if sequence_features:
                result, sequence_result, feature_lengths = tf.io.parse_sequence_example(
                    example,
                    context_features=context_features,
                    sequence_features=sequence_features)
                del feature_lengths
                result.update(sequence_result)
            else:
                result = tf.parse_example(example, context_features)
            to_convert = [
                k for k, v in six.iteritems(spec_dict) if is_bfloat_feature(v)
            ]

            for c in to_convert:
                result[c] = tf.cast(result[c], tf.bfloat16)

            return result

        prepend_keys = lambda d, pre: {pre + k: v for k, v in list(d.items())}
        # Parse each dataset's tensors. Parsed results from parse_wrapper get
        # dataset_key prepended to ensure uniqueness of keys among datasets.
        parsed_tensors = {}
        # {Prepended parsed key : TensorSpecs} for all datasets.
        tensor_spec_dict = {}
        for dataset_key, example_proto in dict_extracted.items():
            # Parsed key to Feature Specs (retained only for this dataset).
            tensor_dict = {}
            sub_feature_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
                feature_tspec, dataset_key)
            feature_dict, feature_tspec_dict = (
                tensorspec_utils.tensorspec_to_feature_dict(sub_feature_tspec))
            tensor_dict.update(feature_dict)
            tensor_spec_dict.update(
                prepend_keys(feature_tspec_dict, dataset_key))
            if label_tspec is not None:
                sub_label_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
                    label_tspec, dataset_key)
                label_dict, label_tspec_dict = (
                    tensorspec_utils.tensorspec_to_feature_dict(
                        sub_label_tspec))
                tensor_dict.update(label_dict)
                tensor_spec_dict.update(
                    prepend_keys(label_tspec_dict, dataset_key))
            for key, parsed in parse_wrapper(example_proto,
                                             tensor_dict).items():
                parsed_tensors[dataset_key + key] = parsed

        # At this point, all tensors have been parsed into a single flat map.
        # Interpret encoded images.
        def decode_image(key, raw_bytes):
            """Decodes single or batches of JPEG- or PNG-encoded string tensors.

      Args:
        key: String key specified in feature map.
        raw_bytes: String tensor to decode as JPEG or PNG.

      Returns:
        Decoded image tensor with shape specified by tensor spec.
      Raises:
        ValueError: If dtype other than uint8 is supplied for image specs.
      """
            img_batch_dims = tf.shape(raw_bytes)
            # The spatial + channel dimensions of a single image, assumed to be the
            # last 3 entries of the image feature's tensor spec.
            single_img_dims = tensor_spec_dict[key].shape[-3:]

            # Collapse (possibly multiple) batch dims to a single batch dim for
            # decoding purposes.
            raw_bytes = tf.reshape(raw_bytes, [-1])
            img = tf.map_fn(tf.image.decode_image,
                            raw_bytes,
                            dtype=tf.uint8,
                            back_prop=False)
            img.set_shape(raw_bytes.shape.concatenate(single_img_dims))

            # Expand the collapsed batch dim back to the original img_batch_dims.
            img = tf.reshape(img,
                             tf.concat([img_batch_dims, single_img_dims], 0))

            return img

        for key, val in parsed_tensors.items():
            tensor_spec = tensor_spec_dict[key]
            if tensorspec_utils.is_encoded_image_spec(tensor_spec):
                parsed_tensors[key] = decode_image(key, val)
                if tensor_spec.dtype != tf.uint8:
                    raise ValueError('Encoded images with key {} must be '
                                     'specified with uint8 dtype.'.format(key))

        # Ensure that we have a consistent ordered mapping despite the underlying
        # spec structure.
        flat_feature_tspec = tensorspec_utils.TensorSpecStruct(
            sorted(
                tensorspec_utils.flatten_spec_structure(
                    feature_tspec).items()))
        # Using the flat spec structure we allow to map the same parsed_tensor
        # to multiple features or labels. Note, the spec structure ensures that
        # the corresponding tensorspecs are iddentical in such cases.
        features = tensorspec_utils.TensorSpecStruct([
            (key, parsed_tensors[value.dataset_key + value.name])
            for key, value in flat_feature_tspec.items()
        ])

        features = tensorspec_utils.validate_and_pack(flat_feature_tspec,
                                                      features,
                                                      ignore_batch=True)
        if label_tspec is not None:
            # Ensure that we have a consistent ordered mapping despite the underlying
            # spec structure.
            flat_label_tspec = tensorspec_utils.TensorSpecStruct(
                sorted(
                    tensorspec_utils.flatten_spec_structure(
                        label_tspec).items()))
            labels = tensorspec_utils.TensorSpecStruct([
                (key, parsed_tensors[value.dataset_key + value.name])
                for key, value in flat_label_tspec.items()
            ])
            labels = tensorspec_utils.validate_and_pack(flat_label_tspec,
                                                        labels,
                                                        ignore_batch=True)
            return features, labels
        return features