def store(batch: Dict[str, List], updates: List[Dict]) -> Dict[str, List]: """Update a batch of examples with slice information.""" if "slices" not in batch: batch["slices"] = [{} for _ in range(len(batch["index"]))] # For each example, recursively merge the example's original cache dictionary # with the update dictionary batch["slices"] = [ recmerge(example_dict, update_dict, merge_sequences=True) for example_dict, update_dict in zip(batch["slices"], updates) ] return batch
def _merge_batch_and_output(cls, batch: Batch, output: Batch): """Merge an output during .map() into a batch.""" combined = batch for k in output.keys(): if k not in batch: combined[k] = output[k] else: if isinstance(batch[k][0], dict) and isinstance(output[k][0], dict): combined[k] = [ recmerge(b_i, o_i) for b_i, o_i in zip(batch[k], output[k]) ] else: combined[k] = output[k] return combined
def store(batch: Batch, updates: List[Dict]) -> Batch: """Updates the cache of preprocessed information stored with each example in a batch. Args: batch: a batch of data updates: a list of dictionaries, one per example Returns: updated batch """ if "cache" not in batch: batch["cache"] = [{} for _ in range(len(batch["index"]))] # For each example, recursively merge the example's original cache dictionary # with the update dictionary batch["cache"] = [ recmerge(cache_dict, update_dict) for cache_dict, update_dict in zip(batch["cache"], updates) ] return batch
def test_recmerge(self): output = recmerge( { "a": 2, "b": 3, "d": { "e": [1, 2, 3], "f": [3, 4, 5] }, "g": 17 }, { "b": 12, "d": { "e": [1, 2, 3], "f": [3, 4] } }, { "a": 4, "d": { "f": [3] } }, ) self.assertEqual(output, { "a": 4, "b": 12, "d": { "e": [1, 2, 3], "f": [3] }, "g": 17 }) output = recmerge( { "a": 2, "b": 3, "d": { "e": [1, 2, 3], "f": [3, 4, 5] }, "g": 17 }, { "b": 12, "d": { "e": [1, 2, 3], "f": [3, 4] } }, { "a": 4, "d": { "f": [3] } }, merge_sequences=True, ) self.assertEqual( output, { "a": 4, "b": 12, "d": { "e": [1, 2, 3, 1, 2, 3], "f": [3, 4, 5, 3, 4, 3] }, "g": 17, }, )
def __call__( self, batch_or_dataset: Union[Batch, Dataset], columns: List[str], mask: List[int] = None, store_compressed: bool = None, store: bool = None, num_proc: int = None, *args, **kwargs, ): if mask: raise NotImplementedError( "Mask not supported for SubpopulationCollection yet.") if not num_proc or num_proc == 1: slices = [] slice_membership = [] # Apply each slicebuilder in sequence for i, slicebuilder in tqdm(enumerate(self.subpopulations)): # Apply the slicebuilder batch_or_dataset, slices_i, slice_membership_i = slicebuilder( batch_or_dataset=batch_or_dataset, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) # Add in the slices and slice membership slices.extend(slices_i) slice_membership.append(slice_membership_i) else: # TODO(karan): cleanup, make mp.Pool support simpler across the library with Pool(num_proc) as pool: batches_or_datasets, slices, slice_membership = zip(*pool.map( lambda sb: sb( batch_or_dataset=batch_or_dataset, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ), [slicebuilder for slicebuilder in self.subpopulations], )) # Combine all the slices slices = list(tz.concat(slices)) def _store_updates(batch, indices): # Each Subpopulation will generate slices for i, subpopulation in enumerate(self.subpopulations): updates = subpopulation.construct_updates( slice_membership=slice_membership[i][indices], columns=columns, mask=mask, # TODO(karan): this option should be set correctly compress=True, ) batch = subpopulation.store( batch=batch, updates=updates, ) return batch if isinstance(batch_or_dataset, Dataset): batch_or_dataset = batch_or_dataset.map( _store_updates, with_indices=True, batched=True, ) for subpopulation in self.subpopulations: # Update the Dataset's history batch_or_dataset.update_tape( path=[SLICEBUILDERS, subpopulation.category], identifiers=subpopulation.identifiers, columns=columns, ) else: batch_or_dataset = recmerge(*batches_or_datasets, merge_sequences=True) # Combine all the slice membership matrices slice_membership = np.concatenate(slice_membership, axis=1) return batch_or_dataset, slices, slice_membership