def calc_jacobian_elem(start, end): with warnings.catch_warnings(): warnings.simplefilter("ignore") b = anp.ones(start.shape) n = new_box(b, 0, VJPNode.new_root()) jac = backward_pass(n, end._node) return jac._value
def grad_fn(*args, **kwds): """Computes the gradient of the wrapped function.""" tape.push_new_tape() end_node = f(*args) start_node = tape.pop_tape() ag_core.active_progenitors.remove(start_node) if not ag_core.isnode(end_node): raise ValueError( "Target not part of a computation being traced. %s" % end_node) if start_node not in end_node.progenitors: raise ValueError("Target not derived from source. %s %s" % (end_node.progenitors, repr(start_node))) output_gradients = kwds.get("output_gradients", None) if output_gradients is None: output_gradients = _ones(end_node.shape, end_node.dtype) grad = ag_core.backward_pass(output_gradients, end_node, start_node) return end_node.value, _aggregate_grads(grad.gradients)
def calc_jacobian(start, end): # if the end_box is not a box - autograd can not track back if not isbox(end): return vspace(start.shape).zeros() # the final jacobian matrices jac = [] # the backward pass is done for each objective function once for j in range(end.shape[1]): b = anp.zeros(end.shape) b[:, j] = 1 n = new_box(b, 0, VJPNode.new_root()) _jac = backward_pass(n, end._node) jac.append(_jac) jac = anp.stack(jac, axis=1) return jac
def time_fan_out_fan_in_backward_pass(): if MASTER_BRANCH: backward_pass(1., fan_end_node, fan_start_node) else: backward_pass(1., fan_end_node)
def time_long_backward_pass(): if MASTER_BRANCH: backward_pass(1., long_end_node, long_start_node) else: backward_pass(1., long_end_node)
def time_short_backward_pass(): if MASTER_BRANCH: backward_pass(1., short_end_node, short_start_node) else: backward_pass(1., short_end_node)
def time_fan_out_fan_in_backward_pass(): core.backward_pass(1., fan_end_node, fan_start_node)
def time_long_backward_pass(): core.backward_pass(1., long_end_node, long_start_node)
def time_short_backward_pass(): core.backward_pass(1., short_end_node, short_start_node)