def testValidateDictWithTypeCheck(self): elem = "element" struct = {"a": elem, "b": elem, "c": elem} structure_utils.validate_min_diff_structure(struct, element_type=str) # Assert raises an error if one or more elements are the wrong type. # A single element is the wrong type. struct = {"a": elem, "b": 2, "c": elem} with self.assertRaisesRegex( ValueError, "name.* is a dict.*must be unnested and contain only " "elements of type.*str.*{.*2.*}"): structure_utils.validate_min_diff_structure(struct, struct_name="name", element_type=str) # All elements are the (same) wrong type. struct = {"a": elem, "b": 2, "c": elem} with self.assertRaisesRegex( ValueError, "name.* is a dict.*must be unnested and contain only " "elements of type.*str.*{.*2.*}"): structure_utils.validate_min_diff_structure(struct, struct_name="name", element_type=str)
def testValidateSingleElement(self): elem = "element" # String structure_utils.validate_min_diff_structure(elem) elem = 3 # int structure_utils.validate_min_diff_structure(elem) elem = ("a", "b", "c") structure_utils.validate_min_diff_structure(elem)
def testValidateSingleElementWithTypeCheck(self): elem = "element" structure_utils.validate_min_diff_structure(elem, element_type=str) with self.assertRaisesRegex( TypeError, "not a recognized MinDiff structure.*should have a type of" ".*single unnested element.*of type.*int.*str"): structure_utils.validate_min_diff_structure(elem, element_type=int) elem = 3 structure_utils.validate_min_diff_structure(elem, element_type=int) with self.assertRaisesRegex( TypeError, "not a recognized MinDiff structure.*should have a type of" ".*single unnested element.*of type.*str.*.*int"): structure_utils.validate_min_diff_structure(elem, element_type=str)
def testValidateDict(self): elem = "element" struct = {"a": elem, "b": elem, "c": elem} structure_utils.validate_min_diff_structure(struct) # Assert raises an error if dict is not simple. struct = {"a": elem, "b": {"d": elem}, "c": elem} with self.assertRaisesRegex( ValueError, "name.* is a dict.*must be unnested.*{.*{.*}.*}"): structure_utils.validate_min_diff_structure(struct, struct_name="name") # Assert raises an error if dict has non string keys. struct = {"a": elem, 3: elem, "c": elem} with self.assertRaisesRegex( ValueError, r"name.*must contain only string keys.*\['a', 3, 'c'\]"): structure_utils.validate_min_diff_structure(struct, struct_name="name")
def pack_min_diff_data(original_dataset: tf.data.Dataset, sensitive_group_dataset=None, nonsensitive_group_dataset=None, min_diff_dataset=None) -> tf.data.Dataset: # pyformat: disable """Packs `min_diff_data` with the `x` component of the original dataset. Args: original_dataset: `tf.data.Dataset` that was used before applying min diff. The output should conform to the format used in `tf.keras.Model.fit`. sensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure (unnested dict) of `tf.data.Dataset`s containing only examples that belong to the sensitive group. This must be passed in if `nonsensitive_group_dataset` is passed in. Furthermore, the `x` component for every batch should have the same structure as that of the `original_dataset` batches' `x` components. nonsensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure (unnested dict) of `tf.data.Dataset`s containing only examples that do **not** belong to the sensitive group. This must be passed in if `sensitive_group_dataset` is passed in. Furthermore, the `x` component for every batch should have the same structure as that of the `original_dataset` batches' `x` components. min_diff_dataset: `tf.data.Dataset` or valid MinDiff structure (unnested dict) of `tf.data.Dataset`s containing only examples to be used to calculate the `min_diff_loss`. This should only be set if neither `sensitive_group_dataset` or `nonsensitive_group_dataset` is passed in. Furthermore, the `x` component for every batch should have the same structure as that of the `original_dataset` batches' `x` components. This function should be used to create the dataset that will be passed to `min_diff.keras.MinDiffModel` during training and, optionally, during evaluation. The inputs should either have both `sensitive_group_dataset` and `nonsensitive_group_dataset` passed in and `min_diff_dataset` left unset or vice versa. In the case of the former, `min_diff_data` will be built using `utils.build_min_diff_dataset`. Warning: All input datasets should be batched **before** being passed in. Each input dataset must output a tuple in the format used in `tf.keras.Model.fit`. Specifically the output must be a tuple of length 1, 2 or 3 in the form `(x, y, sample_weight)`. This output will be parsed internally in the following way: ``` batch = ... # Batch from any one of the input datasets. x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch) ``` Every batch from the returned `tf.data.Dataset` will contain one batch from each of the input datasets. Each returned batch will be a tuple of `(packed_inputs, original_y, original_sample_weight)` matching the length of `original_dataset` batches where: - `packed_inputs`: is an instance of `utils.MinDiffPackedInputs` containing: - `original_inputs`: `x` component taken directly from the `original_dataset` batch. - `min_diff_data`: batch of data formed from `sensitive_group_dataset` and `nonsensitive_group_dataset` (as described in `utils.build_min_diff_dataset`) or taken directly from `min_diff_dataset`. - `original_y`: is the `y` component taken directly from the `original_dataset` batch. - `original_sample_weight`: is the `sample_weight` component taken directly from the `original_dataset` batch. `min_diff_data` will be used in `min_diff.keras.MinDiffModel` when calculating the `min_diff_loss`. It is a tuple or structure (matching the structure of the inputs) of `(min_diff_x, min_diff_membership, min_diff_sample_weight)`. Caution: If you are passing in `min_diff_dataset` make sure that each `min_diff_data` batch contains about the same number of sensitive and nonsensitive examples as indicated by `min_diff_membership` (when passing in `sensitive_group_dataset` and `nonsensitive_group_dataset` this is determined by their batch sizes). Returns: A `tf.data.Dataset` whose output is a tuple of (`packed_inputs`, `original_y`, `original_sample_weight`) matching the output length of `original_dataset`. """ # pyformat: enable # Either sensitive_group_dataset and nonsensitive_group_dataset are both set # and min_diff_dataset is not or vice versa. min_diff_dataset_present = min_diff_dataset is not None sensitive_dataset_present = sensitive_group_dataset is not None nonsensitive_dataset_present = nonsensitive_group_dataset is not None # Case where min_diff_dataset is set and the others are not. set_to_use_min_diff_dataset = ( min_diff_dataset_present and not (sensitive_dataset_present or nonsensitive_dataset_present)) # Case where sensitive_group_dataset and nonsensitive_group_dataset are both # set and min_diff_dataset is not. set_to_construct_min_diff_dataset = ((sensitive_dataset_present and nonsensitive_dataset_present) and not min_diff_dataset_present) if not (set_to_use_min_diff_dataset or set_to_construct_min_diff_dataset): raise ValueError( "Invalid arguments: You must either pass in only the `min_diff_dataset`" " (and leave `sensitive_group_dataset` and `nonsensitive_group_dataset`" " as None) or set both `sensitive_group_dataset` and " "`nonsensitive_group_dataset` (and leave `min_diff_dataset` as None), " "given: \n" "\n`sensitive_group_dataset`: {}" "\n`nonsensitive_group_dataset`: {}" "\n`min_diff_dataset`: {}".format(sensitive_group_dataset, nonsensitive_group_dataset, min_diff_dataset)) # First construct the min_diff_dataset if need be. if set_to_construct_min_diff_dataset: min_diff_dataset = build_min_diff_dataset(sensitive_group_dataset, nonsensitive_group_dataset) else: # validate min_diff_dataset since it was passed in. structure_utils.validate_min_diff_structure( min_diff_dataset, struct_name="min_diff_dataset", element_type=tf.data.Dataset) dataset = tf.data.Dataset.zip((original_dataset, min_diff_dataset)) def _map_fn(original_batch, min_diff_batch): # Unpack original batch. original_x, original_y, original_sample_weight = ( tf.keras.utils.unpack_x_y_sample_weight(original_batch)) # Assert that all min_diff_xs have the same structure as original_x. # TODO: Should we assert that Tensor shapes are the same (other # than number of examples). min_diff_xs = [ tf.keras.utils.unpack_x_y_sample_weight(batch)[ 0] # First element is x. for batch in structure_utils._flatten_min_diff_structure( min_diff_batch) ] for min_diff_x in min_diff_xs: try: tf.nest.assert_same_structure(original_x, min_diff_x) except Exception as e: raise type( e )("The x component structure of (one of) the `min_diff_dataset`(s) " "does not match that of the original x structure (original shown " "first): {}".format(e)) # pack min_diff_batch with original_x return _pack_as_original( original_batch, MinDiffPackedInputs(original_inputs=original_x, min_diff_data=min_diff_batch), original_y, original_sample_weight) # Reshape dataset output. return dataset.map(_map_fn)
def build_min_diff_dataset(sensitive_group_dataset, nonsensitive_group_dataset) -> tf.data.Dataset: # pyformat: disable """Build MinDiff dataset from sensitive and nonsensitive datasets. Args: sensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure (unnested dict) of `tf.data.Dataset`s containing only examples that belong to the sensitive group. nonsensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure (unnested dict) of `tf.data.Dataset`s containing only examples that do **not** belong to the sensitive group. This function builds a `tf.data.Dataset` containing examples that are meant to only be used when calculating a `min_diff_loss`. This resulting dataset will need to be packed with the original dataset used for the original task of the model which can be done by calling `utils.pack_min_diff_data`. Warning: All input datasets should be batched **before** being passed in. Each input dataset must output a tuple in the format used in `tf.keras.Model.fit`. Specifically the output must be a tuple of length 1, 2 or 3 in the form `(x, y, sample_weight)`. This output will be parsed internally in the following way: ``` batch = ... # Batch from any of the input datasets. x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch) ``` Note: the `y` component of input datasets will be ignored completely so it can be set to `None` or any other arbitrary value. If `sample_weight` is not included, it can be left out entirely. Every batch from the returned `tf.data.Dataset` will contain one batch from each of the input datasets. Each returned batch will be a tuple or structure (matching the structure of the inputs) of `(min_diff_x, min_diff_membership, min_diff_sample_weight)` where, for each pair of input datasets: - `min_diff_x`: is formed by concatenating the `x` components of the paired datasets. The structure of these must match. If they don't the dataset will raise an error at the first batch. - `min_diff_membership`: is a tensor of size `[min_diff_batch_size, 1]` indicating which dataset each example comes from (`1.0` for `sensitive_group_dataset` and `0.0` for `nonsensitive_group_dataset`). - `min_diff_sample_weight`: is formed by concatenating the `sample_weight` components of the paired datasets. If both are `None`, then this will be set to `None`. If only one is `None`, it is replaced with a `Tensor` of ones of the appropriate shape. Returns: A `tf.data.Dataset` whose output is a tuple or structure (matching the structure of the inputs) of `(min_diff_x, min_diff_membership, min_diff_sample_weight)`. Raises: ValueError: If either `sensitive_group_dataset` or `nonsensitive_group_dataset` is not a valid MinDiff structure (unnested dict). ValueError: If `sensitive_group_dataset` and `nonsensitive_group_dataset` do not have the same structure. """ # pyformat: enable # validate structures. structure_utils.validate_min_diff_structure( sensitive_group_dataset, struct_name="sensitive_group_dataset", element_type=tf.data.Dataset) structure_utils.validate_min_diff_structure( nonsensitive_group_dataset, struct_name="nonsensitive_group_dataset", element_type=tf.data.Dataset) try: structure_utils._assert_same_min_diff_structure( sensitive_group_dataset, nonsensitive_group_dataset) except Exception as e: raise type(e)( "`sensitive_group_dataset` and `nonsensitive_group_dataset` " "do not have the same structure:\n{}".format(e)) sensitive_group_dataset = tf.nest.map_structure( lambda dataset: dataset.repeat(), sensitive_group_dataset) nonsensitive_group_dataset = tf.nest.map_structure( lambda dataset: dataset.repeat(), nonsensitive_group_dataset) dataset = tf.data.Dataset.zip( (sensitive_group_dataset, nonsensitive_group_dataset)) def _build_single_batch(single_sensitive_batch, single_nonsensitive_batch): # Unpack both batches. sensitive_x, _, sensitive_sample_weight = ( tf.keras.utils.unpack_x_y_sample_weight(single_sensitive_batch)) nonsensitive_x, _, nonsensitive_sample_weight = ( tf.keras.utils.unpack_x_y_sample_weight(single_nonsensitive_batch)) # sensitive_x and nonsensitive_x must have the same structure. try: tf.nest.assert_same_structure(sensitive_x, nonsensitive_x) except Exception as e: raise type(e)( "The x component structure of (one of) the " "`sensitive_group_dataset`(s) does not match that of the " "(corresponding) `nonsensitive_group_dataset` x structure " "(sensitive shown first): {}".format(e)) # Create min_diff_data. # Merge sensitive_x and nonsensitive_x to form min_diff_x. flat_sensitive_x = tf.nest.flatten(sensitive_x) flat_nonsensitive_x = tf.nest.flatten(nonsensitive_x) flat_min_diff_x = [ tf.concat([t1, t2], axis=0) for t1, t2 in zip(flat_sensitive_x, flat_nonsensitive_x) ] min_diff_x = tf.nest.pack_sequence_as(sensitive_x, flat_min_diff_x) # min_diff_membership indicates which dataset each example comes from. sensitive_shape = [tf.shape(flat_sensitive_x[0])[0], 1] nonsensitive_shape = [tf.shape(flat_nonsensitive_x[0])[0], 1] min_diff_membership = tf.concat(axis=0, values=[ tf.ones(sensitive_shape, dtype=tf.float32), tf.zeros(nonsensitive_shape, dtype=tf.float32) ]) # min_diff_sample_weight is the concatenation of both sample_weights. min_diff_sample_weight = None # Default if both sample_weights are None. if (sensitive_sample_weight is not None or nonsensitive_sample_weight is not None): if sensitive_sample_weight is None: sensitive_sample_weight = tf.ones(sensitive_shape, dtype=tf.float32) elif nonsensitive_sample_weight is None: nonsensitive_sample_weight = tf.ones(nonsensitive_shape, dtype=tf.float32) min_diff_sample_weight = tf.concat( [sensitive_sample_weight, nonsensitive_sample_weight], axis=0) # Pack the three components and return them return tf.keras.utils.pack_x_y_sample_weight(min_diff_x, min_diff_membership, min_diff_sample_weight) def _map_fn(sensitive_batch, nonsensitive_batch): flat_sensitive_batch = structure_utils._flatten_min_diff_structure( sensitive_batch) flat_nonsensitive_batch = structure_utils._flatten_min_diff_structure( nonsensitive_batch) flat_min_diff_data = [ _build_single_batch(single_sensitive_batch, single_nonsensitive_batch) for single_sensitive_batch, single_nonsensitive_batch in zip( flat_sensitive_batch, flat_nonsensitive_batch) ] return structure_utils._pack_min_diff_sequence_as( sensitive_batch, flat_min_diff_data) # Reshape dataset output. return dataset.map(_map_fn)
def testValidateRaisesErrorWithBadElementType(self): with self.assertRaisesRegex( TypeError, "element_type.*expected type.*object instance was given.*a"): structure_utils.validate_min_diff_structure("elem", element_type="a")
def _conform_weights_to_losses(loss, loss_weight, default_value): """Conforms weights to match structure of losses. Shape weights to match the structure of `loss` if possible. If `loss_weight` is a single value, it will be broadcast for all losses. If `loss_weight` is `None` or has missing entries, `default_value` will be used. Args: loss: loss (possible nested) that weights will be conformed to. loss_weight: weight that will be conformed to loss structure. If only a single value, it will be broadcast for all losses. If `None`, it will be replaced by `default_value`. default_value: Value used if `loss_weight` is `None` or if some weights are missing for certain losses. Returns: Weight corresponding to `loss` structure. """ # Validate loss (loss_weights will be implicitly validated) structure_utils.validate_min_diff_structure(loss, struct_name="loss") # If loss_weight is unnested, then broadcast to all values of loss. if not tf.nest.is_nested(loss_weight): if loss_weight is None: loss_weight = default_value return tf.nest.map_structure(lambda _: loss_weight, loss) # If execution reaches here, then loss_weight is nested (a dict). # If loss is not nested, then raise an error (since loss_weight is a nested). if not tf.nest.is_nested(loss): try: tf.nest.assert_same_structure(loss, loss_weight) except Exception as e: raise ValueError("`loss` and `loss_weight` do not have matching " "structures: \n{}".format(e)) # At this point, we should be guaranteed that the two structures are dicts if # they are valid MinDiff structures. However, in case they are not, we assert # that they are both dicts (this also helps be future proof since it will # catch the broken assumption immediately if the validity definition changes). # Note: As is, it should be impossible to get to this point. The only way it # would is if this function is called without validating or if the # definition of a valid MinDiff structure has changed. if not (isinstance(loss, dict) and isinstance(loss_weight, dict)): raise ValueError( "One of `loss` and `loss_weight` is neither a single element nor a " "dict. This should never happen if they are valid MinDiff structures. " "If you think this is a valid use case (e.g. if the definition has " "changed but this piece of code is out of sync), please file an issue " "so we can look at it and make the appropriate fix.") # Save copy to not alter the original dict. loss_weight = loss_weight.copy() # First, we make sure to set defaults for any losses that do not have # corresponding weights. Raise an error if there are weights with keys that # don't correspond to losses. if not set(loss_weight.keys()) <= set(loss.keys()): raise ValueError( "`loss_weight` contains keys that do not correspond to losses:" "\n\nloss: {}\n\nloss_weight: {}".format(loss, loss_weight)) # Provide defaults for any missing weights. for key in loss.keys(): if key not in loss_weight: loss_weight[key] = default_value # At this point, we should be guaranteed that the two structures match if they # are valid MinDiff structures. However, in case they are not we assert that # they match. try: tf.nest.assert_same_structure(loss, loss_weight) except Exception as e: raise ValueError( "`loss` and `loss_weight` (potentially with default weights added) " "do not have matching structures: \n{}".format(e)) return loss_weight
def __init__(self, original_model: tf.keras.Model, loss, loss_weight=1.0, predictions_transform=None, **kwargs): """Initializes a MinDiffModel instance. Raises: ValueError: If `predictions_transform` is passed in but not callable. """ # Roundabout way of accessing the Functional class. functional_class = tf.keras.Sequential.__bases__[0] # We need to handle a special case where a custom MinDiffModel class is # created that is also a subclass of the Functional class. In this case, we # need to make sure that args match what the Functional.__init__ requires # (i.e. `inputs` and `outputs` args) and that the rest of the # Functional.__init__ method is skipped (supported by passing in # `skip_init=True`). # This requires any __init__ methods to not do input validation and to # pass through `skip_init`. if (isinstance(self, functional_class) and not isinstance(self, tf.keras.Sequential)): try: super(MinDiffModel, self).__init__(inputs=None, outputs=None, skip_init=True, **kwargs) tf.keras.Model.__init__(self, **kwargs) except Exception as e: raise type( e )("There was a problem initializing the MinDiffModel subclass " "instance. This was likely caused by:\n" " - The kwargs that were passed in were not valid according to " "tf.keras.Model or a base of your custom Model.\n" " - Some args validation or requirement in your custom Model " "__init__ method is too strict.\n" " - Your Model subclass is not passing through **kwargs (in " "particular `skip_init`) to the super().__init__ invocation.\n" "To fix this, either fix the args, loosen the requirements, or " "make sure to pass **kwargs to calls with super. If this is not " "possible, you may need to integrate MinDiff without using " "MinDiffModel.\n" "Error raised: {}".format(e)) else: try: super(MinDiffModel, self).__init__(**kwargs) except Exception as e: raise type( e )("There was a problem initializing the MinDiffModel instance. " "This was likely caused by the kwargs that were passed in not " "being valid according to tf.keras.Model.\n" "Error raised: {}".format(e)) # Set _auto_track_sub_layers to true to ensure we track the # original_model and MinDiff layers. self._auto_track_sub_layers = True # Track sub layers. self.built = True # This Model is built, original_model may or may not be. # Masking, if any, is taken care of by original_model. self._supports_masking = False # Clear input_spec in case there is one. We cannot make any strong # assertions because `min_diff_data` may or may not be included and can # have different shapes since weight is optional. self.input_spec = None self._original_model = original_model structure_utils.validate_min_diff_structure(loss, struct_name="loss") self._loss = tf.nest.map_structure(loss_utils._get_loss, loss) structure_utils.validate_min_diff_structure(loss_weight, struct_name="loss_weight") self._loss_weight = _conform_weights_to_losses(self._loss, loss_weight, default_value=1.0) self._min_diff_loss_metric = _create_unique_metrics( self._loss, self.metrics) if (predictions_transform is not None and not callable(predictions_transform)): raise ValueError( "`predictions_transform` must be callable if passed " "in, given: {}".format(predictions_transform)) self._predictions_transform = predictions_transform