Пример #1
0
 def testBadSampleTransformation(self):
     # If we don't flatten the transformations, we don't expect an error
     targets.VectorModel(TestBadSampleTransformationModel())
     with self.assertRaisesRegex(
             TypeError,
             'Sample transformation \'bad\' must have only one Tensor dtype'
     ):
         targets.VectorModel(TestBadSampleTransformationModel(),
                             flatten_sample_transformations=True)
Пример #2
0
    def testFlattenSampleTransformations(self, batch_shape):
        model = targets.VectorModel(TestStructuredModel(),
                                    flatten_sample_transformations=True)

        x_init = tf.zeros(batch_shape + model.event_shape, model.dtype)
        first_moment_transform = model.sample_transformations['first_moment']
        first_moment = first_moment_transform(x_init)
        self.assertAllEqual(batch_shape + model.event_shape,
                            first_moment.shape)
        self.assertAllEqual(
            model.event_shape,
            first_moment_transform.ground_truth_mean.shape,
        )
        self.assertAllEqual(
            model.event_shape,
            first_moment_transform.ground_truth_standard_deviation.shape,
        )
        self.assertAllEqual(
            model.event_shape,
            first_moment_transform.ground_truth_mean_standard_error.shape,
        )
        self.assertAllEqual(
            model.event_shape,
            first_moment_transform.
            ground_truth_standard_deviation_standard_error.shape,
        )
Пример #3
0
    def testBasic(self, model_class, vec_event_size, batch_shape):
        base_model = model_class()
        vec_model = targets.VectorModel(base_model)

        # We can randomize only one element, as otherwise we'd need to know the
        # details of the reshaping/flattening which is outside the scope of this
        # test.
        rand_elem = tf.constant(np.random.randn(), tf.float32)

        self.assertEqual('vector_' + base_model.name, vec_model.name)
        self.assertEqual(str(base_model), str(vec_model))
        self.assertEqual(tf.float32, vec_model.dtype)
        self.assertEqual([vec_event_size], list(vec_model.event_shape))

        # z - unconstrained space
        structured_z = tf.nest.map_structure(
            lambda s, d, b: tf.fill(  # pylint: disable=g-long-lambda
                batch_shape + list(b.inverse_event_shape(s)),
                tf.cast(rand_elem, d)),
            base_model.event_shape,
            base_model.dtype,
            base_model.default_event_space_bijector)

        # x - constrained space
        structured_x = tf.nest.map_structure(
            lambda b, z: b(z), base_model.default_event_space_bijector,
            structured_z)

        vec_z = tf.fill(
            batch_shape + list(
                vec_model.default_event_space_bijector.inverse_event_shape(
                    vec_model.event_shape)),
            tf.cast(rand_elem, vec_model.dtype),
        )
        vec_x = vec_model.default_event_space_bijector(vec_z)

        # Utility transforms.
        self.assertAllEqualNested(structured_x,
                                  vec_model.vector_event_to_structured(vec_x))
        self.assertAllEqual(vec_x,
                            vec_model.structured_event_to_vector(structured_x))

        self.assertAllEqual(base_model.unnormalized_log_prob(structured_x),
                            vec_model.unnormalized_log_prob(vec_x))
        self.assertAllEqualNested(
            base_model.sample_transformations['first_moment'](structured_x),
            vec_model.sample_transformations['first_moment'](vec_x))
        self.assertAllEqualNested(
            base_model.sample_transformations['second_moment'](structured_x),
            vec_model.sample_transformations['second_moment'](vec_x))

        # Verify that generating values directly from the vector event space works
        # as well.
        vec_x = tf.zeros(batch_shape + vec_model.event_shape, vec_model.dtype)
        self.assertAllEqual(batch_shape,
                            vec_model.unnormalized_log_prob(vec_x).shape)
Пример #4
0
    def testExample(self):
        base_model = targets.SyntheticItemResponseTheory()
        vec_model = targets.VectorModel(base_model)

        self.assertAllAssertsNested(
            self.assertEqual, {
                'mean_student_ability': tf.float32,
                'student_ability': tf.float32,
                'question_difficulty': tf.float32,
            }, base_model.dtype)
        self.assertEqual(tf.float32, vec_model.dtype)
        self.assertAllAssertsNested(self.assertEqual, {
            'mean_student_ability': [],
            'student_ability': [400],
            'question_difficulty': [100],
        },
                                    base_model.event_shape,
                                    shallow=base_model.dtype)
        self.assertEqual([501], list(vec_model.event_shape))
Пример #5
0
 def testBadModel(self):
     with self.assertRaisesRegex(TypeError,
                                 'Model must have only one Tensor dtype'):
         targets.VectorModel(TestBadModel())