def testWithBothMinDiffWeightsNone(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x, None, None)).batch(original_batch_size) sensitive_batch_size = 3 sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, None)).batch(sensitive_batch_size) nonsensitive_batch_size = 2 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, None)).batch(nonsensitive_batch_size) dataset = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset) for batch in dataset: packed_inputs, _, _ = tf.keras.utils.unpack_x_y_sample_weight( batch) self.assertIsInstance(packed_inputs, input_utils.MinDiffPackedInputs) # Skip original batch assertions. # Skip all min_diff_data assertions except for weight. _, _, min_diff_w = tf.keras.utils.unpack_x_y_sample_weight( packed_inputs.min_diff_data) self.assertIsNone(min_diff_w)
def testWithOriginaleWeightsNone(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x, self.original_y, None)).batch(original_batch_size) sensitive_batch_size = 3 sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, self.sensitive_w)).batch(sensitive_batch_size) nonsensitive_batch_size = 1 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, self.nonsensitive_w)).batch(nonsensitive_batch_size) dataset = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset) for batch in dataset: _, _, w = tf.keras.utils.unpack_x_y_sample_weight(batch) # Skip original batch assertions except for weight. self.assertIsNone(w)
def testWithOnlyNonsensitiveWeightsNone(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x, None, None)).batch(original_batch_size) sensitive_batch_size = 3 sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, self.sensitive_w)).batch(sensitive_batch_size) nonsensitive_batch_size = 2 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, None)).batch(nonsensitive_batch_size) dataset = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset) for batch_ind, batch in enumerate(dataset): packed_inputs, _, _ = tf.keras.utils.unpack_x_y_sample_weight( batch) self.assertIsInstance(packed_inputs, input_utils.MinDiffPackedInputs) # Skip original batch assertions. # Skip all min_diff_data assertions except for weight. _, _, min_diff_w = tf.keras.utils.unpack_x_y_sample_weight( packed_inputs.min_diff_data) self.assertAllClose( min_diff_w, _get_min_diff_batch(self.sensitive_w, tf.fill([nonsensitive_batch_size, 1], 1.0), sensitive_batch_size, nonsensitive_batch_size, batch_ind))
def testWithOriginalWeightsNone(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x, self.original_y, None)).batch(original_batch_size) sensitive_batch_size = 3 sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, self.sensitive_w)).batch(sensitive_batch_size) nonsensitive_batch_size = 1 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, self.nonsensitive_w)).batch(nonsensitive_batch_size) dataset = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset) for batch in dataset: # Only validate original batch weights (other tests cover others). # Should be of length 3 despite sample_weight being None. self.assertLen(batch, 3) _, _, w = tf.keras.utils.unpack_x_y_sample_weight(batch) self.assertIsNone(w)
def testDifferentMinDiffAndOriginalStructuresRaisesError(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x["f1"], None, None)).batch(original_batch_size) sensitive_batch_sizes = [3, 5] sensitive_dataset = { key: tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, None)).batch(batch_size) for key, batch_size in zip(["k1", "k2"], sensitive_batch_sizes) } nonsensitive_batch_sizes = [1, 2] nonsensitive_dataset = { key: tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, None)).batch(batch_size) for key, batch_size in zip(["k1", "k2"], nonsensitive_batch_sizes) } with self.assertRaisesRegex( ValueError, "x component structure.*min_diff_dataset.*does not match.*" "original x structure(.|\n)*don't have the same nested structure" ): _ = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset)
def testBadDatasetCombinationRaisesError(self): # Input dataset, content doesn't matter. inputs = tf.data.Dataset.from_tensor_slices( (self.original_x, self.original_y)).batch(5) # No errors raised for correct combination of elements. _ = input_utils.pack_min_diff_data(inputs, min_diff_dataset=inputs) _ = input_utils.pack_min_diff_data( inputs, sensitive_group_dataset=inputs, nonsensitive_group_dataset=inputs) # Assert raised if no datasets provided. with self.assertRaisesRegex( ValueError, "You must either.*or.*\n\n.*sensitive_group_dataset.*" "None.*\n.*nonsensitive_group_dataset.*None.*\n" ".*min_diff_dataset.*None"): _ = input_utils.pack_min_diff_data(inputs) # Assert raised if all datasets provided. with self.assertRaisesRegex( ValueError, "You must either.*or.*\n\n.*sensitive_group_dataset.*" "Dataset.*\n.*nonsensitive_group_dataset.*Dataset.*\n" ".*min_diff_dataset.*Dataset"): _ = input_utils.pack_min_diff_data(inputs, inputs, inputs, inputs) # Assert raised if only sensitive_group_dataset provided. with self.assertRaisesRegex( ValueError, "You must either.*or.*\n\n.*sensitive_group_dataset.*" "Dataset.*\n.*nonsensitive_group_dataset.*None.*\n" ".*min_diff_dataset.*None"): _ = input_utils.pack_min_diff_data(inputs, sensitive_group_dataset=inputs) # Assert raised if only nonsensitive_group_dataset provided. with self.assertRaisesRegex( ValueError, "You must either.*or.*\n\n.*sensitive_group_dataset.*" "None.*\n.*nonsensitive_group_dataset.*Dataset.*\n" ".*min_diff_dataset.*None"): _ = input_utils.pack_min_diff_data( inputs, nonsensitive_group_dataset=inputs) # Assert raised if only nonsensitive_group_dataset is missing. with self.assertRaisesRegex( ValueError, "You must either.*or.*\n\n.*sensitive_group_dataset.*" "Dataset.*\n.*nonsensitive_group_dataset.*None.*\n" ".*min_diff_dataset.*Dataset"): _ = input_utils.pack_min_diff_data( inputs, sensitive_group_dataset=inputs, min_diff_dataset=inputs) # Assert raised if only sensitive_group_dataset is missing. with self.assertRaisesRegex( ValueError, "You must either.*or.*\n\n.*sensitive_group_dataset.*" "None.*\n.*nonsensitive_group_dataset.*Dataset.*\n" ".*min_diff_dataset.*Dataset"): _ = input_utils.pack_min_diff_data( inputs, nonsensitive_group_dataset=inputs, min_diff_dataset=inputs)
def testWithXAsTensor(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x["f1"], self.original_y, self.original_w)).batch(original_batch_size) sensitive_batch_size = 3 sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x["f1"], None, self.sensitive_w)).batch(sensitive_batch_size) nonsensitive_batch_size = 2 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x["f1"], None, self.nonsensitive_w)).batch(nonsensitive_batch_size) dataset = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset) for batch_ind, (packed_inputs, y, w) in enumerate(dataset): self.assertIsInstance(packed_inputs, input_utils.MinDiffPackedInputs) # Assert original batch is conserved self.assertAllClose( packed_inputs.original_inputs, _get_batch(self.original_x["f1"], original_batch_size, batch_ind)) self.assertAllClose( y, _get_batch(self.original_y, original_batch_size, batch_ind)) self.assertAllClose( w, _get_batch(self.original_w, original_batch_size, batch_ind)) # Assert min_diff batch properly formed. min_diff_x, min_diff_membership, min_diff_w = packed_inputs.min_diff_data self.assertAllClose( min_diff_x, _get_min_diff_batch(self.sensitive_x["f1"], self.nonsensitive_x["f1"], sensitive_batch_size, nonsensitive_batch_size, batch_ind)) self.assertAllClose( min_diff_membership, _get_min_diff_membership_batch(sensitive_batch_size, nonsensitive_batch_size)) self.assertAllClose( min_diff_w, _get_min_diff_batch(self.sensitive_w, self.nonsensitive_w, sensitive_batch_size, nonsensitive_batch_size, batch_ind))
def testInvalidStructureRaisesError(self): # Input dataset, content doesn't matter. inputs = tf.data.Dataset.from_tensor_slices( (self.original_x, self.original_y)).batch(5) nested_inputs = {"a": inputs, "b": inputs} bad_nested_inputs = {"a": inputs, "b": [inputs]} # No errors raised for valid nested structures. _ = input_utils.pack_min_diff_data(inputs, min_diff_dataset=nested_inputs) _ = input_utils.pack_min_diff_data( inputs, sensitive_group_dataset=nested_inputs, nonsensitive_group_dataset=nested_inputs) # Assert raises error for invalid min_diff_dataset structure. with self.assertRaisesRegex( ValueError, "min_diff_dataset.*unnested.*only " "elements of type.*Dataset.*Given"): _ = input_utils.pack_min_diff_data( inputs, min_diff_dataset=bad_nested_inputs) # Assert raises error for invalid sensitive_group_dataset structure. with self.assertRaisesRegex( ValueError, "sensitive_group_dataset.*unnested" ".*only elements of type.*Dataset.*Given"): _ = input_utils.pack_min_diff_data( inputs, sensitive_group_dataset=bad_nested_inputs, nonsensitive_group_dataset=nested_inputs) # Assert raises error for invalid nonsensitive_group_dataset structure. with self.assertRaisesRegex( ValueError, "nonsensitive_group_dataset.*unnested.*only elements of " "type.*Dataset.*Given"): _ = input_utils.pack_min_diff_data( inputs, sensitive_group_dataset=nested_inputs, nonsensitive_group_dataset=bad_nested_inputs) # Assert raises error for different sensitive and nonsensitive structures. different_nested_inputs = {"a": inputs, "c": inputs} with self.assertRaisesRegex( ValueError, "sensitive_group_dataset.*" "nonsensitive_group_dataset.*do " "not have the same structure(.|\n)*don't have the same set of keys" ): _ = input_utils.pack_min_diff_data( inputs, sensitive_group_dataset=nested_inputs, nonsensitive_group_dataset=different_nested_inputs)
def testDifferentMinDiffAndOriginalStructuresRaisesError(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x["f1"], None, None)).batch(original_batch_size) sensitive_batch_size = 3 sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, None)).batch(sensitive_batch_size) nonsensitive_batch_size = 2 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, None)).batch(nonsensitive_batch_size) with self.assertRaisesRegex(ValueError, "don't have the same nested structure"): _ = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset)
def testDifferentSensitiveAndNonsensitivetructuresRaisesError(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x, None, None)).batch(original_batch_size) sensitive_batch_size = 3 sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, None)).batch(sensitive_batch_size) nonsensitive_batch_size = 2 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x["f1"], None, None)).batch(nonsensitive_batch_size) with self.assertRaisesRegex( ValueError, "x component structure.*sensitive_group_dataset.*does not " "match.*nonsensitive_group_dataset(.|\n)*don't have the same nested " "structure"): _ = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset)
def testDifferentWeightsShapeRaisesError(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x, None, None)).batch(original_batch_size) sensitive_batch_size = 3 # Create weights with different shape. sensitive_w = self.sensitive_w[:, tf.newaxis, :] sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, sensitive_w)).batch(sensitive_batch_size) nonsensitive_batch_size = 2 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, self.nonsensitive_w)).batch(nonsensitive_batch_size) with self.assertRaisesRegex(ValueError, "must be rank.*but is rank"): _ = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset)
def testWithoutOriginalLabels(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( self.original_x).batch(original_batch_size) sensitive_batch_size = 3 sensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, None)).batch(sensitive_batch_size) nonsensitive_batch_size = 1 nonsensitive_dataset = tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, None)).batch(nonsensitive_batch_size) dataset = input_utils.pack_min_diff_data(original_dataset, sensitive_dataset, nonsensitive_dataset) for batch in dataset: # Only validate original batch weights (other tests cover others). # Should not be a tuple. self.assertIsInstance(batch, input_utils.MinDiffPackedInputs) _, _, w = tf.keras.utils.unpack_x_y_sample_weight(batch) self.assertIsNone(w)
def testPackDictsOfDatasets(self): original_batch_size = 5 original_dataset = tf.data.Dataset.from_tensor_slices( (self.original_x, self.original_y, self.original_w)).batch(original_batch_size) sensitive_batch_sizes = [3, 5] sensitive_dataset = { key: tf.data.Dataset.from_tensor_slices( (self.sensitive_x, None, self.sensitive_w)).batch(batch_size) for key, batch_size in zip(["k1", "k2"], sensitive_batch_sizes) } nonsensitive_batch_sizes = [1, 2] nonsensitive_dataset = { key: tf.data.Dataset.from_tensor_slices( (self.nonsensitive_x, None, self.nonsensitive_w)).batch(batch_size) for key, batch_size in zip(["k1", "k2"], nonsensitive_batch_sizes) } dataset = input_utils.pack_min_diff_data( original_dataset, min_diff_dataset=input_utils.build_min_diff_dataset( sensitive_dataset, nonsensitive_dataset)) for batch_ind, (packed_inputs, y, w) in enumerate(dataset): self.assertIsInstance(packed_inputs, input_utils.MinDiffPackedInputs) # Assert original batch is conserved self.assertAllClose( packed_inputs.original_inputs, _get_batch(self.original_x, original_batch_size, batch_ind)) self.assertAllClose( y, _get_batch(self.original_y, original_batch_size, batch_ind)) self.assertAllClose( w, _get_batch(self.original_w, original_batch_size, batch_ind)) min_diff_keys = sorted(packed_inputs.min_diff_data.keys()) # Assert min_diff_batches has the right structure (i.e. set of keys). self.assertAllEqual(min_diff_keys, ["k1", "k2"]) min_diff_batches = [ packed_inputs.min_diff_data[key] for key in min_diff_keys ] for sensitive_batch_size, nonsensitive_batch_size, min_diff_batch in zip( sensitive_batch_sizes, nonsensitive_batch_sizes, min_diff_batches): # Assert min_diff batch properly formed. min_diff_x, min_diff_membership, min_diff_w = min_diff_batch self.assertAllClose( min_diff_x, _get_min_diff_batch(self.sensitive_x, self.nonsensitive_x, sensitive_batch_size, nonsensitive_batch_size, batch_ind)) self.assertAllClose( min_diff_membership, _get_min_diff_membership_batch(sensitive_batch_size, nonsensitive_batch_size)) self.assertAllClose( min_diff_w, _get_min_diff_batch(self.sensitive_w, self.nonsensitive_w, sensitive_batch_size, nonsensitive_batch_size, batch_ind))