def _test_multi_record_input_generator(self,
                                           input_generator,
                                           is_dataset=False):
        feature_spec = tensorspec_utils.TensorSpecStruct()
        feature_spec.state = tensorspec_utils.ExtendedTensorSpec(
            shape=(64, 64, 3),
            dtype=tf.uint8,
            name='state/image',
            data_format='jpeg',
            dataset_key='d1')
        feature_spec.action = tensorspec_utils.ExtendedTensorSpec(
            shape=(2), dtype=tf.float32, name='pose', dataset_key='d1')
        label_spec = tensorspec_utils.TensorSpecStruct()
        label_spec.reward = tensorspec_utils.ExtendedTensorSpec(
            shape=(), dtype=tf.float32, name='reward', dataset_key='d1')
        label_spec.reward_2 = tensorspec_utils.ExtendedTensorSpec(
            shape=(), dtype=tf.float32, name='reward', dataset_key='d2')
        input_generator.set_feature_specifications(feature_spec, feature_spec)
        input_generator.set_label_specifications(label_spec, label_spec)

        np_features, np_labels = input_generator.create_dataset_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN)().make_one_shot_iterator(
            ).get_next()

        np_features = tensorspec_utils.validate_and_pack(feature_spec,
                                                         np_features,
                                                         ignore_batch=True)
        np_labels = tensorspec_utils.validate_and_pack(label_spec,
                                                       np_labels,
                                                       ignore_batch=True)
        self.assertAllEqual([2, 64, 64, 3], np_features.state.shape)
        self.assertAllEqual([2, 2], np_features.action.shape)
        self.assertAllEqual((2, ), np_labels.reward.shape)
        self.assertAllEqual((2, ), np_labels.reward_2.shape)
Ejemplo n.º 2
0
  def preprocess(
      self, features,
      labels, mode
  ):
    """The function which preprocesses the features and labels per example.

    Note, this function performs the boilerplate packing and flattening and
    verification of the features and labels according to our spec. The actual
    preprocessing is performed by _preprocess_fn.

    Args:
      features: The features of a single example.
      labels: (Optional None) The labels of a single example.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.

    Returns:
      features_preprocessed: The preprocessed and flattened features
        verified to fulfill our output specs.
      labels_preprocessed: (Optional None) The preprocessed and flattened labels
        verified to fulfill our output specs.
    """
    # First, we verify that the input features and labels fulfill our spec.
    # We further pack the flattened features and labels to our (hierarchical)
    # specification.:
    features = tensorspec_utils.validate_and_pack(
        expected_spec=self.get_in_feature_specification(mode),
        actual_tensors_or_spec=features,
        ignore_batch=True)
    if labels is not None:
      labels = tensorspec_utils.validate_and_pack(
          expected_spec=self.get_in_label_specification(mode),
          actual_tensors_or_spec=labels,
          ignore_batch=True)

    features_preprocessed, labels_preprocessed = self._preprocess_fn(
        features=features, labels=labels, mode=mode)

    features_preprocessed = tensorspec_utils.validate_and_flatten(
        expected_spec=self.get_out_feature_specification(mode),
        actual_tensors_or_spec=features_preprocessed,
        ignore_batch=True)
    if labels_preprocessed:
      labels_preprocessed = tensorspec_utils.validate_and_flatten(
          expected_spec=self.get_out_label_specification(mode),
          actual_tensors_or_spec=labels_preprocessed,
          ignore_batch=True)
    return features_preprocessed, labels_preprocessed
Ejemplo n.º 3
0
  def test_validate_flatten_and_pack(self):
    # An example data pipeline.
    # Some input generator creates input features according to some spec.
    input_features = utils.make_placeholders(mock_nested_spec)
    # Assume a preprocessor has altered these input_features and we want
    # to pass the data on to the next stage, then we simply assure that
    # our output is according to our spec and flatten.
    flat_input_features = utils.validate_and_flatten(
        mock_nested_optional_spec, input_features, ignore_batch=True)
    utils.assert_required(
        mock_nested_optional_spec, input_features, ignore_batch=True)

    # Then e.g. the model_fn receives the flat_input_spec and validates
    # that it is according to it's requirements and packs it back into the
    # spec structure.
    output_features = utils.validate_and_pack(
        mock_nested_subset_spec, flat_input_features, ignore_batch=True)
    utils.assert_required(
        mock_nested_subset_spec, output_features, ignore_batch=True)
Ejemplo n.º 4
0
  def model_fn(self,
               features,
               labels,
               mode,
               config = None,
               params = None):
    """Estimator model_fn.

    Args:
      features: This is the first item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will
        receive what is passed to Estimator in config parameter, or the default
        config (tf.estimator.RunConfig). Allows updating things in your model_fn
        based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.

    Raises:
      ValueError: If the mode key is not supported, not in [PREDICT, TRAIN,
        EVAL].

    Returns:
      An EstimatorSpec.
    """
    features = tensorspec_utils.validate_and_pack(
        expected_spec=self.get_feature_specification(mode),
        actual_tensors_or_spec=features,
        ignore_batch=True)
    if labels:
      labels = tensorspec_utils.validate_and_pack(
          expected_spec=self.get_label_specification(mode),
          actual_tensors_or_spec=labels,
          ignore_batch=True)
    inference_outputs = self.inference_network_fn(features, labels, mode,
                                                  config, params)
    update_ops = None
    if isinstance(inference_outputs, tuple):
      if len(inference_outputs) != 2:
        raise ValueError('Unknown output of inference_network_fn: '
                         'tuple of length %d' % len(inference_outputs))
      outputs = inference_outputs[0]
      update_ops = inference_outputs[1]
      inference_outputs = outputs

    if mode == tf.estimator.ModeKeys.PREDICT:
      model_fn_results = self.create_export_outputs_fn(features,
                                                       inference_outputs, mode,
                                                       config, params)
      export_outputs = None
      if isinstance(model_fn_results, tuple):
        predictions = model_fn_results[0]
        export_outputs = model_fn_results[1]
      elif isinstance(model_fn_results, dict):
        export_outputs = {}
        if len(model_fn_results) == 1:
          name, output = list(model_fn_results.items())[0]
          export_outputs[name] = tf.estimator.export.RegressionOutput(output)
        export_outputs[tf.saved_model.signature_constants
                       .DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
                           tf.estimator.export.PredictOutput(model_fn_results))
        predictions = model_fn_results
      else:
        raise ValueError('The create_export_outputs_fn should return a '
                         'tuple(predictions, export_outputs) or predictions.')

      return tf.estimator.EstimatorSpec(
          mode=mode, predictions=predictions, export_outputs=export_outputs)

    train_fn_result = self.model_train_fn(features, labels, inference_outputs,
                                          mode, config, params)
    if isinstance(train_fn_result, tf.Tensor):
      train_loss = train_fn_result
      train_outputs = {}
    elif isinstance(train_fn_result, tuple):
      train_loss = train_fn_result[0]
      train_outputs = train_fn_result[1]
    else:
      raise ValueError('The model_train_fn should return a '
                       'tuple(loss, train_outputs) or loss.')

    if mode == tf.estimator.ModeKeys.TRAIN:
      # Create the tf.train.Optimizer.
      optimizer = self.create_optimizer()

      train_op = self.create_train_op(train_loss, optimizer, update_ops,
                                      train_outputs)

      self.add_summaries(features, labels, inference_outputs, train_loss,
                         train_outputs, mode, config, params)

      # Now the optimizer has been created, therefore, the checkpoint could be
      # initialized.
      # No new variables are allowed to be added, otherwise
      # we would not initialize these variables.
      # Note, this feature is only available for train to bootstrap a model
      # (partially) from a different model. As soon as this checkpoint is
      # written all other modes will use the local checkpoint within model_dir.
      self.maybe_init_from_checkpoint()
      training_hooks = []

      # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not,
      # so we have to use training_hooks here and check is_chief.
      if config and config.is_chief:  # pytype: disable=attribute-error
        training_hooks.append(
            gin_utils.GinConfigSaverHook(
                config.model_dir, summarize_config=True))
        if hasattr(self, 'writer_init_ops'):
          training_hooks.append(V2SummaryInitHook(self.writer_init_ops[mode]))

      # `SyncReplicasOptimizer` needs to attach a training hook.
      if self._sync_replicas_optimizer:
        training_hooks.append(
            self._sync_replicas_optimizer.make_session_run_hook(
                config.is_chief))  # pytype: disable=attribute-error

      # Return the value of the property first since it might be changed.
      scaffold_fn = self.scaffold_fn
      scaffold = scaffold_fn()

      # In order to export asynchronously the saver has to be registered
      # in the graph collection. The scaffold function might register a
      # saver already which is why it is checked here and a saver only
      # added it has none has been added.
      if not tf.get_collection(tf.GraphKeys.SAVERS):
        # TODO(T2R_CONTRIBUTORS): Switch to using gin config for all saver params.
        keep_checkpoint_every_n_hours = None
        max_to_keep = None
        if config is not None:
          keep_checkpoint_every_n_hours = config.keep_checkpoint_every_n_hours
          max_to_keep = config.keep_checkpoint_max
        saver = gin_configurable_saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
            max_to_keep=max_to_keep,
        )
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      return tf.estimator.EstimatorSpec(
          mode=mode,
          loss=train_loss,
          train_op=train_op,
          training_hooks=training_hooks,
          scaffold=scaffold)

    if mode == tf.estimator.ModeKeys.EVAL:
      self.add_summaries(features, labels, inference_outputs, train_loss,
                         train_outputs, mode, config, params)

      eval_metrics = self.model_eval_fn(features, labels, inference_outputs,
                                        train_loss, train_outputs, mode, config,
                                        params)
      evaluation_hooks = self.get_eval_hooks(config, params)
      if config and config.is_chief:  # pytype: disable=attribute-error
        eval_name = params.get('eval_name', 'eval')  # pytype: disable=attribute-error
        evaluation_hooks.append(
            gin_utils.GinConfigSaverHook(
                os.path.join(config.model_dir, eval_name),
                summarize_config=True))
        if hasattr(self, 'writer_init_ops'):
          evaluation_hooks.append(V2SummaryInitHook(self.writer_init_ops[mode]))
      return tf.estimator.EstimatorSpec(
          mode=mode,
          loss=train_loss,
          eval_metric_ops=eval_metrics,
          evaluation_hooks=evaluation_hooks)

    raise ValueError('The mode {} is not supported yet.'.format(mode))
Ejemplo n.º 5
0
    def model_fn(self, features, labels, mode, config=None, params=None):
        """Estimator model_fn.

    Note, this function overwrites the model_fn of the wrapped t2r_model since
    is replaces specifications with their TPU corresponding calls and introduces
    additional casting conversion after the specification has been verified.

    Args:
      features: This is the first item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig) Will
        receive what is passed to Estimator in config parameter, or the default
        config (tf.estimator.RunConfig). Allows updating things in your model_fn
        based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.

    Raises:
      ValueError: If the mode key is not supported, not in [PREDICT, TRAIN,
        EVAL].

    Returns:
      A TPUEstimatorSpec.
    """

        features = tensorspec_utils.validate_and_pack(
            expected_spec=self.get_feature_specification(mode),
            actual_tensors_or_spec=features,
            ignore_batch=True)
        if labels:
            labels = tensorspec_utils.validate_and_pack(
                expected_spec=self.get_label_specification(mode),
                actual_tensors_or_spec=labels,
                ignore_batch=True)

        # In order to support both TPU and CPU for inference, tensors
        # with dtype=bfloat16 will be casted to float32.
        # Note, despite casting the benefit of bfloat16 are still maintained
        # for TPUs since this operation is a noop on this platform.
        # See http://shortn/_TTg3ZyATRo for rationale.
        features = tensorspec_utils.cast_bfloat16_to_float32(features)
        if labels is not None:
            labels = tensorspec_utils.cast_bfloat16_to_float32(labels)

        inference_outputs = self._t2r_model.inference_network_fn(
            features, labels, mode, config, params)

        if mode == tf.estimator.ModeKeys.PREDICT:
            model_fn_results = self._t2r_model.create_export_outputs_fn(
                features, inference_outputs, mode, config, params)
            export_outputs = None
            if isinstance(model_fn_results, tuple):
                predictions = model_fn_results[0]
                export_outputs = model_fn_results[1]
            elif isinstance(model_fn_results, dict):
                export_outputs = {}
                if len(model_fn_results) == 1:
                    name, output = model_fn_results.items()[0]
                    export_outputs[
                        name] = tf.estimator.export.RegressionOutput(output)
                export_outputs[tf.saved_model.signature_constants.
                               DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
                                   tf.estimator.export.PredictOutput(
                                       model_fn_results))
                predictions = model_fn_results
            else:
                raise ValueError(
                    'The create_export_outputs_fn should return a '
                    'tuple(predictions, export_outputs) or predictions.')

            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs=export_outputs)

        train_fn_result = self._t2r_model.model_train_fn(
            features, labels, inference_outputs, mode, config, params)
        if isinstance(train_fn_result, tf.Tensor):
            train_loss = train_fn_result
            train_outputs = {}
        elif isinstance(train_fn_result, tuple):
            train_loss = train_fn_result[0]
            train_outputs = train_fn_result[1]
        else:
            raise ValueError('The model_train_fn should return a '
                             'tuple(loss, train_outputs) or loss.')

        if mode == tf.estimator.ModeKeys.TRAIN:
            # Create the tf.train.Optimizer.
            optimizer = get_cross_shard_optimizer(
                self._t2r_model.create_optimizer())

            # Required for batch norm usage.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = self._t2r_model.create_train_op(
                    train_loss, optimizer)

            self._t2r_model.add_summaries(features, labels, inference_outputs,
                                          train_loss, train_outputs, mode,
                                          config, params)

            # For TPUs the init has to happen in a scaffold function. Since the model
            # already contains one implementation which is internal to the model
            # this call is simply wrapped.
            # No new variables are allowed to be added, otherwise
            # we would not initialize these variables.
            # Note, this feature is only available for train to bootstrap a model
            # (partially) from a different model. As soon as this checkpoint is
            # written all other modes will use the local checkpoint within
            # model_dir.

            def create_scaffold_fn():
                """Creates a scaffold instance."""
                self._t2r_model.maybe_init_from_checkpoint()
                # Return the value of the property first since it might be changed.
                scaffold_fn = self._t2r_model.scaffold_fn
                scaffold = scaffold_fn()
                # In order to export asynchronously the saver has to be registered
                # in the graph collection. The scaffold function might register a
                # saver already which is why it is checked here and a saver only
                # added it has none has been added.
                if not tf.get_collection(tf.GraphKeys.SAVERS):
                    # TODO(T2R_CONTRIBUTORS): Switch to using gin config for all saver params.
                    keep_checkpoint_every_n_hours = None
                    max_to_keep = None
                    if config is not None:
                        keep_checkpoint_every_n_hours = config.keep_checkpoint_every_n_hours
                        max_to_keep = config.keep_checkpoint_max
                    saver = abstract_model.gin_configurable_saver(
                        keep_checkpoint_every_n_hours=
                        keep_checkpoint_every_n_hours,
                        max_to_keep=max_to_keep,
                    )
                    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                return scaffold

            training_hooks = []

            # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not,
            # so we have to use training_hooks here and check is_chief.
            if config and config.is_chief:  # pytype: disable=attribute-error
                training_hooks.append(
                    gin_utils.GinConfigSaverHook(config.model_dir,
                                                 summarize_config=True))

            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=train_loss,
                train_op=train_op,
                training_hooks=training_hooks,
                scaffold_fn=create_scaffold_fn)

        if mode == tf.estimator.ModeKeys.EVAL:
            self._t2r_model.add_summaries(features, labels, inference_outputs,
                                          train_loss, train_outputs, mode,
                                          config, params)
            eval_metrics = self._t2r_model.model_eval_fn(
                features, labels, inference_outputs, train_loss, train_outputs,
                mode, config, params)
            evaluation_hooks = self._t2r_model.get_eval_hooks(config, params)
            if config and config.is_chief:  # pytype: disable=attribute-error
                eval_name = params.get('eval_name', 'eval')  # pytype: disable=attribute-error
                evaluation_hooks.append(
                    gin_utils.GinConfigSaverHook(os.path.join(
                        config.model_dir, eval_name),
                                                 summarize_config=True))

            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=train_loss,
                eval_metrics=eval_metrics,
                evaluation_hooks=evaluation_hooks)

        raise ValueError('The mode {} is not supported yet.'.format(mode))
Ejemplo n.º 6
0
  def parse_tf_example_fn(*input_values):
    """Maps string tensors (serialized TFExamples) to parsed tensors.

    Args:
      *input_values: A (string tensor,) tuple if mapping from a RecordIODataset
        or TFRecordDataset, or a (key, string tensor) tuple if mapping from a
        SSTableDataset, or (Dict[dataset_key, values],) if mapping from multiple
        datasets.

    Returns:
      features: Collection of tensors conforming to feature_tspec.
      labels: Collection of tensors conforming to label_tspec.
    Raises:
        ValueError: If dtype other than uint8 or uint16 is supplied for image
        specs.
    """
    dict_extracted = _get_sstable_proto_dict(*input_values)

    def parse_wrapper(example, spec_dict):
      """Wrap tf.parse_example to support bfloat16 dtypes.

      This allows models which declare bfloat16 as inputs to not require an
      additional preprocessing step to cast all inputs from float32 to bfloat16.
      Consider this to be analogous to JPEG decoding in the data step.

      Args:
        example: TFExample
        spec_dict: Dictionary of feature name -> tf.FixedLenFeature

      Returns:
        Parsed feature map
      """

      def is_bfloat_feature(value):
        return value.dtype == tf.bfloat16

      def maybe_map_bfloat(value):
        """Maps bfloat16 to float32."""
        if is_bfloat_feature(value):
          if isinstance(value, tf.FixedLenFeature):
            return tf.FixedLenFeature(
                value.shape, tf.float32, default_value=value.default_value)
          elif isinstance(value, tf.VarLenFeature):
            return tf.VarLenFeature(
                value.shape, tf.float32, default_value=value.default_value)
          else:
            return tf.FixedLenSequenceFeature(
                value.shape, tf.float32, default_value=value.default_value)
        return value

      # Change bfloat features to float32 for parsing.
      new_spec_dict = {
          k: maybe_map_bfloat(v) for k, v in six.iteritems(spec_dict)
      }
      for k, v in six.iteritems(new_spec_dict):
        if v.dtype not in [tf.float32, tf.string, tf.int64]:
          raise ValueError('Feature specification with invalid data type for '
                           'tf.Example parsing: "%s": %s' % (k, v.dtype))

      # Separate new_spec_dict into Context and Sequence features. In the event
      # that there are no SequenceFeatures, the context_features dictionary
      # (containing FixedLenFeatures) is passed to tf.parse_examples.
      context_features, sequence_features = {}, {}
      for k, v in six.iteritems(new_spec_dict):
        v = maybe_map_bfloat(v)
        if isinstance(v, tf.FixedLenSequenceFeature):
          sequence_features[k] = v
        elif isinstance(v, tf.FixedLenFeature):
          context_features[k] = v
        elif isinstance(v, tf.VarLenFeature):
          context_features[k] = v
        else:
          raise ValueError(
              'Only FixedLenFeature and FixedLenSequenceFeature are currently '
              'supported.')

      # If there are any sequence features, we use parse_sequence_example.
      if sequence_features:
        # Filter out '_length' context features; don't parse them from records.
        for parse_name in sequence_features:
          # Sometimes, the '_length' context feature doesn't exist.
          if parse_name + '_length' in context_features:
            del context_features[parse_name + '_length']
        result, sequence_result, feature_lengths = tf.io.parse_sequence_example(
            example,
            context_features=context_features,
            sequence_features=sequence_features)
        result.update(sequence_result)
        # Augment the parsed tensors with feature length tensors.
        for parse_name, length_tensor in feature_lengths.items():
          result[parse_name + '_length'] = length_tensor
      else:
        result = tf.parse_example(example, context_features)
      to_convert = [
          k for k, v in six.iteritems(spec_dict) if is_bfloat_feature(v)
      ]

      for c in to_convert:
        result[c] = tf.cast(result[c], tf.bfloat16)

      return result

    prepend_keys = lambda d, pre: {pre + k: v for k, v in list(d.items())}
    # Parse each dataset's tensors. Parsed results from parse_wrapper get
    # dataset_key prepended to ensure uniqueness of keys among datasets.
    parsed_tensors = {}
    # {Prepended parsed key : TensorSpecs} for all datasets. Will contain
    # '_length' TensorSpecs that won't actually get parsed. We filter those out
    # before passing to the parse_sequence_example call.
    tensor_spec_dict = {}
    for dataset_key, example_proto in dict_extracted.items():
      # Parsed key to Feature Specs (retained only for this dataset).
      tensor_dict = {}
      sub_feature_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
          feature_tspec, dataset_key)
      feature_dict, feature_tspec_dict = (
          tensorspec_utils.tensorspec_to_feature_dict(
              sub_feature_tspec, decode_images=decode_images))
      tensor_dict.update(feature_dict)
      tensor_spec_dict.update(prepend_keys(feature_tspec_dict, dataset_key))
      if label_tspec is not None:
        sub_label_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
            label_tspec, dataset_key)
        label_dict, label_tspec_dict = (
            tensorspec_utils.tensorspec_to_feature_dict(
                sub_label_tspec, decode_images=decode_images))
        tensor_dict.update(label_dict)
        tensor_spec_dict.update(prepend_keys(label_tspec_dict, dataset_key))
      for key, parsed in parse_wrapper(example_proto, tensor_dict).items():
        parsed_tensors[dataset_key + key] = parsed

    # At this point, all tensors have been parsed into a single flat map.
    # Interpret encoded images.
    def decode_image(key, raw_bytes):
      """Decodes single or batches of JPEG- or PNG-encoded string tensors.

      Args:
        key: String key specified in feature map.
        raw_bytes: String tensor to decode as JPEG or PNG.

      Returns:
        Decoded image tensor with shape specified by tensor spec.
      Raises:
        ValueError: If dtype other than uint8 or uint16 is supplied for image
        specs.
      """
      img_batch_dims = tf.shape(raw_bytes)
      # The spatial + channel dimensions of a single image, assumed to be the
      # last 3 entries of the image feature's tensor spec.
      if len(tensor_spec_dict[key].shape) < 3:
        raise ValueError(
            'Shape of tensor spec for image feature "%s" must '
            'be 3 dimensional (h, w, c), but is %s' %
            (tensor_spec_dict[key].name, tensor_spec_dict[key].shape))
      single_img_dims = tensor_spec_dict[key].shape[-3:]
      num_channels = single_img_dims[2]
      if num_channels not in [1, 3]:
        raise ValueError(
            'Last dimension of shape of tensor spec for image '
            'feature "%s" must 1 or 3, but the shape is %s' %
            (tensor_spec_dict[key].name, tensor_spec_dict[key].shape))

      # Collapse (possibly multiple) batch dims to a single batch dim for
      # decoding purposes.
      raw_bytes = tf.reshape(raw_bytes, [-1])
      data_type = tensor_spec_dict[key].dtype
      if data_type not in SUPPORTED_PIXEL_ENCODINGS:
        raise ValueError('Decoding an image requires tensorspec.data_type '
                         'to be uint8 or uint16.')

      def _decode_images(image_bytes):
        """Decode single image."""
        def _zero_image():
          return tf.zeros(single_img_dims, dtype=data_type)

        def _tf_decode_image():
          return tf.image.decode_image(
              image_bytes, channels=num_channels, dtype=data_type)

        image = tf.cond(
            tf.equal(image_bytes, ''), _zero_image, _tf_decode_image)
        image.set_shape(single_img_dims)
        return image

      img = tf.map_fn(
          _decode_images, raw_bytes, dtype=data_type, back_prop=False)
      img.set_shape(raw_bytes.shape.concatenate(single_img_dims))

      # Expand the collapsed batch dim back to the original img_batch_dims.
      img = tf.reshape(img, tf.concat([img_batch_dims, single_img_dims], 0))

      return img

    # Convert all sparse tensors to dense tensors.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensor_spec.varlen_default_value is not None:
        if tensorspec_utils.is_encoded_image_spec(tensor_spec):
          default_value = ''
        else:
          default_value = tf.cast(
              tf.constant(tensor_spec.varlen_default_value),
              dtype=tensor_spec.dtype)
        parsed_tensors[key] = tf.sparse.to_dense(
            val, default_value=default_value)

    # Ensure that all images are properly decoded.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensorspec_utils.is_encoded_image_spec(tensor_spec) and decode_images:
        parsed_tensors[key] = decode_image(key, val)
        if tensor_spec.dtype not in SUPPORTED_PIXEL_ENCODINGS:
          raise ValueError('Encoded images with key {} must be '
                           'specified with uint8 or uint16 dtype.'.format(key))

    # Pad all varlen features to the corrensponding spec.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensor_spec.varlen_default_value is not None:
        parsed_tensors[key] = tensorspec_utils.pad_or_clip_tensor_to_spec_shape(
            val, tensor_spec)

    # Ensure that we have a consistent ordered mapping despite the underlying
    # spec structure.
    flat_feature_tspec = tensorspec_utils.TensorSpecStruct(
        sorted(tensorspec_utils.flatten_spec_structure(feature_tspec).items()))
    # Using the flat spec structure we allow to map the same parsed_tensor
    # to multiple features or labels. Note, the spec structure ensures that
    # the corresponding tensorspecs are iddentical in such cases.
    features = tensorspec_utils.TensorSpecStruct([
        (key, parsed_tensors[value.dataset_key + value.name])
        for key, value in flat_feature_tspec.items()
    ])

    features = tensorspec_utils.validate_and_pack(
        flat_feature_tspec, features, ignore_batch=True)
    if label_tspec is not None:
      # Ensure that we have a consistent ordered mapping despite the underlying
      # spec structure.
      flat_label_tspec = tensorspec_utils.TensorSpecStruct(
          sorted(tensorspec_utils.flatten_spec_structure(label_tspec).items()))
      labels = tensorspec_utils.TensorSpecStruct([
          (key, parsed_tensors[value.dataset_key + value.name])
          for key, value in flat_label_tspec.items()
      ])
      labels = tensorspec_utils.validate_and_pack(
          flat_label_tspec, labels, ignore_batch=True)
      return features, labels
    return features
Ejemplo n.º 7
0
    def model_fn(self, features, labels, mode, config=None, params=None):
        """Estimator model_fn.

    Args:
      features: This is the first item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig) Will
        receive what is passed to Estimator in config parameter, or the default
        config (tf.estimator.RunConfig). Allows updating things in your model_fn
        based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.

    Raises:
      ValueError: If the mode key is not supported, not in [PREDICT, TRAIN,
        EVAL].

    Returns:
      An EstimatorSpec.
    """

        features = tensorspec_utils.validate_and_pack(
            expected_spec=self.get_feature_specification(mode),
            actual_tensors_or_spec=features,
            ignore_batch=True)
        if labels:
            labels = tensorspec_utils.validate_and_pack(
                expected_spec=self.get_label_specification(mode),
                actual_tensors_or_spec=labels,
                ignore_batch=True)
        inference_outputs = self.inference_network_fn(features, labels, mode,
                                                      config, params)

        # After inference_fn no new variables are allowed to be added, otherwise
        # we would not initialize these variables.
        self.maybe_init_from_checkpoint()

        if mode == tf.estimator.ModeKeys.PREDICT:
            model_fn_results = self.create_export_outputs_fn(
                features, inference_outputs, mode, config, params)
            export_outputs = None
            if isinstance(model_fn_results, tuple):
                predictions = model_fn_results[0]
                export_outputs = model_fn_results[1]
            elif isinstance(model_fn_results, dict):
                export_outputs = {}
                if len(model_fn_results) == 1:
                    name, output = list(model_fn_results.items())[0]
                    export_outputs[
                        name] = tf.estimator.export.RegressionOutput(output)
                export_outputs[tf.saved_model.signature_constants.
                               DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
                                   tf.estimator.export.PredictOutput(
                                       model_fn_results))
                predictions = model_fn_results
            else:
                raise ValueError(
                    'The create_export_outputs_fn should return a '
                    'tuple(predictions, export_outputs) or predictions.')

            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions,
                                              export_outputs=export_outputs)

        train_fn_result = self.model_train_fn(features, labels,
                                              inference_outputs, mode, config,
                                              params)
        if isinstance(train_fn_result, tf.Tensor):
            train_loss = train_fn_result
            train_outputs = {}
        elif isinstance(train_fn_result, tuple):
            train_loss = train_fn_result[0]
            train_outputs = train_fn_result[1]
        else:
            raise ValueError('The model_train_fn should return a '
                             'tuple(loss, train_outputs) or loss.')

        if mode == tf.estimator.ModeKeys.TRAIN:
            # Create the tf.train.Optimizer.
            optimizer = self.create_optimizer()

            # Required for batch norm usage.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = self.create_train_op(train_loss, optimizer)

            self.add_summaries(features, labels, inference_outputs, train_loss,
                               train_outputs, mode, config, params)

            training_hooks = []

            # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not,
            # so we have to use training_hooks here and check is_chief.
            if config and config.is_chief:  # pytype: disable=attribute-error
                training_hooks.append(
                    gin_utils.GinConfigSaverHook(config.model_dir,
                                                 summarize_config=True))

            # `SyncReplicasOptimizer` needs to attach a training hook.
            if self._sync_replicas_optimizer:
                training_hooks.append(
                    self._sync_replicas_optimizer.make_session_run_hook(
                        config.is_chief))  # pytype: disable=attribute-error

            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=train_loss,
                                              train_op=train_op,
                                              training_hooks=training_hooks,
                                              scaffold=self._scaffold_fn())

        if mode == tf.estimator.ModeKeys.EVAL:
            self.add_summaries(features, labels, inference_outputs, train_loss,
                               train_outputs, mode, config, params)
            eval_metrics = self.model_eval_fn(features, labels,
                                              inference_outputs, train_loss,
                                              train_outputs, mode, config,
                                              params)
            evaluation_hooks = self.get_eval_hooks(config, params)
            if config and config.is_chief:  # pytype: disable=attribute-error
                eval_name = params.get('eval_name', 'eval')  # pytype: disable=attribute-error
                evaluation_hooks.append(
                    gin_utils.GinConfigSaverHook(os.path.join(
                        config.model_dir, eval_name),
                                                 summarize_config=True))
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=train_loss,
                eval_metric_ops=eval_metrics,
                evaluation_hooks=evaluation_hooks)

        raise ValueError('The mode {} is not supported yet.'.format(mode))
Ejemplo n.º 8
0
    def parse_tf_example_fn(*input_values):
        """Maps string tensor to a parsed TFExample.

    Args:
      *input_values: A string tensor if mapping from a RecordIODataset or
        TFRecordDataset, or a (key, string tensor) tuple if mapping from a
        SSTableDataset.

    Returns:
      features: Collection of tensors conforming to feature_tspec.
      labels: Collection of tensors conforming to label_tspec.
    Raises:
        ValueError: If dtype other than uint8 is supplied for image specs.
    """
        if len(input_values) == 2:
            # Assume an SSTable key, value pair.
            _, example_proto = input_values
        else:
            example_proto, = input_values

        def parse_wrapper(example, spec_dict):
            """Wrap tf.parse_example to support bfloat16 dtypes.

      This allows models which declare bfloat16 as inputs to not require an
      additional preprocessing step to cast all inputs from float32 to bfloat16.
      Consider this to be analogous to JPEG decoding in the data step.

      Args:
        example: TFExample
        spec_dict: Dictionary of feature name -> tf.FixedLenFeature
      Returns:
        Parsed feature map
      """
            def is_bfloat_feature(value):
                return value.dtype == tf.bfloat16

            def maybe_map_bfloat(value):
                if is_bfloat_feature(value):
                    if isinstance(value, tf.FixedLenFeature):
                        return tf.FixedLenFeature(
                            value.shape,
                            tf.float32,
                            default_value=value.default_value)
                    else:
                        return tf.FixedLenSequenceFeature(
                            value.shape,
                            tf.float32,
                            default_value=value.default_value)
                return value

            # Change bfloat features to float32 for parsing.
            new_spec_dict = {
                k: maybe_map_bfloat(v)
                for k, v in six.iteritems(spec_dict)
            }
            for k, v in six.iteritems(new_spec_dict):
                if v.dtype not in [tf.float32, tf.string, tf.int64]:
                    raise ValueError(
                        'Feature specification with invalid data type for '
                        'tf.Example parsing: "%s": %s' % (k, v.dtype))

            # Separate new_spec_dict into Context and Sequence features. In the event
            # that there are no SequenceFeatures, the context_features dictionary
            # (containing FixedLenFeatures) is passed to tf.parse_examples.
            context_features, sequence_features = {}, {}
            for k, v in six.iteritems(new_spec_dict):
                v = maybe_map_bfloat(v)
                if isinstance(v, tf.FixedLenSequenceFeature):
                    sequence_features[k] = v
                elif isinstance(v, tf.FixedLenFeature):
                    context_features[k] = v
                else:
                    raise ValueError(
                        'Only FixedLenFeature and FixedLenSequenceFeature are currently '
                        'supported.')

            # If there are any sequence features, we use parse_sequence_example.
            if sequence_features:
                result, sequence_result, feature_lengths = tf.io.parse_sequence_example(
                    example,
                    context_features=context_features,
                    sequence_features=sequence_features)
                del feature_lengths
                result.update(sequence_result)
            else:
                result = tf.parse_example(example, context_features)
            to_convert = [
                k for k, v in six.iteritems(spec_dict) if is_bfloat_feature(v)
            ]

            for c in to_convert:
                result[c] = tf.cast(result[c], tf.bfloat16)

            return result

        parsed_tensors = parse_wrapper(example_proto, tensor_dict)

        # Interpret encoded images.
        def decode_image(key, raw_bytes):
            """Decodes single or batches of JPEG- or PNG-encoded string tensors.

      Args:
        key: String key specified in feature map.
        raw_bytes: String tensor to decode as JPEG or PNG.

      Returns:
        Decoded image tensor with shape specified by tensor spec.
      Raises:
        ValueError: If dtype other than uint8 is supplied for image specs.
      """
            img_batch_dims = tf.shape(raw_bytes)
            # The spatial + channel dimensions of a single image, assumed to be the
            # last 3 entries of the image feature's tensor spec.
            single_img_dims = tensor_spec_dict[key].shape[-3:]

            # Collapse (possibly multiple) batch dims to a single batch dim for
            # decoding purposes.
            raw_bytes = tf.reshape(raw_bytes, [-1])
            img = tf.map_fn(tf.image.decode_image,
                            raw_bytes,
                            dtype=tf.uint8,
                            back_prop=False)
            img.set_shape(raw_bytes.shape.concatenate(single_img_dims))

            # Expand the collapsed batch dim back to the original img_batch_dims.
            img = tf.reshape(img,
                             tf.concat([img_batch_dims, single_img_dims], 0))

            return img

        for key, val in parsed_tensors.items():
            tensor_spec = tensor_spec_dict[key]
            if tensorspec_utils.is_encoded_image_spec(tensor_spec):
                parsed_tensors[key] = decode_image(key, val)
                if tensor_spec.dtype != tf.uint8:
                    raise ValueError('Encoded images with key {} must be '
                                     'specified with uint8 dtype.'.format(key))

        # Ensure that we have a consistent ordered mapping despite the underlying
        # spec structure.
        flat_feature_tspec = tensorspec_utils.TensorSpecStruct(
            sorted(
                tensorspec_utils.flatten_spec_structure(
                    feature_tspec).items()))
        # Using the flat spec structure we allow to map the same parsed_tensor
        # to multiple features or labels. Note, the spec structure ensures that
        # the corresponding tensorspecs are iddentical in such cases.
        features = tensorspec_utils.TensorSpecStruct([
            (key, parsed_tensors[value.name])
            for key, value in flat_feature_tspec.items()
        ])

        features = tensorspec_utils.validate_and_pack(flat_feature_tspec,
                                                      features,
                                                      ignore_batch=True)
        if label_tspec is not None:
            # Ensure that we have a consistent ordered mapping despite the underlying
            # spec structure.
            flat_label_tspec = tensorspec_utils.TensorSpecStruct(
                sorted(
                    tensorspec_utils.flatten_spec_structure(
                        label_tspec).items()))
            labels = tensorspec_utils.TensorSpecStruct([
                (key, parsed_tensors[value.name])
                for key, value in flat_label_tspec.items()
            ])
            labels = tensorspec_utils.validate_and_pack(flat_label_tspec,
                                                        labels,
                                                        ignore_batch=True)
            return features, labels
        return features