Example #1
0
 def testPartiallySpecifiedTestSet(self):
     """Check that partially specified test set raises an error."""
     num_features = 5
     num_test_points = 5
     dataset = _test_dataset(num_features, num_test_points)
     del dataset['test_features']
     with self.assertRaisesRegex(ValueError, 'both specified'):
         sparse_logistic_regression.SparseLogisticRegression(**dataset)
Example #2
0
    def testBasic(self, num_test_points):
        """Checks that you get finite values given unconstrained samples.

    We check `unnormalized_log_prob` as well as the values of the sample
    transformations.

    Args:
      num_test_points: Number of test points.
    """
        num_features = 5
        model = sparse_logistic_regression.SparseLogisticRegression(
            **_test_dataset(num_features, num_test_points))
        self.validate_log_prob_and_transforms(
            model,
            sample_transformation_shapes=dict(
                identity={
                    'global_scale': [],
                    'local_scales': [num_features + 1],
                    'unscaled_weights': [num_features + 1],
                },
                test_nll=[],
                per_example_test_nll=[num_test_points],
            ))