Ejemplo n.º 1
0
 def get_string( obj ):
   """Return the string that identifies `obj`"""
   if isinstance(obj, type):
     if is_bitstruct_class(obj):
       return get_rtlir_dtype( obj() ).get_name()
     return obj.__name__
   return str( obj )
Ejemplo n.º 2
0
    def collect_sig_func(self, top, wavmeta):

        # TODO use actual nets to reduce the amount of saved signals

        # Give all ' and " characters a preceding backslash for .format
        wav_srcs = []

        # Now we create per-cycle signal value collect functions
        for x in top._dsl.all_signals:
            if x.is_top_level_signal() and (not repr(x).endswith('.clk')
                                            or x is top.clk):
                if is_bitstruct_class(x._dsl.Type):
                    raise ModelTypeError(
                        "designs without Bitstruct signals.\n"
                        "- Currently text waveform cannot dump design with bitstruct.\n"
                        "- :) The bitstruct is usually too wide to display anyways."
                    )
                wav_srcs.append(
                    "wavmeta.sigs['{0}'].append( {0}.bin() )".format(x))

        wavmeta.sigs = defaultdict(list)

        # TODO use integer index instead of dict, should be easy
        src = """
def dump_wav():
  {}
""".format("\n  ".join(wav_srcs))
        s, l_dict = top, {}
        exec(compile(src, filename="temp", mode="exec"),
             globals().update(locals()), l_dict)
        return l_dict['dump_wav']
Ejemplo n.º 3
0
    def visit_Call(s, node):
        """Return behavioral RTLIR of a method call.

    At L3 we need to support the syntax of struct instantiation in upblks.
    This is achieved by function calls like `struct( 1, 2, 0 )`.
    """
        obj = s.get_call_obj(node)
        if is_bitstruct_class(obj):
            fields = obj.__bitstruct_fields__
            nargs = len(node.args)
            nfields = len(fields.keys())
            if nargs == 0:
                # Infer the values of each field by inspecting the object constructed
                # with default arguments
                inst = obj()
                values = [
                    s._datatype_to_bir(getattr(inst, field))
                    for field in fields.keys()
                ]
            else:
                # Otherwise all fields of the struct must be present in the arguments
                if nargs != nfields:
                    raise PyMTLSyntaxError(
                        s.blk, node,
                        f'BitStruct {obj.__name__} has {nfields} fields but {nargs} arguments are given!'
                    )
                values = [s.visit(arg) for arg in node.args]

            ret = bir.StructInst(obj, values)
            ret.ast = node
            return ret

        else:
            return super().visit_Call(node)
Ejemplo n.º 4
0
  def collect_sig_func( self, top, wavmeta ):

    # TODO use actual nets to reduce the amount of saved signals

    # Give all ' and " characters a preceding backslash for .format
    wav_srcs = []

    # Now we create per-cycle signal value collect functions
    for x in top._dsl.all_signals:
      if x.is_top_level_signal() and ( not repr(x).endswith('.clk') or x is top.clk ):
        if is_bitstruct_class( x._dsl.Type ):
          wav_srcs.append( "wavmeta.text_sigs['{0}'].append( to_bits({0}).bin() )".format(x) )
        elif issubclass( x._dsl.Type, Bits ):
          wav_srcs.append( "wavmeta.text_sigs['{0}'].append( {0}.bin() )".format(x) )

    wavmeta.text_sigs = defaultdict(list)

    # TODO use integer index instead of dict, should be easy
    src =  """
def dump_wav():
  {}
""".format( "\n  ".join(wav_srcs) )
    s, l_dict = top, {}
    exec(compile( src, filename="temp", mode="exec"), globals().update(locals()), l_dict)
    return l_dict['dump_wav']
Ejemplo n.º 5
0
def get_rtlir_dtype(obj):
    """Return the RTLIR data type of obj."""
    try:
        assert not isinstance(obj, list), \
          'array datatype object should be a field of some struct!'

        # Signals might be parameterized with different data types
        if isinstance(obj, (dsl.Signal, dsl.Const)):
            Type = obj._dsl.Type
            assert isinstance(Type, type)

            # Vector data type
            if issubclass(Type, Bits):
                return Vector(Type.nbits)

            # python int object
            elif Type is int:
                return Vector(32)

            # Struct data type
            elif is_bitstruct_class(Type):
                try:
                    return _get_rtlir_dtype_struct(Type())
                except TypeError:
                    assert False, \
                      f'__init__() of supposed struct {Type.__name__} should take 0 argument ( you can \
            achieve this by adding default values to your arguments )!'

            else:
                assert False, f'cannot convert object of type {Type} into RTLIR!'

        # Python integer objects
        elif isinstance(obj, int):
            # Following the Verilog bitwidth rule: number literals have 32 bit width
            # by default.
            return Vector(32)

        # PyMTL Bits objects
        elif isinstance(obj, Bits):
            return Vector(obj.nbits)

        # PyMTL BitStruct objects
        elif is_bitstruct_inst(obj):
            return _get_rtlir_dtype_struct(obj)

        else:
            assert False, 'cannot infer the data type of the given object!'
    except AssertionError as e:
        msg = '' if e.args[0] is None else e.args[0]
        raise RTLIRConversionError(obj, msg)
Ejemplo n.º 6
0
 def _get_arg_str(name, obj):
     # Support common python types and Bits/BitStruct
     if isinstance(obj, (int, bool, str)):
         return str(obj)
     elif obj == None:
         return 'None'
     elif isinstance(obj, Bits):
         nbits = obj.nbits
         value = int(obj)
         Bits_name = f"Bits{nbits}"
         Bits_arg_str = f"{Bits_name}( {value} )"
         if Bits_name not in symbols and nbits >= 256:
             Bits_class = mk_bits(nbits)
             symbols.update({Bits_name: Bits_class})
         return Bits_arg_str
     elif is_bitstruct_inst(obj):
         raise TypeError(
             "Do you really want to pass in an instance of "
             "a BitStruct? Contact PyMTL developers!")
         # This is hacky: we don't know how to construct an object that
         # is the same as `obj`, but we do have the object itself. If we
         # add `obj` to the namespace of `construct` everything works fine
         # but the user cannot tell what object is passed to the constructor
         # just from the code.
         # Do not use a double underscore prefix because that will be
         # interpreted differently by the Python interpreter
         # bs_name = ("_" if name[0] != "_" else "") + name + "_obj"
         # if bs_name not in symbols:
         # symbols.update( { bs_name : obj } )
         # return bs_name
     elif isinstance(obj, type) and issubclass(obj, Bits):
         nbits = obj.nbits
         Bits_name = f"Bits{nbits}"
         if Bits_name not in symbols and nbits >= 256:
             Bits_class = mk_bits(nbits)
             symbols.update({Bits_name: Bits_class})
         return Bits_name
     elif is_bitstruct_class(obj):
         BitStruct_name = obj.__name__
         if BitStruct_name not in symbols:
             symbols.update({BitStruct_name: obj})
         return BitStruct_name
     # FIXME formalize this
     elif isinstance(obj, type) and hasattr(
             obj, 'req') and hasattr(obj, 'resp'):
         symbols.update({obj.__name__: obj})
         return obj.__name__
     raise TypeError(
         f"Interface constructor argument {obj} is not an int/Bits/BitStruct/TypeTuple!"
     )
Ejemplo n.º 7
0
  def __init__( s, Type=Bits1 ):
    if isinstance( Type, int ):
      Type = mk_bits(Type)
    else:
      assert isinstance( Type, type ) and ( issubclass( Type, Bits ) or is_bitstruct_class(Type) ), \
              f"RTL signal can only be of Bits type or bitstruct type, not {Type}.\n" \
              f"Note: an integer is also accepted: Wire(32) is equivalent to Wire(Bits32))"

    s._dsl.Type = Type
    s._dsl.type_instance = None

    s._dsl.slice  = None # None -- not a slice of some wire by default
    s._dsl.slices = {}
    s._dsl.top_level_signal = s

    s._dsl.needs_double_buffer = False
Ejemplo n.º 8
0
  def enter( s, node ):
    ret = s.visit( node )
    # Constant objects that are recognized
    # 1. int, BitsN( X )
    # 2. BitsN
    # 3. BitStruct, BitStruct()
    # 4. Functions, including concat, zext, sext, etc.
    is_value = isinstance(ret, (int, Bits)) or is_bitstruct_inst(ret)
    is_type = isinstance(ret, type) and (issubclass(ret, Bits) or is_bitstruct_class(ret))
    try:
      is_function = ret in s.pymtl_functions
    except:
      is_function = False

    if is_value or is_type or is_function:
      return ret
    else:
      return None
Ejemplo n.º 9
0
    def visit_Call(s, node):
        """Return behavioral RTLIR of a method call.

    At L3 we need to support the syntax of struct instantiation in upblks.
    This is achieved by function calls like `struct( 1, 2, 0 )`.
    """
        obj = s.get_call_obj(node)
        if is_bitstruct_class(obj):
            if len(node.args) < 1:
                raise PyMTLSyntaxError(
                    s.blk, node,
                    'at least one value should be provided to struct instantiation!'
                )
            values = [s.visit(arg) for arg in node.args]
            ret = bir.StructInst(obj, values)
            ret.ast = node
            return ret

        else:
            return super().visit_Call(node)
Ejemplo n.º 10
0
        def compile_scc(i):
            nonlocal scc_id

            scc = SCCs[i]

            if len(scc) == 1:
                return list(scc)[0]

            for x in scc:
                if x in onces:
                    raise UpblkCyclicError("update_once blocks are not allowed to appear in a cycle. \n - " + \
                                    "\n - ".join( [
                                      f"{y.__name__} ({'@update_once' if y in onces else '@update'} " \
                                      f"in 'top.{repr(top.get_update_block_host_component(y))[2:]}')"
                                      for y in scc] ))

            scc_id += 1
            if _DEBUG: print(f"{'='*100}\n SCC{scc_id}\n{'='*100}")

            # For each non-trivial SCC, we need to figure out a intra-SCC
            # linear schedule that minimizes the time to re-execute this SCC
            # due to value changes. A bad schedule may inefficiently execute
            # the SCC for many times, each of which changes a few signals.
            # The current algorithm iteratively finds the "entry block" of
            # the SCC and expand its adjancent blocks. The implementation is
            # to first find the actual entry point, and then BFS to expand the
            # footprint until all nodes are visited.

            tmp_schedule = []
            Q = deque()

            if scc_pred[i] is None:
                # We start bfs from the block that has the least number of input
                # edges in the SCC
                InD = {v: 0 for v in scc}
                for (u, v) in E:  # u -> v
                    if u in scc and v in scc:
                        InD[v] += 1
                Q.append(max(InD, key=InD.get))

            else:
                # We start bfs with the blocks that are successors of the
                # predecessor scc in the previous SCC-level topological sort.
                pred = set(SCCs[scc_pred[i]])
                # Sort by names for a fixed outcome
                for x in sorted(scc, key=lambda x: x.__name__):
                    for v in G_T[
                            x]:  # find reversed edges point back to pred SCC
                        if v in pred:
                            Q.append(x)

            # Perform bfs to find a heuristic schedule
            visited = set(Q)
            while Q:
                u = Q.popleft()
                tmp_schedule.append(u)
                for v in G[u]:
                    if v in scc and v not in visited:
                        Q.append(v)
                        visited.add(v)

            variables = set()
            for (u, v) in E:
                # Collect all variables that triggers other blocks in the SCC
                if u in scc and v in scc:
                    variables.update(constraint_objs[(u, v)])

            if len(variables) == 0:
                raise UpblkCyclicError("There is a cyclic dependency without involving variables."
                                "Probably a loop that involves blocks that should be update_once:\n{}"\
                                .format(", ".join( [ x.__name__ for x in scc] )))

            # generate a loop for scc
            # Shunning: we just simply loop over the whole SCC block
            # TODO performance optimizations using Mamba techniques within a SCC block

            template = """
from copy import deepcopy
def wrapped_SCC_{0}():
  N = 0
  while True:
    N += 1
    if N > 100:
      raise UpblkCyclicError("Combinational loop detected at runtime in {{{4}}} after 100 iters!")
    {1}
    {3}
    {2}
    # print( "SCC block{0} is executed", N, "times" )
    break
generated_block = wrapped_SCC_{0}
"""

            # clean up non-top variables if top is there. For slices of Bits
            # we directly use the top level wide Bits since Bits clone is
            # rpython code

            final_variables = set()

            for x in sorted(variables, key=repr):
                w = x.get_top_level_signal()
                if w is x:
                    final_variables.add(x)
                    continue

                # w is not x
                if issubclass(w._dsl.Type, Bits):
                    if w not in final_variables:
                        final_variables.add(w)
                elif is_bitstruct_class(w._dsl.Type):
                    if w not in final_variables:
                        final_variables.add(x)
                else:
                    final_variables.add(x)

            # also group them by common ancestor to reduce byte code
            # TODO use longest-common-prefix (LCP) algorithms ...

            final_var_host = defaultdict(list)
            for x in final_variables:
                final_var_host[x.get_host_component()].append(x)

            # Then, we generate the Python code that saves variables at the
            # beginning of each SCC iteration and the code that checks if the
            # values of those variables have changed
            copy_srcs = []
            check_srcs = []

            var_id = 0
            for host, var_list in final_var_host.items():
                hostlen = len(repr(host))

                copy_srcs.append(f"host = {host!r}")
                check_srcs.append(f"host = {host!r}")

                sub_check_srcs = []

                for var in var_list:
                    var_id += 1
                    subname = repr(var)[hostlen + 1:]
                    if issubclass(var._dsl.Type, Bits):
                        copy_srcs.append(f"t{var_id}=host.{subname}.clone()")
                    elif is_bitstruct_class(var._dsl.Type):
                        copy_srcs.append(f"t{var_id}=host.{subname}.clone()")
                    else:
                        copy_srcs.append(f"t{var_id}=deepcopy(host.{subname})")

                    sub_check_srcs.append(f"host.{subname} != t{var_id}")

                check_srcs.append(
                    f"if { ' or '.join(sub_check_srcs)}: continue")

            # Divide all blks into meta blocks
            # Branchiness factor is the bound of branchiness in a meta block.
            branchiness_factor = 20
            branchy_block_factor = 6

            num_blks = 0  # sanity check
            cur_meta, cur_br, cur_count = [], 0, 0
            scc_schedule = []

            _globals = {'s': top, 'UpblkCyclicError': UpblkCyclicError}
            blk_srcs = []

            # If there is only 10 blocks, we directly unroll it
            if len(tmp_schedule) < 10:
                blk_srcs = []
                for i, b in enumerate(tmp_schedule):
                    blk_srcs.append(
                        f"blk{i}() # [br {self.branchiness[b]}, loop {int(self.only_loop_at_top[b])}] {b.__name__}"
                    )
                    _globals[f"blk{i}"] = b  # put it into the block's closure

            else:
                for i, blk in enumerate(tmp_schedule):
                    # Same here. If an update block only has top-level loop, br = 0
                    br = 0 if self.only_loop_at_top[blk] else self.branchiness[
                        blk]
                    if cur_br == 0:
                        cur_meta.append(blk)
                        cur_br += br
                        cur_count += (br > 0)
                        if cur_br >= branchiness_factor or cur_count >= branchy_block_factor:
                            num_blks += len(cur_meta)
                            scc_schedule.append(cur_meta)
                            cur_meta, cur_br, cur_count = [], 0, 0  # clear
                    else:
                        if br == 0:
                            # If no branchy block available, directly start a new metablock
                            num_blks += len(cur_meta)
                            scc_schedule.append(cur_meta)
                            cur_meta, cur_br, cur_count = [blk], br, (br > 0)
                        else:
                            cur_meta.append(blk)
                            cur_br += br
                            cur_count += (br > 0)

                            if cur_br + br >= branchiness_factor or cur_count + 1 >= branchy_block_factor:
                                num_blks += len(cur_meta)
                                scc_schedule.append(cur_meta)
                                cur_meta, cur_br, cur_count = [], 0, 0  # clear

                if cur_meta:
                    num_blks += len(cur_meta)
                    scc_schedule.append(cur_meta)

                assert num_blks == len(tmp_schedule), f"Some blocks are missing during trace breaking of SCC "\
                                                      f"({num_blks} compiled, {len(tmp_schedule)} total)"

                blk_srcs = []

                if len(scc_schedule) == 1:
                    for i, b in enumerate(scc_schedule[-1]):
                        blk_srcs.append(
                            f"blk{i}() # [br {self.branchiness[b]}, loop {int(self.only_loop_at_top[b])}] {b.__name__}"
                        )
                        _globals[f"blk{i}"] = b

                else:
                    # TODO we might turn all meta blocks before the last one into meta
                    # blocks, and directly fold the last block into the main loop
                    # for i, meta in enumerate( scc_schedule[:-1] ):
                    # b = self.compile_meta_block( meta )
                    # blk_srcs.append( f"{b.__name__}()" )
                    # _globals[ b.__name__ ] = b

                    # for i, b in enumerate( scc_schedule[-1] ):
                    # blk_srcs.append( f"blk_of_last_meta{i}() # [br {self.branchiness[b]}, loop {int(self.only_loop_at_top[b])}] {b.__name__}" )
                    # _globals[ f"blk_of_last_meta{i}" ] = b

                    for i, meta in enumerate(scc_schedule):
                        b = self.compile_meta_block(meta)
                        blk_srcs.append(f"{b.__name__}()")
                        _globals[b.__name__] = b

            scc_block_src = template.format(
                scc_id, "; ".join(copy_srcs), "\n    ".join(check_srcs),
                '\n    '.join(blk_srcs), ", ".join([x.__name__ for x in scc]))

            if _DEBUG: print(scc_block_src, "\n", "=" * 100)

            _locals = {}
            custom_exec(
                py.code.Source(scc_block_src).compile(), _globals, _locals)
            return _locals['generated_block']
Ejemplo n.º 11
0
def run_test_vector_sim(model,
                        test_vectors,
                        cmdline_opts=None,
                        line_trace=True):
    cmdline_opts = cmdline_opts or {
        'dump_vcd': False,
        'test_verilog': False,
        'dump_vtb': ''
    }

    # First row in test vectors contains port names

    if isinstance(test_vectors[0], str):
        port_names = test_vectors[0].split()
    else:
        port_names = test_vectors[0]

    # Remaining rows contain the actual test vectors

    test_vectors = test_vectors[1:]

    # Setup the model

    model = config_model_with_cmdline_opts(model, cmdline_opts, [])

    try:
        # Create a simulator
        model.apply(DefaultPassGroup(print_line_trace=line_trace))
        # Reset model
        model.sim_reset()

        # Run the simulation

        row_num = 0
        in_ids = []
        out_ids = []
        groups = [None] * len(port_names)
        types = [None] * len(port_names)

        # Preprocess default type
        # Special case for lists of ports
        # NOTE THAT WE ONLY SUPPORT 1D ARRAY and no interface
        for i, port_full_name in enumerate(port_names):
            if port_full_name[-1] == "*":
                out_ids.append(i)
                port_name = port_full_name[:-1]
            else:
                in_ids.append(i)
                port_name = port_full_name

            if '[' in port_name:
                # Get tokens of the full name
                m = re.match(r'(\w+)\[(\d+)\]', port_name)
                if not m:
                    raise Exception(
                        f"Could not parse port name: {port_name}. "
                        f"Currently we don't support interface or high-D array."
                    )

                groups[i] = g = (True, m.group(1), int(m.group(2)))

                # Get type of all the ports
                t = type(getattr(model, g[1])[int(g[2])])
                types[i] = None if is_bitstruct_class(t) else t

            else:
                groups[i] = (False, port_name)
                t = type(getattr(model, port_name))
                types[i] = None if is_bitstruct_class(t) else t

        # Run simulation

        for row in test_vectors:
            row_num += 1

            # Apply test inputs
            for i in in_ids:
                in_value = row[i]
                t = types[i]
                if t: in_value = t(in_value)
                g = groups[i]
                x = getattr(model, g[1])
                if g[0]: x[g[2]] @= in_value
                else: x @= in_value

            # Evaluate combinational concurrent blocks
            model.sim_eval_combinational()

            # Check test outputs
            for i in out_ids:
                ref_value = row[i]
                if ref_value == '?': continue

                g = groups[i]
                if g[0]: out_value = getattr(model, g[1])[g[2]]
                else: out_value = getattr(model, g[1])

                if out_value != ref_value:
                    if line_trace:
                        model.print_line_trace()
                    error_msg = """
run_test_vector_sim received an incorrect value!
- row number     : {row_number}
- port name      : {port_name}
- expected value : {expected_msg}
- actual value   : {actual_msg}
"""
                    raise RunTestVectorSimError(
                        error_msg.format(row_number=row_num,
                                         port_name=port_name,
                                         expected_msg=ref_value,
                                         actual_msg=out_value))

            # Tick the simulation
            model.sim_tick()

        # Extra ticks to make VCD easier to read
        model.sim_tick()
        model.sim_tick()
        model.sim_tick()

    finally:
        finalize_verilator(model)
Ejemplo n.º 12
0
def run_test_vector_sim(model,
                        test_vectors,
                        dump_vcd=None,
                        test_verilog=False,
                        line_trace=True):

    # First row in test vectors contains port names

    if isinstance(test_vectors[0], str):
        port_names = test_vectors[0].split()
    else:
        port_names = test_vectors[0]

    # Remaining rows contain the actual test vectors

    test_vectors = test_vectors[1:]

    # Setup the model

    model.elaborate()

    if dump_vcd:
        model.config_tracing = TracingConfigs(tracing='vcd',
                                              vcd_file_name=dump_vcd)

    if test_verilog:
        if not hasattr(model, 'config_verilog_import'):
            model.config_verilog_import = VerilatorImportConfigs(
                vl_xinit=test_verilog, )
        else:
            model.config_verilog_import.vl_xinit = test_verilog
        model.verilog_translate_import = True

    model.apply(VerilogPlaceholderPass())
    model = TranslationImportPass()(model)

    # Create a simulator

    model.apply(SimulationPass())

    # Reset model

    model.sim_reset(print_line_trace=line_trace)

    # Run the simulation

    row_num = 0

    in_ids = []
    out_ids = []

    groups = [None] * len(port_names)
    types = [None] * len(port_names)

    # Preprocess default type
    # Special case for lists of ports
    # NOTE THAT WE ONLY SUPPORT 1D ARRAY and no interface

    for i, port_full_name in enumerate(port_names):
        if port_full_name[-1] == "*":
            out_ids.append(i)
            port_name = port_full_name[:-1]
        else:
            in_ids.append(i)
            port_name = port_full_name

        if '[' in port_name:

            # Get tokens of the full name

            m = re.match(r'(\w+)\[(\d+)\]', port_name)

            if not m:
                raise Exception(
                    f"Could not parse port name: {port_name}. "
                    f"Currently we don't support interface or high-D array.")

            groups[i] = g = (True, m.group(1), int(m.group(2)))

            # Get type of all the ports
            t = type(getattr(model, g[1])[int(g[2])])
            types[i] = None if is_bitstruct_class(t) else t

        else:
            groups[i] = (False, port_name)
            t = type(getattr(model, port_name))
            types[i] = None if is_bitstruct_class(t) else t

    for row in test_vectors:
        row_num += 1

        # Apply test inputs

        for i in in_ids:

            in_value = row[i]
            t = types[i]
            if t:
                in_value = t(in_value)

            g = groups[i]
            if g[0]:
                getattr(model, g[1])[g[2]] = in_value
            else:
                setattr(model, g[1], in_value)

        # Evaluate combinational concurrent blocks

        model.eval_combinational()

        # Display line trace output

        if line_trace:
            model.print_line_trace()

        # Check test outputs

        for i in out_ids:
            ref_value = row[i]
            if ref_value == '?':
                continue

            g = groups[i]
            if g[0]:
                out_value = getattr(model, g[1])[g[2]]
            else:
                out_value = getattr(model, g[1])

            if out_value != ref_value:

                error_msg = """
run_test_vector_sim received an incorrect value!
- row number     : {row_number}
- port name      : {port_name}
- expected value : {expected_msg}
- actual value   : {actual_msg}
"""
                raise RunTestVectorSimError(
                    error_msg.format(row_number=row_num,
                                     port_name=port_name,
                                     expected_msg=ref_value,
                                     actual_msg=out_value))

        # Tick the simulation

        model.tick()

    # Extra ticks to make VCD easier to read

    model.tick()
    model.tick()
    model.tick()
Ejemplo n.º 13
0
  def schedule_intra_cycle( self, top ):

    # Construct the intra-cycle graph based on normal update blocks

    V   = top._dag.final_upblks - top.get_all_update_ff()

    G   = { v: [] for v in V }
    G_T = { v: [] for v in V } # transpose graph

    E = set()
    for (u, v) in top._dag.all_constraints: # u -> v
      if u in V and v in V:
        G  [u].append( v )
        G_T[v].append( u )
        E.add( (u, v) )

    if 'MAMBA_DAG' in os.environ:
      dump_dag( top, V, E )

    # Compute SCC using Kosaraju's algorithm

    SCCs, G_new = kosaraju_scc( G, G_T )

    # Perform topological sort on SCCs

    InD = { i: 0 for i in range(len(SCCs)) }
    for u, vs in G_new.items():
      for v in vs:
        InD[ v ] += 1

    scc_pred = {}
    scc_schedule = []

    Q = deque( [ i for i in range(len(SCCs)) if not InD[i] ] )
    for x in Q:
      scc_pred[ x ] = None

    while Q:
      u = Q.pop()
      scc_schedule.append( u )
      for v in G_new[u]:
        InD[v] -= 1
        if not InD[v]:
          Q.append( v )
          scc_pred[ v ] = u

    assert len(scc_schedule) == len(SCCs)

    #---------------------------------------------------------------------
    # Now we generate super blocks for each SCC and produce final schedule
    #---------------------------------------------------------------------

    constraint_objs = top._dag.constraint_objs
    onces = top.get_all_update_once()

    # Put the graph schedule to _sched
    top._sched.update_schedule = schedule = []

    scc_id = 0
    for i in scc_schedule:
      scc = SCCs[i]
      if len(scc) == 1:
        schedule.append( list(scc)[0] )
      else:

        # For each non-trivial SCC, we need to figure out a intra-SCC
        # linear schedule that minimizes the time to re-execute this SCC
        # due to value changes. A bad schedule may inefficiently execute
        # the SCC for many times, each of which changes a few signals.
        # The current algorithm iteratively finds the "entry block" of
        # the SCC and expand its adjancent blocks. The implementation is
        # to first find the actual entry point, and then BFS to expand the
        # footprint until all nodes are visited.

        # check update_once first
        for x in scc:
          if x in onces:
            raise UpblkCyclicError("update_once blocks are not allowed to appear in a cycle. \n - " + \
                            "\n - ".join( [
                              f"{y.__name__} ({'@update_once' if y in onces else '@update'} " \
                              f"in 'top.{repr(top.get_update_block_host_component(y))[2:]}')"
                              for y in scc] ))

        tmp_schedule = []
        Q = deque()

        if scc_pred[i] is None:
          # We start bfs from the block that has the least number of input
          # edges in the SCC
          InD = { v: 0 for v in scc }
          for (u, v) in E: # u -> v
            if u in scc and v in scc:
              InD[ v ] += 1
          Q.append( max(InD, key=InD.get) )

        else:
          # We start bfs with the blocks that are successors of the
          # predecessor scc in the previous SCC-level topological sort.
          pred = set( SCCs[ scc_pred[i] ] )
          # Sort by names for a fixed outcome
          for x in sorted( scc, key = lambda x: x.__name__ ):
            for v in G_T[x]: # find reversed edges point back to pred SCC
              if v in pred:
                Q.append( x )

        # Perform bfs to find a heuristic schedule
        visited = set(Q)
        while Q:
          u = Q.popleft()
          tmp_schedule.append( u )
          for v in G[u]:
            if v in scc and v not in visited:
              Q.append( v )
              visited.add( v )

        scc_id += 1
        variables = set()
        for (u, v) in E:
          # Collect all variables that triggers other blocks in the SCC
          if u in scc and v in scc:
            variables.update( constraint_objs[ (u, v) ] )

        if len(variables) == 0:
          raise UpblkCyclicError("There is a cyclic dependency without involving variables."
                          "Probably a loop that involves blocks that should be update_once:\n{}"\
                          .format(", ".join( [ x.__name__ for x in scc] )))

        # generate a loop for scc
        # Shunning: we just simply loop over the whole SCC block
        # TODO performance optimizations using Mamba techniques within a SCC block

        def gen_wrapped_SCCblk( s, scc, src ):

          # TODO mamba?
          scc_tick_func = SimpleTickPass.gen_tick_function( scc )
          _globals = { 's': s, 'scc_tick_func': scc_tick_func, 'deepcopy': deepcopy,
                       'UpblkCyclicError': UpblkCyclicError }
          _locals  = {}

          custom_exec(py.code.Source( src ).compile(), _globals, _locals)
          return _locals[ 'generated_block' ]

        template = """
def wrapped_SCC_{0}():
  N = 0
  while True:
    N += 1
    if N > 100:
      raise UpblkCyclicError("Combinational loop detected at runtime in {{{3}}} after 100 iters!")
    {1}
    scc_tick_func()
    {2}
    # print( "SCC block{0} is executed", num_iters, "times" )
    break
generated_block = wrapped_SCC_{0}
          """

        copy_srcs  = []
        check_srcs = []
        # print_srcs = []

        # clean up non-top variables if top is there. remove slices

        final_variables = set()

        for x in sorted( variables, key=repr ):
          w = x.get_top_level_signal()
          if w is x:
            final_variables.add( x )
            continue

          # w is not x
          if issubclass( w._dsl.Type, Bits ):
            if w not in final_variables:
              final_variables.add( w )
          elif is_bitstruct_class( w._dsl.Type ):
            if w not in final_variables:
              final_variables.add( x )
          else:
            final_variables.add( x )

        # group them by host component so that we create less bytecode

        final_var_host = defaultdict(list)
        for x in final_variables:
          final_var_host[ x.get_host_component() ].append( x )

        # create a block of copy/check code for each host component. Need
        # to allocate global var_id across different host components.

        var_id = 0
        for host, var_list in final_var_host.items():

          copy_srcs .append( f"host={host!r}" )
          check_srcs.append( f"host={host!r}" )

          sub_check_srcs = []

          hostlen = len(repr(host))
          for var in var_list:
            var_id += 1
            subname = repr(var)[hostlen+1:]
            if issubclass( var._dsl.Type, Bits ):     copy_srcs.append( f"t{var_id}=host.{subname}.clone()" )
            elif is_bitstruct_class( var._dsl.Type ): copy_srcs.append( f"t{var_id}=host.{subname}.clone()" )
            else:                                     copy_srcs.append( f"t{var_id}=deepcopy(host.{subname})" )

            sub_check_srcs.append( f"host.{subname} != t{var_id}" )

          check_srcs.append( f"if { ' or '.join(sub_check_srcs)}: continue" )

        scc_block_src = template.format( scc_id, "; ".join( copy_srcs ), "\n    ".join( check_srcs ),
                                         ", ".join( [ x.__name__ for x in scc] ) )

        # print(scc_block_src)
        schedule.append( gen_wrapped_SCCblk( top, tmp_schedule, scc_block_src ) )
Ejemplo n.º 14
0
        def extract_obj_from_names(func,
                                   names,
                                   update_ff=False,
                                   is_write=False):
            def expand_array_index(obj, name_depth, node_depth, idx_depth,
                                   idx):
                """ Find s.x[0][*][2], if index is exhausted, jump back to lookup_variable """

                if idx_depth >= len(
                        idx):  # exhausted, go to next level of name
                    lookup_variable(obj, name_depth + 1, node_depth + 1)
                    return

                current_idx = idx[idx_depth]

                if current_idx == "*":  # special case, materialize all objects
                    if isinstance(
                            obj,
                            NamedObject):  # Signal[*] is the signal itself
                        objs.add(obj)
                    else:
                        for i, child in enumerate(obj):
                            expand_array_index(child, name_depth, node_depth,
                                               idx_depth + 1, idx)

                # Here we try to find the value of free variables in the current
                # component scope.
                # tuple: [x] where x is a closure/global variable
                # slice: [x:y] where x and y are either normal integers or
                #        closure/global variable.
                else:
                    if isinstance(current_idx, tuple):
                        is_closure, name = current_idx
                        current_idx = _closure[
                            name] if is_closure else _globals[name]
                    elif isinstance(current_idx, slice):
                        start = current_idx.start
                        if isinstance(start, tuple):
                            is_closure, name = start
                            start = _closure[name] if is_closure else _globals[
                                name]
                        stop = current_idx.stop
                        if isinstance(stop, tuple):
                            is_closure, name = stop
                            stop = _closure[name] if is_closure else _globals[
                                name]
                        current_idx = slice(start, stop)

                    try:
                        child = obj[current_idx]
                    except TypeError:  # cannot convert to integer
                        raise VarNotDeclaredError(obj, current_idx, func, s,
                                                  nodelist[node_depth].lineno)
                    except IndexError:
                        return
                    except AssertionError:
                        raise InvalidIndexError(obj, current_idx, func, s,
                                                nodelist[node_depth].lineno)

                    expand_array_index(child, name_depth, node_depth,
                                       idx_depth + 1, idx)

            def lookup_variable(obj, name_depth, node_depth):
                """ Look up the object s.a.b.c in s. Jump to expand_array_index if c[] """
                if obj is None:
                    return

                if name_depth >= len(obj_name):  # exhausted
                    if isinstance(obj, NamedObject):
                        objs.add(obj)
                    elif isinstance(
                            obj, list
                    ) and obj:  # Exhaust all the elements in the high-d array
                        Q = [
                            *obj
                        ]  # PEP 448 -- see https://stackoverflow.com/a/43220129/6470797
                        while Q:
                            m = Q.pop()
                            if isinstance(m, NamedObject):
                                objs.add(m)
                            elif isinstance(m, list):
                                Q.extend(m)
                    return

                # still have names
                field, idx = obj_name[name_depth]
                try:
                    child = getattr(obj, field)
                except AttributeError as e:
                    print(e)
                    raise VarNotDeclaredError(obj, field, func, s,
                                              nodelist[node_depth].lineno)

                if not idx:
                    lookup_variable(child, name_depth + 1, node_depth + 1)
                else:
                    expand_array_index(child, name_depth, node_depth + 1, 0,
                                       idx)

            """ extract_obj_from_names:
      Here we enumerate names and use the above functions to turn names
      into objects """

            _globals = func.__globals__

            _closure = {}
            for i, var in enumerate(func.__code__.co_freevars):
                try:
                    _closure[var] = func.__closure__[i].cell_contents
                except ValueError:
                    pass

            all_objs = set()

            for obj_name, nodelist in names:
                if obj_name[0][0] == "s":

                    objs = set()
                    lookup_variable(s, 1, 1)
                    for obj in objs:
                        if not isinstance(obj, Signal) and is_write:
                            raise InvalidUpblkWriteError(
                                s, func, nodelist[0].lineno, obj)
                        all_objs.add(obj)

                    # Check <<= in update_ff
                    if update_ff:
                        for x in objs:
                            x._dsl.needs_double_buffer = True
                            if not x.is_top_level_signal():
                                raise InvalidFFAssignError(
                                    s, func, nodelist[0].lineno,
                                    "has an invalid left value. It needs to be a top level signal, not a slice or a subfield."
                                )
                            if not issubclass(x._dsl.Type,
                                              Bits) and not is_bitstruct_class(
                                                  x._dsl.Type):
                                raise InvalidFFAssignError(
                                    s, func, nodelist[0].lineno,
                                    "has a wrong type on the left value. "
                                    "We only allow <<= on Bits/BitStruct type signals, not {x._dsl.Type}"
                                )

                # This is a function call without "s." prefix, check func list
                elif obj_name[0][0] in s._dsl.name_func:
                    call = s._dsl.name_func[obj_name[0][0]]
                    all_objs.add(call)

            return all_objs
Ejemplo n.º 15
0
    def extract_obj_from_names( func, names, update_ff=False ):

      def expand_array_index( obj, name_depth, node_depth, idx_depth, idx ):
        """ Find s.x[0][*][2], if index is exhausted, jump back to lookup_variable """

        if idx_depth >= len(idx): # exhausted, go to next level of name
          lookup_variable( obj, name_depth+1, node_depth+1 )

        elif idx[ idx_depth ] == "*": # special case, materialize all objects
          if isinstance( obj, NamedObject ): # Signal[*] is the signal itself
            objs.add( obj )
          else:
            for i, child in enumerate( obj ):
              expand_array_index( child, name_depth, node_depth, idx_depth+1, idx )
        else:
          try:
            child = obj[ idx[ idx_depth ] ]
          except TypeError: # cannot convert to integer
            raise VarNotDeclaredError( obj, idx[idx_depth], func, s, nodelist[node_depth].lineno )
          except IndexError:
            return

          expand_array_index( child, name_depth, node_depth, idx_depth+1, idx )

      def lookup_variable( obj, name_depth, node_depth ):
        """ Look up the object s.a.b.c in s. Jump to expand_array_index if c[] """
        if obj is None:
          return

        if name_depth >= len(obj_name): # exhausted
          if   isinstance( obj, NamedObject ):
            objs.add( obj )
          elif isinstance( obj, list ) and obj: # Exhaust all the elements in the high-d array
            Q = [ *obj ] # PEP 448 -- see https://stackoverflow.com/a/43220129/6470797
            while Q:
              m = Q.pop()
              if isinstance( Q, NamedObject ):
                objs.add( m )
              elif isinstance( m, list ):
                Q.extend( m )
          return

        # still have names
        field, idx = obj_name[ name_depth ]
        try:
          child = getattr( obj, field )
        except AttributeError as e:
          print(e)
          raise VarNotDeclaredError( obj, field, func, s, nodelist[node_depth].lineno )

        if not idx: lookup_variable   ( child, name_depth+1, node_depth+1 )
        else:       expand_array_index( child, name_depth,   node_depth+1, 0, idx )

      """ extract_obj_from_names:
      Here we enumerate names and use the above functions to turn names
      into objects """

      all_objs = set()

      for obj_name, nodelist in names:
        if obj_name[0][0] == "s":
          objs = set()
          lookup_variable( s, 1, 1 )
          all_objs |= objs

          # Check <<= in update_ff
          if update_ff:
            for x in objs:
              x._dsl.needs_double_buffer = True
              if not x.is_top_level_signal():
                raise InvalidFFAssignError( s, func, nodelist[0].lineno,
                  "has an invalid left value. It needs to be a top level signal, not a slice or a subfield.")
              if not issubclass( x._dsl.Type, Bits ) and not is_bitstruct_class( x._dsl.Type ):
                raise InvalidFFAssignError( s, func, nodelist[0].lineno,
                  "has a wrong type on the left value. "
                  "We only allow <<= on Bits/BitStruct type signals, not {x._dsl.Type}")

        # This is a function call without "s." prefix, check func list
        elif obj_name[0][0] in s._dsl.name_func:
          call = s._dsl.name_func[ obj_name[0][0] ]
          all_objs.add( call )

      return all_objs