Ejemplo n.º 1
0
    def remap_schema(self, dataset: Dataset):
        # Ground the schema to the dataset
        input_grounding, reversed_input_grounding = self.input_schema.ground(
            dataset.features
        )
        output_grounding, reversed_output_grounding = self.output_schema.ground(
            dataset.features
        )

        # Construct a map_fn that remaps the dataset schema
        def map_fn(example):
            return tz.merge(
                {k: example[input_grounding[k]] for k in self.input_schema.features},
                {k: example[output_grounding[k]] for k in self.output_schema.features},
            )

        return dataset.map(
            map_fn,
            remove_columns=list(reversed_input_grounding.keys())
            + list(reversed_output_grounding.keys()),
        )
Ejemplo n.º 2
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[Dataset, List[Slice], np.ndarray]:
        """Apply a SliceBuilder to a dataset.

        Args:
            dataset: Dataset
            columns: list of columns
            batch_size: integer batch size
            mask: boolean or integer mask array, mask[i] = True means that the ith
            slice will be masked out
            store_compressed: whether to store in a compressed format
            store: whether to store the results along with the example in Dataset
            num_proc: num processes for multiprocessing
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments

        Returns: tuple of (Dataset, list of Slices, matrix of (example,
        slice) membership)
        """
        # Prepare the dataset
        dataset = self.prepare_dataset(
            dataset=dataset,
            columns=columns,
            batch_size=batch_size,
            mask=mask,
            store_compressed=store_compressed,
            store=store,
            *args,
            **kwargs,
        )

        # Compute a hash
        val = persistent_hash(str(
            dataset.identifier)) ^ dataset.hash_interactions()
        for i, identifier in enumerate(self.identifiers):
            if not mask[i]:
                val ^= persistent_hash(
                    str(identifier) + str(strings_as_json(columns)))

        try:
            # Map the SliceBuilder over the dataset
            all_sliced_batches = []
            all_slice_memberships = []

            def _map_fn(batch):
                """Map function for processing batches.

                Note that using this map_fn in a stateful way is
                dangerous, since every invocation of this function
                appends to the all_slice_batches list. The .map()
                function will invoke this once for testing before
                performing the map, so we discard the first entry
                inserted into all_sliced_batches.
                """
                batch, sliced_batches, slice_membership = self.process_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                )
                all_sliced_batches.append(sliced_batches)
                all_slice_memberships.append(slice_membership)
                return batch

            dataset = dataset.map(
                _map_fn,
                batched=True,
                batch_size=batch_size,
                # FIXME(karan): enable this by adding logic for generating
                #  all_sliced_batches and all_slice_memberships
                #  when loading from cache file
                load_from_cache_file=False,
                # The cache file name is a XOR of the interaction history and the
                # current operation
                cache_file_name=str(dataset.logdir /
                                    ("cache-" + str(abs(val)) + ".arrow")),
            )

            # Remove the first entry (see _map_fn)
            all_sliced_batches = all_sliced_batches[1:]
            all_slice_memberships = all_slice_memberships[1:]

        except:  # noqa
            # Batch the dataset, and process each batch
            all_batches, all_sliced_batches, all_slice_memberships = zip(*[
                self.process_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ) for batch in dataset.batch(batch_size)
            ])

            # Update the dataset efficiently by reusing all_batches
            dataset = dataset.map(
                lambda examples, indices: all_batches[indices[0] // batch_size
                                                      ],
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                load_from_cache_file=False,
                # The cache file name is a XOR of the interaction history and the
                # current operation
                cache_file_name=str(dataset.logdir /
                                    ("cache-" + str(abs(val)) + ".arrow")),
            )

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        slice_cache_hashes = []
        for identifier in self.identifiers:
            slice_cache_hashes.append(val ^ persistent_hash(str(identifier)))

        if not num_proc or num_proc == 1:
            # Construct slices
            slices = []
            for i, slice_batches in enumerate(zip(*all_sliced_batches)):
                slices.append(
                    create_slice((
                        dataset,
                        slice_membership,
                        slice_batches,
                        i,
                        batch_size,
                        slice_cache_hashes[i],
                    )))
        else:
            # Parallelized slice construction
            with Pool(num_proc) as pool:
                slices = pool.map(
                    create_slice,
                    [(
                        dataset,
                        slice_membership,
                        slice_batches,
                        i,
                        batch_size,
                        slice_cache_hashes[i],
                    )
                     for i, slice_batches in enumerate(zip(
                         *all_sliced_batches))],
                )

        # TODO(karan): make this more systematic
        # TODO(karan): fix bug when slicing a Slice
        for i, sl in enumerate(slices):
            # # Set the Slice features
            # sl.info.features = dataset.features

            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Create the lineage
            sl.lineage = [
                (str(Dataset.__name__), dataset.identifier),
                (
                    str(self.category.capitalize()),
                    self.identifiers[i],
                    strings_as_json(columns),
                ),
            ]
            if isinstance(dataset, Slice):
                # Prepend the Slice's lineage instead, if the dataset was a slice
                sl.lineage = dataset.lineage + [(
                    str(self.category.capitalize()),
                    self.identifiers[i],
                    strings_as_json(columns),
                )]

        return dataset, slices, slice_membership
Ejemplo n.º 3
0
    def prepare_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        *args,
        **kwargs,
    ) -> Dataset:

        # Compute the hash for this operation
        # FIXME(karan): this is repeated inside process_dataset
        val = persistent_hash(str(
            dataset.identifier)) ^ dataset.hash_interactions()
        for i, identifier in enumerate(self.identifiers):
            if not mask[i]:
                val ^= persistent_hash(
                    str(identifier) + str(strings_as_json(columns)))

        try:
            return dataset.map(
                partial(
                    self.prepare_batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ),
                batched=True,
                batch_size=batch_size,
                load_from_cache_file=False,
                cache_file_name=str(
                    dataset.logdir /
                    ("cache-" + str(abs(val)) + "-prep.arrow")),
            )
        except:  # TypeError or PicklingError or AttributeError: # noqa
            # Batch the dataset, and process each batch
            all_batches = [
                self.prepare_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ) for batch in dataset.batch(batch_size)
            ]

            # Update the dataset efficiently by reusing all_batches
            return dataset.map(
                lambda examples, indices: all_batches[indices[0] // batch_size
                                                      ],
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                load_from_cache_file=False,
                cache_file_name=str(
                    dataset.logdir /
                    ("cache-" + str(abs(val)) + "-prep.arrow")),
            )