Esempio n. 1
0
    def adjoint(self, red_op, bin_op, root, targets):

        bin_unit = to_funsor(ops.UNITS[bin_op])
        adjoint_values = defaultdict(lambda: bin_unit)

        reached_root = False
        while self.tape:
            output, fn, inputs = self.tape.pop()
            if not reached_root:
                if output is root:
                    reached_root = True
                else:
                    continue

            # reverse the effects of alpha-renaming
            with interpretation(reflect):
                other_subs = tuple((name, to_funsor(name.split("__BOUND")[0], domain))
                                   for name, domain in output.inputs.items() if "__BOUND" in name)
                inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs))
                output = type(output)(*_alpha_unmangle(substitute(output, other_subs)))

            in_adjs = adjoint_ops(fn, red_op, bin_op, adjoint_values[output], *inputs)
            for v, adjv in in_adjs.items():
                adjoint_values[v] = bin_op(adjoint_values[v], adjv)

        target_adjs = {}
        for v in targets:
            target_adjs[v] = adjoint_values[v]
            if not isinstance(v, Variable):
                target_adjs[v] = bin_op(target_adjs[v], v)

        return target_adjs
Esempio n. 2
0
 def _alpha_convert(self, alpha_subs):
     assert self.bound.issuperset(alpha_subs)
     reduced_vars = frozenset(alpha_subs.get(k, k) for k in self.reduced_vars)
     alpha_subs = {k: to_funsor(v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k)))
                   for k, v in alpha_subs.items()}
     log_measure = substitute(self.log_measure, alpha_subs)
     integrand = substitute(self.integrand, alpha_subs)
     return log_measure, integrand, reduced_vars
Esempio n. 3
0
 def _alpha_convert(self, alpha_subs):
     assert self.bound.issuperset(alpha_subs)
     time = Variable(alpha_subs.get(self.time.name, self.time.name),
                     self.time.output)
     step = frozenset((alpha_subs.get(k, k), alpha_subs.get(v, v))
                      for k, v in self.step.items())
     step_names = frozenset((alpha_subs.get(k, k), v)
                            for k, v in self.step_names.items())
     alpha_subs = {k: to_funsor(v, self.trans.inputs[k])
                   for k, v in alpha_subs.items()
                   if k in self.trans.inputs}
     trans = substitute(self.trans, alpha_subs)
     return self.sum_op, self.prod_op, trans, time, step, step_names
Esempio n. 4
0
def materialize(x):
    """
    Attempt to convert a Funsor to a :class:`~funsor.terms.Number` or
    :class:`Tensor` by substituting :func:`arange` s into its free variables.
    """
    assert isinstance(x, Funsor)
    if isinstance(x, (Number, Tensor)):
        return x
    subs = []
    for name, domain in x.inputs.items():
        if isinstance(domain.dtype, integer_types):
            subs.append((name, arange(name, domain.dtype)))
    subs = tuple(subs)
    return substitute(x, subs)
Esempio n. 5
0
def materialize(x):
    """
    Attempt to convert a Funsor to a :class:`~funsor.terms.Number` or
    :class:`numpy.ndarray` by substituting :func:`arange` s into its free variables.
    """
    assert isinstance(x, Funsor)
    if isinstance(x, (Number, Array)):
        return x
    subs = []
    for name, domain in x.inputs.items():
        if not isinstance(domain.dtype, integer_types):
            raise ValueError(
                'materialize() requires integer free variables but found '
                '"{}" of domain {}'.format(name, domain))
        assert not domain.shape
        subs.append((name, arange(name, domain.dtype)))
    subs = tuple(subs)
    return substitute(x, subs)