Exemplo n.º 1
0
    def testValidateDictWithTypeCheck(self):
        elem = "element"
        struct = {"a": elem, "b": elem, "c": elem}
        structure_utils.validate_min_diff_structure(struct, element_type=str)

        # Assert raises an error if one or more elements are the wrong type.
        # A single element is the wrong type.
        struct = {"a": elem, "b": 2, "c": elem}
        with self.assertRaisesRegex(
                ValueError,
                "name.* is a dict.*must be unnested and contain only "
                "elements of type.*str.*{.*2.*}"):
            structure_utils.validate_min_diff_structure(struct,
                                                        struct_name="name",
                                                        element_type=str)

        # All elements are the (same) wrong type.
        struct = {"a": elem, "b": 2, "c": elem}
        with self.assertRaisesRegex(
                ValueError,
                "name.* is a dict.*must be unnested and contain only "
                "elements of type.*str.*{.*2.*}"):
            structure_utils.validate_min_diff_structure(struct,
                                                        struct_name="name",
                                                        element_type=str)
Exemplo n.º 2
0
    def testValidateSingleElement(self):
        elem = "element"  # String
        structure_utils.validate_min_diff_structure(elem)

        elem = 3  # int
        structure_utils.validate_min_diff_structure(elem)

        elem = ("a", "b", "c")
        structure_utils.validate_min_diff_structure(elem)
Exemplo n.º 3
0
    def testValidateSingleElementWithTypeCheck(self):
        elem = "element"
        structure_utils.validate_min_diff_structure(elem, element_type=str)

        with self.assertRaisesRegex(
                TypeError,
                "not a recognized MinDiff structure.*should have a type of"
                ".*single unnested element.*of type.*int.*str"):
            structure_utils.validate_min_diff_structure(elem, element_type=int)

        elem = 3
        structure_utils.validate_min_diff_structure(elem, element_type=int)

        with self.assertRaisesRegex(
                TypeError,
                "not a recognized MinDiff structure.*should have a type of"
                ".*single unnested element.*of type.*str.*.*int"):
            structure_utils.validate_min_diff_structure(elem, element_type=str)
Exemplo n.º 4
0
    def testValidateDict(self):
        elem = "element"
        struct = {"a": elem, "b": elem, "c": elem}
        structure_utils.validate_min_diff_structure(struct)

        # Assert raises an error if dict is not simple.
        struct = {"a": elem, "b": {"d": elem}, "c": elem}
        with self.assertRaisesRegex(
                ValueError, "name.* is a dict.*must be unnested.*{.*{.*}.*}"):
            structure_utils.validate_min_diff_structure(struct,
                                                        struct_name="name")

        # Assert raises an error if dict has non string keys.
        struct = {"a": elem, 3: elem, "c": elem}
        with self.assertRaisesRegex(
                ValueError,
                r"name.*must contain only string keys.*\['a', 3, 'c'\]"):
            structure_utils.validate_min_diff_structure(struct,
                                                        struct_name="name")
Exemplo n.º 5
0
def pack_min_diff_data(original_dataset: tf.data.Dataset,
                       sensitive_group_dataset=None,
                       nonsensitive_group_dataset=None,
                       min_diff_dataset=None) -> tf.data.Dataset:
    # pyformat: disable
    """Packs `min_diff_data` with the `x` component of the original dataset.

  Args:
    original_dataset: `tf.data.Dataset` that was used before applying min
      diff. The output should conform to the format used in
      `tf.keras.Model.fit`.
    sensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
      (unnested dict) of `tf.data.Dataset`s containing only examples that
      belong to the sensitive group.

      This must be passed in if `nonsensitive_group_dataset` is passed in.
      Furthermore, the `x` component for every batch should have the same
      structure as that of the `original_dataset` batches' `x` components.
    nonsensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
      (unnested dict) of `tf.data.Dataset`s containing only examples that do
      **not** belong to the sensitive group.

      This must be passed in if `sensitive_group_dataset` is passed in.
      Furthermore, the `x` component for every batch should have the same
      structure as that of the `original_dataset` batches' `x` components.
    min_diff_dataset: `tf.data.Dataset` or valid MinDiff structure (unnested
      dict) of `tf.data.Dataset`s containing only examples to be used to
      calculate the `min_diff_loss`.

      This should only be set if neither `sensitive_group_dataset` or
      `nonsensitive_group_dataset` is passed in.
      Furthermore, the `x` component for every batch should have the same
      structure as that of the `original_dataset` batches' `x` components.

  This function should be used to create the dataset that will be passed to
  `min_diff.keras.MinDiffModel` during training and, optionally, during
  evaluation.

  The inputs should either have both `sensitive_group_dataset` and
  `nonsensitive_group_dataset` passed in and `min_diff_dataset` left unset or
  vice versa. In the case of the former, `min_diff_data` will be built using
  `utils.build_min_diff_dataset`.

  Warning: All input datasets should be batched **before** being passed in.

  Each input dataset must output a tuple in the format used in
  `tf.keras.Model.fit`. Specifically the output must be a tuple of
  length 1, 2 or 3 in the form `(x, y, sample_weight)`.

  This output will be parsed internally in the following way:

  ```
  batch = ...  # Batch from any one of the input datasets.
  x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
  ```

  Every batch from the returned `tf.data.Dataset` will contain one batch from
  each of the input datasets. Each returned batch will be a tuple of
  `(packed_inputs, original_y, original_sample_weight)` matching the length of
  `original_dataset` batches where:

  - `packed_inputs`: is an instance of `utils.MinDiffPackedInputs` containing:

    - `original_inputs`: `x` component taken directly from the
        `original_dataset` batch.
    - `min_diff_data`: batch of data formed from `sensitive_group_dataset` and
      `nonsensitive_group_dataset` (as described in
      `utils.build_min_diff_dataset`) or taken directly from `min_diff_dataset`.

  - `original_y`: is the `y` component taken directly from the
    `original_dataset` batch.
  - `original_sample_weight`: is the `sample_weight` component taken directly
    from the `original_dataset` batch.

  `min_diff_data` will be used in `min_diff.keras.MinDiffModel` when calculating
  the `min_diff_loss`. It is a tuple or structure (matching the structure of the
  inputs) of `(min_diff_x, min_diff_membership, min_diff_sample_weight)`.

  Caution: If you are passing in `min_diff_dataset` make sure that each
  `min_diff_data` batch contains about the same number of sensitive and
  nonsensitive examples as indicated by `min_diff_membership` (when passing in
  `sensitive_group_dataset` and `nonsensitive_group_dataset` this is determined
  by their batch sizes).

  Returns:
    A `tf.data.Dataset` whose output is a tuple of (`packed_inputs`,
      `original_y`, `original_sample_weight`) matching the output length
      of `original_dataset`.
  """
    # pyformat: enable
    # Either sensitive_group_dataset and nonsensitive_group_dataset are both set
    # and min_diff_dataset is not or vice versa.
    min_diff_dataset_present = min_diff_dataset is not None
    sensitive_dataset_present = sensitive_group_dataset is not None
    nonsensitive_dataset_present = nonsensitive_group_dataset is not None
    # Case where min_diff_dataset is set and the others are not.
    set_to_use_min_diff_dataset = (
        min_diff_dataset_present
        and not (sensitive_dataset_present or nonsensitive_dataset_present))
    # Case where sensitive_group_dataset and nonsensitive_group_dataset are both
    # set and min_diff_dataset is not.
    set_to_construct_min_diff_dataset = ((sensitive_dataset_present
                                          and nonsensitive_dataset_present)
                                         and not min_diff_dataset_present)
    if not (set_to_use_min_diff_dataset or set_to_construct_min_diff_dataset):
        raise ValueError(
            "Invalid arguments: You must either pass in only the `min_diff_dataset`"
            " (and leave `sensitive_group_dataset` and `nonsensitive_group_dataset`"
            " as None) or set both `sensitive_group_dataset` and "
            "`nonsensitive_group_dataset` (and leave `min_diff_dataset` as None), "
            "given: \n"
            "\n`sensitive_group_dataset`: {}"
            "\n`nonsensitive_group_dataset`: {}"
            "\n`min_diff_dataset`: {}".format(sensitive_group_dataset,
                                              nonsensitive_group_dataset,
                                              min_diff_dataset))

    # First construct the min_diff_dataset if need be.
    if set_to_construct_min_diff_dataset:
        min_diff_dataset = build_min_diff_dataset(sensitive_group_dataset,
                                                  nonsensitive_group_dataset)
    else:
        # validate min_diff_dataset since it was passed in.
        structure_utils.validate_min_diff_structure(
            min_diff_dataset,
            struct_name="min_diff_dataset",
            element_type=tf.data.Dataset)

    dataset = tf.data.Dataset.zip((original_dataset, min_diff_dataset))

    def _map_fn(original_batch, min_diff_batch):
        # Unpack original batch.
        original_x, original_y, original_sample_weight = (
            tf.keras.utils.unpack_x_y_sample_weight(original_batch))

        # Assert that all min_diff_xs have the same structure as original_x.
        # TODO: Should we assert that Tensor shapes are the same (other
        #                    than number of examples).

        min_diff_xs = [
            tf.keras.utils.unpack_x_y_sample_weight(batch)[
                0]  # First element is x.
            for batch in structure_utils._flatten_min_diff_structure(
                min_diff_batch)
        ]
        for min_diff_x in min_diff_xs:
            try:
                tf.nest.assert_same_structure(original_x, min_diff_x)
            except Exception as e:
                raise type(
                    e
                )("The x component structure of (one of) the `min_diff_dataset`(s) "
                  "does not match that of the original x structure (original shown "
                  "first): {}".format(e))

        # pack min_diff_batch with original_x
        return _pack_as_original(
            original_batch,
            MinDiffPackedInputs(original_inputs=original_x,
                                min_diff_data=min_diff_batch), original_y,
            original_sample_weight)

    # Reshape dataset output.
    return dataset.map(_map_fn)
Exemplo n.º 6
0
def build_min_diff_dataset(sensitive_group_dataset,
                           nonsensitive_group_dataset) -> tf.data.Dataset:
    # pyformat: disable
    """Build MinDiff dataset from sensitive and nonsensitive datasets.

  Args:
    sensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
      (unnested dict) of `tf.data.Dataset`s containing only examples that
      belong to the sensitive group.
    nonsensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
      (unnested dict) of `tf.data.Dataset`s containing only examples that do
      **not** belong to the sensitive group.

  This function builds a `tf.data.Dataset` containing examples that are meant to
  only be used when calculating a `min_diff_loss`. This resulting dataset will
  need to be packed with the original dataset used for the original task of the
  model which can be done by calling `utils.pack_min_diff_data`.

  Warning: All input datasets should be batched **before** being passed in.

  Each input dataset must output a tuple in the format used in
  `tf.keras.Model.fit`. Specifically the output must be a tuple of
  length 1, 2 or 3 in the form `(x, y, sample_weight)`.

  This output will be parsed internally in the following way:

  ```
  batch = ...  # Batch from any of the input datasets.
  x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
  ```

  Note: the `y` component of input datasets will be ignored completely so it can
  be set to `None` or any other arbitrary value. If `sample_weight` is not
  included, it can be left out entirely.

  Every batch from the returned `tf.data.Dataset` will contain one batch from
  each of the input datasets. Each returned batch will be a tuple or structure
  (matching the structure of the inputs) of `(min_diff_x, min_diff_membership,
  min_diff_sample_weight)` where, for each pair of input datasets:

  - `min_diff_x`: is formed by concatenating the `x` components of the paired
    datasets. The structure of these must match. If they don't the dataset will
    raise an error at the first batch.
  - `min_diff_membership`: is a tensor of size `[min_diff_batch_size, 1]`
    indicating which dataset each example comes from (`1.0` for
    `sensitive_group_dataset` and `0.0` for `nonsensitive_group_dataset`).
  - `min_diff_sample_weight`: is formed by concatenating the `sample_weight`
    components of the paired datasets. If both are `None`, then this will be set
    to `None`. If only one is `None`, it is replaced with a `Tensor` of ones of
    the appropriate shape.

  Returns:
    A `tf.data.Dataset` whose output is a tuple or structure (matching the
      structure of the inputs) of `(min_diff_x, min_diff_membership,
      min_diff_sample_weight)`.

  Raises:
    ValueError: If either `sensitive_group_dataset` or
      `nonsensitive_group_dataset` is not a valid MinDiff structure (unnested
      dict).
    ValueError: If `sensitive_group_dataset` and `nonsensitive_group_dataset` do
      not have the same structure.
  """
    # pyformat: enable
    # validate structures.
    structure_utils.validate_min_diff_structure(
        sensitive_group_dataset,
        struct_name="sensitive_group_dataset",
        element_type=tf.data.Dataset)
    structure_utils.validate_min_diff_structure(
        nonsensitive_group_dataset,
        struct_name="nonsensitive_group_dataset",
        element_type=tf.data.Dataset)
    try:

        structure_utils._assert_same_min_diff_structure(
            sensitive_group_dataset, nonsensitive_group_dataset)
    except Exception as e:
        raise type(e)(
            "`sensitive_group_dataset` and `nonsensitive_group_dataset` "
            "do not have the same structure:\n{}".format(e))

    sensitive_group_dataset = tf.nest.map_structure(
        lambda dataset: dataset.repeat(), sensitive_group_dataset)
    nonsensitive_group_dataset = tf.nest.map_structure(
        lambda dataset: dataset.repeat(), nonsensitive_group_dataset)

    dataset = tf.data.Dataset.zip(
        (sensitive_group_dataset, nonsensitive_group_dataset))

    def _build_single_batch(single_sensitive_batch, single_nonsensitive_batch):
        # Unpack both batches.
        sensitive_x, _, sensitive_sample_weight = (
            tf.keras.utils.unpack_x_y_sample_weight(single_sensitive_batch))
        nonsensitive_x, _, nonsensitive_sample_weight = (
            tf.keras.utils.unpack_x_y_sample_weight(single_nonsensitive_batch))

        # sensitive_x and nonsensitive_x must have the same structure.
        try:
            tf.nest.assert_same_structure(sensitive_x, nonsensitive_x)
        except Exception as e:
            raise type(e)(
                "The x component structure of (one of) the "
                "`sensitive_group_dataset`(s) does not match that of the "
                "(corresponding) `nonsensitive_group_dataset` x structure "
                "(sensitive shown first): {}".format(e))

        # Create min_diff_data.
        # Merge sensitive_x and nonsensitive_x to form min_diff_x.
        flat_sensitive_x = tf.nest.flatten(sensitive_x)
        flat_nonsensitive_x = tf.nest.flatten(nonsensitive_x)
        flat_min_diff_x = [
            tf.concat([t1, t2], axis=0)
            for t1, t2 in zip(flat_sensitive_x, flat_nonsensitive_x)
        ]
        min_diff_x = tf.nest.pack_sequence_as(sensitive_x, flat_min_diff_x)

        # min_diff_membership indicates which dataset each example comes from.
        sensitive_shape = [tf.shape(flat_sensitive_x[0])[0], 1]
        nonsensitive_shape = [tf.shape(flat_nonsensitive_x[0])[0], 1]
        min_diff_membership = tf.concat(axis=0,
                                        values=[
                                            tf.ones(sensitive_shape,
                                                    dtype=tf.float32),
                                            tf.zeros(nonsensitive_shape,
                                                     dtype=tf.float32)
                                        ])
        # min_diff_sample_weight is the concatenation of both sample_weights.
        min_diff_sample_weight = None  # Default if both sample_weights are None.
        if (sensitive_sample_weight is not None
                or nonsensitive_sample_weight is not None):
            if sensitive_sample_weight is None:
                sensitive_sample_weight = tf.ones(sensitive_shape,
                                                  dtype=tf.float32)
            elif nonsensitive_sample_weight is None:
                nonsensitive_sample_weight = tf.ones(nonsensitive_shape,
                                                     dtype=tf.float32)
            min_diff_sample_weight = tf.concat(
                [sensitive_sample_weight, nonsensitive_sample_weight], axis=0)

        # Pack the three components and return them
        return tf.keras.utils.pack_x_y_sample_weight(min_diff_x,
                                                     min_diff_membership,
                                                     min_diff_sample_weight)

    def _map_fn(sensitive_batch, nonsensitive_batch):

        flat_sensitive_batch = structure_utils._flatten_min_diff_structure(
            sensitive_batch)
        flat_nonsensitive_batch = structure_utils._flatten_min_diff_structure(
            nonsensitive_batch)

        flat_min_diff_data = [
            _build_single_batch(single_sensitive_batch,
                                single_nonsensitive_batch)
            for single_sensitive_batch, single_nonsensitive_batch in zip(
                flat_sensitive_batch, flat_nonsensitive_batch)
        ]

        return structure_utils._pack_min_diff_sequence_as(
            sensitive_batch, flat_min_diff_data)

    # Reshape dataset output.
    return dataset.map(_map_fn)
Exemplo n.º 7
0
 def testValidateRaisesErrorWithBadElementType(self):
     with self.assertRaisesRegex(
             TypeError,
             "element_type.*expected type.*object instance was given.*a"):
         structure_utils.validate_min_diff_structure("elem",
                                                     element_type="a")
Exemplo n.º 8
0
def _conform_weights_to_losses(loss, loss_weight, default_value):
    """Conforms weights to match structure of losses.

  Shape weights to match the structure of `loss` if possible. If `loss_weight`
  is a single value, it will be broadcast for all losses. If `loss_weight` is
  `None` or has missing entries, `default_value` will be used.

  Args:
    loss: loss (possible nested) that weights will be conformed to.
    loss_weight: weight that will be conformed to loss structure. If only a
      single value, it will be broadcast for all losses. If `None`, it will be
      replaced by `default_value`.
    default_value: Value used if `loss_weight` is `None` or if some weights are
      missing for certain losses.

  Returns:
    Weight corresponding to `loss` structure.
  """
    # Validate loss (loss_weights will be implicitly validated)
    structure_utils.validate_min_diff_structure(loss, struct_name="loss")

    # If loss_weight is unnested, then broadcast to all values of loss.
    if not tf.nest.is_nested(loss_weight):
        if loss_weight is None:
            loss_weight = default_value
        return tf.nest.map_structure(lambda _: loss_weight, loss)

    # If execution reaches here, then loss_weight is nested (a dict).

    # If loss is not nested, then raise an error (since loss_weight is a nested).
    if not tf.nest.is_nested(loss):
        try:
            tf.nest.assert_same_structure(loss, loss_weight)
        except Exception as e:
            raise ValueError("`loss` and `loss_weight` do not have matching "
                             "structures: \n{}".format(e))

    # At this point, we should be guaranteed that the two structures are dicts if
    # they are valid MinDiff structures. However, in case they are not, we assert
    # that they are both dicts (this also helps be future proof since it will
    # catch the broken assumption immediately if the validity definition changes).
    # Note: As is, it should be impossible to get to this point. The only way it
    #       would is if this function is called without validating or if the
    #       definition of a valid MinDiff structure has changed.
    if not (isinstance(loss, dict) and isinstance(loss_weight, dict)):
        raise ValueError(
            "One of `loss` and `loss_weight` is neither a single element nor a "
            "dict. This should never happen if they are valid MinDiff structures. "
            "If you think this is a valid use case (e.g. if the definition has "
            "changed but this piece of code is out of sync), please file an issue "
            "so we can look at it and make the appropriate fix.")

    # Save copy to not alter the original dict.
    loss_weight = loss_weight.copy()

    # First, we make sure to set defaults for any losses that do not have
    # corresponding weights. Raise an error if there are weights with keys that
    # don't correspond to losses.
    if not set(loss_weight.keys()) <= set(loss.keys()):
        raise ValueError(
            "`loss_weight` contains keys that do not correspond to losses:"
            "\n\nloss: {}\n\nloss_weight: {}".format(loss, loss_weight))

    # Provide defaults for any missing weights.
    for key in loss.keys():
        if key not in loss_weight:
            loss_weight[key] = default_value

    # At this point, we should be guaranteed that the two structures match if they
    # are valid MinDiff structures. However, in case they are not we assert that
    # they match.
    try:
        tf.nest.assert_same_structure(loss, loss_weight)
    except Exception as e:
        raise ValueError(
            "`loss` and `loss_weight` (potentially with default weights added) "
            "do not have matching structures: \n{}".format(e))

    return loss_weight
Exemplo n.º 9
0
    def __init__(self,
                 original_model: tf.keras.Model,
                 loss,
                 loss_weight=1.0,
                 predictions_transform=None,
                 **kwargs):
        """Initializes a MinDiffModel instance.

    Raises:
      ValueError: If `predictions_transform` is passed in but not callable.
    """
        # Roundabout way of accessing the Functional class.
        functional_class = tf.keras.Sequential.__bases__[0]
        # We need to handle a special case where a custom MinDiffModel class is
        # created that is also a subclass of the Functional class. In this case, we
        # need to make sure that args match what the Functional.__init__ requires
        # (i.e. `inputs` and `outputs` args) and that the rest of the
        # Functional.__init__ method is skipped (supported by passing in
        # `skip_init=True`).
        # This requires any __init__ methods to not do input validation and to
        # pass through `skip_init`.
        if (isinstance(self, functional_class)
                and not isinstance(self, tf.keras.Sequential)):
            try:
                super(MinDiffModel, self).__init__(inputs=None,
                                                   outputs=None,
                                                   skip_init=True,
                                                   **kwargs)
                tf.keras.Model.__init__(self, **kwargs)
            except Exception as e:
                raise type(
                    e
                )("There was a problem initializing the MinDiffModel subclass "
                  "instance. This was likely caused by:\n"
                  "  - The kwargs that were passed in were not valid according to "
                  "tf.keras.Model or a base of your custom Model.\n"
                  "  - Some args validation or requirement in your custom Model "
                  "__init__ method is too strict.\n"
                  "  - Your Model subclass is not passing through **kwargs (in "
                  "particular `skip_init`) to the super().__init__ invocation.\n"
                  "To fix this, either fix the args, loosen the requirements, or "
                  "make sure to pass **kwargs to calls with super. If this is not "
                  "possible, you may need to integrate MinDiff without using "
                  "MinDiffModel.\n"
                  "Error raised: {}".format(e))
        else:
            try:
                super(MinDiffModel, self).__init__(**kwargs)
            except Exception as e:
                raise type(
                    e
                )("There was a problem initializing the MinDiffModel instance. "
                  "This was likely caused by the kwargs that were passed in not "
                  "being valid according to tf.keras.Model.\n"
                  "Error raised: {}".format(e))

        # Set _auto_track_sub_layers to true to ensure we track the
        # original_model and MinDiff layers.
        self._auto_track_sub_layers = True  # Track sub layers.
        self.built = True  # This Model is built, original_model may or may not be.
        # Masking, if any, is taken care of by original_model.
        self._supports_masking = False
        # Clear input_spec in case there is one. We cannot make any strong
        # assertions because `min_diff_data` may or may not be included and can
        # have different shapes since weight is optional.
        self.input_spec = None

        self._original_model = original_model
        structure_utils.validate_min_diff_structure(loss, struct_name="loss")
        self._loss = tf.nest.map_structure(loss_utils._get_loss, loss)
        structure_utils.validate_min_diff_structure(loss_weight,
                                                    struct_name="loss_weight")
        self._loss_weight = _conform_weights_to_losses(self._loss,
                                                       loss_weight,
                                                       default_value=1.0)
        self._min_diff_loss_metric = _create_unique_metrics(
            self._loss, self.metrics)

        if (predictions_transform is not None
                and not callable(predictions_transform)):
            raise ValueError(
                "`predictions_transform` must be callable if passed "
                "in, given: {}".format(predictions_transform))
        self._predictions_transform = predictions_transform