Пример #1
0
    def test_matmul_expression(self):
        """Test matmul function, corresponding to .__matmul__( operator.
        """
        # Vectors
        c = Constant([[2], [2]])
        exp = c.__matmul__(self.x)
        self.assertEqual(exp.curvature, s.AFFINE)
        self.assertEqual(exp.sign, s.UNKNOWN)
        # self.assertEqual(exp.name(), c.name() + " .__matmul__( " + self.x.name())
        self.assertEqual(exp.shape, (1, ))

        with self.assertRaises(Exception) as cm:
            self.x.__matmul__(2)
        self.assertEqual(str(cm.exception),
                         "Scalar operands are not allowed, use '*' instead")

        # Incompatible dimensions
        with self.assertRaises(ValueError) as cm:
            (self.x.__matmul__(np.array([2, 2, 3])))

        # Incompatible dimensions
        with self.assertRaises(Exception) as cm:
            Constant([[2, 1], [2, 2]]).__matmul__(self.C)

        # Affine times affine is okay
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            q = self.A.__matmul__(self.B)
            self.assertTrue(q.is_quadratic())

        # # Nonaffine times nonconstant raises error
        # with warnings.catch_warnings():
        #     warnings.simplefilter("ignore")
        #     with self.assertRaises(Exception) as cm:
        #         (self.A.__matmul__(self.B).__matmul__(self.A))
        #     self.assertEqual(str(cm.exception), "Cannot multiply UNKNOWN and AFFINE.")

        # Constant expressions
        T = Constant([[1, 2, 3], [3, 5, 5]])
        exp = (T + T).__matmul__(self.B)
        self.assertEqual(exp.curvature, s.AFFINE)
        self.assertEqual(exp.shape, (3, 2))

        # Expression that would break sign multiplication without promotion.
        c = Constant([[2], [2], [-2]])
        exp = [[1], [2]] + c.__matmul__(self.C)
        self.assertEqual(exp.sign, s.UNKNOWN)

        # Testing shape.
        a = Parameter((1, ))
        x = Variable(shape=(1, ))
        expr = a.__matmul__(x)
        self.assertEqual(expr.shape, ())

        # Testing shape.
        a = Parameter((1, ))
        x = Variable(shape=(1, ))
        expr = a.__matmul__(x)
        self.assertEqual(expr.shape, ())

        A = Parameter((4, 4))
        z = Variable((4, 1))
        expr = A.__matmul__(z)
        self.assertEqual(expr.shape, (4, 1))

        v = Variable((1, 1))
        col_scalar = Parameter((1, 1))
        assert v.shape == col_scalar.shape == col_scalar.T.shape