コード例 #1
0
 def _get_testbench(self, task=None):
     # TODO Have a proper mock testbench
     # Create test bench
     testbench_identifier = "test-testbench"
     testbench = TestBench(
         identifier=testbench_identifier,
         task=task,
         slices=[
             SliceDataPanel(
                 dataset=DataPanel.from_huggingface("snli", split="train[:128]"),
                 identifier="snli_1",
             ).filter(lambda example: example["label"] != -1),
             SliceDataPanel(
                 dataset=DataPanel.from_huggingface(
                     "snli", split="validation[:128]"
                 ),
                 identifier="snli_2",
             ).filter(lambda example: example["label"] != -1),
             SliceDataPanel(
                 dataset=DataPanel.from_huggingface("snli", split="test[:128]"),
                 identifier="snli_3",
             ).filter(lambda example: example["label"] != -1),
         ],
         dataset_id="snli",
     )
     return testbench
コード例 #2
0
    def prepare_dataset(
        self,
        dp: DataPanel,
        columns: List[str],
        batch_size: int = 32,
        *args,
        **kwargs,
    ) -> None:
        """Apply a preparation function to the data. Use this to update
        attributes of `self`.

        Args:
            dp: DataPanel
            columns: list of columns
            batch_size: batch size for preparation
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments
        """
        # Set the data format
        with dp.format(
                columns +
                self._filter_prerequisite_columns(columns, dp.column_names)):
            # Batch the dataset, and prepare each batch
            for batch in dp.batch(batch_size):
                try:
                    # Check if the `prepare_batch` function has been implemented
                    self.prepare_batch(
                        batch=batch,
                        columns=columns,
                        *args,
                        **kwargs,
                    )
                except NotImplementedError:
                    break
コード例 #3
0
    def process_dataset(
        self,
        dp: DataPanel,
        columns: List[str],
        batch_size: int = 32,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[List[DataPanel], np.ndarray]:

        # Create slices using the dataset
        all_slice_memberships = []

        # Batch the dataset, and process each batch
        for batch in dp.batch(batch_size):
            # Process the batch
            _, slice_memberships = self.process_batch(
                dp=batch,
                columns=columns,
                *args,
                **kwargs,
            )

            # Keep track of the slice memberships
            all_slice_memberships.append(slice_memberships)

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        slices = []
        for i in range(len(self.identifiers)):
            # Create a view of the original DataPanel
            sl = dp.view()

            # Only keep the filtered rows visible
            for column in sl._data.values():
                column.visible_rows = np.where(slice_membership[:, i])[0]

            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Append the the lineage
            sl.add_to_lineage(
                category=str(self.category.capitalize()),
                identifier=self.identifiers[i],
                columns=strings_as_json(columns),
            )

            #
            # sl.identifier = ...

            slices.append(sl)

        # for i, sl in enumerate(slices):
        #     # Set the visible rows for each slice
        #     sl.set_visible_rows(np.where(slice_membership[:, i])[0])

        return slices, slice_membership
コード例 #4
0
    def process_batch(
        self,
        dp: DataPanel,
        columns: List[str],
        *args,
        **kwargs,
    ) -> Tuple[List[DataPanel], Optional[np.ndarray]]:

        # Determine the size of the batch
        batch_size = len(dp[list(dp.keys())[0]])

        # Construct the matrix of slice labels: (batch_size x num_slices)
        slice_membership = np.ones((batch_size, self.num_slices),
                                   dtype=np.int32)

        # Uncache the batch to construct the skeleton for transformed batches
        skeleton_batches = [
            DataPanel.uncached_batch(dp) for _ in range(self.num_slices)
        ]

        # Set the index for the skeleton batches
        for j, skeleton_batch in enumerate(skeleton_batches):
            # skeleton_batch.update(
            #     lambda x: {'index': f"{x['index']}-{self.identifiers[j]}"}
            # )
            skeleton_batch["index"] = [
                f"{idx}-{self.identifiers[j]}"
                for idx in skeleton_batch["index"]
            ]

        # Apply the SliceBuilder's core functionality: use positional args
        try:
            transformed_batches, slice_membership = self.apply(
                dp,
                columns,
                skeleton_batches,
                slice_membership,
                *args,
                **kwargs,
            )
        except TypeError:
            self.apply(dp, columns, *args, **kwargs)

        # Remove transformed examples where slice_membership[i, :] = 0 before returning
        transformed_batches = [
            self.filter_batch_by_slice_membership(
                batch=transformed_batch,
                slice_membership=slice_membership[:, j:j + 1],
            )[0] for j, transformed_batch in enumerate(transformed_batches)
        ]

        return transformed_batches, slice_membership
コード例 #5
0
    def process(
        self,
        dp: DataPanel,
        columns: List[str],
        batch_size: int = 32,
        *args,
        **kwargs,
    ) -> DataPanel:
        """Apply the Operation to a DataPanel.

        Args:
            dp (DataPanel): DataPanel
            columns (list): list of columns
            batch_size (int): batch size for `dp.update(...)`
            *args: optional positional arguments
            **kwargs: optional keyword arguments
        """

        return dp.update(
            tuple_to_dict(keys=[
                str(ident(columns=columns))
                for ident in self.output_identifiers
            ])(partial(self.process_batch, columns=columns, *args, **kwargs)),
            batch_size=batch_size,
            is_batched_fn=True,
            *args,
            **kwargs,
        )
コード例 #6
0
    def __init__(self):
        # Create a fake dataset
        self.dataset = DataPanel.from_batch(
            {
                "text_a": [
                    "Before the actor slept, the senator ran.",
                    "The lawyer knew that the judges shouted.",
                    "If the actor slept, the judge saw the artist.",
                    "The lawyers resigned, or the artist slept.",
                ],
                "text_b": [
                    "The actor slept.",
                    "The judges shouted.",
                    "The actor slept.",
                    "The artist slept.",
                ],
                "label": [0, 0, 1, 1],
                "z": [1, 0, 1, 0],
                "fast": [False, True, True, False],
            },
            identifier=Identifier(_name="MockDataset", version="2.0"),
        )

        # Keep a copy of the original
        self.original_dataset = deepcopy(self.dataset)

        assert len(self.dataset) == 4
コード例 #7
0
 def test_from_dataset(self):
     # Create a slice
     sl = SliceDataPanel(self.testbed.dataset)
     # Compare the slice identifier
     self.assertEqual(str(sl),
                      "RGSlice[num_rows: 6](MockDataset(version=1.0))")
     # Length of the slice
     self.assertEqual(len(sl), 6)
     # Lineage of the slice
     self.assertEqual(sl.lineage, [("Dataset", "MockDataset(version=1.0)")])
コード例 #8
0
    def exists(cls, dp: DataPanel) -> bool:
        """Check if the outputs of the Operation are in `dp`.

        Args:
            dp: DataPanel

        Returns:
            bool: True if `dp` contains a column produced by `Operation`,
                False otherwise
        """
        # TODO: update this to use `Operation.outputs`
        return any([key.startswith(cls.__name__) for key in dp.keys()])
コード例 #9
0
    def prepare(
        self,
        dp: DataPanel,
        columns: List[str],
        batch_size: int = 32,
        *args,
        **kwargs,
    ) -> None:
        """Preparation that is applied before the Operation is applied.

        Many Operations require a full pass over the DataPanel to precompute some
        variables before the core operation can actually be applied e.g. to create a
        Bag-of-Words representation, constructing a vocabulary to keep only
        tokens that are frequently seen across the DataPanel.

        Args:
            dp (DataPanel): DataPanel
            columns (list): list of columns
            batch_size (int): batch size for `dp.map(...)`
            *args: optional positional arguments
            **kwargs: optional keyword arguments
        """

        try:
            dp.map(
                function=partial(self.prepare_batch,
                                 columns=columns,
                                 *args,
                                 **kwargs),
                input_columns=columns,
                is_batched_fn=True,
                batch_size=batch_size,
                *args,
                **kwargs,
            )
        except NotImplementedError:
            return
コード例 #10
0
    def load(cls, path: str) -> DevBench:
        """Load a devbench from disk.

        Args:
            path (str): path to the devbench directory. The devbench directory must
                have the `.devbench` extension.

        Returns:
            a DevBench
        """

        # Path to the save directory
        savedir = pathlib.Path(path)

        # Load all the slices
        slices = []
        for sl_path in tqdm(list((savedir / "slices").glob("*"))):
            try:
                slices.append(DataPanel.read(str(sl_path)))
            except FileNotFoundError:
                continue

        # Load metrics
        metrics = dill.load(open(str(savedir / "metrics.dill"), "rb"))

        # Load metrics
        aggregators = dill.load(open(str(savedir / "aggregators.dill"), "rb"))

        # Create the devbench
        devbench = cls()

        # Set previously stored metrics
        devbench.metrics = metrics

        # Set previously stored aggregators
        devbench.aggregators = aggregators

        # Set the slices
        devbench.add_slices(slices)

        # Load version info
        with open(str(savedir / "version.dill"), "rb") as f:
            devbench._loads_version(f.read())

        return devbench
コード例 #11
0
    def __init__(self):
        # Create a fake batch of data
        self.batch = {
            "text": [
                "The man is walking.",
                "The man is running.",
                "The woman is sprinting.",
                "The woman is resting.",
                "The hobbit is flying.",
                "The hobbit is swimming.",
            ],
            "label": [0, 0, 1, 1, 0, 0],
            "z": [1, 0, 1, 0, 1, 0],
            "fast": [False, True, True, False, False, False],
            "metadata": [
                {
                    "source": "real"
                },
                {
                    "source": "real"
                },
                {
                    "source": "real"
                },
                {
                    "source": "real"
                },
                {
                    "source": "fictional"
                },
                {
                    "source": "fictional"
                },
            ],
        }
        # Create a fake dataset
        self.dataset = DataPanel.from_batch(
            self.batch,
            identifier=Identifier(_name="MockDataset", version="1.0"),
        )

        # Keep a copy of the original
        self.original_dataset = deepcopy(self.dataset)

        assert len(self.dataset) == 6
コード例 #12
0
    def load(cls, path: str) -> TestBench:
        """Load a testbench from disk.

        Args:
            path: string path to the testbench directory

        Returns:
        """

        # Path to the save directory
        savedir = pathlib.Path(path)

        # Load all the slices
        slices = []
        for sl_path in tqdm(list((savedir / "slices").glob("*"))):
            try:
                slices.append(DataPanel.load_from_disk(str(sl_path)))
            except FileNotFoundError:
                continue

        # Load metrics
        metrics = dill.load(open(str(savedir / "metrics.dill"), "rb"))

        # Load metadata
        metadata = dill.load(open(str(savedir / "metadata.dill"), "rb"))

        # Create the testbench
        testbench = cls(
            identifier=metadata["identifier"],
            task=metadata["task"],
            slices=slices,
        )

        # Set previously stored metrics
        testbench.metrics = metrics

        # Load version info
        with open(str(savedir / "version.dill"), "rb") as f:
            testbench._loads_version(f.read())

        return testbench
コード例 #13
0
ファイル: task.py プロジェクト: robustness-gym/robustness-gym
    def remap_schema(self, dp: DataPanel):
        # Ground the schema to the dp
        input_grounding, reversed_input_grounding = self.input_schema.ground(
            dp.features)
        output_grounding, reversed_output_grounding = self.output_schema.ground(
            dp.features)

        for col in self.input_schema.columns:
            # Grab the column
            values = dp[input_grounding[col]]
            # Remove it from the dp
            dp.remove_column(input_grounding[col])
            # Add again with the right column name
            dp.add_column(col, values)

        for col in self.output_schema.columns:
            # Grab the column
            values = dp[output_grounding[col]]
            # Remove it from the dp
            dp.remove_column(output_grounding[col])
            # Add again with the right column name
            dp.add_column(col, values)

        return dp
コード例 #14
0
    def evaluate(
        self,
        dataset: DataPanel,
        input_columns: List[str],
        output_columns: List[str],
        batch_size: int = 32,
        metrics: List[str] = None,
        coerce_fn: Callable = None,
    ):

        # TODO(karan): generalize to TF2

        # Reset the dataset format
        dataset.reset_format()
        dataset.set_format(columns=input_columns + output_columns)

        # TODO(karan): check that the DataPanel conforms to the task definition
        # TODO(karan): figure out how the output_columns will be used by the metrics
        pass

        predictions = []
        targets = []

        # Loop and apply the prediction function
        # TODO(karan): not using .map() here in order to get more fine-grained
        #  control over devices
        for idx in range(0, len(dataset), batch_size):
            # Create the batch
            batch = dataset[idx:idx + batch_size]

            # Predict on the batch
            prediction_dict = self.predict_batch(batch=batch,
                                                 input_columns=input_columns)

            # Coerce the predictions
            if coerce_fn:
                prediction_dict = coerce_fn(prediction_dict)

            # Grab the raw target key/values
            target_dict = tz.keyfilter(lambda k: k in output_columns, batch)

            # TODO(karan): general version for non-classification problems
            # TODO(karan): move this to the right device
            if self.is_classifier:
                target_dict = tz.valmap(lambda v: torch.tensor(v), target_dict)

            # TODO(karan): incremental metric computation here
            # Append the predictions and targets
            predictions.append(prediction_dict)
            targets.append(target_dict)

        # Consolidate the predictions and targets
        if self.is_classifier:
            # TODO(karan): Need to store predictions and outputs from the model
            predictions = tz.merge_with(lambda v: torch.cat(v).to("cpu"),
                                        *predictions)
            targets = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *targets)
        else:
            predictions = tz.merge_with(
                lambda x: list(itertools.chain.from_iterable(x)), *predictions)
            targets = tz.merge_with(
                lambda x: list(itertools.chain.from_iterable(x)), *targets)

        # Compute the metrics
        # TODO(karan): generalize this code to support metric computation for any task

        # Assumes classification, so the output_columns contains a single key for the
        # label
        if self.is_classifier:
            assert len(
                output_columns) == 1  # , "Only supports classification."
            num_classes = self.task.output_schema.features[list(
                self.task.output_schema.columns)[0]].num_classes

        labels = targets[list(targets.keys())[0]]

        if metrics is None:
            if self.task is None:
                raise ValueError(
                    "Must specify metrics if model not associated with task")
            metrics = self.task.metrics

        pred = predictions["pred"].to(self.device)
        target = labels.to(self.device)

        evaluation_dict = {
            metric: compute_metric(metric, pred, target, num_classes)
            for metric in metrics
        }

        # Reset the data format
        dataset.reset_format()

        return evaluation_dict
コード例 #15
0
    def __call__(
        self,
        dp: DataPanel,
        columns: List[str],
        num_proc: int = None,
        *args,
        **kwargs,
    ):

        if not num_proc or num_proc == 1:
            slices = []
            slice_membership = []
            # Apply each slicebuilder in sequence
            for i, slicebuilder in tqdm(enumerate(self.subpopulations)):
                # Apply the slicebuilder
                slices_i, slice_membership_i = slicebuilder(
                    dp=dp,
                    columns=columns,
                    *args,
                    **kwargs,
                )

                # Add in the slices and slice membership
                slices.extend(slices_i)
                slice_membership.append(slice_membership_i)

        else:
            # TODO(karan): cleanup, make mp.Pool support simpler across the library
            with Pool(num_proc) as pool:
                slices, slice_membership = zip(
                    *pool.map(
                        lambda sb: sb(
                            dp=dp,
                            columns=columns,
                            *args,
                            **kwargs,
                        ),
                        [slicebuilder for slicebuilder in self.subpopulations],
                    )
                )

                # Combine all the slices
                slices = list(tz.concat(slices))

            def _store_updates(batch, indices):

                # Each Subpopulation will generate slices
                for i, subpopulation in enumerate(self.subpopulations):
                    updates = subpopulation.construct_updates(
                        slice_membership=slice_membership[i][indices],
                        columns=columns,
                    )

                    batch = subpopulation.store(
                        batch=batch,
                        updates=updates,
                    )

                return batch

            if isinstance(dp, DataPanel):
                dp = dp.map(
                    _store_updates,
                    with_indices=True,
                    batched=True,
                )

                for subpopulation in self.subpopulations:
                    # Update the DataPanel's history
                    dp.update_tape(
                        path=[SLICEBUILDERS, subpopulation.category],
                        identifiers=subpopulation.identifiers,
                        columns=columns,
                    )

        # Combine all the slice membership matrices
        slice_membership = np.concatenate(slice_membership, axis=1)

        return slices, slice_membership
コード例 #16
0
    def process_dataset(
        self,
        dp: DataPanel,
        columns: List[str],
        batch_size: int = 32,
        num_proc: int = None,
        *args,
        **kwargs,
    ) -> Tuple[List[DataPanel], np.ndarray]:
        """Apply a SliceBuilder to a dataset.

        Args:
            dp: DataPanel
            columns: list of columns
            batch_size: integer batch size
            num_proc: num processes for multiprocessing
            *args: optional additional arguments
            **kwargs: optional additional keyword arguments

        Returns: tuple of (DataPanel, list of Slices,
        matrix of (example, slice) membership)
        """
        # Create slices
        slices = [[DataPanel()] for _ in range(len(self.identifiers))]
        all_slice_memberships = []

        # Batch the dataset, and process each batch
        for batch in dp.batch(batch_size):
            # Process the batch
            sliced_batches, slice_memberships = self.process_batch(
                dp=batch,
                columns=columns,
                *args,
                **kwargs,
            )

            # Incrementally build the slices
            for sl, sl_batch in zip(slices, sliced_batches):
                sl.append(DataPanel(sl_batch))

            # Keep track of the slice memberships
            all_slice_memberships.append(slice_memberships)

        # Create a single slice label matrix
        slice_membership = np.concatenate(all_slice_memberships, axis=0)

        # Create a single DataPanel for each slice
        slices = [
            meerkat.concat(e[1:], axis=0) if len(e) > 1 else e[0]
            for e in slices
        ]

        # TODO(karan): DataPanel doesn't support this
        for i, sl in enumerate(slices):
            # Set the Slice category using the SliceBuilder's category
            sl.category = self.category

            # Append the the lineage
            sl.add_to_lineage(
                category=str(self.category.capitalize()),
                identifier=self.identifiers[i],
                columns=strings_as_json(columns),
            )

        return slices, slice_membership