def instantiate_memory(top_or_left, idx, size): """ Instantiates: - top memory - structure to move data from memory to read registers. Returns (cells, structure) tuple. """ if top_or_left == "top": name = f"t{idx}" target_reg = f"top_0_{idx}" elif top_or_left == "left": name = f"l{idx}" target_reg = f"left_{idx}_0" else: raise f"Invalid top_or_left: {top_or_left}" var_name = py_ast.CompVar(f"{name}") idx_name = py_ast.CompVar(NAME_SCHEME["index name"].format(prefix=name)) group_name = py_ast.CompVar(NAME_SCHEME["memory move"].format(prefix=name)) target_reg = py_ast.CompVar(target_reg) structure = py_ast.Group( group_name, connections=[ py_ast.Connect(py_ast.CompPort(idx_name, "out"), py_ast.CompPort(var_name, "addr0")), py_ast.Connect( py_ast.CompPort(var_name, "read_data"), py_ast.CompPort(target_reg, "in"), ), py_ast.Connect(py_ast.ConstantPort(1, 1), py_ast.CompPort(target_reg, "write_en")), py_ast.Connect(py_ast.CompPort(target_reg, "done"), py_ast.HolePort(group_name, "done")), ], ) idx_width = bits_needed(size) # Instantiate the indexor (idx_cells, idx_structure) = instantiate_indexor(name, idx_width) idx_structure.append(structure) # Instantiate the memory idx_cells.append( py_ast.Cell( var_name, py_ast.Stdlib().mem_d1(BITWIDTH, size, idx_width), is_external=True, )) return (idx_cells, idx_structure)
def get_memory(name: str, type: tvm.ir.Type) -> Cell: """Returns a Calyx memory for a given TVM type. For non-Tensor types, a register is returned. Otherwise, a memory with the corresponding dimension size is returned, if it exists in Calyx.""" dims = type.concrete_shape # Bitwidth, along with sizes and index sizes (if it is a Tensor). args = [get_bitwidth(type)] + [d for d in dims ] + [bits_needed(d) for d in dims] num_dims = len(dims) assert num_dims in NumDimsToCell, f"Memory of size {num_dims} not supported." return Cell(CompVar(name), NumDimsToCell[num_dims](*args), is_external=True)
def create_systolic_array(top_length, top_depth, left_length, left_depth, gen_metadata=False): """ top_length: Number of PEs in each row. top_depth: Number of elements processed by each PE in a row. left_length: Number of PEs in each column. left_depth: Number of elements processed by each PE in a col. """ assert top_depth == left_depth, ( f"Cannot multiply matrices: " f"{top_length}x{top_depth} and {left_depth}x{left_length}") cells = [] wires = [] # Instantiate all the memories for r in range(top_length): (c, s) = instantiate_memory("top", r, top_depth) cells.extend(c) wires.extend(s) for c in range(left_length): (c, s) = instantiate_memory("left", c, left_depth) cells.extend(c) wires.extend(s) # Instantiate output memory total_size = left_length * top_length out_idx_size = bits_needed(total_size) cells.append( py_ast.Cell( OUT_MEM, py_ast.Stdlib().mem_d1(BITWIDTH, total_size, out_idx_size), is_external=True, )) # Instantiate all the PEs for row in range(left_length): for col in range(top_length): # Instantiate the PEs c = instantiate_pe(row, col, col == top_length - 1, row == left_length - 1) cells.extend(c) # Instantiate the mover fabric s = instantiate_data_move(row, col, col == top_length - 1, row == left_length - 1) wires.extend(s) # Instantiate output movement structure s = instantiate_output_move(row, col, top_length, out_idx_size) wires.append(s) control, source_map = generate_control(top_length, top_depth, left_length, left_depth, gen_metadata=gen_metadata) main = py_ast.Component( name="main", inputs=[], outputs=[], structs=wires + cells, controls=control, ) return ( py_ast.Program( imports=[ py_ast.Import("primitives/core.futil"), py_ast.Import("primitives/binary_operators.futil"), ], components=[main], ), source_map, )
def generate_ntt_pipeline(input_bitwidth: int, n: int, q: int): """ Prints a pipeline in Calyx for the cooley-tukey algorithm that uses phis in bit-reversed order. `n`: Length of the input array. `input_bitwidth`: Bit width of the values in the input array. `q`: The modulus value. Reference: https://www.microsoft.com/en-us/research/wp-content/uploads/2016/05/RLWE-1.pdf """ assert n > 0 and (not (n & (n - 1))), f"Input length: {n} must be a power of 2." bitwidth = bits_needed(n) num_stages = bitwidth - 1 operations = get_pipeline_data(n, num_stages) multiplies = get_multiply_data(n, num_stages) # Used to determine the index of the component # for the `sadd` and `ssub` primitives. component_counts = {"add": 0, "sub": 0} def fresh_comp_index(op): # Produces a new index for the component used in the stage. # This allows for N / 2 `sadd` and `ssub` components. saved_count = component_counts[op] if component_counts[op] == (n // 2) - 1: # Reset for the next stage. component_counts[op] = 0 else: component_counts[op] += 1 return saved_count # Memory component variables. input = CompVar("a") phis = CompVar("phis") def mul_group(stage, mul_tuple): mul_index, k, phi_index = mul_tuple group_name = CompVar(f"s{stage}_mul{mul_index}") mult_pipe = CompVar(f"mult_pipe{mul_index}") phi = CompVar(f"phi{phi_index}") reg = CompVar(f"r{k}") connections = [ Connect(CompPort(phi, "out"), CompPort(mult_pipe, "left")), Connect(CompPort(reg, "out"), CompPort(mult_pipe, "right")), Connect(ConstantPort(1, 1), CompPort(mult_pipe, "go")), Connect(CompPort(mult_pipe, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections) def op_mod_group(stage, row, operations_tuple): lhs, op, mul_index = operations_tuple comp = "add" if op == "+" else "sub" comp_index = fresh_comp_index(comp) group_name = CompVar(f"s{stage}_r{row}_op_mod") op = CompVar(f"{comp}{comp_index}") reg = CompVar(f"r{lhs}") mul = CompVar(f"mult_pipe{mul_index}") mod_pipe = CompVar(f"mod_pipe{row}") A = CompVar(f"A{row}") connections = [ Connect(CompPort(reg, "out"), CompPort(op, "left")), Connect(CompPort(mul, "out"), CompPort(op, "right")), Connect(CompPort(op, "out"), CompPort(mod_pipe, "left")), Connect(ConstantPort(input_bitwidth, q), CompPort(mod_pipe, "right")), Connect( ConstantPort(1, 1), CompPort(mod_pipe, "go"), Not(Atom(CompPort(mod_pipe, "done"))), ), Connect(CompPort(mod_pipe, "done"), CompPort(A, "write_en")), Connect(CompPort(mod_pipe, "out_remainder"), CompPort(A, "in")), Connect(CompPort(A, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections) def precursor_group(row): group_name = CompVar(f"precursor_{row}") r = CompVar(f"r{row}") A = CompVar(f"A{row}") connections = [ Connect(CompPort(A, "out"), CompPort(r, "in")), Connect(ConstantPort(1, 1), CompPort(r, "write_en")), Connect(CompPort(r, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections) def preamble_group(row): reg = CompVar(f"r{row}") phi = CompVar(f"phi{row}") group_name = CompVar(f"preamble_{row}") connections = [ Connect(ConstantPort(bitwidth, row), CompPort(input, "addr0")), Connect(ConstantPort(bitwidth, row), CompPort(phis, "addr0")), Connect(ConstantPort(1, 1), CompPort(reg, "write_en")), Connect(CompPort(input, "read_data"), CompPort(reg, "in")), Connect(ConstantPort(1, 1), CompPort(phi, "write_en")), Connect(CompPort(phis, "read_data"), CompPort(phi, "in")), Connect( ConstantPort(1, 1), HolePort(group_name, "done"), And(Atom(CompPort(reg, "done")), Atom(CompPort(phi, "done"))), ), ] return Group(group_name, connections) def epilogue_group(row): group_name = CompVar(f"epilogue_{row}") A = CompVar(f"A{row}") connections = [ Connect(ConstantPort(bitwidth, row), CompPort(input, "addr0")), Connect(ConstantPort(1, 1), CompPort(input, "write_en")), Connect(CompPort(A, "out"), CompPort(input, "write_data")), Connect(CompPort(input, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections) def cells(): stdlib = Stdlib() memories = [ Cell(input, stdlib.mem_d1(input_bitwidth, n, bitwidth), is_external=True), Cell(phis, stdlib.mem_d1(input_bitwidth, n, bitwidth), is_external=True), ] r_regs = [ Cell(CompVar(f"r{r}"), stdlib.register(input_bitwidth)) for r in range(n) ] A_regs = [ Cell(CompVar(f"A{r}"), stdlib.register(input_bitwidth)) for r in range(n) ] mul_regs = [ Cell(CompVar(f"mul{i}"), stdlib.register(input_bitwidth)) for i in range(n // 2) ] phi_regs = [ Cell(CompVar(f"phi{r}"), stdlib.register(input_bitwidth)) for r in range(n) ] mod_pipes = [ Cell( CompVar(f"mod_pipe{r}"), stdlib.op("div_pipe", input_bitwidth, signed=True), ) for r in range(n) ] mult_pipes = [ Cell( CompVar(f"mult_pipe{i}"), stdlib.op("mult_pipe", input_bitwidth, signed=True), ) for i in range(n // 2) ] adds = [ Cell(CompVar(f"add{i}"), stdlib.op("add", input_bitwidth, signed=True)) for i in range(n // 2) ] subs = [ Cell(CompVar(f"sub{i}"), stdlib.op("sub", input_bitwidth, signed=True)) for i in range(n // 2) ] return ( memories + r_regs + A_regs + mul_regs + phi_regs + mod_pipes + mult_pipes + adds + subs ) def wires(): preambles = [preamble_group(r) for r in range(n)] precursors = [precursor_group(r) for r in range(n)] muls = [ mul_group(s, multiplies[s][i]) for s in range(num_stages) for i in range(n // 2) ] op_mods = [ op_mod_group(s, r, operations[s][r]) for s in range(num_stages) for r in range(n) ] epilogues = [epilogue_group(r) for r in range(n)] return preambles + precursors + muls + op_mods + epilogues def control(): preambles = [SeqComp([Enable(f"preamble_{r}") for r in range(n)])] epilogues = [SeqComp([Enable(f"epilogue_{r}") for r in range(n)])] ntt_stages = [] for s in range(num_stages): if s != 0: # Only append precursors if this is not the first stage. ntt_stages.append(ParComp([Enable(f"precursor_{r}") for r in range(n)])) # Multiply ntt_stages.append(ParComp([Enable(f"s{s}_mul{i}") for i in range(n // 2)])) # Addition or subtraction mod `q` ntt_stages.append(ParComp([Enable(f"s{s}_r{r}_op_mod") for r in range(n)])) return SeqComp(preambles + ntt_stages + epilogues) pp_table(operations, multiplies, n, num_stages) return Program( imports=[ Import("primitives/core.futil"), Import("primitives/binary_operators.futil"), ], components=[ Component( "main", inputs=[], outputs=[], structs=cells() + wires(), controls=control(), ) ], )