コード例 #1
0
    def _get_ancestors(self, var, func) -> MutableSet[RV]:
        """Get all ancestors of a function, doing some accounting for deterministics.
        """

        # this contains all of the variables in the model EXCEPT var...
        vars = set(self.var_list)
        vars.remove(var)

        blockers = set()
        retval = set()

        def _expand(node) -> Optional[Iterator[Tensor]]:
            if node in blockers:
                return None
            elif node in vars:
                blockers.add(node)
                retval.add(node)
                return None
            elif node.owner:
                blockers.add(node)
                return reversed(node.owner.inputs)
            else:
                return None

        stack_search(start=deque([func]), expand=_expand, mode='bfs')
        return retval
コード例 #2
0
ファイル: test_graph.py プロジェクト: michaelosthege/aesara
def test_stack_search():

    r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
    o1 = MyOp(r1, r2)
    o1.name = "o1"
    o2 = MyOp(r3, o1)
    o2.name = "o2"

    def expand(r):
        if r.owner:
            return r.owner.inputs

    res = stack_search([o2], expand, bfs=True, return_children=False)
    res_list = list(res)
    assert res_list == [o2, r3, o1, r1, r2]

    res = stack_search([o2], expand, bfs=False, return_children=False)
    res_list = list(res)
    assert res_list == [o2, o1, r2, r1, r3]

    res = stack_search([o2], expand, bfs=True, return_children=True)
    res_list = list(res)
    assert res_list == [
        (o2, [r3, o1]),
        (r3, None),
        (o1, [r1, r2]),
        (r1, None),
        (r2, None),
    ]
コード例 #3
0
ファイル: for_theano.py プロジェクト: gwtaylor/MonteTheano
def ancestors(variable_list, blockers = None):
    """Return the variables that contribute to those in variable_list (inclusive).

    :type variable_list: list of `Variable` instances
    :param variable_list:
        output `Variable` instances from which to search backward through owners
    :rtype: list of `Variable` instances
    :returns:
        all input nodes, in the order found by a left-recursive depth-first search
        started at the nodes in `variable_list`.

    """
    def expand(r):
        if r.owner and (not blockers or r not in blockers):
            l = list(r.owner.inputs)
            l.reverse()
            return l
    dfs_variables = graph.stack_search(graph.deque(variable_list), expand, 'dfs')
    return dfs_variables
コード例 #4
0
def ancestors(variable_list, blockers=None):
    """Return the variables that contribute to those in variable_list (inclusive).

    :type variable_list: list of `Variable` instances
    :param variable_list:
        output `Variable` instances from which to search backward through owners
    :rtype: list of `Variable` instances
    :returns:
        all input nodes, in the order found by a left-recursive depth-first search
        started at the nodes in `variable_list`.

    """
    def expand(r):
        if r.owner and (not blockers or r not in blockers):
            l = list(r.owner.inputs)
            l.reverse()
            return l

    dfs_variables = graph.stack_search(graph.deque(variable_list), expand,
                                       'dfs')
    return dfs_variables