def control(): preambles = [SeqComp([Enable(f"preamble_{r}") for r in range(n)])] epilogues = [SeqComp([Enable(f"epilogue_{r}") for r in range(n)])] ntt_stages = [] for s in range(num_stages): if s != 0: # Only append precursors if this is not the first stage. ntt_stages.append(ParComp([Enable(f"precursor_{r}") for r in range(n)])) # Multiply ntt_stages.append(ParComp([Enable(f"s{s}_mul{i}") for i in range(n // 2)])) # Addition or subtraction mod `q` ntt_stages.append(ParComp([Enable(f"s{s}_r{r}_op_mod") for r in range(n)])) return SeqComp(preambles + ntt_stages + epilogues)
def gen_reduce_impl(stmt, arr_size, s_idx): """ Returns a dictionary containing Calyx cells, wires and control needed to implement a map statement. Similar to gen_map_impl, with an implementation of a body of the `reduce` statement instead of an implementation of a `map` statement. """ stdlib = Stdlib() op_name = "mult" if stmt.op.body.op == "mul" else "add" cells = [ Cell(CompVar(f"le{s_idx}"), stdlib.op("lt", 32, signed=False)), Cell(CompVar(f"idx{s_idx}"), stdlib.register(32)), Cell(CompVar(f"adder_idx{s_idx}"), stdlib.op("add", 32, signed=False)), Cell(CompVar(f"adder_op{s_idx}"), stdlib.op(f"{op_name}", 32, signed=False)), ] wires = [ emit_cond_group(s_idx, arr_size), emit_idx_group(s_idx), emit_eval_body_group(s_idx, stmt, 0), ] control = While( port=CompPort(CompVar(f"le{s_idx}"), "out"), cond=CompVar(f"cond{s_idx}"), body=SeqComp([Enable(f"eval_body{s_idx}"), Enable(f"incr_idx{s_idx}")]), ) return {"cells": cells, "wires": wires, "control": control}
def generate_control(degree: int, is_signed: bool) -> Control: pow_invokes = [ ParComp([ Invoke( CompVar("pow1"), [ ("base", CompPort(CompVar("e"), "out")), ("integer_exp", CompPort(CompVar("int_x"), "out")), ], [], ) ] + [ Invoke( CompVar(f"pow{i}"), [ ("base", CompPort(CompVar("frac_x"), "out")), ("integer_exp", CompPort(CompVar(f"c{i}"), "out")), ], [], ) for i in range(2, degree + 1) ]) ] consume_pow = [ ParComp([Enable(f"consume_pow{i}") for i in range(2, degree + 1)]) ] mult_by_reciprocal = [ ParComp([ Enable(f"mult_by_reciprocal_factorial{i}") for i in range(2, degree + 1) ]) ] divide_and_conquer = [] Enable_count = degree >> 1 for r in range(1, int(log2(degree) + 1)): divide_and_conquer.append( ParComp([ Enable(f"sum_round{r}_{i}") for i in range(1, Enable_count + 1) ])) Enable_count >>= 1 ending_sequence = [Enable("add_degree_zero"), Enable("final_multiply")] + ([ If( CompPort(CompVar("lt"), "out"), CompVar("is_negative"), Enable("reciprocal"), ) ] if is_signed else []) return SeqComp([Enable("init")] + ([ If( CompPort(CompVar("lt"), "out"), CompVar("is_negative"), Enable("negate"), ) ] if is_signed else []) + [Enable("split_bits")] + pow_invokes + consume_pow + mult_by_reciprocal + divide_and_conquer + ending_sequence)
def gen_map_impl(stmt, arr_size, bank_factor, s_idx): """ Returns a dictionary containing Calyx cells, wires and control needed to implement a map statement. (See gen_stmt_impl for format of the dictionary.) Generates these groups: - a group that implements the body of the map statement - a group that increments an index to access the map input array - a group that implements the loop condition, checking if the index has reached the end of the input array """ stdlib = Stdlib() cells = [] for b in range(bank_factor): cells.extend([ Cell(CompVar(f"le_b{b}_{s_idx}"), stdlib.op("lt", 32, signed=False)), Cell(CompVar(f"idx_b{b}_{s_idx}"), stdlib.register(32)), Cell( CompVar(f"adder_idx_b{b}_{s_idx}"), stdlib.op("add", 32, signed=False), ), ]) op_name = "mult" if stmt.op.body.op == "mul" else "add" for b in range(bank_factor): cells.append( Cell( CompVar(f"adder_op_b{b}_{s_idx}"), stdlib.op(f"{op_name}", 32, signed=False), )) wires = [] for b in range(bank_factor): wires.extend([ emit_cond_group(s_idx, arr_size // bank_factor, b), emit_idx_group(s_idx, b), emit_eval_body_group(s_idx, stmt, b), ]) map_loops = [] for b in range(bank_factor): b_suffix = f"_b{str(b)}_" map_loops.append( While( CompPort(CompVar(f"le{b_suffix}{s_idx}"), "out"), CompVar(f"cond{b_suffix}{s_idx}"), SeqComp([ Enable(f"eval_body{b_suffix}{s_idx}"), Enable(f"incr_idx{b_suffix}{s_idx}"), ]), )) control = ParComp(map_loops) return {"cells": cells, "wires": wires, "control": control}
def reduce_parallel_control_pass(component: Component, N: int, input_size: int): """Reduces the amount of fan-out by reducing parallelization in the execution flow by a factor of `N`. For example, given an input size 4 and reduction factor N = 2: Before: par { s0_mul0; s0_mul1; } par { s0_r0_op_mod; s0_r1_op_mod; s0_r2_op_mod; s0_r3_op_mod; } ... After: par { s0_mul0; s0_mul1; } par { s0_r0_op_mod; s0_r1_op_mod; } par { s0_r2_op_mod; s0_r3_op_mod; } ... """ assert ( N is not None and 0 < N < input_size and (not (N & (N - 1))) ), f"""N: {N} should be a power of two within bounds (0, {input_size}).""" reduced_controls = [] for control in component.controls.stmts: if not isinstance(control, ParComp): reduced_controls.append(control) continue enable = next(iter(control.stmts), None).stmt # Parallelized multiplies are already a factor of 1/2 less. factor = N // 2 if "mul" in enable else N reduced_controls.extend( ParComp(x) for x in np.split(np.array(control.stmts), factor) ) component.controls = SeqComp(reduced_controls)
def emit(prog): """ Returns a string containing a Calyx program, compiled from `prog`, a MrXL program. """ cells, wires, control = [], [], [] # All arrays must be the same size. The first array we see determines the # size that we'll assume for the rest of the program's arrays. arr_size = None # Collect banking factors. name2par = dict() for stmt in prog.stmts: if isinstance(stmt.op, ast.Map): name2par[stmt.dest] = stmt.op.par for b in stmt.op.bind: name2par[b.src] = stmt.op.par # Collect memory and register declarations. used_names = [] stdlib = Stdlib() for decl in prog.decls: used_names.append(decl.name) if decl.type.size: # A memory arr_size = decl.type.size cells.extend( emit_mem_decl(decl.name, decl.type.size, name2par[decl.name])) else: # A register cells.append(Cell(CompVar(decl.name), stdlib.register(32))) # Collect implicit memory and register declarations. for stmt in prog.stmts: if stmt.dest not in used_names: if isinstance(stmt.op, ast.Map): cells.extend( emit_mem_decl(stmt.dest, arr_size, name2par[stmt.dest])) else: raise NotImplementedError("Generating register declarations") # cells.append(emit_reg_decl(stmt.dest, 32)) used_names.append(stmt.dest) # Generate Calyx. for i, stmt in enumerate(prog.stmts): stmt_impl = gen_stmt_impl(stmt, arr_size, name2par, i) cells.extend(stmt_impl["cells"]) wires.extend(stmt_impl["wires"]) control.append(stmt_impl["control"]) program = Program( imports=[ Import("primitives/core.futil"), Import("primitives/binary_operators.futil"), ], components=[ Component( name="main", inputs=[], outputs=[], structs=cells + wires, controls=SeqComp(control), ) ], ) program.emit()
def generate_fp_pow_component(width: int, int_width: int, is_signed: bool) -> Component: """Generates a fixed point `pow` component, which computes the value x**y, where y must be an integer. """ stdlib = Stdlib() frac_width = width - int_width pow = CompVar("pow") count = CompVar("count") mul = CompVar("mul") lt = CompVar("lt") incr = CompVar("incr") cells = [ Cell(pow, stdlib.register(width)), Cell(count, stdlib.register(width)), Cell( mul, stdlib.fixed_point_op("mult_pipe", width, int_width, frac_width, signed=is_signed), ), Cell(lt, stdlib.op("lt", width, signed=is_signed)), Cell(incr, stdlib.op("add", width, signed=is_signed)), ] wires = [ Group( id=CompVar("init"), connections=[ Connect( ConstantPort( width, numeric_types.FixedPoint( "1.0", width, int_width, is_signed=is_signed).unsigned_integer(), ), CompPort(pow, "in"), ), Connect(ConstantPort(1, 1), CompPort(pow, "write_en")), Connect(ConstantPort(width, 0), CompPort(count, "in")), Connect(ConstantPort(1, 1), CompPort(count, "write_en")), Connect( ConstantPort(1, 1), HolePort(CompVar("init"), "done"), And( Atom(CompPort(pow, "done")), Atom(CompPort(count, "done")), ), ), ], ), Group( id=CompVar("execute_mul"), connections=[ Connect(ThisPort(CompVar("base")), CompPort(mul, "left")), Connect(CompPort(pow, "out"), CompPort(mul, "right")), Connect( ConstantPort(1, 1), CompPort(mul, "go"), Not(Atom(CompPort(mul, "done"))), ), Connect(CompPort(mul, "done"), CompPort(pow, "write_en")), Connect(CompPort(mul, "out"), CompPort(pow, "in")), Connect( CompPort(pow, "done"), HolePort(CompVar("execute_mul"), "done"), ), ], ), Group( id=CompVar("incr_count"), connections=[ Connect(ConstantPort(width, 1), CompPort(incr, "left")), Connect(CompPort(count, "out"), CompPort(incr, "right")), Connect(CompPort(incr, "out"), CompPort(count, "in")), Connect(ConstantPort(1, 1), CompPort(count, "write_en")), Connect( CompPort(count, "done"), HolePort(CompVar("incr_count"), "done"), ), ], ), CombGroup( id=CompVar("cond"), connections=[ Connect(CompPort(count, "out"), CompPort(lt, "left")), Connect(ThisPort(CompVar("integer_exp")), CompPort(lt, "right")), ], ), Connect(CompPort(CompVar("pow"), "out"), ThisPort(CompVar("out"))), ] return Component( "fp_pow", inputs=[ PortDef(CompVar("base"), width), PortDef(CompVar("integer_exp"), width), ], outputs=[PortDef(CompVar("out"), width)], structs=cells + wires, controls=SeqComp([ Enable("init"), While( CompPort(lt, "out"), CompVar("cond"), ParComp([Enable("execute_mul"), Enable("incr_count")]), ), ]), )
Connect( ConstantPort(1, 0), CompPort(CompVar("ret"), "addr0"), ), Connect( ConstantPort(1, 1), CompPort(CompVar("ret"), "write_en"), ), Connect( CompPort(CompVar("e"), "out"), CompPort(CompVar("ret"), "write_data"), ), Connect( CompPort(CompVar("ret"), "done"), HolePort(CompVar("write_to_memory"), "done"), ), ], ), ], controls=SeqComp([ Enable("init"), Invoke( id=CompVar("e"), in_connects=[("x", CompPort(CompVar("t"), "out"))], out_connects=[], ), Enable("write_to_memory"), ]), )) program.emit()