class FSM(vtypes.VeriloggenNode): """ Finite State Machine Generator """ def __init__(self, m, name, clk, rst, width=32, initname='init', nohook=False, as_module=False): self.m = m self.name = name self.clk = clk self.rst = rst self.width = width self.state_count = 0 self.state = self.m.Reg(name, width) # set initval later self.mark = collections.OrderedDict() # key:index self._set_mark(0, self.name + '_' + initname) self.state.initval = self._get_mark(0) self.body = collections.defaultdict(list) self.jump = collections.defaultdict(list) self.delay_amount = 0 self.delayed_state = collections.OrderedDict() # key:delay self.delayed_body = collections.defaultdict( functools.partial(collections.defaultdict, list)) # key:delay self.delayed_cond = collections.OrderedDict() # key:name self.tmp_count = 0 self.dst_var = collections.OrderedDict() self.dst_visitor = SubstDstVisitor() self.reset_visitor = ResetVisitor() self.seq = Seq(self.m, self.name + '_par', clk, rst, nohook=True) self.done = False self.last_cond = [] self.last_kwargs = {} self.last_if_statement = None self.elif_cond = None self.next_kwargs = {} self.as_module = as_module if not nohook: self.m.add_hook(self.implement) # ------------------------------------------------------------------------- def goto(self, dst, cond=None, else_dst=None): if cond is None and 'cond' in self.next_kwargs: cond = self.next_kwargs['cond'] self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() src = self.current return self._go(src, dst, cond, else_dst) def goto_init(self, cond=None): if cond is None and 'cond' in self.next_kwargs: cond = self.next_kwargs['cond'] self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() src = self.current dst = 0 return self._go(src, dst, cond) def goto_next(self, cond=None): if cond is None and 'cond' in self.next_kwargs: cond = self.next_kwargs['cond'] self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() src = self.current dst = self.current + 1 ret = self._go(src, dst, cond=cond) self.inc() return ret def goto_from(self, src, dst, cond=None, else_dst=None): if cond is None and 'cond' in self.next_kwargs: cond = self.next_kwargs['cond'] self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() return self._go(src, dst, cond, else_dst) def inc(self): self._set_index(None) # ------------------------------------------------------------------------- def add(self, *statement, **kwargs): """ add new assignments """ kwargs.update(self.next_kwargs) self.last_kwargs = kwargs self._clear_next_kwargs() # if there is no attributes, Elif object is reused. has_args = not (len(kwargs) == 0 or # has no args (len(kwargs) == 1 and 'cond' in kwargs) ) # has only 'cond' if self.elif_cond is not None and not has_args: next_call = self.last_if_statement.Elif(self.elif_cond) next_call(*statement) self.last_if_statement = next_call self._add_dst_var(statement) self._clear_elif_cond() return self self._clear_last_if_statement() return self._add_statement(statement, **kwargs) # ------------------------------------------------------------------------- def add_reset(self, v): return self.seq.add_reset(v) # ------------------------------------------------------------------------- def Prev(self, var, delay, initval=0, cond=None, prefix=None): return self.seq.Prev(var, delay, initval, cond, prefix) # ------------------------------------------------------------------------- def If(self, *cond): self._clear_elif_cond() cond = make_condition(*cond) if cond is None: return self if 'cond' not in self.next_kwargs: self.next_kwargs['cond'] = cond else: self.next_kwargs['cond'] = vtypes.Ands(self.next_kwargs['cond'], cond) self.last_cond = [self.next_kwargs['cond']] return self def Else(self, *statement, **kwargs): self._clear_elif_cond() if len(self.last_cond) == 0: raise ValueError("No previous condition for Else.") old = self.last_cond.pop() self.last_cond.append(vtypes.Not(old)) # if the true-statement has delay attributes, # Else statement is separated. if 'delay' in self.last_kwargs and self.last_kwargs['delay'] > 0: prev_cond = self.last_cond ret = self.Then()(*statement) self.last_cond = prev_cond return ret # if there is additional attribute, Else statement is separated. has_args = not (len(self.next_kwargs) == 0 or # has no args (len(self.next_kwargs) == 1 and 'cond' in kwargs) ) # has only 'cond' if has_args: prev_cond = self.last_cond ret = self.Then()(*statement) self.last_cond = prev_cond return ret if not isinstance(self.last_if_statement, vtypes.If): raise ValueError("Last if-statement is not If") self.last_if_statement.Else(*statement) self._add_dst_var(statement) return self def Elif(self, *cond): if len(self.last_cond) == 0: raise ValueError("No previous condition for Else.") cond = make_condition(*cond) old = self.last_cond.pop() self.last_cond.append(vtypes.Not(old)) self.last_cond.append(cond) # if the true-statement has delay attributes, Else statement is # separated. if 'delay' in self.last_kwargs and self.last_kwargs['delay'] > 0: prev_cond = self.last_cond ret = self.Then() self.last_cond = prev_cond return ret if not isinstance(self.last_if_statement, vtypes.If): raise ValueError("Last if-statement is not If") self.elif_cond = cond cond = self._make_cond(self.last_cond) self.next_kwargs['cond'] = cond return self def Delay(self, delay): self.next_kwargs['delay'] = delay return self def Keep(self, keep): self.next_kwargs['keep'] = keep return self def Then(self): cond = self._make_cond(self.last_cond) self._clear_last_cond() self.If(cond) return self def LazyCond(self, value=True): self.next_kwargs['lazy_cond'] = value return self def EagerVal(self, value=True): self.next_kwargs['eager_val'] = value return self def Clear(self): self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() self._clear_elif_cond() return self # ------------------------------------------------------------------------- @property def current(self): return self.state_count @property def next(self): return self.current + 1 @property def current_delay(self): if 'delay' in self.next_kwargs: return self.next_kwargs['delay'] return 0 @property def last_delay(self): if 'delay' in self.last_kwargs: return self.last_kwargs['delay'] return 0 @property def current_condition(self): cond = self.next_kwargs['cond'] if 'cond' in self.next_kwargs else None if cond is not None: cond = vtypes.AndList(self.state == self.state_count, cond) else: cond = self.state == self.state_count return cond @property def last_condition(self): cond = self._make_cond(self.last_cond) if cond is not None: cond = vtypes.AndList(self.state == self.state_count, cond) else: cond = self.state == self.state_count return cond @property def then(self): return self.last_condition @property def here(self): return self.state == self.current # ------------------------------------------------------------------------- def implement(self): if self.as_module: self.make_module() return self.make_always() # ------------------------------------------------------------------------- def make_always(self, reset=(), body=(), case=True): if self.done: #raise ValueError('make_always() has been already called.') return self.done = True part_reset = self.make_reset(reset) part_body = list(body) + list( self.make_case() if case else self.make_if()) self.m.Always(vtypes.Posedge(self.clk))(vtypes.If(self.rst)( part_reset, )(part_body, ))
class Stream(object): def __init__(self, *nodes, **opts): # ID for manager reuse and merge global _stream_counter self.object_id = _stream_counter _stream_counter += 1 self.nodes = set() self.named_numerics = OrderedDict() self.add(*nodes) self.max_stage = 0 self.last_input = None self.last_output = None self.module = opts['module'] if 'module' in opts else None self.clock = opts['clock'] if 'clock' in opts else None self.reset = opts['reset'] if 'reset' in opts else None self.ivalid = opts['ivalid'] if 'ivalid' in opts else None self.iready = opts['iready'] if 'iready' in opts else None self.ovalid = opts['ovalid'] if 'ovalid' in opts else None self.oready = opts['oready'] if 'oready' in opts else None self.aswire = opts['aswire'] if 'aswire' in opts else True self.dump = opts['dump'] if 'dump' in opts else False self.dump_base = opts['dump_base'] if 'dump_base' in opts else 10 self.dump_mode = opts['dump_mode'] if 'dump_mode' in opts else 'all' self.seq = None self.has_control = False self.implemented = False if (self.module is not None and self.clock is not None and self.reset is not None): no_hook = opts['no_hook'] if 'no_hook' in opts else False if not no_hook: self.module.add_hook(self.implement) seq_name = (opts['seq_name'] if 'seq_name' in opts else '_stream_seq_%d' % self.object_id) self.seq = Seq(self.module, seq_name, self.clock, self.reset) if self.dump: dump_enable_name = '_stream_dump_enable_%d' % self.object_id dump_enable = self.module.Reg(dump_enable_name, initval=0) dump_mask_name = '_stream_dump_mask_%d' % self.object_id dump_mask = self.module.Reg(dump_mask_name, initval=0) dump_step_name = '_stream_dump_step_%d' % self.object_id dump_step = self.module.Reg(dump_step_name, 32, initval=0) self.dump_enable = dump_enable self.dump_mask = dump_mask self.dump_step = dump_step if self.seq: self.seq.add_reset(self.dump_enable) self.seq.add_reset(self.dump_mask) # ------------------------------------------------------------------------- def add(self, *nodes): self.nodes.update(set(nodes)) for node in nodes: if hasattr(node, 'input_data'): if isinstance(node.input_data, str): name = node.input_data else: name = node.input_data.name self.named_numerics[name] = node elif hasattr(node, 'output_data'): if node.output_data is None: continue if isinstance(node.output_data, str): name = node.output_data else: name = node.output_data.name self.named_numerics[name] = node # ------------------------------------------------------------------------- def to_module(self, name, clock='CLK', reset='RST', aswire=False, seq_name=None): """ generate a Module definion """ m = Module(name) clk = m.Input(clock) rst = m.Input(reset) m = self.implement(m, clk, rst, aswire=aswire, seq_name=seq_name) return m # ------------------------------------------------------------------------- def implement(self, m=None, clock=None, reset=None, aswire=None, seq_name=None): """ implemente actual registers and operations in Verilog """ if self.implemented: if m is None: return self.module raise ValueError('already implemented.') self.implemented = True if m is None: m = self.module if self.module is None: self.module = m if clock is None: clock = self.clock if reset is None: reset = self.reset if self.seq is None: if seq_name is None: seq_name = '_stream_seq_%d' % self.object_id seq = Seq(m, seq_name, clock, reset) else: seq = self.seq if aswire is None: aswire = self.aswire self.add_control(aswire=aswire) self.has_control = True # for mult and div m._clock = clock m._reset = reset stream_nodes = self.nodes input_visitor = visitor.InputVisitor() input_vars = set() for node in sorted(stream_nodes, key=lambda x: x.object_id): input_vars.update(input_visitor.visit(node)) output_visitor = visitor.OutputVisitor() output_vars = set() for node in sorted(stream_nodes, key=lambda x: x.object_id): output_vars.update(output_visitor.visit(node)) # add input ports for input_var in sorted(input_vars, key=lambda x: x.object_id): input_var._implement_input(m, seq, aswire) # schedule sched = scheduler.ASAPScheduler() sched.schedule(output_vars) # balance output stage depth max_stage = 0 for output_var in sorted(output_vars, key=lambda x: x.object_id): max_stage = stypes._max(max_stage, output_var.end_stage) self.max_stage = max_stage output_vars = sched.balance_output(output_vars, max_stage) # get all vars all_visitor = visitor.AllVisitor() all_vars = set() for output_var in sorted(output_vars, key=lambda x: x.object_id): all_vars.update(all_visitor.visit(output_var)) # control (valid and ready) if not self.has_control: self.add_control(aswire) self.implement_control(seq) # allocate (implement signals) alloc = allocator.Allocator() alloc.allocate(m, seq, all_vars, self.valid_list, self.senable) # set default module information for var in sorted(all_vars, key=lambda x: x.object_id): var._set_module(m) var._set_strm(self) if var.seq is not None: seq.update(var.seq) var._set_seq(seq) # add output ports for output_var in sorted(output_vars, key=lambda x: x.object_id): output_var._implement_output(m, seq, aswire) # save schedule result self.last_input = input_vars self.last_output = output_vars if self.dump: self.add_dump(m, seq, input_vars, output_vars, all_vars) return m def add_dump(self, m, seq, input_vars, output_vars, all_vars): pipeline_depth = self.pipeline_depth() log_pipeline_depth = max( int(math.ceil(math.log(max(pipeline_depth, 10), 10))), 1) seq( self.dump_step(1) ) for i in range(pipeline_depth + 1): seq.If(seq.Prev(self.dump_enable, i))( self.dump_step.inc() ) def get_name(obj): if hasattr(obj, 'name'): return obj.name if isinstance(obj, vtypes._Constant): return obj.__class__.__name__ raise TypeError() longest_name_len = 0 for input_var in sorted(input_vars, key=lambda x: x.object_id): if not (self.dump_mode == 'all' or self.dump_mode == 'stream' or self.dump_mode == 'input' or self.dump_mode == 'inout' or (self.dump_mode == 'selective' and hasattr(input_var, 'dump') and input_var.dump)): continue name = get_name(input_var.sig_data) length = len(name) + 6 longest_name_len = max(longest_name_len, length) for var in sorted(all_vars, key=lambda x: (-1, x.object_id) if x.end_stage is None else (x.end_stage, x.object_id)): if not (self.dump_mode == 'all' or self.dump_mode == 'stream' or (self.dump_mode == 'selective' and hasattr(var, 'dump') and var.dump)): continue name = get_name(var.sig_data) length = len(name) + 6 longest_name_len = max(longest_name_len, length) for output_var in sorted(output_vars, key=lambda x: x.object_id): if not (self.dump_mode == 'all' or self.dump_mode == 'stream' or self.dump_mode == 'output' or self.dump_mode == 'inout' or (self.dump_mode == 'selective' and hasattr(output_var, 'dump') and output_var.dump)): continue name = get_name(output_var.output_sig_data) length = len(name) + 6 longest_name_len = max(longest_name_len, length) longest_var_len = 0 for var in sorted(all_vars, key=lambda x: (-1, x.object_id) if x.start_stage is None else (x.start_stage, x.object_id)): bitwidth = vtypes.get_width(var.sig_data) if bitwidth is None: bitwidth = 1 if bitwidth <= 0: bitwidth = 1 base = (var.dump_base if hasattr(var, 'dump_base') else self.dump_base) total_length = int(math.ceil(bitwidth / math.log(base, 2))) #point_length = int(math.ceil(var.point / math.log(base, 2))) #point_length = max(point_length, 8) #longest_var_len = max(longest_var_len, total_length, point_length) longest_var_len = max(longest_var_len, total_length) for input_var in sorted(input_vars, key=lambda x: x.object_id): base = (input_var.dump_base if hasattr(input_var, 'dump_base') else self.dump_base) base_char = ('b' if base == 2 else 'o' if base == 8 else 'd' if base == 10 and input_var.point <= 0 else # 'f' if base == 10 and input_var.point > 0 else 'g' if base == 10 and input_var.point > 0 else 'x') prefix = ('0b' if base == 2 else '0o' if base == 8 else ' ' if base == 10 else '0x') # if base_char == 'f': # point_length = int(math.ceil(input_var.point / math.log(base, 2))) # point_length = max(point_length, 8) # fmt_list = [prefix, '%', # '%d.%d' % (longest_var_len + 1, point_length), base_char] # if base_char == 'g': # fmt_list = [prefix, '%', base_char] # else: # fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char] fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char] if input_var not in all_vars: fmt_list.append(' (unused)') input_var.dump_fmt = ''.join(fmt_list) for output_var in sorted(output_vars, key=lambda x: x.object_id): base = (output_var.dump_base if hasattr(output_var, 'dump_base') else self.dump_base) base_char = ('b' if base == 2 else 'o' if base == 8 else 'd' if base == 10 and output_var.point <= 0 else # 'f' if base == 10 and output_var.point > 0 else 'g' if base == 10 and output_var.point > 0 else 'x') prefix = ('0b' if base == 2 else '0o' if base == 8 else ' ' if base == 10 else '0x') # if base_char == 'f': # point_length = int(math.ceil(output_var.point / math.log(base, 2))) # point_length = max(point_length, 8) # fmt_list = [prefix, '%', # '%d.%d' % (longest_var_len + 1, point_length), base_char] # if base_char == 'g': # fmt_list = [prefix, '%', base_char] # else: # fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char] fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char] if output_var not in all_vars: fmt_list.append(' (unused)') output_var.dump_fmt = ''.join(fmt_list) for var in sorted(all_vars, key=lambda x: (-1, x.object_id) if x.start_stage is None else (x.start_stage, x.object_id)): base = (var.dump_base if hasattr(var, 'dump_base') else self.dump_base) base_char = ('b' if base == 2 else 'o' if base == 8 else 'd' if base == 10 and var.point <= 0 else # 'f' if base == 10 and var.point > 0 else 'g' if base == 10 and var.point > 0 else 'x') prefix = ('0b' if base == 2 else '0o' if base == 8 else ' ' if base == 10 else '0x') # if base_char == 'f': # point_length = int(math.ceil(var.point / math.log(base, 2))) # point_length = max(point_length, 8) # fmt_list = [prefix, '%', # '%d.%d' % (longest_var_len + 1, point_length), base_char] # if base_char == 'g': # fmt_list = [prefix, '%', base_char] # else: # fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char] fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char] var.dump_fmt = ''.join(fmt_list) enables = [] for input_var in sorted(input_vars, key=lambda x: x.object_id): if not (self.dump_mode == 'all' or self.dump_mode == 'stream' or self.dump_mode == 'input' or self.dump_mode == 'inout' or (self.dump_mode == 'selective' and hasattr(input_var, 'dump') and input_var.dump)): continue vfmt = input_var.dump_fmt name = get_name(input_var.sig_data) name_alignment = ' ' * (longest_name_len - len(name) - len('(in) ')) fmt = ''.join(['<', self.name, ' step:%d, ', 'stage:%', str( log_pipeline_depth), 'd, age:%d> (in) ', name_alignment, name, ' = ', vfmt]) stage = input_var.end_stage if input_var.end_stage is not None else 0 enable = seq.Prev(self.dump_enable, stage) enables.append(enable) age = seq.Prev(self.dump_step, stage) - 1 if input_var.point > 0: sig_data = vtypes.Div(vtypes.SystemTask('itor', input_var.sig_data), 1.0 * (2 ** input_var.point)) elif input_var.point < 0: sig_data = vtypes.Times(input_var.sig_data, 2 ** -input_var.point) else: sig_data = input_var.sig_data seq.If(enable, vtypes.Not(self.dump_mask))( vtypes.Display(fmt, self.dump_step, stage, age, sig_data) ) for var in sorted(all_vars, key=lambda x: (-1, x.object_id) if x.end_stage is None else (x.end_stage, x.object_id)): if not (self.dump_mode == 'all' or self.dump_mode == 'stream' or (self.dump_mode == 'selective' and hasattr(var, 'dump') and var.dump)): continue vfmt = var.dump_fmt name = get_name(var.sig_data) name_alignment = ' ' * (longest_name_len - len(name)) stage = var.end_stage if var.end_stage is not None else 0 fmt = ''.join(['<', self.name, ' step:%d, ', 'stage:%', str(log_pipeline_depth), 'd, age:%d> ', name_alignment, name, ' = ', vfmt]) enable = seq.Prev(self.dump_enable, stage) enables.append(enable) age = seq.Prev(self.dump_step, stage) - 1 if var.point > 0: sig_data = vtypes.Div(vtypes.SystemTask('itor', var.sig_data), 1.0 * (2 ** var.point)) elif var.point < 0: sig_data = vtypes.Times(var.sig_data, 2 ** -var.point) else: sig_data = var.sig_data seq.If(enable, vtypes.Not(self.dump_mask))( vtypes.Display(fmt, self.dump_step, stage, age, sig_data) ) for output_var in sorted(output_vars, key=lambda x: x.object_id): if not (self.dump_mode == 'all' or self.dump_mode == 'stream' or self.dump_mode == 'output' or self.dump_mode == 'inout' or (self.dump_mode == 'selective' and hasattr(output_var, 'dump') and output_var.dump)): continue vfmt = output_var.dump_fmt name = get_name(output_var.output_sig_data) name_alignment = ' ' * (longest_name_len - len(name) - len('(out) ')) fmt = ''.join(['<', self.name, ' step:%d, ', 'stage:%', str( log_pipeline_depth), 'd, age:%d> (out) ', name_alignment, name, ' = ', vfmt]) stage = output_var.end_stage if output_var.end_stage is not None else 0 enable = seq.Prev(self.dump_enable, stage) enables.append(enable) age = seq.Prev(self.dump_step, stage) - 1 if output_var.point > 0: sig_data = vtypes.Div(vtypes.SystemTask('itor', output_var.output_sig_data), 1.0 * (2 ** output_var.point)) elif output_var.point < 0: sig_data = vtypes.Times(output_var.output_sig_data, 2 ** -output_var.point) else: sig_data = output_var.output_sig_data seq.If(enable, vtypes.Not(self.dump_mask))( vtypes.Display(fmt, self.dump_step, stage, age, sig_data) ) # ------------------------------------------------------------------------- def add_control(self, aswire=True): if self.ivalid is not None and isinstance(self.ivalid, str): if aswire: self.ivalid = self.module.Wire(self.ivalid) else: self.ivalid = self.module.Input(self.ivalid) if self.iready is not None and isinstance(self.iready, str): if aswire: self.iready = self.module.Wire(self.iready) else: self.iready = self.module.Output(self.iready) if self.ovalid is not None and isinstance(self.ovalid, str): if aswire: self.ovalid = self.module.Wire(self.ovalid) else: self.ovalid = self.module.Output(self.ovalid) if self.oready is not None and isinstance(self.oready, str): if aswire: self.oready = self.module.Wire(self.oready) else: self.oready = self.module.Input(self.oready) def implement_control(self, seq): self.valid_list = None if self.ivalid is None and self.oready is None: self.senable = None if self.ovalid is not None: self.ovalid.assign(1) if self.iready is not None: self.iready.assign(1) return if self.ivalid is None: self.senable = self.oready if self.iready is not None: self.iready.assign(self.senable) return if self.oready is None: self.senable = None else: self.senable = self.oready self._make_valid_chain(seq, self.senable) if self.iready is not None: self.iready.assign(self.senable) def _make_valid_chain(self, seq, cond=None): self.valid_list = [] self.valid_list.append(self.ivalid) name = self.ivalid.name prev = self.ivalid for i in range(self.max_stage): v = self.module.Reg("_{}_{}".format(name, i + 1), initval=0) self.valid_list.append(v) seq(v(prev), cond=cond) prev = v if self.ovalid is not None: self.ovalid.assign(prev) # ------------------------------------------------------------------------- def draw_graph(self, filename='out.png', prog='dot', rankdir='LR', approx=False): if self.last_output is None: self.to_module() graph.draw_graph(self.last_output, filename=filename, prog=prog, rankdir=rankdir, approx=approx) def enable_draw_graph(self, filename='out.png', prog='dot', rankdir='LR', approx=False): self.module.add_hook(self.draw_graph, kwargs={'filename': filename, 'prog': prog, 'rankdir': rankdir, 'approx': approx}) # ------------------------------------------------------------------------- def get_input(self): if self.last_input is None: return OrderedDict() ret = OrderedDict() for input_var in sorted(self.last_input, key=lambda x: x.object_id): key = str(input_var.input_data) value = input_var ret[key] = value return ret def get_output(self): if self.last_output is None: return OrderedDict() ret = OrderedDict() for output_var in sorted(self.last_output, key=lambda x: x.object_id): key = str(output_var.output_data) value = output_var ret[key] = value return ret # ------------------------------------------------------------------------- def pipeline_depth(self): return self.max_stage # ------------------------------------------------------------------------- def __getattr__(self, attr): try: return object.__getattribute__(self, attr) except AttributeError as e: if attr.startswith('__') or attr not in dir(stypes): raise e func = getattr(stypes, attr) @functools.wraps(func) def wrapper(*args, **kwargs): v = func(*args, **kwargs) if isinstance(v, (tuple, list)): for item in v: self._set_info(item) else: self._set_info(v) return v return wrapper def _set_info(self, v): if isinstance(v, stypes._Numeric): v._set_module(self.module) v._set_strm(self) v._set_seq(self.seq) self.add(v) def get_named_numeric(self, name): if name not in self.named_numerics: raise NameError("Numeric '%s' is not defined." % name) return self.named_numerics[name]
class FSM(vtypes.VeriloggenNode): """ Finite State Machine Generator """ def __init__(self, m, name, clk, rst, width=32, initname='init', nohook=False, as_module=False): self.m = m self.name = name self.clk = clk self.rst = rst self.width = width self.state_count = 0 self.state = self.m.Reg(name, width) # set initval later self.mark = collections.OrderedDict() # key:index self._set_mark(0, self.name + '_' + initname) self.state.initval = self._get_mark(0) self.body = collections.defaultdict(list) self.jump = collections.defaultdict(list) self.delay_amount = 0 self.delayed_state = collections.OrderedDict() # key:delay self.delayed_body = collections.defaultdict( functools.partial(collections.defaultdict, list)) # key:delay self.delayed_cond = collections.OrderedDict() # key:name self.tmp_count = 0 self.dst_var = collections.OrderedDict() self.dst_visitor = SubstDstVisitor() self.reset_visitor = ResetVisitor() self.seq = Seq(self.m, self.name + '_par', clk, rst, nohook=True) self.done = False self.last_cond = [] self.last_kwargs = {} self.last_if_statement = None self.elif_cond = None self.next_kwargs = {} self.as_module = as_module if not nohook: self.m.add_hook(self.implement) # ------------------------------------------------------------------------- def goto(self, dst, cond=None, else_dst=None): if cond is None and 'cond' in self.next_kwargs: cond = self.next_kwargs['cond'] self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() src = self.current return self._go(src, dst, cond, else_dst) def goto_init(self, cond=None): if cond is None and 'cond' in self.next_kwargs: cond = self.next_kwargs['cond'] self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() src = self.current dst = 0 return self._go(src, dst, cond) def goto_next(self, cond=None): if cond is None and 'cond' in self.next_kwargs: cond = self.next_kwargs['cond'] self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() src = self.current dst = self.current + 1 ret = self._go(src, dst, cond=cond) self.inc() return ret def goto_from(self, src, dst, cond=None, else_dst=None): if cond is None and 'cond' in self.next_kwargs: cond = self.next_kwargs['cond'] self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() return self._go(src, dst, cond, else_dst) def inc(self): self._set_index(None) # ------------------------------------------------------------------------- def add(self, *statement, **kwargs): """ add new assignments """ kwargs.update(self.next_kwargs) self.last_kwargs = kwargs self._clear_next_kwargs() # if there is no attributes, Elif object is reused. has_args = not (len(kwargs) == 0 or # has no args (len(kwargs) == 1 and 'cond' in kwargs)) # has only 'cond' if self.elif_cond is not None and not has_args: next_call = self.last_if_statement.Elif(self.elif_cond) next_call(*statement) self.last_if_statement = next_call self._add_dst_var(statement) self._clear_elif_cond() return self self._clear_last_if_statement() return self._add_statement(statement, **kwargs) # ------------------------------------------------------------------------- def add_reset(self, v): return self.seq.add_reset(v) # ------------------------------------------------------------------------- def Prev(self, var, delay, initval=0, cond=None, prefix=None): return self.seq.Prev(var, delay, initval, cond, prefix) # ------------------------------------------------------------------------- def If(self, *cond): self._clear_elif_cond() cond = make_condition(*cond) if cond is None: return self if 'cond' not in self.next_kwargs: self.next_kwargs['cond'] = cond else: self.next_kwargs['cond'] = vtypes.Ands( self.next_kwargs['cond'], cond) self.last_cond = [self.next_kwargs['cond']] return self def Else(self, *statement, **kwargs): self._clear_elif_cond() if len(self.last_cond) == 0: raise ValueError("No previous condition for Else.") old = self.last_cond.pop() self.last_cond.append(vtypes.Not(old)) # if the true-statement has delay attributes, # Else statement is separated. if 'delay' in self.last_kwargs and self.last_kwargs['delay'] > 0: prev_cond = self.last_cond ret = self.Then()(*statement) self.last_cond = prev_cond return ret # if there is additional attribute, Else statement is separated. has_args = not (len(self.next_kwargs) == 0 or # has no args (len(self.next_kwargs) == 1 and 'cond' in kwargs)) # has only 'cond' if has_args: prev_cond = self.last_cond ret = self.Then()(*statement) self.last_cond = prev_cond return ret if not isinstance(self.last_if_statement, vtypes.If): raise ValueError("Last if-statement is not If") self.last_if_statement.Else(*statement) self._add_dst_var(statement) return self def Elif(self, *cond): if len(self.last_cond) == 0: raise ValueError("No previous condition for Else.") cond = make_condition(*cond) old = self.last_cond.pop() self.last_cond.append(vtypes.Not(old)) self.last_cond.append(cond) # if the true-statement has delay attributes, Else statement is # separated. if 'delay' in self.last_kwargs and self.last_kwargs['delay'] > 0: prev_cond = self.last_cond ret = self.Then() self.last_cond = prev_cond return ret if not isinstance(self.last_if_statement, vtypes.If): raise ValueError("Last if-statement is not If") self.elif_cond = cond cond = self._make_cond(self.last_cond) self.next_kwargs['cond'] = cond return self def Delay(self, delay): self.next_kwargs['delay'] = delay return self def Keep(self, keep): self.next_kwargs['keep'] = keep return self def Then(self): cond = self._make_cond(self.last_cond) self._clear_last_cond() self.If(cond) return self def LazyCond(self, value=True): self.next_kwargs['lazy_cond'] = value return self def EagerVal(self, value=True): self.next_kwargs['eager_val'] = value return self def Clear(self): self._clear_next_kwargs() self._clear_last_if_statement() self._clear_last_cond() self._clear_elif_cond() return self # ------------------------------------------------------------------------- @property def current(self): return self.state_count @property def next(self): return self.current + 1 @property def current_delay(self): if 'delay' in self.next_kwargs: return self.next_kwargs['delay'] return 0 @property def last_delay(self): if 'delay' in self.last_kwargs: return self.last_kwargs['delay'] return 0 @property def current_condition(self): cond = self.next_kwargs['cond'] if 'cond' in self.next_kwargs else None if cond is not None: cond = vtypes.AndList(self.state == self.state_count, cond) else: cond = self.state == self.state_count return cond @property def last_condition(self): cond = self._make_cond(self.last_cond) if cond is not None: cond = vtypes.AndList(self.state == self.state_count, cond) else: cond = self.state == self.state_count return cond @property def then(self): return self.last_condition @property def here(self): return self.state == self.current # ------------------------------------------------------------------------- def implement(self): if self.as_module: self.make_module() return self.make_always() # ------------------------------------------------------------------------- def make_always(self, reset=(), body=(), case=True): if self.done: #raise ValueError('make_always() has been already called.') return self.done = True part_reset = self.make_reset(reset) part_body = list(body) + list(self.make_case() if case else self.make_if()) self.m.Always(vtypes.Posedge(self.clk))( vtypes.If(self.rst)( part_reset, )( part_body, )) # ------------------------------------------------------------------------- def make_module(self, reset=(), body=(), case=True): if self.done: #raise ValueError('make_always() has been already called.') return self.done = True m = Module('sub_%s' % self.name) clk = m.Input('CLK') if self.rst is not None: rst = m.Input('RST') else: rst = None body = list(body) + list(self.make_case() if case else self.make_if()) dst_var = self.seq.dst_var dst_var.update(self.dst_var) dsts = dst_var.values() src_visitor = SubstSrcVisitor() # collect sources in destination variable definitions for dst in dsts: if isinstance(dst, (vtypes.Pointer, vtypes.Slice, vtypes.Cat)): raise ValueError( 'Partial assignment is not supported by as_module mode.') if isinstance(dst, vtypes._Variable): if dst.width is not None: src_visitor.visit(dst.width) if dst.length is not None: src_visitor.visit(dst.length) # collect sources in statements for statement in body: src_visitor.visit(statement) srcs = src_visitor.srcs.values() # collect sources in source variable definitions for src in srcs: if isinstance(src, vtypes._Variable): if src.width is not None: src_visitor.visit(src.width) if src.length is not None: src_visitor.visit(src.length) srcs = src_visitor.srcs.values() params = collections.OrderedDict() ports = collections.OrderedDict() src_rename_dict = collections.OrderedDict() fsm_orig_labels = [v.name for v in self.mark.values()] fsm_labels = collections.OrderedDict() # create parameter/localparam definitions for src in srcs: if isinstance(src, (vtypes.Parameter, vtypes.Localparam)): if src.name in fsm_orig_labels: fsm_labels[src.name] = m.Localparam(src.name, src.value) continue arg_name = src.name v = m.Parameter(arg_name, src.value, src.width, src.signed) src_rename_dict[src.name] = v params[arg_name] = src src_rename_visitor = SrcRenameVisitor(src_rename_dict) state_width = src_rename_visitor.visit(self.state.width) state_initval = src_rename_visitor.visit(self.state.initval) state = m.OutputReg(self.state.name, state_width, initval=state_initval) out_state = self.m.TmpWire(state_width, prefix='_%s_out' % self.state.name) self.m.Always()(self.state(out_state, blk=True)) ports[state.name] = out_state src_rename_dict[self.state.name] = state for delay, s in sorted(self.delayed_state.items(), key=lambda x: x[0]): s_width = src_rename_visitor.visit(s.width) s_initval = src_rename_visitor.visit(s.initval) d = m.OutputReg(s.name, s_width, initval=s_initval) out_d = self.m.TmpWire(s_width, prefix='_%s_out' % d.name) self.m.Always()(s(out_d, blk=True)) ports[s.name] = out_d state_names = [self.state.name] state_names.extend([s.name for s in self.delayed_state.values()]) for src in srcs: if isinstance(src, (vtypes.Parameter, vtypes.Localparam)): continue if src.name in state_names: continue if src.name in list(self.delayed_cond.keys()): rep_width = (src_rename_visitor.visit(src.width) if src.width is not None else None) v = m.Reg(src.name, rep_width, initval=0) src_rename_dict[src.name] = v continue arg_name = 'i_%s' % src.name if src.length is not None: width = src.bit_length() length = src.length pack_width = vtypes.Mul(width, length) out_line = self.m.TmpWire(pack_width, prefix='_%s' % self.name) i = self.m.TmpGenvar(prefix='i') v = out_line[i * width:(i + 1) * width] g = self.m.GenerateFor(i(0), i < length, i(i + 1)) p = g.Assign(v(src[i])) rep_width = (src_rename_visitor.visit(src.width) if src.width is not None else None) rep_length = src_rename_visitor.visit(src.length) pack_width = (rep_length if rep_width is None else vtypes.Mul(rep_length, rep_width)) in_line = m.Input(arg_name + '_line', pack_width, signed=src.get_signed()) in_array = m.Wire(arg_name, rep_width, rep_length, signed=src.get_signed()) i = m.TmpGenvar(prefix='i') v = in_line[i * rep_width:(i + 1) * rep_width] g = m.GenerateFor(i(0), i < rep_length, i(i + 1)) p = g.Assign(in_array[i](v)) src_rename_dict[src.name] = in_array ports[in_line.name] = out_line else: rep_width = (src_rename_visitor.visit(src.width) if src.width is not None else None) v = m.Input(arg_name, rep_width, signed=src.get_signed()) src_rename_dict[src.name] = v ports[arg_name] = src for dst in dsts: if dst.name in list(self.delayed_cond.keys()): continue arg_name = dst.name rep_width = (src_rename_visitor.visit(dst.width) if dst.width is not None else None) out = m.OutputReg(arg_name, rep_width, signed=dst.get_signed()) out_wire = self.m.TmpWire(rep_width, signed=dst.get_signed(), prefix='_%s_%s' % (self.name, arg_name)) self.m.Always()(dst(out_wire, blk=True)) ports[arg_name] = out_wire body = [src_rename_visitor.visit(statement) for statement in body] reset = self.make_reset(reset) if not reset and not body: pass elif not reset or rst is None: m.Always(vtypes.Posedge(clk))( body, ) else: m.Always(vtypes.Posedge(clk))( vtypes.If(rst)( reset, )( body, )) arg_params = [(name, param) for name, param in params.items()] arg_ports = [('CLK', self.clk)] if self.rst is not None: arg_ports.append(('RST', self.rst)) arg_ports.extend([(name, port) for name, port in ports.items()]) sub = Submodule(self.m, m, 'inst_' + m.name, '_%s_' % self.name, arg_params=arg_params, arg_ports=arg_ports) # ------------------------------------------------------------------------- def make_case(self): indexes = set(self.body.keys()) indexes.update(set(self.jump.keys())) for index in indexes: self._add_mark(index) ret = [] ret.extend(self.seq.make_code()) ret.extend(self._get_delayed_substs()) for delay, dct in sorted(self.delayed_body.items(), key=lambda x: x[0], reverse=True): body = tuple([self._get_delayed_when_statement(index, delay) for index in sorted(dct.keys(), key=lambda x:x)]) case = vtypes.Case(self._get_delayed_state(delay))(*body) ret.append(case) body = tuple([self._get_when_statement(index) for index in sorted(indexes, key=lambda x:x)]) case = vtypes.Case(self.state)(*body) if len(case.statement) > 0: ret.append(case) return ret def make_if(self): indexes = set(self.body.keys()) indexes.update(set(self.jump.keys())) for index in indexes: self._add_mark(index) ret = [] ret.extend(self.seq.make_code()) ret.extend(self._get_delayed_substs()) for delay, dct in sorted(self.delayed_body.items(), key=lambda x: x[0], reverse=True): ret.append([self._get_delayed_if_statement(index, delay) for index in sorted(dct.keys(), key=lambda x:x)]) ret.extend([self._get_if_statement(index) for index in sorted(indexes, key=lambda x:x)]) return ret # ------------------------------------------------------------------------- def make_reset(self, reset): ret = collections.OrderedDict() for v in reset: if not isinstance(v, vtypes.Subst): raise TypeError( 'make_reset requires Subst, not %s' % str(type(v))) key = str(v.left) if key not in ret: ret[key] = v v = self.reset_visitor.visit(self.state) key = str(self.state) if v is not None and key not in ret: ret[key] = v for dst in self.delayed_state.values(): v = self.reset_visitor.visit(dst) if v is None: continue key = str(v.left) if key not in ret: ret[key] = v for dst in self.dst_var.values(): v = self.reset_visitor.visit(dst) if v is None: continue key = str(v.left) if key not in ret: ret[key] = v for v in self.seq.make_reset(): if not isinstance(v, vtypes.Subst): raise TypeError( 'make_reset requires Subst, not %s' % str(type(v))) key = str(v.left) if key not in ret: ret[key] = v return list(ret.values()) # ------------------------------------------------------------------------- def set_index(self, index): return self._set_index(index) # ------------------------------------------------------------------------- def _go(self, src, dst, cond=None, else_dst=None): self._add_jump(src, dst, cond, else_dst) return self def _add_jump(self, src, dst, cond=None, else_dst=None): self.jump[src].append((dst, cond, else_dst)) # ------------------------------------------------------------------------- def _add_statement(self, statement, index=None, keep=None, delay=None, cond=None, lazy_cond=False, eager_val=False, no_delay_cond=False): cond = make_condition(cond) index = self._to_index(index) if index is not None else self.current if keep is not None: for i in range(keep): new_delay = i if delay is None else delay + i self._add_statement(statement, index=index, keep=None, delay=new_delay, cond=cond, lazy_cond=lazy_cond, eager_val=eager_val, no_delay_cond=no_delay_cond) return self if delay is not None and delay > 0: self._add_delayed_state(delay) if eager_val: statement = [self._add_delayed_subst(s, index, delay) for s in statement] if not no_delay_cond: if cond is None: cond = 1 if not lazy_cond: cond = self._add_delayed_cond(cond, index, delay) else: # lazy condition t = self._add_delayed_cond(1, index, delay) if isinstance(cond, int) and cond == 1: cond = t else: cond = vtypes.Ands(t, cond) statement = [vtypes.If(cond)(*statement)] self.delayed_body[delay][index].extend(statement) self._add_dst_var(statement) return self if cond is not None: statement = [vtypes.If(cond)(*statement)] self.last_if_statement = statement[0] self.body[index].extend(statement) self._add_dst_var(statement) return self # ------------------------------------------------------------------------- def _add_dst_var(self, statement): for s in statement: values = self.dst_visitor.visit(s) for v in values: k = str(v) if k not in self.dst_var: self.dst_var[k] = v # ------------------------------------------------------------------------- def _add_delayed_cond(self, statement, index, delay): name_prefix = '_'.join( ['', self.name, 'cond', str(index), str(self.tmp_count)]) self.tmp_count += 1 prev = statement for i in range(delay): tmp_name = '_'.join([name_prefix, str(i + 1)]) tmp = self.m.Reg(tmp_name, initval=0) self.delayed_cond[tmp_name] = tmp self._add_statement([tmp(prev)], delay=i, no_delay_cond=True) prev = tmp return prev # ------------------------------------------------------------------------- def _add_delayed_subst(self, subst, index, delay): if not isinstance(subst, vtypes.Subst): return subst left = subst.left right = subst.right if isinstance(right, (bool, int, float, str, vtypes._Constant, vtypes._ParameterVariable)): return subst width = left.bit_length() signed = vtypes.get_signed(left) prev = right name_prefix = ('_'.join(['', left.name, str(index), str(self.tmp_count)]) if isinstance(left, vtypes._Variable) else '_'.join(['', self.name, 'sbst', str(index), str(self.tmp_count)])) self.tmp_count += 1 for i in range(delay): tmp_name = '_'.join([name_prefix, str(i + 1)]) tmp = self.m.Reg(tmp_name, width, initval=0, signed=signed) self._add_statement([tmp(prev)], delay=i, no_delay_cond=True) prev = tmp return left(prev) # ------------------------------------------------------------------------- def _clear_next_kwargs(self): self.next_kwargs = {} def _clear_last_if_statement(self): self.last_if_statement = None def _clear_last_cond(self): self.last_cond = [] def _clear_elif_cond(self): self.elif_cond = None def _make_cond(self, condlist): ret = None for cond in condlist: if ret is None: ret = cond else: ret = vtypes.Ands(ret, cond) return ret # ------------------------------------------------------------------------- def _set_index(self, index=None): if index is None: self.state_count += 1 return self.state_count self.state_count = index return self.state_count def _get_mark(self, index=None): if index is None: index = self.state_count if index not in self.mark: raise KeyError("No such index in FSM marks: %s" % index) return self.mark[index] def _set_mark(self, index=None, name=None): if index is None: index = self.state_count if name is None: name = self.name + '_' + str(index) self.mark[index] = self.m.Localparam(name, index) def _get_mark_index(self, s): for index, m in self.mark.items(): if m.name == s.name: return index raise KeyError("No such mark in FSM marks: %s" % s.name) # ------------------------------------------------------------------------- def _add_mark(self, index): index = self._to_index(index) if index not in self.mark: self._set_mark(index) mark = self._get_mark(index) return mark def _to_index(self, index): if not isinstance(index, int): index = self._get_mark_index(index) return index # ------------------------------------------------------------------------- def _add_delayed_state(self, value): if not isinstance(value, int): raise TypeError("Delay amount must be int, not '%s'" % str(type(value))) if value < 0: raise ValueError("Delay amount must be positive number") if value == 0: return self.state if value <= self.delay_amount: return self._get_delayed_state(value) for i in range(self.delay_amount + 1, value + 1): d = self.m.Reg(''.join(['_d', str(i), '_', self.name]), self.width, initval=self._get_mark(0)) self.delayed_state[i] = d self.delay_amount = value return d def _get_delayed_state(self, value): if value == 0: return self.state if value not in self.delayed_state: raise IndexError('No such index %d in delayed state' % value) return self.delayed_state[value] def _get_delayed_substs(self): ret = [] prev = self.state for d in range(1, self.delay_amount + 1): ret.append(vtypes.Subst(self.delayed_state[d], prev)) prev = self.delayed_state[d] return ret def _init_delayed_state(self): ret = [] for d in range(1, self.delay_amount + 1): ret.append(vtypes.Subst(self.delayed_state[d], self._get_mark(0))) return ret def _to_state_assign(self, dst, cond=None, else_dst=None): dst_mark = self._get_mark(dst) value = self.state(dst_mark) if cond is not None: value = vtypes.If(cond)(value) if else_dst is not None: else_dst_mark = self._get_mark(else_dst) value = value.Else(self.state(else_dst_mark)) return value # ------------------------------------------------------------------------- def _cond_case(self, index): if index not in self.mark: self._set_mark(index) return self._get_mark(index) def _cond_if(self, index): if index not in self.mark: self._set_mark(index) return (self.state == self._get_mark(index)) def _delayed_cond_if(self, index, delay): if index not in self.mark: self._set_mark(index) if delay > 0 and delay not in self.delayed_state: self._add_delayed_state(delay) return (self._get_delayed_state(delay) == self._get_mark(index)) def _get_when_statement(self, index): body = [] body.extend(self.body[index]) for dst, cond, else_dst in self.jump[index]: self._add_mark(dst) if else_dst is not None: self._add_mark(else_dst) body.append(self._to_state_assign(dst, cond, else_dst)) return vtypes.When(self._cond_case(index))(*body) def _get_delayed_when_statement(self, index, delay): return vtypes.When(self._cond_case(index))(*self.delayed_body[delay][index]) def _get_if_statement(self, index): body = [] body.extend(self.body[index]) for dst, cond, else_dst in self.jump[index]: self._add_mark(dst) if else_dst is not None: self._add_mark(else_dst) body.append(self._to_state_assign(dst, cond, else_dst)) return vtypes.If(self._cond_if(index))(*body) def _get_delayed_if_statement(self, index, delay): return vtypes.If(self._delayed_cond_if(index, delay))(*self.delayed_body[delay][index]) # ------------------------------------------------------------------------- def __call__(self, *statement, **kwargs): return self.add(*statement, **kwargs) def __getitem__(self, index): return self.body[index] def __len__(self): return self.state_count + 1