def testTranspose_InferredMatchesActualShapeShape(self):
        with utils.SaveCodeAsString() as code_saver:
            x: Tensor2[A1, A2] = tf.zeros((1, 2))
            y = tf.transpose(x)

        inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code)

        self.assertEqual(inferred.y, y.shape)
    def testSum_InferredMatchesActualShape(self):
        with utils.SaveCodeAsString() as code_saver:
            x: Tensor2[A1, A2] = tf.zeros((1, 2))
            y1 = tf.reduce_sum(x, axis=0)
            y2 = tf.reduce_sum(x, axis=1)

        inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code)

        self.assertEqual(inferred.y1, y1.shape)
        self.assertEqual(inferred.y2, y2.shape)
    def testBinaryOpWithScalar_InferredMatchesActualShape(self):
        with utils.SaveCodeAsString() as code_saver:
            x: Tensor2[A1, A2] = tf.zeros((1, 2))
            y1 = x + 1.0
            y2 = x - 1.0
            y3 = x / 1.0
            y4 = x * 1.0

        inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code)

        self.assertEqual(y1.shape, inferred.y1)
        self.assertEqual(y2.shape, inferred.y2)
        self.assertEqual(y3.shape, inferred.y3)
        self.assertEqual(y4.shape, inferred.y4)
Esempio n. 4
0
    def testSum_InferredMatchesActualShape(self):
        with utils.SaveCodeAsString() as code_saver:
            x: Array2[A1, A2] = jnp.zeros((1, 2))
            y1 = jnp.sum(x, axis=0)
            y2 = jnp.sum(x, axis=1)
            y3 = jnp.sum(x, axis=(0, 1))
            y4 = jnp.sum(x)

        inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code)

        self.assertEqual(inferred.y1, y1.shape)
        self.assertEqual(inferred.y2, y2.shape)
        self.assertEqual(inferred.y3, y3.shape)
        self.assertEqual(inferred.y4, y4.shape)
    def testBinaryOpWithSameShape_InferredMatchesActualShape(self):
        with utils.SaveCodeAsString() as code_saver:
            a: Tensor2[A1, A2] = tf.zeros((1, 2))
            b: Tensor2[A1, A2] = tf.zeros((1, 2))
            y1 = a + b
            y2 = a - b
            y3 = a / b
            y4 = a * b

        inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code)

        self.assertEqual(y1.shape, inferred.y1)
        self.assertEqual(y2.shape, inferred.y2)
        self.assertEqual(y3.shape, inferred.y3)
        self.assertEqual(y4.shape, inferred.y4)