Beispiel #1
0
    def test_mul(self):
        self.assertEqual(shape.mul_shapes((5,9), (9,2)), (5,2))

        with self.assertRaises(Exception) as cm:
            shape.mul_shapes((5,3), (9,2))
        self.assertEqual(str(cm.exception), "Incompatible dimensions (5, 3) (9, 2)")

        # Promotion
        self.assertEqual(shape.mul_shapes((3,4), (1,1)), (3,4))
        self.assertEqual(shape.mul_shapes((1,1), (3,4)), (3,4))
Beispiel #2
0
    def test_mul(self):
        self.assertEqual(shape.mul_shapes((5, 9), (9, 2)), (5, 2))

        with self.assertRaises(Exception) as cm:
            shape.mul_shapes((5, 3), (9, 2))
        self.assertEqual(str(cm.exception),
                         "Incompatible dimensions (5, 3) (9, 2)")

        # Promotion
        self.assertEqual(shape.mul_shapes((3, 4), (1, 1)), (3, 4))
        self.assertEqual(shape.mul_shapes((1, 1), (3, 4)), (3, 4))
Beispiel #3
0
 def test_mul_scalars(self) -> None:
     """Test multiplication by scalars raises a ValueError.
     """
     with self.assertRaises(ValueError):
         shape.mul_shapes(tuple(), (5, 9))
     with self.assertRaises(ValueError):
         shape.mul_shapes((5, 9), tuple())
     with self.assertRaises(ValueError):
         shape.mul_shapes(tuple(), tuple())
Beispiel #4
0
    def test_mul_2d(self) -> None:
        """Test multiplication where at least one of the shapes is >= 2D.
        """
        self.assertEqual(shape.mul_shapes((5, 9), (9, 2)), (5, 2))
        self.assertEqual(shape.mul_shapes((3, 5, 9), (3, 9, 2)), (3, 5, 2))

        with self.assertRaises(Exception) as cm:
            shape.mul_shapes((5, 3), (9, 2))
        self.assertEqual(str(cm.exception),
                         "Incompatible dimensions (5, 3) (9, 2)")

        with self.assertRaises(Exception) as cm:
            shape.mul_shapes((3, 5, 9), (4, 9, 2))
        self.assertEqual(str(cm.exception),
                         "Incompatible dimensions (3, 5, 9) (4, 9, 2)")