def testTranspose_InferredMatchesActualShapeShape(self): with utils.SaveCodeAsString() as code_saver: x: Tensor2[A1, A2] = tf.zeros((1, 2)) y = tf.transpose(x) inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code) self.assertEqual(inferred.y, y.shape)
def testSum_InferredMatchesActualShape(self): with utils.SaveCodeAsString() as code_saver: x: Tensor2[A1, A2] = tf.zeros((1, 2)) y1 = tf.reduce_sum(x, axis=0) y2 = tf.reduce_sum(x, axis=1) inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code) self.assertEqual(inferred.y1, y1.shape) self.assertEqual(inferred.y2, y2.shape)
def testBinaryOpWithScalar_InferredMatchesActualShape(self): with utils.SaveCodeAsString() as code_saver: x: Tensor2[A1, A2] = tf.zeros((1, 2)) y1 = x + 1.0 y2 = x - 1.0 y3 = x / 1.0 y4 = x * 1.0 inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code) self.assertEqual(y1.shape, inferred.y1) self.assertEqual(y2.shape, inferred.y2) self.assertEqual(y3.shape, inferred.y3) self.assertEqual(y4.shape, inferred.y4)
def testSum_InferredMatchesActualShape(self): with utils.SaveCodeAsString() as code_saver: x: Array2[A1, A2] = jnp.zeros((1, 2)) y1 = jnp.sum(x, axis=0) y2 = jnp.sum(x, axis=1) y3 = jnp.sum(x, axis=(0, 1)) y4 = jnp.sum(x) inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code) self.assertEqual(inferred.y1, y1.shape) self.assertEqual(inferred.y2, y2.shape) self.assertEqual(inferred.y3, y3.shape) self.assertEqual(inferred.y4, y4.shape)
def testBinaryOpWithSameShape_InferredMatchesActualShape(self): with utils.SaveCodeAsString() as code_saver: a: Tensor2[A1, A2] = tf.zeros((1, 2)) b: Tensor2[A1, A2] = tf.zeros((1, 2)) y1 = a + b y2 = a - b y3 = a / b y4 = a * b inferred = utils.pytype_infer_shapes(_PREAMBLE + code_saver.code) self.assertEqual(y1.shape, inferred.y1) self.assertEqual(y2.shape, inferred.y2) self.assertEqual(y3.shape, inferred.y3) self.assertEqual(y4.shape, inferred.y4)