def dispatched_interpretation(fn): """ Decorator to create a dispatched interpretation function. """ registry = KeyedRegistry(default=lambda *args: None) if _DEBUG: fn.register = lambda *args: lambda fn: registry.register(*args)(debug_logged(fn)) else: fn.register = registry.register fn.dispatch = registry.dispatch return fn
target_adjs[v] = adjoint_values[v] if not isinstance(v, Variable): target_adjs[v] = bin_op(target_adjs[v], v) return target_adjs # logaddexp/add def _fail_default(*args): raise NotImplementedError("Should not be here! {}".format(args)) adjoint_ops = KeyedRegistry(default=_fail_default) if interpreter._DEBUG: adjoint_ops_register = adjoint_ops.register adjoint_ops.register = lambda *args: lambda fn: adjoint_ops_register(*args)(interpreter.debug_logged(fn)) @adjoint_ops.register(Tensor, AssociativeOp, AssociativeOp, Funsor, torch.Tensor, tuple, object) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {} @adjoint_ops.register(Binary, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, Funsor) def adjoint_binary(adj_redop, adj_binop, out_adj, op, lhs, rhs): assert (adj_redop, op) in ops.DISTRIBUTIVE_OPS lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs) lhs_adj = op(out_adj, rhs).reduce(adj_redop, lhs_reduced_vars) rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs)