def test_basics(self): dim_in = 10 dim_out = 3 num_components = 7 num_scales = 5 num_features = 50 num_samples = 100 # create model mcgsm = MCGSM(dim_in, dim_out, num_components, num_scales, num_features) # generate output input = randn(dim_in, num_samples) output = mcgsm.sample(input) loglik = mcgsm.loglikelihood(input, output) post = mcgsm.posterior(input, output) samples = mcgsm.sample_posterior(input, output) # check hyperparameters self.assertEqual(mcgsm.dim_in, dim_in) self.assertEqual(mcgsm.dim_out, dim_out) self.assertEqual(mcgsm.num_components, num_components) self.assertEqual(mcgsm.num_scales, num_scales) self.assertEqual(mcgsm.num_features, num_features) # check parameters self.assertEqual(mcgsm.priors.shape[0], num_components) self.assertEqual(mcgsm.priors.shape[1], num_scales) self.assertEqual(mcgsm.scales.shape[0], num_components) self.assertEqual(mcgsm.scales.shape[1], num_scales) self.assertEqual(mcgsm.weights.shape[0], num_components) self.assertEqual(mcgsm.weights.shape[1], num_features) self.assertEqual(mcgsm.features.shape[0], dim_in) self.assertEqual(mcgsm.features.shape[1], num_features) self.assertEqual(len(mcgsm.cholesky_factors), num_components) self.assertEqual(len(mcgsm.predictors), num_components) self.assertEqual(mcgsm.cholesky_factors[0].shape[0], dim_out) self.assertEqual(mcgsm.cholesky_factors[0].shape[1], dim_out) self.assertEqual(mcgsm.predictors[0].shape[0], dim_out) self.assertEqual(mcgsm.predictors[0].shape[1], dim_in) self.assertEqual(mcgsm.linear_features.shape[0], num_components) self.assertEqual(mcgsm.linear_features.shape[1], dim_in) self.assertEqual(mcgsm.means.shape[0], dim_out) self.assertEqual(mcgsm.means.shape[1], num_components) # check dimensionality of output self.assertEqual(output.shape[0], dim_out) self.assertEqual(output.shape[1], num_samples) self.assertEqual(loglik.shape[0], 1) self.assertEqual(loglik.shape[1], num_samples) self.assertEqual(post.shape[0], num_components) self.assertEqual(post.shape[1], num_samples) self.assertLess(max(samples), mcgsm.num_components) self.assertGreaterEqual(min(samples), 0) self.assertEqual(samples.shape[0], 1) self.assertEqual(samples.shape[1], num_samples)