示例#1
0
def get_memory(name: str, type: tvm.ir.Type) -> Cell:
    """Returns a Calyx memory for a given TVM type.
    For non-Tensor types, a register is returned.
    Otherwise, a memory with the corresponding dimension size
    is returned, if it exists in Calyx."""
    dims = type.concrete_shape
    # Bitwidth, along with sizes and index sizes (if it is a Tensor).
    args = [get_bitwidth(type)] + [d for d in dims
                                   ] + [bits_needed(d) for d in dims]

    num_dims = len(dims)
    assert num_dims in NumDimsToCell, f'Memory of size {num_dims} not supported.'

    return Cell(CompVar(name), NumDimsToCell[num_dims](*args))
示例#2
0
def instantiate_memory(top_or_left, idx, size):
    """
    Instantiates:
    - top memory
    - structure to move data from memory to read registers.

    Returns (cells, structure) tuple.
    """
    if top_or_left == "top":
        name = f"t{idx}"
        target_reg = f"top_0_{idx}"
    elif top_or_left == "left":
        name = f"l{idx}"
        target_reg = f"left_{idx}_0"
    else:
        raise f"Invalid top_or_left: {top_or_left}"

    var_name = ast.CompVar(f"{name}")
    idx_name = ast.CompVar(NAME_SCHEME["index name"].format(prefix=name))
    group_name = ast.CompVar(NAME_SCHEME["memory move"].format(prefix=name))
    target_reg = ast.CompVar(target_reg)
    structure = ast.Group(
        group_name,
        connections=[
            ast.Connect(ast.CompPort(idx_name, "out"), ast.CompPort(var_name, "addr0")),
            ast.Connect(
                ast.CompPort(var_name, "read_data"), ast.CompPort(target_reg, "in")
            ),
            ast.Connect(ast.ConstantPort(1, 1), ast.CompPort(target_reg, "write_en")),
            ast.Connect(
                ast.CompPort(target_reg, "done"), ast.HolePort(group_name, "done")
            ),
        ],
    )

    idx_width = bits_needed(size)
    # Instantiate the indexor
    (idx_cells, idx_structure) = instantiate_indexor(name, idx_width)
    idx_structure.append(structure)
    # Instantiate the memory
    idx_cells.append(
        ast.Cell(
            var_name, ast.Stdlib().mem_d1(BITWIDTH, size, idx_width), is_external=True
        )
    )
    return (idx_cells, idx_structure)
示例#3
0
def create_systolic_array(top_length, top_depth, left_length, left_depth):
    """
    top_length: Number of PEs in each row.
    top_depth: Number of elements processed by each PE in a row.
    left_length: Number of PEs in each column.
    left_depth: Number of elements processed by each PE in a col.
    """

    assert top_depth == left_depth, (
        f"Cannot multiply matrices: "
        f"{top_length}x{top_depth} and {left_depth}x{left_length}"
    )

    cells = []
    wires = []

    # Instantiate all the memories
    for r in range(top_length):
        (c, s) = instantiate_memory("top", r, top_depth)
        cells.extend(c)
        wires.extend(s)

    for c in range(left_length):
        (c, s) = instantiate_memory("left", c, left_depth)
        cells.extend(c)
        wires.extend(s)

    # Instantiate output memory
    out_ridx_size = bits_needed(left_length)
    out_cidx_size = bits_needed(top_length)
    cells.append(
        ast.Cell(
            OUT_MEM,
            ast.Stdlib().mem_d2(
                BITWIDTH, left_length, top_length, out_ridx_size, out_cidx_size
            ),
            is_external=True,
        )
    )

    # Instantiate all the PEs
    for row in range(left_length):
        for col in range(top_length):
            # Instantiate the PEs
            c = instantiate_pe(row, col, col == top_length - 1, row == left_length - 1)
            cells.extend(c)

            # Instantiate the mover fabric
            s = instantiate_data_move(
                row, col, col == top_length - 1, row == left_length - 1
            )
            wires.extend(s)

            # Instantiate output movement structure
            s = instantiate_output_move(row, col, out_ridx_size, out_cidx_size)
            wires.append(s)
    main = ast.Component(
        name="main",
        inputs=[],
        outputs=[],
        structs=wires + cells,
        controls=generate_control(top_length, top_depth, left_length, left_depth),
    )

    return ast.Program(imports=[ast.Import("primitives/std.lib")], components=[main])
示例#4
0
def generate_ntt_pipeline(input_bitwidth, n, q):
    """
    Prints a pipeline in FuTIL for the cooley-tukey algorithm
    that uses phis in bit-reversed order.

    `n`:
      Length of the input array.
    `input_bitwidth`:
      Bit width of the values in the input array.
    `q`:
      The modulus value.

    Reference:
    https://www.microsoft.com/en-us/research/wp-content/uploads/2016/05/RLWE-1.pdf
    """
    assert n > 0 and (
        not (n & (n - 1))), f'Input length: {n} must be a power of 2.'
    bitwidth = bits_needed(n)
    num_stages = bitwidth - 1

    operations = get_pipeline_data(n, num_stages)
    multiplies = get_multiply_data(n, num_stages)

    # Used to determine the index of the component
    # for the `sadd` and `ssub` primitives.
    component_counts = {'add': 0, 'sub': 0}

    def fresh_comp_index(op):
        # Produces a new index for the component used in the stage.
        # This allows for N / 2 `sadd` and `ssub` components.

        saved_count = component_counts[op]
        if component_counts[op] == (n // 2) - 1:
            # Reset for the next stage.
            component_counts[op] = 0
        else:
            component_counts[op] += 1

        return saved_count

    # Memory component variables.
    input = CompVar('a')
    phis = CompVar('phis')

    def mul_group(stage, mul_tuple):
        mul_index, k, phi_index = mul_tuple

        group_name = CompVar(f's{stage}_mul{mul_index}')
        mult_pipe = CompVar(f'mult_pipe{mul_index}')
        mul = CompVar(f'mul{mul_index}')
        phi = CompVar(f'phi{phi_index}')
        reg = CompVar(f'r{k}')
        connections = [
            Connect(CompPort(phi, 'out'), CompPort(mult_pipe, 'left')),
            Connect(CompPort(reg, 'out'), CompPort(mult_pipe, 'right')),
            Connect(ConstantPort(1, 1), CompPort(mult_pipe, 'go'),
                    Not(Atom(CompPort(mult_pipe, 'done')))),
            Connect(CompPort(mult_pipe, 'done'), CompPort(mul, 'write_en')),
            Connect(CompPort(mult_pipe, 'out'), CompPort(mul, 'in')),
            Connect(CompPort(mul, 'done'), HolePort(group_name, 'done'))
        ]
        return Group(group_name, connections)

    def op_mod_group(stage, row, operations_tuple):
        lhs, op, mul_index = operations_tuple
        comp = 'add' if op == '+' else 'sub'
        comp_index = fresh_comp_index(comp)

        group_name = CompVar(f's{stage}_r{row}_op_mod')
        op = CompVar(f'{comp}{comp_index}')
        reg = CompVar(f'r{lhs}')
        mul = CompVar(f'mul{mul_index}')
        mod_pipe = CompVar(f'mod_pipe{row}')
        A = CompVar(f'A{row}')
        connections = [
            Connect(CompPort(reg, 'out'), CompPort(op, 'left')),
            Connect(CompPort(mul, 'out'), CompPort(op, 'right')),
            Connect(CompPort(op, 'out'), CompPort(mod_pipe, 'left')),
            Connect(ConstantPort(input_bitwidth, q),
                    CompPort(mod_pipe, 'right')),
            Connect(ConstantPort(1, 1), CompPort(mod_pipe, 'go'),
                    Not(Atom(CompPort(mod_pipe, 'done')))),
            Connect(CompPort(mod_pipe, 'done'), CompPort(A, 'write_en')),
            Connect(CompPort(mod_pipe, 'out'), CompPort(A, 'in')),
            Connect(CompPort(A, 'done'), HolePort(group_name, 'done'))
        ]
        return Group(group_name, connections)

    def precursor_group(row):
        group_name = CompVar(f'precursor_{row}')
        r = CompVar(f'r{row}')
        A = CompVar(f'A{row}')
        connections = [
            Connect(CompPort(A, 'out'), CompPort(r, 'in')),
            Connect(ConstantPort(1, 1), CompPort(r, 'write_en')),
            Connect(CompPort(r, 'done'), HolePort(group_name, 'done'))
        ]
        return Group(group_name, connections)

    def preamble_group(row):
        reg = CompVar(f'r{row}')
        phi = CompVar(f'phi{row}')
        group_name = CompVar(f'preamble_{row}')
        connections = [
            Connect(ConstantPort(bitwidth, row), CompPort(input, 'addr0')),
            Connect(ConstantPort(bitwidth, row), CompPort(phis, 'addr0')),
            Connect(ConstantPort(1, 1), CompPort(reg, 'write_en')),
            Connect(CompPort(input, 'read_data'), CompPort(reg, 'in')),
            Connect(ConstantPort(1, 1), CompPort(phi, 'write_en')),
            Connect(CompPort(phis, 'read_data'), CompPort(phi, 'in')),
            Connect(
                ConstantPort(1, 1), HolePort(group_name, 'done'),
                And(Atom(CompPort(reg, 'done')), Atom(CompPort(phi, 'done'))))
        ]
        return Group(group_name, connections)

    def epilogue_group(row):
        group_name = CompVar(f'epilogue_{row}')
        A = CompVar(f'A{row}')
        connections = [
            Connect(ConstantPort(bitwidth, row), CompPort(input, 'addr0')),
            Connect(ConstantPort(1, 1), CompPort(input, 'write_en')),
            Connect(CompPort(A, 'out'), CompPort(input, 'write_data')),
            Connect(CompPort(input, 'done'), HolePort(group_name, 'done'))
        ]
        return Group(group_name, connections)

    def cells():
        stdlib = Stdlib()

        memories = [
            Cell(input, stdlib.mem_d1(input_bitwidth, n, bitwidth)),
            Cell(phis, stdlib.mem_d1(input_bitwidth, n, bitwidth))
        ]
        r_regs = [
            Cell(CompVar(f'r{r}'), stdlib.register(input_bitwidth))
            for r in range(n)
        ]
        A_regs = [
            Cell(CompVar(f'A{r}'), stdlib.register(input_bitwidth))
            for r in range(n)
        ]
        mul_regs = [
            Cell(CompVar(f'mul{i}'), stdlib.register(input_bitwidth))
            for i in range(n // 2)
        ]
        phi_regs = [
            Cell(CompVar(f'phi{r}'), stdlib.register(input_bitwidth))
            for r in range(n)
        ]
        mod_pipes = [
            Cell(CompVar(f'mod_pipe{r}'),
                 stdlib.op('mod_pipe', input_bitwidth, signed=True))
            for r in range(n)
        ]
        mult_pipes = [
            Cell(CompVar(f'mult_pipe{i}'),
                 stdlib.op('mult_pipe', input_bitwidth, signed=True))
            for i in range(n // 2)
        ]
        adds = [
            Cell(CompVar(f'add{i}'),
                 stdlib.op('add', input_bitwidth, signed=True))
            for i in range(n // 2)
        ]
        subs = [
            Cell(CompVar(f'sub{i}'),
                 stdlib.op('sub', input_bitwidth, signed=True))
            for i in range(n // 2)
        ]

        return memories + r_regs + A_regs + mul_regs + phi_regs + mod_pipes + mult_pipes + adds + subs

    def wires():
        preambles = [preamble_group(r) for r in range(n)]
        precursors = [precursor_group(r) for r in range(n)]
        muls = [
            mul_group(s, multiplies[s][i]) for s in range(num_stages)
            for i in range(n // 2)
        ]
        op_mods = [
            op_mod_group(s, r, operations[s][r]) for s in range(num_stages)
            for r in range(n)
        ]
        epilogues = [epilogue_group(r) for r in range(n)]
        return preambles + precursors + muls + op_mods + epilogues

    def control():
        preambles = [SeqComp([Enable(f'preamble_{r}') for r in range(n)])]
        epilogues = [SeqComp([Enable(f'epilogue_{r}') for r in range(n)])]

        ntt_stages = []
        for s in range(num_stages):
            if s != 0:
                # Only append precursors if this is not the first stage.
                ntt_stages.append(
                    ParComp([Enable(f'precursor_{r}') for r in range(n)]))
            # Multiply
            ntt_stages.append(
                ParComp([Enable(f's{s}_mul{i}') for i in range(n // 2)]))
            # Addition or subtraction mod `q`
            ntt_stages.append(
                ParComp([Enable(f's{s}_r{r}_op_mod') for r in range(n)]))
        return SeqComp(preambles + ntt_stages + epilogues)

    pp_table(operations, multiplies, n, num_stages)
    Program(imports=[Import('primitives/std.lib')],
            components=[
                Component('main',
                          inputs=[],
                          outputs=[],
                          structs=cells() + wires(),
                          controls=control())
            ]).emit()