def maybe_cached_fun(x): if graph: _graph = graph[0] vals = {_graph[0]: x} for node in _graph[1:]: vals[node] = node.partial_fun([vals[p] for p in node.parents]) return vals[node] else: start_node = ConstGraphNode.new_root() end_value, end_node = trace(start_node, _fun.pop(), x) if end_node is None: raise Exception("Output is independent of input") graph.append(list(toposort(end_node))[::-1]) return end_value
def maybe_cached_fun(x): if graph: _graph = graph[0] vals = {_graph[0] : x} for node in _graph[1:]: vals[node] = node.partial_fun([vals[p] for p in node.parents]) return vals[node] else: start_node = ConstGraphNode.new_root() end_value, end_node = trace(start_node, _fun.pop(), x) if end_node is None: raise Exception("Output is independent of input") graph.append(list(toposort(end_node))[::-1]) return end_value
def time_fan_out_fan_in_forward_pass(): if MASTER_BRANCH: forward_pass(fan_out_fan_in, (2.,), {}) else: start_node = VJPNode.new_root() trace(start_node, fan_out_fan_in, x)
def time_long_forward_pass(): if MASTER_BRANCH: forward_pass(f_long, (2.,), {}) else: start_node = VJPNode.new_root() trace(start_node, f_long, x)
def time_exp_call(): onp.exp(2.) def time_exp_primitive_call_unboxed(): np.exp(2.) def time_exp_primitive_call_boxed(): if MASTER_BRANCH: np.exp(progenitor) else: np.exp(start_box) def time_no_autograd_control(): # Test whether the benchmarking machine is running slowly independent of autograd A = np.random.randn(200, 200) np.dot(A, A) if MASTER_BRANCH: short_start_node, short_end_node = forward_pass(f_short, (2.,), {}) long_start_node, long_end_node = forward_pass(f_long, (2.,), {}) fan_start_node, fan_end_node = forward_pass(fan_out_fan_in, (2.,), {}) progenitor = new_progenitor(2.) else: x = 2. start_node = VJPNode.new_root() start_box = new_box(x, 0, start_node) _, short_end_node = trace(VJPNode.new_root(), f_short, x) _, long_end_node = trace(VJPNode.new_root(), f_long, x) _, fan_end_node = trace(VJPNode.new_root(), fan_out_fan_in, x)
def print_trace(f, x): start_node = PrintNode.new_root(x) trace(start_node, f, x) print
def full_graph(fun, *args, **kwargs): unary_fun = lambda args: fun(*args, **kwargs) start_node = FullGraphNode.new_root() end_value, end_node = trace(start_node, unary_fun, args) return end_node
def tf_fun(*tf_xs): fmap_in = fmap_out = container_fmap start_nodes = fmap_in(TFNode, tf_xs) end_values, end_nodes = trace(fun, xs, start_nodes, fmap_in, fmap_out) return fmap_out(lambda n, v: v if n is None else n.tensor, end_nodes, end_values)
def print_trace(f, x): start_node = PrintNode.new_root(x) trace(start_node, f, x) print()
def trace_graph(f, x): start_node = GraphNode.new_root(x) _, node = trace(start_node, f, x) return node
def time_fan_out_fan_in_forward_pass(): if MASTER_BRANCH: forward_pass(fan_out_fan_in, (2.,), {}) else: start_node = VJPNode.new_root(x) trace(start_node, fan_out_fan_in, x)
def time_long_forward_pass(): if MASTER_BRANCH: forward_pass(f_long, (2.,), {}) else: start_node = VJPNode.new_root(x) trace(start_node, f_long, x)
def time_exp_call(): onp.exp(2.) def time_exp_primitive_call_unboxed(): np.exp(2.) def time_exp_primitive_call_boxed(): if MASTER_BRANCH: np.exp(progenitor) else: np.exp(start_box) def time_no_autograd_control(): # Test whether the benchmarking machine is running slowly independent of autograd A = np.random.randn(200, 200) np.dot(A, A) if MASTER_BRANCH: short_start_node, short_end_node = forward_pass(f_short, (2.,), {}) long_start_node, long_end_node = forward_pass(f_long, (2.,), {}) fan_start_node, fan_end_node = forward_pass(fan_out_fan_in, (2.,), {}) progenitor = new_progenitor(2.) else: x = 2. start_node = VJPNode.new_root(x) start_box = new_box(x, 0, start_node) _, short_end_node = trace(VJPNode.new_root(x), f_short, x) _, long_end_node = trace(VJPNode.new_root(x), f_long, x) _, fan_end_node = trace(VJPNode.new_root(x), fan_out_fan_in, x)