def compute_dci(dataholder, random_state, artifact_dir=None, num_train=gin.REQUIRED, num_test=gin.REQUIRED, num_eval=gin.REQUIRED, mode=gin.REQUIRED): """Computes the DCI scores according to Sec 2. Args: ground_truth_data: GroundTruthData to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. random_state: Numpy random state used for randomness. artifact_dir: Optional path to directory where artifacts can be saved. num_train: Number of points used for training. num_test: Number of points used for testing. batch_size: Batch size for sampling. Returns: Dictionary with average disentanglement score, completeness and informativeness (train and test). """ del artifact_dir logging.info("Generating training set.") if mode == "RF_class" or mode == "LogRegL1": continuous = False if mode == "RF_reg" or mode == "Lasso": continuous = True # mus_train are of shape [num_codes, num_train], while ys_train are of shape # [num_factors, num_train]. mus_train, ys_train = utils.generate_batch_factor_code(dataholder, num_train, random_state, num_train, continuous=continuous) assert mus_train.shape[1] == num_train assert ys_train.shape[1] == num_train mus_test, ys_test = utils.generate_batch_factor_code( dataholder, num_test, random_state, num_test, continuous=continuous) mus_eval, ys_eval = utils.generate_batch_factor_code( dataholder, num_eval, random_state, num_eval, continuous=continuous) scores = _compute_dci(random_state, mus_train, ys_train, mus_test, ys_test, mus_eval, ys_eval, mode) return scores
def compute_mig(dataholder, random_state, artifact_dir=None, num_train=gin.REQUIRED): """Computes the mutual information gap. Args: ground_truth_data: GroundTruthData to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. random_state: Numpy random state used for randomness. artifact_dir: Optional path to directory where artifacts can be saved. num_train: Number of points used for training. batch_size: Batch size for sampling. Returns: Dict with average mutual information gap. """ del artifact_dir logging.info("Generating training set.") mus_train, ys_train = utils.generate_batch_factor_code( dataholder, num_train, random_state, num_train) assert mus_train.shape[1] == num_train return _compute_mig(dataholder, mus_train, ys_train)
def compute_sap(dataholder, random_state, artifact_dir=None, num_train=gin.REQUIRED, num_test=gin.REQUIRED, continuous_factors=gin.REQUIRED): """Computes the SAP score. Args: ground_truth_data: GroundTruthData to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. random_state: Numpy random state used for randomness. artifact_dir: Optional path to directory where artifacts can be saved. num_train: Number of points used for training. num_test: Number of points used for testing discrete variables. batch_size: Batch size for sampling. continuous_factors: Factors are continuous variable (True) or not (False). Returns: Dictionary with SAP score. """ del artifact_dir logging.info("Generating training set.") mus, ys = utils.generate_batch_factor_code(dataholder, num_train, random_state, num_train, continuous=continuous_factors) mus_test, ys_test = utils.generate_batch_factor_code( dataholder, num_test, random_state, num_test, continuous=continuous_factors) logging.info("Computing score matrix.") return _compute_sap(random_state, mus, ys, mus_test, ys_test, continuous_factors)
def unsupervised_metrics(ground_truth_data, representation_function, random_state, artifact_dir=None, num_train=gin.REQUIRED, batch_size=16): """Computes unsupervised scores based on covariance and mutual information. Args: ground_truth_data: GroundTruthData to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. random_state: Numpy random state used for randomness. artifact_dir: Optional path to directory where artifacts can be saved. num_train: Number of points used for training. batch_size: Batch size for sampling. Returns: Dictionary with scores. """ del artifact_dir scores = {} logging.info("Generating training set.") mus_train, _ = utils.generate_batch_factor_code(ground_truth_data, representation_function, num_train, random_state, batch_size) num_codes = mus_train.shape[0] cov_mus = np.cov(mus_train) assert num_codes == cov_mus.shape[0] # Gaussian total correlation. scores["gaussian_total_correlation"] = gaussian_total_correlation(cov_mus) # Gaussian Wasserstein correlation. scores[ "gaussian_wasserstein_correlation"] = gaussian_wasserstein_correlation( cov_mus) scores["gaussian_wasserstein_correlation_norm"] = ( scores["gaussian_wasserstein_correlation"] / np.sum(np.diag(cov_mus))) # Compute average mutual information between different factors. mus_discrete, bins = utils.make_discretizer(mus_train) mutual_info_matrix = utils.discrete_mutual_info(mus_discrete, mus_discrete) np.fill_diagonal(mutual_info_matrix, 0) mutual_info_score = np.sum(mutual_info_matrix) / (num_codes**2 - num_codes) scores["mutual_info_score"] = mutual_info_score return scores
def compute_irs(dataholder, random_state, artifact_dir=None, diff_quantile=0.99, num_train=gin.REQUIRED, batch_size=gin.REQUIRED): """Computes the Interventional Robustness Score. Args: ground_truth_data: GroundTruthData to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. random_state: Numpy random state used for randomness. artifact_dir: Optional path to directory where artifacts can be saved. diff_quantile: Float value between 0 and 1 to decide what quantile of diffs to select (use 1.0 for the version in the paper). num_train: Number of points used for training. batch_size: Batch size for sampling. Returns: Dict with IRS and number of active dimensions. """ del artifact_dir logging.info("Generating training set.") mus, ys = utils.generate_batch_factor_code(dataholder, num_train, random_state, batch_size) assert mus.shape[1] == num_train ys_discrete, bins = utils.make_discretizer(ys, dataholder.cumulative_dist) active_mus = _drop_constant_dims(mus) if not active_mus.any(): irs_score = 0.0 else: irs_score = scalable_disentanglement_score(ys_discrete.T, active_mus.T, diff_quantile) score_dict = {} score_dict["IRS"] = irs_score["avg_score"] score_dict["IRS_disentanglement_scores"] = irs_score["disentanglement_scores"] score_dict["num_active_dims"] = np.sum(active_mus) return score_dict
def compute_dcimig(dataholder, random_state, artifact_dir=None, num_train=gin.REQUIRED): """Computes the mutual information gap. Args: dataholder: Holds all factors and associated representations random_state: Numpy random state used for randomness. artifact_dir: Optional path to directory where artifacts can be saved. num_train: Number of points used for training. Returns: Dict with average mutual information gap. """ del artifact_dir logging.info("Generating training set.") mus_train, ys_train = utils.generate_batch_factor_code( dataholder, num_train, random_state, num_train) assert mus_train.shape[1] == num_train return _compute_dcimig(dataholder, mus_train, ys_train)
def compute_modularity_explicitness(dataholder, random_state, artifact_dir=None, num_train=gin.REQUIRED, num_test=gin.REQUIRED, batch_size=16): """Computes the modularity metric according to Sec 3. Args: ground_truth_data: GroundTruthData to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. random_state: Numpy random state used for randomness. artifact_dir: Optional path to directory where artifacts can be saved. num_train: Number of points used for training. num_test: Number of points used for testing. batch_size: Batch size for sampling. Returns: Dictionary with average modularity score and average explicitness (train and test). """ del artifact_dir scores = {} mus_train, ys_train = utils.generate_batch_factor_code( dataholder, num_train, random_state, batch_size) mus_test, ys_test = utils.generate_batch_factor_code( dataholder, num_test, random_state, batch_size) all_mus = np.transpose(dataholder.embed_codes) all_ys = np.transpose(dataholder.factors) #New score mutual_information = get_MI_matrix(dataholder, all_mus, all_ys) # Mutual information should have shape [num_codes, num_factors]. assert mutual_information.shape[0] == mus_train.shape[0] assert mutual_information.shape[1] == ys_train.shape[0] scores["MODEX_modularity_score"] = modularity(mutual_information) # From paper : "for modularity, we report the mean across validation splits and embedding dimensions." # old implementation of disentanglement-lib used train set. # So we get results for whole dataset, train-set, and test set. #old score 1 mutual_information = get_MI_matrix(dataholder, mus_train, ys_train) scores["MODEX_modularity_oldtrain_score"] = modularity(mutual_information) #old score 2 mutual_information = get_MI_matrix(dataholder, mus_test, ys_test) scores["MODEX_modularity_oldtest_score"] = modularity(mutual_information) explicitness_score_train = np.zeros([ys_train.shape[0], 1]) explicitness_score_test = np.zeros([ys_test.shape[0], 1]) # Avoid divisions by zero for inactive dimensions std_train = np.std(mus_train, axis=1) std_train[std_train == 0] = 1 mus_train_norm, mean_mus, __ = utils.normalize_data(mus_train, stddev=std_train) mus_test_norm, _, _ = utils.normalize_data(mus_test, mean_mus, std_train) # mus_train_norm = np.nan_to_num(mus_train_norm, copy=True, nan=0.0) # mus_test_norm = np.nan_to_num(mus_test_norm, copy=True, nan=0.0) for i in range(ys_train.shape[0]): explicitness_score_train[i], explicitness_score_test[i] = \ explicitness_per_factor(random_state, mus_train_norm, ys_train[i, :], mus_test_norm, ys_test[i, :]) scores["MODEX_explicitness_score_train"] = np.mean( explicitness_score_train) scores["MODEX_explicitness_score_test"] = np.mean(explicitness_score_test) return scores