Exemple #1
0
def vars_to_rvs(var):
    """Compute paths from `TensorVariable`s to their underlying `RandomVariable` outputs."""
    return {
        a: v if v[0] is not a else (v[1])
        for a, v in [(a, is_random_variable(a)) for a in ancestors([var])]
        if v is not None
    }
Exemple #2
0
def sort_replacements(replace_pairs):
    """
    Return a list of (oldvar, newvar) pairs in dependency order.

    returns: a list of [(old0, new0), (old1, new1), ...] pairs such that
    if A < B, then newA's does not depend on oldB.

    The purpose of this function is to support a sensible interpretation of
    givens when the various subgraphs they represent are tangled up and
    co-dependent.

    """
    # Suppose we're replacing vars v1 and v2,
    # but v2 appears in the ancestors of v1.
    # In this case we have to replace v2 first, and then v1.
    v_orig_ancestors = {}
    v_origs_set = set([v_orig for (v_orig, v_repl) in replace_pairs])
    for v_orig in v_origs_set:
        anc = graph.ancestors([v_orig],
                blockers=set(
                    [v for v in v_origs_set if v is not v_orig]))
        v_orig_ancestors[v_orig] = set(anc)
    def v_cmp(x, y):
        if x[0] in v_orig_ancestors[y[0]]:
            return -1
        if y[0] in v_orig_ancestors[x[0]]:
            return 1
        return 0
    rval = list(replace_pairs)
    rval.sort(v_cmp)
    return rval
Exemple #3
0
def test_ancestors():

    r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
    o1 = MyOp(r1, r2)
    o1.name = "o1"
    o2 = MyOp(r3, o1)
    o2.name = "o2"

    res = ancestors([o2], blockers=None)
    res_list = list(res)
    assert res_list == [o2, r3, o1, r1, r2]

    res = ancestors([o2], blockers=None)
    assert r3 in res
    res_list = list(res)
    assert res_list == [o1, r1, r2]

    res = ancestors([o2], blockers=[o1])
    res_list = list(res)
    assert res_list == [o2, r3, o1]
Exemple #4
0
 def prefer_replace(self, replace_pairs, reason=None):
     """Move clients as possible from r to new_r without creating cycle.
     """
     replacements = sort_replacements(replace_pairs)
     replacements = [(self.equiv.get(r, r), self.equiv.get(new_r, new_r))
             for (r, new_r) in replacements if r is not new_r]
     for r, new_r in replacements:
         new_ancestors = set(graph.ancestors([new_r]))
         for node, i in list(r.clients):
             if (node == 'output'
                     or any([(outvar in new_ancestors)
                         for outvar in node.outputs])):
                 # if a client is in the ancestors of new_r, then do not
                 # transfer it.  It would create a cycle, and in the case
                 # of shape nodes... it's not what we want either.
                 continue
             assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
             self.change_input(node, i, new_r, reason=reason)
 def _ancestors(self, var, func, blockers=None):
     """Get ancestors of a function that are also named PyMC3 variables"""
     return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])
Exemple #6
0
 def _constant_parents(self, var, func):
     return set([float(j.eval()) for j in ancestors([func]) if is_constant(j)])        
Exemple #7
0
 def _ancestors(self, var, func, blockers=None):
     """Get ancestors of a function that are also named PyMC3 variables"""
     return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])        
def get_parameter_dependencies(parameter):
    return [
        i for i in ancestors([parameter])
        if i.name is not None and i is not parameter
    ]