예제 #1
0
파일: ssa.py 프로젝트: standanley/magma
def _ssa(defn_env: dict, phi: typing.Union[str, typing.Callable],
         fn: typing.Callable):
    tree = ast_utils.get_func_ast(fn)
    tree.decorator_list = ast_utils.filter_decorator(ssa,
                                                     tree.decorator_list,
                                                     defn_env)

    if isinstance(phi, str):
        phi_name = phi
    else:
        phi_name = ast_utils.gen_free_name(tree, defn_env)

    tree, _ = convert_tree_to_ssa(tree, defn_env, phi_name=phi_name)

    if not isinstance(phi, str):
        defn_env[phi_name] = phi

    tree.body.append(ast.Return(ast.Name("O", ast.Load())))
    return ast_utils.compile_function_to_file(tree, defn_env=defn_env)
예제 #2
0
def combinational(defn_env: dict, fn: types.FunctionType):
    tree = ast_utils.get_func_ast(fn)
    tree, renamed_args = convert_tree_to_ssa(tree, defn_env)
    tree = FunctionToCircuitDefTransformer(renamed_args).visit(tree)
    tree = ast.fix_missing_locations(tree)
    filename = None
    lines = None
    if get_debug_mode():
        filename = inspect.getsourcefile(fn)
        lines = inspect.getsourcelines(fn)
    tree = IfTransformer(filename, lines).visit(tree)
    tree = ast.fix_missing_locations(tree)
    tree.decorator_list = ast_utils.filter_decorator(
        combinational, tree.decorator_list, defn_env)
    if "phi" not in defn_env:
        tree = ast.Module([
            ast.parse("import magma as m").body[0],
            ast.parse("from mantle import mux as phi").body[0],
            tree
        ])
    source = "\n"
    for i, line in enumerate(astor.to_source(tree).splitlines()):
        source += f"    {i}: {line}\n"

    debug(source)
    circuit_def = ast_utils.compile_function_to_file(tree, fn.__name__,
                                                     defn_env)
    if get_debug_mode() and getattr(circuit_def, "debug_info", False):
        circuit_def.debug_info = debug_info(circuit_def.debug_info.filename,
                                            circuit_def.debug_info.lineno,
                                            inspect.getmodule(fn))

    @functools.wraps(fn)
    def func(*args, **kwargs):
        return circuit_def()(*args, **kwargs)
    func.__name__ = fn.__name__
    func.__qualname__ = fn.__name__
    # Provide a mechanism for accessing the underlying circuit definition
    setattr(func, "circuit_definition", circuit_def)
    return func
예제 #3
0
def _sequential(defn_env: dict, async_reset: bool, cls,
                combinational_decorator: typing.Callable):
    # if not inspect.isclass(cls):
    #     raise ValueError("sequential decorator only works with classes")

    initial_value_map = get_initial_value_map(cls.__init__, defn_env)

    call_def = get_ast(cls.__call__).body[0]
    inputs, output_type = get_io(call_def)
    io_list = gen_io_list(inputs, output_type, async_reset)

    circuit_combinational_output_type = []
    circuit_combinational_args = []
    circuit_combinational_call_args = []
    comb_out_wiring = []
    for name, type_ in inputs:
        type_ = astor.to_source(type_).rstrip()
        circuit_combinational_args.append(f"{name}: {type_}")
        circuit_combinational_call_args.append(f"io.{name}")

    comb_out_count = 0
    for name, (value, type_, eval_type,
               eval_value) in initial_value_map.items():
        if isinstance(eval_type, m.Kind):
            type_ = astor.to_source(type_).rstrip()
            circuit_combinational_args.append(f"self_{name}_O: {type_}")
            circuit_combinational_call_args.append(f"{name}")
            circuit_combinational_output_type.append(f"{type_}")
            comb_out_wiring.append(f"{name}.I <= comb_out[{comb_out_count}]\n")
            comb_out_count += 1
        else:
            for key, value in eval_value.interface.ports.items():
                if isinstance(value, (m.ClockType, m.AsyncResetType)):
                    continue
                type_ = repr(type(value))
                if value.isoutput():
                    circuit_combinational_args.append(
                        f"self_{name}_{value}: m.{type_}")
                    circuit_combinational_call_args.append(f"{name}.{value}")
                if value.isinput():
                    circuit_combinational_output_type.append(f"m.{type_}")
                    comb_out_wiring.append(
                        f"{name}.{value} <= comb_out[{comb_out_count}]\n")
                    comb_out_count += 1

    circuit_combinational_args = ', '.join(circuit_combinational_args)
    circuit_combinational_call_args = ', '.join(
        circuit_combinational_call_args)

    if isinstance(output_type, ast.Tuple):
        output_types = []
        for i, elem in enumerate(output_type.elts):
            circuit_combinational_output_type.append(
                astor.to_source(elem).rstrip())
            comb_out_wiring.append(
                f"io.O{i} <= comb_out[{comb_out_count + i}]\n")
    else:
        output_type_str = astor.to_source(output_type).rstrip()
        circuit_combinational_output_type.append(output_type_str)
        comb_out_wiring.append(f"io.O <= comb_out[{comb_out_count}]\n")

    tab = 4 * ' '
    comb_out_wiring = (3 * tab).join(comb_out_wiring)
    circuit_combinational_output_type = ', '.join(
        circuit_combinational_output_type)
    circuit_combinational_body = []
    for stmt in call_def.body:
        rewriter = RewriteSelfAttributes(initial_value_map)
        stmt = rewriter.visit(stmt)
        code = [stmt]
        if rewriter.calls_seen:
            code = rewriter.calls_seen + code
        stmt = RewriteReturn(initial_value_map).visit(stmt)
        for stmt in code:
            for line in astor.to_source(stmt).rstrip().splitlines():
                circuit_combinational_body.append(line)

    circuit_combinational_body = ('\n' +
                                  4 * tab).join(circuit_combinational_body)
    register_instances = gen_register_instances(initial_value_map, async_reset)
    register_instances = ('\n' + 3 * tab).join(register_instances)

    circuit_definition_str = circuit_definition_template.format(
        circuit_name=cls.__name__,
        io_list=io_list,
        register_instances=register_instances,
        circuit_combinational_args=circuit_combinational_args,
        circuit_combinational_output_type=circuit_combinational_output_type,
        circuit_combinational_body=circuit_combinational_body,
        circuit_combinational_call_args=circuit_combinational_call_args,
        comb_out_wiring=comb_out_wiring)
    tree = ast.parse(circuit_definition_str)
    if "DefineRegister" not in defn_env:
        tree = ast.Module([
            ast.parse("from mantle import DefineRegister").body[0],
        ] + tree.body)

    circuit_def_constructor = ast_utils.compile_function_to_file(
        tree, 'make_' + cls.__name__, defn_env)
    circuit_def = circuit_def_constructor(combinational_decorator)

    if get_debug_mode() and getattr(circuit_def, "debug_info", False):
        circuit_def.debug_info = debug_info(circuit_def.debug_info.filename,
                                            circuit_def.debug_info.lineno,
                                            inspect.getmodule(cls))

    return circuit_def
예제 #4
0
파일: ssa.py 프로젝트: Kuree/magma
def ssa(defn_env: dict, fn: types.FunctionType):
    tree = ast_utils.get_func_ast(fn)
    tree, _ = convert_tree_to_ssa(tree, defn_env)
    tree.body.append(ast.Return(ast.Name("O", ast.Load())))
    return ast_utils.compile_function_to_file(tree, defn_env=defn_env)