def _prepare_features_for_sqss(features, labels, mode,
                               sequence_feature_columns,
                               context_feature_columns):
  """Prepares features for batching by the SQSS.

  In preparation for batching by the SQSS, this function:
  - Extracts the input key from the features dict.
  - Separates sequence and context features dicts from the features dict.
  - Adds the labels tensor to the sequence features dict.

  Args:
    features: A dict of Python string to an iterable of `Tensor` or
      `SparseTensor` of rank 2, the `features` argument of a TF.Learn model_fn.
    labels: An iterable of `Tensor`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    sequence_feature_columns: An iterable containing all the feature columns
      describing sequence features. All items in the set should be instances
      of classes derived from `FeatureColumn`.
    context_feature_columns: An iterable containing all the feature columns
      describing context features, i.e., features that apply accross all time
      steps. All items in the set should be instances of classes derived from
      `FeatureColumn`.

  Returns:
    sequence_features: A dict mapping feature names to sequence features.
    context_features: A dict mapping feature names to context features.

  Raises:
    ValueError: If `features` does not contain a value for every key in
      `sequence_feature_columns` or `context_feature_columns`.
  """

  # Extract sequence features.
  feature_column_ops._check_supported_sequence_columns(sequence_feature_columns)  # pylint: disable=protected-access
  sequence_features = {}
  for column in sequence_feature_columns:
    for name in _get_name_or_parent_names(column):
      feature = features.get(name, None)
      if feature is None:
        raise ValueError('No key in features for sequence feature: ' + name)
      sequence_features[name] = feature

  # Extract context features.
  context_features = {}
  if context_feature_columns is not None:
    for column in context_feature_columns:
      name = column.name
      feature = features.get(name, None)
      if feature is None:
        raise ValueError('No key in features for context feature: ' + name)
      context_features[name] = feature

  # Add labels to the resulting sequence features dict.
  if mode != model_fn.ModeKeys.INFER:
    sequence_features[rnn_common.RNNKeys.LABELS_KEY] = labels

  return sequence_features, context_features
Ejemplo n.º 2
0
def _prepare_features_for_sqss(features, labels, mode, input_key_column_name,
                               sequence_feature_columns,
                               context_feature_columns):
    """Prepares features for batching by the SQSS.

  In preparation for batching by the SQSS, this function:
  - Extracts the input key from the features dict.
  - Separates sequence and context features dicts from the features dict.
  - Adds the labels tensor to the sequence features dict.

  Args:
    features: A dict of Python string to an iterable of `Tensor` or
      `SparseTensor` of rank 2, the `features` argument of a TF.Learn model_fn.
    labels: An iterable of `Tensor`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    input_key_column_name: Python string, the name of the feature column
      containing a string scalar `Tensor` that serves as a unique key to
      identify the input sequence across minibatches.
    sequence_feature_columns: An iterable containing all the feature columns
      describing sequence features. All items in the set should be instances
      of classes derived from `FeatureColumn`.
    context_feature_columns: An iterable containing all the feature columns
      describing context features, i.e., features that apply accross all time
      steps. All items in the set should be instances of classes derived from
      `FeatureColumn`.

  Returns:
    input_key: The string scalar `Tensor` that serves as a unique key to
      identify the input sequence across minibatches.
    sequence_features: A dict mapping feature names to sequence features.
    context_features: A dict mapping feature names to context features.

  Raises:
    ValueError: If `features` does not contain a value for
      `input_key_column_name`.
    ValueError: If `features` does not contain a value for every key in
      `sequence_feature_columns` or `context_feature_columns`.
  """
    # Pop the input key from the features dict.
    input_key = features.pop(input_key_column_name, None)
    if input_key is None:
        raise ValueError('No key in features for input_key_column_name: ' +
                         input_key_column_name)

    # Extract sequence features.

    feature_column_ops._check_supported_sequence_columns(
        sequence_feature_columns)  # pylint: disable=protected-access
    sequence_features = {}
    for column in sequence_feature_columns:
        for name in _get_name_or_parent_names(column):
            feature = features.get(name, None)
            if feature is None:
                raise ValueError('No key in features for sequence feature: ' +
                                 name)
            sequence_features[name] = feature

    # Extract context features.
    context_features = {}
    if context_feature_columns is not None:
        for column in context_feature_columns:
            name = column.name
            feature = features.get(name, None)
            if feature is None:
                raise ValueError('No key in features for context feature: ' +
                                 name)
            context_features[name] = feature

    # Add labels to the resulting sequence features dict.
    if mode != model_fn.ModeKeys.INFER:
        sequence_features[RNNKeys.LABELS_KEY] = labels

    return input_key, sequence_features, context_features