def testPartiallySpecifiedTestSet(self): """Check that partially specified test set raises an error.""" num_features = 5 num_test_points = 5 dataset = _test_dataset(num_features, num_test_points) del dataset['test_features'] with self.assertRaisesRegex(ValueError, 'both specified'): probit_regression.ProbitRegression(**dataset)
def testBasic(self, num_test_points): """Checks that you get finite values given unconstrained samples. We check `unnormalized_log_prob` as well as the values of the sample transformations. Args: num_test_points: Number of test points. """ num_features = 5 model = probit_regression.ProbitRegression( **_test_dataset(num_features, num_test_points)) self.validate_log_prob_and_transforms( model, sample_transformation_shapes=dict( identity=[num_features + 1], test_nll=[], per_example_test_nll=[num_test_points], ))