def f(local_opt): name = (kwargs and kwargs.pop("name")) or local_opt.__name__ optdb.register( name, TopoOptimizer(local_opt, failure_callback=TopoOptimizer.warn_inplace), 60, "fast_run", "inplace", "gpuarray", *tags, ) return local_opt
from aesara.compile import optdb from aesara.graph.opt import TopoOptimizer, local_optimizer from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse @local_optimizer([Append, Extend, Insert, Reverse, Remove], inplace=True) def typed_list_inplace_opt(fgraph, node): if (isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) and not node.op.inplace): new_op = node.op.__class__(inplace=True) new_node = new_op(*node.inputs) return [new_node] return False optdb.register( "typed_list_inplace_opt", TopoOptimizer(typed_list_inplace_opt, failure_callback=TopoOptimizer.warn_inplace), 60, "fast_run", "inplace", )
def TopoPatternOptimizer(p1, p2, ign=True): return TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
@local_optimizer([SparseBlockGemv], inplace=True) def local_inplace_sparse_block_gemv(fgraph, node): """ SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True) """ if isinstance(node.op, SparseBlockGemv) and not node.op.inplace: new_node = sparse_block_gemv_inplace(*node.inputs) copy_stack_trace(node.outputs[0], new_node) return [new_node] return False compile.optdb.register( "local_inplace_sparse_block_gemv", TopoOptimizer( local_inplace_sparse_block_gemv, failure_callback=TopoOptimizer.warn_inplace ), 60, "fast_run", "inplace", ) # DEBUG @local_optimizer([SparseBlockOuter], inplace=True) def local_inplace_sparse_block_outer(fgraph, node): """ SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True) """ if isinstance(node.op, SparseBlockOuter) and not node.op.inplace: new_node = sparse_block_outer_inplace(*node.inputs) copy_stack_trace(node.outputs[0], new_node)
def map_variables(replacer, graphs, additional_inputs=None): """Construct new graphs based on 'graphs' with some variables replaced according to 'replacer'. :param replacer: function that takes a variable and returns its replacement. :param graphs: an iterable of graphs in which to replace variables :param additional_inputs: an iterable of graph inputs not used in any of 'graphs' but possibly used in the graphs returned by `replacer` :return: the new graphs, in the same order as 'graphs' Example: .. code-block:: python tag = "replaceme" a = aesara.tensor.type.scalar("a") b = aesara.tensor.type.scalar("b") c = aesara.tensor.type.scalar("c") ab = a + b ab.tag.replacement = a * b u = ab + c v, = map_variables(lambda graph: return getattr(graph.tag, "replacement", graph), [u]) # v is now equal to a * b + c """ if additional_inputs is None: additional_inputs = [] # wrap replacer to avoid replacing things we just put there. graphs_seen = set() def wrapped_replacer(graph): if graph in graphs_seen: return graph else: new_graph = replacer(graph) graphs_seen.add(new_graph) return new_graph graphs = list(graphs) inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs))) # perform any desired replacement of input variables. these # aren't replaced by the local optimizer approach because they are # not outputs of any Apply node. new_inputs = [wrapped_replacer(i) for i in inputs_] replacements = [(input_, new_input) for input_, new_input in zip(inputs_, new_inputs) if new_input is not input_] graphs = clone_replace(graphs, share_inputs=True, replace=replacements) inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs))) fg = FunctionGraph(inputs_, graphs, clone=False) nodes_seen = set() @local_optimizer(None) def local_transform(fgraph, node): if node in nodes_seen: return False # importing Scan into module scope would be circular from aesara.compile.builders import OpFromGraph from aesara.scan.op import Scan if isinstance(node.op, (Scan, OpFromGraph)): # recurse on the inner graph ( new_inner_inputs, new_outer_inputs, new_inner_outputs, ) = _map_variables_inner( wrapped_replacer, inner_inputs=node.op.inputs, outer_inputs=node.inputs, inner_outputs=node.op.outputs, containing_op=node.op, ) # reinstantiate the op if isinstance(node.op, Scan): new_op = Scan( new_inner_inputs, new_inner_outputs, node.op.info, node.op.mode, # FIXME: infer this someday? typeConstructor=None, ) elif isinstance(node.op, OpFromGraph): new_op = OpFromGraph(new_inner_inputs, new_inner_outputs, **node.op.kwargs) # make a new node to replace the old one new_node = new_op.make_node(*new_outer_inputs) nodes_seen.add(new_node) return new_node.outputs else: nodes_seen.add(node) replacements = [wrapped_replacer(o) for o in node.outputs] # Add inputs to replacement graphs as inputs to this `fgraph` for i in graph_inputs(replacements): fgraph.add_input(i) return replacements topo_transform = TopoOptimizer(local_transform, "out_to_in") topo_transform.optimize(fg) new_graphs = fg.outputs fg.disown() return new_graphs
out_tmp_padded = at.zeros(dtype=out_tmp.dtype, shape=(Ns, Ts + 2 * Tpad, Nf, Tf, Hout, Wout)) out_tmp_padded = aesara.tensor.subtensor.set_subtensor( out_tmp_padded[:, Tpad:(Ts + Tpad), :, :, :, :], out_tmp) out_5d = diagonal_subtensor(out_tmp_padded, 1, 3).sum(axis=3) return out_5d @local_optimizer([DiagonalSubtensor, IncDiagonalSubtensor]) def local_inplace_DiagonalSubtensor(fgraph, node): """Also work for IncDiagonalSubtensor.""" if (isinstance(node.op, (DiagonalSubtensor, IncDiagonalSubtensor)) and not node.op.inplace): new_op = node.op.__class__(inplace=True) new_node = new_op(*node.inputs) copy_stack_trace(node.outputs[0], new_node) return [new_node] return False aesara.compile.optdb.register( "local_inplace_DiagonalSubtensor", TopoOptimizer(local_inplace_DiagonalSubtensor, failure_callback=TopoOptimizer.warn_inplace), "fast_run", "inplace", position=60, )
def OpSubOptimizer(op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True): return TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback=fail)