def process(self, element: types.Extracts) -> Sequence[types.Extracts]: batch_size = element[constants.ARROW_RECORD_BATCH_KEY].num_rows try: result = self._batch_reducible_process(element) self._batch_size.update(batch_size) self._num_instances.inc(batch_size) return result except (ValueError, tf.errors.InvalidArgumentError) as e: logging.warning( 'Large batch_size %s failed with error %s. ' 'Attempting to run batch through serially. Note that this will ' 'significantly affect the performance.', batch_size, e) self._batch_size_failed.update(batch_size) result = [] record_batch = element[constants.ARROW_RECORD_BATCH_KEY] for i in range(batch_size): self._batch_size.update(1) unbatched_element = {} for key in element.keys(): if key == constants.ARROW_RECORD_BATCH_KEY: unbatched_element[key] = record_batch.slice(i, 1) else: unbatched_element[key] = [element[key][i]] result.extend(self._batch_reducible_process(unbatched_element)) self._num_instances.inc(len(result)) return result
def _ExtractUnbatchedInputs( batched_extract: types.Extracts) -> Sequence[types.Extracts]: """Extract features, predictions, labels and weights from batched extract.""" keys_to_retain = set(batched_extract.keys()) keys_to_retain.remove(constants.ARROW_RECORD_BATCH_KEY) dataframe = pd.DataFrame() for key in keys_to_retain: dataframe[key] = batched_extract[key] return dataframe.to_dict(orient='records')
def _extract_unbatched_inputs( # pylint: disable=invalid-name mixed_legacy_batched_extract: types.Extracts) -> Sequence[types.Extracts]: """Extract features, predictions, labels and weights from batched extract.""" batched_extract = {} # TODO(mdreves): Remove record batch keys_to_retain = set(mixed_legacy_batched_extract.keys()) if constants.ARROW_RECORD_BATCH_KEY in keys_to_retain: keys_to_retain.remove(constants.ARROW_RECORD_BATCH_KEY) dataframe = pd.DataFrame() for key in keys_to_retain: # Previously a batch of transformed features were stored as a list of dicts # instead of a dict of np.arrays with batch dimensions. These legacy # conversions are done using dataframes instead. if isinstance(mixed_legacy_batched_extract[key], list): try: dataframe[key] = mixed_legacy_batched_extract[key] except Exception as e: raise RuntimeError( f'Exception encountered while adding key {key} with ' f'batched length {len(mixed_legacy_batched_extract[key])}' ) from e else: batched_extract[key] = mixed_legacy_batched_extract[key] unbatched_extracts = util.split_extracts(batched_extract) legacy_unbatched_extracts = dataframe.to_dict(orient='records') if unbatched_extracts and legacy_unbatched_extracts: if len(unbatched_extracts) != len(legacy_unbatched_extracts): raise ValueError( f'Batch sizes have differing values: {len(unbatched_extracts)} != ' f'{len(legacy_unbatched_extracts)}, ' f'unbatched_extracts={unbatched_extracts}, ' f'legacy_unbatched_extracts={legacy_unbatched_extracts}') result = [] for unbatched_extract, legacy_unbatched_extract in zip( unbatched_extracts, legacy_unbatched_extracts): legacy_unbatched_extract.update(unbatched_extract) result.append(legacy_unbatched_extract) return result elif legacy_unbatched_extracts: return legacy_unbatched_extracts else: return unbatched_extracts