コード例 #1
0
 def test_metric(self):
     ground_truth_data = dummy_data.IdentityObservationsData()
     representation_function = lambda x: np.array(x, dtype=np.float64)
     random_state = np.random.RandomState(0)
     scores = dci.compute_dci(ground_truth_data, representation_function,
                              random_state, None, 1000, 1000)
     self.assertBetween(scores["disentanglement"], 0.9, 1.0)
     self.assertBetween(scores["completeness"], 0.9, 1.0)
コード例 #2
0
 def test_duplicated_latent_space(self):
   ground_truth_data = dummy_data.IdentityObservationsData()
   def representation_function(x):
     x = np.array(x, dtype=np.float64)
     return np.hstack([x, x])
   random_state = np.random.RandomState(0)
   scores = dci.compute_dci(
       ground_truth_data, representation_function, random_state, None, 1000,
       1000)
   self.assertBetween(scores["disentanglement"], 0.9, 1.0)
   target = 1. - np.log(2)/np.log(10)
   self.assertBetween(scores["completeness"], target-.1, target+.1)
コード例 #3
0
 def test_bad_metric(self):
   ground_truth_data = dummy_data.IdentityObservationsData()
   random_state_rep = np.random.RandomState(0)
   # The representation which randomly permutes the factors, should have equal
   # non-zero importance which should give a low modularity score.
   def representation_function(x):
     code = np.array(x, dtype=np.float64)
     for i in range(code.shape[0]):
       code[i, :] = random_state_rep.permutation(code[i, :])
     return code
   random_state = np.random.RandomState(0)
   scores = dci.compute_dci(
       ground_truth_data, representation_function, random_state, None, 1000,
       1000)
   self.assertBetween(scores["disentanglement"], 0.0, 0.2)
   self.assertBetween(scores["completeness"], 0.0, 0.2)
コード例 #4
0
ファイル: evaluate.py プロジェクト: ThomasMrY/DisCo
def evaluate(net,
             dataset=None,
             beta_VAE_score=False,
             dci_score=False,
             factor_VAE_score=False,
             MIG=False,
             print_txt=False,
             txt_name="metric.json"):
    def _representation(x):
        x = torch.from_numpy(x).float().cuda()
        x = x.permute(0, 3, 1, 2)
        z = net(x.contiguous()).squeeze()
        return z.detach().cpu().numpy()

    if beta_VAE_score:
        with gin.unlock_config():
            from evaluation.metrics.beta_vae import compute_beta_vae_sklearn
            gin.bind_parameter("beta_vae_sklearn.batch_size", 64)
            gin.bind_parameter("beta_vae_sklearn.num_train", 10000)
            gin.bind_parameter("beta_vae_sklearn.num_eval", 5000)
        result_dict = compute_beta_vae_sklearn(
            dataset,
            _representation,
            random_state=np.random.RandomState(0),
            artifact_dir=None)
        print("beta VAE score:" + str(result_dict))
        write_text("beta_VAE_score", result_dict, print_txt,
                   net.model_name + txt_name)
        gin.clear_config()
    if dci_score:
        from evaluation.metrics.dci import compute_dci
        with gin.unlock_config():
            gin.bind_parameter("dci.num_train", 10000)
            gin.bind_parameter("dci.num_test", 5000)
        result_dict = compute_dci(dataset,
                                  _representation,
                                  random_state=np.random.RandomState(0),
                                  artifact_dir=None)
        print("dci score:" + str(result_dict))
        write_text("dci_score", result_dict, print_txt,
                   net.model_name + txt_name)
        gin.clear_config()
    if factor_VAE_score:
        with gin.unlock_config():
            from evaluation.metrics.factor_vae import compute_factor_vae
            gin.bind_parameter("factor_vae_score.num_variance_estimate", 10000)
            gin.bind_parameter("factor_vae_score.num_train", 10000)
            gin.bind_parameter("factor_vae_score.num_eval", 5000)
            gin.bind_parameter("factor_vae_score.batch_size", 64)
            gin.bind_parameter("prune_dims.threshold", 0.05)
        result_dict = compute_factor_vae(dataset,
                                         _representation,
                                         random_state=np.random.RandomState(0),
                                         artifact_dir=None)
        print("factor VAE score:" + str(result_dict))
        write_text("factor_VAE_score", result_dict, print_txt,
                   net.model_name + txt_name)
        gin.clear_config()
    if MIG:
        with gin.unlock_config():
            from evaluation.metrics.mig import compute_mig
            from evaluation.metrics.utils import _histogram_discretize
            gin.bind_parameter("mig.num_train", 10000)
            gin.bind_parameter("discretizer.discretizer_fn",
                               _histogram_discretize)
            gin.bind_parameter("discretizer.num_bins", 20)
        result_dict = compute_mig(dataset,
                                  _representation,
                                  random_state=np.random.RandomState(0),
                                  artifact_dir=None)
        print("MIG score:" + str(result_dict))
        write_text("MIG", result_dict, print_txt, net.model_name + txt_name)
        gin.clear_config()