예제 #1
0
 def testPartiallySpecifiedTestSet(self):
     """Check that partially specified test set raises an error."""
     num_test_points = 5
     dataset = _test_dataset(num_test_points)
     del dataset['test_student_ids']
     with self.assertRaisesRegex(ValueError, 'all be specified'):
         item_response_theory.ItemResponseTheory(**dataset)
예제 #2
0
 def testCreateDataset(self, num_test_points):
     """Checks that creating a dataset works."""
     # Technically this is private functionality, but we don't have it tested
     # elsewhere.
     if not tf.executing_eagerly():
         self.skipTest(
             'This is Eager only for now due to _sample_dataset being '
             'Eager-only.')
     model = item_response_theory.ItemResponseTheory(
         **_test_dataset(num_test_points))
     model2 = item_response_theory.ItemResponseTheory(
         **model._sample_dataset(tfp_test_util.test_seed()))
     self.validate_log_prob_and_transforms(
         model2,
         sample_transformation_shapes=dict(
             identity={
                 'mean_student_ability': [],
                 'student_ability': [20],
                 'question_difficulty': [10],
             },
             test_nll=[],
             per_example_test_nll=[num_test_points],
         ))
예제 #3
0
    def testBasic(self, num_test_points):
        """Checks that you get finite values given unconstrained samples.

    We check `unnormalized_log_prob` as well as the values of the sample
    transformations.

    Args:
      num_test_points: Number of test points.
    """
        model = item_response_theory.ItemResponseTheory(
            **_test_dataset(num_test_points))
        self.validate_log_prob_and_transforms(
            model,
            sample_transformation_shapes=dict(
                identity={
                    'mean_student_ability': [],
                    'student_ability': [20],
                    'question_difficulty': [10],
                },
                test_nll=[],
                per_example_test_nll=[num_test_points],
            ))