예제 #1
0
def test_mul_dispatcher():
    class NewBase(Expr):
        @property
        def _mul_handler(self):
            return NewMul

    class NewMul(NewBase, Mul):
        pass

    mul.register_handlerclass((Mul, NewMul), NewMul)

    a, b = Symbol('a'), NewBase()

    # Mul called as fallback
    assert mul(1, 2) == Mul(1, 2)
    assert mul(a, a) == Mul(a, a)

    # selection by registered priority
    assert mul(a, b, a) == NewMul(a**2, b)
예제 #2
0
                    Transpose(i).doit() if i.is_Matrix else i
                    for i in reversed(left_args)
                ])
            else:
                left_rev = Identity(self.shape[0])

            d = self.args[ind]._eval_derivative_matrix_lines(x)
            for i in d:
                i.append_first(left_rev)
                i.append_second(right_mat)
                lines.append(i)

        return lines


mul.register_handlerclass((Mul, MatMul), MatMul)


def validate(*matrices):
    """ Checks for valid shapes for args of MatMul """
    for i in range(len(matrices) - 1):
        A, B = matrices[i:i + 2]
        if A.cols != B.rows:
            raise ShapeError("Matrices %s and %s are not aligned" % (A, B))


# Rules


def newmul(*args):
    if args[0] == 1: