def compute_gradients(x, values=None): """Compute gradients using automatic differentiation. x: The expression to compute gradients of. values: As returned by compute_values(x); contains all values to compute gradients with respect to. """ if values is None: values = compute_values(x) if not isinstance(values, collections.OrderedDict): raise TypeError() if x not in values: raise ValueError() gradients = {x: 1.} for subx in values: if isinstance(subx, expr.parameter) or any(arg in gradients for arg in subx.args): if subx not in gradients: gradients[subx] = 0. for subx in reversed(values): # assume values is in topological order if subx not in gradients: continue if logging.trace: sys.stdout.write("d<%s>/d<%s> = %s" % (x.serial, subx.serial, format_value(gradients[subx]))) try: subx.backward(values, gradients) if logging.trace: for arg in subx.args: sys.stdout.write(" d<%s>/d<%s> := %s" % (x.serial, arg.serial, format_value(gradients[arg], indent=8))) except: if logging.debug: sys.stderr.write("Expression traceback (most recent call last):\n" + "".join(logging.format_list(subx.stack))) raise return gradients
def compute_values(x, initvalues={}): """Evaluate an expression and all its subexpressions. x: The expression to evaluate. initvalues: Optional dictionary from subexpressions to precomputed values; can be used to continue a computation when the expression grows. """ values = collections.OrderedDict() for subx in expr.topological(x): if subx in initvalues: values[subx] = initvalues[subx] else: try: subx.forward(values) except: if logging.debug: sys.stderr.write("Expression traceback (most recent call last):\n" + "".join(logging.format_list(subx.stack))) raise if logging.trace: sys.stdout.write("<%s> = %s = %s" % (subx.serial, subx, format_value(values[subx]))) return values