Beispiel #1
0
    def get_cache_hash(self, columns: Optional[List[str]] = None):
        """Construct a hash that will be used to identify the application of a
        cached operation to the columns of a dataset."""

        val = hash(self)
        if columns:
            for key in columns:
                val ^= persistent_hash(key)
        return val
    def save(self, path: str) -> None:
        """Save the current testbench to disk. This will save all slices in the
        testbench to disk, as well as metrics and other metadata associated
        with this testbench.

        Args:
            path: string path to the save directory

        Returns: None

        >>> testbench = TestBench(identifier='my-testbench',
        task=TernaryNaturalLanguageInference())
        # Save to the current directory
        >>> testbench.save('.')
        # Load back the testbench
        >>> testbench = TestBench.load('my-testbench')
        """

        # Path to the save directory
        savedir = pathlib.Path(path) / f"{self.identifier}"

        # Create a directory inside savedir for the slices
        (savedir / "slices").mkdir(parents=True, exist_ok=True)

        # Save all the slices
        pbar = tqdm(self.slices)
        for sl in pbar:
            pbar.set_description(f"Saving slice {str(sl.identifier)[:100]}...")
            sl.save_to_disk(
                str(savedir / "slices" / str(persistent_hash(str(sl.identifier))))
            )

        # Save metrics
        dill.dump(self.metrics, open(str(savedir / "metrics.dill"), "wb"))

        # Save metadata
        dill.dump(
            {
                "task": self.task,
                "identifier": self.identifier,
                "dataset_id": self.dataset_id,
            },
            open(str(savedir / "metadata.dill"), "wb"),
        )

        # Save version info
        with open(str(savedir / "version.dill"), "wb") as f:
            f.write(self._dumps_version())
    def save(self, path: str) -> None:
        """Save the current testbench to disk. This will save all slices in the
        testbench to disk, as well as metrics and other metadata associated
        with this testbench.

        Args:
            path: string path to the save directory

        Returns: None
        """

        # Path to the save directory
        savedir = pathlib.Path(path) / f"{self.identifier}"

        # Create a directory inside savedir for the slices
        (savedir / "slices").mkdir(parents=True, exist_ok=True)

        # Save all the slices
        pbar = tqdm(self.slices)
        for sl in pbar:
            pbar.set_description(f"Saving slice {str(sl.identifier)[:100]}...")
            sl.save_to_disk(
                str(savedir / "slices" /
                    str(persistent_hash(str(sl.identifier)))))

        # Save metrics
        dill.dump(self.metrics, open(str(savedir / "metrics.dill"), "wb"))

        # Save metadata
        dill.dump(
            {
                "task":
                self.task,
                "identifier":
                self.identifier,
                "dataset_id":
                self.dataset_id,
                "slices": [(sl.identifier, sl.category, sl.lineage, len(sl))
                           for sl in self.slices],
            },
            open(str(savedir / "metadata.dill"), "wb"),
        )

        # Save version info
        with open(str(savedir / "version.dill"), "wb") as f:
            f.write(self._dumps_version())
Beispiel #4
0
    def save(self, path: str) -> None:
        """Save the devbench to disk.

        This will save all slices in the devbench to disk, as well as
        metrics and aggregators associated with this devbench.

        Args:
            path: string path to the save directory e.g. "./my_analysis".
                A `.devbench` extension is added, so the devbench will be stored at
                "./my_analysis.devbench".

                To load the devbench back in, use `DevBench.load("./my_analysis")` or
                `DevBench.load("./my_analysis.devbench")`.
        """
        # Path to the save directory
        savedir = pathlib.Path(path)
        savedir = savedir.with_suffix(".devbench")

        # Create a directory inside savedir for the slices
        (savedir / "slices").mkdir(parents=True, exist_ok=True)

        # Save all the slices
        pbar = tqdm(self.slices)
        for sl in pbar:
            pbar.set_description(f"Saving slice {str(sl.identifier)[:100]}...")
            sl.write(str(savedir / "slices" / str(persistent_hash(str(sl.identifier)))))

        # Save metrics
        dill.dump(self.metrics, open(str(savedir / "metrics.dill"), "wb"))

        # Save aggregators
        dill.dump(self.aggregators, open(str(savedir / "aggregators.dill"), "wb"))

        # Save version info
        with open(str(savedir / "version.dill"), "wb") as f:
            f.write(self._dumps_version())
Beispiel #5
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[Dataset, List[Slice], np.ndarray]:
        """Apply a SliceBuilder to a dataset.

        Args:
            dataset: Dataset
            columns: list of columns
            batch_size: integer batch size
            mask: boolean or integer mask array, mask[i] = True means that the ith
            slice will be masked out
            store_compressed: whether to store in a compressed format
            store: whether to store the results along with the example in Dataset
            num_proc: num processes for multiprocessing
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments

        Returns: tuple of (Dataset, list of Slices, matrix of (example,
        slice) membership)
        """
        # Prepare the dataset
        dataset = self.prepare_dataset(
            dataset=dataset,
            columns=columns,
            batch_size=batch_size,
            mask=mask,
            store_compressed=store_compressed,
            store=store,
            *args,
            **kwargs,
        )

        # Compute a hash
        val = persistent_hash(str(
            dataset.identifier)) ^ dataset.hash_interactions()
        for i, identifier in enumerate(self.identifiers):
            if not mask[i]:
                val ^= persistent_hash(
                    str(identifier) + str(strings_as_json(columns)))

        try:
            # Map the SliceBuilder over the dataset
            all_sliced_batches = []
            all_slice_memberships = []

            def _map_fn(batch):
                """Map function for processing batches.

                Note that using this map_fn in a stateful way is
                dangerous, since every invocation of this function
                appends to the all_slice_batches list. The .map()
                function will invoke this once for testing before
                performing the map, so we discard the first entry
                inserted into all_sliced_batches.
                """
                batch, sliced_batches, slice_membership = self.process_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                )
                all_sliced_batches.append(sliced_batches)
                all_slice_memberships.append(slice_membership)
                return batch

            dataset = dataset.map(
                _map_fn,
                batched=True,
                batch_size=batch_size,
                # FIXME(karan): enable this by adding logic for generating
                #  all_sliced_batches and all_slice_memberships
                #  when loading from cache file
                load_from_cache_file=False,
                # The cache file name is a XOR of the interaction history and the
                # current operation
                cache_file_name=str(dataset.logdir /
                                    ("cache-" + str(abs(val)) + ".arrow")),
            )

            # Remove the first entry (see _map_fn)
            all_sliced_batches = all_sliced_batches[1:]
            all_slice_memberships = all_slice_memberships[1:]

        except:  # noqa
            # Batch the dataset, and process each batch
            all_batches, all_sliced_batches, all_slice_memberships = zip(*[
                self.process_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ) for batch in dataset.batch(batch_size)
            ])

            # Update the dataset efficiently by reusing all_batches
            dataset = dataset.map(
                lambda examples, indices: all_batches[indices[0] // batch_size
                                                      ],
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                load_from_cache_file=False,
                # The cache file name is a XOR of the interaction history and the
                # current operation
                cache_file_name=str(dataset.logdir /
                                    ("cache-" + str(abs(val)) + ".arrow")),
            )

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        slice_cache_hashes = []
        for identifier in self.identifiers:
            slice_cache_hashes.append(val ^ persistent_hash(str(identifier)))

        if not num_proc or num_proc == 1:
            # Construct slices
            slices = []
            for i, slice_batches in enumerate(zip(*all_sliced_batches)):
                slices.append(
                    create_slice((
                        dataset,
                        slice_membership,
                        slice_batches,
                        i,
                        batch_size,
                        slice_cache_hashes[i],
                    )))
        else:
            # Parallelized slice construction
            with Pool(num_proc) as pool:
                slices = pool.map(
                    create_slice,
                    [(
                        dataset,
                        slice_membership,
                        slice_batches,
                        i,
                        batch_size,
                        slice_cache_hashes[i],
                    )
                     for i, slice_batches in enumerate(zip(
                         *all_sliced_batches))],
                )

        # TODO(karan): make this more systematic
        # TODO(karan): fix bug when slicing a Slice
        for i, sl in enumerate(slices):
            # # Set the Slice features
            # sl.info.features = dataset.features

            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Create the lineage
            sl.lineage = [
                (str(Dataset.__name__), dataset.identifier),
                (
                    str(self.category.capitalize()),
                    self.identifiers[i],
                    strings_as_json(columns),
                ),
            ]
            if isinstance(dataset, Slice):
                # Prepend the Slice's lineage instead, if the dataset was a slice
                sl.lineage = dataset.lineage + [(
                    str(self.category.capitalize()),
                    self.identifiers[i],
                    strings_as_json(columns),
                )]

        return dataset, slices, slice_membership
Beispiel #6
0
    def prepare_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        *args,
        **kwargs,
    ) -> Dataset:

        # Compute the hash for this operation
        # FIXME(karan): this is repeated inside process_dataset
        val = persistent_hash(str(
            dataset.identifier)) ^ dataset.hash_interactions()
        for i, identifier in enumerate(self.identifiers):
            if not mask[i]:
                val ^= persistent_hash(
                    str(identifier) + str(strings_as_json(columns)))

        try:
            return dataset.map(
                partial(
                    self.prepare_batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ),
                batched=True,
                batch_size=batch_size,
                load_from_cache_file=False,
                cache_file_name=str(
                    dataset.logdir /
                    ("cache-" + str(abs(val)) + "-prep.arrow")),
            )
        except:  # TypeError or PicklingError or AttributeError: # noqa
            # Batch the dataset, and process each batch
            all_batches = [
                self.prepare_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ) for batch in dataset.batch(batch_size)
            ]

            # Update the dataset efficiently by reusing all_batches
            return dataset.map(
                lambda examples, indices: all_batches[indices[0] // batch_size
                                                      ],
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                load_from_cache_file=False,
                cache_file_name=str(
                    dataset.logdir /
                    ("cache-" + str(abs(val)) + "-prep.arrow")),
            )
Beispiel #7
0
 def __hash__(self):
     val = 0
     for (identifier, json_columns) in self.history:
         val ^= persistent_hash(str(identifier) + str(json_columns))
     return val
 def __hash__(self):
     """Compute a hash value for the cached operation object."""
     val = 0
     for identifier in self.identifiers:
         val ^= persistent_hash(str(identifier))
     return val
 def __hash__(self):
     return persistent_hash(str(self))