示例#1
0
def dispatched_interpretation(fn):
    """
    Decorator to create a dispatched interpretation function.
    """
    registry = KeyedRegistry(default=lambda *args: None)
    if _DEBUG:
        fn.register = lambda *args: lambda fn: registry.register(*args)(debug_logged(fn))
    else:
        fn.register = registry.register
    fn.dispatch = registry.dispatch
    return fn
示例#2
0
文件: delta.py 项目: MillerJJY/funsor
    :param Funsor expr: An expression with a free variable.
    :param Funsor value: A target value.
    :return: A tuple ``(name, point, log_abs_det_jacobian)``
    :rtype: tuple
    :raises: ValueError
    """
    assert isinstance(expr, Funsor)
    assert isinstance(value, Funsor)
    result = solve.dispatch(type(expr), *(expr._ast_values + (value, )))
    if result is None:
        raise ValueError("Cannot substitute into a Delta: {}".format(value))
    return result


_solve = KeyedRegistry(lambda *args: None)
solve.dispatch = _solve.__call__
solve.register = _solve.register


@solve.register(Variable, str, Domain, Funsor)
@debug_logged
def solve_variable(name, output, y):
    assert y.output == output
    point = y
    log_density = Number(0)
    return name, point, log_density


@solve.register(Unary, TransformOp, Funsor, Funsor)
@debug_logged
示例#3
0
            adjoint_values[v] = adjoint_values[v] + adjv  # product in logspace

    target_adjs = {}
    for v in targets:
        target_adjs[v] = adjoint_values[v] / multiplicities[v]
        if not isinstance(v, Variable):
            target_adjs[v] = target_adjs[v] + v
    return target_adjs


# logaddexp/add
def _fail_default(*args):
    raise NotImplementedError("Should not be here! {}".format(args))


adjoint_ops = KeyedRegistry(default=_fail_default)


@adjoint_ops.register(Tensor, Funsor, Funsor, torch.Tensor, tuple, object)
def adjoint_tensor(out_adj, out, data, inputs, dtype):
    all_vars = frozenset(k for (k, v) in inputs)
    in_adjs = {}
    for (k, v) in inputs:
        in_adj = (out_adj + out).reduce(ops.logaddexp, all_vars - {k})
        in_adjs[Variable(k, v)] = in_adj
    return in_adjs


@adjoint_ops.register(Binary, Funsor, Funsor, AssociativeOp, Funsor, Funsor)
def adjoint_binary(out_adj, out, op, lhs, rhs):
    assert op is ops.add
示例#4
0
 def __init__(cls, name, bases, dct):
     super().__init__(name, bases, dct)
     cls.registry = KeyedRegistry(default=lambda *args: None)
     cls.dispatch = cls.registry.dispatch
示例#5
0
        target_adjs = {}
        for v in targets:
            target_adjs[v] = adjoint_values[v]
            if not isinstance(v, Variable):
                target_adjs[v] = bin_op(target_adjs[v], v)

        return target_adjs


# logaddexp/add
def _fail_default(*args):
    raise NotImplementedError("Should not be here! {}".format(args))


adjoint_ops = KeyedRegistry(default=_fail_default)
if interpreter._DEBUG:
    adjoint_ops_register = adjoint_ops.register
    adjoint_ops.register = lambda *args: lambda fn: adjoint_ops_register(*args)(interpreter.debug_logged(fn))


@adjoint_ops.register(Tensor, AssociativeOp, AssociativeOp, Funsor, torch.Tensor, tuple, object)
def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
    return {}


@adjoint_ops.register(Binary, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, Funsor)
def adjoint_binary(adj_redop, adj_binop, out_adj, op, lhs, rhs):
    assert (adj_redop, op) in ops.DISTRIBUTIVE_OPS

    lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs)