Ejemplo n.º 1
0
    def store(batch: Dict[str, List], updates: List[Dict]) -> Dict[str, List]:
        """Update a batch of examples with slice information."""
        if "slices" not in batch:
            batch["slices"] = [{} for _ in range(len(batch["index"]))]

        # For each example, recursively merge the example's original cache dictionary
        # with the update dictionary
        batch["slices"] = [
            recmerge(example_dict, update_dict, merge_sequences=True)
            for example_dict, update_dict in zip(batch["slices"], updates)
        ]

        return batch
Ejemplo n.º 2
0
 def _merge_batch_and_output(cls, batch: Batch, output: Batch):
     """Merge an output during .map() into a batch."""
     combined = batch
     for k in output.keys():
         if k not in batch:
             combined[k] = output[k]
         else:
             if isinstance(batch[k][0], dict) and isinstance(output[k][0], dict):
                 combined[k] = [
                     recmerge(b_i, o_i) for b_i, o_i in zip(batch[k], output[k])
                 ]
             else:
                 combined[k] = output[k]
     return combined
Ejemplo n.º 3
0
    def store(batch: Batch, updates: List[Dict]) -> Batch:
        """Updates the cache of preprocessed information stored with each
        example in a batch.

        Args:
            batch: a batch of data
            updates: a list of dictionaries, one per example

        Returns: updated batch
        """
        if "cache" not in batch:
            batch["cache"] = [{} for _ in range(len(batch["index"]))]

        # For each example, recursively merge the example's original cache dictionary
        # with the update dictionary
        batch["cache"] = [
            recmerge(cache_dict, update_dict)
            for cache_dict, update_dict in zip(batch["cache"], updates)
        ]

        return batch
Ejemplo n.º 4
0
    def test_recmerge(self):
        output = recmerge(
            {
                "a": 2,
                "b": 3,
                "d": {
                    "e": [1, 2, 3],
                    "f": [3, 4, 5]
                },
                "g": 17
            },
            {
                "b": 12,
                "d": {
                    "e": [1, 2, 3],
                    "f": [3, 4]
                }
            },
            {
                "a": 4,
                "d": {
                    "f": [3]
                }
            },
        )
        self.assertEqual(output, {
            "a": 4,
            "b": 12,
            "d": {
                "e": [1, 2, 3],
                "f": [3]
            },
            "g": 17
        })

        output = recmerge(
            {
                "a": 2,
                "b": 3,
                "d": {
                    "e": [1, 2, 3],
                    "f": [3, 4, 5]
                },
                "g": 17
            },
            {
                "b": 12,
                "d": {
                    "e": [1, 2, 3],
                    "f": [3, 4]
                }
            },
            {
                "a": 4,
                "d": {
                    "f": [3]
                }
            },
            merge_sequences=True,
        )
        self.assertEqual(
            output,
            {
                "a": 4,
                "b": 12,
                "d": {
                    "e": [1, 2, 3, 1, 2, 3],
                    "f": [3, 4, 5, 3, 4, 3]
                },
                "g": 17,
            },
        )
Ejemplo n.º 5
0
    def __call__(
        self,
        batch_or_dataset: Union[Batch, Dataset],
        columns: List[str],
        mask: List[int] = None,
        store_compressed: bool = None,
        store: bool = None,
        num_proc: int = None,
        *args,
        **kwargs,
    ):

        if mask:
            raise NotImplementedError(
                "Mask not supported for SubpopulationCollection yet.")

        if not num_proc or num_proc == 1:
            slices = []
            slice_membership = []
            # Apply each slicebuilder in sequence
            for i, slicebuilder in tqdm(enumerate(self.subpopulations)):
                # Apply the slicebuilder
                batch_or_dataset, slices_i, slice_membership_i = slicebuilder(
                    batch_or_dataset=batch_or_dataset,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                )

                # Add in the slices and slice membership
                slices.extend(slices_i)
                slice_membership.append(slice_membership_i)

        else:
            # TODO(karan): cleanup, make mp.Pool support simpler across the library
            with Pool(num_proc) as pool:
                batches_or_datasets, slices, slice_membership = zip(*pool.map(
                    lambda sb: sb(
                        batch_or_dataset=batch_or_dataset,
                        columns=columns,
                        mask=mask,
                        store_compressed=store_compressed,
                        store=store,
                        *args,
                        **kwargs,
                    ),
                    [slicebuilder for slicebuilder in self.subpopulations],
                ))

                # Combine all the slices
                slices = list(tz.concat(slices))

            def _store_updates(batch, indices):

                # Each Subpopulation will generate slices
                for i, subpopulation in enumerate(self.subpopulations):
                    updates = subpopulation.construct_updates(
                        slice_membership=slice_membership[i][indices],
                        columns=columns,
                        mask=mask,
                        # TODO(karan): this option should be set correctly
                        compress=True,
                    )

                    batch = subpopulation.store(
                        batch=batch,
                        updates=updates,
                    )

                return batch

            if isinstance(batch_or_dataset, Dataset):
                batch_or_dataset = batch_or_dataset.map(
                    _store_updates,
                    with_indices=True,
                    batched=True,
                )

                for subpopulation in self.subpopulations:
                    # Update the Dataset's history
                    batch_or_dataset.update_tape(
                        path=[SLICEBUILDERS, subpopulation.category],
                        identifiers=subpopulation.identifiers,
                        columns=columns,
                    )

            else:
                batch_or_dataset = recmerge(*batches_or_datasets,
                                            merge_sequences=True)

        # Combine all the slice membership matrices
        slice_membership = np.concatenate(slice_membership, axis=1)

        return batch_or_dataset, slices, slice_membership