def compute_unified_score_on_fixed_data( observations, labels, representation_function, train_percentage=gin.REQUIRED, matrix_fns=gin.REQUIRED, batch_size=100): """Computes the unified scores on the fixed set of observations and labels. Args: observations: Observations on which to compute the score. Observations have shape (num_observations, 64, 64, num_channels). labels: Observed factors of variations. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. train_percentage: Percentage of observations used for training. matrix_fns: List of functions to relate factors of variations and codes. batch_size: Batch size used to compute the representation. Returns: Unified scores. """ mus = utils.obtain_representation(observations, representation_function, batch_size) assert labels.shape[1] == observations.shape[0], "Wrong labels shape." assert mus.shape[1] == observations.shape[0], "Wrong representation shape." mus_train, mus_test = utils.split_train_test( mus, train_percentage) ys_train, ys_test = utils.split_train_test( labels, train_percentage) return unified_scores(mus_train, ys_train, mus_test, ys_test, matrix_fns)
def compute_sap_on_fixed_data(observations, labels, representation_function, train_percentage=gin.REQUIRED, continuous_factors=gin.REQUIRED, batch_size=100): """Computes the SAP score on the fixed set of observations and labels. Args: observations: Observations on which to compute the score. Observations have shape (num_observations, 64, 64, num_channels). labels: Observed factors of variations. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. train_percentage: Percentage of observations used for training. continuous_factors: Whether factors should be considered continuous or discrete. batch_size: Batch size used to compute the representation. Returns: SAP computed on the provided observations and labels. """ mus = utils.obtain_representation(observations, representation_function, batch_size) assert labels.shape[1] == observations.shape[0], "Wrong labels shape." assert mus.shape[1] == observations.shape[0], "Wrong representation shape." mus_train, mus_test = utils.split_train_test( mus, train_percentage) ys_train, ys_test = utils.split_train_test( labels, train_percentage) return _compute_sap(mus_train, ys_train, mus_test, ys_test, continuous_factors)
def test_split_train_test(self): xs = np.zeros([10, 100]) xs_train, xs_test = utils.split_train_test(xs, 0.9) shouldbe_train = np.zeros([10, 90]) shouldbe_test = np.zeros([10, 10]) np.testing.assert_allclose(xs_train, shouldbe_train) np.testing.assert_allclose(xs_test, shouldbe_test)