def cond_merge_random_op(fgraph, main_node): if isinstance(main_node.op, IfElse): return False all_inp_nodes = set() for inp in main_node.inputs: all_inp_nodes.add(inp.owner) cond_nodes = [x for x in list(all_inp_nodes) if x and isinstance(x.op, IfElse)] if len(cond_nodes) < 2: return False merging_node = cond_nodes[0] for proposal in cond_nodes[1:]: if ( proposal.inputs[0] == merging_node.inputs[0] and not is_in_ancestors(proposal, merging_node) and not is_in_ancestors(merging_node, proposal) ): # Create a list of replacements for proposal mn_ts = merging_node.inputs[1:][: merging_node.op.n_outs] mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :] pl_ts = proposal.inputs[1:][: proposal.op.n_outs] pl_fs = proposal.inputs[1:][proposal.op.n_outs :] new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs mn_name = "?" if merging_node.op.name: mn_name = merging_node.op.name pl_name = "?" # mn_n_ts = len(mn_ts) # mn_n_fs = len(mn_fs) if proposal.op.name: pl_name = proposal.op.name new_ifelse = IfElse( n_outs=len(mn_ts + pl_ts), as_view=False, gpu=False, name=mn_name + "&" + pl_name, ) new_outs = new_ifelse(*new_ins, **dict(return_list=True)) old_outs = [] if type(merging_node.outputs) not in (list, tuple): old_outs += [merging_node.outputs] else: old_outs += merging_node.outputs if type(proposal.outputs) not in (list, tuple): old_outs += [proposal.outputs] else: old_outs += proposal.outputs pairs = list(zip(old_outs, new_outs)) main_outs = clone_replace(main_node.outputs, replace=pairs) return main_outs
def test_is_in_ancestors(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) o1 = MyOp(r1, r2) o1.name = "o1" o2 = MyOp(r3, o1) o2.name = "o2" assert is_in_ancestors(o2.owner, o1.owner)
def apply(self, fgraph): nodelist = list(fgraph.toposort()) cond_nodes = [s for s in nodelist if isinstance(s.op, IfElse)] if len(cond_nodes) < 2: return False merging_node = cond_nodes[0] for proposal in cond_nodes[1:]: if proposal.inputs[0] == merging_node.inputs[0] and not is_in_ancestors( proposal, merging_node ): # Create a list of replacements for proposal mn_ts = merging_node.inputs[1:][: merging_node.op.n_outs] mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :] pl_ts = proposal.inputs[1:][: proposal.op.n_outs] pl_fs = proposal.inputs[1:][proposal.op.n_outs :] new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs mn_name = "?" if merging_node.op.name: mn_name = merging_node.op.name pl_name = "?" # mn_n_ts = len(mn_ts) # mn_n_fs = len(mn_fs) if proposal.op.name: pl_name = proposal.op.name new_ifelse = IfElse( n_outs=len(mn_ts + pl_ts), as_view=False, gpu=False, name=mn_name + "&" + pl_name, ) print("here") new_outs = new_ifelse(*new_ins, **dict(return_list=True)) new_outs = [clone_replace(x) for x in new_outs] old_outs = [] if type(merging_node.outputs) not in (list, tuple): old_outs += [merging_node.outputs] else: old_outs += merging_node.outputs if type(proposal.outputs) not in (list, tuple): old_outs += [proposal.outputs] else: old_outs += proposal.outputs pairs = list(zip(old_outs, new_outs)) fgraph.replace_all_validate(pairs, reason="cond_merge")