예제 #1
0
def test_gaussian_distribution(event_inputs, batch_inputs):
    num_samples = 100000
    sample_inputs = OrderedDict(particle=bint(num_samples))
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    sampled_vars = frozenset(event_inputs)
    p = random_gaussian(be_inputs)

    rng_key = None if get_backend() == "torch" else np.array([0, 0],
                                                             dtype=np.uint32)
    q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key)
    p_vars = sampled_vars
    q_vars = sampled_vars | frozenset(['particle'])
    # Check zeroth moment.
    assert_close(q.reduce(ops.logaddexp, q_vars),
                 p.reduce(ops.logaddexp, p_vars),
                 atol=1e-6)
    for k1, d1 in event_inputs.items():
        x = Variable(k1, d1)
        # Check first moments.
        assert_close(Integrate(q, x, q_vars),
                     Integrate(p, x, p_vars),
                     atol=0.5,
                     rtol=0.2)
        for k2, d2 in event_inputs.items():
            y = Variable(k2, d2)
            # Check second moments.
            continue  # FIXME: Quadratic integration is not supported:
            assert_close(Integrate(q, x * y, q_vars),
                         Integrate(p, x * y, p_vars),
                         atol=1e-2)
예제 #2
0
def test_mc_plate_gaussian():
    log_measure = Gaussian(torch.tensor([0.]), torch.tensor([[1.]]),
                           (('loc', reals()),)) + torch.tensor(-0.9189)
    integrand = Gaussian(torch.randn((100, 1)) + 3., torch.ones((100, 1, 1)),
                         (('data', bint(100)), ('loc', reals())))

    res = Integrate(log_measure.sample(frozenset({'loc'})), integrand, frozenset({'loc'}))
    res = res.reduce(ops.mul, frozenset({'data'}))
    assert not torch.isinf(res).any()
예제 #3
0
def test_mc_plate_gaussian():
    log_measure = Gaussian(numeric_array([0.]), numeric_array([[1.]]),
                           (('loc', Real),)) + numeric_array(-0.9189)
    integrand = Gaussian(randn((100, 1)) + 3., ones((100, 1, 1)),
                         (('data', Bint[100]), ('loc', Real)))

    rng_key = None if get_backend() != 'jax' else np.array([0, 0], dtype=np.uint32)
    res = Integrate(log_measure.sample('loc', rng_key=rng_key), integrand, 'loc')
    res = res.reduce(ops.mul, 'data')
    assert not ((res == float('inf')) | (res == float('-inf'))).any()
예제 #4
0
def _assert_conjugate_density_ok(latent, conditional, obs, lazy_latent=None,
                                 num_samples=10000, prec=1e-2):
    sample_inputs = OrderedDict(n=Bint[num_samples])
    lazy_latent = lazy_latent if lazy_latent is not None else latent
    rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
    latent_samples = lazy_latent.sample(frozenset(["prior"]), sample_inputs, rng_key=rng_key)
    expected = Integrate(latent_samples, conditional(value=obs).exp(), frozenset(['prior']))
    expected = expected.reduce(ops.add, frozenset(sample_inputs))
    actual = (latent + conditional).reduce(ops.logaddexp, set(["prior"]))(value=obs).exp()
    assert_close(actual, expected, atol=prec, rtol=None)
예제 #5
0
def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples,
                   statistic, with_lazy, params):
    params = [Tensor(p, inputs) for p in params]
    if isinstance(with_lazy, bool):
        with interpretation(lazy if with_lazy else eager):
            funsor_dist = funsor_dist_class(*params)
    else:
        funsor_dist = funsor_dist_class(*params)

    rng_key = None if get_backend() == "torch" else np.array([0, 0],
                                                             dtype=np.uint32)
    sample_value = funsor_dist.sample(frozenset(['value']),
                                      sample_inputs,
                                      rng_key=rng_key)
    expected_inputs = OrderedDict(
        tuple(sample_inputs.items()) + tuple(inputs.items()) +
        (('value', funsor_dist.inputs['value']), ))
    check_funsor(sample_value, expected_inputs, reals())

    if sample_inputs:

        actual_mean = Integrate(sample_value,
                                Variable('value', funsor_dist.inputs['value']),
                                frozenset(['value'
                                           ])).reduce(ops.add,
                                                      frozenset(sample_inputs))

        inputs, tensors = align_tensors(
            *list(funsor_dist.params.values())[:-1])
        raw_dist = funsor_dist.dist_class(
            **dict(zip(funsor_dist._ast_fields[:-1], tensors)))
        expected_mean = Tensor(raw_dist.mean, inputs)

        if statistic == "mean":
            actual_stat, expected_stat = actual_mean, expected_mean
        elif statistic == "variance":
            actual_stat = Integrate(
                sample_value, (Variable('value', funsor_dist.inputs['value']) -
                               actual_mean)**2,
                frozenset(['value'])).reduce(ops.add, frozenset(sample_inputs))
            expected_stat = Tensor(raw_dist.variance, inputs)
        elif statistic == "entropy":
            actual_stat = -Integrate(sample_value, funsor_dist,
                                     frozenset(['value'])).reduce(
                                         ops.add, frozenset(sample_inputs))
            expected_stat = Tensor(raw_dist.entropy(), inputs)
        else:
            raise ValueError("invalid test statistic")

        diff = actual_stat.reduce(ops.add).data - expected_stat.reduce(
            ops.add).data
        return diff.sum(), diff
예제 #6
0
def monte_carlo_integrate(log_measure, integrand, reduced_vars):
    sample = log_measure.sample(reduced_vars, monte_carlo.sample_inputs)
    if sample is log_measure:
        return None  # cannot progress
    reduced_vars |= frozenset(monte_carlo.sample_inputs).intersection(
        sample.inputs)
    return Integrate(sample, integrand, reduced_vars)
예제 #7
0
파일: joint.py 프로젝트: fehiepsi/funsor
def monte_carlo_integrate(log_measure, integrand, reduced_vars):
    sampled_log_measure = log_measure.sample(reduced_vars,
                                             monte_carlo.sample_inputs)
    if sampled_log_measure is not log_measure:
        reduced_vars = reduced_vars | frozenset(monte_carlo.sample_inputs)
        return Integrate(sampled_log_measure, integrand, reduced_vars)

    return None  # defer to default implementation
예제 #8
0
def test_integrate_gaussian(int_inputs, real_inputs):
    int_inputs = OrderedDict(sorted(int_inputs.items()))
    real_inputs = OrderedDict(sorted(real_inputs.items()))
    inputs = int_inputs.copy()
    inputs.update(real_inputs)

    log_measure = random_gaussian(inputs)
    integrand = random_gaussian(inputs)
    reduced_vars = frozenset(real_inputs)

    with monte_carlo_interpretation(particle=bint(10000)):
        approx = Integrate(log_measure, integrand, reduced_vars)
        assert isinstance(approx, Tensor)

    exact = Integrate(log_measure, integrand, reduced_vars)
    assert isinstance(exact, Tensor)
    assert_close(approx, exact, atol=0.1, rtol=0.1)
예제 #9
0
def test_integrate_gaussian(int_inputs, real_inputs):
    int_inputs = OrderedDict(sorted(int_inputs.items()))
    real_inputs = OrderedDict(sorted(real_inputs.items()))
    inputs = int_inputs.copy()
    inputs.update(real_inputs)

    log_measure = random_gaussian(inputs)
    integrand = random_gaussian(inputs)
    reduced_vars = frozenset(real_inputs)

    sampled_log_measure = log_measure.sample(reduced_vars, OrderedDict(particle=bint(10000)))
    approx = Integrate(sampled_log_measure, integrand, reduced_vars | {'particle'})
    assert isinstance(approx, Tensor)

    exact = Integrate(log_measure, integrand, reduced_vars)
    assert isinstance(exact, Tensor)
    assert_close(approx, exact, atol=0.1, rtol=0.1)
예제 #10
0
def test_integrate_gaussian(int_inputs, real_inputs):
    int_inputs = OrderedDict(sorted(int_inputs.items()))
    real_inputs = OrderedDict(sorted(real_inputs.items()))
    inputs = int_inputs.copy()
    inputs.update(real_inputs)

    log_measure = random_gaussian(inputs)
    integrand = random_gaussian(inputs)
    reduced_vars = frozenset(real_inputs)

    rng_key = None if get_backend() != 'jax' else np.array([0, 0], dtype=np.uint32)
    sampled_log_measure = log_measure.sample(reduced_vars, OrderedDict(particle=Bint[100000]), rng_key=rng_key)
    approx = Integrate(sampled_log_measure, integrand, reduced_vars | {'particle'})
    assert isinstance(approx, Tensor)

    exact = Integrate(log_measure, integrand, reduced_vars)
    assert isinstance(exact, Tensor)
    assert_close(approx, exact, atol=0.1, rtol=0.1)
예제 #11
0
파일: joint.py 프로젝트: fehiepsi/funsor
def eager_integrate(log_measure, integrand, reduced_vars):
    name = integrand.name
    assert reduced_vars == frozenset([name])
    if any(d.name == name for d in log_measure.deltas):
        deltas = tuple(d for d in log_measure.deltas if d.name != name)
        log_norm = Joint(deltas, log_measure.discrete, log_measure.gaussian)
        for d in log_measure.deltas:
            if d.name == name:
                mean = d.point
                break
        return mean * log_norm.exp()
    elif name in log_measure.discrete.inputs:
        integrand = arange(name, integrand.inputs[name].dtype)
        return Integrate(log_measure, integrand, reduced_vars)
    else:
        assert name in log_measure.gaussian.inputs
        gaussian = Integrate(log_measure.gaussian, integrand, reduced_vars)
        return Joint(log_measure.deltas, log_measure.discrete).exp() * gaussian
예제 #12
0
파일: joint.py 프로젝트: fehiepsi/funsor
def _simplify_integrate(fn, joint, integrand, reduced_vars):
    if any(d.name in reduced_vars for d in joint.deltas):
        subs = tuple(
            (d.name, d.point) for d in joint.deltas if d.name in reduced_vars)
        deltas = tuple(d for d in joint.deltas if d.name not in reduced_vars)
        log_measure = Joint(deltas, joint.discrete, joint.gaussian)
        integrand = Subs(integrand, subs)
        reduced_vars = reduced_vars - frozenset(name for name, point in subs)
        return Integrate(log_measure, integrand, reduced_vars)

    return fn(joint, integrand, reduced_vars)
예제 #13
0
def monte_carlo_integrate(state, log_measure, integrand, reduced_vars):
    sample_options = {}
    if state.rng_key is not None and get_backend() == "jax":
        import jax

        sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key)

    sample = log_measure.sample(reduced_vars, state.sample_inputs, **sample_options)
    if sample is log_measure:
        return None  # cannot progress
    reduced_vars |= frozenset(state.sample_inputs).intersection(sample.inputs)
    return Integrate(sample, integrand, reduced_vars)
예제 #14
0
def test_lognormal_distribution(moment):
    num_samples = 100000
    inputs = OrderedDict(batch=bint(10))
    loc = random_tensor(inputs)
    scale = random_tensor(inputs).exp()

    log_measure = dist.LogNormal(loc, scale)(value='x')
    probe = Variable('x', reals())**moment
    with monte_carlo_interpretation(particle=bint(num_samples)):
        with xfail_if_not_implemented():
            actual = Integrate(log_measure, probe, frozenset(['x']))

    samples = backend_dist.LogNormal(loc, scale).sample((num_samples, ))
    expected = (samples**moment).mean(0)
    assert_close(actual.data, expected, atol=1e-2, rtol=1e-2)
예제 #15
0
def test_syntactic_sugar():
    i = Variable("i", bint(3))
    log_measure = random_tensor(OrderedDict(i=bint(3)))
    integrand = random_tensor(OrderedDict(i=bint(3)))
    expected = (log_measure.exp() * integrand).reduce(ops.add, "i")
    assert_close(Integrate(log_measure, integrand, "i"), expected)
    assert_close(Integrate(log_measure, integrand, {"i"}), expected)
    assert_close(Integrate(log_measure, integrand, frozenset(["i"])), expected)
    assert_close(Integrate(log_measure, integrand, i), expected)
    assert_close(Integrate(log_measure, integrand, {i}), expected)
    assert_close(Integrate(log_measure, integrand, frozenset([i])), expected)
예제 #16
0
def test_lognormal_distribution(moment):
    num_samples = 100000
    inputs = OrderedDict(batch=Bint[10])
    loc = random_tensor(inputs)
    scale = random_tensor(inputs).exp()

    log_measure = dist.LogNormal(loc, scale)(value='x')
    probe = Variable('x', Real)**moment
    with interpretation(MonteCarlo(particle=Bint[num_samples])):
        with xfail_if_not_implemented():
            actual = Integrate(log_measure, probe, frozenset(['x']))

    _, (loc_data, scale_data) = align_tensors(loc, scale)
    samples = backend_dist.LogNormal(loc_data, scale_data).sample(
        (num_samples, ))
    expected = (samples**moment).mean(0)
    assert_close(actual.data, expected, atol=1e-2, rtol=1e-2)
예제 #17
0
def test_reduce_moment_matching_moments():
    x = Variable('x', reals(2))
    gaussian = random_gaussian(
        OrderedDict([('i', bint(2)), ('j', bint(3)), ('x', reals(2))]))
    with interpretation(moment_matching):
        approx = gaussian.reduce(ops.logaddexp, 'j')
    with monte_carlo_interpretation(s=bint(100000)):
        actual = Integrate(approx, Number(1.), 'x')
        expected = Integrate(gaussian, Number(1.), {'j', 'x'})
        assert_close(actual, expected, atol=1e-3, rtol=1e-3)

        actual = Integrate(approx, x, 'x')
        expected = Integrate(gaussian, x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)

        actual = Integrate(approx, x * x, 'x')
        expected = Integrate(gaussian, x * x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)
예제 #18
0
def test_reduce_moment_matching_moments():
    x = Variable('x', Reals[2])
    gaussian = random_gaussian(
        OrderedDict([('i', Bint[2]), ('j', Bint[3]), ('x', Reals[2])]))
    with interpretation(moment_matching):
        approx = gaussian.reduce(ops.logaddexp, 'j')
    with interpretation(MonteCarlo(s=Bint[100000])):
        actual = Integrate(approx, Number(1.), 'x')
        expected = Integrate(gaussian, Number(1.), {'j', 'x'})
        assert_close(actual, expected, atol=1e-3, rtol=1e-3)

        actual = Integrate(approx, x, 'x')
        expected = Integrate(gaussian, x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)

        actual = Integrate(approx, x * x, 'x')
        expected = Integrate(gaussian, x * x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)
예제 #19
0
def _check_sample(funsor_dist,
                  sample_inputs,
                  inputs,
                  atol=1e-2,
                  rtol=None,
                  num_samples=100000,
                  statistic="mean",
                  skip_grad=False):
    """utility that compares a Monte Carlo estimate of a distribution mean with the true mean"""
    samples_per_dim = int(num_samples**(1. / max(1, len(sample_inputs))))
    sample_inputs = OrderedDict(
        (k, bint(samples_per_dim)) for k in sample_inputs)

    for tensor in list(funsor_dist.params.values())[:-1]:
        tensor.data.requires_grad_()

    sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs)
    expected_inputs = OrderedDict(
        tuple(sample_inputs.items()) + tuple(inputs.items()) +
        (('value', funsor_dist.inputs['value']), ))
    check_funsor(sample_value, expected_inputs, reals())

    if sample_inputs:

        actual_mean = Integrate(sample_value,
                                Variable('value', funsor_dist.inputs['value']),
                                frozenset(['value'
                                           ])).reduce(ops.add,
                                                      frozenset(sample_inputs))

        inputs, tensors = align_tensors(
            *list(funsor_dist.params.values())[:-1])
        raw_dist = funsor_dist.dist_class(
            **dict(zip(funsor_dist._ast_fields[:-1], tensors)))
        expected_mean = Tensor(raw_dist.mean, inputs)

        check_funsor(actual_mean, expected_mean.inputs, expected_mean.output)
        assert_close(actual_mean, expected_mean, atol=atol, rtol=rtol)

    if sample_inputs and not skip_grad:
        if statistic == "mean":
            actual_stat, expected_stat = actual_mean, expected_mean
        elif statistic == "variance":
            actual_stat = Integrate(
                sample_value, (Variable('value', funsor_dist.inputs['value']) -
                               actual_mean)**2,
                frozenset(['value'])).reduce(ops.add, frozenset(sample_inputs))
            expected_stat = Tensor(raw_dist.variance, inputs)
        elif statistic == "entropy":
            actual_stat = -Integrate(sample_value, funsor_dist,
                                     frozenset(['value'])).reduce(
                                         ops.add, frozenset(sample_inputs))
            expected_stat = Tensor(raw_dist.entropy(), inputs)
        else:
            raise ValueError("invalid test statistic")

        grad_targets = [v.data for v in list(funsor_dist.params.values())[:-1]]
        actual_grads = torch.autograd.grad(actual_stat.reduce(
            ops.add).sum().data,
                                           grad_targets,
                                           allow_unused=True)
        expected_grads = torch.autograd.grad(expected_stat.reduce(
            ops.add).sum().data,
                                             grad_targets,
                                             allow_unused=True)

        assert_close(actual_stat, expected_stat, atol=atol, rtol=rtol)

        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            if expected_grad is not None:
                assert_close(actual_grad, expected_grad, atol=atol, rtol=rtol)
            else:
                assert_close(actual_grad,
                             torch.zeros_like(actual_grad),
                             atol=atol,
                             rtol=rtol)
예제 #20
0
def test_bart(analytic_kl):
    global call_count
    call_count = 0

    with interpretation(reflect):
        q = Independent(
            Independent(
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Tensor(
                            torch.tensor(
                                [[
                                    -0.6077086925506592, -1.1546266078948975,
                                    -0.7021151781082153, -0.5303535461425781,
                                    -0.6365622282028198, -1.2423288822174072,
                                    -0.9941254258155823, -0.6287292242050171
                                ],
                                 [
                                     -0.6987162828445435, -1.0875964164733887,
                                     -0.7337473630905151, -0.4713417589664459,
                                     -0.6674002408981323, -1.2478348016738892,
                                     -0.8939017057418823, -0.5238542556762695
                                 ]],
                                dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                            ),
                            'real'),
                        Gaussian(
                            torch.tensor([
                                [[-0.3536059558391571], [-0.21779225766658783],
                                 [0.2840439975261688], [0.4531521499156952],
                                 [-0.1220812276005745], [-0.05519985035061836],
                                 [0.10932210087776184], [0.6656699776649475]],
                                [[-0.39107921719551086], [
                                    -0.20241987705230713
                                ], [0.2170514464378357], [0.4500560462474823],
                                 [0.27945515513420105], [-0.0490039587020874],
                                 [-0.06399798393249512], [0.846565842628479]]
                            ],
                                         dtype=torch.float32),  # noqa
                            torch.tensor([
                                [[[1.984686255455017]], [[0.6699360013008118]],
                                 [[1.6215802431106567]], [[2.372016668319702]],
                                 [[1.77385413646698]], [[0.526767373085022]],
                                 [[0.8722561597824097]], [[2.1879124641418457]]
                                 ],
                                [[[1.6996612548828125]], [[
                                    0.7535632252693176
                                ]], [[1.4946647882461548]],
                                 [[2.642792224884033]], [[1.7301604747772217]],
                                 [[0.5203893780708313]], [[1.055436372756958]],
                                 [[2.8370864391326904]]]
                            ],
                                         dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                                (
                                    'value_b1',
                                    reals(),
                                ),
                            )),
                    )),
                'gate_rate_b3',
                '_event_1_b2',
                'value_b1'),
            'gate_rate_t',
            'time_b4',
            'gate_rate_b3')
        p_prior = Contraction(
            ops.logaddexp,
            ops.add,
            frozenset({'state(time=1)_b11', 'state_b10'}),
            (
                MarkovProduct(
                    ops.logaddexp,
                    ops.add,
                    Contraction(
                        ops.nullop,
                        ops.add,
                        frozenset(),
                        (
                            Tensor(
                                torch.tensor(2.7672932147979736,
                                             dtype=torch.float32), (), 'real'),
                            Gaussian(
                                torch.tensor([-0.0, -0.0, 0.0, 0.0],
                                             dtype=torch.float32),
                                torch.tensor([[
                                    98.01002502441406, 0.0, -99.0000228881836,
                                    -0.0
                                ],
                                              [
                                                  0.0, 98.01002502441406, -0.0,
                                                  -99.0000228881836
                                              ],
                                              [
                                                  -99.0000228881836, -0.0,
                                                  100.0000228881836, 0.0
                                              ],
                                              [
                                                  -0.0, -99.0000228881836, 0.0,
                                                  100.0000228881836
                                              ]],
                                             dtype=torch.float32),  # noqa
                                (
                                    (
                                        'state_b7',
                                        reals(2, ),
                                    ),
                                    (
                                        'state(time=1)_b8',
                                        reals(2, ),
                                    ),
                                )),
                            Subs(
                                AffineNormal(
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.03488487750291824,
                                                0.07356668263673782,
                                                0.19946961104869843,
                                                0.5386509299278259,
                                                -0.708323061466217,
                                                0.24411526322364807,
                                                -0.20855577290058136,
                                                -0.2421337217092514
                                            ],
                                             [
                                                 0.41762110590934753,
                                                 0.5272183418273926,
                                                 -0.49835553765296936,
                                                 -0.0363837406039238,
                                                 -0.0005282597267068923,
                                                 0.2704298794269562,
                                                 -0.155222088098526,
                                                 -0.44802337884902954
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        (),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                -0.003566693514585495,
                                                -0.2848514914512634,
                                                0.037103548645973206,
                                                0.12648648023605347,
                                                -0.18501518666744232,
                                                -0.20899859070777893,
                                                0.04121830314397812,
                                                0.0054807960987091064
                                            ],
                                             [
                                                 0.0021788496524095535,
                                                 -0.18700894713401794,
                                                 0.08187370002269745,
                                                 0.13554862141609192,
                                                 -0.10477752983570099,
                                                 -0.20848378539085388,
                                                 -0.01393645629286766,
                                                 0.011670656502246857
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.5974780917167664,
                                                0.864071786403656,
                                                1.0236268043518066,
                                                0.7147538065910339,
                                                0.7423890233039856,
                                                0.9462157487869263,
                                                1.2132389545440674,
                                                1.0596832036972046
                                            ],
                                             [
                                                 0.5787821412086487,
                                                 0.9178534150123596,
                                                 0.9074794054031372,
                                                 0.6600189208984375,
                                                 0.8473222255706787,
                                                 0.8426999449729919,
                                                 1.194266438484192,
                                                 1.0471148490905762
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Variable('state(time=1)_b8', reals(2, )),
                                    Variable('gate_rate_b6', reals(8, ))),
                                ((
                                    'gate_rate_b6',
                                    Binary(
                                        ops.GetitemOp(0),
                                        Variable('gate_rate_t', reals(2, 8)),
                                        Variable('time_b9', bint(2))),
                                ), )),
                        )),
                    Variable('time_b9', bint(2)),
                    frozenset({('state_b7', 'state(time=1)_b8')}),
                    frozenset({('state(time=1)_b8', 'state(time=1)_b11'),
                               ('state_b7', 'state_b10')})),  # noqa
                Subs(
                    dist.MultivariateNormal(
                        Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32),
                               (), 'real'),
                        Tensor(
                            torch.tensor([[10.0, 0.0], [0.0, 10.0]],
                                         dtype=torch.float32),
                            (), 'real'), Variable('value_b5', reals(2, ))), ((
                                'value_b5',
                                Variable('state_b10', reals(2, )),
                            ), )),
            ))
        p_likelihood = Contraction(
            ops.add,
            ops.nullop,
            frozenset({'time_b17', 'destin_b16', 'origin_b15'}),
            (
                Contraction(
                    ops.logaddexp,
                    ops.add,
                    frozenset({'gated_b14'}),
                    (
                        dist.Categorical(
                            Binary(
                                ops.GetitemOp(0),
                                Binary(
                                    ops.GetitemOp(0),
                                    Subs(
                                        Function(
                                            unpack_gate_rate_0, reals(2, 2, 2),
                                            (Variable('gate_rate_b12',
                                                      reals(8, )), )),
                                        ((
                                            'gate_rate_b12',
                                            Binary(
                                                ops.GetitemOp(0),
                                                Variable(
                                                    'gate_rate_t', reals(2,
                                                                         8)),
                                                Variable('time_b17', bint(2))),
                                        ), )), Variable('origin_b15',
                                                        bint(2))),
                                Variable('destin_b16', bint(2))),
                            Variable('gated_b14', bint(2))),
                        Stack(
                            'gated_b14',
                            (
                                dist.Poisson(
                                    Binary(
                                        ops.GetitemOp(0),
                                        Binary(
                                            ops.GetitemOp(0),
                                            Subs(
                                                Function(
                                                    unpack_gate_rate_1,
                                                    reals(2, 2), (Variable(
                                                        'gate_rate_b13',
                                                        reals(8, )), )),
                                                ((
                                                    'gate_rate_b13',
                                                    Binary(
                                                        ops.GetitemOp(0),
                                                        Variable(
                                                            'gate_rate_t',
                                                            reals(2, 8)),
                                                        Variable(
                                                            'time_b17',
                                                            bint(2))),
                                                ), )),
                                            Variable('origin_b15', bint(2))),
                                        Variable('destin_b16', bint(2))),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                                dist.Delta(
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                            )),
                    )), ))

    if analytic_kl:
        exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
        with interpretation(monte_carlo):
            approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
        elbo = exact_part + approx_part
    else:
        p = p_prior + p_likelihood
        with interpretation(monte_carlo):
            elbo = Integrate(q, p - q, "gate_rate_t")

    assert isinstance(elbo, Tensor), elbo.pretty()
    assert call_count == 1
예제 #21
0
파일: delta.py 프로젝트: fehiepsi/funsor
def eager_integrate(delta, integrand, reduced_vars):
    assert delta.name in reduced_vars
    integrand = Subs(integrand, ((delta.name, delta.point), ))
    log_measure = delta.log_density
    reduced_vars -= frozenset([delta.name])
    return Integrate(log_measure, integrand, reduced_vars)
예제 #22
0
def optimize_reduce_binary_exp(op, arg, reduced_vars):
    if op is not ops.add or arg.op is not ops.mul or \
            not isinstance(arg.lhs, Unary) or arg.lhs.op is not ops.exp:
        return None
    return Integrate(arg.lhs.arg, arg.rhs, reduced_vars)
예제 #23
0
def test_integrate(interp):
    log_measure = random_tensor(OrderedDict([('i', bint(2)), ('j', bint(3))]))
    integrand = random_tensor(OrderedDict([('j', bint(3)), ('k', bint(4))]))
    with interpretation(interp):
        Integrate(log_measure, integrand, {'i', 'j', 'k'})
예제 #24
0
def optimize_contract_exp_funsor(sum_op, prod_op, lhs, rhs, reduced_vars):
    if lhs.op is ops.exp and isinstance(lhs.arg, (Gaussian, Tensor, Delta, Joint)) and \
            sum_op is ops.add and prod_op is ops.mul:
        return Integrate(lhs.arg, rhs, reduced_vars)
    return None
예제 #25
0
def test_integrate(interp):
    log_measure = random_tensor(OrderedDict([('i', Bint[2]), ('j', Bint[3])]))
    integrand = random_tensor(OrderedDict([('j', Bint[3]), ('k', Bint[4])]))
    with interpretation(interp):
        Integrate(log_measure, integrand, {'i', 'j', 'k'})