def get_memory(name: str, type: tvm.ir.Type) -> Cell: """Returns a Calyx memory for a given TVM type. For non-Tensor types, a register is returned. Otherwise, a memory with the corresponding dimension size is returned, if it exists in Calyx.""" dims = type.concrete_shape # Bitwidth, along with sizes and index sizes (if it is a Tensor). args = [get_bitwidth(type)] + [d for d in dims ] + [bits_needed(d) for d in dims] num_dims = len(dims) assert num_dims in NumDimsToCell, f'Memory of size {num_dims} not supported.' return Cell(CompVar(name), NumDimsToCell[num_dims](*args))
def instantiate_memory(top_or_left, idx, size): """ Instantiates: - top memory - structure to move data from memory to read registers. Returns (cells, structure) tuple. """ if top_or_left == "top": name = f"t{idx}" target_reg = f"top_0_{idx}" elif top_or_left == "left": name = f"l{idx}" target_reg = f"left_{idx}_0" else: raise f"Invalid top_or_left: {top_or_left}" var_name = ast.CompVar(f"{name}") idx_name = ast.CompVar(NAME_SCHEME["index name"].format(prefix=name)) group_name = ast.CompVar(NAME_SCHEME["memory move"].format(prefix=name)) target_reg = ast.CompVar(target_reg) structure = ast.Group( group_name, connections=[ ast.Connect(ast.CompPort(idx_name, "out"), ast.CompPort(var_name, "addr0")), ast.Connect( ast.CompPort(var_name, "read_data"), ast.CompPort(target_reg, "in") ), ast.Connect(ast.ConstantPort(1, 1), ast.CompPort(target_reg, "write_en")), ast.Connect( ast.CompPort(target_reg, "done"), ast.HolePort(group_name, "done") ), ], ) idx_width = bits_needed(size) # Instantiate the indexor (idx_cells, idx_structure) = instantiate_indexor(name, idx_width) idx_structure.append(structure) # Instantiate the memory idx_cells.append( ast.Cell( var_name, ast.Stdlib().mem_d1(BITWIDTH, size, idx_width), is_external=True ) ) return (idx_cells, idx_structure)
def create_systolic_array(top_length, top_depth, left_length, left_depth): """ top_length: Number of PEs in each row. top_depth: Number of elements processed by each PE in a row. left_length: Number of PEs in each column. left_depth: Number of elements processed by each PE in a col. """ assert top_depth == left_depth, ( f"Cannot multiply matrices: " f"{top_length}x{top_depth} and {left_depth}x{left_length}" ) cells = [] wires = [] # Instantiate all the memories for r in range(top_length): (c, s) = instantiate_memory("top", r, top_depth) cells.extend(c) wires.extend(s) for c in range(left_length): (c, s) = instantiate_memory("left", c, left_depth) cells.extend(c) wires.extend(s) # Instantiate output memory out_ridx_size = bits_needed(left_length) out_cidx_size = bits_needed(top_length) cells.append( ast.Cell( OUT_MEM, ast.Stdlib().mem_d2( BITWIDTH, left_length, top_length, out_ridx_size, out_cidx_size ), is_external=True, ) ) # Instantiate all the PEs for row in range(left_length): for col in range(top_length): # Instantiate the PEs c = instantiate_pe(row, col, col == top_length - 1, row == left_length - 1) cells.extend(c) # Instantiate the mover fabric s = instantiate_data_move( row, col, col == top_length - 1, row == left_length - 1 ) wires.extend(s) # Instantiate output movement structure s = instantiate_output_move(row, col, out_ridx_size, out_cidx_size) wires.append(s) main = ast.Component( name="main", inputs=[], outputs=[], structs=wires + cells, controls=generate_control(top_length, top_depth, left_length, left_depth), ) return ast.Program(imports=[ast.Import("primitives/std.lib")], components=[main])
def generate_ntt_pipeline(input_bitwidth, n, q): """ Prints a pipeline in FuTIL for the cooley-tukey algorithm that uses phis in bit-reversed order. `n`: Length of the input array. `input_bitwidth`: Bit width of the values in the input array. `q`: The modulus value. Reference: https://www.microsoft.com/en-us/research/wp-content/uploads/2016/05/RLWE-1.pdf """ assert n > 0 and ( not (n & (n - 1))), f'Input length: {n} must be a power of 2.' bitwidth = bits_needed(n) num_stages = bitwidth - 1 operations = get_pipeline_data(n, num_stages) multiplies = get_multiply_data(n, num_stages) # Used to determine the index of the component # for the `sadd` and `ssub` primitives. component_counts = {'add': 0, 'sub': 0} def fresh_comp_index(op): # Produces a new index for the component used in the stage. # This allows for N / 2 `sadd` and `ssub` components. saved_count = component_counts[op] if component_counts[op] == (n // 2) - 1: # Reset for the next stage. component_counts[op] = 0 else: component_counts[op] += 1 return saved_count # Memory component variables. input = CompVar('a') phis = CompVar('phis') def mul_group(stage, mul_tuple): mul_index, k, phi_index = mul_tuple group_name = CompVar(f's{stage}_mul{mul_index}') mult_pipe = CompVar(f'mult_pipe{mul_index}') mul = CompVar(f'mul{mul_index}') phi = CompVar(f'phi{phi_index}') reg = CompVar(f'r{k}') connections = [ Connect(CompPort(phi, 'out'), CompPort(mult_pipe, 'left')), Connect(CompPort(reg, 'out'), CompPort(mult_pipe, 'right')), Connect(ConstantPort(1, 1), CompPort(mult_pipe, 'go'), Not(Atom(CompPort(mult_pipe, 'done')))), Connect(CompPort(mult_pipe, 'done'), CompPort(mul, 'write_en')), Connect(CompPort(mult_pipe, 'out'), CompPort(mul, 'in')), Connect(CompPort(mul, 'done'), HolePort(group_name, 'done')) ] return Group(group_name, connections) def op_mod_group(stage, row, operations_tuple): lhs, op, mul_index = operations_tuple comp = 'add' if op == '+' else 'sub' comp_index = fresh_comp_index(comp) group_name = CompVar(f's{stage}_r{row}_op_mod') op = CompVar(f'{comp}{comp_index}') reg = CompVar(f'r{lhs}') mul = CompVar(f'mul{mul_index}') mod_pipe = CompVar(f'mod_pipe{row}') A = CompVar(f'A{row}') connections = [ Connect(CompPort(reg, 'out'), CompPort(op, 'left')), Connect(CompPort(mul, 'out'), CompPort(op, 'right')), Connect(CompPort(op, 'out'), CompPort(mod_pipe, 'left')), Connect(ConstantPort(input_bitwidth, q), CompPort(mod_pipe, 'right')), Connect(ConstantPort(1, 1), CompPort(mod_pipe, 'go'), Not(Atom(CompPort(mod_pipe, 'done')))), Connect(CompPort(mod_pipe, 'done'), CompPort(A, 'write_en')), Connect(CompPort(mod_pipe, 'out'), CompPort(A, 'in')), Connect(CompPort(A, 'done'), HolePort(group_name, 'done')) ] return Group(group_name, connections) def precursor_group(row): group_name = CompVar(f'precursor_{row}') r = CompVar(f'r{row}') A = CompVar(f'A{row}') connections = [ Connect(CompPort(A, 'out'), CompPort(r, 'in')), Connect(ConstantPort(1, 1), CompPort(r, 'write_en')), Connect(CompPort(r, 'done'), HolePort(group_name, 'done')) ] return Group(group_name, connections) def preamble_group(row): reg = CompVar(f'r{row}') phi = CompVar(f'phi{row}') group_name = CompVar(f'preamble_{row}') connections = [ Connect(ConstantPort(bitwidth, row), CompPort(input, 'addr0')), Connect(ConstantPort(bitwidth, row), CompPort(phis, 'addr0')), Connect(ConstantPort(1, 1), CompPort(reg, 'write_en')), Connect(CompPort(input, 'read_data'), CompPort(reg, 'in')), Connect(ConstantPort(1, 1), CompPort(phi, 'write_en')), Connect(CompPort(phis, 'read_data'), CompPort(phi, 'in')), Connect( ConstantPort(1, 1), HolePort(group_name, 'done'), And(Atom(CompPort(reg, 'done')), Atom(CompPort(phi, 'done')))) ] return Group(group_name, connections) def epilogue_group(row): group_name = CompVar(f'epilogue_{row}') A = CompVar(f'A{row}') connections = [ Connect(ConstantPort(bitwidth, row), CompPort(input, 'addr0')), Connect(ConstantPort(1, 1), CompPort(input, 'write_en')), Connect(CompPort(A, 'out'), CompPort(input, 'write_data')), Connect(CompPort(input, 'done'), HolePort(group_name, 'done')) ] return Group(group_name, connections) def cells(): stdlib = Stdlib() memories = [ Cell(input, stdlib.mem_d1(input_bitwidth, n, bitwidth)), Cell(phis, stdlib.mem_d1(input_bitwidth, n, bitwidth)) ] r_regs = [ Cell(CompVar(f'r{r}'), stdlib.register(input_bitwidth)) for r in range(n) ] A_regs = [ Cell(CompVar(f'A{r}'), stdlib.register(input_bitwidth)) for r in range(n) ] mul_regs = [ Cell(CompVar(f'mul{i}'), stdlib.register(input_bitwidth)) for i in range(n // 2) ] phi_regs = [ Cell(CompVar(f'phi{r}'), stdlib.register(input_bitwidth)) for r in range(n) ] mod_pipes = [ Cell(CompVar(f'mod_pipe{r}'), stdlib.op('mod_pipe', input_bitwidth, signed=True)) for r in range(n) ] mult_pipes = [ Cell(CompVar(f'mult_pipe{i}'), stdlib.op('mult_pipe', input_bitwidth, signed=True)) for i in range(n // 2) ] adds = [ Cell(CompVar(f'add{i}'), stdlib.op('add', input_bitwidth, signed=True)) for i in range(n // 2) ] subs = [ Cell(CompVar(f'sub{i}'), stdlib.op('sub', input_bitwidth, signed=True)) for i in range(n // 2) ] return memories + r_regs + A_regs + mul_regs + phi_regs + mod_pipes + mult_pipes + adds + subs def wires(): preambles = [preamble_group(r) for r in range(n)] precursors = [precursor_group(r) for r in range(n)] muls = [ mul_group(s, multiplies[s][i]) for s in range(num_stages) for i in range(n // 2) ] op_mods = [ op_mod_group(s, r, operations[s][r]) for s in range(num_stages) for r in range(n) ] epilogues = [epilogue_group(r) for r in range(n)] return preambles + precursors + muls + op_mods + epilogues def control(): preambles = [SeqComp([Enable(f'preamble_{r}') for r in range(n)])] epilogues = [SeqComp([Enable(f'epilogue_{r}') for r in range(n)])] ntt_stages = [] for s in range(num_stages): if s != 0: # Only append precursors if this is not the first stage. ntt_stages.append( ParComp([Enable(f'precursor_{r}') for r in range(n)])) # Multiply ntt_stages.append( ParComp([Enable(f's{s}_mul{i}') for i in range(n // 2)])) # Addition or subtraction mod `q` ntt_stages.append( ParComp([Enable(f's{s}_r{r}_op_mod') for r in range(n)])) return SeqComp(preambles + ntt_stages + epilogues) pp_table(operations, multiplies, n, num_stages) Program(imports=[Import('primitives/std.lib')], components=[ Component('main', inputs=[], outputs=[], structs=cells() + wires(), controls=control()) ]).emit()