def sum_canon(expr, args): X = args[0] if expr.axis is None: summation = explicit_sum(X) canon, _ = add_canon(summation, summation.args) return reshape(canon, expr.shape), [] if expr.axis == 0: X = X.T rows = [] for i in range(X.shape[0]): summation = explicit_sum(X[i]) canon, _ = add_canon(summation, summation.args) rows.append(canon) canon = hstack(rows) return reshape(canon, expr.shape), []
def quad_over_lin_canon(expr, args): summed = explicit_sum(2 * args[0]) numerator, _ = add_canon(summed, summed.args) return numerator - args[1], []
def trace_canon(expr, args): diag_sum = explicit_sum(diag(args[0])) return add_canon(diag_sum, diag_sum.args)