def test_tfdata(self): ground_truth_data = dummy_data.DummyData() dataset = util.tf_data_set_from_ground_truth_data(ground_truth_data, 0) one_shot_iterator = dataset.make_one_shot_iterator() next_element = one_shot_iterator.get_next() with self.test_session() as sess: for _ in range(10): sess.run(next_element)
def test_metric(self): gin.bind_parameter("predictor.predictor_fn", utils.gradient_boosting_classifier) ground_truth_data = dummy_data.DummyData() def representation_function(x): return np.array(x, dtype=np.float64)[:, :, 0, 0] random_state = np.random.RandomState(0) _ = fairness.compute_fairness(ground_truth_data, representation_function, random_state, None, 1000, 1000)
def get_named_ground_truth_data(name): """Returns ground truth data set based on name. Args: name: String with the name of the dataset. Raises: ValueError: if an invalid data set name is provided. """ if name == "dsprites_full": return dsprites.DSprites([1, 2, 3, 4, 5]) elif name == "dsprites_noshape": return dsprites.DSprites([2, 3, 4, 5]) elif name == "color_dsprites": return dsprites.ColorDSprites([1, 2, 3, 4, 5]) elif name == "noisy_dsprites": return dsprites.NoisyDSprites([1, 2, 3, 4, 5]) elif name == "scream_dsprites": return dsprites.ScreamDSprites([1, 2, 3, 4, 5]) elif name == "smallnorb": return norb.SmallNORB() elif name == "cars3d": return cars3d.Cars3D() elif name == "mpi3d_toy": return mpi3d.MPI3D(mode="mpi3d_toy") elif name == "mpi3d_realistic": return mpi3d.MPI3D(mode="mpi3d_realistic") elif name == "mpi3d_real": return mpi3d.MPI3D(mode="mpi3d_real") elif name == "shapes3d": return shapes3d.Shapes3D() elif name == "dummy_data": return dummy_data.DummyData() else: raise ValueError("Invalid data set name.")