Exemple #1
0
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
    def test_metric_lasso(self):
        ground_truth_data = dummy_data.DummyData()
        random_state = np.random.RandomState(0)
        num_factors = ground_truth_data.num_factors
        batch_size = 10
        num_data_points = 1000

        permutation = np.random.permutation(num_factors)
        sign_inverse = np.random.choice(num_factors, int(num_factors / 2))

        def rep_fn1(data):
            return (np.reshape(data, (batch_size, -1))[:, :num_factors],
                    np.ones(num_factors))

        # Should be invariant to permutation and sign inverse.
        def rep_fn2(data):
            raw_representation = np.reshape(data,
                                            (batch_size, -1))[:, :num_factors]
            perm_rep = raw_representation[:, permutation]
            perm_rep[:, sign_inverse] = -1.0 * perm_rep[:, sign_inverse]
            return perm_rep, np.ones(num_factors)

        scores = udr.compute_udr_sklearn(ground_truth_data, [rep_fn1, rep_fn2],
                                         random_state,
                                         batch_size,
                                         num_data_points,
                                         correlation_matrix="lasso")
        self.assertBetween(scores["model_scores"][0], 0.8, 1.0)
        self.assertBetween(scores["model_scores"][1], 0.8, 1.0)
    def test_metric_kl(self):
        ground_truth_data = dummy_data.DummyData()
        random_state = np.random.RandomState(0)
        num_factors = ground_truth_data.num_factors
        batch_size = 10
        num_data_points = 1000

        # Representation without KL Mask where only first latent is valid.
        def rep_fn(data):
            rep = np.concatenate([
                np.reshape(data, (batch_size, -1))[:, :1],
                np.random.random_sample((batch_size, num_factors - 1))
            ],
                                 axis=1)
            kl_mask = np.zeros(num_factors)
            kl_mask[0] = 1.0
            return rep, kl_mask

        scores = udr.compute_udr_sklearn(ground_truth_data, [rep_fn, rep_fn],
                                         random_state,
                                         batch_size,
                                         num_data_points,
                                         filter_low_kl=False)
        self.assertBetween(scores["model_scores"][0], 0.0, 0.2)
        self.assertBetween(scores["model_scores"][1], 0.0, 0.2)

        scores = udr.compute_udr_sklearn(ground_truth_data, [rep_fn, rep_fn],
                                         random_state,
                                         batch_size,
                                         num_data_points,
                                         filter_low_kl=True)
        self.assertBetween(scores["model_scores"][0], 0.8, 1.0)
        self.assertBetween(scores["model_scores"][1], 0.8, 1.0)
Exemple #4
0
    def test_perfect_labeller(self, labels, target):

        ground_truth_data = dummy_data.DummyData()
        processed_labels, _ = semi_supervised_utils.perfect_labeller(
            labels, ground_truth_data, np.random.RandomState(0))
        test_value = np.sum(np.abs(processed_labels - labels))
        self.assertEqual(test_value, target)
 def test_tfdata(self):
     ground_truth_data = dummy_data.DummyData()
     dataset = util.tf_data_set_from_ground_truth_data(ground_truth_data, 0)
     one_shot_iterator = dataset.make_one_shot_iterator()
     next_element = one_shot_iterator.get_next()
     with self.test_session() as sess:
         for _ in range(10):
             sess.run(next_element)
 def test_metric(self):
   gin.bind_parameter("predictor.predictor_fn",
                      utils.gradient_boosting_classifier)
   ground_truth_data = dummy_data.DummyData()
   def representation_function(x):
     return np.array(x, dtype=np.float64)[:, :, 0, 0]
   random_state = np.random.RandomState(0)
   _ = fairness.compute_fairness(ground_truth_data, representation_function,
                                 random_state, None, 1000, 1000)
Exemple #7
0
    def test_noisy_labeller(self, labels, target_low, target_high):

        ground_truth_data = dummy_data.DummyData()
        old_labels = labels.copy()
        processed_labels, _ = semi_supervised_utils.noisy_labeller(
            labels, ground_truth_data, np.random.RandomState(0), 0.1)
        index_equal = (processed_labels - old_labels).flatten()
        test_value = np.count_nonzero(index_equal)
        self.assertBetween(test_value, target_low, target_high)
def get_named_ground_truth_data(name,
                                corr_type='plane',
                                corr_indices=[3, 4],
                                col=None):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "correlated_dsprites_full":
        return dsprites.CorrelatedDSprites([1, 2, 3, 4, 5], corr_indices,
                                           corr_type)
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "correlated_dsprites_noshape":
        return dsprites.CorrelatedDSprites([2, 3, 4, 5], corr_indices,
                                           corr_type)
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "backgroundcolor_dsprites":
        return dsprites.BackgroundColorDSprites([1, 2, 3, 4, 5, 6], col)
    elif name == "correlated_backgroundcolor_dsprites":
        return dsprites.CorrelatedBackgroundColorDSprites([1, 2, 3, 4, 5, 6],
                                                          corr_indices,
                                                          corr_type)
    elif name == "correlated_color_dsprites":
        return dsprites.CorrelatedColorDSprites([1, 2, 3, 4, 5], corr_indices,
                                                corr_type)
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "correlated_noisy_dsprites":
        return dsprites.CorrelatedNoisyDSprites([1, 2, 3, 4, 5], corr_indices,
                                                corr_type)
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "correlated_scream_dsprites":
        return dsprites.CorrelatedScreamDSprites([1, 2, 3, 4, 5], corr_indices,
                                                 corr_type)
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name: " + name + ".")
Exemple #9
0
 def test_bin_labeller(self, labels, target, num_bins):
     labels = labels.reshape((1, 10))
     target = target.reshape((1, 10))
     ground_truth_data = dummy_data.DummyData()
     processed_labels, _ = semi_supervised_utils.bin_labeller(
         labels,
         ground_truth_data,
         np.random.RandomState(0),
         num_bins=num_bins)
     test_value = np.all(processed_labels == target)
     self.assertEqual(test_value, True)
 def test_weak_data(self):
     ground_truth_data = dummy_data.DummyData()
     binding = ["dynamics.k = 1"]
     gin.parse_config_files_and_bindings([], binding)
     dataset = \
       train_weak_lib.weak_dataset_from_ground_truth_data(
           ground_truth_data, 0)
     one_shot_iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
     next_element = one_shot_iterator.get_next()
     with self.test_session() as sess:
         elem = sess.run(next_element)
         self.assertEqual(elem[0].shape, (128, 64, 1))
         self.assertEqual(elem[1].shape, (1, ))
Exemple #11
0
    def test_task(self):
        gin.bind_parameter("predictor.predictor_fn",
                           utils.gradient_boosting_classifier)
        gin.bind_parameter("strong_downstream_task.num_train", [1000])
        gin.bind_parameter("strong_downstream_task.num_test", 1000)
        gin.bind_parameter("strong_downstream_task.n_experiment", 2)
        ground_truth_data = dummy_data.DummyData()

        def representation_function(x):
            return np.array(x, dtype=np.float64)[:, :, 0, 0]

        random_state = np.random.RandomState(0)
        scores = strong_downstream_task.compute_strong_downstream_task(
            ground_truth_data,
            representation_function,
            random_state,
            artifact_dir=None)
        self.assertBetween(scores["1000:mean_strong_test_accuracy"], 0.0, 0.3)
Exemple #12
0
    def test_intervene(self):
        ground_truth_data = dummy_data.DummyData()

        random_state = np.random.RandomState(0)

        ys_train = ground_truth_data.sample_factors(1000, random_state)
        ys_test = ground_truth_data.sample_factors(1000, random_state)
        num_factors = ys_train.shape[1]
        for i in range(num_factors):
            (y_train_int, y_test_int, interv_factor,
             factor_interv_train) = strong_downstream_task.intervene(
                 ys_train.copy(), ys_test.copy(), i, num_factors,
                 ground_truth_data)
            assert interv_factor != i, "Wrong factor interevened on."
            assert (y_train_int[:, interv_factor] == factor_interv_train
                    ).all(), "Training set not intervened on."
            assert (y_test_int[:, interv_factor] != factor_interv_train
                    ).all(), "Training set not intervened on."
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    elif name == "modelnet":
        return DisLibGroundTruthData(**MODELNET_PARAMS)
    elif name == "arrow":
        return DisLibGroundTruthData(**ARROW_PARAMS)
    elif name == "pixel4":
        return DisLibGroundTruthData(**WRAPPED_PIXEL4_PARAMS)
    elif name == "pixel8":
        return DisLibGroundTruthData(**WRAPPED_PIXEL8_PARAMS)
    else:
        raise ValueError("Invalid data set name.")
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "threeDotsCache":
        # a large random sample from ThreeDots
        return threeDots.ThreeDotsTrainingCache()
    elif name == "threeDots":
        return threeDots.ThreeDots()
    elif name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "mpi3d_multi_real":
        return mpi3d_multi.MPI3DMulti(mode="mpi3d_real")
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
Exemple #15
0
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """
    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "3dshapes":
        return shapes3d.Shapes3D()
    elif name == "3dshapes_holdout" or name == "3dshapes_pca_holdout_s5000":
        return shapes3d_partial.Shapes3DPartial(name)
    elif name == "3dshapes_model_all":
        return shapes3d_partial.Shapes3DPartial(name), None
    elif name[:8] == "3dshapes":
        return shapes3d_partial.Shapes3DPartial(
            name + '_train'), shapes3d_partial.Shapes3DPartial(name + '_valid')
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
Exemple #16
0
 def test_semi_supervised_data(self):
     num_labels = 1000
     gin.clear_config()
     gin_bindings = ["labeller.labeller_fn = @perfect_labeller"]
     gin.parse_config_files_and_bindings([], gin_bindings)
     ground_truth_data = dummy_data.DummyData()
     (sampled_observations, sampled_factors,
      _) = semi_supervised_utils.sample_supervised_data(
          0, ground_truth_data, num_labels)
     dataset = train_s2_lib.semi_supervised_dataset_from_ground_truth_data(
         ground_truth_data, num_labels, 0, sampled_observations,
         sampled_factors)
     one_shot_iterator = dataset.make_one_shot_iterator()
     next_element = one_shot_iterator.get_next()
     with self.test_session() as sess:
         for _ in range(1):
             elem = sess.run(next_element)
             self.assertEqual(elem[0].shape, (64, 64, 1))
             self.assertEqual(elem[1][0].shape, (64, 64, 1))
             self.assertLen(elem[1][1], 10)