def mixed_sequential_sum_product(sum_op, prod_op, trans, time, step, num_segments=None): """ For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``, computes a recursion equivalent to:: tail_time = 1 + arange("time", trans.inputs["time"].size - 1) tail = sequential_sum_product(sum_op, prod_op, trans(time=tail_time), time, {"prev": "curr"}) return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \ .reduce(sum_op, "drop") by mixing parallel and serial scan algorithms over ``num_segments`` segments. :param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation. :param ~funsor.ops.AssociativeOp prod_op: A semiring product operation. :param ~funsor.terms.Funsor trans: A transition funsor. :param Variable time: The time input dimension. :param dict step: A dict mapping previous variables to current variables. This can contain multiple pairs of prev->curr variable names. :param int num_segments: number of segments for the first stage """ time_var, time, duration = time, time.name, time.output.size num_segments = duration if num_segments is None else num_segments assert num_segments > 0 and duration > 0 # handle unevenly sized segments by chopping off the final segment and calling mixed_sequential_sum_product again if duration % num_segments and duration - duration % num_segments > 0: remainder = trans(**{time: Slice(time, duration - duration % num_segments, duration, 1, duration)}) initial = trans(**{time: Slice(time, 0, duration - duration % num_segments, 1, duration)}) initial_eliminated = mixed_sequential_sum_product( sum_op, prod_op, initial, Variable(time, bint(duration - duration % num_segments)), step, num_segments=num_segments) final = Cat(time, (Stack(time, (initial_eliminated,)), remainder)) final_eliminated = naive_sequential_sum_product( sum_op, prod_op, final, Variable(time, bint(1 + duration % num_segments)), step) return final_eliminated # handle degenerate cases that reduce to a single stage if num_segments == 1: return naive_sequential_sum_product(sum_op, prod_op, trans, time_var, step) if num_segments >= duration: return sequential_sum_product(sum_op, prod_op, trans, time_var, step) # break trans into num_segments segments of equal length segment_length = duration // num_segments segments = [trans(**{time: Slice(time, i * segment_length, (i + 1) * segment_length, 1, duration)}) for i in range(num_segments)] first_stage_result = naive_sequential_sum_product( sum_op, prod_op, Stack(time + "__SEGMENTED", tuple(segments)), Variable(time, bint(segment_length)), step) second_stage_result = sequential_sum_product( sum_op, prod_op, first_stage_result, Variable(time + "__SEGMENTED", bint(num_segments)), step) return second_stage_result
def sequential_sum_product(sum_op, prod_op, trans, time, step): """ For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``, computes a recursion equivalent to:: tail_time = 1 + arange("time", trans.inputs["time"].size - 1) tail = sequential_sum_product(sum_op, prod_op, trans(time=tail_time), time, {"prev": "curr"}) return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \ .reduce(sum_op, "drop") but does so efficiently in parallel in O(log(time)). :param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation. :param ~funsor.ops.AssociativeOp prod_op: A semiring product operation. :param ~funsor.terms.Funsor trans: A transition funsor. :param Variable time: The time input dimension. :param dict step: A dict mapping previous variables to current variables. This can contain multiple pairs of prev->curr variable names. """ assert isinstance(sum_op, AssociativeOp) assert isinstance(prod_op, AssociativeOp) assert isinstance(trans, Funsor) assert isinstance(time, Variable) assert isinstance(step, dict) assert all(isinstance(k, str) for k in step.keys()) assert all(isinstance(v, str) for v in step.values()) if time.name in trans.inputs: assert time.output == trans.inputs[time.name] step = OrderedDict(sorted(step.items())) drop = tuple("_drop_{}".format(i) for i in range(len(step))) prev_to_drop = dict(zip(step.keys(), drop)) curr_to_drop = dict(zip(step.values(), drop)) drop = frozenset(drop) time, duration = time.name, time.output.size while duration > 1: even_duration = duration // 2 * 2 x = trans(**{time: Slice(time, 0, even_duration, 2, duration)}, **curr_to_drop) y = trans(**{time: Slice(time, 1, even_duration, 2, duration)}, **prev_to_drop) contracted = Contraction(sum_op, prod_op, drop, x, y) if duration > even_duration: extra = trans(**{time: Slice(time, duration - 1, duration)}) contracted = Cat(time, (contracted, extra)) trans = contracted duration = (duration + 1) // 2 return trans(**{time: 0})
def test_stack_slice(start, stop, step): xs = tuple(map(Number, range(10))) actual = Stack('i', xs)(i=Slice('j', start, stop, step, dtype=10)) expected = Stack('j', xs[start:stop:step]) assert type(actual) == type(expected) assert actual.name == expected.name assert actual.parts == expected.parts
def test_cons_hash(): assert Variable('x', bint(3)) is Variable('x', bint(3)) assert Variable('x', reals()) is Variable('x', reals()) assert Variable('x', reals()) is not Variable('x', bint(3)) assert Number(0, 3) is Number(0, 3) assert Number(0.) is Number(0.) assert Number(0.) is not Number(0, 3) assert Slice('x', 10) is Slice('x', 10) assert Slice('x', 10) is Slice('x', 0, 10) assert Slice('x', 10, 10) is not Slice('x', 0, 10) assert Slice('x', 2, 10, 1) is Slice('x', 2, 10)
def test_cons_hash(): assert Variable('x', Bint[3]) is Variable('x', Bint[3]) assert Variable('x', Real) is Variable('x', Real) assert Variable('x', Real) is not Variable('x', Bint[3]) assert Number(0, 3) is Number(0, 3) assert Number(0.) is Number(0.) assert Number(0.) is not Number(0, 3) assert Slice('x', 10) is Slice('x', 10) assert Slice('x', 10) is Slice('x', 0, 10) assert Slice('x', 10, 10) is not Slice('x', 0, 10) assert Slice('x', 2, 10, 1) is Slice('x', 2, 10)
def adjoint_cat(adj_redop, adj_binop, out_adj, name, parts, part_name): in_adjs = {} start = 0 size = sum(part.inputs[part_name].dtype for part in parts) for i, part in enumerate(parts): if part_name in out_adj.inputs: in_adjs[part] = out_adj(**{name: Slice(name, start, start + part.inputs[part_name].dtype, 1, size)}) start += part.inputs[part_name].dtype else: in_adjs[part] = adj_binop(out_adj, Binary(ops.PRODUCT_INVERSES[adj_binop], part, part)) return in_adjs
def test_slice(): t_slice = Slice("t", 10) s_slice = t_slice(t="s") assert isinstance(s_slice, Slice) assert s_slice.slice == t_slice.slice assert s_slice(s="t") is t_slice assert t_slice(t=0) is Number(0, 10) assert t_slice(t=1) is Number(1, 10) assert t_slice(t=2) is Number(2, 10) assert t_slice(t=t_slice) is t_slice
def test_slice_simple(): t = randn((3, 4, 5)) f = Tensor(t)["i", "j"] assert_close(f, f(i=Slice("i", 3))) assert_close(f, f(j=Slice("j", 4))) assert_close(f, f(i=Slice("i", 3), j=Slice("j", 4))) assert_close(f, f(i=Slice("i", 3), j="j")) assert_close(f, f(i="i", j=Slice("j", 4)))
def sarkka_bilmes_product(sum_op, prod_op, trans, time_var, global_vars=frozenset(), num_periods=1): assert isinstance(global_vars, frozenset) time = time_var.name def get_shift(name): return len(re.search("^P*", name).group(0)) def shift_name(name, t): return t * "P" + name def shift_funsor(f, t): if t == 0: return f return f(**{name: shift_name(name, t) for name in f.inputs if name != time and name not in global_vars}) lags = {get_shift(name) for name in trans.inputs if name != time} lags.discard(0) if not lags: return sequential_sum_product(sum_op, prod_op, trans, time_var, {}) period = int(np.lcm.reduce(list(lags))) original_names = frozenset(name for name in trans.inputs if name != time and name not in global_vars and not name.startswith("P")) renamed_factors = [] duration = trans.inputs[time].size if duration % period: raise NotImplementedError("TODO handle partial windows") for t in range(period): slice_t = Slice(time, t, duration - period + t + 1, period, duration) factor = shift_funsor(trans, period - t - 1) factor = factor(**{time: slice_t}) renamed_factors.append(factor) block_trans = reduce(prod_op, renamed_factors) block_step = {shift_name(name, period): name for name in block_trans.inputs if name != time and name not in global_vars and get_shift(name) < period} block_time_var = Variable(time_var.name, bint(duration // period)) final_chunk = mixed_sequential_sum_product( sum_op, prod_op, block_trans, block_time_var, block_step, num_segments=max(1, duration // (period * num_periods))) final_sum_vars = frozenset( shift_name(name, t) for name in original_names for t in range(1, period)) result = final_chunk.reduce(sum_op, final_sum_vars) result = result(**{name: name.replace("P" * period, "P") for name in result.inputs}) return result
def test_quote(interp): with interpretation(interp): x = Variable('x', bint(8)) check_quote(x) y = Variable('y', reals(8, 3, 3)) check_quote(y) check_quote(y[x]) z = Stack('i', (Number(0), Variable('z', reals()))) check_quote(z) check_quote(z(i=0)) check_quote(z(i=Slice('i', 0, 1, 1, 2))) check_quote(z.reduce(ops.add, 'i')) check_quote(Cat('i', (z, z, z))) check_quote(Lambda(Variable('i', bint(2)), z))
def test_cat_slice_tensor(start, stop, step): terms = tuple( random_tensor(OrderedDict(t=bint(t), a=bint(2))) for t in [2, 1, 3, 4, 1, 3]) dtype = sum(term.inputs['t'].dtype for term in terms) sub = Slice('t', start, stop, step, dtype) # eager expected = Cat('t', terms)(t=sub) # lazy - exercise Cat.eager_subs with interpretation(lazy): actual = Cat('t', terms)(t=sub) actual = reinterpret(actual) assert_close(actual, expected)
def test_slice_2(start, stop, step): t = randn((10, 2)) actual = Tensor(t)["i"](i=Slice("j", start, stop, step, dtype=10)) expected = Tensor(t[start:stop:step])["j"] assert_close(actual, expected)
def test_lazy_subs_type_clash(): with interpretation(reflect): Slice('t', 3)(t=Slice('t', 2, dtype=3)).reduce(ops.add)
def test_slice_1(stop): t = torch.randn(10, 2) actual = Tensor(t)["i"](i=Slice("j", stop, dtype=10)) expected = Tensor(t[:stop])["j"] assert_close(actual, expected)