Пример #1
0
def cond_merge_random_op(fgraph, main_node):
    if isinstance(main_node.op, IfElse):
        return False

    all_inp_nodes = set()
    for inp in main_node.inputs:
        all_inp_nodes.add(inp.owner)
    cond_nodes = [x for x in list(all_inp_nodes) if x and isinstance(x.op, IfElse)]

    if len(cond_nodes) < 2:
        return False

    merging_node = cond_nodes[0]
    for proposal in cond_nodes[1:]:
        if (
            proposal.inputs[0] == merging_node.inputs[0]
            and not is_in_ancestors(proposal, merging_node)
            and not is_in_ancestors(merging_node, proposal)
        ):
            # Create a list of replacements for proposal
            mn_ts = merging_node.inputs[1:][: merging_node.op.n_outs]
            mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :]
            pl_ts = proposal.inputs[1:][: proposal.op.n_outs]
            pl_fs = proposal.inputs[1:][proposal.op.n_outs :]
            new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs
            mn_name = "?"
            if merging_node.op.name:
                mn_name = merging_node.op.name
            pl_name = "?"
            # mn_n_ts = len(mn_ts)
            # mn_n_fs = len(mn_fs)
            if proposal.op.name:
                pl_name = proposal.op.name
            new_ifelse = IfElse(
                n_outs=len(mn_ts + pl_ts),
                as_view=False,
                gpu=False,
                name=mn_name + "&" + pl_name,
            )
            new_outs = new_ifelse(*new_ins, **dict(return_list=True))
            old_outs = []
            if type(merging_node.outputs) not in (list, tuple):
                old_outs += [merging_node.outputs]
            else:
                old_outs += merging_node.outputs
            if type(proposal.outputs) not in (list, tuple):
                old_outs += [proposal.outputs]
            else:
                old_outs += proposal.outputs
            pairs = list(zip(old_outs, new_outs))
            main_outs = clone_replace(main_node.outputs, replace=pairs)
            return main_outs
Пример #2
0
def test_is_in_ancestors():

    r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
    o1 = MyOp(r1, r2)
    o1.name = "o1"
    o2 = MyOp(r3, o1)
    o2.name = "o2"

    assert is_in_ancestors(o2.owner, o1.owner)
Пример #3
0
 def apply(self, fgraph):
     nodelist = list(fgraph.toposort())
     cond_nodes = [s for s in nodelist if isinstance(s.op, IfElse)]
     if len(cond_nodes) < 2:
         return False
     merging_node = cond_nodes[0]
     for proposal in cond_nodes[1:]:
         if proposal.inputs[0] == merging_node.inputs[0] and not is_in_ancestors(
             proposal, merging_node
         ):
             # Create a list of replacements for proposal
             mn_ts = merging_node.inputs[1:][: merging_node.op.n_outs]
             mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :]
             pl_ts = proposal.inputs[1:][: proposal.op.n_outs]
             pl_fs = proposal.inputs[1:][proposal.op.n_outs :]
             new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs
             mn_name = "?"
             if merging_node.op.name:
                 mn_name = merging_node.op.name
             pl_name = "?"
             # mn_n_ts = len(mn_ts)
             # mn_n_fs = len(mn_fs)
             if proposal.op.name:
                 pl_name = proposal.op.name
             new_ifelse = IfElse(
                 n_outs=len(mn_ts + pl_ts),
                 as_view=False,
                 gpu=False,
                 name=mn_name + "&" + pl_name,
             )
             print("here")
             new_outs = new_ifelse(*new_ins, **dict(return_list=True))
             new_outs = [clone_replace(x) for x in new_outs]
             old_outs = []
             if type(merging_node.outputs) not in (list, tuple):
                 old_outs += [merging_node.outputs]
             else:
                 old_outs += merging_node.outputs
             if type(proposal.outputs) not in (list, tuple):
                 old_outs += [proposal.outputs]
             else:
                 old_outs += proposal.outputs
             pairs = list(zip(old_outs, new_outs))
             fgraph.replace_all_validate(pairs, reason="cond_merge")