def testAccessFeaturesDict(self): sparse = tf.SparseTensor( indices=[[0, 0], [1, 1]], values=['apple', 'banana'], dense_shape=[2, 2]) dense = tf.constant([1.0, 2.0]) bad_dense = tf.constant([[1.0, 2.0], [3.0, 4.0]]) squeeze_needed = tf.constant([[2.0]]) sess = tf.Session() sparse_value, dense_value, bad_dense_value, squeeze_needed_value = sess.run( fetches=[sparse, dense, bad_dense, squeeze_needed]) features_dict = { 'sparse': { encoding.NODE_SUFFIX: sparse_value }, 'dense': { encoding.NODE_SUFFIX: dense_value }, 'squeeze_needed': { encoding.NODE_SUFFIX: squeeze_needed_value }, 'bad_dense': { encoding.NODE_SUFFIX: bad_dense_value }, } accessor = slice_accessor.SliceAccessor(features_dict) self.assertEqual(['apple', 'banana'], list(accessor.get('sparse'))) self.assertEqual([1.0, 2.0], list(accessor.get('dense'))) self.assertEqual([2.0], list(accessor.get('squeeze_needed'))) with self.assertRaises(ValueError): accessor.get('bad_dense') with self.assertRaises(KeyError): accessor.get('no_such_key')
def testAccessFeaturesDict(self, feature_value, slice_value): accessor = slice_accessor.SliceAccessor([{'feature': feature_value}]) self.assertEqual(slice_value, accessor.get('feature')) # Test with multiple dicts and duplicate values accessor = slice_accessor.SliceAccessor([{ 'feature': feature_value }, { 'feature': feature_value }]) self.assertEqual(slice_value, accessor.get('feature')) # Test with default features dict accessor = slice_accessor.SliceAccessor( [{ 'unmatched_feature': feature_value }], default_features_dict={'feature': feature_value}) self.assertEqual(slice_value, accessor.get('feature'))
def get_slices_for_features_dict(features_dict, slice_spec): """Generates the slice keys appropriate for the given features dictionary. Args: features_dict: Features dictionary. slice_spec: slice specification. Yields: Slice keys appropriate for the given features dictionary. """ accessor = slice_accessor.SliceAccessor(features_dict) for single_slice_spec in slice_spec: for slice_key in single_slice_spec.generate_slices(accessor): yield slice_key
def get_slices_for_features_dict(features_dict: types.DictOfFetchedTensorValues, slice_spec: List[SingleSliceSpec] ) -> Iterable[SliceKeyType]: """Generates the slice keys appropriate for the given features dictionary. Args: features_dict: Features dictionary. slice_spec: slice specification. Yields: Slice keys appropriate for the given features dictionary. """ accessor = slice_accessor.SliceAccessor(features_dict) for single_slice_spec in slice_spec: for slice_key in single_slice_spec.generate_slices(accessor): yield slice_key
def get_slices_for_features_dicts( features_dicts: Iterable[Union[types.DictOfTensorValue, types.DictOfFetchedTensorValues]], default_features_dict: Union[types.DictOfTensorValue, types.DictOfFetchedTensorValues], slice_spec: List[SingleSliceSpec]) -> Iterable[SliceKeyType]: """Generates the slice keys appropriate for the given features dictionaries. Args: features_dicts: Features dictionaries. For example a list of transformed features dictionaries. default_features_dict: Additional dict to search if a match is not found in features dictionaries. For example the raw features. slice_spec: slice specification. Yields: Slice keys appropriate for the given features dictionaries. """ accessor = slice_accessor.SliceAccessor(features_dicts, default_features_dict) for single_slice_spec in slice_spec: for slice_key in single_slice_spec.generate_slices(accessor): yield slice_key
def testLegacyAccessFeaturesDict(self): with tf.compat.v1.Session() as sess: sparse = tf.SparseTensor(indices=[[0, 0], [1, 1]], values=['apple', 'banana'], dense_shape=[2, 2]) dense = tf.constant([1.0, 2.0]) dense_single = tf.constant([7.0]) dense_multidim = tf.constant([[1.0, 2.0], [3.0, 4.0]]) squeeze_needed = tf.constant([[2.0]]) (sparse_value, dense_value, dense_single_value, dense_multidim_value, squeeze_needed_value) = sess.run(fetches=[ sparse, dense, dense_single, dense_multidim, squeeze_needed ]) features_dict = { 'sparse': { encoding.NODE_SUFFIX: sparse_value }, 'dense': { encoding.NODE_SUFFIX: dense_value }, 'dense_single': { encoding.NODE_SUFFIX: dense_single_value }, 'squeeze_needed': { encoding.NODE_SUFFIX: squeeze_needed_value }, 'dense_multidim': { encoding.NODE_SUFFIX: dense_multidim_value }, } accessor = slice_accessor.SliceAccessor(features_dict) self.assertEqual([b'apple', b'banana'], accessor.get('sparse')) self.assertEqual([1.0, 2.0], accessor.get('dense')) self.assertEqual([7.0], accessor.get('dense_single')) self.assertEqual([1.0, 2.0, 3.0, 4.0], accessor.get('dense_multidim')) self.assertEqual([2.0], accessor.get('squeeze_needed'))
def testAccessFeaturesDict(self, feature_value, slice_value): accessor = slice_accessor.SliceAccessor({'feature': feature_value}) self.assertEqual(slice_value, accessor.get('feature'))
def testRaisesKeyError(self): accessor = slice_accessor.SliceAccessor({}) with self.assertRaises(KeyError): accessor.get('no_such_key')