예제 #1
0
 def filter_extracts(extracts: types.Extracts) -> types.Extracts:  # pylint: disable=invalid-name
     """Filters extracts."""
     if not include and not exclude:
         return extracts
     elif include:
         return {k: v for k, v in extracts.items() if k in include}
     else:
         assert exclude
         return {k: v for k, v in extracts.items() if k not in exclude}
예제 #2
0
 def process(self, element: types.Extracts
            ) -> List[Tuple[SliceKeyType, types.Extracts]]:
   key_filter_fn = self._key_filter_fn  # Local cache.
   filtered = {k: v for k, v in element.items() if key_filter_fn(k)}
   result = [(slice_key, filtered)
             for slice_key in element.get(constants.SLICE_KEY_TYPES_KEY)]
   self._num_slices_generated_per_instance.update(len(result))
   self._post_slice_num_instances.inc(len(result))
   return result
예제 #3
0
 def merge_lists(target: types.Extracts) -> types.Extracts:
     """Converts target's leaves which are lists to batched np.array's, etc."""
     if isinstance(target, Mapping):
         result = {}
         for key, value in target.items():
             try:
                 result[key] = merge_lists(value)
             except Exception as e:
                 raise RuntimeError(
                     'Failed to convert value for key "{}"'.format(
                         key)) from e
         return {k: merge_lists(v) for k, v in target.items()}
     elif target and (isinstance(target[0], tf.compat.v1.SparseTensorValue)
                      or isinstance(target[0], types.SparseTensorValue)):
         t = tf.sparse.concat(0, [
             tf.sparse.expand_dims(to_tensorflow_tensor(t), 0)
             for t in target
         ])
         return to_tensor_value(t)
     elif target and isinstance(target[0], types.RaggedTensorValue):
         t = tf.concat(
             [tf.expand_dims(to_tensorflow_tensor(t), 0) for t in target],
             0)
         return to_tensor_value(t)
     else:
         arr = np.array(target)
         # Flatten values that were originally single item lists into a single list
         # e.g. [[1], [2], [3]] -> [1, 2, 3]
         if len(arr.shape) == 2 and arr.shape[1] == 1:
             return arr.squeeze(axis=1)
         # Special case for empty slice arrays since numpy treats empty tuples as
         # arrays with dimension 0.
         # e.g. [[()], [()], [()]] -> [(), (), ()]
         elif len(arr.shape
                  ) == 3 and arr.shape[1] == 1 and arr.shape[2] == 0:
             return arr.squeeze(axis=1)
         else:
             return arr
예제 #4
0
 def merge_lists(target: types.Extracts) -> types.Extracts:
     """Converts target's leaves which are lists to batched np.array's, etc."""
     if isinstance(target, Mapping):
         result = {}
         for key, value in target.items():
             try:
                 result[key] = merge_lists(value)
             except Exception as e:
                 raise RuntimeError(
                     'Failed to convert value for key "{}"'.format(
                         key)) from e
         return {k: merge_lists(v) for k, v in target.items()}
     elif target and (isinstance(target[0], tf.compat.v1.SparseTensorValue)
                      or isinstance(target[0], types.SparseTensorValue)):
         t = tf.compat.v1.sparse_concat(0, [
             tf.sparse.expand_dims(to_tensorflow_tensor(t), 0)
             for t in target
         ],
                                        expand_nonconcat_dim=True)
         return to_tensor_value(t)
     elif target and isinstance(target[0], types.RaggedTensorValue):
         t = tf.concat(
             [tf.expand_dims(to_tensorflow_tensor(t), 0) for t in target],
             0)
         return to_tensor_value(t)
     elif (all(isinstance(t, np.ndarray) for t in target)
           and len({t.shape
                    for t in target}) > 1):
         return types.VarLenTensorValue.from_dense_rows(target)
     else:
         arr = np.array(target)
         # Flatten values that were originally single item lists into a single list
         # e.g. [[1], [2], [3]] -> [1, 2, 3]
         if len(arr.shape) == 2 and arr.shape[1] == 1:
             return arr.squeeze(axis=1)
         return arr
예제 #5
0
 def process(
         self, element: types.Extracts
 ) -> List[Tuple[SliceKeyType, types.Extracts]]:
     key_filter_fn = self._key_filter_fn  # Local cache.
     filtered = {k: v for k, v in element.items() if key_filter_fn(k)}
     slice_keys = element.get(constants.SLICE_KEY_TYPES_KEY)
     # The query based evaluator will group slices from multiple examples, so we
     # deduplicate to avoid overcounting. Depending on whether the rows within a
     # batch have a variable or fixed length, either a VarLenTensorValue or a 2D
     # np.ndarray will be created.
     if isinstance(slice_keys, types.VarLenTensorValue):
         slice_keys = slice_keys.values
     elif isinstance(slice_keys, np.ndarray) and len(slice_keys.shape) == 2:
         slice_keys = slice_keys.flatten()
     result = [(slice_key, filtered) for slice_key in set(slice_keys)]
     self._num_slices_generated_per_instance.update(len(result))
     self._post_slice_num_instances.inc(len(result))
     return result
예제 #6
0
 def process(
         self, element: types.Extracts
 ) -> List[Tuple[SliceKeyType, types.Extracts]]:
     key_filter_fn = self._key_filter_fn  # Local cache.
     filtered = {k: v for k, v in element.items() if key_filter_fn(k)}
     slice_keys = element.get(constants.SLICE_KEY_TYPES_KEY)
     # The query based evaluator will group slices into a multi-dimentional array
     # with an extra dimension representing the examples matching the query key.
     # We need to flatten and dedup the slice keys.
     if _is_multi_dim_keys(slice_keys):
         arr = np.array(slice_keys)
         unique_keys = set()
         for k in arr.flatten():
             unique_keys.add(k)
         if not unique_keys and arr.shape:
             # If only the empty overall slice is in array, it is removed by flatten
             unique_keys.add(())
         slice_keys = unique_keys
     result = [(slice_key, filtered) for slice_key in slice_keys]
     self._num_slices_generated_per_instance.update(len(result))
     self._post_slice_num_instances.inc(len(result))
     return result
예제 #7
0
 def visit(subtree: types.Extracts, keys: List[str]):
     for key, value in subtree.items():
         if isinstance(value, Mapping):
             visit(value, keys + [key])
         else:
             add_to_results(keys + [key], value)