コード例 #1
0
 def testPartiallySpecifiedTestSet(self):
   """Check that partially specified test set raises an error."""
   num_features = 5
   num_test_points = 5
   dataset = _test_dataset(num_features, num_test_points)
   del dataset['test_features']
   with self.assertRaisesRegex(ValueError, 'both specified'):
     probit_regression.ProbitRegression(**dataset)
コード例 #2
0
  def testBasic(self, num_test_points):
    """Checks that you get finite values given unconstrained samples.

    We check `unnormalized_log_prob` as well as the values of the sample
    transformations.

    Args:
      num_test_points: Number of test points.
    """
    num_features = 5
    model = probit_regression.ProbitRegression(
        **_test_dataset(num_features, num_test_points))
    self.validate_log_prob_and_transforms(
        model,
        sample_transformation_shapes=dict(
            identity=[num_features + 1],
            test_nll=[],
            per_example_test_nll=[num_test_points],
        ))