def test_metric(self): ground_truth_data = dummy_data.IdentityObservationsData() representation_function = lambda x: np.array(x, dtype=np.float64) random_state = np.random.RandomState(0) scores = dci.compute_dci(ground_truth_data, representation_function, random_state, None, 1000, 1000) self.assertBetween(scores["disentanglement"], 0.9, 1.0) self.assertBetween(scores["completeness"], 0.9, 1.0)
def test_duplicated_latent_space(self): ground_truth_data = dummy_data.IdentityObservationsData() def representation_function(x): x = np.array(x, dtype=np.float64) return np.hstack([x, x]) random_state = np.random.RandomState(0) scores = dci.compute_dci( ground_truth_data, representation_function, random_state, None, 1000, 1000) self.assertBetween(scores["disentanglement"], 0.9, 1.0) target = 1. - np.log(2)/np.log(10) self.assertBetween(scores["completeness"], target-.1, target+.1)
def test_bad_metric(self): ground_truth_data = dummy_data.IdentityObservationsData() random_state_rep = np.random.RandomState(0) # The representation which randomly permutes the factors, should have equal # non-zero importance which should give a low modularity score. def representation_function(x): code = np.array(x, dtype=np.float64) for i in range(code.shape[0]): code[i, :] = random_state_rep.permutation(code[i, :]) return code random_state = np.random.RandomState(0) scores = dci.compute_dci( ground_truth_data, representation_function, random_state, None, 1000, 1000) self.assertBetween(scores["disentanglement"], 0.0, 0.2) self.assertBetween(scores["completeness"], 0.0, 0.2)
def evaluate(net, dataset=None, beta_VAE_score=False, dci_score=False, factor_VAE_score=False, MIG=False, print_txt=False, txt_name="metric.json"): def _representation(x): x = torch.from_numpy(x).float().cuda() x = x.permute(0, 3, 1, 2) z = net(x.contiguous()).squeeze() return z.detach().cpu().numpy() if beta_VAE_score: with gin.unlock_config(): from evaluation.metrics.beta_vae import compute_beta_vae_sklearn gin.bind_parameter("beta_vae_sklearn.batch_size", 64) gin.bind_parameter("beta_vae_sklearn.num_train", 10000) gin.bind_parameter("beta_vae_sklearn.num_eval", 5000) result_dict = compute_beta_vae_sklearn( dataset, _representation, random_state=np.random.RandomState(0), artifact_dir=None) print("beta VAE score:" + str(result_dict)) write_text("beta_VAE_score", result_dict, print_txt, net.model_name + txt_name) gin.clear_config() if dci_score: from evaluation.metrics.dci import compute_dci with gin.unlock_config(): gin.bind_parameter("dci.num_train", 10000) gin.bind_parameter("dci.num_test", 5000) result_dict = compute_dci(dataset, _representation, random_state=np.random.RandomState(0), artifact_dir=None) print("dci score:" + str(result_dict)) write_text("dci_score", result_dict, print_txt, net.model_name + txt_name) gin.clear_config() if factor_VAE_score: with gin.unlock_config(): from evaluation.metrics.factor_vae import compute_factor_vae gin.bind_parameter("factor_vae_score.num_variance_estimate", 10000) gin.bind_parameter("factor_vae_score.num_train", 10000) gin.bind_parameter("factor_vae_score.num_eval", 5000) gin.bind_parameter("factor_vae_score.batch_size", 64) gin.bind_parameter("prune_dims.threshold", 0.05) result_dict = compute_factor_vae(dataset, _representation, random_state=np.random.RandomState(0), artifact_dir=None) print("factor VAE score:" + str(result_dict)) write_text("factor_VAE_score", result_dict, print_txt, net.model_name + txt_name) gin.clear_config() if MIG: with gin.unlock_config(): from evaluation.metrics.mig import compute_mig from evaluation.metrics.utils import _histogram_discretize gin.bind_parameter("mig.num_train", 10000) gin.bind_parameter("discretizer.discretizer_fn", _histogram_discretize) gin.bind_parameter("discretizer.num_bins", 20) result_dict = compute_mig(dataset, _representation, random_state=np.random.RandomState(0), artifact_dir=None) print("MIG score:" + str(result_dict)) write_text("MIG", result_dict, print_txt, net.model_name + txt_name) gin.clear_config()