예제 #1
0
def test_stack_subs():
    x = Variable('x', reals())
    y = Variable('y', reals())
    z = Variable('z', reals())
    j = Variable('j', bint(3))

    f = Stack('i', (Number(0), x, y * z))
    check_funsor(f, {
        'i': bint(3),
        'x': reals(),
        'y': reals(),
        'z': reals()
    }, reals())

    assert f(i=Number(0, 3)) is Number(0)
    assert f(i=Number(1, 3)) is x
    assert f(i=Number(2, 3)) is y * z
    assert f(i=j) is Stack('j', (Number(0), x, y * z))
    assert f(i='j') is Stack('j', (Number(0), x, y * z))
    assert f.reduce(ops.add, 'i') is Number(0) + x + (y * z)

    assert f(x=0) is Stack('i', (Number(0), Number(0), y * z))
    assert f(y=x) is Stack('i', (Number(0), x, x * z))
    assert f(x=0, y=x) is Stack('i', (Number(0), Number(0), x * z))
    assert f(x=0, y=x, i=Number(2, 3)) is x * z
    assert f(x=0, i=j) is Stack('j', (Number(0), Number(0), y * z))
    assert f(x=0, i='j') is Stack('j', (Number(0), Number(0), y * z))
예제 #2
0
def test_stack_slice(start, stop, step):
    xs = tuple(map(Number, range(10)))
    actual = Stack('i', xs)(i=Slice('j', start, stop, step, dtype=10))
    expected = Stack('j', xs[start:stop:step])
    assert type(actual) == type(expected)
    assert actual.name == expected.name
    assert actual.parts == expected.parts
예제 #3
0
def mixed_sequential_sum_product(sum_op, prod_op, trans, time, step, num_segments=None):
    """
    For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``,
    computes a recursion equivalent to::

        tail_time = 1 + arange("time", trans.inputs["time"].size - 1)
        tail = sequential_sum_product(sum_op, prod_op,
                                      trans(time=tail_time),
                                      time, {"prev": "curr"})
        return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \
           .reduce(sum_op, "drop")

    by mixing parallel and serial scan algorithms over ``num_segments`` segments.

    :param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation.
    :param ~funsor.ops.AssociativeOp prod_op: A semiring product operation.
    :param ~funsor.terms.Funsor trans: A transition funsor.
    :param Variable time: The time input dimension.
    :param dict step: A dict mapping previous variables to current variables.
        This can contain multiple pairs of prev->curr variable names.
    :param int num_segments: number of segments for the first stage
    """
    time_var, time, duration = time, time.name, time.output.size
    num_segments = duration if num_segments is None else num_segments
    assert num_segments > 0 and duration > 0

    # handle unevenly sized segments by chopping off the final segment and calling mixed_sequential_sum_product again
    if duration % num_segments and duration - duration % num_segments > 0:
        remainder = trans(**{time: Slice(time, duration - duration % num_segments, duration, 1, duration)})
        initial = trans(**{time: Slice(time, 0, duration - duration % num_segments, 1, duration)})
        initial_eliminated = mixed_sequential_sum_product(
            sum_op, prod_op, initial, Variable(time, bint(duration - duration % num_segments)), step,
            num_segments=num_segments)
        final = Cat(time, (Stack(time, (initial_eliminated,)), remainder))
        final_eliminated = naive_sequential_sum_product(
            sum_op, prod_op, final, Variable(time, bint(1 + duration % num_segments)), step)
        return final_eliminated

    # handle degenerate cases that reduce to a single stage
    if num_segments == 1:
        return naive_sequential_sum_product(sum_op, prod_op, trans, time_var, step)
    if num_segments >= duration:
        return sequential_sum_product(sum_op, prod_op, trans, time_var, step)

    # break trans into num_segments segments of equal length
    segment_length = duration // num_segments
    segments = [trans(**{time: Slice(time, i * segment_length, (i + 1) * segment_length, 1, duration)})
                for i in range(num_segments)]

    first_stage_result = naive_sequential_sum_product(
        sum_op, prod_op, Stack(time + "__SEGMENTED", tuple(segments)),
        Variable(time, bint(segment_length)), step)

    second_stage_result = sequential_sum_product(
        sum_op, prod_op, first_stage_result,
        Variable(time + "__SEGMENTED", bint(num_segments)), step)

    return second_stage_result
예제 #4
0
def test_cat_simple():
    x = Stack('i', (Number(0), Number(1), Number(2)))
    y = Stack('i', (Number(3), Number(4)))

    assert Cat('i', (x, )) is x
    assert Cat('i', (y, )) is y

    xy = Cat('i', (x, y))
    assert xy.inputs == OrderedDict(i=bint(5))
    assert xy.name == 'i'
    for i in range(5):
        assert xy(i=i) is Number(i)
예제 #5
0
def test_stack_simple():
    x = Number(0.)
    y = Number(1.)
    z = Number(4.)

    xyz = Stack('i', (x, y, z))
    check_funsor(xyz, {'i': bint(3)}, reals())

    assert xyz(i=Number(0, 3)) is x
    assert xyz(i=Number(1, 3)) is y
    assert xyz(i=Number(2, 3)) is z
    assert xyz.reduce(ops.add, 'i') == 5.
예제 #6
0
def test_reduce_syntactic_sugar():
    i = Variable("i", bint(3))
    x = Stack("i", (Number(1), Number(2), Number(3)))
    expected = Number(1 + 2 + 3)
    assert x.reduce(ops.add) is expected
    assert x.reduce(ops.add, "i") is expected
    assert x.reduce(ops.add, {"i"}) is expected
    assert x.reduce(ops.add, frozenset(["i"])) is expected
    assert x.reduce(ops.add, i) is expected
    assert x.reduce(ops.add, {i}) is expected
    assert x.reduce(ops.add, frozenset([i])) is expected
예제 #7
0
def test_quote(interp):
    with interpretation(interp):
        x = Variable('x', bint(8))
        check_quote(x)

        y = Variable('y', reals(8, 3, 3))
        check_quote(y)
        check_quote(y[x])

        z = Stack('i', (Number(0), Variable('z', reals())))
        check_quote(z)
        check_quote(z(i=0))
        check_quote(z(i=Slice('i', 0, 1, 1, 2)))
        check_quote(z.reduce(ops.add, 'i'))
        check_quote(Cat('i', (z, z, z)))
        check_quote(Lambda(Variable('i', bint(2)), z))
예제 #8
0
def Uniform(components):
    components = tuple(components)
    size = len(components)
    if size == 1:
        return components[0]
    var = Variable('v', bint(size))
    return (Stack(var.name, components).reduce(ops.logaddexp, var.name) -
            math.log(size))
예제 #9
0
def test_funsor_stack(output):
    x = random_tensor(OrderedDict([
        ('i', bint(2)),
    ]), output)
    y = random_tensor(OrderedDict([
        ('j', bint(3)),
    ]), output)
    z = random_tensor(OrderedDict([
        ('i', bint(2)),
        ('k', bint(4)),
    ]), output)

    xy = Stack('t', (x, y))
    assert isinstance(xy, Tensor)
    assert xy.inputs == OrderedDict([
        ('t', bint(2)),
        ('i', bint(2)),
        ('j', bint(3)),
    ])
    assert xy.output == output
    for j in range(3):
        assert_close(xy(t=0, j=j), x)
    for i in range(2):
        assert_close(xy(t=1, i=i), y)

    xyz = Stack('t', (x, y, z))
    assert isinstance(xyz, Tensor)
    assert xyz.inputs == OrderedDict([
        ('t', bint(3)),
        ('i', bint(2)),
        ('j', bint(3)),
        ('k', bint(4)),
    ])
    assert xy.output == output
    for j in range(3):
        for k in range(4):
            assert_close(xyz(t=0, j=j, k=k), x)
    for i in range(2):
        for k in range(4):
            assert_close(xyz(t=1, i=i, k=k), y)
    for j in range(3):
        assert_close(xyz(t=2, j=j), z)
예제 #10
0
def test_stack_subs():
    x = Variable('x', Real)
    y = Variable('y', Real)
    z = Variable('z', Real)
    j = Variable('j', Bint[3])

    f = Stack('i', (Number(0), x, y * z))
    check_funsor(f, {'i': Bint[3], 'x': Real, 'y': Real, 'z': Real}, Real)

    assert f(i=Number(0, 3)) is Number(0)
    assert f(i=Number(1, 3)) is x
    assert f(i=Number(2, 3)) is y * z
    assert f(i=j) is Stack('j', (Number(0), x, y * z))
    assert f(i='j') is Stack('j', (Number(0), x, y * z))
    assert f.reduce(ops.add, 'i') is Number(0) + x + (y * z)

    assert f(x=0) is Stack('i', (Number(0), Number(0), y * z))
    assert f(y=x) is Stack('i', (Number(0), x, x * z))
    assert f(x=0, y=x) is Stack('i', (Number(0), Number(0), x * z))
    assert f(x=0, y=x, i=Number(2, 3)) is x * z
    assert f(x=0, i=j) is Stack('j', (Number(0), Number(0), y * z))
    assert f(x=0, i='j') is Stack('j', (Number(0), Number(0), y * z))
예제 #11
0
def test_bart(analytic_kl):
    global call_count
    call_count = 0

    with interpretation(reflect):
        q = Independent(
            Independent(
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Tensor(
                            torch.tensor(
                                [[
                                    -0.6077086925506592, -1.1546266078948975,
                                    -0.7021151781082153, -0.5303535461425781,
                                    -0.6365622282028198, -1.2423288822174072,
                                    -0.9941254258155823, -0.6287292242050171
                                ],
                                 [
                                     -0.6987162828445435, -1.0875964164733887,
                                     -0.7337473630905151, -0.4713417589664459,
                                     -0.6674002408981323, -1.2478348016738892,
                                     -0.8939017057418823, -0.5238542556762695
                                 ]],
                                dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                            ),
                            'real'),
                        Gaussian(
                            torch.tensor([
                                [[-0.3536059558391571], [-0.21779225766658783],
                                 [0.2840439975261688], [0.4531521499156952],
                                 [-0.1220812276005745], [-0.05519985035061836],
                                 [0.10932210087776184], [0.6656699776649475]],
                                [[-0.39107921719551086], [
                                    -0.20241987705230713
                                ], [0.2170514464378357], [0.4500560462474823],
                                 [0.27945515513420105], [-0.0490039587020874],
                                 [-0.06399798393249512], [0.846565842628479]]
                            ],
                                         dtype=torch.float32),  # noqa
                            torch.tensor([
                                [[[1.984686255455017]], [[0.6699360013008118]],
                                 [[1.6215802431106567]], [[2.372016668319702]],
                                 [[1.77385413646698]], [[0.526767373085022]],
                                 [[0.8722561597824097]], [[2.1879124641418457]]
                                 ],
                                [[[1.6996612548828125]], [[
                                    0.7535632252693176
                                ]], [[1.4946647882461548]],
                                 [[2.642792224884033]], [[1.7301604747772217]],
                                 [[0.5203893780708313]], [[1.055436372756958]],
                                 [[2.8370864391326904]]]
                            ],
                                         dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                                (
                                    'value_b1',
                                    reals(),
                                ),
                            )),
                    )),
                'gate_rate_b3',
                '_event_1_b2',
                'value_b1'),
            'gate_rate_t',
            'time_b4',
            'gate_rate_b3')
        p_prior = Contraction(
            ops.logaddexp,
            ops.add,
            frozenset({'state(time=1)_b11', 'state_b10'}),
            (
                MarkovProduct(
                    ops.logaddexp,
                    ops.add,
                    Contraction(
                        ops.nullop,
                        ops.add,
                        frozenset(),
                        (
                            Tensor(
                                torch.tensor(2.7672932147979736,
                                             dtype=torch.float32), (), 'real'),
                            Gaussian(
                                torch.tensor([-0.0, -0.0, 0.0, 0.0],
                                             dtype=torch.float32),
                                torch.tensor([[
                                    98.01002502441406, 0.0, -99.0000228881836,
                                    -0.0
                                ],
                                              [
                                                  0.0, 98.01002502441406, -0.0,
                                                  -99.0000228881836
                                              ],
                                              [
                                                  -99.0000228881836, -0.0,
                                                  100.0000228881836, 0.0
                                              ],
                                              [
                                                  -0.0, -99.0000228881836, 0.0,
                                                  100.0000228881836
                                              ]],
                                             dtype=torch.float32),  # noqa
                                (
                                    (
                                        'state_b7',
                                        reals(2, ),
                                    ),
                                    (
                                        'state(time=1)_b8',
                                        reals(2, ),
                                    ),
                                )),
                            Subs(
                                AffineNormal(
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.03488487750291824,
                                                0.07356668263673782,
                                                0.19946961104869843,
                                                0.5386509299278259,
                                                -0.708323061466217,
                                                0.24411526322364807,
                                                -0.20855577290058136,
                                                -0.2421337217092514
                                            ],
                                             [
                                                 0.41762110590934753,
                                                 0.5272183418273926,
                                                 -0.49835553765296936,
                                                 -0.0363837406039238,
                                                 -0.0005282597267068923,
                                                 0.2704298794269562,
                                                 -0.155222088098526,
                                                 -0.44802337884902954
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        (),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                -0.003566693514585495,
                                                -0.2848514914512634,
                                                0.037103548645973206,
                                                0.12648648023605347,
                                                -0.18501518666744232,
                                                -0.20899859070777893,
                                                0.04121830314397812,
                                                0.0054807960987091064
                                            ],
                                             [
                                                 0.0021788496524095535,
                                                 -0.18700894713401794,
                                                 0.08187370002269745,
                                                 0.13554862141609192,
                                                 -0.10477752983570099,
                                                 -0.20848378539085388,
                                                 -0.01393645629286766,
                                                 0.011670656502246857
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.5974780917167664,
                                                0.864071786403656,
                                                1.0236268043518066,
                                                0.7147538065910339,
                                                0.7423890233039856,
                                                0.9462157487869263,
                                                1.2132389545440674,
                                                1.0596832036972046
                                            ],
                                             [
                                                 0.5787821412086487,
                                                 0.9178534150123596,
                                                 0.9074794054031372,
                                                 0.6600189208984375,
                                                 0.8473222255706787,
                                                 0.8426999449729919,
                                                 1.194266438484192,
                                                 1.0471148490905762
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Variable('state(time=1)_b8', reals(2, )),
                                    Variable('gate_rate_b6', reals(8, ))),
                                ((
                                    'gate_rate_b6',
                                    Binary(
                                        ops.GetitemOp(0),
                                        Variable('gate_rate_t', reals(2, 8)),
                                        Variable('time_b9', bint(2))),
                                ), )),
                        )),
                    Variable('time_b9', bint(2)),
                    frozenset({('state_b7', 'state(time=1)_b8')}),
                    frozenset({('state(time=1)_b8', 'state(time=1)_b11'),
                               ('state_b7', 'state_b10')})),  # noqa
                Subs(
                    dist.MultivariateNormal(
                        Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32),
                               (), 'real'),
                        Tensor(
                            torch.tensor([[10.0, 0.0], [0.0, 10.0]],
                                         dtype=torch.float32),
                            (), 'real'), Variable('value_b5', reals(2, ))), ((
                                'value_b5',
                                Variable('state_b10', reals(2, )),
                            ), )),
            ))
        p_likelihood = Contraction(
            ops.add,
            ops.nullop,
            frozenset({'time_b17', 'destin_b16', 'origin_b15'}),
            (
                Contraction(
                    ops.logaddexp,
                    ops.add,
                    frozenset({'gated_b14'}),
                    (
                        dist.Categorical(
                            Binary(
                                ops.GetitemOp(0),
                                Binary(
                                    ops.GetitemOp(0),
                                    Subs(
                                        Function(
                                            unpack_gate_rate_0, reals(2, 2, 2),
                                            (Variable('gate_rate_b12',
                                                      reals(8, )), )),
                                        ((
                                            'gate_rate_b12',
                                            Binary(
                                                ops.GetitemOp(0),
                                                Variable(
                                                    'gate_rate_t', reals(2,
                                                                         8)),
                                                Variable('time_b17', bint(2))),
                                        ), )), Variable('origin_b15',
                                                        bint(2))),
                                Variable('destin_b16', bint(2))),
                            Variable('gated_b14', bint(2))),
                        Stack(
                            'gated_b14',
                            (
                                dist.Poisson(
                                    Binary(
                                        ops.GetitemOp(0),
                                        Binary(
                                            ops.GetitemOp(0),
                                            Subs(
                                                Function(
                                                    unpack_gate_rate_1,
                                                    reals(2, 2), (Variable(
                                                        'gate_rate_b13',
                                                        reals(8, )), )),
                                                ((
                                                    'gate_rate_b13',
                                                    Binary(
                                                        ops.GetitemOp(0),
                                                        Variable(
                                                            'gate_rate_t',
                                                            reals(2, 8)),
                                                        Variable(
                                                            'time_b17',
                                                            bint(2))),
                                                ), )),
                                            Variable('origin_b15', bint(2))),
                                        Variable('destin_b16', bint(2))),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                                dist.Delta(
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                            )),
                    )), ))

    if analytic_kl:
        exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
        with interpretation(monte_carlo):
            approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
        elbo = exact_part + approx_part
    else:
        p = p_prior + p_likelihood
        with interpretation(monte_carlo):
            elbo = Integrate(q, p - q, "gate_rate_t")

    assert isinstance(elbo, Tensor), elbo.pretty()
    assert call_count == 1
예제 #12
0
def test_cat(name):
    with interpretation(reflect):
        x = Stack("t", (Number(1), Number(2)))
        y = Stack("t", (Number(4), Number(8), Number(16)))
        xy = Cat(name, (x, y), "t")
        xy.reduce(ops.add)
예제 #13
0
    def __call__(self):

        # calls pyro.param so that params are exposed and constraints applied
        # should not create any new torch.Tensors after __init__
        self.initialize_params()

        N_state = self.config["sizes"]["state"]

        # initialize gamma to uniform
        gamma = Tensor(
            torch.zeros((N_state, N_state)),
            OrderedDict([("y_prev", bint(N_state))]),
        )

        N_v = self.config["sizes"]["random"]
        N_c = self.config["sizes"]["group"]
        log_prob = []

        plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))]))

        # group-level random effects
        if self.config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = Variable("e_g", bint(N_v))
            e_g_dist = plate_g + dist.Categorical(**self.params["e_g"])(value=e_g)

            log_prob.append(e_g_dist)

            eps_g = (plate_g + self.params["eps_g"]["theta"])(e_g=e_g)

        elif self.config["group"]["random"] == "continuous":
            eps_g = Variable("eps_g", reals(N_state))
            eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])(value=eps_g)

            log_prob.append(eps_g_dist)
        else:
            eps_g = to_funsor(0.)

        N_s = self.config["sizes"]["individual"]

        plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))]))
        # individual-level random effects
        if self.config["individual"]["random"] == "discrete":
            # individual-level discrete effect
            e_i = Variable("e_i", bint(N_v))
            e_i_dist = plate_g + plate_i + dist.Categorical(
                **self.params["e_i"]
            )(value=e_i) * self.raggedness_masks["individual"](t=0)

            log_prob.append(e_i_dist)

            eps_i = (plate_i + plate_g + self.params["eps_i"]["theta"](e_i=e_i))

        elif self.config["individual"]["random"] == "continuous":
            eps_i = Variable("eps_i", reals(N_state))
            eps_i_dist = plate_g + plate_i + dist.Normal(**self.params["eps_i"])(value=eps_i)

            log_prob.append(eps_i_dist)
        else:
            eps_i = to_funsor(0.)

        # add group-level and individual-level random effects to gamma
        gamma = gamma + eps_g + eps_i

        N_state = self.config["sizes"]["state"]

        # we've accounted for all effects, now actually compute gamma_y
        gamma_y = gamma(y_prev="y(t=1)")

        y = Variable("y", bint(N_state))
        y_dist = plate_g + plate_i + dist.Categorical(
            probs=gamma_y.exp() / gamma_y.exp().sum()
        )(value=y)

        # observation 1: step size
        step_dist = plate_g + plate_i + dist.Gamma(
            **{k: v(y_curr=y) for k, v in self.params["step"].items()}
        )(value=self.observations["step"])

        # step size zero-inflation
        if self.config["zeroinflation"]:
            step_zi = dist.Categorical(probs=self.params["zi_step"]["zi_param"](y_curr=y))(
                value="zi_step")
            step_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)(
                value=self.observations["step"])
            step_dist = (step_zi + Stack("zi_step", (step_dist, step_zi_dist))).reduce(ops.logaddexp, "zi_step")

        # observation 2: step angle
        angle_dist = plate_g + plate_i + dist.VonMises(
            **{k: v(y_curr=y) for k, v in self.params["angle"].items()}
        )(value=self.observations["angle"])

        # observation 3: dive activity
        omega_dist = plate_g + plate_i + dist.Beta(
            **{k: v(y_curr=y) for k, v in self.params["omega"].items()}
        )(value=self.observations["omega"])

        # dive activity zero-inflation
        if self.config["zeroinflation"]:
            omega_zi = dist.Categorical(probs=self.params["zi_omega"]["zi_param"](y_curr=y))(
                value="zi_omega")
            omega_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)(
                value=self.observations["omega"])
            omega_dist = (omega_zi + Stack("zi_omega", (omega_dist, omega_zi_dist))).reduce(ops.logaddexp, "zi_omega")

        # finally, construct the term for parallel scan reduction
        hmm_factor = step_dist + angle_dist + omega_dist
        hmm_factor = hmm_factor * self.raggedness_masks["individual"]
        hmm_factor = hmm_factor * self.raggedness_masks["timestep"]
        # copy masking behavior of pyro.infer.TraceEnum_ELBO._compute_model_factors
        hmm_factor = hmm_factor + y_dist
        log_prob.insert(0, hmm_factor)

        return log_prob