예제 #1
0
    def prepare_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        *args,
        **kwargs,
    ) -> None:
        """Apply a preparation function to the dataset. Use this to update
        attributes of `self`.

        Args:
            dataset: dataset
            columns: list of columns
            batch_size: batch size for preparation
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments
        """
        # Set the data format
        with dataset.format(
            columns + self._filter_prerequisite_columns(columns, dataset.column_names)
        ):
            # Batch the dataset, and prepare each batch
            for batch in dataset.batch(batch_size):
                try:
                    # Check if the `prepare_batch` function has been implemented
                    self.prepare_batch(
                        batch=batch,
                        columns=columns,
                        *args,
                        **kwargs,
                    )
                except NotImplementedError:
                    break
예제 #2
0
    def __init__(self):
        # Create a fake dataset
        self.dataset = Dataset.from_batch(
            {
                "text_a": [
                    "Before the actor slept, the senator ran.",
                    "The lawyer knew that the judges shouted.",
                    "If the actor slept, the judge saw the artist.",
                    "The lawyers resigned, or the artist slept.",
                ],
                "text_b": [
                    "The actor slept.",
                    "The judges shouted.",
                    "The actor slept.",
                    "The artist slept.",
                ],
                "label": [0, 0, 1, 1],
                "z": [1, 0, 1, 0],
                "fast": [False, True, True, False],
            },
            identifier=Identifier(_name="MockDataset", version="2.0"),
        )

        # Keep a copy of the original
        self.original_dataset = deepcopy(self.dataset)

        assert len(self.dataset) == 4
예제 #3
0
    def __init__(self):
        # Create a fake batch of data
        self.batch = {
            "text": [
                "The man is walking.",
                "The man is running.",
                "The woman is sprinting.",
                "The woman is resting.",
                "The hobbit is flying.",
                "The hobbit is swimming.",
            ],
            "label": [0, 0, 1, 1, 0, 0],
            "z": [1, 0, 1, 0, 1, 0],
            "fast": [False, True, True, False, False, False],
            "metadata": [
                {"source": "real"},
                {"source": "real"},
                {"source": "real"},
                {"source": "real"},
                {"source": "fictional"},
                {"source": "fictional"},
            ],
        }
        # Create a fake dataset
        self.dataset = Dataset.from_batch(
            self.batch,
            identifier=Identifier(_name="MockDataset", version="1.0"),
        )

        # Keep a copy of the original
        self.original_dataset = deepcopy(self.dataset)

        assert len(self.dataset) == 6
예제 #4
0
    def test_load_dataset(self):
        # Load the first 20 examples of the boolq dataset
        dataset = Dataset.load_dataset("boolq", split="train[:20]")

        # Check that we got 20 examples
        self.assertTrue(isinstance(dataset, Dataset))
        self.assertEqual(len(dataset), 20)
예제 #5
0
    def test_from_jsonl(self):
        # Create a temporary directory
        os.makedirs("tmp", exist_ok=True)

        # Create a json file with data
        with jsonlines.open("tmp/data.jsonl", "w") as writer:
            writer.write_all(
                transpose_batch({
                    "a": [1, 2, 3],
                    "b": [True, False, True],
                    "c": ["x", "y", "z"],
                    "d": [{
                        "e": 2
                    }, {
                        "e": 3
                    }, {
                        "e": 4
                    }],
                }))

        # Load the dataset
        dataset = Dataset.from_jsonl(
            json_path="tmp/data.jsonl",
            identifier=Identifier(_name="MockJSONDataset"),
        )

        self.assertEqual(set(dataset.column_names),
                         {"a", "b", "c", "d", "index"})
        self.assertEqual(len(dataset), 3)

        # Remove the temporary directory
        shutil.rmtree("tmp")
예제 #6
0
    def test_load_dataset(self):
        # Load the first 20 examples of the boolq dataset
        dataset = Dataset.load_dataset("boolq",
                                       split="train[:20]",
                                       dataset_fmt="in_memory")

        # Check that we got 20 examples
        self.assertTrue(isinstance(dataset, Dataset))
        self.assertEqual(len(dataset), 20)
        self.assertTrue(
            isinstance(dataset.identifier.parameters["version"], str))
예제 #7
0
    def prepare_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        *args,
        **kwargs,
    ) -> None:
        """Preparation that is applied before the CachedOperation.

        Many CachedOperations require a full pass over the dataset to precompute some
        variables before the core operation can actually be applied e.g. to create a
        Bag-of-Words representation, constructing a dataset vocabulary to keep only
        tokens that are frequently seen across the dataset.

        Args:
            dataset: Dataset
            columns: list of columns
            batch_size: batch size for .map(..)

        Returns: updated Dataset
        """

        # Set the data format
        dataset.set_format(columns)

        # Batch the dataset, and prepare each batch
        for batch in dataset.batch(batch_size):
            try:
                # Check if the `prepare_batch` function has been implemented
                self.prepare_batch(
                    batch=batch,
                    columns=columns,
                    *args,
                    **kwargs,
                )
            except NotImplementedError:
                break

        # Reset the data format
        dataset.reset_format()
예제 #8
0
    def test_interleave(self):
        # Interleave the dataset with itself
        dataset = Dataset.interleave(
            [self.testbed.dataset, self.testbed.dataset],
            identifier=Identifier(_name="MockInterleavedDataset"),
        )

        # Check that the elements match up
        for i, x in enumerate(dataset):
            self.assertEqual(x, self.testbed.dataset[i // 2])

        self.assertEqual(len(dataset), len(self.testbed.dataset) * 2)
예제 #9
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
    ) -> Dataset:
        """Apply the cached operation to a dataset."""

        return dataset.update(
            partial(self.process_batch, columns=columns),
            batched=True,
            batch_size=batch_size,
        )
예제 #10
0
    def test_chain(self):
        # Chain the dataset with itself
        dataset = Dataset.chain(
            [self.testbed.dataset, self.testbed.dataset],
            identifier=Identifier(_name="MockChainedDataset"),
        )

        # Check that the elements match up
        for i, x in enumerate(dataset):
            self.assertEqual(
                x, self.testbed.dataset[i % len(self.testbed.dataset)])

        self.assertEqual(len(dataset), len(self.testbed.dataset) * 2)
예제 #11
0
    def test_save_load(self):
        # Create a temporary directory
        os.makedirs("tmp", exist_ok=True)

        # Save the dataset to disk
        self.testbed.dataset.save_to_disk(path="tmp")

        # Load the dataset from disk
        dataset = Dataset.load_from_disk(path="tmp")

        # Remove the temporary directory
        shutil.rmtree("tmp")

        self.assertEqual(dataset.features, self.testbed.dataset.features)
예제 #12
0
    def test_save_load(self):
        # Create a temporary directory
        os.mkdir("tmp")

        # Save the dataset to disk
        self.testbed.dataset.save(path="tmp")

        # Load the dataset from disk
        dataset = Dataset.load(path="tmp")

        # Remove the temporary directory
        shutil.rmtree("tmp")

        self.assertEqual(dataset.features, self.testbed.dataset.features)
예제 #13
0
    def load(cls, path: str) -> DevBench:
        """Load a devbench from disk.

        Args:
            path: string path to the devbench directory

        Returns:
        """

        # Path to the save directory
        savedir = pathlib.Path(path)

        # Load all the slices
        slices = []
        for sl_path in tqdm(list((savedir / "slices").glob("*"))):
            try:
                slices.append(Slice.load_from_disk(str(sl_path)))
            except FileNotFoundError:
                continue

        # Load dataset
        dataset = Dataset.load_from_disk(str(savedir / "dataset"))

        # Load metrics
        metrics = dill.load(open(str(savedir / "metrics.dill"), "rb"))

        # Load metrics
        aggregators = dill.load(open(str(savedir / "aggregators.dill"), "rb"))

        # Load metadata
        _ = dill.load(open(str(savedir / "metadata.dill"), "rb"))

        # Create the devbench
        devbench = cls(dataset=dataset, )

        # Set previously stored metrics
        devbench.metrics = metrics

        # Set previously stored aggregators
        devbench.aggregators = aggregators

        # Set the slices
        devbench.add_slices(slices)

        # Load version info
        with open(str(savedir / "version.dill"), "rb") as f:
            devbench._loads_version(f.read())

        return devbench
예제 #14
0
    def remap_schema(self, dataset: Dataset):
        # Ground the schema to the dataset
        input_grounding, reversed_input_grounding = self.input_schema.ground(
            dataset.features
        )
        output_grounding, reversed_output_grounding = self.output_schema.ground(
            dataset.features
        )

        # Construct a map_fn that remaps the dataset schema
        def map_fn(example):
            return tz.merge(
                {k: example[input_grounding[k]] for k in self.input_schema.features},
                {k: example[output_grounding[k]] for k in self.output_schema.features},
            )

        return dataset.map(
            map_fn,
            remove_columns=list(reversed_input_grounding.keys())
            + list(reversed_output_grounding.keys()),
        )
예제 #15
0
    def test_from_batches(self):
        # Build a dataset from multiple batches
        dataset = Dataset.from_batches(
            [{
                "a": [1, 2, 3],
                "b": [True, False, True],
                "c": ["x", "y", "z"],
                "d": [{
                    "e": 2
                }, {
                    "e": 3
                }, {
                    "e": 4
                }],
            }] * 3,
            identifier=Identifier(_name="MyDataset"),
        )

        self.assertEqual(set(dataset.column_names),
                         {"a", "b", "c", "d", "index"})
        self.assertEqual(len(dataset), 9)
예제 #16
0
    def evaluate(
        self,
        dataset: Dataset,
        input_columns: List[str],
        output_columns: List[str],
        batch_size: int = 32,
        metrics: List[str] = None,
        coerce_fn: Callable = None,
    ):

        # TODO(karan): generalize to TF2

        # Reset the dataset format
        dataset.reset_format()
        dataset.set_format(columns=input_columns + output_columns)

        # TODO(karan): check that the Dataset conforms to the task definition
        # TODO(karan): figure out how the output_columns will be used by the metrics
        pass

        predictions = []
        targets = []

        # Loop and apply the prediction function
        # TODO(karan): not using .map() here in order to get more fine-grained
        #  control over devices
        for idx in range(0, len(dataset), batch_size):
            # Create the batch
            batch = dataset[idx:idx + batch_size]

            # Predict on the batch
            prediction_dict = self.predict_batch(batch=batch,
                                                 input_columns=input_columns)

            # Coerce the predictions
            if coerce_fn:
                prediction_dict = coerce_fn(prediction_dict)

            # Grab the raw target key/values
            target_dict = tz.keyfilter(lambda k: k in output_columns, batch)

            # TODO(karan): general version for non-classification problems
            # TODO(karan): move this to the right device
            if self.task.classification():
                target_dict = tz.valmap(lambda v: torch.tensor(v), target_dict)

            # TODO(karan): incremental metric computation here
            # Append the predictions and targets
            predictions.append(prediction_dict)
            targets.append(target_dict)

        # Consolidate the predictions and targets
        if self.task.classification():
            # TODO(karan): Need to store predictions and outputs from the model
            predictions = tz.merge_with(lambda v: torch.cat(v).to("cpu"),
                                        *predictions)
            targets = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *targets)
        else:
            predictions = tz.merge_with(
                lambda x: list(itertools.chain.from_iterable(x)), *predictions)
            targets = tz.merge_with(
                lambda x: list(itertools.chain.from_iterable(x)), *targets)

        # Compute the metrics
        # TODO(karan): generalize this code to support metric computation for any task

        # Assumes classification, so the output_columns contains a single key for the
        # label
        if self.task.classification():
            assert len(
                output_columns) == 1  # , "Only supports classification."
            num_classes = self.task.output_schema.features[list(
                self.task.output_schema.keys())[0]].num_classes

        labels = targets[list(targets.keys())[0]]

        if metrics is None:
            if self.task is None:
                raise ValueError(
                    "Must specify metrics if model not associated with task")
            metrics = self.task.metrics

        pred = predictions["pred"].to(self.device)
        target = labels.to(self.device)

        evaluation_dict = {
            metric: compute_metric(metric, pred, target, num_classes)
            for metric in metrics
        }

        # Reset the data format
        dataset.reset_format()

        return evaluation_dict
예제 #17
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        # mask: List[int] = None,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[List[Slice], np.ndarray]:

        # Create slices using the dataset
        slices = [Slice(dataset) for _ in range(len(self.identifiers))]
        all_slice_memberships = []
        # Batch the dataset, and process each batch
        for batch in dataset.batch(batch_size):
            # Process the batch
            _, slice_memberships = self.process_batch(
                batch=batch,
                columns=columns,
                *args,
                **kwargs,
            )

            # Keep track of the slice memberships
            all_slice_memberships.append(slice_memberships)

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        for i, sl in enumerate(slices):
            # Set the visible rows for each slice
            sl.set_visible_rows(np.where(slice_membership[:, i])[0])

            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Append the the lineage
            sl.add_to_lineage(
                category=str(self.category.capitalize()),
                identifier=self.identifiers[i],
                columns=strings_as_json(columns),
            )

            # # Create the lineage
            # sl.lineage = [
            #     (str(Dataset.__name__), dataset.identifier),
            #     (
            #         str(self.category.capitalize()),
            #         self.identifiers[i],
            #         strings_as_json(columns),
            #     ),
            # ]
            # if isinstance(dataset, Slice):
            #     # Prepend the Slice's lineage instead, if the dataset was a slice
            #     sl.lineage = dataset.lineage + [
            #         (
            #             str(self.category.capitalize()),
            #             self.identifiers[i],
            #             strings_as_json(columns),
            #         )
            #     ]

        return slices, slice_membership
예제 #18
0
def stow(
    dataset: Dataset,
    cached_ops: Dict[CachedOperation, List[List[str]]],
    batch_size: int = 32,
    load_from_cache_file: bool = True,
):
    """Apply a list of cached operations in sequence."""

    # Check the InteractionTape to remove CachedOperations that have already been stowed
    for cached_op, list_of_columns in list(cached_ops.items()):
        indices_to_remove = []
        for i, columns in enumerate(list(list_of_columns)):
            if dataset.check_tape(
                    path=[CACHEDOPS],
                    identifiers=cached_op.identifier,
                    columns=columns,
            ):
                # Remove the columns at index i
                indices_to_remove.append(i)

        # Remove the columns that are already cached
        for index in sorted(indices_to_remove, reverse=True):
            columns = cached_ops[cached_op].pop(index)
            print(f"skipped: {cached_op.identifier} -> {columns}", flush=True)

        # Check if list_of_columns is now empty
        if not cached_ops[cached_op]:
            # Remove the op entirely
            cached_ops.pop(cached_op)

    for cached_op, list_of_columns in cached_ops.items():
        for columns in list_of_columns:
            dataset = cached_op(dataset,
                                columns=columns,
                                batch_size=batch_size)

    # def _map_fn(batch: Batch):
    #     """
    #     Consolidate the application of the CachedOperations passed to stow into a
    #     single mappable function.
    #     """
    #     for cached_op, list_of_columns in cached_ops.items():
    #         for columns in list_of_columns:
    #             batch = cached_op(batch, columns=columns)
    #
    #     return batch
    #
    # # Compute the hash value
    # val = 0
    # for cached_op, list_of_columns in cached_ops.items():
    #     for columns in list_of_columns:
    #         val ^= cached_op.get_cache_hash(columns=columns)
    #
    # # Combine with the hash for the dataset on which the cached ops are applied
    # val ^= persistent_hash(
    #     # TODO(karan): move this to Dataset
    #     "-".join(
    #         "-".join(str(k) + "-" + str(v) for k, v in f.items()) for f in
    #         dataset._data_files
    #     )
    # )
    #
    # # Map the cached operations over the dataset
    # try:
    #     dataset = dataset.map(
    #         _map_fn,
    #         batched=True,
    #         batch_size=32,
    #         cache_file_name='cache-' + str(abs(val)) + '.arrow',
    #         load_from_cache_file=load_from_cache_file
    #     )
    # except TypeError:
    #     # Batch the dataset, and process each batch
    #     all_batches = [_map_fn(batch=batch) for batch in dataset.batch(batch_size)]
    #
    #     # Update the dataset efficiently by reusing all_batches
    #     dataset = dataset.map(
    #         lambda examples, indices: all_batches[indices[0] // batch_size],
    #         batched=True,
    #         batch_size=batch_size,
    #         with_indices=True,
    #     )

    # Update the Dataset history
    for cached_op, list_of_columns in cached_ops.items():
        for columns in list_of_columns:
            dataset.update_tape(path=[CACHEDOPS],
                                identifiers=cached_op.identifier,
                                columns=columns)

    return dataset
예제 #19
0
    def process_batch(
        self,
        batch: Batch,
        columns: List[str],
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        *args,
        **kwargs,
    ) -> Tuple[Batch, List[Batch], Optional[np.ndarray]]:
        # Determine the size of the batch
        batch_size = len(batch[list(batch.keys())[0]])

        # Construct the matrix of slice labels: (batch_size x num_slices)
        slice_membership = np.ones((batch_size, self.num_slices),
                                   dtype=np.int32)

        # Uncache the batch to construct the skeleton for transformed batches
        skeleton_batches = [
            Dataset.uncached_batch(batch) for _ in range(self.num_slices)
        ]

        # Set the index for the skeleton batches
        for j, skeleton_batch in enumerate(skeleton_batches):
            skeleton_batch["index"] = [
                f"{idx}-{self.identifiers[j]}"
                for idx in skeleton_batch["index"]
            ]

        # Apply the SliceBuilder's core functionality
        transformed_batches, slice_membership = self.apply(
            skeleton_batches=skeleton_batches,
            slice_membership=slice_membership,
            batch=batch,
            columns=columns,
            *args,
            **kwargs,
        )

        # Store the transformed examples
        updates = self.construct_updates(
            transformed_batches=transformed_batches,
            slice_membership=slice_membership,
            batch_size=batch_size,
            columns=columns,
            mask=mask,
            compress=store_compressed,
        )

        # Remove transformed examples where slice_membership[i, :] = 0 before returning
        transformed_batches = [
            self.filter_batch_by_slice_membership(
                batch=transformed_batch,
                slice_membership=slice_membership[:, j:j + 1])[0]
            for j, transformed_batch in enumerate(transformed_batches)
        ]

        if store:
            batch = self.store(
                batch=batch,
                updates=updates,
            )

        return batch, transformed_batches, slice_membership
예제 #20
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[Dataset, List[Slice], np.ndarray]:
        """Apply a SliceBuilder to a dataset.

        Args:
            dataset: Dataset
            columns: list of columns
            batch_size: integer batch size
            mask: boolean or integer mask array, mask[i] = True means that the ith
            slice will be masked out
            store_compressed: whether to store in a compressed format
            store: whether to store the results along with the example in Dataset
            num_proc: num processes for multiprocessing
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments

        Returns: tuple of (Dataset, list of Slices, matrix of (example,
        slice) membership)
        """
        # Prepare the dataset
        dataset = self.prepare_dataset(
            dataset=dataset,
            columns=columns,
            batch_size=batch_size,
            mask=mask,
            store_compressed=store_compressed,
            store=store,
            *args,
            **kwargs,
        )

        # Compute a hash
        val = persistent_hash(str(
            dataset.identifier)) ^ dataset.hash_interactions()
        for i, identifier in enumerate(self.identifiers):
            if not mask[i]:
                val ^= persistent_hash(
                    str(identifier) + str(strings_as_json(columns)))

        try:
            # Map the SliceBuilder over the dataset
            all_sliced_batches = []
            all_slice_memberships = []

            def _map_fn(batch):
                """Map function for processing batches.

                Note that using this map_fn in a stateful way is
                dangerous, since every invocation of this function
                appends to the all_slice_batches list. The .map()
                function will invoke this once for testing before
                performing the map, so we discard the first entry
                inserted into all_sliced_batches.
                """
                batch, sliced_batches, slice_membership = self.process_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                )
                all_sliced_batches.append(sliced_batches)
                all_slice_memberships.append(slice_membership)
                return batch

            dataset = dataset.map(
                _map_fn,
                batched=True,
                batch_size=batch_size,
                # FIXME(karan): enable this by adding logic for generating
                #  all_sliced_batches and all_slice_memberships
                #  when loading from cache file
                load_from_cache_file=False,
                # The cache file name is a XOR of the interaction history and the
                # current operation
                cache_file_name=str(dataset.logdir /
                                    ("cache-" + str(abs(val)) + ".arrow")),
            )

            # Remove the first entry (see _map_fn)
            all_sliced_batches = all_sliced_batches[1:]
            all_slice_memberships = all_slice_memberships[1:]

        except:  # noqa
            # Batch the dataset, and process each batch
            all_batches, all_sliced_batches, all_slice_memberships = zip(*[
                self.process_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ) for batch in dataset.batch(batch_size)
            ])

            # Update the dataset efficiently by reusing all_batches
            dataset = dataset.map(
                lambda examples, indices: all_batches[indices[0] // batch_size
                                                      ],
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                load_from_cache_file=False,
                # The cache file name is a XOR of the interaction history and the
                # current operation
                cache_file_name=str(dataset.logdir /
                                    ("cache-" + str(abs(val)) + ".arrow")),
            )

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        slice_cache_hashes = []
        for identifier in self.identifiers:
            slice_cache_hashes.append(val ^ persistent_hash(str(identifier)))

        if not num_proc or num_proc == 1:
            # Construct slices
            slices = []
            for i, slice_batches in enumerate(zip(*all_sliced_batches)):
                slices.append(
                    create_slice((
                        dataset,
                        slice_membership,
                        slice_batches,
                        i,
                        batch_size,
                        slice_cache_hashes[i],
                    )))
        else:
            # Parallelized slice construction
            with Pool(num_proc) as pool:
                slices = pool.map(
                    create_slice,
                    [(
                        dataset,
                        slice_membership,
                        slice_batches,
                        i,
                        batch_size,
                        slice_cache_hashes[i],
                    )
                     for i, slice_batches in enumerate(zip(
                         *all_sliced_batches))],
                )

        # TODO(karan): make this more systematic
        # TODO(karan): fix bug when slicing a Slice
        for i, sl in enumerate(slices):
            # # Set the Slice features
            # sl.info.features = dataset.features

            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Create the lineage
            sl.lineage = [
                (str(Dataset.__name__), dataset.identifier),
                (
                    str(self.category.capitalize()),
                    self.identifiers[i],
                    strings_as_json(columns),
                ),
            ]
            if isinstance(dataset, Slice):
                # Prepend the Slice's lineage instead, if the dataset was a slice
                sl.lineage = dataset.lineage + [(
                    str(self.category.capitalize()),
                    self.identifiers[i],
                    strings_as_json(columns),
                )]

        return dataset, slices, slice_membership
예제 #21
0
    def prepare_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        mask: List[int] = None,
        store_compressed: bool = True,
        store: bool = True,
        *args,
        **kwargs,
    ) -> Dataset:

        # Compute the hash for this operation
        # FIXME(karan): this is repeated inside process_dataset
        val = persistent_hash(str(
            dataset.identifier)) ^ dataset.hash_interactions()
        for i, identifier in enumerate(self.identifiers):
            if not mask[i]:
                val ^= persistent_hash(
                    str(identifier) + str(strings_as_json(columns)))

        try:
            return dataset.map(
                partial(
                    self.prepare_batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ),
                batched=True,
                batch_size=batch_size,
                load_from_cache_file=False,
                cache_file_name=str(
                    dataset.logdir /
                    ("cache-" + str(abs(val)) + "-prep.arrow")),
            )
        except:  # TypeError or PicklingError or AttributeError: # noqa
            # Batch the dataset, and process each batch
            all_batches = [
                self.prepare_batch(
                    batch=batch,
                    columns=columns,
                    mask=mask,
                    store_compressed=store_compressed,
                    store=store,
                    *args,
                    **kwargs,
                ) for batch in dataset.batch(batch_size)
            ]

            # Update the dataset efficiently by reusing all_batches
            return dataset.map(
                lambda examples, indices: all_batches[indices[0] // batch_size
                                                      ],
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                load_from_cache_file=False,
                cache_file_name=str(
                    dataset.logdir /
                    ("cache-" + str(abs(val)) + "-prep.arrow")),
            )
예제 #22
0
    def process_dataset(
        self,
        dataset: Dataset,
        columns: List[str],
        batch_size: int = 32,
        # mask: List[int] = None,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[List[Slice], np.ndarray]:
        """Apply a SliceBuilder to a dataset.

        Args:
            dataset: Dataset
            columns: list of columns
            batch_size: integer batch size
            # mask: boolean or integer mask array, mask[i] = True means that the ith
            # slice will be masked out
            num_proc: num processes for multiprocessing
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments

        Returns: tuple of (Dataset, list of Slices, matrix of (example,
        slice) membership)
        """

        # # Compute a hash
        # val = persistent_hash(str(dataset.identifier)) ^ dataset.hash_interactions()
        # for i, identifier in enumerate(self.identifiers):
        #     if not mask[i]:
        #         val ^= persistent_hash(str(identifier)
        #         + str(strings_as_json(columns)))

        # try:
        #     # Map the SliceBuilder over the dataset
        #     all_sliced_batches = []
        #     all_slice_memberships = []
        #
        #     def _map_fn(batch):
        #         """Map function for processing batches.
        #
        #         Note that using this map_fn in a stateful way is
        #         dangerous, since every invocation of this function
        #         appends to the all_slice_batches list. The .map()
        #         function will invoke this once for testing before
        #         performing the map, so we discard the first entry
        #         inserted into all_sliced_batches.
        #         """
        #         batch, sliced_batches, slice_membership = self.process_batch(
        #             batch=batch,
        #             columns=columns,
        #             mask=mask,
        #             store_compressed=store_compressed,
        #             store=store,
        #             *args,
        #             **kwargs,
        #         )
        #         all_sliced_batches.append(sliced_batches)
        #         all_slice_memberships.append(slice_membership)
        #         return batch
        #
        #     dataset = dataset.map(
        #         _map_fn,
        #         batched=True,
        #         batch_size=batch_size,
        #         # FIXME(karan): enable this by adding logic for generating
        #         #  all_sliced_batches and all_slice_memberships
        #         #  when loading from cache file
        #         load_from_cache_file=False,
        #         # The cache file name is a XOR of the interaction history and the
        #         # current operation
        #         cache_file_name=str(
        #             dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")
        #         ),
        #     )
        #
        #     # Remove the first entry (see _map_fn)
        #     all_sliced_batches = all_sliced_batches[1:]
        #     all_slice_memberships = all_slice_memberships[1:]
        #
        # except:  # noqa
        # all_batches, all_sliced_batches, all_slice_memberships = zip(
        #     *[
        #         self.process_batch(
        #             batch=batch,
        #             columns=columns,
        #             mask=mask,
        #             store_compressed=store_compressed,
        #             store=store,
        #             *args,
        #             **kwargs,
        #         )
        #         for batch in dataset.batch(batch_size)
        #     ]
        # )
        # # Update the dataset efficiently by reusing all_batches
        # dataset = dataset.map(
        #     lambda examples, indices: all_batches[indices[0] // batch_size],
        #     batched=True,
        #     batch_size=batch_size,
        #     with_indices=True,
        #     load_from_cache_file=False,
        #     # The cache file name is a XOR of the interaction history and the
        #     # current operation
        #     cache_file_name=str(
        #         dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")
        #     ),
        # )

        # Create slices
        slices = [Slice() for _ in range(len(self.identifiers))]
        all_slice_memberships = []
        # Batch the dataset, and process each batch
        for batch in dataset.batch(batch_size):
            # Process the batch
            sliced_batches, slice_memberships = self.process_batch(
                batch=batch,
                columns=columns,
                *args,
                **kwargs,
            )

            # Incrementally build the slices
            for sl, sl_batch in zip(slices, sliced_batches):
                sl._dataset.append(sl_batch)

            # Keep track of the slice memberships
            all_slice_memberships.append(slice_memberships)

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        # slice_cache_hashes = []
        # for identifier in self.identifiers:
        #     slice_cache_hashes.append(val ^ persistent_hash(str(identifier)))

        # if not num_proc or num_proc == 1:
        #     # Construct slices
        #     slices = []
        #     for i, slice_batches in enumerate(zip(*all_sliced_batches)):
        #         slices.append(
        #             create_slice(
        #                 (
        #                     dataset,
        #                     slice_membership,
        #                     slice_batches,
        #                     i,
        #                     batch_size,
        #                     slice_cache_hashes[i],
        #                 )
        #             )
        #         )
        # else:
        #     # Parallelized slice construction
        #     with Pool(num_proc) as pool:
        #         slices = pool.map(
        #             create_slice,
        #             [
        #                 (
        #                     dataset,
        #                     slice_membership,
        #                     slice_batches,
        #                     i,
        #                     batch_size,
        #                     slice_cache_hashes[i],
        #                 )
        #                 for i, slice_batches in enumerate(zip(*all_sliced_batches))
        #             ],
        #         )

        for i, sl in enumerate(slices):
            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Append the the lineage
            sl.add_to_lineage(
                category=str(self.category.capitalize()),
                identifier=self.identifiers[i],
                columns=strings_as_json(columns),
            )

            # # Create the lineage
            # sl.lineage = [
            #     (str(Dataset.__name__), dataset.identifier),
            #     (
            #         str(self.category.capitalize()),
            #         self.identifiers[i],
            #         strings_as_json(columns),
            #     ),
            # ]
            # if isinstance(dataset, Slice):
            #     # Prepend the Slice's lineage instead, if the dataset was a slice
            #     sl.lineage = dataset.lineage + [
            #         (
            #             str(self.category.capitalize()),
            #             self.identifiers[i],
            #             strings_as_json(columns),
            #         )
            #     ]

        return slices, slice_membership