def aggregation_dci(matrix, ys): """Aggregation function of the DCI Disentanglement.""" del ys score = {} score["dci_disentanglement"] = dci.disentanglement(matrix) score["dci_completeness"] = dci.completeness(matrix) score["dci"] = dci.disentanglement(matrix) disentanglement_per_code = dci.disentanglement_per_code(matrix) completeness_per_factor = dci.completeness_per_factor(matrix) assert len(disentanglement_per_code) == matrix.shape[0], "Wrong length." assert len(completeness_per_factor) == matrix.shape[1], "Wrong length." for i in range(len(disentanglement_per_code)): score["dci_disentanglement.code_{}".format(i)] = disentanglement_per_code[i] for i in range(len(completeness_per_factor)): score["dci_completeness.code_{}".format(i)] = completeness_per_factor[i] return score
def test_diagonal_empty_codes(self): importance_matrix = np.array([[ 1., 0., ], [0., 1.], [0., 0.]]) result = dci.disentanglement(importance_matrix) np.testing.assert_allclose(result, 1.0)
def test_one_code_two_factors(self): importance_matrix = np.diag(5.*np.ones(5)) importance_matrix = np.hstack([importance_matrix, importance_matrix]) result = dci.disentanglement(importance_matrix) np.testing.assert_allclose(result, 1. - np.log(2)/np.log(10))
def test_missed_factors(self): importance_matrix = np.diag(5.*np.ones(5)) result = dci.disentanglement(importance_matrix[:2, :]) np.testing.assert_allclose(result, 1.0)
def test_redundant_codes(self): importance_matrix = np.diag(5.*np.ones(5)) importance_matrix = np.vstack([importance_matrix, importance_matrix]) result = dci.disentanglement(importance_matrix) np.testing.assert_allclose(result, 1.)
def test_zero(self): importance_matrix = np.zeros(shape=[10, 10], dtype=np.float64) result = dci.disentanglement(importance_matrix) np.testing.assert_allclose(result, .0)
def test_diagonal(self): importance_matrix = np.diag(5.*np.ones(5)) result = dci.disentanglement(importance_matrix) np.testing.assert_allclose(result, 1.0)
def main(argv, model_dir=None): del argv # Unused if model_dir is None: out_dir = FLAGS.model_name else: out_dir = model_dir z_path = '{}/z_mean.npy'.format(out_dir) if FLAGS.c_path == '': if FLAGS.data_type_dci != 'hirid': c_path = os.path.join(F'/data/{FLAGS.data_type_dci}', F'factors_{FLAGS.data_type_dci}.npz') else: c_path = os.path.join(F'/data/{FLAGS.data_type_dci}', F'{FLAGS.data_type_dci}.npz') else: c_path = FLAGS.c_path if FLAGS.data_type_dci == "physionet": # Use imputed values as ground truth for physionet data c, z = load_z_c('{}/imputed.npy'.format(out_dir), z_path) c = np.transpose(c, (0,2,1)) elif FLAGS.data_type_dci == "hirid": c = np.load(c_path)['x_test_miss'] c = np.transpose(c, (0, 2, 1)) c = c.astype(int) z = np.load(z_path) else: c, z = load_z_c(c_path, z_path) z_shape = z.shape c_shape = c.shape z_reshape = np.reshape(np.transpose(z, (0,2,1)),(z_shape[0]*z_shape[2],z_shape[1])) c_reshape = np.reshape(np.transpose(c, (0,2,1)),(c_shape[0]*c_shape[2],c_shape[1])) c_reshape = c_reshape[:z_reshape.shape[0], ...] # Experimental physionet rescaling if FLAGS.data_type_dci == 'physionet': if FLAGS.rescaling == 'linear': # linear rescaling c_rescale = 10 * c_reshape c_reshape = c_rescale.astype(int) elif FLAGS.rescaling == 'standard': # standardizing scaler = StandardScaler() c_rescale = scaler.fit_transform(c_reshape) c_reshape = (10*c_rescale).astype(int) else: raise ValueError("Rescaling must be 'linear' or 'standard'") # Include all factors in score calculation, if not specified otherwise if not FLAGS.score_factors: FLAGS.score_factors = np.arange(c_shape[1]).astype(str) # Check if ground truth factor doesn't change and remove if is the case mask = np.ones(c_reshape.shape[1], dtype=bool) for i in range(c_reshape.shape[1]): c_change = np.sum(abs(np.diff(c_reshape[:8000,i]))) if (not c_change) or (F"{i}" not in FLAGS.score_factors): mask[i] = False c_reshape = c_reshape[:,mask] print(F'C shape: {c_reshape.shape}') print(F'Z shape: {z_reshape.shape}') print(F'Shuffle: {FLAGS.shuffle}') c_train, c_test, z_train, z_test = train_test_split(c_reshape, z_reshape, test_size=0.2, shuffle=FLAGS.shuffle, random_state=FLAGS.dci_seed) if FLAGS.data_type_dci == "hirid": n_train = 20000 n_test = 5000 else: n_train = 8000 n_test = 2000 importance_matrix, i_train, i_test = dci.compute_importance_gbt( z_train[:n_train, :].transpose(), c_train[:n_train, :].transpose().astype(int), z_test[:n_test, :].transpose(), c_test[:n_test, :].transpose().astype(int)) # Calculate scores d = dci.disentanglement(importance_matrix) c = dci.completeness(importance_matrix) print(F'D: {d}') print(F'C: {c}') print(F'I: {i_test}') if FLAGS.data_type_dci in ['hirid', 'physionet']: miss_idxs = np.nonzero(np.invert(mask))[0] for idx in miss_idxs: importance_matrix = np.insert(importance_matrix, idx, 0, axis=1) assign_mat = np.load(FLAGS.assign_mat_path) impt_mat_assign = np.matmul(importance_matrix, assign_mat) impt_mat_assign_norm = np.nan_to_num( impt_mat_assign / np.sum(impt_mat_assign, axis=0)) d_assign = dci.disentanglement(impt_mat_assign_norm) c_assign = dci.completeness(impt_mat_assign_norm) print(F'D assign: {d_assign}') print(F'C assign: {c_assign}') if FLAGS.save_score: if FLAGS.data_type_dci in ['hirid', 'physionet']: np.savez(F'{out_dir}/dci_assign_2_{FLAGS.dci_seed}', informativeness_train=i_train, informativeness_test=i_test, disentanglement=d, completeness=c, disentanglement_assign=d_assign, completeness_assign=c_assign) else: np.savez(F'{out_dir}/dci_{FLAGS.dci_seed}', informativeness_train=i_train, informativeness_test=i_test, disentanglement=d, completeness=c) # Visualization if FLAGS.visualize_score: if FLAGS.data_type_dci == 'hirid': # Visualize visualize_scores.heat_square(np.transpose(importance_matrix), out_dir, F"dci_matrix_{FLAGS.dci_seed}", "feature", "latent dim") visualize_scores.heat_square(np.transpose(impt_mat_assign_norm), out_dir, F"dci_matrix_assign_{FLAGS.dci_seed}", "feature", "latent_dim") # Save importance matrices if FLAGS.save_score: np.save(F"{out_dir}/impt_matrix_{FLAGS.dci_seed}", importance_matrix) np.save(F"{out_dir}/impt_matrix_assign_{FLAGS.dci_seed}", impt_mat_assign_norm) else: # Visualize visualize_scores.heat_square(importance_matrix, out_dir, F"dci_matrix_{FLAGS.dci_seed}", "x_axis", "y_axis") # Save importance matrices np.save(F"{out_dir}/impt_matrix_{FLAGS.dci_seed}", importance_matrix) print("Evaluation finished")