def test_normal_independent(): loc = random_tensor(OrderedDict(), reals(2)) scale = random_tensor(OrderedDict(), reals(2)).exp() fn = dist.Normal(loc['i'], scale['i'], value='z_i') assert fn.inputs['z_i'] == reals() d = Independent(fn, 'z', 'i', 'z_i') assert d.inputs['z'] == reals(2) sample = d.sample(frozenset(['z'])) assert isinstance(sample, Contraction) assert sample.inputs['z'] == reals(2)
def test_normal_independent(): loc = random_tensor(OrderedDict(), Reals[2]) scale = ops.exp(random_tensor(OrderedDict(), Reals[2])) fn = dist.Normal(loc['i'], scale['i'], value='z_i') assert fn.inputs['z_i'] == Real d = Independent(fn, 'z', 'i', 'z_i') assert d.inputs['z'] == Reals[2] rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) sample = d.sample(frozenset(['z']), rng_key=rng_key) assert isinstance(sample, Contraction) assert sample.inputs['z'] == Reals[2]
def test_independent(): f = Variable('x_i', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3))) assert f.inputs['x_i'] == reals(4, 5) assert f.inputs['i'] == bint(3) actual = Independent(f, 'x', 'i', 'x_i') assert actual.inputs['x'] == reals(3, 4, 5) assert 'i' not in actual.inputs x = Variable('x', reals(3, 4, 5)) expected = f(x_i=x['i']).reduce(ops.add, 'i') assert actual.inputs == expected.inputs assert actual.output == expected.output data = random_tensor(OrderedDict(), x.output) assert_close(actual(data), expected(data), atol=1e-5, rtol=1e-5) renamed = actual(x='y') assert isinstance(renamed, Independent) assert_close(renamed(y=data), expected(x=data), atol=1e-5, rtol=1e-5) # Ensure it's ok for .reals_var and .diag_var to be the same. renamed = actual(x='x_i') assert isinstance(renamed, Independent) assert_close(renamed(x_i=data), expected(x=data), atol=1e-5, rtol=1e-5)
def _independent_to_funsor(pyro_dist, event_inputs=()): event_names = tuple("_event_{}".format(len(event_inputs) + i) for i in range(pyro_dist.reinterpreted_batch_ndims)) result = dist_to_funsor(pyro_dist.base_dist, event_inputs + event_names) for name in reversed(event_names): result = Independent(result, "value", name, "value") return result
def eager_independent_joint(joint, reals_var, bint_var, diag_var): if diag_var not in joint.terms[0].fresh: return None delta = Independent(joint.terms[0], reals_var, bint_var, diag_var) new_terms = (delta, ) + tuple( t.reduce(ops.add, bint_var) for t in joint.terms[1:]) return reduce(joint.bin_op, new_terms)
def eager_affine_normal(matrix, loc, scale, value_x, value_y): assert len(matrix.output.shape) == 2 assert value_x.output == reals(matrix.output.shape[0]) assert value_y.output == reals(matrix.output.shape[1]) loc += value_x @ matrix int_inputs, (loc, scale) = align_tensors(loc, scale, expand=True) i_name = gensym("i") y_name = gensym("y") y_i_name = gensym("y_i") int_inputs[i_name] = bint(value_y.output.shape[0]) loc = Tensor(loc, int_inputs) scale = Tensor(scale, int_inputs) y_dist = Independent(Normal(loc, scale, y_i_name), y_name, i_name, y_i_name) return y_dist(**{y_name: value_y})
def test_subs_independent(): f = Variable('x', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3))) actual = Independent(f, 'x', 'i') assert 'i' not in actual.inputs y = Variable('y', reals(3, 4, 5)) fsub = y + (0. * random_tensor(OrderedDict(i=bint(7)))) actual = actual(x=fsub) assert actual.inputs['i'] == bint(7) expected = f(x=y['i']).reduce(ops.add, 'i') data = random_tensor(OrderedDict(i=bint(7)), y.output) assert_close(actual(y=data), expected(y=data))
def eager_independent(joint, reals_var, bint_var): for i, delta in enumerate(joint.deltas): if delta.name == reals_var or delta.name.startswith(reals_var + "__BOUND"): delta = Independent(delta, reals_var, bint_var) deltas = joint.deltas[:i] + (delta, ) + joint.deltas[1 + i:] discrete = joint.discrete if bint_var in discrete.inputs: discrete = discrete.reduce(ops.add, bint_var) gaussian = joint.gaussian if bint_var in gaussian.inputs: gaussian = gaussian.reduce(ops.add, bint_var) return Joint(deltas, discrete, gaussian) return None # defer to default implementation
def test_subs_independent(): f = Variable('x_i', Reals[4, 5]) + random_tensor(OrderedDict(i=Bint[3])) actual = Independent(f, 'x', 'i', 'x_i') assert 'i' not in actual.inputs assert 'x_i' not in actual.inputs y = Variable('y', Reals[3, 4, 5]) fsub = y + (0. * random_tensor(OrderedDict(i=Bint[7]))) actual = actual(x=fsub) assert actual.inputs['i'] == Bint[7] expected = f(x_i=y['i']).reduce(ops.add, 'i') data = random_tensor(OrderedDict(i=Bint[7]), y.output) assert_close(actual(y=data), expected(y=data))
def test_independent(): f = Variable('x', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3))) assert f.inputs['x'] == reals(4, 5) assert f.inputs['i'] == bint(3) actual = Independent(f, 'x', 'i') assert actual.inputs['x'] == reals(3, 4, 5) assert 'i' not in actual.inputs x = Variable('x', reals(3, 4, 5)) expected = f(x=x['i']).reduce(ops.add, 'i') assert actual.inputs == expected.inputs assert actual.output == expected.output data = random_tensor(OrderedDict(), x.output) assert_close(actual(data), expected(data), atol=1e-5, rtol=1e-5)
def test_bart(analytic_kl): global call_count call_count = 0 with interpretation(reflect): q = Independent( Independent( Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor( [[ -0.6077086925506592, -1.1546266078948975, -0.7021151781082153, -0.5303535461425781, -0.6365622282028198, -1.2423288822174072, -0.9941254258155823, -0.6287292242050171 ], [ -0.6987162828445435, -1.0875964164733887, -0.7337473630905151, -0.4713417589664459, -0.6674002408981323, -1.2478348016738892, -0.8939017057418823, -0.5238542556762695 ]], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ), 'real'), Gaussian( torch.tensor([ [[-0.3536059558391571], [-0.21779225766658783], [0.2840439975261688], [0.4531521499156952], [-0.1220812276005745], [-0.05519985035061836], [0.10932210087776184], [0.6656699776649475]], [[-0.39107921719551086], [ -0.20241987705230713 ], [0.2170514464378357], [0.4500560462474823], [0.27945515513420105], [-0.0490039587020874], [-0.06399798393249512], [0.846565842628479]] ], dtype=torch.float32), # noqa torch.tensor([ [[[1.984686255455017]], [[0.6699360013008118]], [[1.6215802431106567]], [[2.372016668319702]], [[1.77385413646698]], [[0.526767373085022]], [[0.8722561597824097]], [[2.1879124641418457]] ], [[[1.6996612548828125]], [[ 0.7535632252693176 ]], [[1.4946647882461548]], [[2.642792224884033]], [[1.7301604747772217]], [[0.5203893780708313]], [[1.055436372756958]], [[2.8370864391326904]]] ], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ( 'value_b1', reals(), ), )), )), 'gate_rate_b3', '_event_1_b2', 'value_b1'), 'gate_rate_t', 'time_b4', 'gate_rate_b3') p_prior = Contraction( ops.logaddexp, ops.add, frozenset({'state(time=1)_b11', 'state_b10'}), ( MarkovProduct( ops.logaddexp, ops.add, Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor(2.7672932147979736, dtype=torch.float32), (), 'real'), Gaussian( torch.tensor([-0.0, -0.0, 0.0, 0.0], dtype=torch.float32), torch.tensor([[ 98.01002502441406, 0.0, -99.0000228881836, -0.0 ], [ 0.0, 98.01002502441406, -0.0, -99.0000228881836 ], [ -99.0000228881836, -0.0, 100.0000228881836, 0.0 ], [ -0.0, -99.0000228881836, 0.0, 100.0000228881836 ]], dtype=torch.float32), # noqa ( ( 'state_b7', reals(2, ), ), ( 'state(time=1)_b8', reals(2, ), ), )), Subs( AffineNormal( Tensor( torch.tensor( [[ 0.03488487750291824, 0.07356668263673782, 0.19946961104869843, 0.5386509299278259, -0.708323061466217, 0.24411526322364807, -0.20855577290058136, -0.2421337217092514 ], [ 0.41762110590934753, 0.5272183418273926, -0.49835553765296936, -0.0363837406039238, -0.0005282597267068923, 0.2704298794269562, -0.155222088098526, -0.44802337884902954 ]], dtype=torch.float32), # noqa (), 'real'), Tensor( torch.tensor( [[ -0.003566693514585495, -0.2848514914512634, 0.037103548645973206, 0.12648648023605347, -0.18501518666744232, -0.20899859070777893, 0.04121830314397812, 0.0054807960987091064 ], [ 0.0021788496524095535, -0.18700894713401794, 0.08187370002269745, 0.13554862141609192, -0.10477752983570099, -0.20848378539085388, -0.01393645629286766, 0.011670656502246857 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Tensor( torch.tensor( [[ 0.5974780917167664, 0.864071786403656, 1.0236268043518066, 0.7147538065910339, 0.7423890233039856, 0.9462157487869263, 1.2132389545440674, 1.0596832036972046 ], [ 0.5787821412086487, 0.9178534150123596, 0.9074794054031372, 0.6600189208984375, 0.8473222255706787, 0.8426999449729919, 1.194266438484192, 1.0471148490905762 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Variable('state(time=1)_b8', reals(2, )), Variable('gate_rate_b6', reals(8, ))), (( 'gate_rate_b6', Binary( ops.GetitemOp(0), Variable('gate_rate_t', reals(2, 8)), Variable('time_b9', bint(2))), ), )), )), Variable('time_b9', bint(2)), frozenset({('state_b7', 'state(time=1)_b8')}), frozenset({('state(time=1)_b8', 'state(time=1)_b11'), ('state_b7', 'state_b10')})), # noqa Subs( dist.MultivariateNormal( Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32), (), 'real'), Tensor( torch.tensor([[10.0, 0.0], [0.0, 10.0]], dtype=torch.float32), (), 'real'), Variable('value_b5', reals(2, ))), (( 'value_b5', Variable('state_b10', reals(2, )), ), )), )) p_likelihood = Contraction( ops.add, ops.nullop, frozenset({'time_b17', 'destin_b16', 'origin_b15'}), ( Contraction( ops.logaddexp, ops.add, frozenset({'gated_b14'}), ( dist.Categorical( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_0, reals(2, 2, 2), (Variable('gate_rate_b12', reals(8, )), )), (( 'gate_rate_b12', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable('time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Variable('gated_b14', bint(2))), Stack( 'gated_b14', ( dist.Poisson( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_1, reals(2, 2), (Variable( 'gate_rate_b13', reals(8, )), )), (( 'gate_rate_b13', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable( 'time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), dist.Delta( Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), )), )), )) if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") with interpretation(monte_carlo): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood with interpretation(monte_carlo): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() assert call_count == 1
def test_sample_independent(): f = Variable('x_i', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3))) actual = Independent(f, 'x', 'i', 'x_i') assert actual.sample('i') assert actual.sample('j', {'i': 2})
def test_sample_independent(): f = Variable('x_i', Reals[4, 5]) + random_tensor(OrderedDict(i=Bint[3])) actual = Independent(f, 'x', 'i', 'x_i') assert actual.sample('i') assert actual.sample('j', {'i': 2})