def adjoint(self, red_op, bin_op, root, targets): bin_unit = to_funsor(ops.UNITS[bin_op]) adjoint_values = defaultdict(lambda: bin_unit) reached_root = False while self.tape: output, fn, inputs = self.tape.pop() if not reached_root: if output is root: reached_root = True else: continue # reverse the effects of alpha-renaming with interpretation(reflect): other_subs = tuple((name, to_funsor(name.split("__BOUND")[0], domain)) for name, domain in output.inputs.items() if "__BOUND" in name) inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs)) output = type(output)(*_alpha_unmangle(substitute(output, other_subs))) in_adjs = adjoint_ops(fn, red_op, bin_op, adjoint_values[output], *inputs) for v, adjv in in_adjs.items(): adjoint_values[v] = bin_op(adjoint_values[v], adjv) target_adjs = {} for v in targets: target_adjs[v] = adjoint_values[v] if not isinstance(v, Variable): target_adjs[v] = bin_op(target_adjs[v], v) return target_adjs
def _alpha_convert(self, alpha_subs): assert self.bound.issuperset(alpha_subs) reduced_vars = frozenset(alpha_subs.get(k, k) for k in self.reduced_vars) alpha_subs = {k: to_funsor(v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k))) for k, v in alpha_subs.items()} log_measure = substitute(self.log_measure, alpha_subs) integrand = substitute(self.integrand, alpha_subs) return log_measure, integrand, reduced_vars
def _alpha_convert(self, alpha_subs): assert self.bound.issuperset(alpha_subs) time = Variable(alpha_subs.get(self.time.name, self.time.name), self.time.output) step = frozenset((alpha_subs.get(k, k), alpha_subs.get(v, v)) for k, v in self.step.items()) step_names = frozenset((alpha_subs.get(k, k), v) for k, v in self.step_names.items()) alpha_subs = {k: to_funsor(v, self.trans.inputs[k]) for k, v in alpha_subs.items() if k in self.trans.inputs} trans = substitute(self.trans, alpha_subs) return self.sum_op, self.prod_op, trans, time, step, step_names
def materialize(x): """ Attempt to convert a Funsor to a :class:`~funsor.terms.Number` or :class:`Tensor` by substituting :func:`arange` s into its free variables. """ assert isinstance(x, Funsor) if isinstance(x, (Number, Tensor)): return x subs = [] for name, domain in x.inputs.items(): if isinstance(domain.dtype, integer_types): subs.append((name, arange(name, domain.dtype))) subs = tuple(subs) return substitute(x, subs)
def materialize(x): """ Attempt to convert a Funsor to a :class:`~funsor.terms.Number` or :class:`numpy.ndarray` by substituting :func:`arange` s into its free variables. """ assert isinstance(x, Funsor) if isinstance(x, (Number, Array)): return x subs = [] for name, domain in x.inputs.items(): if not isinstance(domain.dtype, integer_types): raise ValueError( 'materialize() requires integer free variables but found ' '"{}" of domain {}'.format(name, domain)) assert not domain.shape subs.append((name, arange(name, domain.dtype))) subs = tuple(subs) return substitute(x, subs)