Esempio n. 1
0
def instantiate_memory(top_or_left, idx, size):
    """
    Instantiates:
    - top memory
    - structure to move data from memory to read registers.

    Returns (cells, structure) tuple.
    """
    if top_or_left == "top":
        name = f"t{idx}"
        target_reg = f"top_0_{idx}"
    elif top_or_left == "left":
        name = f"l{idx}"
        target_reg = f"left_{idx}_0"
    else:
        raise f"Invalid top_or_left: {top_or_left}"

    var_name = py_ast.CompVar(f"{name}")
    idx_name = py_ast.CompVar(NAME_SCHEME["index name"].format(prefix=name))
    group_name = py_ast.CompVar(NAME_SCHEME["memory move"].format(prefix=name))
    target_reg = py_ast.CompVar(target_reg)
    structure = py_ast.Group(
        group_name,
        connections=[
            py_ast.Connect(py_ast.CompPort(idx_name, "out"),
                           py_ast.CompPort(var_name, "addr0")),
            py_ast.Connect(
                py_ast.CompPort(var_name, "read_data"),
                py_ast.CompPort(target_reg, "in"),
            ),
            py_ast.Connect(py_ast.ConstantPort(1, 1),
                           py_ast.CompPort(target_reg, "write_en")),
            py_ast.Connect(py_ast.CompPort(target_reg, "done"),
                           py_ast.HolePort(group_name, "done")),
        ],
    )

    idx_width = bits_needed(size)
    # Instantiate the indexor
    (idx_cells, idx_structure) = instantiate_indexor(name, idx_width)
    idx_structure.append(structure)
    # Instantiate the memory
    idx_cells.append(
        py_ast.Cell(
            var_name,
            py_ast.Stdlib().mem_d1(BITWIDTH, size, idx_width),
            is_external=True,
        ))
    return (idx_cells, idx_structure)
Esempio n. 2
0
def get_memory(name: str, type: tvm.ir.Type) -> Cell:
    """Returns a Calyx memory for a given TVM type.
    For non-Tensor types, a register is returned.
    Otherwise, a memory with the corresponding dimension size
    is returned, if it exists in Calyx."""
    dims = type.concrete_shape
    # Bitwidth, along with sizes and index sizes (if it is a Tensor).
    args = [get_bitwidth(type)] + [d for d in dims
                                   ] + [bits_needed(d) for d in dims]

    num_dims = len(dims)
    assert num_dims in NumDimsToCell, f"Memory of size {num_dims} not supported."

    return Cell(CompVar(name),
                NumDimsToCell[num_dims](*args),
                is_external=True)
Esempio n. 3
0
def create_systolic_array(top_length,
                          top_depth,
                          left_length,
                          left_depth,
                          gen_metadata=False):
    """
    top_length: Number of PEs in each row.
    top_depth: Number of elements processed by each PE in a row.
    left_length: Number of PEs in each column.
    left_depth: Number of elements processed by each PE in a col.
    """

    assert top_depth == left_depth, (
        f"Cannot multiply matrices: "
        f"{top_length}x{top_depth} and {left_depth}x{left_length}")

    cells = []
    wires = []

    # Instantiate all the memories
    for r in range(top_length):
        (c, s) = instantiate_memory("top", r, top_depth)
        cells.extend(c)
        wires.extend(s)

    for c in range(left_length):
        (c, s) = instantiate_memory("left", c, left_depth)
        cells.extend(c)
        wires.extend(s)

    # Instantiate output memory
    total_size = left_length * top_length
    out_idx_size = bits_needed(total_size)
    cells.append(
        py_ast.Cell(
            OUT_MEM,
            py_ast.Stdlib().mem_d1(BITWIDTH, total_size, out_idx_size),
            is_external=True,
        ))

    # Instantiate all the PEs
    for row in range(left_length):
        for col in range(top_length):
            # Instantiate the PEs
            c = instantiate_pe(row, col, col == top_length - 1,
                               row == left_length - 1)
            cells.extend(c)

            # Instantiate the mover fabric
            s = instantiate_data_move(row, col, col == top_length - 1,
                                      row == left_length - 1)
            wires.extend(s)

            # Instantiate output movement structure
            s = instantiate_output_move(row, col, top_length, out_idx_size)
            wires.append(s)

    control, source_map = generate_control(top_length,
                                           top_depth,
                                           left_length,
                                           left_depth,
                                           gen_metadata=gen_metadata)

    main = py_ast.Component(
        name="main",
        inputs=[],
        outputs=[],
        structs=wires + cells,
        controls=control,
    )

    return (
        py_ast.Program(
            imports=[
                py_ast.Import("primitives/core.futil"),
                py_ast.Import("primitives/binary_operators.futil"),
            ],
            components=[main],
        ),
        source_map,
    )
Esempio n. 4
0
def generate_ntt_pipeline(input_bitwidth: int, n: int, q: int):
    """
    Prints a pipeline in Calyx for the cooley-tukey algorithm
    that uses phis in bit-reversed order.

    `n`:
      Length of the input array.
    `input_bitwidth`:
      Bit width of the values in the input array.
    `q`:
      The modulus value.

    Reference:
    https://www.microsoft.com/en-us/research/wp-content/uploads/2016/05/RLWE-1.pdf
    """
    assert n > 0 and (not (n & (n - 1))), f"Input length: {n} must be a power of 2."
    bitwidth = bits_needed(n)
    num_stages = bitwidth - 1

    operations = get_pipeline_data(n, num_stages)
    multiplies = get_multiply_data(n, num_stages)

    # Used to determine the index of the component
    # for the `sadd` and `ssub` primitives.
    component_counts = {"add": 0, "sub": 0}

    def fresh_comp_index(op):
        # Produces a new index for the component used in the stage.
        # This allows for N / 2 `sadd` and `ssub` components.

        saved_count = component_counts[op]
        if component_counts[op] == (n // 2) - 1:
            # Reset for the next stage.
            component_counts[op] = 0
        else:
            component_counts[op] += 1

        return saved_count

    # Memory component variables.
    input = CompVar("a")
    phis = CompVar("phis")

    def mul_group(stage, mul_tuple):
        mul_index, k, phi_index = mul_tuple

        group_name = CompVar(f"s{stage}_mul{mul_index}")
        mult_pipe = CompVar(f"mult_pipe{mul_index}")
        phi = CompVar(f"phi{phi_index}")
        reg = CompVar(f"r{k}")
        connections = [
            Connect(CompPort(phi, "out"), CompPort(mult_pipe, "left")),
            Connect(CompPort(reg, "out"), CompPort(mult_pipe, "right")),
            Connect(ConstantPort(1, 1), CompPort(mult_pipe, "go")),
            Connect(CompPort(mult_pipe, "done"), HolePort(group_name, "done")),
        ]
        return Group(group_name, connections)

    def op_mod_group(stage, row, operations_tuple):
        lhs, op, mul_index = operations_tuple
        comp = "add" if op == "+" else "sub"
        comp_index = fresh_comp_index(comp)

        group_name = CompVar(f"s{stage}_r{row}_op_mod")
        op = CompVar(f"{comp}{comp_index}")
        reg = CompVar(f"r{lhs}")
        mul = CompVar(f"mult_pipe{mul_index}")
        mod_pipe = CompVar(f"mod_pipe{row}")
        A = CompVar(f"A{row}")
        connections = [
            Connect(CompPort(reg, "out"), CompPort(op, "left")),
            Connect(CompPort(mul, "out"), CompPort(op, "right")),
            Connect(CompPort(op, "out"), CompPort(mod_pipe, "left")),
            Connect(ConstantPort(input_bitwidth, q), CompPort(mod_pipe, "right")),
            Connect(
                ConstantPort(1, 1),
                CompPort(mod_pipe, "go"),
                Not(Atom(CompPort(mod_pipe, "done"))),
            ),
            Connect(CompPort(mod_pipe, "done"), CompPort(A, "write_en")),
            Connect(CompPort(mod_pipe, "out_remainder"), CompPort(A, "in")),
            Connect(CompPort(A, "done"), HolePort(group_name, "done")),
        ]
        return Group(group_name, connections)

    def precursor_group(row):
        group_name = CompVar(f"precursor_{row}")
        r = CompVar(f"r{row}")
        A = CompVar(f"A{row}")
        connections = [
            Connect(CompPort(A, "out"), CompPort(r, "in")),
            Connect(ConstantPort(1, 1), CompPort(r, "write_en")),
            Connect(CompPort(r, "done"), HolePort(group_name, "done")),
        ]
        return Group(group_name, connections)

    def preamble_group(row):
        reg = CompVar(f"r{row}")
        phi = CompVar(f"phi{row}")
        group_name = CompVar(f"preamble_{row}")
        connections = [
            Connect(ConstantPort(bitwidth, row), CompPort(input, "addr0")),
            Connect(ConstantPort(bitwidth, row), CompPort(phis, "addr0")),
            Connect(ConstantPort(1, 1), CompPort(reg, "write_en")),
            Connect(CompPort(input, "read_data"), CompPort(reg, "in")),
            Connect(ConstantPort(1, 1), CompPort(phi, "write_en")),
            Connect(CompPort(phis, "read_data"), CompPort(phi, "in")),
            Connect(
                ConstantPort(1, 1),
                HolePort(group_name, "done"),
                And(Atom(CompPort(reg, "done")), Atom(CompPort(phi, "done"))),
            ),
        ]
        return Group(group_name, connections)

    def epilogue_group(row):
        group_name = CompVar(f"epilogue_{row}")
        A = CompVar(f"A{row}")
        connections = [
            Connect(ConstantPort(bitwidth, row), CompPort(input, "addr0")),
            Connect(ConstantPort(1, 1), CompPort(input, "write_en")),
            Connect(CompPort(A, "out"), CompPort(input, "write_data")),
            Connect(CompPort(input, "done"), HolePort(group_name, "done")),
        ]
        return Group(group_name, connections)

    def cells():
        stdlib = Stdlib()

        memories = [
            Cell(input, stdlib.mem_d1(input_bitwidth, n, bitwidth), is_external=True),
            Cell(phis, stdlib.mem_d1(input_bitwidth, n, bitwidth), is_external=True),
        ]
        r_regs = [
            Cell(CompVar(f"r{r}"), stdlib.register(input_bitwidth)) for r in range(n)
        ]
        A_regs = [
            Cell(CompVar(f"A{r}"), stdlib.register(input_bitwidth)) for r in range(n)
        ]
        mul_regs = [
            Cell(CompVar(f"mul{i}"), stdlib.register(input_bitwidth))
            for i in range(n // 2)
        ]
        phi_regs = [
            Cell(CompVar(f"phi{r}"), stdlib.register(input_bitwidth)) for r in range(n)
        ]
        mod_pipes = [
            Cell(
                CompVar(f"mod_pipe{r}"),
                stdlib.op("div_pipe", input_bitwidth, signed=True),
            )
            for r in range(n)
        ]
        mult_pipes = [
            Cell(
                CompVar(f"mult_pipe{i}"),
                stdlib.op("mult_pipe", input_bitwidth, signed=True),
            )
            for i in range(n // 2)
        ]
        adds = [
            Cell(CompVar(f"add{i}"), stdlib.op("add", input_bitwidth, signed=True))
            for i in range(n // 2)
        ]
        subs = [
            Cell(CompVar(f"sub{i}"), stdlib.op("sub", input_bitwidth, signed=True))
            for i in range(n // 2)
        ]

        return (
            memories
            + r_regs
            + A_regs
            + mul_regs
            + phi_regs
            + mod_pipes
            + mult_pipes
            + adds
            + subs
        )

    def wires():
        preambles = [preamble_group(r) for r in range(n)]
        precursors = [precursor_group(r) for r in range(n)]
        muls = [
            mul_group(s, multiplies[s][i])
            for s in range(num_stages)
            for i in range(n // 2)
        ]
        op_mods = [
            op_mod_group(s, r, operations[s][r])
            for s in range(num_stages)
            for r in range(n)
        ]
        epilogues = [epilogue_group(r) for r in range(n)]
        return preambles + precursors + muls + op_mods + epilogues

    def control():
        preambles = [SeqComp([Enable(f"preamble_{r}") for r in range(n)])]
        epilogues = [SeqComp([Enable(f"epilogue_{r}") for r in range(n)])]

        ntt_stages = []
        for s in range(num_stages):
            if s != 0:
                # Only append precursors if this is not the first stage.
                ntt_stages.append(ParComp([Enable(f"precursor_{r}") for r in range(n)]))
            # Multiply
            ntt_stages.append(ParComp([Enable(f"s{s}_mul{i}") for i in range(n // 2)]))
            # Addition or subtraction mod `q`
            ntt_stages.append(ParComp([Enable(f"s{s}_r{r}_op_mod") for r in range(n)]))
        return SeqComp(preambles + ntt_stages + epilogues)

    pp_table(operations, multiplies, n, num_stages)
    return Program(
        imports=[
            Import("primitives/core.futil"),
            Import("primitives/binary_operators.futil"),
        ],
        components=[
            Component(
                "main",
                inputs=[],
                outputs=[],
                structs=cells() + wires(),
                controls=control(),
            )
        ],
    )