def eager_reduce(self, op, reduced_vars): if op is self.op: lhs = self.lhs.reduce(op, reduced_vars) rhs = self.rhs.reduce(op, reduced_vars) return op(lhs, rhs) return interpreter.debug_logged(super(Binary, self).eager_reduce)(op, reduced_vars)
def sample(self, sampled_vars, sample_inputs=None): """ Create a Monte Carlo approximation to this funsor by replacing functions of ``sampled_vars`` with :class:`~funsor.delta.Delta` s. The result is a :class:`Funsor` with the same ``.inputs`` and ``.output`` as the original funsor (plus ``sample_inputs`` if provided), so that self can be replaced by the sample in expectation computations:: y = x.sample(sampled_vars) assert y.inputs == x.inputs assert y.output == x.output exact = (x.exp() * integrand).reduce(ops.add) approx = (y.exp() * integrand).reduce(ops.add) If ``sample_inputs`` is provided, this creates a batch of samples scaled samples. :param frozenset sampled_vars: A set of input variables to sample. :param OrderedDict sample_inputs: An optional mapping from variable name to :class:`~funsor.domains.Domain` over which samples will be batched. """ assert self.output == reals() assert isinstance(sampled_vars, frozenset) if sample_inputs is None: sample_inputs = OrderedDict() assert isinstance(sample_inputs, OrderedDict) if sampled_vars.isdisjoint(self.inputs): return self result = interpreter.debug_logged(self.unscaled_sample)(sampled_vars, sample_inputs) if sample_inputs is not None: log_scale = 0 for var, domain in sample_inputs.items(): if var in result.inputs and var not in self.inputs: log_scale -= math.log(domain.dtype) if log_scale != 0: result += log_scale return result
def contractor(fn): """ Decorator for contract implementations to simplify inputs. """ fn = interpreter.debug_logged(fn) return functools.partial(_simplify_contract, fn)
def integrator(fn): """ Decorator for integration implementations. """ fn = interpreter.debug_logged(fn) return functools.partial(_simplify_integrate, fn)
def moment_matching_reduce(op, arg, reduced_vars): return interpreter.debug_logged(arg.moment_matching_reduce)(op, reduced_vars)
def sequential_reduce(op, arg, reduced_vars): return interpreter.debug_logged(arg.sequential_reduce)(op, reduced_vars)
def eager_reduce(op, arg, reduced_vars): return interpreter.debug_logged(arg.eager_reduce)(op, reduced_vars)
def eager_unary(op, arg): if not arg.output.shape: return arg return interpreter.debug_logged(arg.eager_unary)(op)
def eager_unary(op, arg): return interpreter.debug_logged(arg.eager_unary)(op)
def subs_interpreter(cls, *args): expr = cls(*args) fresh_subs = tuple((k, v) for k, v in subs if k in expr.fresh) if fresh_subs: expr = interpreter.debug_logged(expr.eager_subs)(fresh_subs) return expr
def _joint_integrator(fn): """ Decorator for Integrate(Joint(...), ...) patterns. """ fn = interpreter.debug_logged(fn) return integrator(functools.partial(_simplify_integrate, fn))
target_adjs[v] = adjoint_values[v] if not isinstance(v, Variable): target_adjs[v] = bin_op(target_adjs[v], v) return target_adjs # logaddexp/add def _fail_default(*args): raise NotImplementedError("Should not be here! {}".format(args)) adjoint_ops = KeyedRegistry(default=_fail_default) if interpreter._DEBUG: adjoint_ops_register = adjoint_ops.register adjoint_ops.register = lambda *args: lambda fn: adjoint_ops_register(*args)(interpreter.debug_logged(fn)) @adjoint_ops.register(Tensor, AssociativeOp, AssociativeOp, Funsor, torch.Tensor, tuple, object) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {} @adjoint_ops.register(Binary, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, Funsor) def adjoint_binary(adj_redop, adj_binop, out_adj, op, lhs, rhs): assert (adj_redop, op) in ops.DISTRIBUTIVE_OPS lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs) lhs_adj = op(out_adj, rhs).reduce(adj_redop, lhs_reduced_vars) rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs)
if not isinstance(v, Variable): target_adjs[v] = bin_op(target_adjs[v], v) return target_adjs # logaddexp/add def _fail_default(*args): raise NotImplementedError("Should not be here! {}".format(args)) adjoint_ops = KeyedRegistry(default=_fail_default) if interpreter._DEBUG: adjoint_ops_register = adjoint_ops.register adjoint_ops.register = lambda *args: lambda fn: adjoint_ops_register( *args)(interpreter.debug_logged(fn)) @adjoint_ops.register(Tensor, AssociativeOp, AssociativeOp, Funsor, (np.ndarray, np.generic), tuple, object) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {} @adjoint_ops.register(Binary, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, Funsor) def adjoint_binary(adj_redop, adj_binop, out_adj, op, lhs, rhs): assert (adj_redop, op) in ops.DISTRIBUTIVE_OPS lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs) lhs_adj = op(out_adj, rhs).reduce(adj_redop, lhs_reduced_vars)