Ejemplo n.º 1
0
    def prepare_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        *args,
        **kwargs,
    ) -> None:
        """Apply a preparation function to the dataset. Use this to update
        attributes of `self`.

        Args:
            dataset: dataset
            columns: list of columns
            batch_size: batch size for preparation
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments
        """
        # Set the data format
        with dataset.format(
            columns + self._filter_prerequisite_columns(columns, dataset.column_names)
        ):
            # Batch the dataset, and prepare each batch
            for batch in dataset.batch(batch_size):
                try:
                    # Check if the `prepare_batch` function has been implemented
                    self.prepare_batch(
                        batch=batch,
                        columns=columns,
                        *args,
                        **kwargs,
                    )
                except NotImplementedError:
                    break
Ejemplo n.º 2
0
    def prepare_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        *args,
        **kwargs,
    ) -> None:
        """Preparation that is applied before the CachedOperation.

        Many CachedOperations require a full pass over the dataset to precompute some
        variables before the core operation can actually be applied e.g. to create a
        Bag-of-Words representation, constructing a dataset vocabulary to keep only
        tokens that are frequently seen across the dataset.

        Args:
            dataset: Dataset
            columns: list of columns
            batch_size: batch size for .map(..)

        Returns: updated Dataset
        """

        # Set the data format
        dataset.set_format(columns)

        # Batch the dataset, and prepare each batch
        for batch in dataset.batch(batch_size):
            try:
                # Check if the `prepare_batch` function has been implemented
                self.prepare_batch(
                    batch=batch,
                    columns=columns,
                    *args,
                    **kwargs,
                )
            except NotImplementedError:
                break

        # Reset the data format
        dataset.reset_format()
Ejemplo n.º 3
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[Dataset, List[Slice], np.ndarray]:
        """Apply a SliceBuilder to a dataset.

        Args:
            dataset: Dataset
            columns: list of columns
            batch_size: integer batch size
            mask: boolean or integer mask array, mask[i] = True means that the ith
            slice will be masked out
            store_compressed: whether to store in a compressed format
            store: whether to store the results along with the example in Dataset
            num_proc: num processes for multiprocessing
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments

        Returns: tuple of (Dataset, list of Slices, matrix of (example,
        slice) membership)
        """
        # Prepare the dataset
        dataset = self.prepare_dataset(
            dataset=dataset,
            columns=columns,
            batch_size=batch_size,
            mask=mask,
            store_compressed=store_compressed,
            store=store,
            *args,
            **kwargs,
        )

        # Compute a hash
        val = persistent_hash(str(
            dataset.identifier)) ^ dataset.hash_interactions()
        for i, identifier in enumerate(self.identifiers):
            if not mask[i]:
                val ^= persistent_hash(
                    str(identifier) + str(strings_as_json(columns)))

        try:
            # Map the SliceBuilder over the dataset
            all_sliced_batches = []
            all_slice_memberships = []

            def _map_fn(batch):
                """Map function for processing batches.

                Note that using this map_fn in a stateful way is
                dangerous, since every invocation of this function
                appends to the all_slice_batches list. The .map()
                function will invoke this once for testing before
                performing the map, so we discard the first entry
                inserted into all_sliced_batches.
                """
                batch, sliced_batches, slice_membership = self.process_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                )
                all_sliced_batches.append(sliced_batches)
                all_slice_memberships.append(slice_membership)
                return batch

            dataset = dataset.map(
                _map_fn,
                batched=True,
                batch_size=batch_size,
                # FIXME(karan): enable this by adding logic for generating
                #  all_sliced_batches and all_slice_memberships
                #  when loading from cache file
                load_from_cache_file=False,
                # The cache file name is a XOR of the interaction history and the
                # current operation
                cache_file_name=str(dataset.logdir /
                                    ("cache-" + str(abs(val)) + ".arrow")),
            )

            # Remove the first entry (see _map_fn)
            all_sliced_batches = all_sliced_batches[1:]
            all_slice_memberships = all_slice_memberships[1:]

        except:  # noqa
            # Batch the dataset, and process each batch
            all_batches, all_sliced_batches, all_slice_memberships = zip(*[
                self.process_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ) for batch in dataset.batch(batch_size)
            ])

            # Update the dataset efficiently by reusing all_batches
            dataset = dataset.map(
                lambda examples, indices: all_batches[indices[0] // batch_size
                                                      ],
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                load_from_cache_file=False,
                # The cache file name is a XOR of the interaction history and the
                # current operation
                cache_file_name=str(dataset.logdir /
                                    ("cache-" + str(abs(val)) + ".arrow")),
            )

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        slice_cache_hashes = []
        for identifier in self.identifiers:
            slice_cache_hashes.append(val ^ persistent_hash(str(identifier)))

        if not num_proc or num_proc == 1:
            # Construct slices
            slices = []
            for i, slice_batches in enumerate(zip(*all_sliced_batches)):
                slices.append(
                    create_slice((
                        dataset,
                        slice_membership,
                        slice_batches,
                        i,
                        batch_size,
                        slice_cache_hashes[i],
                    )))
        else:
            # Parallelized slice construction
            with Pool(num_proc) as pool:
                slices = pool.map(
                    create_slice,
                    [(
                        dataset,
                        slice_membership,
                        slice_batches,
                        i,
                        batch_size,
                        slice_cache_hashes[i],
                    )
                     for i, slice_batches in enumerate(zip(
                         *all_sliced_batches))],
                )

        # TODO(karan): make this more systematic
        # TODO(karan): fix bug when slicing a Slice
        for i, sl in enumerate(slices):
            # # Set the Slice features
            # sl.info.features = dataset.features

            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Create the lineage
            sl.lineage = [
                (str(Dataset.__name__), dataset.identifier),
                (
                    str(self.category.capitalize()),
                    self.identifiers[i],
                    strings_as_json(columns),
                ),
            ]
            if isinstance(dataset, Slice):
                # Prepend the Slice's lineage instead, if the dataset was a slice
                sl.lineage = dataset.lineage + [(
                    str(self.category.capitalize()),
                    self.identifiers[i],
                    strings_as_json(columns),
                )]

        return dataset, slices, slice_membership
Ejemplo n.º 4
0
    def prepare_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        *args,
        **kwargs,
    ) -> Dataset:

        # Compute the hash for this operation
        # FIXME(karan): this is repeated inside process_dataset
        val = persistent_hash(str(
            dataset.identifier)) ^ dataset.hash_interactions()
        for i, identifier in enumerate(self.identifiers):
            if not mask[i]:
                val ^= persistent_hash(
                    str(identifier) + str(strings_as_json(columns)))

        try:
            return dataset.map(
                partial(
                    self.prepare_batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ),
                batched=True,
                batch_size=batch_size,
                load_from_cache_file=False,
                cache_file_name=str(
                    dataset.logdir /
                    ("cache-" + str(abs(val)) + "-prep.arrow")),
            )
        except:  # TypeError or PicklingError or AttributeError: # noqa
            # Batch the dataset, and process each batch
            all_batches = [
                self.prepare_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ) for batch in dataset.batch(batch_size)
            ]

            # Update the dataset efficiently by reusing all_batches
            return dataset.map(
                lambda examples, indices: all_batches[indices[0] // batch_size
                                                      ],
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                load_from_cache_file=False,
                cache_file_name=str(
                    dataset.logdir /
                    ("cache-" + str(abs(val)) + "-prep.arrow")),
            )
Ejemplo n.º 5
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        # mask: List[int] = None,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[List[Slice], np.ndarray]:

        # Create slices using the dataset
        slices = [Slice(dataset) for _ in range(len(self.identifiers))]
        all_slice_memberships = []
        # Batch the dataset, and process each batch
        for batch in dataset.batch(batch_size):
            # Process the batch
            _, slice_memberships = self.process_batch(
                batch=batch,
                columns=columns,
                *args,
                **kwargs,
            )

            # Keep track of the slice memberships
            all_slice_memberships.append(slice_memberships)

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        for i, sl in enumerate(slices):
            # Set the visible rows for each slice
            sl.set_visible_rows(np.where(slice_membership[:, i])[0])

            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Append the the lineage
            sl.add_to_lineage(
                category=str(self.category.capitalize()),
                identifier=self.identifiers[i],
                columns=strings_as_json(columns),
            )

            # # Create the lineage
            # sl.lineage = [
            #     (str(Dataset.__name__), dataset.identifier),
            #     (
            #         str(self.category.capitalize()),
            #         self.identifiers[i],
            #         strings_as_json(columns),
            #     ),
            # ]
            # if isinstance(dataset, Slice):
            #     # Prepend the Slice's lineage instead, if the dataset was a slice
            #     sl.lineage = dataset.lineage + [
            #         (
            #             str(self.category.capitalize()),
            #             self.identifiers[i],
            #             strings_as_json(columns),
            #         )
            #     ]

        return slices, slice_membership
Ejemplo n.º 6
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        # mask: List[int] = None,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[List[Slice], np.ndarray]:
        """Apply a SliceBuilder to a dataset.

        Args:
            dataset: Dataset
            columns: list of columns
            batch_size: integer batch size
            # mask: boolean or integer mask array, mask[i] = True means that the ith
            # slice will be masked out
            num_proc: num processes for multiprocessing
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments

        Returns: tuple of (Dataset, list of Slices, matrix of (example,
        slice) membership)
        """

        # # Compute a hash
        # val = persistent_hash(str(dataset.identifier)) ^ dataset.hash_interactions()
        # for i, identifier in enumerate(self.identifiers):
        #     if not mask[i]:
        #         val ^= persistent_hash(str(identifier)
        #         + str(strings_as_json(columns)))

        # try:
        #     # Map the SliceBuilder over the dataset
        #     all_sliced_batches = []
        #     all_slice_memberships = []
        #
        #     def _map_fn(batch):
        #         """Map function for processing batches.
        #
        #         Note that using this map_fn in a stateful way is
        #         dangerous, since every invocation of this function
        #         appends to the all_slice_batches list. The .map()
        #         function will invoke this once for testing before
        #         performing the map, so we discard the first entry
        #         inserted into all_sliced_batches.
        #         """
        #         batch, sliced_batches, slice_membership = self.process_batch(
        #             batch=batch,
        #             columns=columns,
        #             mask=mask,
        #             store_compressed=store_compressed,
        #             store=store,
        #             *args,
        #             **kwargs,
        #         )
        #         all_sliced_batches.append(sliced_batches)
        #         all_slice_memberships.append(slice_membership)
        #         return batch
        #
        #     dataset = dataset.map(
        #         _map_fn,
        #         batched=True,
        #         batch_size=batch_size,
        #         # FIXME(karan): enable this by adding logic for generating
        #         #  all_sliced_batches and all_slice_memberships
        #         #  when loading from cache file
        #         load_from_cache_file=False,
        #         # The cache file name is a XOR of the interaction history and the
        #         # current operation
        #         cache_file_name=str(
        #             dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")
        #         ),
        #     )
        #
        #     # Remove the first entry (see _map_fn)
        #     all_sliced_batches = all_sliced_batches[1:]
        #     all_slice_memberships = all_slice_memberships[1:]
        #
        # except:  # noqa
        # all_batches, all_sliced_batches, all_slice_memberships = zip(
        #     *[
        #         self.process_batch(
        #             batch=batch,
        #             columns=columns,
        #             mask=mask,
        #             store_compressed=store_compressed,
        #             store=store,
        #             *args,
        #             **kwargs,
        #         )
        #         for batch in dataset.batch(batch_size)
        #     ]
        # )
        # # Update the dataset efficiently by reusing all_batches
        # dataset = dataset.map(
        #     lambda examples, indices: all_batches[indices[0] // batch_size],
        #     batched=True,
        #     batch_size=batch_size,
        #     with_indices=True,
        #     load_from_cache_file=False,
        #     # The cache file name is a XOR of the interaction history and the
        #     # current operation
        #     cache_file_name=str(
        #         dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")
        #     ),
        # )

        # Create slices
        slices = [Slice() for _ in range(len(self.identifiers))]
        all_slice_memberships = []
        # Batch the dataset, and process each batch
        for batch in dataset.batch(batch_size):
            # Process the batch
            sliced_batches, slice_memberships = self.process_batch(
                batch=batch,
                columns=columns,
                *args,
                **kwargs,
            )

            # Incrementally build the slices
            for sl, sl_batch in zip(slices, sliced_batches):
                sl._dataset.append(sl_batch)

            # Keep track of the slice memberships
            all_slice_memberships.append(slice_memberships)

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        # slice_cache_hashes = []
        # for identifier in self.identifiers:
        #     slice_cache_hashes.append(val ^ persistent_hash(str(identifier)))

        # if not num_proc or num_proc == 1:
        #     # Construct slices
        #     slices = []
        #     for i, slice_batches in enumerate(zip(*all_sliced_batches)):
        #         slices.append(
        #             create_slice(
        #                 (
        #                     dataset,
        #                     slice_membership,
        #                     slice_batches,
        #                     i,
        #                     batch_size,
        #                     slice_cache_hashes[i],
        #                 )
        #             )
        #         )
        # else:
        #     # Parallelized slice construction
        #     with Pool(num_proc) as pool:
        #         slices = pool.map(
        #             create_slice,
        #             [
        #                 (
        #                     dataset,
        #                     slice_membership,
        #                     slice_batches,
        #                     i,
        #                     batch_size,
        #                     slice_cache_hashes[i],
        #                 )
        #                 for i, slice_batches in enumerate(zip(*all_sliced_batches))
        #             ],
        #         )

        for i, sl in enumerate(slices):
            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Append the the lineage
            sl.add_to_lineage(
                category=str(self.category.capitalize()),
                identifier=self.identifiers[i],
                columns=strings_as_json(columns),
            )

            # # Create the lineage
            # sl.lineage = [
            #     (str(Dataset.__name__), dataset.identifier),
            #     (
            #         str(self.category.capitalize()),
            #         self.identifiers[i],
            #         strings_as_json(columns),
            #     ),
            # ]
            # if isinstance(dataset, Slice):
            #     # Prepend the Slice's lineage instead, if the dataset was a slice
            #     sl.lineage = dataset.lineage + [
            #         (
            #             str(self.category.capitalize()),
            #             self.identifiers[i],
            #             strings_as_json(columns),
            #         )
            #     ]

        return slices, slice_membership