Exemple #1
0
    def control():
        preambles = [SeqComp([Enable(f"preamble_{r}") for r in range(n)])]
        epilogues = [SeqComp([Enable(f"epilogue_{r}") for r in range(n)])]

        ntt_stages = []
        for s in range(num_stages):
            if s != 0:
                # Only append precursors if this is not the first stage.
                ntt_stages.append(ParComp([Enable(f"precursor_{r}") for r in range(n)]))
            # Multiply
            ntt_stages.append(ParComp([Enable(f"s{s}_mul{i}") for i in range(n // 2)]))
            # Addition or subtraction mod `q`
            ntt_stages.append(ParComp([Enable(f"s{s}_r{r}_op_mod") for r in range(n)]))
        return SeqComp(preambles + ntt_stages + epilogues)
Exemple #2
0
def gen_reduce_impl(stmt, arr_size, s_idx):
    """
    Returns a dictionary containing Calyx cells, wires and
    control needed to implement a map statement. Similar
    to gen_map_impl, with an implementation of a body
    of the `reduce` statement instead of an implementation
    of a `map` statement.
    """
    stdlib = Stdlib()
    op_name = "mult" if stmt.op.body.op == "mul" else "add"
    cells = [
        Cell(CompVar(f"le{s_idx}"), stdlib.op("lt", 32, signed=False)),
        Cell(CompVar(f"idx{s_idx}"), stdlib.register(32)),
        Cell(CompVar(f"adder_idx{s_idx}"), stdlib.op("add", 32, signed=False)),
        Cell(CompVar(f"adder_op{s_idx}"),
             stdlib.op(f"{op_name}", 32, signed=False)),
    ]
    wires = [
        emit_cond_group(s_idx, arr_size),
        emit_idx_group(s_idx),
        emit_eval_body_group(s_idx, stmt, 0),
    ]
    control = While(
        port=CompPort(CompVar(f"le{s_idx}"), "out"),
        cond=CompVar(f"cond{s_idx}"),
        body=SeqComp([Enable(f"eval_body{s_idx}"),
                      Enable(f"incr_idx{s_idx}")]),
    )

    return {"cells": cells, "wires": wires, "control": control}
Exemple #3
0
def generate_control(degree: int, is_signed: bool) -> Control:
    pow_invokes = [
        ParComp([
            Invoke(
                CompVar("pow1"),
                [
                    ("base", CompPort(CompVar("e"), "out")),
                    ("integer_exp", CompPort(CompVar("int_x"), "out")),
                ],
                [],
            )
        ] + [
            Invoke(
                CompVar(f"pow{i}"),
                [
                    ("base", CompPort(CompVar("frac_x"), "out")),
                    ("integer_exp", CompPort(CompVar(f"c{i}"), "out")),
                ],
                [],
            ) for i in range(2, degree + 1)
        ])
    ]
    consume_pow = [
        ParComp([Enable(f"consume_pow{i}") for i in range(2, degree + 1)])
    ]
    mult_by_reciprocal = [
        ParComp([
            Enable(f"mult_by_reciprocal_factorial{i}")
            for i in range(2, degree + 1)
        ])
    ]

    divide_and_conquer = []
    Enable_count = degree >> 1
    for r in range(1, int(log2(degree) + 1)):
        divide_and_conquer.append(
            ParComp([
                Enable(f"sum_round{r}_{i}")
                for i in range(1, Enable_count + 1)
            ]))
        Enable_count >>= 1

    ending_sequence = [Enable("add_degree_zero"),
                       Enable("final_multiply")] + ([
                           If(
                               CompPort(CompVar("lt"), "out"),
                               CompVar("is_negative"),
                               Enable("reciprocal"),
                           )
                       ] if is_signed else [])
    return SeqComp([Enable("init")] + ([
        If(
            CompPort(CompVar("lt"), "out"),
            CompVar("is_negative"),
            Enable("negate"),
        )
    ] if is_signed else []) + [Enable("split_bits")] + pow_invokes +
                   consume_pow + mult_by_reciprocal + divide_and_conquer +
                   ending_sequence)
Exemple #4
0
def gen_map_impl(stmt, arr_size, bank_factor, s_idx):
    """
    Returns a dictionary containing Calyx cells, wires and
    control needed to implement a map statement. (See gen_stmt_impl
    for format of the dictionary.)

    Generates these groups:
      - a group that implements the body of the map statement
      - a group that increments an index to access the map input array
      - a group that implements the loop condition, checking if the index
        has reached the end of the input array
    """
    stdlib = Stdlib()

    cells = []
    for b in range(bank_factor):
        cells.extend([
            Cell(CompVar(f"le_b{b}_{s_idx}"), stdlib.op("lt", 32,
                                                        signed=False)),
            Cell(CompVar(f"idx_b{b}_{s_idx}"), stdlib.register(32)),
            Cell(
                CompVar(f"adder_idx_b{b}_{s_idx}"),
                stdlib.op("add", 32, signed=False),
            ),
        ])

    op_name = "mult" if stmt.op.body.op == "mul" else "add"
    for b in range(bank_factor):
        cells.append(
            Cell(
                CompVar(f"adder_op_b{b}_{s_idx}"),
                stdlib.op(f"{op_name}", 32, signed=False),
            ))

    wires = []
    for b in range(bank_factor):
        wires.extend([
            emit_cond_group(s_idx, arr_size // bank_factor, b),
            emit_idx_group(s_idx, b),
            emit_eval_body_group(s_idx, stmt, b),
        ])

        map_loops = []
        for b in range(bank_factor):
            b_suffix = f"_b{str(b)}_"
            map_loops.append(
                While(
                    CompPort(CompVar(f"le{b_suffix}{s_idx}"), "out"),
                    CompVar(f"cond{b_suffix}{s_idx}"),
                    SeqComp([
                        Enable(f"eval_body{b_suffix}{s_idx}"),
                        Enable(f"incr_idx{b_suffix}{s_idx}"),
                    ]),
                ))

    control = ParComp(map_loops)

    return {"cells": cells, "wires": wires, "control": control}
Exemple #5
0
def reduce_parallel_control_pass(component: Component, N: int, input_size: int):
    """Reduces the amount of fan-out by reducing
    parallelization in the execution flow
    by a factor of `N`.

    For example, given an input size 4 and
    reduction factor N = 2:

    Before:
    par { s0_mul0; s0_mul1; }
    par { s0_r0_op_mod; s0_r1_op_mod; s0_r2_op_mod; s0_r3_op_mod; }
    ...

    After:
    par { s0_mul0; s0_mul1; }
    par { s0_r0_op_mod; s0_r1_op_mod; }
    par { s0_r2_op_mod; s0_r3_op_mod; }
    ...
    """
    assert (
        N is not None and 0 < N < input_size and (not (N & (N - 1)))
    ), f"""N: {N} should be a power of two within bounds (0, {input_size})."""

    reduced_controls = []
    for control in component.controls.stmts:
        if not isinstance(control, ParComp):
            reduced_controls.append(control)
            continue

        enable = next(iter(control.stmts), None).stmt
        # Parallelized multiplies are already a factor of 1/2 less.
        factor = N // 2 if "mul" in enable else N

        reduced_controls.extend(
            ParComp(x) for x in np.split(np.array(control.stmts), factor)
        )

    component.controls = SeqComp(reduced_controls)
Exemple #6
0
def emit(prog):
    """
    Returns a string containing a Calyx program, compiled from `prog`, a MrXL
    program.
    """
    cells, wires, control = [], [], []

    # All arrays must be the same size. The first array we see determines the
    # size that we'll assume for the rest of the program's arrays.
    arr_size = None

    # Collect banking factors.
    name2par = dict()
    for stmt in prog.stmts:
        if isinstance(stmt.op, ast.Map):
            name2par[stmt.dest] = stmt.op.par
            for b in stmt.op.bind:
                name2par[b.src] = stmt.op.par

    # Collect memory and register declarations.
    used_names = []
    stdlib = Stdlib()
    for decl in prog.decls:
        used_names.append(decl.name)
        if decl.type.size:  # A memory
            arr_size = decl.type.size
            cells.extend(
                emit_mem_decl(decl.name, decl.type.size, name2par[decl.name]))
        else:  # A register
            cells.append(Cell(CompVar(decl.name), stdlib.register(32)))

    # Collect implicit memory and register declarations.
    for stmt in prog.stmts:
        if stmt.dest not in used_names:
            if isinstance(stmt.op, ast.Map):
                cells.extend(
                    emit_mem_decl(stmt.dest, arr_size, name2par[stmt.dest]))
            else:
                raise NotImplementedError("Generating register declarations")
                #  cells.append(emit_reg_decl(stmt.dest, 32))
            used_names.append(stmt.dest)

    # Generate Calyx.
    for i, stmt in enumerate(prog.stmts):
        stmt_impl = gen_stmt_impl(stmt, arr_size, name2par, i)
        cells.extend(stmt_impl["cells"])
        wires.extend(stmt_impl["wires"])
        control.append(stmt_impl["control"])

    program = Program(
        imports=[
            Import("primitives/core.futil"),
            Import("primitives/binary_operators.futil"),
        ],
        components=[
            Component(
                name="main",
                inputs=[],
                outputs=[],
                structs=cells + wires,
                controls=SeqComp(control),
            )
        ],
    )
    program.emit()
Exemple #7
0
def generate_fp_pow_component(width: int, int_width: int,
                              is_signed: bool) -> Component:
    """Generates a fixed point `pow` component, which
    computes the value x**y, where y must be an integer.
    """
    stdlib = Stdlib()
    frac_width = width - int_width

    pow = CompVar("pow")
    count = CompVar("count")
    mul = CompVar("mul")
    lt = CompVar("lt")
    incr = CompVar("incr")

    cells = [
        Cell(pow, stdlib.register(width)),
        Cell(count, stdlib.register(width)),
        Cell(
            mul,
            stdlib.fixed_point_op("mult_pipe",
                                  width,
                                  int_width,
                                  frac_width,
                                  signed=is_signed),
        ),
        Cell(lt, stdlib.op("lt", width, signed=is_signed)),
        Cell(incr, stdlib.op("add", width, signed=is_signed)),
    ]
    wires = [
        Group(
            id=CompVar("init"),
            connections=[
                Connect(
                    ConstantPort(
                        width,
                        numeric_types.FixedPoint(
                            "1.0", width, int_width,
                            is_signed=is_signed).unsigned_integer(),
                    ),
                    CompPort(pow, "in"),
                ),
                Connect(ConstantPort(1, 1), CompPort(pow, "write_en")),
                Connect(ConstantPort(width, 0), CompPort(count, "in")),
                Connect(ConstantPort(1, 1), CompPort(count, "write_en")),
                Connect(
                    ConstantPort(1, 1),
                    HolePort(CompVar("init"), "done"),
                    And(
                        Atom(CompPort(pow, "done")),
                        Atom(CompPort(count, "done")),
                    ),
                ),
            ],
        ),
        Group(
            id=CompVar("execute_mul"),
            connections=[
                Connect(ThisPort(CompVar("base")), CompPort(mul, "left")),
                Connect(CompPort(pow, "out"), CompPort(mul, "right")),
                Connect(
                    ConstantPort(1, 1),
                    CompPort(mul, "go"),
                    Not(Atom(CompPort(mul, "done"))),
                ),
                Connect(CompPort(mul, "done"), CompPort(pow, "write_en")),
                Connect(CompPort(mul, "out"), CompPort(pow, "in")),
                Connect(
                    CompPort(pow, "done"),
                    HolePort(CompVar("execute_mul"), "done"),
                ),
            ],
        ),
        Group(
            id=CompVar("incr_count"),
            connections=[
                Connect(ConstantPort(width, 1), CompPort(incr, "left")),
                Connect(CompPort(count, "out"), CompPort(incr, "right")),
                Connect(CompPort(incr, "out"), CompPort(count, "in")),
                Connect(ConstantPort(1, 1), CompPort(count, "write_en")),
                Connect(
                    CompPort(count, "done"),
                    HolePort(CompVar("incr_count"), "done"),
                ),
            ],
        ),
        CombGroup(
            id=CompVar("cond"),
            connections=[
                Connect(CompPort(count, "out"), CompPort(lt, "left")),
                Connect(ThisPort(CompVar("integer_exp")),
                        CompPort(lt, "right")),
            ],
        ),
        Connect(CompPort(CompVar("pow"), "out"), ThisPort(CompVar("out"))),
    ]
    return Component(
        "fp_pow",
        inputs=[
            PortDef(CompVar("base"), width),
            PortDef(CompVar("integer_exp"), width),
        ],
        outputs=[PortDef(CompVar("out"), width)],
        structs=cells + wires,
        controls=SeqComp([
            Enable("init"),
            While(
                CompPort(lt, "out"),
                CompVar("cond"),
                ParComp([Enable("execute_mul"),
                         Enable("incr_count")]),
            ),
        ]),
    )
Exemple #8
0
                        Connect(
                            ConstantPort(1, 0),
                            CompPort(CompVar("ret"), "addr0"),
                        ),
                        Connect(
                            ConstantPort(1, 1),
                            CompPort(CompVar("ret"), "write_en"),
                        ),
                        Connect(
                            CompPort(CompVar("e"), "out"),
                            CompPort(CompVar("ret"), "write_data"),
                        ),
                        Connect(
                            CompPort(CompVar("ret"), "done"),
                            HolePort(CompVar("write_to_memory"), "done"),
                        ),
                    ],
                ),
            ],
            controls=SeqComp([
                Enable("init"),
                Invoke(
                    id=CompVar("e"),
                    in_connects=[("x", CompPort(CompVar("t"), "out"))],
                    out_connects=[],
                ),
                Enable("write_to_memory"),
            ]),
        ))
    program.emit()