def combinational(defn_env: dict, fn: types.FunctionType): tree = ast_utils.get_func_ast(fn) tree, renamed_args = convert_tree_to_ssa(tree, defn_env) tree = FunctionToCircuitDefTransformer(renamed_args).visit(tree) tree = ast.fix_missing_locations(tree) filename = None lines = None if get_debug_mode(): filename = inspect.getsourcefile(fn) lines = inspect.getsourcelines(fn) tree = IfTransformer(filename, lines).visit(tree) tree = ast.fix_missing_locations(tree) tree.decorator_list = ast_utils.filter_decorator( combinational, tree.decorator_list, defn_env) if "phi" not in defn_env: tree = ast.Module([ ast.parse("import magma as m").body[0], ast.parse("from mantle import mux as phi").body[0], tree ]) source = "\n" for i, line in enumerate(astor.to_source(tree).splitlines()): source += f" {i}: {line}\n" debug(source) circuit_def = ast_utils.compile_function_to_file(tree, fn.__name__, defn_env) if get_debug_mode() and getattr(circuit_def, "debug_info", False): circuit_def.debug_info = debug_info(circuit_def.debug_info.filename, circuit_def.debug_info.lineno, inspect.getmodule(fn)) @functools.wraps(fn) def func(*args, **kwargs): return circuit_def()(*args, **kwargs) func.__name__ = fn.__name__ func.__qualname__ = fn.__name__ # Provide a mechanism for accessing the underlying circuit definition setattr(func, "circuit_definition", circuit_def) return func
def _sequential(defn_env: dict, async_reset: bool, cls, combinational_decorator: typing.Callable): # if not inspect.isclass(cls): # raise ValueError("sequential decorator only works with classes") initial_value_map = get_initial_value_map(cls.__init__, defn_env) call_def = get_ast(cls.__call__).body[0] inputs, output_type = get_io(call_def) io_list = gen_io_list(inputs, output_type, async_reset) circuit_combinational_output_type = [] circuit_combinational_args = [] circuit_combinational_call_args = [] comb_out_wiring = [] for name, type_ in inputs: type_ = astor.to_source(type_).rstrip() circuit_combinational_args.append(f"{name}: {type_}") circuit_combinational_call_args.append(f"io.{name}") comb_out_count = 0 for name, (value, type_, eval_type, eval_value) in initial_value_map.items(): if isinstance(eval_type, m.Kind): type_ = astor.to_source(type_).rstrip() circuit_combinational_args.append(f"self_{name}_O: {type_}") circuit_combinational_call_args.append(f"{name}") circuit_combinational_output_type.append(f"{type_}") comb_out_wiring.append(f"{name}.I <= comb_out[{comb_out_count}]\n") comb_out_count += 1 else: for key, value in eval_value.interface.ports.items(): if isinstance(value, (m.ClockType, m.AsyncResetType)): continue type_ = repr(type(value)) if value.isoutput(): circuit_combinational_args.append( f"self_{name}_{value}: m.{type_}") circuit_combinational_call_args.append(f"{name}.{value}") if value.isinput(): circuit_combinational_output_type.append(f"m.{type_}") comb_out_wiring.append( f"{name}.{value} <= comb_out[{comb_out_count}]\n") comb_out_count += 1 circuit_combinational_args = ', '.join(circuit_combinational_args) circuit_combinational_call_args = ', '.join( circuit_combinational_call_args) if isinstance(output_type, ast.Tuple): output_types = [] for i, elem in enumerate(output_type.elts): circuit_combinational_output_type.append( astor.to_source(elem).rstrip()) comb_out_wiring.append( f"io.O{i} <= comb_out[{comb_out_count + i}]\n") else: output_type_str = astor.to_source(output_type).rstrip() circuit_combinational_output_type.append(output_type_str) comb_out_wiring.append(f"io.O <= comb_out[{comb_out_count}]\n") tab = 4 * ' ' comb_out_wiring = (3 * tab).join(comb_out_wiring) circuit_combinational_output_type = ', '.join( circuit_combinational_output_type) circuit_combinational_body = [] for stmt in call_def.body: rewriter = RewriteSelfAttributes(initial_value_map) stmt = rewriter.visit(stmt) code = [stmt] if rewriter.calls_seen: code = rewriter.calls_seen + code stmt = RewriteReturn(initial_value_map).visit(stmt) for stmt in code: for line in astor.to_source(stmt).rstrip().splitlines(): circuit_combinational_body.append(line) circuit_combinational_body = ('\n' + 4 * tab).join(circuit_combinational_body) register_instances = gen_register_instances(initial_value_map, async_reset) register_instances = ('\n' + 3 * tab).join(register_instances) circuit_definition_str = circuit_definition_template.format( circuit_name=cls.__name__, io_list=io_list, register_instances=register_instances, circuit_combinational_args=circuit_combinational_args, circuit_combinational_output_type=circuit_combinational_output_type, circuit_combinational_body=circuit_combinational_body, circuit_combinational_call_args=circuit_combinational_call_args, comb_out_wiring=comb_out_wiring) tree = ast.parse(circuit_definition_str) if "DefineRegister" not in defn_env: tree = ast.Module([ ast.parse("from mantle import DefineRegister").body[0], ] + tree.body) circuit_def_constructor = ast_utils.compile_function_to_file( tree, 'make_' + cls.__name__, defn_env) circuit_def = circuit_def_constructor(combinational_decorator) if get_debug_mode() and getattr(circuit_def, "debug_info", False): circuit_def.debug_info = debug_info(circuit_def.debug_info.filename, circuit_def.debug_info.lineno, inspect.getmodule(cls)) return circuit_def