def compute_unified_score_on_fixed_data(
    observations, labels, representation_function,
    train_percentage=gin.REQUIRED, matrix_fns=gin.REQUIRED, batch_size=100):
  """Computes the unified scores on the fixed set of observations and labels.

  Args:
    observations: Observations on which to compute the score. Observations have
      shape (num_observations, 64, 64, num_channels).
    labels: Observed factors of variations.
    representation_function: Function that takes observations as input and
      outputs a dim_representation sized representation for each observation.
    train_percentage: Percentage of observations used for training.
    matrix_fns: List of functions to relate factors of variations and codes.
    batch_size: Batch size used to compute the representation.

  Returns:
    Unified scores.
  """
  mus = utils.obtain_representation(observations, representation_function,
                                    batch_size)
  assert labels.shape[1] == observations.shape[0], "Wrong labels shape."
  assert mus.shape[1] == observations.shape[0], "Wrong representation shape."

  mus_train, mus_test = utils.split_train_test(
      mus,
      train_percentage)
  ys_train, ys_test = utils.split_train_test(
      labels,
      train_percentage)
  return unified_scores(mus_train, ys_train, mus_test, ys_test, matrix_fns)
def compute_sap_on_fixed_data(observations, labels, representation_function,
                              train_percentage=gin.REQUIRED,
                              continuous_factors=gin.REQUIRED,
                              batch_size=100):
  """Computes the SAP score on the fixed set of observations and labels.

  Args:
    observations: Observations on which to compute the score. Observations have
      shape (num_observations, 64, 64, num_channels).
    labels: Observed factors of variations.
    representation_function: Function that takes observations as input and
      outputs a dim_representation sized representation for each observation.
    train_percentage: Percentage of observations used for training.
    continuous_factors: Whether factors should be considered continuous or
      discrete.
    batch_size: Batch size used to compute the representation.

  Returns:
    SAP computed on the provided observations and labels.
  """
  mus = utils.obtain_representation(observations, representation_function,
                                    batch_size)
  assert labels.shape[1] == observations.shape[0], "Wrong labels shape."
  assert mus.shape[1] == observations.shape[0], "Wrong representation shape."
  mus_train, mus_test = utils.split_train_test(
      mus,
      train_percentage)
  ys_train, ys_test = utils.split_train_test(
      labels,
      train_percentage)
  return _compute_sap(mus_train, ys_train, mus_test, ys_test,
                      continuous_factors)
Exemple #3
0
 def test_split_train_test(self):
     xs = np.zeros([10, 100])
     xs_train, xs_test = utils.split_train_test(xs, 0.9)
     shouldbe_train = np.zeros([10, 90])
     shouldbe_test = np.zeros([10, 10])
     np.testing.assert_allclose(xs_train, shouldbe_train)
     np.testing.assert_allclose(xs_test, shouldbe_test)