def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): for i, v in enumerate(terms): if not isinstance(v, Contraction): continue if v.red_op is nullop and (v.bin_op, bin_op) in DISTRIBUTIVE_OPS: # a * e * (b + c + d) -> (a * e * b) + (a * e * c) + (a * e * d) new_terms = tuple( Contraction(v.red_op, bin_op, v.reduced_vars, *(terms[:i] + (vt,) + terms[i+1:])) for vt in v.terms) return Contraction(red_op, v.bin_op, reduced_vars, *new_terms) if red_op in (v.red_op, nullop) and (v.red_op, bin_op) in DISTRIBUTIVE_OPS: new_terms = terms[:i] + (Contraction(v.red_op, v.bin_op, frozenset(), *v.terms),) + terms[i+1:] return Contraction(v.red_op, bin_op, v.reduced_vars, *new_terms).reduce(red_op, reduced_vars) if v.red_op in (red_op, nullop) and bin_op in (v.bin_op, nullop): red_op = v.red_op if red_op is nullop else red_op bin_op = v.bin_op if bin_op is nullop else bin_op new_terms = terms[:i] + v.terms + terms[i+1:] return Contraction(red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms) return None
def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms): if red_op is nullop or bin_op is nullop or not ( red_op, bin_op) in DISTRIBUTIVE_OPS: return None # build opt_einsum optimizer IR inputs = [frozenset(term.inputs) for term in terms] size_dict = { k: ((REAL_SIZE * v.num_elements) if v.dtype == 'real' else v.dtype) for term in terms for k, v in term.inputs.items() } outputs = frozenset().union(*inputs) - reduced_vars # optimize path with greedy opt_einsum optimizer # TODO switch to new 'auto' strategy path = greedy(inputs, outputs, size_dict) # first prepare a reduce_dim counter to avoid early reduction reduce_dim_counter = collections.Counter() for input in inputs: reduce_dim_counter.update({d: 1 for d in input}) operands = list(terms) for (a, b) in path: b, a = tuple(sorted((a, b), reverse=True)) tb = operands.pop(b) ta = operands.pop(a) # don't reduce a dimension too early - keep a collections.Counter # and only reduce when the dimension is removed from all lhs terms in path reduce_dim_counter.subtract( {d: 1 for d in reduced_vars & frozenset(ta.inputs.keys())}) reduce_dim_counter.subtract( {d: 1 for d in reduced_vars & frozenset(tb.inputs.keys())}) # reduce variables that don't appear in other terms both_vars = frozenset(ta.inputs.keys()) | frozenset(tb.inputs.keys()) path_end_reduced_vars = frozenset(d for d in reduced_vars & both_vars if reduce_dim_counter[d] == 0) # count new appearance of variables that aren't reduced reduce_dim_counter.update( {d: 1 for d in reduced_vars & (both_vars - path_end_reduced_vars)}) path_end = Contraction(red_op if path_end_reduced_vars else nullop, bin_op, path_end_reduced_vars, ta, tb) operands.append(path_end) # reduce any remaining dims, if necessary final_reduced_vars = frozenset(d for ( d, count) in reduce_dim_counter.items() if count > 0) & reduced_vars if final_reduced_vars: path_end = path_end.reduce(red_op, final_reduced_vars) return path_end
def adjoint_contract(adj_redop, adj_binop, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs): assert sum_op is nullop or (sum_op, prod_op) in ops.DISTRIBUTIVE_OPS lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs) lhs_adj = Contraction(sum_op if sum_op is not nullop else adj_redop, prod_op, lhs_reduced_vars, out_adj, rhs) rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs) rhs_adj = Contraction(sum_op if sum_op is not nullop else adj_redop, prod_op, rhs_reduced_vars, out_adj, lhs) return {lhs: lhs_adj, rhs: rhs_adj}
def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gaussian): approx_vars = frozenset( k for k in reduced_vars if k in gaussian.inputs and gaussian.inputs[k].dtype != 'real') exact_vars = reduced_vars - approx_vars if exact_vars and approx_vars: return Contraction(red_op, bin_op, exact_vars, discrete, gaussian).reduce(red_op, approx_vars) if approx_vars and not exact_vars: discrete += gaussian.log_normalizer new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) num_elements = reduce(ops.mul, [ gaussian.inputs[k].num_elements for k in approx_vars.difference(discrete.inputs) ], 1) if num_elements != 1: new_discrete -= math.log(num_elements) int_inputs = OrderedDict( (k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real') probs = (discrete - new_discrete.clamp_finite()).exp() old_loc = Tensor( gaussian.info_vec.unsqueeze(-1).cholesky_solve( gaussian._precision_chol).squeeze(-1), int_inputs) new_loc = (probs * old_loc).reduce(ops.add, approx_vars) old_cov = Tensor(cholesky_inverse(gaussian._precision_chol), int_inputs) diff = old_loc - new_loc outers = Tensor( diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs) new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) + (probs * outers).reduce(ops.add, approx_vars)) # Numerically stabilize by adding bogus precision to empty components. total = probs.reduce(ops.add, approx_vars) mask = (total.data == 0).to( total.data.dtype).unsqueeze(-1).unsqueeze(-1) new_cov.data += mask * torch.eye(new_cov.data.size(-1)) new_precision = Tensor(cholesky_inverse(cholesky(new_cov.data)), new_cov.inputs) new_info_vec = new_precision.data.matmul( new_loc.data.unsqueeze(-1)).squeeze(-1) new_inputs = new_loc.inputs.copy() new_inputs.update( (k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real') new_gaussian = Gaussian(new_info_vec, new_precision.data, new_inputs) new_discrete -= new_gaussian.log_normalizer return new_discrete + new_gaussian return None
def sequential_sum_product(sum_op, prod_op, trans, time, step): """ For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``, computes a recursion equivalent to:: tail_time = 1 + arange("time", trans.inputs["time"].size - 1) tail = sequential_sum_product(sum_op, prod_op, trans(time=tail_time), time, {"prev": "curr"}) return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \ .reduce(sum_op, "drop") but does so efficiently in parallel in O(log(time)). :param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation. :param ~funsor.ops.AssociativeOp prod_op: A semiring product operation. :param ~funsor.terms.Funsor trans: A transition funsor. :param Variable time: The time input dimension. :param dict step: A dict mapping previous variables to current variables. This can contain multiple pairs of prev->curr variable names. """ assert isinstance(sum_op, AssociativeOp) assert isinstance(prod_op, AssociativeOp) assert isinstance(trans, Funsor) assert isinstance(time, Variable) assert isinstance(step, dict) assert all(isinstance(k, str) for k in step.keys()) assert all(isinstance(v, str) for v in step.values()) if time.name in trans.inputs: assert time.output == trans.inputs[time.name] step = OrderedDict(sorted(step.items())) drop = tuple("_drop_{}".format(i) for i in range(len(step))) prev_to_drop = dict(zip(step.keys(), drop)) curr_to_drop = dict(zip(step.values(), drop)) drop = frozenset(drop) time, duration = time.name, time.output.size while duration > 1: even_duration = duration // 2 * 2 x = trans(**{time: Slice(time, 0, even_duration, 2, duration)}, **curr_to_drop) y = trans(**{time: Slice(time, 1, even_duration, 2, duration)}, **prev_to_drop) contracted = Contraction(sum_op, prod_op, drop, x, y) if duration > even_duration: extra = trans(**{time: Slice(time, duration - 1, duration)}) contracted = Cat(time, (contracted, extra)) trans = contracted duration = (duration + 1) // 2 return trans(**{time: 0})
def test_eager_contract_tensor_tensor(red_op, bin_op, x_inputs, x_shape, y_inputs, y_shape): backend = get_backend() inputs = OrderedDict([("i", bint(4)), ("j", bint(5)), ("k", bint(6))]) x_inputs = OrderedDict((k, v) for k, v in inputs.items() if k in x_inputs) y_inputs = OrderedDict((k, v) for k, v in inputs.items() if k in y_inputs) x = random_tensor(x_inputs, reals(*x_shape)) y = random_tensor(y_inputs, reals(*y_shape)) xy = bin_op(x, y) all_vars = frozenset(x.inputs).union(y.inputs) for n in range(len(all_vars)): for reduced_vars in map(frozenset, itertools.combinations(all_vars, n)): print(f"reduced_vars = {reduced_vars}") expected = xy.reduce(red_op, reduced_vars) actual = Contraction(red_op, bin_op, reduced_vars, (x, y)) assert_close(actual, expected, atol=1e-4, rtol=5e-4 if backend == "jax" else 1e-4)
def test_affine_subs(): # This was recorded from test_pyro_convert. x = Subs( Gaussian( torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32), # noqa torch.tensor([[1.0199567079544067, 0.9840421676635742, -0.473368763923645, 0.34206756949424744, -0.7562517523765564], [0.9840421676635742, 1.511502742767334, -1.7593903541564941, 0.6647964119911194, -0.5119513273239136], [-0.4733688533306122, -1.7593903541564941, 3.2386727333068848, -0.9345928430557251, -0.1534711718559265], [0.34206756949424744, 0.6647964119911194, -0.9345928430557251, 0.3141004145145416, -0.12399007380008698], [-0.7562517523765564, -0.5119513273239136, -0.1534711718559265, -0.12399007380008698, 0.6450173854827881]], dtype=torch.float32), # noqa (('state_1_b6', reals(3,),), ('obs_b2', reals(2,),),)), (('obs_b2', Contraction(ops.nullop, ops.add, frozenset(), (Variable('bias_b5', reals(2,)), Tensor( torch.tensor([-2.1787893772125244, 0.5684312582015991], dtype=torch.float32), # noqa (), 'real'),)),),)) assert isinstance(x, (Gaussian, Contraction)), x.pretty()
def naive_contract_einsum(eqn, *terms, **kwargs): """ Use for testing Contract against einsum """ assert "plates" not in kwargs backend = kwargs.pop('backend', 'torch') if backend in BACKEND_OPS: sum_op, prod_op = BACKEND_OPS[backend] else: raise ValueError("{} backend not implemented".format(backend)) assert isinstance(eqn, str) assert all(isinstance(term, Funsor) for term in terms) inputs, output = eqn.split('->') inputs = inputs.split(',') assert len(inputs) == len(terms) assert len(output.split(',')) == 1 input_dims = frozenset(d for inp in inputs for d in inp) output_dims = frozenset(d for d in output) reduced_vars = input_dims - output_dims return Contraction(sum_op, prod_op, reduced_vars, *terms)
def test_bart(analytic_kl): global call_count call_count = 0 with interpretation(reflect): q = Independent( Independent( Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor( [[ -0.6077086925506592, -1.1546266078948975, -0.7021151781082153, -0.5303535461425781, -0.6365622282028198, -1.2423288822174072, -0.9941254258155823, -0.6287292242050171 ], [ -0.6987162828445435, -1.0875964164733887, -0.7337473630905151, -0.4713417589664459, -0.6674002408981323, -1.2478348016738892, -0.8939017057418823, -0.5238542556762695 ]], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ), 'real'), Gaussian( torch.tensor([ [[-0.3536059558391571], [-0.21779225766658783], [0.2840439975261688], [0.4531521499156952], [-0.1220812276005745], [-0.05519985035061836], [0.10932210087776184], [0.6656699776649475]], [[-0.39107921719551086], [ -0.20241987705230713 ], [0.2170514464378357], [0.4500560462474823], [0.27945515513420105], [-0.0490039587020874], [-0.06399798393249512], [0.846565842628479]] ], dtype=torch.float32), # noqa torch.tensor([ [[[1.984686255455017]], [[0.6699360013008118]], [[1.6215802431106567]], [[2.372016668319702]], [[1.77385413646698]], [[0.526767373085022]], [[0.8722561597824097]], [[2.1879124641418457]] ], [[[1.6996612548828125]], [[ 0.7535632252693176 ]], [[1.4946647882461548]], [[2.642792224884033]], [[1.7301604747772217]], [[0.5203893780708313]], [[1.055436372756958]], [[2.8370864391326904]]] ], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ( 'value_b1', reals(), ), )), )), 'gate_rate_b3', '_event_1_b2', 'value_b1'), 'gate_rate_t', 'time_b4', 'gate_rate_b3') p_prior = Contraction( ops.logaddexp, ops.add, frozenset({'state(time=1)_b11', 'state_b10'}), ( MarkovProduct( ops.logaddexp, ops.add, Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor(2.7672932147979736, dtype=torch.float32), (), 'real'), Gaussian( torch.tensor([-0.0, -0.0, 0.0, 0.0], dtype=torch.float32), torch.tensor([[ 98.01002502441406, 0.0, -99.0000228881836, -0.0 ], [ 0.0, 98.01002502441406, -0.0, -99.0000228881836 ], [ -99.0000228881836, -0.0, 100.0000228881836, 0.0 ], [ -0.0, -99.0000228881836, 0.0, 100.0000228881836 ]], dtype=torch.float32), # noqa ( ( 'state_b7', reals(2, ), ), ( 'state(time=1)_b8', reals(2, ), ), )), Subs( AffineNormal( Tensor( torch.tensor( [[ 0.03488487750291824, 0.07356668263673782, 0.19946961104869843, 0.5386509299278259, -0.708323061466217, 0.24411526322364807, -0.20855577290058136, -0.2421337217092514 ], [ 0.41762110590934753, 0.5272183418273926, -0.49835553765296936, -0.0363837406039238, -0.0005282597267068923, 0.2704298794269562, -0.155222088098526, -0.44802337884902954 ]], dtype=torch.float32), # noqa (), 'real'), Tensor( torch.tensor( [[ -0.003566693514585495, -0.2848514914512634, 0.037103548645973206, 0.12648648023605347, -0.18501518666744232, -0.20899859070777893, 0.04121830314397812, 0.0054807960987091064 ], [ 0.0021788496524095535, -0.18700894713401794, 0.08187370002269745, 0.13554862141609192, -0.10477752983570099, -0.20848378539085388, -0.01393645629286766, 0.011670656502246857 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Tensor( torch.tensor( [[ 0.5974780917167664, 0.864071786403656, 1.0236268043518066, 0.7147538065910339, 0.7423890233039856, 0.9462157487869263, 1.2132389545440674, 1.0596832036972046 ], [ 0.5787821412086487, 0.9178534150123596, 0.9074794054031372, 0.6600189208984375, 0.8473222255706787, 0.8426999449729919, 1.194266438484192, 1.0471148490905762 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Variable('state(time=1)_b8', reals(2, )), Variable('gate_rate_b6', reals(8, ))), (( 'gate_rate_b6', Binary( ops.GetitemOp(0), Variable('gate_rate_t', reals(2, 8)), Variable('time_b9', bint(2))), ), )), )), Variable('time_b9', bint(2)), frozenset({('state_b7', 'state(time=1)_b8')}), frozenset({('state(time=1)_b8', 'state(time=1)_b11'), ('state_b7', 'state_b10')})), # noqa Subs( dist.MultivariateNormal( Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32), (), 'real'), Tensor( torch.tensor([[10.0, 0.0], [0.0, 10.0]], dtype=torch.float32), (), 'real'), Variable('value_b5', reals(2, ))), (( 'value_b5', Variable('state_b10', reals(2, )), ), )), )) p_likelihood = Contraction( ops.add, ops.nullop, frozenset({'time_b17', 'destin_b16', 'origin_b15'}), ( Contraction( ops.logaddexp, ops.add, frozenset({'gated_b14'}), ( dist.Categorical( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_0, reals(2, 2, 2), (Variable('gate_rate_b12', reals(8, )), )), (( 'gate_rate_b12', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable('time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Variable('gated_b14', bint(2))), Stack( 'gated_b14', ( dist.Poisson( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_1, reals(2, 2), (Variable( 'gate_rate_b13', reals(8, )), )), (( 'gate_rate_b13', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable( 'time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), dist.Delta( Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), )), )), )) if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") with interpretation(monte_carlo): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood with interpretation(monte_carlo): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() assert call_count == 1
def normalize_integrate(log_measure, integrand, reduced_vars): return Contraction(ops.add, ops.mul, reduced_vars, log_measure.exp(), integrand)