예제 #1
0
 def test_diagonal_empty_codes(self):
     importance_matrix = np.array([[
         1.,
         0.,
     ], [0., 1.], [0., 0.]])
     result = dci.completeness(importance_matrix)
     np.testing.assert_allclose(result, 1.0)
def aggregation_dci(matrix, ys):
  """Aggregation function of the DCI Disentanglement."""
  del ys
  score = {}
  score["dci_disentanglement"] = dci.disentanglement(matrix)
  score["dci_completeness"] = dci.completeness(matrix)
  score["dci"] = dci.disentanglement(matrix)
  disentanglement_per_code = dci.disentanglement_per_code(matrix)
  completeness_per_factor = dci.completeness_per_factor(matrix)
  assert len(disentanglement_per_code) == matrix.shape[0], "Wrong length."
  assert len(completeness_per_factor) == matrix.shape[1], "Wrong length."
  for i in range(len(disentanglement_per_code)):
    score["dci_disentanglement.code_{}".format(i)] = disentanglement_per_code[i]
  for i in range(len(completeness_per_factor)):
    score["dci_completeness.code_{}".format(i)] = completeness_per_factor[i]
  return score
예제 #3
0
 def test_one_code_two_factors(self):
   importance_matrix = np.diag(5.*np.ones(5))
   importance_matrix = np.hstack([importance_matrix, importance_matrix])
   result = dci.completeness(importance_matrix)
   np.testing.assert_allclose(result, 1.)
예제 #4
0
 def test_missed_factors(self):
   importance_matrix = np.diag(5.*np.ones(5))
   result = dci.completeness(importance_matrix[:2, :])
   np.testing.assert_allclose(result, 1.0)
예제 #5
0
 def test_redundant_codes(self):
   importance_matrix = np.diag(5.*np.ones(5))
   importance_matrix = np.vstack([importance_matrix, importance_matrix])
   result = dci.completeness(importance_matrix)
   np.testing.assert_allclose(result, 1. - np.log(2)/np.log(10))
예제 #6
0
 def test_zero(self):
   importance_matrix = np.zeros(shape=[10, 10], dtype=np.float64)
   result = dci.completeness(importance_matrix)
   np.testing.assert_allclose(result, .0, atol=1e-7)
예제 #7
0
 def test_diagonal(self):
   importance_matrix = np.diag(5.*np.ones(5))
   result = dci.completeness(importance_matrix)
   np.testing.assert_allclose(result, 1.0)
예제 #8
0
def main(argv, model_dir=None):
    del argv # Unused

    if model_dir is None:
        out_dir = FLAGS.model_name
    else:
        out_dir = model_dir

    z_path = '{}/z_mean.npy'.format(out_dir)

    if FLAGS.c_path == '':
        if FLAGS.data_type_dci != 'hirid':
            c_path = os.path.join(F'/data/{FLAGS.data_type_dci}', F'factors_{FLAGS.data_type_dci}.npz')
        else:
            c_path = os.path.join(F'/data/{FLAGS.data_type_dci}', F'{FLAGS.data_type_dci}.npz')
    else:
        c_path = FLAGS.c_path

    if FLAGS.data_type_dci == "physionet":
        # Use imputed values as ground truth for physionet data
        c, z = load_z_c('{}/imputed.npy'.format(out_dir), z_path)
        c = np.transpose(c, (0,2,1))
    elif FLAGS.data_type_dci == "hirid":
        c = np.load(c_path)['x_test_miss']
        c = np.transpose(c, (0, 2, 1))
        c = c.astype(int)
        z = np.load(z_path)
    else:
        c, z = load_z_c(c_path, z_path)

    z_shape = z.shape
    c_shape = c.shape

    z_reshape = np.reshape(np.transpose(z, (0,2,1)),(z_shape[0]*z_shape[2],z_shape[1]))
    c_reshape = np.reshape(np.transpose(c, (0,2,1)),(c_shape[0]*c_shape[2],c_shape[1]))
    c_reshape = c_reshape[:z_reshape.shape[0], ...]

    # Experimental physionet rescaling
    if FLAGS.data_type_dci == 'physionet':
        if FLAGS.rescaling == 'linear':
            # linear rescaling
            c_rescale = 10 * c_reshape
            c_reshape = c_rescale.astype(int)
        elif FLAGS.rescaling == 'standard':
            # standardizing
            scaler = StandardScaler()
            c_rescale = scaler.fit_transform(c_reshape)
            c_reshape = (10*c_rescale).astype(int)
        else:
            raise ValueError("Rescaling must be 'linear' or 'standard'")


    # Include all factors in score calculation, if not specified otherwise
    if not FLAGS.score_factors:
        FLAGS.score_factors = np.arange(c_shape[1]).astype(str)

    # Check if ground truth factor doesn't change and remove if is the case
    mask = np.ones(c_reshape.shape[1], dtype=bool)
    for i in range(c_reshape.shape[1]):
        c_change = np.sum(abs(np.diff(c_reshape[:8000,i])))
        if (not c_change) or (F"{i}" not in FLAGS.score_factors):
            mask[i] = False
    c_reshape = c_reshape[:,mask]
    print(F'C shape: {c_reshape.shape}')
    print(F'Z shape: {z_reshape.shape}')
    print(F'Shuffle: {FLAGS.shuffle}')

    c_train, c_test, z_train, z_test = train_test_split(c_reshape, z_reshape, test_size=0.2, shuffle=FLAGS.shuffle, random_state=FLAGS.dci_seed)

    if FLAGS.data_type_dci == "hirid":
        n_train = 20000
        n_test = 5000
    else:
        n_train = 8000
        n_test = 2000

    importance_matrix, i_train, i_test = dci.compute_importance_gbt(
        z_train[:n_train, :].transpose(),
        c_train[:n_train, :].transpose().astype(int),
        z_test[:n_test, :].transpose(), c_test[:n_test, :].transpose().astype(int))
    # Calculate scores
    d = dci.disentanglement(importance_matrix)
    c = dci.completeness(importance_matrix)
    print(F'D: {d}')
    print(F'C: {c}')
    print(F'I: {i_test}')

    if FLAGS.data_type_dci in ['hirid', 'physionet']:
        miss_idxs = np.nonzero(np.invert(mask))[0]
        for idx in miss_idxs:
            importance_matrix = np.insert(importance_matrix,
                                          idx,
                                          0, axis=1)
        assign_mat = np.load(FLAGS.assign_mat_path)
        impt_mat_assign = np.matmul(importance_matrix, assign_mat)
        impt_mat_assign_norm = np.nan_to_num(
            impt_mat_assign / np.sum(impt_mat_assign, axis=0))
        d_assign = dci.disentanglement(impt_mat_assign_norm)
        c_assign = dci.completeness(impt_mat_assign_norm)
        print(F'D assign: {d_assign}')
        print(F'C assign: {c_assign}')

    if FLAGS.save_score:
        if FLAGS.data_type_dci in ['hirid', 'physionet']:
            np.savez(F'{out_dir}/dci_assign_2_{FLAGS.dci_seed}', informativeness_train=i_train, informativeness_test=i_test,
                     disentanglement=d, completeness=c,
                     disentanglement_assign=d_assign, completeness_assign=c_assign)
        else:
            np.savez(F'{out_dir}/dci_{FLAGS.dci_seed}', informativeness_train=i_train, informativeness_test=i_test,
                     disentanglement=d, completeness=c)

    # Visualization
    if FLAGS.visualize_score:
        if FLAGS.data_type_dci == 'hirid':
            # Visualize
            visualize_scores.heat_square(np.transpose(importance_matrix), out_dir,
                                         F"dci_matrix_{FLAGS.dci_seed}",
                                         "feature", "latent dim")
            visualize_scores.heat_square(np.transpose(impt_mat_assign_norm), out_dir,
                                         F"dci_matrix_assign_{FLAGS.dci_seed}",
                                         "feature", "latent_dim")

            # Save importance matrices
            if FLAGS.save_score:
                np.save(F"{out_dir}/impt_matrix_{FLAGS.dci_seed}", importance_matrix)
                np.save(F"{out_dir}/impt_matrix_assign_{FLAGS.dci_seed}", impt_mat_assign_norm)

        else:
            # Visualize
            visualize_scores.heat_square(importance_matrix, out_dir,
                                         F"dci_matrix_{FLAGS.dci_seed}",
                                         "x_axis", "y_axis")
            # Save importance matrices
            np.save(F"{out_dir}/impt_matrix_{FLAGS.dci_seed}", importance_matrix)

    print("Evaluation finished")