def _ParseExample(extracts: types.Extracts, materialize_columns: bool = True) -> None: """Feature extraction from serialized tf.Example.""" # Deserialize the example. example = tf.train.Example() try: example.ParseFromString(extracts[constants.INPUT_KEY]) except: # pylint: disable=bare-except logging.warning('Could not parse tf.Example from the input source.') features = {} if constants.FEATURES_PREDICTIONS_LABELS_KEY in extracts: features = extracts[constants.FEATURES_PREDICTIONS_LABELS_KEY].features for name in example.features.feature: if materialize_columns or name not in features: key = util.compound_key(['features', name]) value = example.features.feature[name] if value.HasField('bytes_list'): values = list(v for v in value.bytes_list.value) elif value.HasField('float_list'): values = list(v for v in value.float_list.value) elif value.HasField('int64_list'): values = list(v for v in value.int64_list.value) if materialize_columns: extracts[key] = types.MaterializedColumn(name=key, value=values) if name not in features: features[name] = {encoding.NODE_SUFFIX: np.array([values])}
def _AugmentExtracts(data: Dict[str, Any], prefix: str, excludes: List[bytes], extracts: types.Extracts) -> None: """Augments the Extracts with FeaturesPredictionsLabels. Args: data: Data dictionary returned by PredictExtractor. prefix: Prefix to use in column naming (e.g. 'features', 'labels', etc). excludes: List of strings containing features, predictions, or labels to exclude from materialization. extracts: The Extracts to be augmented. This is mutated in-place. Raises: TypeError: If the FeaturesPredictionsLabels is corrupt. """ for name, val in data.items(): if excludes is not None and name in excludes: continue # If data originated from FeaturesPredictionsLabels, then the value will be # stored under a 'node' key. if isinstance(val, dict) and encoding.NODE_SUFFIX in val: val = val.get(encoding.NODE_SUFFIX) if name in (prefix, util.KEY_SEPARATOR + prefix): col_name = prefix elif prefix not in ('features', 'predictions', 'labels'): # Names used by additional extracts should be properly escaped already so # avoid escaping the name a second time by manually combining the prefix. col_name = prefix + util.KEY_SEPARATOR + name else: col_name = util.compound_key([prefix, name]) if isinstance(val, tf.compat.v1.SparseTensorValue): extracts[col_name] = types.MaterializedColumn(name=col_name, value=val.values) elif isinstance(val, np.ndarray) or isinstance(val, list): # Only support first dim for now val = val[0] if len(val) > 0 else [] # pylint: disable=g-explicit-length-test extracts[col_name] = types.MaterializedColumn(name=col_name, value=val) else: raise TypeError( 'Dictionary item with key %s, value %s had unexpected type %s' % (name, val, type(val)))
def testCompoundKey(self): self.assertEqual('a_b', util.compound_key(['a_b'])) self.assertEqual('a__b', util.compound_key(['a', 'b'])) self.assertEqual('a__b____c__d', util.compound_key(['a', 'b__c', 'd']))