def test_stack_subs(): x = Variable('x', reals()) y = Variable('y', reals()) z = Variable('z', reals()) j = Variable('j', bint(3)) f = Stack('i', (Number(0), x, y * z)) check_funsor(f, { 'i': bint(3), 'x': reals(), 'y': reals(), 'z': reals() }, reals()) assert f(i=Number(0, 3)) is Number(0) assert f(i=Number(1, 3)) is x assert f(i=Number(2, 3)) is y * z assert f(i=j) is Stack('j', (Number(0), x, y * z)) assert f(i='j') is Stack('j', (Number(0), x, y * z)) assert f.reduce(ops.add, 'i') is Number(0) + x + (y * z) assert f(x=0) is Stack('i', (Number(0), Number(0), y * z)) assert f(y=x) is Stack('i', (Number(0), x, x * z)) assert f(x=0, y=x) is Stack('i', (Number(0), Number(0), x * z)) assert f(x=0, y=x, i=Number(2, 3)) is x * z assert f(x=0, i=j) is Stack('j', (Number(0), Number(0), y * z)) assert f(x=0, i='j') is Stack('j', (Number(0), Number(0), y * z))
def test_stack_slice(start, stop, step): xs = tuple(map(Number, range(10))) actual = Stack('i', xs)(i=Slice('j', start, stop, step, dtype=10)) expected = Stack('j', xs[start:stop:step]) assert type(actual) == type(expected) assert actual.name == expected.name assert actual.parts == expected.parts
def mixed_sequential_sum_product(sum_op, prod_op, trans, time, step, num_segments=None): """ For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``, computes a recursion equivalent to:: tail_time = 1 + arange("time", trans.inputs["time"].size - 1) tail = sequential_sum_product(sum_op, prod_op, trans(time=tail_time), time, {"prev": "curr"}) return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \ .reduce(sum_op, "drop") by mixing parallel and serial scan algorithms over ``num_segments`` segments. :param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation. :param ~funsor.ops.AssociativeOp prod_op: A semiring product operation. :param ~funsor.terms.Funsor trans: A transition funsor. :param Variable time: The time input dimension. :param dict step: A dict mapping previous variables to current variables. This can contain multiple pairs of prev->curr variable names. :param int num_segments: number of segments for the first stage """ time_var, time, duration = time, time.name, time.output.size num_segments = duration if num_segments is None else num_segments assert num_segments > 0 and duration > 0 # handle unevenly sized segments by chopping off the final segment and calling mixed_sequential_sum_product again if duration % num_segments and duration - duration % num_segments > 0: remainder = trans(**{time: Slice(time, duration - duration % num_segments, duration, 1, duration)}) initial = trans(**{time: Slice(time, 0, duration - duration % num_segments, 1, duration)}) initial_eliminated = mixed_sequential_sum_product( sum_op, prod_op, initial, Variable(time, bint(duration - duration % num_segments)), step, num_segments=num_segments) final = Cat(time, (Stack(time, (initial_eliminated,)), remainder)) final_eliminated = naive_sequential_sum_product( sum_op, prod_op, final, Variable(time, bint(1 + duration % num_segments)), step) return final_eliminated # handle degenerate cases that reduce to a single stage if num_segments == 1: return naive_sequential_sum_product(sum_op, prod_op, trans, time_var, step) if num_segments >= duration: return sequential_sum_product(sum_op, prod_op, trans, time_var, step) # break trans into num_segments segments of equal length segment_length = duration // num_segments segments = [trans(**{time: Slice(time, i * segment_length, (i + 1) * segment_length, 1, duration)}) for i in range(num_segments)] first_stage_result = naive_sequential_sum_product( sum_op, prod_op, Stack(time + "__SEGMENTED", tuple(segments)), Variable(time, bint(segment_length)), step) second_stage_result = sequential_sum_product( sum_op, prod_op, first_stage_result, Variable(time + "__SEGMENTED", bint(num_segments)), step) return second_stage_result
def test_cat_simple(): x = Stack('i', (Number(0), Number(1), Number(2))) y = Stack('i', (Number(3), Number(4))) assert Cat('i', (x, )) is x assert Cat('i', (y, )) is y xy = Cat('i', (x, y)) assert xy.inputs == OrderedDict(i=bint(5)) assert xy.name == 'i' for i in range(5): assert xy(i=i) is Number(i)
def test_stack_simple(): x = Number(0.) y = Number(1.) z = Number(4.) xyz = Stack('i', (x, y, z)) check_funsor(xyz, {'i': bint(3)}, reals()) assert xyz(i=Number(0, 3)) is x assert xyz(i=Number(1, 3)) is y assert xyz(i=Number(2, 3)) is z assert xyz.reduce(ops.add, 'i') == 5.
def test_reduce_syntactic_sugar(): i = Variable("i", bint(3)) x = Stack("i", (Number(1), Number(2), Number(3))) expected = Number(1 + 2 + 3) assert x.reduce(ops.add) is expected assert x.reduce(ops.add, "i") is expected assert x.reduce(ops.add, {"i"}) is expected assert x.reduce(ops.add, frozenset(["i"])) is expected assert x.reduce(ops.add, i) is expected assert x.reduce(ops.add, {i}) is expected assert x.reduce(ops.add, frozenset([i])) is expected
def test_quote(interp): with interpretation(interp): x = Variable('x', bint(8)) check_quote(x) y = Variable('y', reals(8, 3, 3)) check_quote(y) check_quote(y[x]) z = Stack('i', (Number(0), Variable('z', reals()))) check_quote(z) check_quote(z(i=0)) check_quote(z(i=Slice('i', 0, 1, 1, 2))) check_quote(z.reduce(ops.add, 'i')) check_quote(Cat('i', (z, z, z))) check_quote(Lambda(Variable('i', bint(2)), z))
def Uniform(components): components = tuple(components) size = len(components) if size == 1: return components[0] var = Variable('v', bint(size)) return (Stack(var.name, components).reduce(ops.logaddexp, var.name) - math.log(size))
def test_funsor_stack(output): x = random_tensor(OrderedDict([ ('i', bint(2)), ]), output) y = random_tensor(OrderedDict([ ('j', bint(3)), ]), output) z = random_tensor(OrderedDict([ ('i', bint(2)), ('k', bint(4)), ]), output) xy = Stack('t', (x, y)) assert isinstance(xy, Tensor) assert xy.inputs == OrderedDict([ ('t', bint(2)), ('i', bint(2)), ('j', bint(3)), ]) assert xy.output == output for j in range(3): assert_close(xy(t=0, j=j), x) for i in range(2): assert_close(xy(t=1, i=i), y) xyz = Stack('t', (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict([ ('t', bint(3)), ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ]) assert xy.output == output for j in range(3): for k in range(4): assert_close(xyz(t=0, j=j, k=k), x) for i in range(2): for k in range(4): assert_close(xyz(t=1, i=i, k=k), y) for j in range(3): assert_close(xyz(t=2, j=j), z)
def test_stack_subs(): x = Variable('x', Real) y = Variable('y', Real) z = Variable('z', Real) j = Variable('j', Bint[3]) f = Stack('i', (Number(0), x, y * z)) check_funsor(f, {'i': Bint[3], 'x': Real, 'y': Real, 'z': Real}, Real) assert f(i=Number(0, 3)) is Number(0) assert f(i=Number(1, 3)) is x assert f(i=Number(2, 3)) is y * z assert f(i=j) is Stack('j', (Number(0), x, y * z)) assert f(i='j') is Stack('j', (Number(0), x, y * z)) assert f.reduce(ops.add, 'i') is Number(0) + x + (y * z) assert f(x=0) is Stack('i', (Number(0), Number(0), y * z)) assert f(y=x) is Stack('i', (Number(0), x, x * z)) assert f(x=0, y=x) is Stack('i', (Number(0), Number(0), x * z)) assert f(x=0, y=x, i=Number(2, 3)) is x * z assert f(x=0, i=j) is Stack('j', (Number(0), Number(0), y * z)) assert f(x=0, i='j') is Stack('j', (Number(0), Number(0), y * z))
def test_bart(analytic_kl): global call_count call_count = 0 with interpretation(reflect): q = Independent( Independent( Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor( [[ -0.6077086925506592, -1.1546266078948975, -0.7021151781082153, -0.5303535461425781, -0.6365622282028198, -1.2423288822174072, -0.9941254258155823, -0.6287292242050171 ], [ -0.6987162828445435, -1.0875964164733887, -0.7337473630905151, -0.4713417589664459, -0.6674002408981323, -1.2478348016738892, -0.8939017057418823, -0.5238542556762695 ]], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ), 'real'), Gaussian( torch.tensor([ [[-0.3536059558391571], [-0.21779225766658783], [0.2840439975261688], [0.4531521499156952], [-0.1220812276005745], [-0.05519985035061836], [0.10932210087776184], [0.6656699776649475]], [[-0.39107921719551086], [ -0.20241987705230713 ], [0.2170514464378357], [0.4500560462474823], [0.27945515513420105], [-0.0490039587020874], [-0.06399798393249512], [0.846565842628479]] ], dtype=torch.float32), # noqa torch.tensor([ [[[1.984686255455017]], [[0.6699360013008118]], [[1.6215802431106567]], [[2.372016668319702]], [[1.77385413646698]], [[0.526767373085022]], [[0.8722561597824097]], [[2.1879124641418457]] ], [[[1.6996612548828125]], [[ 0.7535632252693176 ]], [[1.4946647882461548]], [[2.642792224884033]], [[1.7301604747772217]], [[0.5203893780708313]], [[1.055436372756958]], [[2.8370864391326904]]] ], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ( 'value_b1', reals(), ), )), )), 'gate_rate_b3', '_event_1_b2', 'value_b1'), 'gate_rate_t', 'time_b4', 'gate_rate_b3') p_prior = Contraction( ops.logaddexp, ops.add, frozenset({'state(time=1)_b11', 'state_b10'}), ( MarkovProduct( ops.logaddexp, ops.add, Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor(2.7672932147979736, dtype=torch.float32), (), 'real'), Gaussian( torch.tensor([-0.0, -0.0, 0.0, 0.0], dtype=torch.float32), torch.tensor([[ 98.01002502441406, 0.0, -99.0000228881836, -0.0 ], [ 0.0, 98.01002502441406, -0.0, -99.0000228881836 ], [ -99.0000228881836, -0.0, 100.0000228881836, 0.0 ], [ -0.0, -99.0000228881836, 0.0, 100.0000228881836 ]], dtype=torch.float32), # noqa ( ( 'state_b7', reals(2, ), ), ( 'state(time=1)_b8', reals(2, ), ), )), Subs( AffineNormal( Tensor( torch.tensor( [[ 0.03488487750291824, 0.07356668263673782, 0.19946961104869843, 0.5386509299278259, -0.708323061466217, 0.24411526322364807, -0.20855577290058136, -0.2421337217092514 ], [ 0.41762110590934753, 0.5272183418273926, -0.49835553765296936, -0.0363837406039238, -0.0005282597267068923, 0.2704298794269562, -0.155222088098526, -0.44802337884902954 ]], dtype=torch.float32), # noqa (), 'real'), Tensor( torch.tensor( [[ -0.003566693514585495, -0.2848514914512634, 0.037103548645973206, 0.12648648023605347, -0.18501518666744232, -0.20899859070777893, 0.04121830314397812, 0.0054807960987091064 ], [ 0.0021788496524095535, -0.18700894713401794, 0.08187370002269745, 0.13554862141609192, -0.10477752983570099, -0.20848378539085388, -0.01393645629286766, 0.011670656502246857 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Tensor( torch.tensor( [[ 0.5974780917167664, 0.864071786403656, 1.0236268043518066, 0.7147538065910339, 0.7423890233039856, 0.9462157487869263, 1.2132389545440674, 1.0596832036972046 ], [ 0.5787821412086487, 0.9178534150123596, 0.9074794054031372, 0.6600189208984375, 0.8473222255706787, 0.8426999449729919, 1.194266438484192, 1.0471148490905762 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Variable('state(time=1)_b8', reals(2, )), Variable('gate_rate_b6', reals(8, ))), (( 'gate_rate_b6', Binary( ops.GetitemOp(0), Variable('gate_rate_t', reals(2, 8)), Variable('time_b9', bint(2))), ), )), )), Variable('time_b9', bint(2)), frozenset({('state_b7', 'state(time=1)_b8')}), frozenset({('state(time=1)_b8', 'state(time=1)_b11'), ('state_b7', 'state_b10')})), # noqa Subs( dist.MultivariateNormal( Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32), (), 'real'), Tensor( torch.tensor([[10.0, 0.0], [0.0, 10.0]], dtype=torch.float32), (), 'real'), Variable('value_b5', reals(2, ))), (( 'value_b5', Variable('state_b10', reals(2, )), ), )), )) p_likelihood = Contraction( ops.add, ops.nullop, frozenset({'time_b17', 'destin_b16', 'origin_b15'}), ( Contraction( ops.logaddexp, ops.add, frozenset({'gated_b14'}), ( dist.Categorical( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_0, reals(2, 2, 2), (Variable('gate_rate_b12', reals(8, )), )), (( 'gate_rate_b12', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable('time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Variable('gated_b14', bint(2))), Stack( 'gated_b14', ( dist.Poisson( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_1, reals(2, 2), (Variable( 'gate_rate_b13', reals(8, )), )), (( 'gate_rate_b13', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable( 'time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), dist.Delta( Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), )), )), )) if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") with interpretation(monte_carlo): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood with interpretation(monte_carlo): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() assert call_count == 1
def test_cat(name): with interpretation(reflect): x = Stack("t", (Number(1), Number(2))) y = Stack("t", (Number(4), Number(8), Number(16))) xy = Cat(name, (x, y), "t") xy.reduce(ops.add)
def __call__(self): # calls pyro.param so that params are exposed and constraints applied # should not create any new torch.Tensors after __init__ self.initialize_params() N_state = self.config["sizes"]["state"] # initialize gamma to uniform gamma = Tensor( torch.zeros((N_state, N_state)), OrderedDict([("y_prev", bint(N_state))]), ) N_v = self.config["sizes"]["random"] N_c = self.config["sizes"]["group"] log_prob = [] plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))])) # group-level random effects if self.config["group"]["random"] == "discrete": # group-level discrete effect e_g = Variable("e_g", bint(N_v)) e_g_dist = plate_g + dist.Categorical(**self.params["e_g"])(value=e_g) log_prob.append(e_g_dist) eps_g = (plate_g + self.params["eps_g"]["theta"])(e_g=e_g) elif self.config["group"]["random"] == "continuous": eps_g = Variable("eps_g", reals(N_state)) eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])(value=eps_g) log_prob.append(eps_g_dist) else: eps_g = to_funsor(0.) N_s = self.config["sizes"]["individual"] plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))])) # individual-level random effects if self.config["individual"]["random"] == "discrete": # individual-level discrete effect e_i = Variable("e_i", bint(N_v)) e_i_dist = plate_g + plate_i + dist.Categorical( **self.params["e_i"] )(value=e_i) * self.raggedness_masks["individual"](t=0) log_prob.append(e_i_dist) eps_i = (plate_i + plate_g + self.params["eps_i"]["theta"](e_i=e_i)) elif self.config["individual"]["random"] == "continuous": eps_i = Variable("eps_i", reals(N_state)) eps_i_dist = plate_g + plate_i + dist.Normal(**self.params["eps_i"])(value=eps_i) log_prob.append(eps_i_dist) else: eps_i = to_funsor(0.) # add group-level and individual-level random effects to gamma gamma = gamma + eps_g + eps_i N_state = self.config["sizes"]["state"] # we've accounted for all effects, now actually compute gamma_y gamma_y = gamma(y_prev="y(t=1)") y = Variable("y", bint(N_state)) y_dist = plate_g + plate_i + dist.Categorical( probs=gamma_y.exp() / gamma_y.exp().sum() )(value=y) # observation 1: step size step_dist = plate_g + plate_i + dist.Gamma( **{k: v(y_curr=y) for k, v in self.params["step"].items()} )(value=self.observations["step"]) # step size zero-inflation if self.config["zeroinflation"]: step_zi = dist.Categorical(probs=self.params["zi_step"]["zi_param"](y_curr=y))( value="zi_step") step_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)( value=self.observations["step"]) step_dist = (step_zi + Stack("zi_step", (step_dist, step_zi_dist))).reduce(ops.logaddexp, "zi_step") # observation 2: step angle angle_dist = plate_g + plate_i + dist.VonMises( **{k: v(y_curr=y) for k, v in self.params["angle"].items()} )(value=self.observations["angle"]) # observation 3: dive activity omega_dist = plate_g + plate_i + dist.Beta( **{k: v(y_curr=y) for k, v in self.params["omega"].items()} )(value=self.observations["omega"]) # dive activity zero-inflation if self.config["zeroinflation"]: omega_zi = dist.Categorical(probs=self.params["zi_omega"]["zi_param"](y_curr=y))( value="zi_omega") omega_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)( value=self.observations["omega"]) omega_dist = (omega_zi + Stack("zi_omega", (omega_dist, omega_zi_dist))).reduce(ops.logaddexp, "zi_omega") # finally, construct the term for parallel scan reduction hmm_factor = step_dist + angle_dist + omega_dist hmm_factor = hmm_factor * self.raggedness_masks["individual"] hmm_factor = hmm_factor * self.raggedness_masks["timestep"] # copy masking behavior of pyro.infer.TraceEnum_ELBO._compute_model_factors hmm_factor = hmm_factor + y_dist log_prob.insert(0, hmm_factor) return log_prob