def process(self, element: types.Extracts) -> Sequence[types.Extracts]:
     batch_size = element[constants.ARROW_RECORD_BATCH_KEY].num_rows
     try:
         result = self._batch_reducible_process(element)
         self._batch_size.update(batch_size)
         self._num_instances.inc(batch_size)
         return result
     except (ValueError, tf.errors.InvalidArgumentError) as e:
         logging.warning(
             'Large batch_size %s failed with error %s. '
             'Attempting to run batch through serially. Note that this will '
             'significantly affect the performance.', batch_size, e)
         self._batch_size_failed.update(batch_size)
         result = []
         record_batch = element[constants.ARROW_RECORD_BATCH_KEY]
         for i in range(batch_size):
             self._batch_size.update(1)
             unbatched_element = {}
             for key in element.keys():
                 if key == constants.ARROW_RECORD_BATCH_KEY:
                     unbatched_element[key] = record_batch.slice(i, 1)
                 else:
                     unbatched_element[key] = [element[key][i]]
             result.extend(self._batch_reducible_process(unbatched_element))
         self._num_instances.inc(len(result))
         return result
def _ExtractUnbatchedInputs(
        batched_extract: types.Extracts) -> Sequence[types.Extracts]:
    """Extract features, predictions, labels and weights from batched extract."""
    keys_to_retain = set(batched_extract.keys())
    keys_to_retain.remove(constants.ARROW_RECORD_BATCH_KEY)
    dataframe = pd.DataFrame()
    for key in keys_to_retain:
        dataframe[key] = batched_extract[key]
    return dataframe.to_dict(orient='records')
Example #3
0
def _extract_unbatched_inputs(  # pylint: disable=invalid-name
    mixed_legacy_batched_extract: types.Extracts) -> Sequence[types.Extracts]:
    """Extract features, predictions, labels and weights from batched extract."""
    batched_extract = {}
    # TODO(mdreves): Remove record batch
    keys_to_retain = set(mixed_legacy_batched_extract.keys())
    if constants.ARROW_RECORD_BATCH_KEY in keys_to_retain:
        keys_to_retain.remove(constants.ARROW_RECORD_BATCH_KEY)
    dataframe = pd.DataFrame()
    for key in keys_to_retain:
        # Previously a batch of transformed features were stored as a list of dicts
        # instead of a dict of np.arrays with batch dimensions. These legacy
        # conversions are done using dataframes instead.
        if isinstance(mixed_legacy_batched_extract[key], list):
            try:
                dataframe[key] = mixed_legacy_batched_extract[key]
            except Exception as e:
                raise RuntimeError(
                    f'Exception encountered while adding key {key} with '
                    f'batched length {len(mixed_legacy_batched_extract[key])}'
                ) from e
        else:
            batched_extract[key] = mixed_legacy_batched_extract[key]
    unbatched_extracts = util.split_extracts(batched_extract)
    legacy_unbatched_extracts = dataframe.to_dict(orient='records')
    if unbatched_extracts and legacy_unbatched_extracts:
        if len(unbatched_extracts) != len(legacy_unbatched_extracts):
            raise ValueError(
                f'Batch sizes have differing values: {len(unbatched_extracts)} != '
                f'{len(legacy_unbatched_extracts)}, '
                f'unbatched_extracts={unbatched_extracts}, '
                f'legacy_unbatched_extracts={legacy_unbatched_extracts}')
        result = []
        for unbatched_extract, legacy_unbatched_extract in zip(
                unbatched_extracts, legacy_unbatched_extracts):
            legacy_unbatched_extract.update(unbatched_extract)
            result.append(legacy_unbatched_extract)
        return result
    elif legacy_unbatched_extracts:
        return legacy_unbatched_extracts
    else:
        return unbatched_extracts