def compute_gradients(x, values=None):
    """Compute gradients using automatic differentiation.

    x:      The expression to compute gradients of.
    values: As returned by compute_values(x); contains all values to compute
            gradients with respect to.
    """

    if values is None:
        values = compute_values(x)
    if not isinstance(values, collections.OrderedDict): raise TypeError()
    if x not in values: raise ValueError()

    gradients = {x: 1.}
    for subx in values:
        if isinstance(subx, expr.parameter) or any(arg in gradients for arg in subx.args):
            if subx not in gradients:
                gradients[subx] = 0.

    for subx in reversed(values): # assume values is in topological order
        if subx not in gradients:
            continue
        if logging.trace:
            sys.stdout.write("d<%s>/d<%s> = %s" % (x.serial, subx.serial, format_value(gradients[subx])))
        try:
            subx.backward(values, gradients)
            if logging.trace:
                for arg in subx.args:
                    sys.stdout.write("    d<%s>/d<%s> := %s" % (x.serial, arg.serial, format_value(gradients[arg], indent=8)))
        except:
            if logging.debug:
                sys.stderr.write("Expression traceback (most recent call last):\n" + "".join(logging.format_list(subx.stack)))
            raise
    return gradients
def compute_values(x, initvalues={}):
    """Evaluate an expression and all its subexpressions.

    x:          The expression to evaluate.
    initvalues: Optional dictionary from subexpressions to
                precomputed values; can be used to continue a
                computation when the expression grows.
    """
    values = collections.OrderedDict()

    for subx in expr.topological(x):
        if subx in initvalues:
            values[subx] = initvalues[subx]
        else:
            try:
                subx.forward(values)
            except:
                if logging.debug:
                    sys.stderr.write("Expression traceback (most recent call last):\n" + "".join(logging.format_list(subx.stack)))
                raise
            if logging.trace:
                sys.stdout.write("<%s> = %s = %s" % (subx.serial, subx, format_value(values[subx])))

    return values