예제 #1
0
    def test_pad_or_clip_tensor_to_spec_shape(self, input_data,
                                              expected_output):
        varlen_spec = utils.ExtendedTensorSpec(shape=(3, ),
                                               dtype=tf.int64,
                                               name='varlen',
                                               varlen_default_value=3.0)
        tmp_dir = self.create_tempdir().full_path
        file_path_padded_to_size_two = os.path.join(tmp_dir,
                                                    'size_two.tfrecord')
        self._write_test_examples(input_data, file_path_padded_to_size_two)
        dataset = tf.data.TFRecordDataset(
            filenames=tf.constant([file_path_padded_to_size_two]))
        dataset = dataset.batch(len(input_data), drop_remainder=True)

        def parse_fn(example):
            return tf.parse_example(example,
                                    {'varlen': tf.VarLenFeature(tf.int64)})

        dataset = dataset.map(parse_fn)
        sparse_tensors = dataset.make_one_shot_iterator().get_next()['varlen']
        default_value = tf.cast(tf.constant(varlen_spec.varlen_default_value),
                                dtype=varlen_spec.dtype)
        tensor = utils.pad_or_clip_tensor_to_spec_shape(
            tf.sparse.to_dense(sparse_tensors, default_value), varlen_spec)
        with self.session() as sess:
            np_tensor = sess.run(tensor)
            self.assertAllEqual(np_tensor, np.array(expected_output))
예제 #2
0
 def test_pad_image_tensor_to_spec_shape(self):
     varlen_spec = utils.ExtendedTensorSpec(shape=(3, 2, 2, 1),
                                            dtype=tf.uint8,
                                            name='varlen',
                                            data_format='png',
                                            varlen_default_value=3.0)
     test_data = [[
         [[[1]] * 2] * 2,
         [[[2]] * 2] * 2,
     ]]
     prepadded_tensor = tf.convert_to_tensor(test_data,
                                             dtype=varlen_spec.dtype)
     tensor = utils.pad_or_clip_tensor_to_spec_shape(
         prepadded_tensor, varlen_spec)
     with self.session() as sess:
         np_tensor = sess.run(tensor)
         self.assertAllEqual(
             np_tensor,
             np.array([[
                 [[[1]] * 2] * 2,
                 [[[2]] * 2] * 2,
                 [[[3]] * 2] * 2,
             ]]))
예제 #3
0
  def parse_tf_example_fn(*input_values):
    """Maps string tensors (serialized TFExamples) to parsed tensors.

    Args:
      *input_values: A (string tensor,) tuple if mapping from a RecordIODataset
        or TFRecordDataset, or a (key, string tensor) tuple if mapping from a
        SSTableDataset, or (Dict[dataset_key, values],) if mapping from multiple
        datasets.

    Returns:
      features: Collection of tensors conforming to feature_tspec.
      labels: Collection of tensors conforming to label_tspec.
    Raises:
        ValueError: If dtype other than uint8 or uint16 is supplied for image
        specs.
    """
    dict_extracted = _get_sstable_proto_dict(*input_values)

    def parse_wrapper(example, spec_dict):
      """Wrap tf.parse_example to support bfloat16 dtypes.

      This allows models which declare bfloat16 as inputs to not require an
      additional preprocessing step to cast all inputs from float32 to bfloat16.
      Consider this to be analogous to JPEG decoding in the data step.

      Args:
        example: TFExample
        spec_dict: Dictionary of feature name -> tf.FixedLenFeature

      Returns:
        Parsed feature map
      """

      def is_bfloat_feature(value):
        return value.dtype == tf.bfloat16

      def maybe_map_bfloat(value):
        """Maps bfloat16 to float32."""
        if is_bfloat_feature(value):
          if isinstance(value, tf.FixedLenFeature):
            return tf.FixedLenFeature(
                value.shape, tf.float32, default_value=value.default_value)
          elif isinstance(value, tf.VarLenFeature):
            return tf.VarLenFeature(
                value.shape, tf.float32, default_value=value.default_value)
          else:
            return tf.FixedLenSequenceFeature(
                value.shape, tf.float32, default_value=value.default_value)
        return value

      # Change bfloat features to float32 for parsing.
      new_spec_dict = {
          k: maybe_map_bfloat(v) for k, v in six.iteritems(spec_dict)
      }
      for k, v in six.iteritems(new_spec_dict):
        if v.dtype not in [tf.float32, tf.string, tf.int64]:
          raise ValueError('Feature specification with invalid data type for '
                           'tf.Example parsing: "%s": %s' % (k, v.dtype))

      # Separate new_spec_dict into Context and Sequence features. In the event
      # that there are no SequenceFeatures, the context_features dictionary
      # (containing FixedLenFeatures) is passed to tf.parse_examples.
      context_features, sequence_features = {}, {}
      for k, v in six.iteritems(new_spec_dict):
        v = maybe_map_bfloat(v)
        if isinstance(v, tf.FixedLenSequenceFeature):
          sequence_features[k] = v
        elif isinstance(v, tf.FixedLenFeature):
          context_features[k] = v
        elif isinstance(v, tf.VarLenFeature):
          context_features[k] = v
        else:
          raise ValueError(
              'Only FixedLenFeature and FixedLenSequenceFeature are currently '
              'supported.')

      # If there are any sequence features, we use parse_sequence_example.
      if sequence_features:
        # Filter out '_length' context features; don't parse them from records.
        for parse_name in sequence_features:
          # Sometimes, the '_length' context feature doesn't exist.
          if parse_name + '_length' in context_features:
            del context_features[parse_name + '_length']
        result, sequence_result, feature_lengths = tf.io.parse_sequence_example(
            example,
            context_features=context_features,
            sequence_features=sequence_features)
        result.update(sequence_result)
        # Augment the parsed tensors with feature length tensors.
        for parse_name, length_tensor in feature_lengths.items():
          result[parse_name + '_length'] = length_tensor
      else:
        result = tf.parse_example(example, context_features)
      to_convert = [
          k for k, v in six.iteritems(spec_dict) if is_bfloat_feature(v)
      ]

      for c in to_convert:
        result[c] = tf.cast(result[c], tf.bfloat16)

      return result

    prepend_keys = lambda d, pre: {pre + k: v for k, v in list(d.items())}
    # Parse each dataset's tensors. Parsed results from parse_wrapper get
    # dataset_key prepended to ensure uniqueness of keys among datasets.
    parsed_tensors = {}
    # {Prepended parsed key : TensorSpecs} for all datasets. Will contain
    # '_length' TensorSpecs that won't actually get parsed. We filter those out
    # before passing to the parse_sequence_example call.
    tensor_spec_dict = {}
    for dataset_key, example_proto in dict_extracted.items():
      # Parsed key to Feature Specs (retained only for this dataset).
      tensor_dict = {}
      sub_feature_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
          feature_tspec, dataset_key)
      feature_dict, feature_tspec_dict = (
          tensorspec_utils.tensorspec_to_feature_dict(
              sub_feature_tspec, decode_images=decode_images))
      tensor_dict.update(feature_dict)
      tensor_spec_dict.update(prepend_keys(feature_tspec_dict, dataset_key))
      if label_tspec is not None:
        sub_label_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
            label_tspec, dataset_key)
        label_dict, label_tspec_dict = (
            tensorspec_utils.tensorspec_to_feature_dict(
                sub_label_tspec, decode_images=decode_images))
        tensor_dict.update(label_dict)
        tensor_spec_dict.update(prepend_keys(label_tspec_dict, dataset_key))
      for key, parsed in parse_wrapper(example_proto, tensor_dict).items():
        parsed_tensors[dataset_key + key] = parsed

    # At this point, all tensors have been parsed into a single flat map.
    # Interpret encoded images.
    def decode_image(key, raw_bytes):
      """Decodes single or batches of JPEG- or PNG-encoded string tensors.

      Args:
        key: String key specified in feature map.
        raw_bytes: String tensor to decode as JPEG or PNG.

      Returns:
        Decoded image tensor with shape specified by tensor spec.
      Raises:
        ValueError: If dtype other than uint8 or uint16 is supplied for image
        specs.
      """
      img_batch_dims = tf.shape(raw_bytes)
      # The spatial + channel dimensions of a single image, assumed to be the
      # last 3 entries of the image feature's tensor spec.
      if len(tensor_spec_dict[key].shape) < 3:
        raise ValueError(
            'Shape of tensor spec for image feature "%s" must '
            'be 3 dimensional (h, w, c), but is %s' %
            (tensor_spec_dict[key].name, tensor_spec_dict[key].shape))
      single_img_dims = tensor_spec_dict[key].shape[-3:]
      num_channels = single_img_dims[2]
      if num_channels not in [1, 3]:
        raise ValueError(
            'Last dimension of shape of tensor spec for image '
            'feature "%s" must 1 or 3, but the shape is %s' %
            (tensor_spec_dict[key].name, tensor_spec_dict[key].shape))

      # Collapse (possibly multiple) batch dims to a single batch dim for
      # decoding purposes.
      raw_bytes = tf.reshape(raw_bytes, [-1])
      data_type = tensor_spec_dict[key].dtype
      if data_type not in SUPPORTED_PIXEL_ENCODINGS:
        raise ValueError('Decoding an image requires tensorspec.data_type '
                         'to be uint8 or uint16.')

      def _decode_images(image_bytes):
        """Decode single image."""
        def _zero_image():
          return tf.zeros(single_img_dims, dtype=data_type)

        def _tf_decode_image():
          return tf.image.decode_image(
              image_bytes, channels=num_channels, dtype=data_type)

        image = tf.cond(
            tf.equal(image_bytes, ''), _zero_image, _tf_decode_image)
        image.set_shape(single_img_dims)
        return image

      img = tf.map_fn(
          _decode_images, raw_bytes, dtype=data_type, back_prop=False)
      img.set_shape(raw_bytes.shape.concatenate(single_img_dims))

      # Expand the collapsed batch dim back to the original img_batch_dims.
      img = tf.reshape(img, tf.concat([img_batch_dims, single_img_dims], 0))

      return img

    # Convert all sparse tensors to dense tensors.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensor_spec.varlen_default_value is not None:
        if tensorspec_utils.is_encoded_image_spec(tensor_spec):
          default_value = ''
        else:
          default_value = tf.cast(
              tf.constant(tensor_spec.varlen_default_value),
              dtype=tensor_spec.dtype)
        parsed_tensors[key] = tf.sparse.to_dense(
            val, default_value=default_value)

    # Ensure that all images are properly decoded.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensorspec_utils.is_encoded_image_spec(tensor_spec) and decode_images:
        parsed_tensors[key] = decode_image(key, val)
        if tensor_spec.dtype not in SUPPORTED_PIXEL_ENCODINGS:
          raise ValueError('Encoded images with key {} must be '
                           'specified with uint8 or uint16 dtype.'.format(key))

    # Pad all varlen features to the corrensponding spec.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensor_spec.varlen_default_value is not None:
        parsed_tensors[key] = tensorspec_utils.pad_or_clip_tensor_to_spec_shape(
            val, tensor_spec)

    # Ensure that we have a consistent ordered mapping despite the underlying
    # spec structure.
    flat_feature_tspec = tensorspec_utils.TensorSpecStruct(
        sorted(tensorspec_utils.flatten_spec_structure(feature_tspec).items()))
    # Using the flat spec structure we allow to map the same parsed_tensor
    # to multiple features or labels. Note, the spec structure ensures that
    # the corresponding tensorspecs are iddentical in such cases.
    features = tensorspec_utils.TensorSpecStruct([
        (key, parsed_tensors[value.dataset_key + value.name])
        for key, value in flat_feature_tspec.items()
    ])

    features = tensorspec_utils.validate_and_pack(
        flat_feature_tspec, features, ignore_batch=True)
    if label_tspec is not None:
      # Ensure that we have a consistent ordered mapping despite the underlying
      # spec structure.
      flat_label_tspec = tensorspec_utils.TensorSpecStruct(
          sorted(tensorspec_utils.flatten_spec_structure(label_tspec).items()))
      labels = tensorspec_utils.TensorSpecStruct([
          (key, parsed_tensors[value.dataset_key + value.name])
          for key, value in flat_label_tspec.items()
      ])
      labels = tensorspec_utils.validate_and_pack(
          flat_label_tspec, labels, ignore_batch=True)
      return features, labels
    return features