def f(i): state = tuplify(initializer()) for k in range(n_parallel): j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) r = reducer(mem_state, state) write_state_to_memory(r)
def _(i): state = tuplify(initializer()) k = 0 block = get_block() while k < n_loops and (len(get_block()) < get_program().budget \ or k == 0) \ and block is get_block(): j = i + k state = reducer(tuplify(loop_body(j)), state) k += 1 r = reducer(mem_state, state) write_state_to_memory(r) global n_opt_loops n_opt_loops = k n_parallel_reg.write(k) return i + k
def foreach_enumerate(a): for x in a: get_program().public_input(' '.join(str(y) for y in tuplify(x))) def decorator(loop_body): @for_range(len(a)) def f(i): loop_body(i, *(public_input() for j in range(len(tuplify(a[0]))))) return f return decorator
def map_sum(n_threads, n_parallel, n_loops, n_items, value_types): value_types = tuplify(value_types) if len(value_types) == 1: value_types *= n_items elif len(value_types) != n_items: raise CompilerError('Incorrect number of value_types.') initializer = lambda: [t(0) for t in value_types] def summer(x,y): return tuple(a + b for a,b in zip(x,y)) return map_reduce(n_threads, n_parallel, n_loops, initializer, summer)
def decorator(loop_body): if isinstance(n_loops, int): loop_rounds = n_loops / n_parallel \ if n_parallel < n_loops else 0 else: loop_rounds = n_loops / n_parallel def write_state_to_memory(r): if use_array: mem_state.assign(r) else: # cannot do mem_state = [...] due to scope issue for j, x in enumerate(r): mem_state[j].write(x) # will be optimized out if n_loops <= n_parallel @for_range(loop_rounds) def f(i): state = tuplify(initializer()) for k in range(n_parallel): j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) r = reducer(mem_state, state) write_state_to_memory(r) if isinstance(n_loops, int): state = mem_state for j in range(loop_rounds * n_parallel, n_loops): state = reducer(tuplify(loop_body(j)), state) else: @for_range(loop_rounds * n_parallel, n_loops) def f(j): r = reducer(tuplify(loop_body(j)), mem_state) write_state_to_memory(r) state = mem_state for i, x in enumerate(state): if use_array: mem_state[i] = x else: mem_state[i].write(x) def returner(): return untuplify(tuple(state)) return returner
def decorator(loop_body): if isinstance(n_loops, int): loop_rounds = n_loops / n_parallel \ if n_parallel < n_loops else 0 else: loop_rounds = n_loops / n_parallel def write_state_to_memory(r): if use_array: mem_state.assign(r) else: # cannot do mem_state = [...] due to scope issue for j,x in enumerate(r): mem_state[j].write(x) # will be optimized out if n_loops <= n_parallel @for_range(loop_rounds) def f(i): state = tuplify(initializer()) for k in range(n_parallel): j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) r = reducer(mem_state, state) write_state_to_memory(r) if isinstance(n_loops, int): state = mem_state for j in range(loop_rounds * n_parallel, n_loops): state = reducer(tuplify(loop_body(j)), state) else: @for_range(loop_rounds * n_parallel, n_loops) def f(j): r = reducer(tuplify(loop_body(j)), mem_state) write_state_to_memory(r) state = mem_state for i,x in enumerate(state): if use_array: mem_state[i] = x else: mem_state[i].write(x) def returner(): return untuplify(tuple(state)) return returner
def set_slice(self, value): value = sbits.compose(util.tuplify(value), sum(self.lengths)) for i,b in enumerate(self.start_bits): value = b.if_else(value << (2**i * sum(self.lengths)), value) self.value = value + self.anti_value return self
def f(i): loop_body(i, *(public_input() for j in range(len(tuplify(a[0])))))
def f(j): r = reducer(tuplify(loop_body(j)), mem_state) write_state_to_memory(r)
def decorator(loop_body): my_n_parallel = n_parallel if isinstance(n_parallel, int): if isinstance(n_loops, int): loop_rounds = n_loops / n_parallel \ if n_parallel < n_loops else 0 else: loop_rounds = n_loops / n_parallel def write_state_to_memory(r): if use_array: mem_state.assign(r) else: # cannot do mem_state = [...] due to scope issue for j, x in enumerate(r): mem_state[j].write(x) if n_parallel is not None: # will be optimized out if n_loops <= n_parallel @for_range(loop_rounds) def f(i): state = tuplify(initializer()) for k in range(n_parallel): j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) r = reducer(mem_state, state) write_state_to_memory(r) else: n_parallel_reg = MemValue(regint(0)) parent_block = get_block() @while_do(lambda x: x + n_parallel_reg <= n_loops, regint(0)) def _(i): state = tuplify(initializer()) k = 0 block = get_block() while k < n_loops and (len(get_block()) < get_program().budget \ or k == 0) \ and block is get_block(): j = i + k state = reducer(tuplify(loop_body(j)), state) k += 1 r = reducer(mem_state, state) write_state_to_memory(r) global n_opt_loops n_opt_loops = k n_parallel_reg.write(k) return i + k my_n_parallel = n_opt_loops loop_rounds = n_loops / my_n_parallel blocks = get_tape().basicblocks n_to_merge = 5 if loop_rounds == 1 and parent_block is blocks[-n_to_merge]: # merge blocks started by if and do_while def exit_elimination(block): if block.exit_condition is not None: for reg in block.exit_condition.get_used(): reg.can_eliminate = True exit_elimination(parent_block) merged = parent_block merged.exit_condition = blocks[-1].exit_condition merged.exit_block = blocks[-1].exit_block assert parent_block is blocks[-n_to_merge] assert blocks[-n_to_merge + 1] is \ get_tape().req_node.children[-1].nodes[0].blocks[0] for block in blocks[-n_to_merge + 1:]: merged.instructions += block.instructions exit_elimination(block) del blocks[-n_to_merge + 1:] del get_tape().req_node.children[-1] merged.children = [] get_tape().active_basicblock = merged else: req_node = get_tape().req_node.children[-1].nodes[0] req_node.children[0].aggregator = lambda x: loop_rounds * x[0] if isinstance(n_loops, int): state = mem_state for j in range(loop_rounds * my_n_parallel, n_loops): state = reducer(tuplify(loop_body(j)), state) else: @for_range(loop_rounds * my_n_parallel, n_loops) def f(j): r = reducer(tuplify(loop_body(j)), mem_state) write_state_to_memory(r) state = mem_state for i, x in enumerate(state): if use_array: mem_state[i] = x else: mem_state[i].write(x) def returner(): return untuplify(tuple(state)) return returner