Esempio n. 1
0
    def testValidateOptionallyCalled(self):
        struct = {"a": "a", "b": {"d": "b"}, "c": "c"}  # Bad struct (nested)

        # By default validation is not run.
        structure_utils._flatten_min_diff_structure(struct)

        with self.assertRaisesRegex(ValueError, "dict.*must be unnested"):
            structure_utils._flatten_min_diff_structure(struct,
                                                        run_validation=True)
    def _map_fn(sensitive_batch, nonsensitive_batch):

        flat_sensitive_batch = structure_utils._flatten_min_diff_structure(
            sensitive_batch)
        flat_nonsensitive_batch = structure_utils._flatten_min_diff_structure(
            nonsensitive_batch)

        flat_min_diff_data = [
            _build_single_batch(single_sensitive_batch,
                                single_nonsensitive_batch)
            for single_sensitive_batch, single_nonsensitive_batch in zip(
                flat_sensitive_batch, flat_nonsensitive_batch)
        ]

        return structure_utils._pack_min_diff_sequence_as(
            sensitive_batch, flat_min_diff_data)
    def _map_fn(original_batch, min_diff_batch):
        # Unpack original batch.
        original_x, original_y, original_sample_weight = (
            tf.keras.utils.unpack_x_y_sample_weight(original_batch))

        # Assert that all min_diff_xs have the same structure as original_x.
        # TODO: Should we assert that Tensor shapes are the same (other
        #                    than number of examples).

        min_diff_xs = [
            tf.keras.utils.unpack_x_y_sample_weight(batch)[
                0]  # First element is x.
            for batch in structure_utils._flatten_min_diff_structure(
                min_diff_batch)
        ]
        for min_diff_x in min_diff_xs:
            try:
                tf.nest.assert_same_structure(original_x, min_diff_x)
            except Exception as e:
                raise type(
                    e
                )("The x component structure of (one of) the `min_diff_dataset`(s) "
                  "does not match that of the original x structure (original shown "
                  "first): {}".format(e))

        # pack min_diff_batch with original_x
        return _pack_as_original(
            original_batch,
            MinDiffPackedInputs(original_inputs=original_x,
                                min_diff_data=min_diff_batch), original_y,
            original_sample_weight)
Esempio n. 4
0
 def testFlattenDictWithTuples(self):
     struct = {"a": ("a", ), "b": ("b", "b"), "c": ("c", )}
     flat_elem = structure_utils._flatten_min_diff_structure(struct)
     # Dict should be flattened into a list in order key order.
     self.assertAllEqual(flat_elem, [("a", ), ("b", "b"), ("c", )])
Esempio n. 5
0
 def testFlattenDict(self):
     struct = {"a": "a", "b": "b", "c": "c"}
     flat_elem = structure_utils._flatten_min_diff_structure(struct)
     # Dict should be flattened into a list in order key order.
     self.assertAllEqual(flat_elem, ["a", "b", "c"])
Esempio n. 6
0
 def testFlattenSingleElement(self):
     elem = "element"
     flat_elem = structure_utils._flatten_min_diff_structure(elem)
     # Single element should be put into a list of size 1.
     self.assertAllEqual(flat_elem, [elem])
Esempio n. 7
0
    def compute_min_diff_loss(self, min_diff_data, training=None, mask=None):
        # pyformat: disable
        """Computes `min_diff_loss`(es) corresponding to `min_diff_data`.

    Args:
      min_diff_data: Tuple of data or valid MinDiff structure of tuples as
        described below.
      training: Boolean indicating whether to run in training or inference mode.
        See `tf.keras.Model.call` for details.
      mask: Mask or list of masks as described in `tf.keras.Model.call`. These
        will be applied when calling the `original_model`.


    `min_diff_data` must have a structure (or be a single element) matching that
    of the `loss` parameter passed in during initialization. Each element of
    `min_diff_data` (and `loss`) corresponds to one application of MinDiff.

    Like the input requirements described in `tf.keras.Model.fit`, each element
    of `min_diff_data` must be a tuple of length 2 or 3. The tuple will be
    unpacked using the standard `tf.keras.utils.unpack_x_y_sample_weight`
    function:

    ```
    min_diff_data_elem = ...  # Single element from a batch of min_diff_data.

    min_diff_x, min_diff_membership, min_diff_sample_weight = (
        tf.keras.utils.unpack_x_y_sample_weight(min_diff_data_elem))
    ```
    The components are defined as follows:

    - `min_diff_x`: inputs to `original_model` to get the corresponding MinDiff
      predictions.
    - `min_diff_membership`: numerical [batch_size, 1] `Tensor` indicating which
      group each example comes from (marked as `0.0` or `1.0`).
    - `min_diff_sample_weight`: Optional weight `Tensor`. The weights will be
      applied to the examples during the `min_diff_loss` calculation.

    For each application of MinDiff, the `min_diff_loss` is ultimately
    calculated from the MinDiff predictions which are evaluated in the
    following way:

    ```
    ...  # In compute_min_diff_loss call.

    min_diff_x = ...  # Single batch of MinDiff examples.

    # Get predictions for MinDiff examples.
    min_diff_predictions = self.original_model(min_diff_x, training=training)
    # Transform the predictions if needed. By default this is the identity.
    min_diff_predictions = self.predictions_transform(min_diff_predictions)
    ```

    Returns:
      Scalar (if only one) or list of `min_diff_loss` values calculated from
        `min_diff_data`.

    Raises:
      ValueError: If the structure of `min_diff_data` does not match that of the
        `loss` that was passed to the model during initialization.
      ValueError: If the transformed `min_diff_predictions` is not a
        `tf.Tensor`.
    """
        # pyformat: enable

        structure_utils._assert_same_min_diff_structure(
            min_diff_data, self._loss)

        # Flatten everything and calculate min_diff_loss for each application.
        flat_data = structure_utils._flatten_min_diff_structure(min_diff_data)
        flat_losses = structure_utils._flatten_min_diff_structure(self._loss)
        flat_weights = structure_utils._flatten_min_diff_structure(
            self._loss_weight)
        flat_metrics = structure_utils._flatten_min_diff_structure(
            self._min_diff_loss_metric)
        min_diff_losses = [
            self._compute_single_min_diff_loss(data, loss, weight, metric,
                                               training, mask)
            for data, loss, weight, metric in zip(flat_data, flat_losses,
                                                  flat_weights, flat_metrics)
        ]
        # If there is only one application return a scalar rather than a list.
        if len(min_diff_losses) == 1:
            min_diff_losses = min_diff_losses[0]
        return min_diff_losses