コード例 #1
0
ファイル: udr_test.py プロジェクト: yding5/yukun_dlib
    def test_metric_kl(self):
        ground_truth_data = dummy_data.DummyData()
        random_state = np.random.RandomState(0)
        num_factors = ground_truth_data.num_factors
        batch_size = 10
        num_data_points = 1000

        # Representation without KL Mask where only first latent is valid.
        def rep_fn(data):
            rep = np.concatenate([
                np.reshape(data, (batch_size, -1))[:, :1],
                np.random.random_sample((batch_size, num_factors - 1))
            ],
                                 axis=1)
            kl_mask = np.zeros(num_factors)
            kl_mask[0] = 1.0
            return rep, kl_mask

        scores = udr.compute_udr_sklearn(ground_truth_data, [rep_fn, rep_fn],
                                         random_state,
                                         batch_size,
                                         num_data_points,
                                         filter_low_kl=False)
        self.assertBetween(scores["model_scores"][0], 0.0, 0.2)
        self.assertBetween(scores["model_scores"][1], 0.0, 0.2)

        scores = udr.compute_udr_sklearn(ground_truth_data, [rep_fn, rep_fn],
                                         random_state,
                                         batch_size,
                                         num_data_points,
                                         filter_low_kl=True)
        self.assertBetween(scores["model_scores"][0], 0.8, 1.0)
        self.assertBetween(scores["model_scores"][1], 0.8, 1.0)
コード例 #2
0
ファイル: udr_test.py プロジェクト: yding5/yukun_dlib
    def test_metric_lasso(self):
        ground_truth_data = dummy_data.DummyData()
        random_state = np.random.RandomState(0)
        num_factors = ground_truth_data.num_factors
        batch_size = 10
        num_data_points = 1000

        permutation = np.random.permutation(num_factors)
        sign_inverse = np.random.choice(num_factors, int(num_factors / 2))

        def rep_fn1(data):
            return (np.reshape(data, (batch_size, -1))[:, :num_factors],
                    np.ones(num_factors))

        # Should be invariant to permutation and sign inverse.
        def rep_fn2(data):
            raw_representation = np.reshape(data,
                                            (batch_size, -1))[:, :num_factors]
            perm_rep = raw_representation[:, permutation]
            perm_rep[:, sign_inverse] = -1.0 * perm_rep[:, sign_inverse]
            return perm_rep, np.ones(num_factors)

        scores = udr.compute_udr_sklearn(ground_truth_data, [rep_fn1, rep_fn2],
                                         random_state,
                                         batch_size,
                                         num_data_points,
                                         correlation_matrix="lasso")
        self.assertBetween(scores["model_scores"][0], 0.8, 1.0)
        self.assertBetween(scores["model_scores"][1], 0.8, 1.0)
コード例 #3
0
 def test_tfdata(self):
   ground_truth_data = dummy_data.DummyData()
   dataset = util.tf_data_set_from_ground_truth_data(ground_truth_data, 0)
   one_shot_iterator = dataset.make_one_shot_iterator()
   next_element = one_shot_iterator.get_next()
   with self.test_session() as sess:
     for _ in range(10):
       sess.run(next_element)
コード例 #4
0
 def test_metric(self):
   gin.bind_parameter("predictor.predictor_fn",
                      utils.gradient_boosting_classifier)
   ground_truth_data = dummy_data.DummyData()
   def representation_function(x):
     return np.array(x, dtype=np.float64)[:, :, 0, 0]
   random_state = np.random.RandomState(0)
   _ = fairness.compute_fairness(ground_truth_data, representation_function,
                                 random_state, None, 1000, 1000)
コード例 #5
0
def get_named_ground_truth_data(name):
  """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

  if name == "dsprites_full":
    return dsprites.DSprites([1, 2, 3, 4, 5])
  elif name == "dsprites_noshape":
    return dsprites.DSprites([2, 3, 4, 5])
  elif name == "color_dsprites":
    return dsprites.ColorDSprites([1, 2, 3, 4, 5])
  elif name == "noisy_dsprites":
    return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
  elif name == "scream_dsprites":
    return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
  elif name == "smallnorb":
    return norb.SmallNORB()
  elif name == "cars3d":
    return cars3d.Cars3D()
  elif name == "mpi3d_toy":
    return mpi3d.MPI3D(mode="mpi3d_toy")
  elif name == "mpi3d_realistic":
    return mpi3d.MPI3D(mode="mpi3d_realistic")
  elif name == "mpi3d_real":
    return mpi3d.MPI3D(mode="mpi3d_real")
  elif name == "shapes3d":
    return shapes3d.Shapes3D()
  elif name == "dummy_data":
    return dummy_data.DummyData()
  elif name == "faces":
    return faces.Faces()
  elif name == "celeba":
    return celeba.CelebA(celeba_path="/hdd/dsvae/img_align_celeba", num_samples=100000)
  elif name == "celebaHR":
    return celebaHR.CelebAHR(celeba_path="/hdd/dsvae/img_align_celeba", num_samples=10000)
  elif name == "chairs":
    return chairs.Chairs()
  else:
    raise ValueError("Invalid data set name.")