def test_bad_metric(self): gin.bind_parameter("discretizer.discretizer_fn", _identity_discretizer) gin.bind_parameter("discretizer.num_bins", 10) ground_truth_data = dummy_data.IdentityObservationsData() representation_function = lambda x: np.zeros_like(x, dtype=np.float64) random_state = np.random.RandomState(0) scores = irs.compute_irs(ground_truth_data, representation_function, random_state, 0.99, 3000, 3000) self.assertBetween(scores["IRS"], 0.0, 0.1)
np.random.RandomState(0), 64, 10000, 5000, 10000) print(' factor_vae: %.6f' % scores['eval_accuracy']) scores = dci.compute_dci(ground_truth_data, representation_fn, np.random.RandomState(0), 10000, 5000) print(' dci: %.6f' % scores['disentanglement']) scores = sap_score.compute_sap(ground_truth_data, representation_fn, np.random.RandomState(0), 10000, 5000, continuous_factors=False) print(' sap_score: %.6f' % scores['SAP_score']) import gin.tf gin.bind_parameter("discretizer.discretizer_fn", utils._histogram_discretize) gin.bind_parameter("discretizer.num_bins", 20) scores = mig.compute_mig(ground_truth_data, representation_fn, np.random.RandomState(0), 10000) print(' mig: %.6f' % scores['discrete_mig']) gin.bind_parameter("irs.batch_size", 16) scores = irs.compute_irs(ground_truth_data, representation_fn, np.random.RandomState(0), num_train=10000) print(' irs: %.6f' % scores['IRS'])