コード例 #1
0
 def eager_reduce(self, op, reduced_vars):
     if op is self.op:
         lhs = self.lhs.reduce(op, reduced_vars)
         rhs = self.rhs.reduce(op, reduced_vars)
         return op(lhs, rhs)
     return interpreter.debug_logged(super(Binary,
                                           self).eager_reduce)(op,
                                                               reduced_vars)
コード例 #2
0
    def sample(self, sampled_vars, sample_inputs=None):
        """
        Create a Monte Carlo approximation to this funsor by replacing
        functions of ``sampled_vars`` with :class:`~funsor.delta.Delta` s.

        The result is a :class:`Funsor` with the same ``.inputs`` and
        ``.output`` as the original funsor (plus ``sample_inputs`` if
        provided), so that self can be replaced by the sample in expectation
        computations::

            y = x.sample(sampled_vars)
            assert y.inputs == x.inputs
            assert y.output == x.output
            exact = (x.exp() * integrand).reduce(ops.add)
            approx = (y.exp() * integrand).reduce(ops.add)

        If ``sample_inputs`` is provided, this creates a batch of samples
        scaled samples.

        :param frozenset sampled_vars: A set of input variables to sample.
        :param OrderedDict sample_inputs: An optional mapping from variable
            name to :class:`~funsor.domains.Domain` over which samples will
            be batched.
        """
        assert self.output == reals()
        assert isinstance(sampled_vars, frozenset)
        if sample_inputs is None:
            sample_inputs = OrderedDict()
        assert isinstance(sample_inputs, OrderedDict)
        if sampled_vars.isdisjoint(self.inputs):
            return self

        result = interpreter.debug_logged(self.unscaled_sample)(sampled_vars,
                                                                sample_inputs)
        if sample_inputs is not None:
            log_scale = 0
            for var, domain in sample_inputs.items():
                if var in result.inputs and var not in self.inputs:
                    log_scale -= math.log(domain.dtype)
            if log_scale != 0:
                result += log_scale
        return result
コード例 #3
0
def contractor(fn):
    """
    Decorator for contract implementations to simplify inputs.
    """
    fn = interpreter.debug_logged(fn)
    return functools.partial(_simplify_contract, fn)
コード例 #4
0
ファイル: integrate.py プロジェクト: fehiepsi/funsor
def integrator(fn):
    """
    Decorator for integration implementations.
    """
    fn = interpreter.debug_logged(fn)
    return functools.partial(_simplify_integrate, fn)
コード例 #5
0
ファイル: terms.py プロジェクト: lawrencechen0921/funsor
def moment_matching_reduce(op, arg, reduced_vars):
    return interpreter.debug_logged(arg.moment_matching_reduce)(op, reduced_vars)
コード例 #6
0
ファイル: terms.py プロジェクト: lawrencechen0921/funsor
def sequential_reduce(op, arg, reduced_vars):
    return interpreter.debug_logged(arg.sequential_reduce)(op, reduced_vars)
コード例 #7
0
ファイル: terms.py プロジェクト: lawrencechen0921/funsor
def eager_reduce(op, arg, reduced_vars):
    return interpreter.debug_logged(arg.eager_reduce)(op, reduced_vars)
コード例 #8
0
ファイル: terms.py プロジェクト: lawrencechen0921/funsor
def eager_unary(op, arg):
    if not arg.output.shape:
        return arg
    return interpreter.debug_logged(arg.eager_unary)(op)
コード例 #9
0
ファイル: terms.py プロジェクト: lawrencechen0921/funsor
def eager_unary(op, arg):
    return interpreter.debug_logged(arg.eager_unary)(op)
コード例 #10
0
ファイル: terms.py プロジェクト: lawrencechen0921/funsor
 def subs_interpreter(cls, *args):
     expr = cls(*args)
     fresh_subs = tuple((k, v) for k, v in subs if k in expr.fresh)
     if fresh_subs:
         expr = interpreter.debug_logged(expr.eager_subs)(fresh_subs)
     return expr
コード例 #11
0
ファイル: joint.py プロジェクト: fehiepsi/funsor
def _joint_integrator(fn):
    """
    Decorator for Integrate(Joint(...), ...) patterns.
    """
    fn = interpreter.debug_logged(fn)
    return integrator(functools.partial(_simplify_integrate, fn))
コード例 #12
0
ファイル: adjoint.py プロジェクト: tessythomas123/funsor
            target_adjs[v] = adjoint_values[v]
            if not isinstance(v, Variable):
                target_adjs[v] = bin_op(target_adjs[v], v)

        return target_adjs


# logaddexp/add
def _fail_default(*args):
    raise NotImplementedError("Should not be here! {}".format(args))


adjoint_ops = KeyedRegistry(default=_fail_default)
if interpreter._DEBUG:
    adjoint_ops_register = adjoint_ops.register
    adjoint_ops.register = lambda *args: lambda fn: adjoint_ops_register(*args)(interpreter.debug_logged(fn))


@adjoint_ops.register(Tensor, AssociativeOp, AssociativeOp, Funsor, torch.Tensor, tuple, object)
def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
    return {}


@adjoint_ops.register(Binary, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, Funsor)
def adjoint_binary(adj_redop, adj_binop, out_adj, op, lhs, rhs):
    assert (adj_redop, op) in ops.DISTRIBUTIVE_OPS

    lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs)
    lhs_adj = op(out_adj, rhs).reduce(adj_redop, lhs_reduced_vars)

    rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs)
コード例 #13
0
            if not isinstance(v, Variable):
                target_adjs[v] = bin_op(target_adjs[v], v)

        return target_adjs


# logaddexp/add
def _fail_default(*args):
    raise NotImplementedError("Should not be here! {}".format(args))


adjoint_ops = KeyedRegistry(default=_fail_default)
if interpreter._DEBUG:
    adjoint_ops_register = adjoint_ops.register
    adjoint_ops.register = lambda *args: lambda fn: adjoint_ops_register(
        *args)(interpreter.debug_logged(fn))


@adjoint_ops.register(Tensor, AssociativeOp, AssociativeOp, Funsor,
                      (np.ndarray, np.generic), tuple, object)
def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
    return {}


@adjoint_ops.register(Binary, AssociativeOp, AssociativeOp, Funsor,
                      AssociativeOp, Funsor, Funsor)
def adjoint_binary(adj_redop, adj_binop, out_adj, op, lhs, rhs):
    assert (adj_redop, op) in ops.DISTRIBUTIVE_OPS

    lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs)
    lhs_adj = op(out_adj, rhs).reduce(adj_redop, lhs_reduced_vars)