Пример #1
0
def mixed_sequential_sum_product(sum_op, prod_op, trans, time, step, num_segments=None):
    """
    For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``,
    computes a recursion equivalent to::

        tail_time = 1 + arange("time", trans.inputs["time"].size - 1)
        tail = sequential_sum_product(sum_op, prod_op,
                                      trans(time=tail_time),
                                      time, {"prev": "curr"})
        return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \
           .reduce(sum_op, "drop")

    by mixing parallel and serial scan algorithms over ``num_segments`` segments.

    :param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation.
    :param ~funsor.ops.AssociativeOp prod_op: A semiring product operation.
    :param ~funsor.terms.Funsor trans: A transition funsor.
    :param Variable time: The time input dimension.
    :param dict step: A dict mapping previous variables to current variables.
        This can contain multiple pairs of prev->curr variable names.
    :param int num_segments: number of segments for the first stage
    """
    time_var, time, duration = time, time.name, time.output.size
    num_segments = duration if num_segments is None else num_segments
    assert num_segments > 0 and duration > 0

    # handle unevenly sized segments by chopping off the final segment and calling mixed_sequential_sum_product again
    if duration % num_segments and duration - duration % num_segments > 0:
        remainder = trans(**{time: Slice(time, duration - duration % num_segments, duration, 1, duration)})
        initial = trans(**{time: Slice(time, 0, duration - duration % num_segments, 1, duration)})
        initial_eliminated = mixed_sequential_sum_product(
            sum_op, prod_op, initial, Variable(time, bint(duration - duration % num_segments)), step,
            num_segments=num_segments)
        final = Cat(time, (Stack(time, (initial_eliminated,)), remainder))
        final_eliminated = naive_sequential_sum_product(
            sum_op, prod_op, final, Variable(time, bint(1 + duration % num_segments)), step)
        return final_eliminated

    # handle degenerate cases that reduce to a single stage
    if num_segments == 1:
        return naive_sequential_sum_product(sum_op, prod_op, trans, time_var, step)
    if num_segments >= duration:
        return sequential_sum_product(sum_op, prod_op, trans, time_var, step)

    # break trans into num_segments segments of equal length
    segment_length = duration // num_segments
    segments = [trans(**{time: Slice(time, i * segment_length, (i + 1) * segment_length, 1, duration)})
                for i in range(num_segments)]

    first_stage_result = naive_sequential_sum_product(
        sum_op, prod_op, Stack(time + "__SEGMENTED", tuple(segments)),
        Variable(time, bint(segment_length)), step)

    second_stage_result = sequential_sum_product(
        sum_op, prod_op, first_stage_result,
        Variable(time + "__SEGMENTED", bint(num_segments)), step)

    return second_stage_result
Пример #2
0
def sequential_sum_product(sum_op, prod_op, trans, time, step):
    """
    For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``,
    computes a recursion equivalent to::

        tail_time = 1 + arange("time", trans.inputs["time"].size - 1)
        tail = sequential_sum_product(sum_op, prod_op,
                                      trans(time=tail_time),
                                      time, {"prev": "curr"})
        return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \
           .reduce(sum_op, "drop")

    but does so efficiently in parallel in O(log(time)).

    :param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation.
    :param ~funsor.ops.AssociativeOp prod_op: A semiring product operation.
    :param ~funsor.terms.Funsor trans: A transition funsor.
    :param Variable time: The time input dimension.
    :param dict step: A dict mapping previous variables to current variables.
        This can contain multiple pairs of prev->curr variable names.
    """
    assert isinstance(sum_op, AssociativeOp)
    assert isinstance(prod_op, AssociativeOp)
    assert isinstance(trans, Funsor)
    assert isinstance(time, Variable)
    assert isinstance(step, dict)
    assert all(isinstance(k, str) for k in step.keys())
    assert all(isinstance(v, str) for v in step.values())
    if time.name in trans.inputs:
        assert time.output == trans.inputs[time.name]

    step = OrderedDict(sorted(step.items()))
    drop = tuple("_drop_{}".format(i) for i in range(len(step)))
    prev_to_drop = dict(zip(step.keys(), drop))
    curr_to_drop = dict(zip(step.values(), drop))
    drop = frozenset(drop)

    time, duration = time.name, time.output.size
    while duration > 1:
        even_duration = duration // 2 * 2
        x = trans(**{time: Slice(time, 0, even_duration, 2, duration)},
                  **curr_to_drop)
        y = trans(**{time: Slice(time, 1, even_duration, 2, duration)},
                  **prev_to_drop)
        contracted = Contraction(sum_op, prod_op, drop, x, y)

        if duration > even_duration:
            extra = trans(**{time: Slice(time, duration - 1, duration)})
            contracted = Cat(time, (contracted, extra))
        trans = contracted
        duration = (duration + 1) // 2
    return trans(**{time: 0})
Пример #3
0
def test_stack_slice(start, stop, step):
    xs = tuple(map(Number, range(10)))
    actual = Stack('i', xs)(i=Slice('j', start, stop, step, dtype=10))
    expected = Stack('j', xs[start:stop:step])
    assert type(actual) == type(expected)
    assert actual.name == expected.name
    assert actual.parts == expected.parts
Пример #4
0
def test_cons_hash():
    assert Variable('x', bint(3)) is Variable('x', bint(3))
    assert Variable('x', reals()) is Variable('x', reals())
    assert Variable('x', reals()) is not Variable('x', bint(3))
    assert Number(0, 3) is Number(0, 3)
    assert Number(0.) is Number(0.)
    assert Number(0.) is not Number(0, 3)
    assert Slice('x', 10) is Slice('x', 10)
    assert Slice('x', 10) is Slice('x', 0, 10)
    assert Slice('x', 10, 10) is not Slice('x', 0, 10)
    assert Slice('x', 2, 10, 1) is Slice('x', 2, 10)
Пример #5
0
def test_cons_hash():
    assert Variable('x', Bint[3]) is Variable('x', Bint[3])
    assert Variable('x', Real) is Variable('x', Real)
    assert Variable('x', Real) is not Variable('x', Bint[3])
    assert Number(0, 3) is Number(0, 3)
    assert Number(0.) is Number(0.)
    assert Number(0.) is not Number(0, 3)
    assert Slice('x', 10) is Slice('x', 10)
    assert Slice('x', 10) is Slice('x', 0, 10)
    assert Slice('x', 10, 10) is not Slice('x', 0, 10)
    assert Slice('x', 2, 10, 1) is Slice('x', 2, 10)
Пример #6
0
def adjoint_cat(adj_redop, adj_binop, out_adj, name, parts, part_name):
    in_adjs = {}
    start = 0
    size = sum(part.inputs[part_name].dtype for part in parts)
    for i, part in enumerate(parts):
        if part_name in out_adj.inputs:
            in_adjs[part] = out_adj(**{name: Slice(name, start, start + part.inputs[part_name].dtype, 1, size)})
            start += part.inputs[part_name].dtype
        else:
            in_adjs[part] = adj_binop(out_adj, Binary(ops.PRODUCT_INVERSES[adj_binop], part, part))
    return in_adjs
Пример #7
0
def test_slice():
    t_slice = Slice("t", 10)

    s_slice = t_slice(t="s")
    assert isinstance(s_slice, Slice)
    assert s_slice.slice == t_slice.slice
    assert s_slice(s="t") is t_slice

    assert t_slice(t=0) is Number(0, 10)
    assert t_slice(t=1) is Number(1, 10)
    assert t_slice(t=2) is Number(2, 10)
    assert t_slice(t=t_slice) is t_slice
Пример #8
0
def test_slice_simple():
    t = randn((3, 4, 5))
    f = Tensor(t)["i", "j"]
    assert_close(f, f(i=Slice("i", 3)))
    assert_close(f, f(j=Slice("j", 4)))
    assert_close(f, f(i=Slice("i", 3), j=Slice("j", 4)))
    assert_close(f, f(i=Slice("i", 3), j="j"))
    assert_close(f, f(i="i", j=Slice("j", 4)))
Пример #9
0
def sarkka_bilmes_product(sum_op, prod_op, trans, time_var, global_vars=frozenset(), num_periods=1):

    assert isinstance(global_vars, frozenset)

    time = time_var.name

    def get_shift(name):
        return len(re.search("^P*", name).group(0))

    def shift_name(name, t):
        return t * "P" + name

    def shift_funsor(f, t):
        if t == 0:
            return f
        return f(**{name: shift_name(name, t) for name in f.inputs
                    if name != time and name not in global_vars})

    lags = {get_shift(name) for name in trans.inputs if name != time}
    lags.discard(0)
    if not lags:
        return sequential_sum_product(sum_op, prod_op, trans, time_var, {})

    period = int(np.lcm.reduce(list(lags)))
    original_names = frozenset(name for name in trans.inputs
                               if name != time and name not in global_vars
                               and not name.startswith("P"))
    renamed_factors = []
    duration = trans.inputs[time].size
    if duration % period:
        raise NotImplementedError("TODO handle partial windows")

    for t in range(period):
        slice_t = Slice(time, t, duration - period + t + 1, period, duration)
        factor = shift_funsor(trans, period - t - 1)
        factor = factor(**{time: slice_t})
        renamed_factors.append(factor)

    block_trans = reduce(prod_op, renamed_factors)
    block_step = {shift_name(name, period): name for name in block_trans.inputs
                  if name != time and name not in global_vars and get_shift(name) < period}
    block_time_var = Variable(time_var.name, bint(duration // period))
    final_chunk = mixed_sequential_sum_product(
        sum_op, prod_op, block_trans, block_time_var, block_step,
        num_segments=max(1, duration // (period * num_periods)))
    final_sum_vars = frozenset(
        shift_name(name, t) for name in original_names for t in range(1, period))
    result = final_chunk.reduce(sum_op, final_sum_vars)
    result = result(**{name: name.replace("P" * period, "P") for name in result.inputs})
    return result
Пример #10
0
def test_quote(interp):
    with interpretation(interp):
        x = Variable('x', bint(8))
        check_quote(x)

        y = Variable('y', reals(8, 3, 3))
        check_quote(y)
        check_quote(y[x])

        z = Stack('i', (Number(0), Variable('z', reals())))
        check_quote(z)
        check_quote(z(i=0))
        check_quote(z(i=Slice('i', 0, 1, 1, 2)))
        check_quote(z.reduce(ops.add, 'i'))
        check_quote(Cat('i', (z, z, z)))
        check_quote(Lambda(Variable('i', bint(2)), z))
Пример #11
0
def test_cat_slice_tensor(start, stop, step):

    terms = tuple(
        random_tensor(OrderedDict(t=bint(t), a=bint(2)))
        for t in [2, 1, 3, 4, 1, 3])
    dtype = sum(term.inputs['t'].dtype for term in terms)
    sub = Slice('t', start, stop, step, dtype)

    # eager
    expected = Cat('t', terms)(t=sub)

    # lazy - exercise Cat.eager_subs
    with interpretation(lazy):
        actual = Cat('t', terms)(t=sub)
    actual = reinterpret(actual)

    assert_close(actual, expected)
Пример #12
0
def test_slice_2(start, stop, step):
    t = randn((10, 2))
    actual = Tensor(t)["i"](i=Slice("j", start, stop, step, dtype=10))
    expected = Tensor(t[start:stop:step])["j"]
    assert_close(actual, expected)
Пример #13
0
def test_lazy_subs_type_clash():
    with interpretation(reflect):
        Slice('t', 3)(t=Slice('t', 2, dtype=3)).reduce(ops.add)
Пример #14
0
def test_slice_1(stop):
    t = torch.randn(10, 2)
    actual = Tensor(t)["i"](i=Slice("j", stop, dtype=10))
    expected = Tensor(t[:stop])["j"]
    assert_close(actual, expected)