def epilogue_group(row): group_name = CompVar(f"epilogue_{row}") A = CompVar(f"A{row}") connections = [ Connect(ConstantPort(bitwidth, row), CompPort(input, "addr0")), Connect(ConstantPort(1, 1), CompPort(input, "write_en")), Connect(CompPort(A, "out"), CompPort(input, "write_data")), Connect(CompPort(input, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections)
def precursor_group(row): group_name = CompVar(f"precursor_{row}") r = CompVar(f"r{row}") A = CompVar(f"A{row}") connections = [ Connect(CompPort(A, "out"), CompPort(r, "in")), Connect(ConstantPort(1, 1), CompPort(r, "write_en")), Connect(CompPort(r, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections)
def mul_group(stage, mul_tuple): mul_index, k, phi_index = mul_tuple group_name = CompVar(f"s{stage}_mul{mul_index}") mult_pipe = CompVar(f"mult_pipe{mul_index}") phi = CompVar(f"phi{phi_index}") reg = CompVar(f"r{k}") connections = [ Connect(CompPort(phi, "out"), CompPort(mult_pipe, "left")), Connect(CompPort(reg, "out"), CompPort(mult_pipe, "right")), Connect(ConstantPort(1, 1), CompPort(mult_pipe, "go")), Connect(CompPort(mult_pipe, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections)
def op_mod_group(stage, row, operations_tuple): lhs, op, mul_index = operations_tuple comp = "add" if op == "+" else "sub" comp_index = fresh_comp_index(comp) group_name = CompVar(f"s{stage}_r{row}_op_mod") op = CompVar(f"{comp}{comp_index}") reg = CompVar(f"r{lhs}") mul = CompVar(f"mult_pipe{mul_index}") mod_pipe = CompVar(f"mod_pipe{row}") A = CompVar(f"A{row}") connections = [ Connect(CompPort(reg, "out"), CompPort(op, "left")), Connect(CompPort(mul, "out"), CompPort(op, "right")), Connect(CompPort(op, "out"), CompPort(mod_pipe, "left")), Connect(ConstantPort(input_bitwidth, q), CompPort(mod_pipe, "right")), Connect( ConstantPort(1, 1), CompPort(mod_pipe, "go"), Not(Atom(CompPort(mod_pipe, "done"))), ), Connect(CompPort(mod_pipe, "done"), CompPort(A, "write_en")), Connect(CompPort(mod_pipe, "out_remainder"), CompPort(A, "in")), Connect(CompPort(A, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections)
def consume_pow(i: int) -> Group: # Write the output of pow{i} to register p{i}. reg = CompVar(f"p{i}") group_name = CompVar(f"consume_pow{i}") connections = [ Connect(ConstantPort(1, 1), CompPort(reg, "write_en")), Connect(CompPort(CompVar(f"pow{i}"), "out"), CompPort(reg, "in")), Connect( ConstantPort(1, 1), HolePort(group_name, "done"), CompPort(reg, "done"), ), ] return Group(group_name, connections, 1)
def emit_eval_body_group(s_idx, stmt, b=None): """ Returns a string of a group that implements the body of stmt, a `map` or `reduce` statement. Adds suffix at the end of the group name, to avoid name collisions with other `map` or `reduce` statement group implementations. If this is a `map` expression, b is the banking factor of the input array. (Otherwise, b is None.) """ bank_suffix = "_b" + str(b) if b is not None else "" mem_offsets = [] name2arr = dict() for bi in stmt.op.bind: idx = 0 if isinstance(stmt.op, ast.Map) else 1 name2arr[bi.dest[idx]] = bi.src src = CompVar(f"{bi.src}{bank_suffix}") dest = CompVar(f"idx{bank_suffix}_{s_idx}") mem_offsets.append( Connect(CompPort(dest, "out"), CompPort(src, "addr0"))) if isinstance(stmt.op, ast.Map): src = CompVar(f"{stmt.dest}{bank_suffix}") dest = CompVar(f"idx{bank_suffix}_{s_idx}") mem_offsets.append( Connect(CompPort(dest, "out"), CompPort(src, "addr0"))) compute_left_op = emit_compute_op(stmt.op.body.lhs, stmt.op, stmt.dest, name2arr, s_idx, bank_suffix) compute_right_op = emit_compute_op(stmt.op.body.rhs, stmt.op, stmt.dest, name2arr, s_idx, bank_suffix) if isinstance(stmt.op, ast.Map): write_to = CompVar(f"{stmt.dest}{bank_suffix}") adder_op = CompVar(f"adder_op{bank_suffix}_{s_idx}") write_connection = Connect(CompPort(adder_op, "out"), CompPort(write_to, "write_data")) else: write_connection = Connect( CompPort(CompVar(f"adder_op{s_idx}"), "out"), CompPort(CompVar(f"{stmt.dest}"), "in"), ) group_id = CompVar(f"eval_body{bank_suffix}_{s_idx}") adder = CompVar(f"adder_op{bank_suffix}_{s_idx}") dest = CompVar(f"{stmt.dest}{bank_suffix}") return Group( id=group_id, connections=[ Connect(ConstantPort(1, 1), CompPort(dest, "write_en")), Connect(compute_left_op, CompPort(adder, "left")), Connect(compute_right_op, CompPort(adder, "right")), write_connection, Connect(CompPort(dest, "done"), HolePort(group_id, "done")), ] + mem_offsets, )
def emit_cond_group(suffix, arr_size, b=None): """ Emits a group that checks if an index has reached arr_size. If the bank number `b` is not None, adds it to the end of the index cell name. suffix is added to the end to the end of each cell, to disambiguate from other `map` or `reduce` implementations. """ bank_suffix = f"_b{b}_" if b is not None else "" group_id = CompVar(f"cond{bank_suffix}{suffix}") le = CompVar(f"le{bank_suffix}{suffix}") idx = CompVar(f"idx{bank_suffix}{suffix}") return CombGroup( id=group_id, connections=[ Connect(CompPort(idx, "out"), CompPort(le, "left")), Connect(ConstantPort(32, arr_size), CompPort(le, "right")), ], )
def emit_idx_group(s_idx, b=None): """ Emits a group that increments an index. If the bank number `b` is not None, adds it (the bank number) as a suffix to each cell name. """ bank_suffix = "_b" + str(b) + "_" if b is not None else "" group_id = CompVar(f"incr_idx{bank_suffix}{s_idx}") adder = CompVar(f"adder_idx{bank_suffix}{s_idx}") idx = CompVar(f"idx{bank_suffix}{s_idx}") return Group( id=group_id, connections=[ Connect(CompPort(idx, "out"), CompPort(adder, "left")), Connect(ConstantPort(32, 1), CompPort(adder, "right")), Connect(ConstantPort(1, 1), CompPort(idx, "write_en")), Connect(CompPort(adder, "out"), CompPort(idx, "in")), Connect(CompPort(idx, "done"), HolePort(group_id, "done")), ], )
def final_multiply(register_id: CompVar) -> List[Group]: # Multiply e^{fractional_value} * e^{integer_value}, # and write it to register `m`. group_name = CompVar("final_multiply") mult_pipe = CompVar("mult_pipe1") reg = CompVar("m") return [ Group( id=group_name, connections=[ Connect( CompPort(CompVar("pow1"), "out"), CompPort(mult_pipe, "left"), ), Connect( CompPort(CompVar("sum1"), "out"), CompPort(mult_pipe, "right"), ), Connect( ConstantPort(1, 1), CompPort(mult_pipe, "go"), Not(Atom(CompPort(mult_pipe, "done"))), ), Connect(CompPort(mult_pipe, "done"), CompPort(reg, "write_en")), Connect(CompPort(mult_pipe, "out"), CompPort(reg, "in")), Connect(CompPort(reg, "done"), HolePort(group_name, "done")), ], ) ]
def preamble_group(row): reg = CompVar(f"r{row}") phi = CompVar(f"phi{row}") group_name = CompVar(f"preamble_{row}") connections = [ Connect(ConstantPort(bitwidth, row), CompPort(input, "addr0")), Connect(ConstantPort(bitwidth, row), CompPort(phis, "addr0")), Connect(ConstantPort(1, 1), CompPort(reg, "write_en")), Connect(CompPort(input, "read_data"), CompPort(reg, "in")), Connect(ConstantPort(1, 1), CompPort(phi, "write_en")), Connect(CompPort(phis, "read_data"), CompPort(phi, "in")), Connect( ConstantPort(1, 1), HolePort(group_name, "done"), And(Atom(CompPort(reg, "done")), Atom(CompPort(phi, "done"))), ), ] return Group(group_name, connections)
def multiply_by_reciprocal_factorial(i: int) -> Group: # Multiply register p{i} with the reciprocal factorial. group_name = CompVar(f"mult_by_reciprocal_factorial{i}") mult_pipe = CompVar(f"mult_pipe{i}") reg = CompVar(f"p{i}") product = CompVar(f"product{i}") reciprocal = CompVar(f"reciprocal_factorial{i}") connections = [ Connect(CompPort(reg, "out"), CompPort(mult_pipe, "left")), Connect(CompPort(reciprocal, "out"), CompPort(mult_pipe, "right")), Connect( ConstantPort(1, 1), CompPort(mult_pipe, "go"), Not(Atom(CompPort(mult_pipe, "done"))), ), Connect(CompPort(mult_pipe, "done"), CompPort(product, "write_en")), Connect(CompPort(mult_pipe, "out"), CompPort(product, "in")), Connect(CompPort(product, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections)
def generate_groups(degree: int, width: int, int_width: int, is_signed: bool) -> List[Structure]: frac_width = width - int_width input = CompVar("exponent_value") init = Group( id=CompVar("init"), connections=[ Connect(ConstantPort(1, 1), CompPort(input, "write_en")), Connect(ThisPort(CompVar("x")), CompPort(input, "in")), Connect(CompPort(input, "done"), HolePort(CompVar("init"), "done")), ], static_delay=1, ) if is_signed: mult_pipe = CompVar("mult_pipe1") negate = Group( id=CompVar("negate"), connections=[ Connect(CompPort(input, "out"), CompPort(mult_pipe, "left")), Connect( CompPort(CompVar("negative_one"), "out"), CompPort(mult_pipe, "right"), ), Connect( ConstantPort(1, 1), CompPort(mult_pipe, "go"), Not(Atom(CompPort(mult_pipe, "done"))), ), Connect(CompPort(mult_pipe, "done"), CompPort(input, "write_en")), Connect(CompPort(mult_pipe, "out"), CompPort(input, "in")), Connect(CompPort(input, "done"), HolePort(CompVar("negate"), "done")), ], ) # Initialization: split up the value `x` into its integer and fractional # values. split_bits = Group( id=CompVar("split_bits"), connections=[ Connect( CompPort(CompVar("exponent_value"), "out"), CompPort(CompVar("and0"), "left"), ), Connect( ConstantPort(width, 2**width - 2**frac_width), CompPort(CompVar("and0"), "right"), ), Connect( CompPort(CompVar("and0"), "out"), CompPort(CompVar("rsh"), "left"), ), Connect( ConstantPort(width, frac_width), CompPort(CompVar("rsh"), "right"), ), Connect( CompPort(CompVar("exponent_value"), "out"), CompPort(CompVar("and1"), "left"), ), Connect( ConstantPort(width, (2**frac_width) - 1), CompPort(CompVar("and1"), "right"), ), Connect( ConstantPort(1, 1), CompPort(CompVar("int_x"), "write_en"), ), Connect( ConstantPort(1, 1), CompPort(CompVar("frac_x"), "write_en"), ), Connect( CompPort(CompVar("rsh"), "out"), CompPort(CompVar("int_x"), "in"), ), Connect( CompPort(CompVar("and1"), "out"), CompPort(CompVar("frac_x"), "in"), ), Connect( ConstantPort(1, 1), HolePort(CompVar("split_bits"), "done"), And( Atom(CompPort(CompVar("int_x"), "done")), Atom(CompPort(CompVar("frac_x"), "done")), ), ), ], ) def consume_pow(i: int) -> Group: # Write the output of pow{i} to register p{i}. reg = CompVar(f"p{i}") group_name = CompVar(f"consume_pow{i}") connections = [ Connect(ConstantPort(1, 1), CompPort(reg, "write_en")), Connect(CompPort(CompVar(f"pow{i}"), "out"), CompPort(reg, "in")), Connect( ConstantPort(1, 1), HolePort(group_name, "done"), CompPort(reg, "done"), ), ] return Group(group_name, connections, 1) def multiply_by_reciprocal_factorial(i: int) -> Group: # Multiply register p{i} with the reciprocal factorial. group_name = CompVar(f"mult_by_reciprocal_factorial{i}") mult_pipe = CompVar(f"mult_pipe{i}") reg = CompVar(f"p{i}") product = CompVar(f"product{i}") reciprocal = CompVar(f"reciprocal_factorial{i}") connections = [ Connect(CompPort(reg, "out"), CompPort(mult_pipe, "left")), Connect(CompPort(reciprocal, "out"), CompPort(mult_pipe, "right")), Connect( ConstantPort(1, 1), CompPort(mult_pipe, "go"), Not(Atom(CompPort(mult_pipe, "done"))), ), Connect(CompPort(mult_pipe, "done"), CompPort(product, "write_en")), Connect(CompPort(mult_pipe, "out"), CompPort(product, "in")), Connect(CompPort(product, "done"), HolePort(group_name, "done")), ] return Group(group_name, connections) def final_multiply(register_id: CompVar) -> List[Group]: # Multiply e^{fractional_value} * e^{integer_value}, # and write it to register `m`. group_name = CompVar("final_multiply") mult_pipe = CompVar("mult_pipe1") reg = CompVar("m") return [ Group( id=group_name, connections=[ Connect( CompPort(CompVar("pow1"), "out"), CompPort(mult_pipe, "left"), ), Connect( CompPort(CompVar("sum1"), "out"), CompPort(mult_pipe, "right"), ), Connect( ConstantPort(1, 1), CompPort(mult_pipe, "go"), Not(Atom(CompPort(mult_pipe, "done"))), ), Connect(CompPort(mult_pipe, "done"), CompPort(reg, "write_en")), Connect(CompPort(mult_pipe, "out"), CompPort(reg, "in")), Connect(CompPort(reg, "done"), HolePort(group_name, "done")), ], ) ] if is_signed: # Take the reciprocal, since the initial value was -x. div_pipe = CompVar("div_pipe") input = CompVar("m") reciprocal = Group( id=CompVar("reciprocal"), connections=[ Connect(CompPort(CompVar("one"), "out"), CompPort(div_pipe, "left")), Connect(CompPort(input, "out"), CompPort(div_pipe, "right")), Connect( ConstantPort(1, 1), CompPort(div_pipe, "go"), Not(Atom(CompPort(div_pipe, "done"))), ), Connect(CompPort(div_pipe, "done"), CompPort(input, "write_en")), Connect(CompPort(div_pipe, "out_quotient"), CompPort(input, "in")), Connect(CompPort(input, "done"), HolePort(CompVar("reciprocal"), "done")), ], ) is_negative = CombGroup(id=CompVar("is_negative"), connections=[ Connect(ThisPort(CompVar("x")), CompPort(CompVar("lt"), "left")), Connect(ConstantPort(width, 0), CompPort(CompVar("lt"), "right")), ]) # Connect final value to the `out` signal of the component. output_register = CompVar("m") out = [Connect(CompPort(output_register, "out"), ThisPort(CompVar("out")))] return ( [init, split_bits] + ([negate, is_negative, reciprocal] if is_signed else []) + [consume_pow(j) for j in range(2, degree + 1)] + [multiply_by_reciprocal_factorial(k) for k in range(2, degree + 1)] + divide_and_conquer_sums(degree) + final_multiply(output_register) + out)
def divide_and_conquer_sums(degree: int) -> List[Structure]: """Returns a list of groups for the sums. This is done by dividing the groups into log2(N) different rounds, where N is the `degree`. These rounds can then be executed in parallel. For example, with N == 4, we will produce groups: group sum_round1_1 { ... } # x p2 p3 p4 # \ / \ / group sum_round1_2 { ... } # sum1 sum2 # \ / group sum_round2_1 { ... } # sum1 group add_degree_zero { ... } # sum1 + 1 """ groups = [] sum_count = degree round = 1 while sum_count > 1: indices = [i for i in range(1, sum_count + 1)] register_indices = [(lhs, rhs) for lhs, rhs in zip( list(filter(lambda x: (x % 2 != 0), indices)), list(filter(lambda x: (x % 2 == 0), indices)), )] for i, (lhs, rhs) in enumerate(register_indices): group_name = CompVar(f"sum_round{round}_{i + 1}") adder = CompVar(f"add{i + 1}") # The first round will accrue its operands # from the previously calculated products. register_name = "product" if round == 1 else "sum" reg_lhs = CompVar(f"{register_name}{lhs}") reg_rhs = CompVar(f"{register_name}{rhs}") sum = CompVar(f"sum{i + 1}") # In the first round and first group, we add the 1st degree, the # value `x` itself. lhs = (CompPort(CompVar("frac_x"), "out") if round == 1 and i == 0 else CompPort(reg_lhs, "out")) connections = [ Connect(lhs, CompPort(adder, "left")), Connect(CompPort(reg_rhs, "out"), CompPort(adder, "right")), Connect(ConstantPort(1, 1), CompPort(sum, "write_en")), Connect(CompPort(adder, "out"), CompPort(sum, "in")), Connect(CompPort(sum, "done"), HolePort(group_name, "done")), ] groups.append(Group(group_name, connections, 1)) sum_count >>= 1 round = round + 1 # Sums the 0th degree value, 1, and the final # sum of the divide-and-conquer. group_name = CompVar("add_degree_zero") adder = CompVar("add1") reg = CompVar("sum1") groups.append( Group( id=group_name, connections=[ Connect(CompPort(reg, "out"), CompPort(adder, "left")), Connect( CompPort(CompVar("one"), "out"), CompPort(adder, "right"), ), Connect(ConstantPort(1, 1), CompPort(reg, "write_en")), Connect(CompPort(adder, "out"), CompPort(reg, "in")), Connect(CompPort(reg, "done"), HolePort(group_name, "done")), ], static_delay=1, )) return groups
def generate_fp_pow_component(width: int, int_width: int, is_signed: bool) -> Component: """Generates a fixed point `pow` component, which computes the value x**y, where y must be an integer. """ stdlib = Stdlib() frac_width = width - int_width pow = CompVar("pow") count = CompVar("count") mul = CompVar("mul") lt = CompVar("lt") incr = CompVar("incr") cells = [ Cell(pow, stdlib.register(width)), Cell(count, stdlib.register(width)), Cell( mul, stdlib.fixed_point_op("mult_pipe", width, int_width, frac_width, signed=is_signed), ), Cell(lt, stdlib.op("lt", width, signed=is_signed)), Cell(incr, stdlib.op("add", width, signed=is_signed)), ] wires = [ Group( id=CompVar("init"), connections=[ Connect( ConstantPort( width, numeric_types.FixedPoint( "1.0", width, int_width, is_signed=is_signed).unsigned_integer(), ), CompPort(pow, "in"), ), Connect(ConstantPort(1, 1), CompPort(pow, "write_en")), Connect(ConstantPort(width, 0), CompPort(count, "in")), Connect(ConstantPort(1, 1), CompPort(count, "write_en")), Connect( ConstantPort(1, 1), HolePort(CompVar("init"), "done"), And( Atom(CompPort(pow, "done")), Atom(CompPort(count, "done")), ), ), ], ), Group( id=CompVar("execute_mul"), connections=[ Connect(ThisPort(CompVar("base")), CompPort(mul, "left")), Connect(CompPort(pow, "out"), CompPort(mul, "right")), Connect( ConstantPort(1, 1), CompPort(mul, "go"), Not(Atom(CompPort(mul, "done"))), ), Connect(CompPort(mul, "done"), CompPort(pow, "write_en")), Connect(CompPort(mul, "out"), CompPort(pow, "in")), Connect( CompPort(pow, "done"), HolePort(CompVar("execute_mul"), "done"), ), ], ), Group( id=CompVar("incr_count"), connections=[ Connect(ConstantPort(width, 1), CompPort(incr, "left")), Connect(CompPort(count, "out"), CompPort(incr, "right")), Connect(CompPort(incr, "out"), CompPort(count, "in")), Connect(ConstantPort(1, 1), CompPort(count, "write_en")), Connect( CompPort(count, "done"), HolePort(CompVar("incr_count"), "done"), ), ], ), CombGroup( id=CompVar("cond"), connections=[ Connect(CompPort(count, "out"), CompPort(lt, "left")), Connect(ThisPort(CompVar("integer_exp")), CompPort(lt, "right")), ], ), Connect(CompPort(CompVar("pow"), "out"), ThisPort(CompVar("out"))), ] return Component( "fp_pow", inputs=[ PortDef(CompVar("base"), width), PortDef(CompVar("integer_exp"), width), ], outputs=[PortDef(CompVar("out"), width)], structs=cells + wires, controls=SeqComp([ Enable("init"), While( CompPort(lt, "out"), CompVar("cond"), ParComp([Enable("execute_mul"), Enable("incr_count")]), ), ]), )
Cell(CompVar("t"), Stdlib().register(width)), Cell(CompVar("x"), Stdlib().mem_d1(width, 1, 1), is_external=True), Cell( CompVar("ret"), Stdlib().mem_d1(width, 1, 1), is_external=True, ), Cell(CompVar("e"), CompInst("exp", [])), Group( id=CompVar("init"), connections=[ Connect( ConstantPort(1, 0), CompPort(CompVar("x"), "addr0"), ), Connect( CompPort(CompVar("x"), "read_data"), CompPort(CompVar("t"), "in"), ), Connect( ConstantPort(1, 1), CompPort(CompVar("t"), "write_en"), ), Connect( CompPort(CompVar("t"), "done"), HolePort(CompVar("init"), "done"), ), ], ),