Example #1
0
    def write(self, fsm, value, *part):
        if self.seq is None:
            m = fsm.m
            clk = fsm.clk
            rst = fsm.rst
            name = self._value.name
            self.seq = Seq(m, '_'.join(['seq', name]), clk, rst)

        cond = fsm.state == fsm.current

        def getval(v, p):
            if isinstance(p, (tuple, list)):
                return v[p[0], p[1]]
            return v[p]

        if len(part) == 0:
            targ = self._value
        elif len(part) == 1:
            targ = getval(self._value, part[0])
        elif len(part) == 2:
            targ = getval(getval(self._value, part[0]), part[1])
        else:
            raise TypeError('unsupported type')

        self.seq.If(cond)(targ(value))

        fsm.goto_next()
        return 0
Example #2
0
    def __init__(self, m, name, clk, rst, datawidth=32, addrwidth=32,
                 nodataflow=False):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.datawidth = datawidth
        self.addrwidth = addrwidth

        self.waddr = AxiSlaveWriteAddress(m, name, datawidth, addrwidth)
        self.wdata = AxiSlaveWriteData(m, name, datawidth, addrwidth)
        self.raddr = AxiSlaveReadAddress(m, name, datawidth, addrwidth)
        self.rdata = AxiSlaveReadData(m, name, datawidth, addrwidth)

        self.seq = Seq(m, name, clk, rst)

        self.write_counters = []
        self.read_counters = []

        if nodataflow:
            self.df = None
        else:
            self.df = DataflowManager(self.m, self.clk, self.rst)

        self._write_disabled = False
        self._read_disabled = False
Example #3
0
    def __init__(self,
                 m,
                 name,
                 clk,
                 rst,
                 datawidth=32,
                 addrwidth=10,
                 numports=1,
                 nodataflow=False):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.datawidth = datawidth
        self.addrwidth = addrwidth
        self.interfaces = [
            RAMInterface(m, name + '_%d' % i, datawidth, addrwidth)
            for i in range(numports)
        ]

        self.definition = mkRAMDefinition(name, datawidth, addrwidth, numports)
        self.inst = self.m.Instance(self.definition,
                                    'inst_' + name,
                                    ports=m.connect_ports(self.definition))

        self.seq = Seq(m, name, clk, rst)

        if nodataflow:
            self.df = None
        else:
            self.df = DataflowManager(self.m, self.clk, self.rst)

        self._write_disabled = [False for i in range(numports)]
Example #4
0
    def __init__(self, *nodes, **opts):
        # ID for manager reuse and merge
        global _stream_counter
        self.object_id = _stream_counter
        _stream_counter += 1

        self.nodes = set()
        self.named_numerics = OrderedDict()

        self.add(*nodes)

        self.max_stage = 0
        self.last_input = None
        self.last_output = None

        self.module = opts['module'] if 'module' in opts else None
        self.clock = opts['clock'] if 'clock' in opts else None
        self.reset = opts['reset'] if 'reset' in opts else None

        self.ivalid = opts['ivalid'] if 'ivalid' in opts else None
        self.iready = opts['iready'] if 'iready' in opts else None
        self.ovalid = opts['ovalid'] if 'ovalid' in opts else None
        self.oready = opts['oready'] if 'oready' in opts else None

        self.aswire = opts['aswire'] if 'aswire' in opts else True
        self.dump = opts['dump'] if 'dump' in opts else False
        self.dump_base = opts['dump_base'] if 'dump_base' in opts else 10
        self.dump_mode = opts['dump_mode'] if 'dump_mode' in opts else 'all'

        self.seq = None
        self.has_control = False

        self.implemented = False

        if (self.module is not None and
                self.clock is not None and self.reset is not None):

            no_hook = opts['no_hook'] if 'no_hook' in opts else False
            if not no_hook:
                self.module.add_hook(self.implement)

            seq_name = (opts['seq_name'] if 'seq_name' in opts else
                        '_stream_seq_%d' % self.object_id)
            self.seq = Seq(self.module, seq_name, self.clock, self.reset)

        if self.dump:
            dump_enable_name = '_stream_dump_enable_%d' % self.object_id
            dump_enable = self.module.Reg(dump_enable_name, initval=0)
            dump_mask_name = '_stream_dump_mask_%d' % self.object_id
            dump_mask = self.module.Reg(dump_mask_name, initval=0)
            dump_step_name = '_stream_dump_step_%d' % self.object_id
            dump_step = self.module.Reg(dump_step_name, 32, initval=0)

            self.dump_enable = dump_enable
            self.dump_mask = dump_mask
            self.dump_step = dump_step

            if self.seq:
                self.seq.add_reset(self.dump_enable)
                self.seq.add_reset(self.dump_mask)
Example #5
0
    def __init__(self,
                 m,
                 name,
                 clk,
                 rst,
                 width=32,
                 initname='init',
                 nohook=False,
                 as_module=False):
        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.width = width
        self.state_count = 0
        self.state = self.m.Reg(name, width)  # set initval later

        self.mark = collections.OrderedDict()  # key:index
        self._set_mark(0, self.name + '_' + initname)
        self.state.initval = self._get_mark(0)

        self.body = collections.defaultdict(list)
        self.jump = collections.defaultdict(list)

        self.delay_amount = 0
        self.delayed_state = collections.OrderedDict()  # key:delay
        self.delayed_body = collections.defaultdict(
            functools.partial(collections.defaultdict, list))  # key:delay
        self.delayed_cond = collections.OrderedDict()  # key:name
        self.tmp_count = 0

        self.dst_var = collections.OrderedDict()
        self.dst_visitor = SubstDstVisitor()
        self.reset_visitor = ResetVisitor()

        self.seq = Seq(self.m, self.name + '_par', clk, rst, nohook=True)

        self.done = False

        self.last_cond = []
        self.last_kwargs = {}
        self.last_if_statement = None
        self.elif_cond = None
        self.next_kwargs = {}

        self.as_module = as_module

        if not nohook:
            self.m.add_hook(self.implement)
Example #6
0
    def __init__(self, m, name, clk, rst, width=32):
        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.width = width

        self.tmp_count = 0
        self.max_stage_id = 0
        self.vars = []

        self.seq = Seq(self.m, self.name, clk, rst, nohook=True)
        self.data_visitor = DataVisitor(self)

        self.done = False
Example #7
0
def mkFifoDefinition(name, datawidth=32, addrwidth=4):
    m = module.Module(name)
    clk = m.Input('CLK')
    rst = m.Input('RST')

    wif = FifoWriteSlaveInterface(m, name, datawidth)
    rif = FifoReadSlaveInterface(m, name, datawidth)

    mem = m.Reg('mem', datawidth, length=2**addrwidth)
    head = m.Reg('head', addrwidth, initval=0)
    tail = m.Reg('tail', addrwidth, initval=0)

    is_empty = m.Wire('is_empty')
    is_almost_empty = m.Wire('is_almost_empty')
    is_full = m.Wire('is_full')
    is_almost_full = m.Wire('is_almost_full')

    mask = (2 ** addrwidth) - 1

    is_empty.assign(head == tail)
    is_almost_empty.assign(head == ((tail + 1) & mask))
    is_full.assign(((head + 1) & mask) == tail)
    is_almost_full.assign(((head + 2) & mask) == tail)

    rdata = m.Reg('rdata_reg', datawidth, initval=0)

    wif.full.assign(is_full)
    wif.almost_full.assign(vtypes.Ors(is_almost_full, is_full))
    rif.empty.assign(is_empty)
    rif.almost_empty.assign(vtypes.Ors(is_almost_empty, is_empty))

    seq = Seq(m, '', clk, rst)

    seq.If(vtypes.Ands(wif.enq, vtypes.Not(is_full)))(
        mem[head](wif.wdata),
        head.inc()
    )

    seq.If(vtypes.Ands(rif.deq, vtypes.Not(is_empty)))(
        rdata(mem[tail]),
        tail.inc()
    )

    rif.rdata.assign(rdata)

    seq.make_always()

    return m
Example #8
0
    def __init__(self, *nodes, **opts):
        # ID for manager reuse and merge
        global _dataflow_counter
        self.object_id = _dataflow_counter
        _dataflow_counter += 1

        self.nodes = set(nodes)
        self.max_stage = 0
        self.last_input = None
        self.last_output = None

        self.module = opts['module'] if 'module' in opts else None
        self.clock = opts['clock'] if 'clock' in opts else None
        self.reset = opts['reset'] if 'reset' in opts else None

        self.aswire = opts['aswire'] if 'aswire' in opts else True

        self.seq = None

        if (self.module is not None and self.clock is not None
                and self.reset is not None):

            no_hook = opts['no_hook'] if 'no_hook' in opts else False
            if not no_hook:
                self.module.add_hook(self.implement)

            seq_name = (opts['seq_name'] if 'seq_name' in opts else
                        '_dataflow_seq_%d' % self.object_id)
            self.seq = Seq(self.module, seq_name, self.clock, self.reset)
Example #9
0
    def __init__(self, m, name, clk, rst, width=32):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.width = width

        self.seq = Seq(self.m, self.name, self.clk, self.rst)

        self.lock_reg = self.m.Reg('_'.join(['', self.name, 'lock_reg']),
                                   initval=0)
        self.lock_id = self.m.Reg('_'.join(['', self.name, 'lock_id']),
                                  self.width,
                                  initval=0)

        self.id_map = {}
        self.id_map_count = 0
Example #10
0
    def __init__(self,
                 m,
                 clk,
                 rst,
                 thread_name,
                 thread_id,
                 id,
                 sub_id=None,
                 datawidth=32,
                 addrwidth=10,
                 numports=1):

        self.m = m
        self.clk = clk
        self.rst = rst
        self.thread_name = thread_name
        self.thread_id = thread_id
        self.id = id
        self.sub_id = sub_id
        self.name = '_'.join([thread_name, 'memory', str(id)])
        self.datawidth = datawidth
        self.addrwidth = addrwidth

        self.interfaces = []
        for i in range(numports):
            index = i if numports > 1 else None
            obj = CoramMemoryInterface(m,
                                       name=self.name,
                                       datawidth=datawidth,
                                       addrwidth=addrwidth,
                                       index=index)
            self.interfaces.append(obj)

        self.definition = get_coram_memory_definition(numports)

        params = []
        params.append(('CORAM_THREAD_NAME', self.thread_name))
        params.append(('CORAM_THREAD_ID', self.thread_id))
        params.append(('CORAM_ID', self.id))
        if self.sub_id is not None:
            params.append(('CORAM_SUB_ID', self.sub_id))
        params.append(('CORAM_ADDR_LEN', self.addrwidth))
        params.append(('CORAM_DATA_WIDTH', self.datawidth))

        ports = []
        ports.append(('CLK', self.clk))
        ports.extend(m.connect_ports(self.definition, prefix=self.name + '_'))

        self.inst = self.m.Instance(self.definition, 'inst_' + self.name,
                                    params, ports)

        self.seq = Seq(m, self.name, clk, rst)

        self._write_disabled = [False for i in range(numports)]
Example #11
0
    def __init__(self, m, name, clk, rst, func):
        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.func = func

        self.df = DataflowManager(self.m, self.clk, self.rst)
        self.seq = Seq(self.m, 'name', self.clk, self.rst)
        self.start_cond = None
        self.done_flags = []
Example #12
0
    def __init__(self, m, name, clk, rst, datawidth=32, addrwidth=4):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst

        self.datawidth = datawidth
        self.addrwidth = addrwidth

        self.wif = FifoWriteInterface(self.m, name, datawidth)
        self.rif = FifoReadInterface(self.m, name, datawidth)

        self.definition = mkFifoDefinition(name, datawidth, addrwidth)

        self.inst = self.m.Instance(self.definition, 'inst_' + name,
                                    ports=m.connect_ports(self.definition))

        self.seq = Seq(m, name, clk, rst)

        # entry counter
        self._max_size = (2 ** self.addrwidth - 1 if isinstance(self.addrwidth, int) else
                          vtypes.Int(2) ** self.addrwidth - 1)

        self._count = self.m.Reg(
            'count_' + name, self.addrwidth + 1, initval=0)

        self.seq.If(
            vtypes.Ands(vtypes.Ands(self.wif.enq, vtypes.Not(self.wif.full)),
                        vtypes.Ands(self.rif.deq, vtypes.Not(self.rif.empty))))(
            self._count(self._count)
        ).Elif(vtypes.Ands(self.wif.enq, vtypes.Not(self.wif.full)))(
            self._count.inc()
        ).Elif(vtypes.Ands(self.rif.deq, vtypes.Not(self.rif.empty)))(
            self._count.dec()
        )

        self._enq_disabled = False
        self._deq_disabled = False

        self.mutex = None
Example #13
0
class Barrier(object):
    __intrinsics__ = ('wait', )

    def __init__(self, m, name, clk, rst, numparties):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.numparties = numparties
        self.width = int(math.ceil(math.log(self.numparties, 2))) + 1

        self.seq = Seq(self.m, self.name, self.clk, self.rst)

        self.count = self.m.Reg('_'.join(['', self.name, 'barrier_count']),
                                self.width,
                                initval=0)
        self.done = self.m.Reg('_'.join(['', self.name, 'barrier_done']),
                               initval=0)
        self.mutex = Mutex(self.m, '_'.join(['', self.name, 'barrier_mutex']),
                           self.clk, self.rst)

        # reset condition
        self.seq(self.done(0))
        self.seq.If(self.count == self.numparties)(self.count(0), self.done(1))

    def wait(self, fsm):

        self.mutex.lock(fsm)

        state_cond = fsm.state == fsm.current
        self.seq.If(state_cond)(self.count.inc())
        fsm.goto_next()

        self.mutex.unlock(fsm)

        fsm.If(self.done).goto_next()

        return 0
Example #14
0
    def __init__(self, m, name, clk, rst, numparties):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.numparties = numparties
        self.width = int(math.ceil(math.log(self.numparties, 2))) + 1

        self.seq = Seq(self.m, self.name, self.clk, self.rst)

        self.count = self.m.Reg('_'.join(['', self.name, 'barrier_count']),
                                self.width,
                                initval=0)
        self.done = self.m.Reg('_'.join(['', self.name, 'barrier_done']),
                               initval=0)
        self.mutex = Mutex(self.m, '_'.join(['', self.name, 'barrier_mutex']),
                           self.clk, self.rst)

        # reset condition
        self.seq(self.done(0))
        self.seq.If(self.count == self.numparties)(self.count(0), self.done(1))
Example #15
0
    def __init__(self,
                 m,
                 clk,
                 rst,
                 thread_name,
                 thread_id,
                 id,
                 sub_id=None,
                 datawidth=32,
                 addrwidth=4):

        self.m = m
        self.clk = clk
        self.rst = rst
        self.thread_name = thread_name
        self.thread_id = thread_id
        self.id = id
        self.sub_id = sub_id
        self.name = '_'.join([thread_name, 'channel', str(id)])
        self.datawidth = datawidth
        self.addrwidth = addrwidth
        self.wif = CoramChannelWriteInterface(self.m, self.name, datawidth)
        self.rif = CoramChannelReadInterface(self.m, self.name, datawidth)
        self.definition = get_coram_channel_definition()

        params = []
        params.append(('CORAM_THREAD_NAME', self.thread_name))
        params.append(('CORAM_THREAD_ID', self.thread_id))
        params.append(('CORAM_ID', self.id))
        if self.sub_id is not None:
            params.append(('CORAM_SUB_ID', self.sub_id))
        params.append(('CORAM_ADDR_LEN', self.addrwidth))
        params.append(('CORAM_DATA_WIDTH', self.datawidth))

        ports = []
        ports.append(('CLK', self.clk))
        ports.append(('RST', self.rst))
        ports.extend(m.connect_ports(self.definition, prefix=self.name + '_'))

        self.inst = self.m.Instance(self.definition, 'inst_' + self.name,
                                    params, ports)

        self.seq = Seq(m, self.name, clk, rst)

        self._max_size = (2**self.addrwidth - 1 if isinstance(
            self.addrwidth, int) else vtypes.Int(2)**self.addrwidth - 1)

        self._count = 0

        self._enq_disabled = False
        self._deq_disabled = False
Example #16
0
class Shared(_MutexFunction):
    __intrinsics__ = ('read', 'write') + _MutexFunction.__intrinsics__

    def __init__(self, value):
        self._value = value
        self.seq = None
        self.mutex = None

    @property
    def value(self):
        return self._value

    def read(self, fsm):
        return self._value

    def write(self, fsm, value, *part):
        if self.seq is None:
            m = fsm.m
            clk = fsm.clk
            rst = fsm.rst
            name = self._value.name
            self.seq = Seq(m, '_'.join(['seq', name]), clk, rst)

        cond = fsm.state == fsm.current

        def getval(v, p):
            if isinstance(p, (tuple, list)):
                return v[p[0], p[1]]
            return v[p]

        if len(part) == 0:
            targ = self._value
        elif len(part) == 1:
            targ = getval(self._value, part[0])
        elif len(part) == 2:
            targ = getval(getval(self._value, part[0]), part[1])
        else:
            raise TypeError('unsupported type')

        self.seq.If(cond)(targ(value))

        fsm.goto_next()
        return 0

    def _check_mutex(self, fsm):
        if self.mutex is None:
            m = fsm.m
            clk = fsm.clk
            rst = fsm.rst
            name = self._value.name
            self.mutex = Mutex(m, '_'.join(['', name, 'mutex']), clk, rst)
Example #17
0
    def __init__(self, m, name, clk, rst, width=32, initname='init',
                 nohook=False, as_module=False):
        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.width = width
        self.state_count = 0
        self.state = self.m.Reg(name, width)  # set initval later

        self.mark = collections.OrderedDict()  # key:index
        self._set_mark(0, self.name + '_' + initname)
        self.state.initval = self._get_mark(0)

        self.body = collections.defaultdict(list)
        self.jump = collections.defaultdict(list)

        self.delay_amount = 0
        self.delayed_state = collections.OrderedDict()  # key:delay
        self.delayed_body = collections.defaultdict(
            functools.partial(collections.defaultdict, list))  # key:delay
        self.delayed_cond = collections.OrderedDict()  # key:name
        self.tmp_count = 0

        self.dst_var = collections.OrderedDict()
        self.dst_visitor = SubstDstVisitor()
        self.reset_visitor = ResetVisitor()

        self.seq = Seq(self.m, self.name + '_par', clk, rst, nohook=True)

        self.done = False

        self.last_cond = []
        self.last_kwargs = {}
        self.last_if_statement = None
        self.elif_cond = None
        self.next_kwargs = {}

        self.as_module = as_module

        if not nohook:
            self.m.add_hook(self.implement)
Example #18
0
    def __init__(self, *nodes, **opts):
        # ID for manager reuse and merge
        global _stream_counter
        self.object_id = _stream_counter
        _stream_counter += 1

        self.nodes = set()
        self.named_numerics = OrderedDict()

        self.add(*nodes)

        self.max_stage = 0
        self.last_input = None
        self.last_output = None

        self.module = opts['module'] if 'module' in opts else None
        self.clock = opts['clock'] if 'clock' in opts else None
        self.reset = opts['reset'] if 'reset' in opts else None

        self.ivalid = opts['ivalid'] if 'ivalid' in opts else None
        self.iready = opts['iready'] if 'iready' in opts else None
        self.ovalid = opts['ovalid'] if 'ovalid' in opts else None
        self.oready = opts['oready'] if 'oready' in opts else None

        self.aswire = opts['aswire'] if 'aswire' in opts else True

        self.seq = None
        self.has_control = False

        self.implemented = False

        if (self.module is not None and self.clock is not None
                and self.reset is not None):

            no_hook = opts['no_hook'] if 'no_hook' in opts else False
            if not no_hook:
                self.module.add_hook(self.implement)

            seq_name = (opts['seq_name'] if 'seq_name' in opts else
                        '_stream_seq_%d' % self.object_id)
            self.seq = Seq(self.module, seq_name, self.clock, self.reset)
Example #19
0
def mkFifoDefinition(name, datawidth=32, addrwidth=4):
    m = module.Module(name)
    clk = m.Input('CLK')
    rst = m.Input('RST')

    wif = FifoWriteSlaveInterface(m, name, datawidth)
    rif = FifoReadSlaveInterface(m, name, datawidth)

    mem = m.Reg('mem', datawidth, 2**addrwidth)
    head = m.Reg('head', addrwidth, initval=0)
    tail = m.Reg('tail', addrwidth, initval=0)

    is_empty = m.Wire('is_empty')
    is_almost_empty = m.Wire('is_almost_empty')
    is_full = m.Wire('is_full')
    is_almost_full = m.Wire('is_almost_full')

    mask = (2**addrwidth) - 1

    is_empty.assign(head == tail)
    is_almost_empty.assign(head == ((tail + 1) & mask))
    is_full.assign(((head + 1) & mask) == tail)
    is_almost_full.assign(((head + 2) & mask) == tail)

    rdata = m.Reg('rdata_reg', datawidth, initval=0)

    wif.full.assign(is_full)
    wif.almost_full.assign(vtypes.Ors(is_almost_full, is_full))
    rif.empty.assign(is_empty)
    rif.almost_empty.assign(vtypes.Ors(is_almost_empty, is_empty))

    seq = Seq(m, '', clk, rst)

    seq.If(vtypes.Ands(wif.enq, vtypes.Not(is_full)))(mem[head](wif.wdata),
                                                      head.inc())

    seq.If(vtypes.Ands(rif.deq, vtypes.Not(is_empty)))(rdata(mem[tail]),
                                                       tail.inc())

    rif.rdata.assign(rdata)

    seq.make_always()

    return m
Example #20
0
class AxiMaster(object):
    burst_size_width = 8

    def __init__(self, m, name, clk, rst, datawidth=32, addrwidth=32,
                 nodataflow=False):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.datawidth = datawidth
        self.addrwidth = addrwidth

        self.waddr = AxiMasterWriteAddress(m, name, datawidth, addrwidth)
        self.wdata = AxiMasterWriteData(m, name, datawidth, addrwidth)
        self.raddr = AxiMasterReadAddress(m, name, datawidth, addrwidth)
        self.rdata = AxiMasterReadData(m, name, datawidth, addrwidth)

        self.seq = Seq(m, name, clk, rst)

        self.write_counters = []
        self.read_counters = []

        if nodataflow:
            self.df = None
        else:
            self.df = DataflowManager(self.m, self.clk, self.rst)

        self._write_disabled = False
        self._read_disabled = False

    def disable_write(self):
        self.seq(
            self.waddr.awaddr(0),
            self.waddr.awlen(0),
            self.waddr.awvalid(0),
            self.wdata.wdata(0),
            self.wdata.wstrb(0),
            self.wdata.wvalid(0),
            self.wdata.wlast(0)
        )
        self._write_disabled = True

    def disable_read(self):
        self.seq(
            self.raddr.araddr(0),
            self.raddr.arlen(0),
            self.raddr.arvalid(0)
        )
        self.rdata.rready.assign(0)
        self._read_disabled = True

    def write_request(self, addr, length=1, cond=None, counter=None):
        """ 
        @return ack, counter
        """

        if self._write_disabled:
            raise TypeError('Write disabled.')

        if isinstance(length, int) and length > 2 ** self.burst_size_width:
            raise ValueError("length must be less than 257.")

        if isinstance(length, int) and length < 1:
            raise ValueError("length must be more than 0.")

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if cond is not None:
            self.seq.If(cond)

        ack = vtypes.Ors(self.waddr.awready, vtypes.Not(self.waddr.awvalid))

        if counter is None:
            counter = self.m.TmpReg(self.burst_size_width, initval=0)

        self.write_counters.append(counter)

        self.seq.If(vtypes.Ands(ack, counter == 0))(
            self.waddr.awaddr(addr),
            self.waddr.awlen(length - 1),
            self.waddr.awvalid(1),
            counter(length)
        )
        self.seq.Then().If(length == 0)(
            self.waddr.awvalid(0)
        )

        # de-assert
        self.seq.Delay(1)(
            self.waddr.awvalid(0)
        )

        # retry
        self.seq.If(vtypes.Ands(self.waddr.awvalid, vtypes.Not(self.waddr.awready)))(
            self.waddr.awvalid(self.waddr.awvalid)
        )

        return ack, counter

    def write_data(self, data, counter=None, cond=None):
        """ 
        @return ack, last
        """

        if self._write_disabled:
            raise TypeError('Write disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.write_counters[-1]

        if cond is not None:
            self.seq.If(cond)

        ack = vtypes.Ands(counter > 0,
                          vtypes.Ors(self.wdata.wready, vtypes.Not(self.wdata.wvalid)))
        last = self.m.TmpReg(initval=0)

        self.seq.If(vtypes.Ands(ack, counter > 0))(
            self.wdata.wdata(data),
            self.wdata.wvalid(1),
            self.wdata.wlast(0),
            self.wdata.wstrb(vtypes.Repeat(
                vtypes.Int(1, 1), (self.wdata.datawidth // 8))),
            counter.dec()
        )
        self.seq.Then().If(counter == 1)(
            self.wdata.wlast(1),
            last(1)
        )

        # de-assert
        self.seq.Delay(1)(
            self.wdata.wvalid(0),
            self.wdata.wlast(0),
            last(0)
        )

        # retry
        self.seq.If(vtypes.Ands(self.wdata.wvalid, vtypes.Not(self.wdata.wready)))(
            self.wdata.wvalid(self.wdata.wvalid),
            self.wdata.wlast(self.wdata.wlast),
            last(last)
        )

        return ack, last

    def write_dataflow(self, data, counter=None, cond=None, when=None):
        """ 
        @return done
        """

        if self._write_disabled:
            raise TypeError('Write disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.write_counters[-1]

        ack = vtypes.Ands(counter > 0,
                          vtypes.Ors(self.wdata.wready, vtypes.Not(self.wdata.wvalid)))
        last = self.m.TmpReg(initval=0)

        if cond is None:
            cond = ack
        else:
            cond = (cond, ack)

        raw_data, raw_valid = data.read(cond=cond)

        when_cond = make_condition(when, ready=cond)
        if when_cond is not None:
            raw_valid = vtypes.Ands(when_cond, raw_valid)

        # write condition
        self.seq.If(raw_valid)

        self.seq.If(vtypes.Ands(ack, counter > 0))(
            self.wdata.wdata(raw_data),
            self.wdata.wvalid(1),
            self.wdata.wlast(0),
            self.wdata.wstrb(vtypes.Repeat(
                vtypes.Int(1, 1), (self.wdata.datawidth // 8))),
            counter.dec()
        )
        self.seq.Then().If(counter == 1)(
            self.wdata.wlast(1),
            last(1)
        )

        # de-assert
        self.seq.Delay(1)(
            self.wdata.wvalid(0),
            self.wdata.wlast(0),
            last(0)
        )

        # retry
        self.seq.If(vtypes.Ands(self.wdata.wvalid, vtypes.Not(self.wdata.wready)))(
            self.wdata.wvalid(self.wdata.wvalid),
            self.wdata.wlast(self.wdata.wlast),
            last(last)
        )

        done = last

        return done

    def read_request(self, addr, length, cond=None, counter=None):
        """ 
        @return ack, counter
        """

        if self._read_disabled:
            raise TypeError('Read disabled.')

        if isinstance(length, int) and length > 2 ** self.burst_size_width:
            raise ValueError("length must be less than 257.")

        if isinstance(length, int) and length < 1:
            raise ValueError("length must be more than 0.")

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if cond is not None:
            self.seq.If(cond)

        ack = vtypes.Ors(self.raddr.arready, vtypes.Not(self.raddr.arvalid))

        if counter is None:
            counter = self.m.TmpReg(self.burst_size_width, initval=0)

        self.read_counters.append(counter)

        self.seq.If(vtypes.Ands(ack, counter == 0))(
            self.raddr.araddr(addr),
            self.raddr.arlen(length - 1),
            self.raddr.arvalid(1),
            counter(length)
        )

        # de-assert
        self.seq.Delay(1)(
            self.raddr.arvalid(0)
        )

        # retry
        self.seq.If(vtypes.Ands(self.raddr.arvalid, vtypes.Not(self.raddr.arready)))(
            self.raddr.arvalid(self.raddr.arvalid)
        )

        return ack, counter

    def read_data(self, counter=None, cond=None):
        """ 
        @return data, valid, last
        """

        if self._read_disabled:
            raise TypeError('Read disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.read_counters[-1]

        ready = make_condition(cond)
        val = 1 if ready is None else ready

        prev_subst = self.rdata.rready._get_subst()
        if not prev_subst:
            self.rdata.rready.assign(val)
        else:
            self.rdata.rready.subst[0].overwrite_right(
                vtypes.Ors(prev_subst[0].right, val))

        ack = vtypes.Ands(self.rdata.rready, self.rdata.rvalid)
        data = self.rdata.rdata
        valid = ack
        last = self.rdata.rlast

        self.seq.If(vtypes.Ands(ack, counter > 0))(
            counter.dec()
        )

        return data, valid, last

    def read_dataflow(self, counter=None, cond=None):
        """ 
        @return data, last, done
        """

        if self._read_disabled:
            raise TypeError('Read disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.read_counters[-1]

        data_ready = self.m.TmpWire()
        last_ready = self.m.TmpWire()
        data_ready.assign(1)
        last_ready.assign(1)

        if cond is None:
            cond = (data_ready, last_ready)
        elif isinstance(cond, (tuple, list)):
            cond = tuple(list(cond) + [data_ready, last_ready])
        else:
            cond = (cond, data_ready, last_ready)

        ready = make_condition(*cond)
        val = 1 if ready is None else ready

        prev_subst = self.rdata.rready._get_subst()
        if not prev_subst:
            self.rdata.rready.assign(val)
        else:
            self.rdata.rready.subst[0].overwrite_right(
                vtypes.Ors(prev_subst[0].right, val))

        ack = vtypes.Ands(self.rdata.rready, self.rdata.rvalid)
        data = self.rdata.rdata
        valid = self.rdata.rvalid
        last = self.rdata.rlast

        self.seq.If(vtypes.Ands(ack, counter > 0))(
            counter.dec()
        )

        df = self.df if self.df is not None else dataflow

        df_data = df.Variable(data, valid, data_ready)
        df_last = df.Variable(last, valid, last_ready, width=1)
        done = last

        return df_data, df_last, done

    def dma_read(self, ram, bus_addr, ram_addr, length, cond=None, ram_port=0):
        fsm = TmpFSM(self.m, self.clk, self.rst)

        if cond is not None:
            fsm.If(cond).goto_next()

        ack, counter = self.read_request(bus_addr, length, cond=fsm)
        fsm.If(ack).goto_next()

        data, last, done = self.read_dataflow()

        done = ram.write_dataflow(ram_port, ram_addr, data, length, cond=fsm)
        fsm.goto_next()
        fsm.If(done).goto_next()

        fsm.goto_init()

        return done

    def dma_write(self, ram, bus_addr, ram_addr, length, cond=None, ram_port=0):
        fsm = TmpFSM(self.m, self.clk, self.rst)

        if cond is not None:
            fsm.If(cond).goto_next()

        ack, counter = self.write_request(bus_addr, length, cond=fsm)
        fsm.If(ack).goto_next()

        data, last, done = ram.read_dataflow(
            ram_port, ram_addr, length, cond=fsm)
        fsm.goto_next()

        done = self.write_dataflow(data, counter, cond=fsm)
        fsm.If(done).goto_next()

        fsm.goto_init()

        return done
Example #21
0
class FIFO(_MutexFunction):
    __intrinsics__ = ('enq', 'deq', 'try_enq', 'try_deq',
                      'is_empty', 'is_almost_empty',
                      'is_full', 'is_almost_full') + _MutexFunction.__intrinsics__

    def __init__(self, m, name, clk, rst, datawidth=32, addrwidth=4):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst

        self.datawidth = datawidth
        self.addrwidth = addrwidth

        self.wif = FifoWriteInterface(self.m, name, datawidth, itype='Wire', otype='Wire')
        self.rif = FifoReadInterface(self.m, name, datawidth, itype='Wire', otype='Wire')

        self.definition = mkFifoDefinition(name, datawidth, addrwidth)

        self.inst = self.m.Instance(self.definition, 'inst_' + name,
                                    ports=m.connect_ports(self.definition))

        self.seq = Seq(m, name, clk, rst)

        # entry counter
        self._max_size = (2 ** self.addrwidth - 1 if isinstance(self.addrwidth, int) else
                          vtypes.Int(2) ** self.addrwidth - 1)

        self._count = self.m.Reg(
            'count_' + name, self.addrwidth + 1, initval=0)

        self.seq.If(
            vtypes.Ands(vtypes.Ands(self.wif.enq, vtypes.Not(self.wif.full)),
                        vtypes.Ands(self.rif.deq, vtypes.Not(self.rif.empty))))(
            self._count(self._count)
        ).Elif(vtypes.Ands(self.wif.enq, vtypes.Not(self.wif.full)))(
            self._count.inc()
        ).Elif(vtypes.Ands(self.rif.deq, vtypes.Not(self.rif.empty)))(
            self._count.dec()
        )

        self._enq_disabled = False
        self._deq_disabled = False

        self.mutex = None

    def _id(self):
        return id(self)

    def disable_enq(self):
        self.seq(
            self.wif.enq(0)
        )
        self._enq_disabled = True

    def disable_deq(self):
        self.rif.deq.assign(0)
        self._deq_disabled = True

    def enq_rtl(self, wdata, cond=None):
        """ Enque """

        if self._enq_disabled:
            raise TypeError('Enq disabled.')

        cond = make_condition(cond)
        ready = vtypes.Not(self.wif.almost_full)

        if cond is not None:
            enq_cond = vtypes.Ands(cond, ready)
            enable = cond
        else:
            enq_cond = ready
            enable = vtypes.Int(1, 1)

        util.add_mux(self.wif.wdata, enable, wdata)
        util.add_enable_cond(self.wif.enq, enable, enq_cond)

        ack = self.seq.Prev(ready, 1)

        return ack, ready

    def deq_rtl(self, cond=None):
        """ Deque """

        if self._deq_disabled:
            raise TypeError('Deq disabled.')

        cond = make_condition(cond)
        ready = vtypes.Not(self.rif.empty)

        if cond is not None:
            deq_cond = vtypes.Ands(cond, ready)
        else:
            deq_cond = ready

        util.add_enable_cond(self.rif.deq, deq_cond, 1)

        data = self.rif.rdata
        valid = self.seq.Prev(deq_cond, 1)

        return data, valid, ready

    @property
    def wdata(self):
        return self.wif.wdata

    @property
    def empty(self):
        return self.rif.empty

    @property
    def almost_empty(self):
        return self.rif.almost_empty

    @property
    def rdata(self):
        return self.rif.rdata

    @property
    def full(self):
        return self.wif.full

    @property
    def almost_full(self):
        return self.wif.almost_full

    @property
    def count(self):
        return self._count

    @property
    def space(self):
        if isinstance(self._max_size, int):
            return vtypes.Int(self._max_size) - self.count
        return self._max_size - self.count

    def has_space(self, num=1):
        if num < 1:
            return True
        return (self._count + num < self._max_size)

    def enq(self, fsm, wdata):
        cond = fsm.state == fsm.current

        ack, ready = self.enq_rtl(wdata, cond=cond)
        fsm.If(ready).goto_next()

        return 0

    def deq(self, fsm):
        cond = fsm.state == fsm.current

        rdata, rvalid, rready = self.deq_rtl(cond=cond)
        fsm.If(rready).goto_next()

        rdata_reg = self.m.TmpReg(self.datawidth, initval=0, signed=True)

        fsm.If(rvalid)(
            rdata_reg(rdata)
        )
        fsm.If(rvalid).goto_next()

        return rdata_reg

    def try_enq(self, fsm, wdata):
        cond = fsm.state == fsm.current

        ack, ready = self.enq_rtl(wdata, cond=cond)
        fsm.goto_next()

        ack_reg = self.m.TmpReg(initval=0)
        fsm(
            ack_reg(ack)
        )
        fsm.goto_next()

        return ack_reg

    def try_deq(self, fsm):
        cond = fsm.state == fsm.current

        rdata, rvalid, rready = self.deq_rtl(cond=cond)
        fsm.goto_next()

        rdata_reg = self.m.TmpReg(self.datawidth, initval=0, signed=True)
        rvalid_reg = self.m.TmpReg(initval=0)

        fsm(
            rdata_reg(rdata),
            rvalid_reg(rvalid)
        )
        fsm.goto_next()

        return rdata_reg, rvalid_reg

    def is_almost_empty(self, fsm):
        fsm.goto_next()
        return self.almost_empty

    def is_empty(self, fsm):
        fsm.goto_next()
        return self.empty

    def is_almost_full(self, fsm):
        fsm.goto_next()
        return self.almost_full

    def is_full(self, fsm):
        fsm.goto_next()
        return self.full
Example #22
0
class AxiSlave(object):
    burst_size_width = 8

    def __init__(self, m, name, clk, rst, datawidth=32, addrwidth=32,
                 nodataflow=False):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.datawidth = datawidth
        self.addrwidth = addrwidth

        self.waddr = AxiSlaveWriteAddress(m, name, datawidth, addrwidth)
        self.wdata = AxiSlaveWriteData(m, name, datawidth, addrwidth)
        self.raddr = AxiSlaveReadAddress(m, name, datawidth, addrwidth)
        self.rdata = AxiSlaveReadData(m, name, datawidth, addrwidth)

        self.seq = Seq(m, name, clk, rst)

        self.write_counters = []
        self.read_counters = []

        if nodataflow:
            self.df = None
        else:
            self.df = DataflowManager(self.m, self.clk, self.rst)

        self._write_disabled = False
        self._read_disabled = False

    def disable_write(self):
        self.waddr.awready.assign(0)
        self.wdata.wready.assign(0)
        self._write_disabled = True

    def disable_read(self):
        self.raddr.arready.assign(0)
        self.seq(
            self.rdata.rvalid(0),
            self.rdata.rlast(0)
        )
        self._read_disabled = True

    def pull_write_request(self, cond=None, counter=None):
        """
        @return addr, counter, valid
        """

        if self._write_disabled:
            raise TypeError('Write disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.m.TmpReg(self.burst_size_width, initval=0)

        self.write_counters.append(counter)

        ready = make_condition(cond)

        ack = vtypes.Ands(self.waddr.awready, self.waddr.awvalid)
        addr = self.m.TmpReg(self.addrwidth, initval=0)
        valid = self.m.TmpReg(initval=0)

        val = (vtypes.Not(valid) if ready is None else
               vtypes.Ands(ready, vtypes.Not(valid)))

        prev_subst = self.waddr.awready._get_subst()
        if not prev_subst:
            self.waddr.awready.assign(val)
        else:
            self.waddr.awready.subst[0].overwrite_right(
                vtypes.Ors(prev_subst[0].right, val))

        self.seq.If(ack)(
            addr(self.waddr.awaddr),
            counter(self.waddr.awlen + 1)
        )

        self.seq(
            valid(ack)
        )

        return addr, counter, valid

    def pull_write_data(self, counter=None, cond=None):
        """
        @return data, mask, valid, last
        """

        if self._write_disabled:
            raise TypeError('Write disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.write_counters[-1]

        ready = make_condition(cond)
        val = 1 if ready is None else ready

        prev_subst = self.wdata.wready._get_subst()
        if not prev_subst:
            self.wdata.wready.assign(val)
        else:
            self.wdata.wready.subst[0].overwrite_right(
                vtypes.Ors(prev_subst[0].right, val))

        ack = vtypes.Ands(self.wdata.wready, self.wdata.wvalid)
        data = self.wdata.wdata
        mask = self.wdata.wstrb
        valid = ack
        last = self.wdata.wlast

        self.seq.If(vtypes.Ands(ack, counter > 0))(
            counter.dec()
        )

        return data, mask, valid, last

    def pull_write_dataflow(self, counter=None, cond=None):
        """
        @return data, mask, last, done
        """

        if self._write_disabled:
            raise TypeError('Write disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.write_counters[-1]

        data_ready = self.m.TmpWire()
        mask_ready = self.m.TmpWire()
        last_ready = self.m.TmpWire()
        data_ready.assign(1)
        mask_ready.assign(1)
        last_ready.assign(1)

        if cond is None:
            cond = (data_ready, last_ready)
        elif isinstance(cond, (tuple, list)):
            cond = tuple(list(cond) + [data_ready, last_ready])
        else:
            cond = (cond, data_ready, last_ready)

        ready = make_condition(*cond)
        val = 1 if ready is None else ready

        prev_subst = self.wdata.wready._get_subst()
        if not prev_subst:
            self.wdata.wready.assign(val)
        else:
            self.wdata.wready.subst[0].overwrite_right(
                vtypes.Ors(prev_subst[0].right, val))

        ack = vtypes.Ands(self.wdata.wready, self.wdata.wvalid)
        data = self.wdata.wdata
        mask = self.wdata.wstrb
        valid = self.wdata.wvalid
        last = self.wdata.wlast

        self.seq.If(vtypes.Ands(ack, counter > 0))(
            counter.dec()
        )

        df = self.df if self.df is not None else dataflow

        df_data = df.Variable(data, valid, data_ready)
        df_mask = df.Variable(mask, valid, mask_ready,
                              width=self.datawidth // 4)
        df_last = df.Variable(last, valid, last_ready, width=1)
        done = last

        return df_data, df_mask, df_last, done

    def pull_read_request(self, cond=None, counter=None):
        """
        @return addr, counter, valid
        """

        if self._read_disabled:
            raise TypeError('Read disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.m.TmpReg(self.burst_size_width, initval=0)

        self.read_counters.append(counter)

        ready = make_condition(cond)

        ack = vtypes.Ands(self.raddr.arready, self.raddr.arvalid)
        addr = self.m.TmpReg(self.addrwidth, initval=0)
        valid = self.m.TmpReg(initval=0)

        val = (vtypes.Not(valid) if ready is None else
               vtypes.Ands(ready, vtypes.Not(valid)))

        prev_subst = self.raddr.arready._get_subst()
        if not prev_subst:
            self.raddr.arready.assign(val)
        else:
            self.raddr.arready.subst[0].overwrite_right(
                vtypes.Ors(prev_subst[0].right, val))

        self.seq.If(ack)(
            addr(self.raddr.araddr),
            counter(self.raddr.arlen + 1)
        )

        self.seq(
            valid(ack)
        )

        return addr, counter, valid

    def push_read_data(self, data, counter=None, cond=None):
        """
        @return ack, last
        """

        if self._read_disabled:
            raise TypeError('Read disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.read_counters[-1]

        if cond is not None:
            self.seq.If(cond)

        ack = vtypes.Ands(counter > 0,
                          vtypes.Ors(self.rdata.rready, vtypes.Not(self.rdata.rvalid)))
        last = self.m.TmpReg(initval=0)

        self.seq.If(vtypes.Ands(ack, counter > 0))(
            self.rdata.rdata(data),
            self.rdata.rvalid(1),
            self.rdata.rlast(0),
            counter.dec()
        )
        self.seq.Then().If(counter == 1)(
            self.rdata.rlast(1),
            last(1)
        )

        # de-assert
        self.seq.Delay(1)(
            self.rdata.rvalid(0),
            self.rdata.rlast(0),
            last(0)
        )

        # retry
        self.seq.If(vtypes.Ands(self.rdata.rvalid, vtypes.Not(self.rdata.rready)))(
            self.rdata.rvalid(self.rdata.rvalid),
            self.rdata.rlast(self.rdata.rlast),
            last(last)
        )

        return ack, last

    def push_read_dataflow(self, data, counter=None, cond=None):
        """ 
        @return done
        """

        if self._read_disabled:
            raise TypeError('Read disabled.')

        if counter is not None and not isinstance(counter, vtypes.Reg):
            raise TypeError("counter must be Reg or None.")

        if counter is None:
            counter = self.read_counters[-1]

        ack = vtypes.Ands(counter > 0,
                          vtypes.Ors(self.rdata.rready, vtypes.Not(self.rdata.rvalid)))
        last = self.m.TmpReg(initval=0)

        if cond is None:
            cond = ack
        else:
            cond = (cond, ack)

        raw_data, raw_valid = data.read(cond=cond)

        # write condition
        self.seq.If(raw_valid)

        self.seq.If(vtypes.Ands(ack, counter > 0))(
            self.rdata.rdata(raw_data),
            self.rdata.rvalid(1),
            self.rdata.rlast(0),
            counter.dec()
        )
        self.seq.Then().If(counter == 1)(
            self.rdata.rlast(1),
            last(1)
        )

        # de-assert
        self.seq.Delay(1)(
            self.rdata.rvalid(0),
            self.rdata.rlast(0),
            last(0)
        )

        # retry
        self.seq.If(vtypes.Ands(self.rdata.rvalid, vtypes.Not(self.rdata.rready)))(
            self.rdata.rvalid(self.rdata.rvalid),
            self.rdata.rlast(self.rdata.rlast),
            last(last)
        )

        done = last

        return done
Example #23
0
class SyncRAMManager(object):
    def __init__(self,
                 m,
                 name,
                 clk,
                 rst,
                 datawidth=32,
                 addrwidth=10,
                 numports=1,
                 nodataflow=False):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.datawidth = datawidth
        self.addrwidth = addrwidth
        self.interfaces = [
            RAMInterface(m, name + '_%d' % i, datawidth, addrwidth)
            for i in range(numports)
        ]

        self.definition = mkRAMDefinition(name, datawidth, addrwidth, numports)
        self.inst = self.m.Instance(self.definition,
                                    'inst_' + name,
                                    ports=m.connect_ports(self.definition))

        self.seq = Seq(m, name, clk, rst)

        if nodataflow:
            self.df = None
        else:
            self.df = DataflowManager(self.m, self.clk, self.rst)

        self._write_disabled = [False for i in range(numports)]

    def __getitem__(self, index):
        return self.interfaces[index]

    def disable_write(self, port):
        self.seq(self.interfaces[port].wdata(0),
                 self.interfaces[port].wenable(0))
        self._write_disabled[port] = True

    def write(self, port, addr, wdata, cond=None):
        """ 
        @return None
        """

        if self._write_disabled[port]:
            raise TypeError('Write disabled.')

        if cond is not None:
            self.seq.If(cond)

        self.seq(self.interfaces[port].addr(addr),
                 self.interfaces[port].wdata(wdata),
                 self.interfaces[port].wenable(1))

        self.seq.Then().Delay(1)(self.interfaces[port].wenable(0))

    def write_dataflow(self, port, addr, data, length=1, cond=None, when=None):
        """ 
        @return done
        """

        if self._write_disabled[port]:
            raise TypeError('Write disabled.')

        counter = self.m.TmpReg(length.bit_length() + 1, initval=0)
        last = self.m.TmpReg(initval=0)

        ext_cond = make_condition(cond)
        data_cond = make_condition(counter > 0, vtypes.Not(last))
        all_cond = make_condition(data_cond, ext_cond)
        raw_data, raw_valid = data.read(cond=data_cond)

        when_cond = make_condition(when, ready=data_cond)
        if when_cond is not None:
            raw_valid = vtypes.Ands(when_cond, raw_valid)

        self.seq.If(make_condition(ext_cond, counter == 0))(
            self.interfaces[port].addr(addr - 1),
            counter(length),
        )

        self.seq.If(make_condition(
            raw_valid, counter > 0))(self.interfaces[port].addr.inc(),
                                     self.interfaces[port].wdata(raw_data),
                                     self.interfaces[port].wenable(1),
                                     counter.dec())

        self.seq.If(make_condition(raw_valid, counter == 1))(last(1))

        # de-assert
        self.seq.Delay(1)(self.interfaces[port].wenable(0), last(0))

        done = last

        return done

    def read(self, port, addr, cond=None):
        """ 
        @return data, valid
        """

        if cond is not None:
            self.seq.If(cond)

        self.seq(self.interfaces[port].addr(addr))

        rdata = self.interfaces[port].rdata
        rvalid = self.m.TmpReg(initval=0)
        self.seq.Then().Delay(1)(rvalid(1))
        self.seq.Then().Delay(2)(rvalid(0))

        return rdata, rvalid

    def read_dataflow(self, port, addr, length=1, cond=None):
        """ 
        @return data, last, done
        """

        data_valid = self.m.TmpReg(initval=0)
        last_valid = self.m.TmpReg(initval=0)
        data_ready = self.m.TmpWire()
        last_ready = self.m.TmpWire()
        data_ready.assign(1)
        last_ready.assign(1)

        data_ack = vtypes.Ors(data_ready, vtypes.Not(data_valid))
        last_ack = vtypes.Ors(last_ready, vtypes.Not(last_valid))

        ext_cond = make_condition(cond)
        data_cond = make_condition(data_ack, last_ack)
        prev_data_cond = self.seq.Prev(data_cond, 1)
        all_cond = make_condition(data_cond, ext_cond)

        data = self.m.TmpWireLike(self.interfaces[port].rdata)
        prev_data = self.seq.Prev(data, 1)
        data.assign(
            vtypes.Mux(prev_data_cond, self.interfaces[port].rdata, prev_data))

        counter = self.m.TmpReg(length.bit_length() + 1, initval=0)

        next_valid_on = self.m.TmpReg(initval=0)
        next_valid_off = self.m.TmpReg(initval=0)

        next_last = self.m.TmpReg(initval=0)
        last = self.m.TmpReg(initval=0)

        self.seq.If(make_condition(data_cond,
                                   next_valid_off))(last(0), data_valid(0),
                                                    last_valid(0),
                                                    next_valid_off(0))
        self.seq.If(make_condition(data_cond,
                                   next_valid_on))(data_valid(1),
                                                   last_valid(1),
                                                   last(next_last),
                                                   next_last(0),
                                                   next_valid_on(0),
                                                   next_valid_off(1))
        self.seq.If(
            make_condition(ext_cond, counter == 0, vtypes.Not(next_last),
                           vtypes.Not(last)))(
                               self.interfaces[port].addr(addr),
                               counter(length - 1),
                               next_valid_on(1),
                           )
        self.seq.If(make_condition(data_cond, counter > 0))(
            self.interfaces[port].addr.inc(), counter.dec(), next_valid_on(1),
            next_last(0))
        self.seq.If(make_condition(data_cond, counter == 1))(next_last(1))

        df = self.df if self.df is not None else dataflow

        df_data = df.Variable(data, data_valid, data_ready)
        df_last = df.Variable(last, last_valid, last_ready, width=1)
        done = last

        return df_data, df_last, done
Example #24
0
class SyncRAMManager(object):
    def __init__(self,
                 m,
                 name,
                 clk,
                 rst,
                 datawidth=32,
                 addrwidth=10,
                 numports=1,
                 nodataflow=False):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.datawidth = datawidth
        self.addrwidth = addrwidth
        self.numports = numports

        self.interfaces = [
            RAMInterface(m, name + '_%d' % i, datawidth, addrwidth)
            for i in range(numports)
        ]

        self.definition = mkRAMDefinition(name, datawidth, addrwidth, numports)

        self.inst = self.m.Instance(self.definition,
                                    'inst_' + name,
                                    ports=m.connect_ports(self.definition))

        self.seq = Seq(m, name, clk, rst)

        if nodataflow:
            self.df = None
        else:
            self.df = DataflowManager(self.m, self.clk, self.rst)

        self._write_disabled = [False for i in range(numports)]
        self._port_disabled = [False for i in range(numports)]

    def __getitem__(self, index):
        return self.interfaces[index]

    def disable_write(self, port):
        self.seq(self.interfaces[port].wdata(0),
                 self.interfaces[port].wenable(0))
        self._write_disabled[port] = True

    def disable_port(self, port):
        self.seq(self.interfaces[port].addr(0), )
        self._port_disabled[port] = True

    def read(self, port, addr, cond=None):
        """ 
        @return data, valid
        """
        return self.read_data(port, addr, cond)

    def read_data(self, port, addr, cond=None):
        """ 
        @return data, valid
        """

        if cond is not None:
            self.seq.If(cond)

        self.seq(self.interfaces[port].addr(addr))

        rdata = self.interfaces[port].rdata
        rvalid = self.m.TmpReg(initval=0)

        self.seq.Then().Delay(1)(rvalid(1))

        self.seq.Then().Delay(2)(rvalid(0))

        return rdata, rvalid

    def read_dataflow(self,
                      port,
                      addr,
                      length=1,
                      stride=1,
                      cond=None,
                      point=0,
                      signed=False):
        """ 
        @return data, last, done
        """

        data_valid = self.m.TmpReg(initval=0)
        last_valid = self.m.TmpReg(initval=0)
        data_ready = self.m.TmpWire()
        last_ready = self.m.TmpWire()
        data_ready.assign(1)
        last_ready.assign(1)

        data_ack = vtypes.Ors(data_ready, vtypes.Not(data_valid))
        last_ack = vtypes.Ors(last_ready, vtypes.Not(last_valid))

        ext_cond = make_condition(cond)
        data_cond = make_condition(data_ack, last_ack)
        prev_data_cond = self.seq.Prev(data_cond, 1)

        data = self.m.TmpWireLike(self.interfaces[port].rdata)

        prev_data = self.seq.Prev(data, 1)
        data.assign(
            vtypes.Mux(prev_data_cond, self.interfaces[port].rdata, prev_data))

        next_valid_on = self.m.TmpReg(initval=0)
        next_valid_off = self.m.TmpReg(initval=0)

        next_last = self.m.TmpReg(initval=0)
        last = self.m.TmpReg(initval=0)

        counter = self.m.TmpReg(length.bit_length() + 1, initval=0)

        self.seq.If(data_cond, next_valid_off)(last(0), data_valid(0),
                                               last_valid(0),
                                               next_valid_off(0))

        self.seq.If(data_cond, next_valid_on)(data_valid(1), last_valid(1),
                                              last(next_last), next_last(0),
                                              next_valid_on(0),
                                              next_valid_off(1))

        self.seq.If(ext_cond, counter == 0, vtypes.Not(next_last),
                    vtypes.Not(last))(
                        self.interfaces[port].addr(addr),
                        counter(length - 1),
                        next_valid_on(1),
                    )

        self.seq.If(data_cond, counter > 0)(
            self.interfaces[port].addr(self.interfaces[port].addr + stride),
            counter.dec(), next_valid_on(1), next_last(0))

        self.seq.If(data_cond, counter == 1)(next_last(1))

        df = self.df if self.df is not None else dataflow

        df_data = df.Variable(data,
                              data_valid,
                              data_ready,
                              width=self.datawidth,
                              point=point,
                              signed=signed)
        df_last = df.Variable(last, last_valid, last_ready, width=1)
        done = last

        return df_data, df_last, done

    def read_dataflow_pattern(self,
                              port,
                              addr,
                              pattern,
                              cond=None,
                              point=0,
                              signed=False):
        """ 
        @return data, last, done
        """

        if not isinstance(pattern, (tuple, list)):
            raise TypeError('pattern must be list or tuple.')

        if not pattern:
            raise ValueError(
                'pattern must have one (size, stride) pair at least.')

        if not isinstance(pattern[0], (tuple, list)):
            pattern = (pattern, )

        data_valid = self.m.TmpReg(initval=0)
        last_valid = self.m.TmpReg(initval=0)
        data_ready = self.m.TmpWire()
        last_ready = self.m.TmpWire()
        data_ready.assign(1)
        last_ready.assign(1)

        data_ack = vtypes.Ors(data_ready, vtypes.Not(data_valid))
        last_ack = vtypes.Ors(last_ready, vtypes.Not(last_valid))

        ext_cond = make_condition(cond)
        data_cond = make_condition(data_ack, last_ack)
        prev_data_cond = self.seq.Prev(data_cond, 1)

        data = self.m.TmpWireLike(self.interfaces[port].rdata)

        prev_data = self.seq.Prev(data, 1)
        data.assign(
            vtypes.Mux(prev_data_cond, self.interfaces[port].rdata, prev_data))

        next_valid_on = self.m.TmpReg(initval=0)
        next_valid_off = self.m.TmpReg(initval=0)

        next_last = self.m.TmpReg(initval=0)
        last = self.m.TmpReg(initval=0)

        running = self.m.TmpReg(initval=0)

        next_addr = self.m.TmpWire(self.addrwidth)
        offset_addr = self.m.TmpWire(self.addrwidth)
        offsets = [
            self.m.TmpReg(self.addrwidth, initval=0) for _ in pattern[1:]
        ]

        offset_addr_value = addr
        for offset in offsets:
            offset_addr_value = offset + offset_addr_value
        offset_addr.assign(offset_addr_value)

        offsets.insert(0, None)

        count_list = [
            self.m.TmpReg(out_size.bit_length() + 1, initval=0)
            for (out_size, out_stride) in pattern
        ]

        self.seq.If(data_cond, next_valid_off)(last(0), data_valid(0),
                                               last_valid(0),
                                               next_valid_off(0))

        self.seq.If(data_cond, next_valid_on)(data_valid(1), last_valid(1),
                                              last(next_last), next_last(0),
                                              next_valid_on(0),
                                              next_valid_off(1))

        self.seq.If(ext_cond, vtypes.Not(running), vtypes.Not(next_last),
                    vtypes.Not(last))(self.interfaces[port].addr(addr),
                                      running(1), next_valid_on(1))

        self.seq.If(data_cond, running)(self.interfaces[port].addr(next_addr),
                                        next_valid_on(1), next_last(0))

        update_count = None
        update_offset = None
        update_addr = None
        last_one = None
        stride_value = None
        carry = None

        for offset, count, (out_size,
                            out_stride) in zip(offsets, count_list, pattern):
            self.seq.If(ext_cond, vtypes.Not(running), vtypes.Not(next_last),
                        vtypes.Not(last))(count(out_size - 1))
            self.seq.If(data_cond, running, update_count)(count.dec())
            self.seq.If(data_cond, running, update_count,
                        count == 0)(count(out_size - 1))

            if offset is not None:
                self.seq.If(ext_cond, vtypes.Not(running),
                            vtypes.Not(next_last), vtypes.Not(last))(offset(0))
                self.seq.If(data_cond, running, update_offset,
                            vtypes.Not(carry))(offset(offset + out_stride))
                self.seq.If(data_cond, running, update_offset,
                            count == 0)(offset(0))

            if update_count is None:
                update_count = count == 0
            else:
                update_count = vtypes.Ands(update_count, count == 0)

            if update_offset is None:
                update_offset = vtypes.Mux(out_size == 1, 1, count == 1)
            else:
                update_offset = vtypes.Ands(update_offset, count == carry)

            if update_addr is None:
                update_addr = count == 0
            else:
                update_addr = vtypes.Mux(carry, count == 0, update_addr)

            if last_one is None:
                last_one = count == 0
            else:
                last_one = vtypes.Ands(last_one, count == 0)

            if stride_value is None:
                stride_value = out_stride
            else:
                stride_value = vtypes.Mux(carry, out_stride, stride_value)

            if carry is None:
                carry = out_size == 1
            else:
                carry = vtypes.Ands(carry, out_size == 1)

        next_addr.assign(
            vtypes.Mux(update_addr, offset_addr,
                       self.interfaces[port].addr + stride_value))

        self.seq.If(data_cond, running, last_one)(running(0), next_last(1))

        df = self.df if self.df is not None else dataflow

        df_data = df.Variable(data,
                              data_valid,
                              data_ready,
                              width=self.datawidth,
                              point=point,
                              signed=signed)
        df_last = df.Variable(last, last_valid, last_ready, width=1)
        done = last

        return df_data, df_last, done

    def read_dataflow_multidim(self,
                               port,
                               addr,
                               shape,
                               order=None,
                               cond=None,
                               point=0,
                               signed=False):
        """ 
        @return data, last, done
        """

        if order is None:
            order = list(range(len(shape)))

        pattern = self._to_pattern(shape, order)
        return self.read_dataflow_pattern(port,
                                          addr,
                                          pattern,
                                          cond=cond,
                                          point=point,
                                          signed=signed)

    def read_dataflow_reuse(self,
                            port,
                            addr,
                            length=1,
                            stride=1,
                            reuse_size=1,
                            num_outputs=1,
                            cond=None,
                            point=0,
                            signed=False):
        """ 
        @return data, last, done
        """

        if not isinstance(num_outputs, int):
            raise TypeError('num_outputs must be int')

        data_valid = [self.m.TmpReg(initval=0) for _ in range(num_outputs)]
        last_valid = self.m.TmpReg(initval=0)
        data_ready = [self.m.TmpWire() for _ in range(num_outputs)]
        last_ready = self.m.TmpWire()

        for r in data_ready:
            r.assign(1)
        last_ready.assign(1)

        data_ack = vtypes.Ands(*[
            vtypes.Ors(r, vtypes.Not(v))
            for v, r in zip(data_valid, data_ready)
        ])
        last_ack = vtypes.Ors(last_ready, vtypes.Not(last_valid))

        ext_cond = make_condition(cond)
        data_cond = make_condition(data_ack, last_ack)

        counter = self.m.TmpReg(length.bit_length() + 1, initval=0)

        last = self.m.TmpReg(initval=0)
        reuse_data = [
            self.m.TmpReg(self.datawidth, initval=0)
            for _ in range(num_outputs)
        ]
        next_reuse_data = [
            self.m.TmpReg(self.datawidth, initval=0)
            for _ in range(num_outputs)
        ]

        reuse_count = self.m.TmpReg(reuse_size.bit_length() + 1, initval=0)
        fill_reuse_count = self.m.TmpReg(initval=0)
        fetch_done = self.m.TmpReg(initval=0)

        fsm = TmpFSM(self.m, self.clk, self.rst)

        # initial state
        fsm.If(ext_cond)(self.interfaces[port].addr(addr - stride),
                         fetch_done(0), counter(length))
        fsm.If(ext_cond, length > 0).goto_next()

        # initial prefetch state
        for n in next_reuse_data:
            fsm(
                self.interfaces[port].addr(self.interfaces[port].addr +
                                           stride),
                counter(vtypes.Mux(counter > 0, counter - 1, counter)))
            fsm.Delay(2)(n(self.interfaces[port].rdata))
            fsm.goto_next()

        fsm.goto_next()
        fsm.goto_next()

        # initial update state
        for n, r in zip(next_reuse_data, reuse_data):
            fsm(r(n))

        fsm(fill_reuse_count(1), fetch_done(counter == 0))
        fsm.Delay(1)(fill_reuse_count(0))

        fsm.goto_next()

        # prefetch state
        read_start_state = fsm.current

        for n in next_reuse_data:
            fsm(
                self.interfaces[port].addr(self.interfaces[port].addr +
                                           stride),
                counter(vtypes.Mux(counter > 0, counter - 1, counter)))
            fsm.Delay(2)(n(self.interfaces[port].rdata))
            fsm.goto_next()

        fsm.goto_next()
        fsm.goto_next()

        # update state
        for n, r in zip(next_reuse_data, reuse_data):
            fsm.If(data_cond, reuse_count == 0)(r(n))

        fsm.If(data_cond,
               reuse_count == 0)(fill_reuse_count(vtypes.Not(fetch_done)),
                                 fetch_done(counter == 0))
        fsm.Delay(1)(fill_reuse_count(0))

        # next -> prefetch state or initial state
        fsm.If(data_cond, reuse_count == 0, counter == 0).goto_init()
        fsm.If(data_cond, reuse_count == 0, counter > 0).goto(read_start_state)

        # output signal control
        self.seq.If(data_cond, last_valid)(last(0), [d(0) for d in data_valid],
                                           last_valid(0))

        self.seq.If(fill_reuse_count)(reuse_count(reuse_size))

        self.seq.If(data_cond, reuse_count > 0)(reuse_count.dec(),
                                                [d(1) for d in data_valid],
                                                last_valid(1), last(0))

        self.seq.If(data_cond, reuse_count == 1, fetch_done)(last(1))

        df = self.df if self.df is not None else dataflow

        df_last = df.Variable(last, last_valid, last_ready, width=1)
        done = last

        df_reuse_data = [
            df.Variable(d,
                        v,
                        r,
                        width=self.datawidth,
                        point=point,
                        signed=signed)
            for d, v, r in zip(reuse_data, data_valid, data_ready)
        ]

        return tuple(df_reuse_data + [df_last, done])

    def read_dataflow_reuse_pattern(self,
                                    port,
                                    addr,
                                    pattern,
                                    reuse_size=1,
                                    num_outputs=1,
                                    cond=None,
                                    point=0,
                                    signed=False):
        """ 
        @return data, last, done
        """

        if not isinstance(pattern, (tuple, list)):
            raise TypeError('pattern must be list or tuple.')

        if not pattern:
            raise ValueError(
                'pattern must have one (size, stride) pair at least.')

        if not isinstance(pattern[0], (tuple, list)):
            pattern = (pattern, )

        if not isinstance(num_outputs, int):
            raise TypeError('num_outputs must be int')

        data_valid = [self.m.TmpReg(initval=0) for _ in range(num_outputs)]
        last_valid = self.m.TmpReg(initval=0)
        data_ready = [self.m.TmpWire() for _ in range(num_outputs)]
        last_ready = self.m.TmpWire()

        for r in data_ready:
            r.assign(1)
        last_ready.assign(1)

        data_ack = vtypes.Ands(*[
            vtypes.Ors(r, vtypes.Not(v))
            for v, r in zip(data_valid, data_ready)
        ])
        last_ack = vtypes.Ors(last_ready, vtypes.Not(last_valid))

        ext_cond = make_condition(cond)
        data_cond = make_condition(data_ack, last_ack)

        next_addr = self.m.TmpWire(self.addrwidth)
        offset_addr = self.m.TmpWire(self.addrwidth)
        offsets = [
            self.m.TmpReg(self.addrwidth, initval=0) for _ in pattern[1:]
        ]

        offset_addr_value = addr
        for offset in offsets:
            offset_addr_value = offset + offset_addr_value
        offset_addr.assign(offset_addr_value)

        offsets.insert(0, None)

        count_list = [
            self.m.TmpReg(out_size.bit_length() + 1, initval=0)
            for (out_size, out_stride) in pattern
        ]

        last = self.m.TmpReg(initval=0)
        reuse_data = [
            self.m.TmpReg(self.datawidth, initval=0)
            for _ in range(num_outputs)
        ]
        next_reuse_data = [
            self.m.TmpReg(self.datawidth, initval=0)
            for _ in range(num_outputs)
        ]

        reuse_count = self.m.TmpReg(reuse_size.bit_length() + 1, initval=0)
        fill_reuse_count = self.m.TmpReg(initval=0)

        prefetch_done = self.m.TmpReg(initval=0)
        fetch_done = self.m.TmpReg(initval=0)

        update_addr = None
        stride_value = None
        carry = None

        for offset, count, (out_size,
                            out_stride) in zip(offsets, count_list, pattern):
            if update_addr is None:
                update_addr = count == 0
            else:
                update_addr = vtypes.Mux(carry, count == 0, update_addr)

            if stride_value is None:
                stride_value = out_stride
            else:
                stride_value = vtypes.Mux(carry, out_stride, stride_value)

            if carry is None:
                carry = out_size == 1
            else:
                carry = vtypes.Ands(carry, out_size == 1)

        next_addr.assign(
            vtypes.Mux(update_addr, offset_addr,
                       self.interfaces[port].addr + stride_value))

        fsm = TmpFSM(self.m, self.clk, self.rst)

        # initial state
        fsm.If(ext_cond)(self.interfaces[port].addr(addr - stride_value),
                         prefetch_done(0), fetch_done(0))

        first = True
        for offset, count, (out_size,
                            out_stride) in zip(offsets, count_list, pattern):
            fsm.If(ext_cond)(count(out_size) if first else count(out_size -
                                                                 1), )
            if offset is not None:
                fsm.If(ext_cond)(offset(0))
            first = False

        fsm.If(ext_cond).goto_next()

        # initial prefetch state
        for n in next_reuse_data:
            update_count = None
            update_offset = None
            last_one = None
            carry = None

            for offset, count, (out_size,
                                out_stride) in zip(offsets, count_list,
                                                   pattern):
                fsm.If(update_count)(count.dec())
                fsm.If(update_count, count == 0)(count(out_size - 1))
                fsm(self.interfaces[port].addr(next_addr))
                fsm.Delay(2)(n(self.interfaces[port].rdata))

                if offset is not None:
                    fsm.If(update_offset,
                           vtypes.Not(carry))(offset(offset + out_stride))
                    fsm.If(update_offset, count == 0)(offset(0))

                if update_count is None:
                    update_count = count == 0
                else:
                    update_count = vtypes.Ands(update_count, count == 0)

                if update_offset is None:
                    update_offset = vtypes.Mux(out_size == 1, 1, count == 1)
                else:
                    update_offset = vtypes.Ands(update_offset, count == carry)

                if last_one is None:
                    last_one = count == 0
                else:
                    last_one = vtypes.Ands(last_one, count == 0)

                if carry is None:
                    carry = out_size == 1
                else:
                    carry = vtypes.Ands(carry, out_size == 1)

            fsm.goto_next()

            fsm.If(last_one)(prefetch_done(1))

        fsm.goto_next()
        fsm.goto_next()

        # initial update state
        for r, n in zip(reuse_data, next_reuse_data):
            fsm(r(n))

        fsm(fetch_done(prefetch_done),
            fill_reuse_count(vtypes.Not(fetch_done)))
        fsm.Delay(1)(fill_reuse_count(0))

        fsm.goto_next()

        # prefetch state
        read_start_state = fsm.current

        for n in next_reuse_data:
            update_count = None
            update_offset = None
            last_one = None
            carry = None

            for offset, count, (out_size,
                                out_stride) in zip(offsets, count_list,
                                                   pattern):
                fsm.If(update_count)(count.dec())
                fsm.If(update_count, count == 0)(count(out_size - 1))
                fsm(self.interfaces[port].addr(next_addr))
                fsm.Delay(2)(n(self.interfaces[port].rdata))

                if offset is not None:
                    fsm.If(update_offset,
                           vtypes.Not(carry))(offset(offset + out_stride))
                    fsm.If(update_offset, count == 0)(offset(0))

                if update_count is None:
                    update_count = count == 0
                else:
                    update_count = vtypes.Ands(update_count, count == 0)

                if update_offset is None:
                    update_offset = vtypes.Mux(out_size == 1, 1, count == 1)
                else:
                    update_offset = vtypes.Ands(update_offset, count == carry)

                if last_one is None:
                    last_one = count == 0
                else:
                    last_one = vtypes.Ands(last_one, count == 0)

                if carry is None:
                    carry = out_size == 1
                else:
                    carry = vtypes.Ands(carry, out_size == 1)

            fsm.goto_next()

            fsm.If(last_one)(prefetch_done(1))

        fsm.goto_next()
        fsm.goto_next()

        # update state
        for r, n in zip(reuse_data, next_reuse_data):
            fsm.If(data_cond, reuse_count == 0)(r(n))

        fsm.If(data_cond,
               reuse_count == 0)(fetch_done(prefetch_done),
                                 fill_reuse_count(vtypes.Not(fetch_done)))
        fsm.Delay(1)(fill_reuse_count(0))

        # next -> prefetch state or initial state
        fsm.If(data_cond, reuse_count == 0, fetch_done).goto_init()
        fsm.If(data_cond, reuse_count == 0,
               vtypes.Not(fetch_done)).goto(read_start_state)

        # output signal control
        self.seq.If(data_cond, last_valid)(last(0), [d(0) for d in data_valid],
                                           last_valid(0))

        self.seq.If(fill_reuse_count)(reuse_count(reuse_size))

        self.seq.If(data_cond, reuse_count > 0)(reuse_count.dec(),
                                                [d(1) for d in data_valid],
                                                last_valid(1), last(0))

        self.seq.If(data_cond, reuse_count == 1, fetch_done)(last(1))

        df = self.df if self.df is not None else dataflow

        df_last = df.Variable(last, last_valid, last_ready, width=1)
        done = last

        df_reuse_data = [
            df.Variable(d,
                        v,
                        r,
                        width=self.datawidth,
                        point=point,
                        signed=signed)
            for d, v, r in zip(reuse_data, data_valid, data_ready)
        ]

        return tuple(df_reuse_data + [df_last, done])

    def read_dataflow_reuse_multidim(self,
                                     port,
                                     addr,
                                     shape,
                                     order=None,
                                     reuse_size=1,
                                     num_outputs=1,
                                     cond=None,
                                     point=0,
                                     signed=False):
        """ 
        @return data, last, done
        """

        if order is None:
            order = list(range(len(shape)))

        pattern = self._to_pattern(shape, order)
        return self.read_dataflow_pattern(port,
                                          addr,
                                          pattern,
                                          reuse_size,
                                          num_outputs,
                                          cond=cond,
                                          point=point,
                                          signed=signed)

    def write(self, port, addr, wdata, cond=None):
        """ 
        @return None
        """
        return self.write_data(port, addr, wdata, cond)

    def write_data(self, port, addr, wdata, cond=None):
        """ 
        @return None
        """

        if self._write_disabled[port]:
            raise TypeError('Write disabled.')

        if cond is not None:
            self.seq.If(cond)

        self.seq(self.interfaces[port].addr(addr),
                 self.interfaces[port].wdata(wdata),
                 self.interfaces[port].wenable(1))

        self.seq.Then().Delay(1)(self.interfaces[port].wenable(0))

    def write_dataflow(self,
                       port,
                       addr,
                       data,
                       length=1,
                       stride=1,
                       cond=None,
                       when=None):
        """ 
        @return done
        'data' and 'when' must be dataflow variables
        """

        if self._write_disabled[port]:
            raise TypeError('Write disabled.')

        counter = self.m.TmpReg(length.bit_length() + 1, initval=0)
        last = self.m.TmpReg(initval=0)

        ext_cond = make_condition(cond)
        data_cond = make_condition(counter > 0, vtypes.Not(last))

        if when is None or not isinstance(when, df_numeric):
            raw_data, raw_valid = data.read(cond=data_cond)
        else:
            data_list, raw_valid = read_multi(self.m,
                                              data,
                                              when,
                                              cond=data_cond)
            raw_data = data_list[0]
            when = data_list[1]

        when_cond = make_condition(when, ready=data_cond)
        if when_cond is not None:
            raw_valid = vtypes.Ands(when_cond, raw_valid)

        self.seq.If(ext_cond, counter == 0)(
            self.interfaces[port].addr(addr - stride),
            counter(length),
        )

        self.seq.If(raw_valid, counter > 0)(
            self.interfaces[port].addr(self.interfaces[port].addr + stride),
            self.interfaces[port].wdata(raw_data),
            self.interfaces[port].wenable(1), counter.dec())

        self.seq.If(raw_valid, counter == 1)(last(1))

        # de-assert
        self.seq.Delay(1)(self.interfaces[port].wenable(0), last(0))

        done = last

        return done

    def write_dataflow_pattern(self,
                               port,
                               addr,
                               data,
                               pattern,
                               cond=None,
                               when=None):
        """ 
        @return done
        'data' and 'when' must be dataflow variables
        """

        if self._write_disabled[port]:
            raise TypeError('Write disabled.')

        if not isinstance(pattern, (tuple, list)):
            raise TypeError('pattern must be list or tuple.')

        if not pattern:
            raise ValueError(
                'pattern must have one (size, stride) pair at least.')

        if not isinstance(pattern[0], (tuple, list)):
            pattern = (pattern, )

        last = self.m.TmpReg(initval=0)

        running = self.m.TmpReg(initval=0)

        ext_cond = make_condition(cond)
        data_cond = make_condition(running, vtypes.Not(last))

        if when is None or not isinstance(when, df_numeric):
            raw_data, raw_valid = data.read(cond=data_cond)
        else:
            data_list, raw_valid = read_multi(self.m,
                                              data,
                                              when,
                                              cond=data_cond)
            raw_data = data_list[0]
            when = data_list[1]

        when_cond = make_condition(when, ready=data_cond)
        if when_cond is not None:
            raw_valid = vtypes.Ands(when_cond, raw_valid)

        offset_addr = self.m.TmpWire(self.addrwidth)
        offsets = [self.m.TmpReg(self.addrwidth, initval=0) for _ in pattern]

        offset_addr_value = addr
        for offset in offsets:
            offset_addr_value = offset + offset_addr_value
        offset_addr.assign(offset_addr_value)

        count_list = [
            self.m.TmpReg(out_size.bit_length() + 1, initval=0)
            for (out_size, out_stride) in pattern
        ]

        self.seq.If(ext_cond, vtypes.Not(running))(running(1))

        self.seq.If(raw_valid,
                    running)(self.interfaces[port].addr(offset_addr),
                             self.interfaces[port].wdata(raw_data),
                             self.interfaces[port].wenable(1))

        update_count = None
        last_one = None

        for offset, count, (out_size,
                            out_stride) in zip(offsets, count_list, pattern):
            self.seq.If(ext_cond, vtypes.Not(running))(count(out_size - 1),
                                                       offset(0))
            self.seq.If(raw_valid, running,
                        update_count)(count.dec(), offset(offset + out_stride))
            self.seq.If(raw_valid, running, update_count,
                        count == 0)(count(out_size - 1), offset(0))

            if update_count is None:
                update_count = count == 0
            else:
                update_count = vtypes.Ands(update_count, count == 0)

            if last_one is None:
                last_one = count == 0
            else:
                last_one = vtypes.Ands(last_one, count == 0)

        self.seq.If(raw_valid, last_one)(running(0), last(1))

        # de-assert
        self.seq.Delay(1)(self.interfaces[port].wenable(0), last(0))

        done = last

        return done

    def write_dataflow_multidim(self,
                                port,
                                addr,
                                data,
                                shape,
                                order=None,
                                cond=None,
                                when=None):
        """ 
        @return done
        'data' and 'when' must be dataflow variables
        """

        if order is None:
            order = list(range(len(shape)))

        pattern = self._to_pattern(shape, order)
        return self.write_dataflow_pattern(port,
                                           addr,
                                           data,
                                           pattern,
                                           cond=cond,
                                           when=when)

    def _to_pattern(self, shape, order):
        pattern = []
        for p in order:
            if not isinstance(p, int):
                raise TypeError("Values of 'order' must be 'int', not %s" %
                                str(type(p)))
            size = shape[p]
            basevalue = 1 if isinstance(size, int) else vtypes.Int(1)
            stride = functools.reduce(lambda x, y: x * y, shape[:p],
                                      basevalue) if p > 0 else basevalue
            pattern.append((size, stride))
        return pattern
Example #25
0
class Mutex(object):
    __intrinsics__ = ('lock', 'try_lock', 'unlock', 'acquire', 'release')

    def __init__(self, m, name, clk, rst, width=32):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.width = width

        self.seq = Seq(self.m, self.name, self.clk, self.rst)

        self.lock_reg = self.m.Reg('_'.join(['', self.name, 'lock_reg']),
                                   initval=0)
        self.lock_id = self.m.Reg('_'.join(['', self.name, 'lock_id']),
                                  self.width,
                                  initval=0)

        self.id_map = {}
        self.id_map_count = 0

    def lock(self, fsm):
        name = fsm.name
        new_lock_id = self._get_id(name)

        if new_lock_id > 2**self.width - 1:
            raise ValueError('too many lock IDs')

        # try
        try_state = fsm.current

        state_cond = fsm.state == fsm.current
        try_cond = vtypes.Not(self.lock_reg)
        fsm_cond = vtypes.Ors(try_cond, self.lock_id == new_lock_id)

        self.seq.If(state_cond, try_cond)(self.lock_reg(1),
                                          self.lock_id(new_lock_id))

        fsm.If(fsm_cond).goto_next()

        # verify
        cond = vtypes.Ands(self.lock_reg, self.lock_id == new_lock_id)
        fsm.If(vtypes.Not(cond)).goto(try_state)  # try again
        fsm.If(cond).goto_next()  # OK

        return 1

    def try_lock(self, fsm):
        name = fsm.name
        new_lock_id = self._get_id(name)

        if new_lock_id > 2**self.width - 1:
            raise ValueError('too many lock IDs')

        # try
        try_state = fsm.current

        state_cond = fsm.state == fsm.current
        try_cond = vtypes.Not(self.lock_reg)

        self.seq.If(state_cond, try_cond)(self.lock_reg(1),
                                          self.lock_id(new_lock_id))

        fsm.goto_next()

        # verify
        cond = vtypes.And(self.lock_reg, self.lock_id == new_lock_id)
        result = self.m.TmpReg(initval=0)
        fsm(result(cond))
        fsm.goto_next()

        return result

    def unlock(self, fsm):
        name = fsm.name
        new_lock_id = self._get_id(name)

        if new_lock_id > 2**self.width - 1:
            raise ValueError('too many lock IDs')

        state_cond = fsm.state == fsm.current

        self.seq.If(state_cond, self.lock_id == new_lock_id)(self.lock_reg(0))

        fsm.goto_next()

        return 0

    def _get_id(self, name):
        if name not in self.id_map:
            self.id_map[name] = self.id_map_count
            self.id_map_count += 1

        return self.id_map[name]

    def acquire(self, fsm, blocking=True):
        """ alias of lock() """

        if not isinstance(blocking, (bool, int)):
            raise TypeError('blocking argument must be bool')

        if blocking:
            return self.lock(fsm)

        return self.try_lock(fsm)

    def release(self, fsm):
        """ alias of unlock() """

        return self.unlock(fsm)
Example #26
0
class FIFO(_MutexFunction):
    __intrinsics__ = ('enq', 'deq', 'try_enq', 'try_deq',
                      'is_empty', 'is_almost_empty',
                      'is_full', 'is_almost_full') + _MutexFunction.__intrinsics__

    def __init__(self, m, name, clk, rst, datawidth=32, addrwidth=4):

        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst

        self.datawidth = datawidth
        self.addrwidth = addrwidth

        self.wif = FifoWriteInterface(self.m, name, datawidth)
        self.rif = FifoReadInterface(self.m, name, datawidth)

        self.definition = mkFifoDefinition(name, datawidth, addrwidth)

        self.inst = self.m.Instance(self.definition, 'inst_' + name,
                                    ports=m.connect_ports(self.definition))

        self.seq = Seq(m, name, clk, rst)

        # entry counter
        self._max_size = (2 ** self.addrwidth - 1 if isinstance(self.addrwidth, int) else
                          vtypes.Int(2) ** self.addrwidth - 1)

        self._count = self.m.Reg(
            'count_' + name, self.addrwidth + 1, initval=0)

        self.seq.If(
            vtypes.Ands(vtypes.Ands(self.wif.enq, vtypes.Not(self.wif.full)),
                        vtypes.Ands(self.rif.deq, vtypes.Not(self.rif.empty))))(
            self._count(self._count)
        ).Elif(vtypes.Ands(self.wif.enq, vtypes.Not(self.wif.full)))(
            self._count.inc()
        ).Elif(vtypes.Ands(self.rif.deq, vtypes.Not(self.rif.empty)))(
            self._count.dec()
        )

        self._enq_disabled = False
        self._deq_disabled = False

        self.mutex = None

    def disable_enq(self):
        self.seq(
            self.wif.enq(0)
        )
        self._enq_disabled = True

    def disable_deq(self):
        self.seq(
            self.rif.deq(0)
        )
        self._deq_disabled = True

    def enq_rtl(self, wdata, cond=None, delay=0):
        """ Enque """

        if self._enq_disabled:
            raise TypeError('Enq disabled.')

        if cond is not None:
            self.seq.If(cond)

        current_delay = self.seq.current_delay

        not_full = vtypes.Not(self.wif.full)
        ack = vtypes.Ands(not_full, self.wif.enq)
        if current_delay + delay == 0:
            ready = vtypes.Not(self.wif.almost_full)
        else:
            ready = self._count + (current_delay + delay + 1) < self._max_size

        self.seq.Delay(current_delay + delay).EagerVal().If(not_full)(
            self.wif.wdata(wdata)
        )
        self.seq.Then().Delay(current_delay + delay)(
            self.wif.enq(1)
        )

        # de-assert
        self.seq.Delay(current_delay + delay + 1)(
            self.wif.enq(0)
        )

        return ack, ready

    def deq_rtl(self, cond=None, delay=0):
        """ Deque """

        if self._deq_disabled:
            raise TypeError('Deq disabled.')

        if cond is not None:
            self.seq.If(cond)

        not_empty = vtypes.Not(self.rif.empty)

        current_delay = self.seq.current_delay

        self.seq.Delay(current_delay + delay)(
            self.rif.deq(1)
        )

        rdata = self.rif.rdata
        rvalid = self.m.TmpReg(initval=0)

        self.seq.Then().Delay(current_delay + delay + 1)(
            rvalid(vtypes.Ands(not_empty, self.rif.deq))
        )

        # de-assert
        self.seq.Delay(current_delay + delay + 1)(
            self.rif.deq(0)
        )
        self.seq.Delay(current_delay + delay + 2)(
            rvalid(0)
        )

        return rdata, rvalid

    @property
    def wdata(self):
        return self.wif.wdata

    @property
    def empty(self):
        return self.rif.empty

    @property
    def almost_empty(self):
        return self.rif.almost_empty

    @property
    def rdata(self):
        return self.rif.rdata

    @property
    def full(self):
        return self.wif.full

    @property
    def almost_full(self):
        return self.wif.almost_full

    @property
    def count(self):
        return self._count

    @property
    def space(self):
        if isinstance(self._max_size, int):
            return vtypes.Int(self._max_size) - self.count
        return self._max_size - self.count

    def has_space(self, num=1):
        if num < 1:
            return True
        return (self._count + num < self._max_size)

    def enq(self, fsm, wdata):
        cond = fsm.state == fsm.current

        ack, ready = self.enq_rtl(wdata, cond=cond)
        fsm.If(ready).goto_next()

        return 0

    def deq(self, fsm):
        cond = fsm.state == fsm.current

        rdata, rvalid = self.deq_rtl(cond=cond)
        fsm.If(vtypes.Not(self.empty)).goto_next()
        fsm.goto_next()

        rdata_reg = self.m.TmpReg(self.datawidth, initval=0, signed=True)

        fsm.If(rvalid)(
            rdata_reg(rdata)
        )
        fsm.If(rvalid).goto_next()

        return rdata_reg

    def try_enq(self, fsm, wdata):
        cond = fsm.state == fsm.current

        ack, ready = self.enq_rtl(wdata, cond=cond)
        fsm.goto_next()

        ack_reg = self.m.TmpReg(initval=0)
        fsm(
            ack_reg(ack)
        )
        fsm.goto_next()

        return ack_reg

    def try_deq(self, fsm):
        cond = fsm.state == fsm.current

        rdata, rvalid = self.deq_rtl(cond=cond)
        fsm.goto_next()
        fsm.goto_next()

        rdata_reg = self.m.TmpReg(self.datawidth, initval=0, signed=True)
        rvalid_reg = self.m.TmpReg(initval=0)

        fsm(
            rdata_reg(rdata),
            rvalid_reg(rvalid)
        )
        fsm.goto_next()

        return rdata_reg, rvalid_reg

    def is_almost_empty(self, fsm):
        fsm.goto_next()
        return self.almost_empty

    def is_empty(self, fsm):
        fsm.goto_next()
        return self.empty

    def is_almost_full(self, fsm):
        fsm.goto_next()
        return self.almost_full

    def is_full(self, fsm):
        fsm.goto_next()
        return self.full
Example #27
0
class Pipeline(vtypes.VeriloggenNode):
    """ Pipeline Generator """

    def __init__(self, m, name, clk, rst, width=32):
        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.width = width

        self.tmp_count = 0
        self.max_stage_id = 0
        self.vars = []

        self.seq = Seq(self.m, self.name, clk, rst, nohook=True)
        self.data_visitor = DataVisitor(self)

        self.done = False

    #-------------------------------------------------------------------------
    def input(self, data, valid=None, ready=None, width=None):
        if ready is not None and not isinstance(ready, (vtypes.Wire, vtypes.Output)):
            raise TypeError('ready port of PipelineVariable must be Wire., not %s' %
                            str(type(ready)))
        ret = _PipelineVariable(self, 0, data, valid, ready,
                                _PipelineInterface(data, valid, ready))
        self.vars.append(ret)
        return ret

    #-------------------------------------------------------------------------
    # self.__call__() calls this method
    def stage(self, data, initval=0, width=None, preg=None):
        if width is None:
            width = self.width
        stage_id, raw_data, raw_valid, raw_ready = self.data_visitor.visit(
            data)
        tmp_data, tmp_valid, tmp_ready = self._make_tmp(raw_data, raw_valid, raw_ready,
                                                        width, initval)
        next_stage_id = stage_id + 1 if stage_id is not None else None

        ret = _PipelineVariable(self, next_stage_id,
                                tmp_data, tmp_valid, tmp_ready, data)

        self.vars.append(ret)
        if isinstance(preg, _PipelineVariable):
            preg._add_preg(next_stage_id, ret)

        if next_stage_id is not None and next_stage_id > self.max_stage_id:
            self.max_stage_id = next_stage_id

        return ret

    #-------------------------------------------------------------------------
    # Accumulator
    def acc_and(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.And], data, width, initval, resetcond)

    def acc_nand(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.And, vtypes.Unot], data, width, initval, resetcond)

    def acc_or(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Or], data, width, initval, resetcond)

    def acc_xor(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Xor], data, width, initval, resetcond)

    def acc_xnor(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Xor, vtypes.Unot], data, width, initval, resetcond)

    def acc_nor(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Or, vtypes.Unot], data, width, initval, resetcond)

    def acc_add(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Plus], data, width, initval, resetcond)

    def acc_sub(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Minus], data, width, initval, resetcond)

    def acc_mul(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Times], data, width, initval, resetcond)

    def acc_div(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Divide], data, width, initval, resetcond)

    def acc_mod(self, data, initval=0, resetcond=None, width=None):
        return self._accumulate([vtypes.Mod], data, width, initval, resetcond)

    def acc_max(self, data, initval=0, resetcond=None, width=None):
        def op(left, right):
            return vtypes.Cond(left > right, left, right)
        return self._accumulate([op], data, width, initval, resetcond, 'max')

    def acc_min(self, data, initval=0, resetcond=None, width=None):
        def op(left, right):
            return vtypes.Cond(left < right, left, right)
        return self._accumulate([op], data, width, initval, resetcond, 'min')

    def acc_custom(self, data, ops, initval=0, resetcond=None, width=None, label=None):
        if not isinstance(ops, (tuple, list)):
            ops = [ops]
        return self._accumulate(ops, data, width, initval, resetcond, label)

    #-------------------------------------------------------------------------
    def make_always(self, reset=(), body=()):
        if self.done:
            raise ValueError('make_always() has been already called.')
        self.done = True

        for var in self.vars:
            if var.output_vars is not None:
                var.make_output()

        self.m.Always(vtypes.Posedge(self.clk))(
            vtypes.If(self.rst)(
                reset,
                self.make_reset()
            )(
                body,
                self.make_code()
            ))

    #-------------------------------------------------------------------------
    def make_reset(self):
        return self.seq.make_reset()

    #-------------------------------------------------------------------------
    def make_code(self):
        return self.seq.make_code()

    #-------------------------------------------------------------------------
    def draw_graph(self, filename='out.png', prog='dot'):
        _draw_graph(self, filename, prog)

    #-------------------------------------------------------------------------
    def _accumulate(self, ops, data, width=None, initval=0, resetcond=None, oplabel=None):
        if width is None:
            width = self.width
        stage_id, raw_data, raw_valid, raw_ready = self.data_visitor.visit(
            data)
        tmp_data, tmp_valid, tmp_ready = self._make_tmp(raw_data, raw_valid, raw_ready,
                                                        width, initval, acc_ops=ops)
        next_stage_id = stage_id + 1 if stage_id is not None else None

        ret = _PipelineVariable(self, next_stage_id, tmp_data, tmp_valid, tmp_ready,
                                data, ops, resetcond, initval, oplabel)

        if resetcond is not None:
            ret.reset(resetcond, initval)

        self.vars.append(ret)

        return ret

    #-------------------------------------------------------------------------
    def _add_reg(self, prefix, count, width=None, initval=0):
        tmp_name = '_'.join(['', self.name, prefix, str(count)])
        tmp = self.m.Reg(tmp_name, width, initval=initval)
        return tmp

    def _add_wire(self, prefix, count, width=None):
        tmp_name = '_'.join(['', self.name, prefix, str(count)])
        tmp = self.m.Wire(tmp_name, width)
        return tmp

    #-------------------------------------------------------------------------
    def _make_tmp(self, data, valid, ready, width=None, initval=0, acc_ops=()):
        tmp_data = self._add_reg(
            'data', self.tmp_count, width=width, initval=initval)

        if valid is not None:
            tmp_valid = self._add_reg('valid', self.tmp_count, initval=0)
        else:
            tmp_valid = None

        if ready:
            tmp_ready = self._add_wire('ready', self.tmp_count)
        else:
            tmp_ready = None

        self.tmp_count += 1

        # all ready
        all_ready = None
        for r in ready:
            if r is None:
                continue
            if all_ready is None:
                all_ready = r
            else:
                all_ready = vtypes.AndList(all_ready, r)

        # data
        data_cond_vars = []
        if valid is not None:
            data_cond_vars.append(valid)
        if tmp_ready is not None:
            data_cond_vars.append(all_ready)
            if tmp_valid is not None:
                data_cond_vars.append(vtypes.OrList(
                    tmp_ready, vtypes.Not(tmp_valid)))
            else:
                data_cond_vars.append(tmp_ready)

        if len(data_cond_vars) == 0:
            data_cond = None
        elif len(data_cond_vars) == 1:
            data_cond = data_cond_vars[0]
        else:
            data_cond = vtypes.AndList(*data_cond_vars)

        # Accumulator
        for op in acc_ops:
            if not isinstance(op, type):
                data = op(tmp_data, data)
            elif issubclass(op, vtypes._BinaryOperator):
                data = op(tmp_data, data)
            elif issubclass(op, vtypes._UnaryOperator):
                data = op(data)

            if not isinstance(data, vtypes._Numeric):
                raise TypeError("Operator '%s' returns unsupported object type '%s'."
                                % (str(op), str(type(data))))

        self.seq.add(tmp_data(data), cond=data_cond)

        # valid
        valid_cond_vars = []
        if tmp_ready is not None:
            valid_cond_vars.append(all_ready)
            ordy = vtypes.OrList(tmp_ready, vtypes.Not(tmp_valid))
            valid_cond_vars.append(ordy)

        if len(valid_cond_vars) == 0:
            valid_cond = None
        elif len(valid_cond_vars) == 1:
            valid_cond = valid_cond_vars[0]
        else:
            valid_cond = vtypes.AndList(*valid_cond_vars)

        if tmp_valid is not None:
            if tmp_ready is not None:
                self.seq.add(tmp_valid(0), cond=vtypes.AndList(
                    tmp_valid, tmp_ready))
            self.seq.add(tmp_valid(valid), cond=valid_cond)

        # ready
        if tmp_ready is not None:
            ordy = vtypes.AndList(vtypes.OrList(
                tmp_ready, vtypes.Not(tmp_valid)), valid)
            for r in ready:
                _connect_ready(self.m, r, ordy)

        return tmp_data, tmp_valid, tmp_ready

    #-------------------------------------------------------------------------
    def _make_prev(self, data, valid, ready, width=None, initval=0):
        tmp_data = self._add_reg(
            'data', self.tmp_count, width=width, initval=initval)
        tmp_valid = valid
        tmp_ready = ready

        self.tmp_count += 1

        if valid is not None and ready is not None:
            data_cond = vtypes.AndList(valid, ready)
        elif valid is not None:
            data_cond = valid
        elif ready is not None:
            data_cond = ready
        else:
            data_cond = None

        self.seq.add(tmp_data(data), cond=data_cond)

        return tmp_data, tmp_valid, tmp_ready

    #-------------------------------------------------------------------------
    def __call__(self, data, initval=0, width=None):
        return self.stage(data, initval=initval, width=width)
Example #28
0
    def implement(self,
                  m=None,
                  clock=None,
                  reset=None,
                  aswire=None,
                  seq_name=None):
        """ implemente actual registers and operations in Verilog """

        if m is None:
            m = self.module

        if clock is None:
            clock = self.clock

        if reset is None:
            reset = self.reset

        if self.seq is None or self.seq.done:
            if seq_name is None:
                seq_name = '_dataflow_seq_%d' % self.object_id
            seq = Seq(m, seq_name, clock, reset)
        else:
            seq = self.seq

        if aswire is None:
            aswire = self.aswire

        # for mult and div
        m._clock = clock
        m._reset = reset

        dataflow_nodes = self.nodes

        input_visitor = visitor.InputVisitor()
        input_vars = set()
        for node in sorted(dataflow_nodes, key=lambda x: x.object_id):
            input_vars.update(input_visitor.visit(node))

        output_visitor = visitor.OutputVisitor()
        output_vars = set()
        for node in sorted(dataflow_nodes, key=lambda x: x.object_id):
            output_vars.update(output_visitor.visit(node))

        # add input ports
        for input_var in sorted(input_vars, key=lambda x: x.object_id):
            input_var._implement_input(m, seq, aswire)

        # schedule
        sched = scheduler.ASAPScheduler()
        sched.schedule(output_vars)

        # balance output stage depth
        max_stage = 0
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            max_stage = dtypes._max(max_stage, output_var.end_stage)
        self.max_stage = max_stage

        output_vars = sched.balance_output(output_vars, max_stage)

        # get all vars
        all_visitor = visitor.AllVisitor()
        all_vars = set()
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            all_vars.update(all_visitor.visit(output_var))

        # allocate (implement signals)
        alloc = allocator.Allocator()
        alloc.allocate(m, seq, all_vars)

        # set default module information
        for var in sorted(all_vars, key=lambda x: x.object_id):
            var._set_module(m)
            var._set_df(self)

            if var.seq is not None:
                seq.update(var.seq)

            var._set_seq(seq)

        # add output ports
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            output_var._implement_output(m, seq, aswire)

        # save schedule result
        self.last_input = input_vars
        self.last_output = output_vars

        return m
Example #29
0
    def implement(self, m=None, clock=None, reset=None, aswire=None, seq_name=None):
        """ implemente actual registers and operations in Verilog """

        if m is None:
            m = self.module

        if clock is None:
            clock = self.clock

        if reset is None:
            reset = self.reset

        if self.seq is None or self.seq.done:
            if seq_name is None:
                seq_name = '_dataflow_seq_%d' % self.object_id
            seq = Seq(m, seq_name, clock, reset)
        else:
            seq = self.seq

        if aswire is None:
            aswire = self.aswire

        # for mult and div
        m._clock = clock
        m._reset = reset

        dataflow_nodes = self.nodes

        input_visitor = visitor.InputVisitor()
        input_vars = set()
        for node in sorted(dataflow_nodes, key=lambda x: x.object_id):
            input_vars.update(input_visitor.visit(node))

        output_visitor = visitor.OutputVisitor()
        output_vars = set()
        for node in sorted(dataflow_nodes, key=lambda x: x.object_id):
            output_vars.update(output_visitor.visit(node))

        # add input ports
        for input_var in sorted(input_vars, key=lambda x: x.object_id):
            input_var._implement_input(m, seq, aswire)

        # schedule
        sched = scheduler.ASAPScheduler()
        sched.schedule(output_vars)

        # balance output stage depth
        max_stage = 0
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            max_stage = dtypes._max(max_stage, output_var.end_stage)
        self.max_stage = max_stage

        output_vars = sched.balance_output(output_vars, max_stage)

        # get all vars
        all_visitor = visitor.AllVisitor()
        all_vars = set()
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            all_vars.update(all_visitor.visit(output_var))

        # allocate (implement signals)
        alloc = allocator.Allocator()
        alloc.allocate(m, seq, all_vars)

        # set default module information
        for var in sorted(all_vars, key=lambda x: x.object_id):
            var._set_module(m)
            var._set_df(self)

            if var.seq is not None:
                seq.update(var.seq)

            var._set_seq(seq)

        # add output ports
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            output_var._implement_output(m, seq, aswire)

        # save schedule result
        self.last_input = input_vars
        self.last_output = output_vars

        return m
Example #30
0
    def implement(self, m=None, clock=None, reset=None, aswire=None, seq_name=None):
        """ implemente actual registers and operations in Verilog """

        if self.implemented:
            if m is None:
                return self.module
            raise ValueError('already implemented.')

        self.implemented = True

        if m is None:
            m = self.module

        if self.module is None:
            self.module = m

        if clock is None:
            clock = self.clock

        if reset is None:
            reset = self.reset

        if self.seq is None:
            if seq_name is None:
                seq_name = '_stream_seq_%d' % self.object_id
            seq = Seq(m, seq_name, clock, reset)
        else:
            seq = self.seq

        if aswire is None:
            aswire = self.aswire

        self.add_control(aswire=aswire)
        self.has_control = True

        # for mult and div
        m._clock = clock
        m._reset = reset

        stream_nodes = self.nodes

        input_visitor = visitor.InputVisitor()
        input_vars = set()
        for node in sorted(stream_nodes, key=lambda x: x.object_id):
            input_vars.update(input_visitor.visit(node))

        output_visitor = visitor.OutputVisitor()
        output_vars = set()
        for node in sorted(stream_nodes, key=lambda x: x.object_id):
            output_vars.update(output_visitor.visit(node))

        # add input ports
        for input_var in sorted(input_vars, key=lambda x: x.object_id):
            input_var._implement_input(m, seq, aswire)

        # schedule
        sched = scheduler.ASAPScheduler()
        sched.schedule(output_vars)

        # balance output stage depth
        max_stage = 0
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            max_stage = stypes._max(max_stage, output_var.end_stage)
        self.max_stage = max_stage

        output_vars = sched.balance_output(output_vars, max_stage)

        # get all vars
        all_visitor = visitor.AllVisitor()
        all_vars = set()
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            all_vars.update(all_visitor.visit(output_var))

        # control (valid and ready)
        if not self.has_control:
            self.add_control(aswire)

        self.implement_control(seq)

        # allocate (implement signals)
        alloc = allocator.Allocator()
        alloc.allocate(m, seq, all_vars, self.valid_list, self.senable)

        # set default module information
        for var in sorted(all_vars, key=lambda x: x.object_id):
            var._set_module(m)
            var._set_strm(self)

            if var.seq is not None:
                seq.update(var.seq)

            var._set_seq(seq)

        # add output ports
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            output_var._implement_output(m, seq, aswire)

        # save schedule result
        self.last_input = input_vars
        self.last_output = output_vars

        if self.dump:
            self.add_dump(m, seq, input_vars, output_vars, all_vars)

        return m
Example #31
0
class Stream(object):

    def __init__(self, *nodes, **opts):
        # ID for manager reuse and merge
        global _stream_counter
        self.object_id = _stream_counter
        _stream_counter += 1

        self.nodes = set()
        self.named_numerics = OrderedDict()

        self.add(*nodes)

        self.max_stage = 0
        self.last_input = None
        self.last_output = None

        self.module = opts['module'] if 'module' in opts else None
        self.clock = opts['clock'] if 'clock' in opts else None
        self.reset = opts['reset'] if 'reset' in opts else None

        self.ivalid = opts['ivalid'] if 'ivalid' in opts else None
        self.iready = opts['iready'] if 'iready' in opts else None
        self.ovalid = opts['ovalid'] if 'ovalid' in opts else None
        self.oready = opts['oready'] if 'oready' in opts else None

        self.aswire = opts['aswire'] if 'aswire' in opts else True
        self.dump = opts['dump'] if 'dump' in opts else False
        self.dump_base = opts['dump_base'] if 'dump_base' in opts else 10
        self.dump_mode = opts['dump_mode'] if 'dump_mode' in opts else 'all'

        self.seq = None
        self.has_control = False

        self.implemented = False

        if (self.module is not None and
                self.clock is not None and self.reset is not None):

            no_hook = opts['no_hook'] if 'no_hook' in opts else False
            if not no_hook:
                self.module.add_hook(self.implement)

            seq_name = (opts['seq_name'] if 'seq_name' in opts else
                        '_stream_seq_%d' % self.object_id)
            self.seq = Seq(self.module, seq_name, self.clock, self.reset)

        if self.dump:
            dump_enable_name = '_stream_dump_enable_%d' % self.object_id
            dump_enable = self.module.Reg(dump_enable_name, initval=0)
            dump_mask_name = '_stream_dump_mask_%d' % self.object_id
            dump_mask = self.module.Reg(dump_mask_name, initval=0)
            dump_step_name = '_stream_dump_step_%d' % self.object_id
            dump_step = self.module.Reg(dump_step_name, 32, initval=0)

            self.dump_enable = dump_enable
            self.dump_mask = dump_mask
            self.dump_step = dump_step

            if self.seq:
                self.seq.add_reset(self.dump_enable)
                self.seq.add_reset(self.dump_mask)

    # -------------------------------------------------------------------------
    def add(self, *nodes):
        self.nodes.update(set(nodes))

        for node in nodes:
            if hasattr(node, 'input_data'):
                if isinstance(node.input_data, str):
                    name = node.input_data
                else:
                    name = node.input_data.name
                self.named_numerics[name] = node

            elif hasattr(node, 'output_data'):
                if node.output_data is None:
                    continue
                if isinstance(node.output_data, str):
                    name = node.output_data
                else:
                    name = node.output_data.name
                self.named_numerics[name] = node

    # -------------------------------------------------------------------------
    def to_module(self, name, clock='CLK', reset='RST', aswire=False, seq_name=None):
        """ generate a Module definion """

        m = Module(name)
        clk = m.Input(clock)
        rst = m.Input(reset)

        m = self.implement(m, clk, rst, aswire=aswire, seq_name=seq_name)

        return m

    # -------------------------------------------------------------------------
    def implement(self, m=None, clock=None, reset=None, aswire=None, seq_name=None):
        """ implemente actual registers and operations in Verilog """

        if self.implemented:
            if m is None:
                return self.module
            raise ValueError('already implemented.')

        self.implemented = True

        if m is None:
            m = self.module

        if self.module is None:
            self.module = m

        if clock is None:
            clock = self.clock

        if reset is None:
            reset = self.reset

        if self.seq is None:
            if seq_name is None:
                seq_name = '_stream_seq_%d' % self.object_id
            seq = Seq(m, seq_name, clock, reset)
        else:
            seq = self.seq

        if aswire is None:
            aswire = self.aswire

        self.add_control(aswire=aswire)
        self.has_control = True

        # for mult and div
        m._clock = clock
        m._reset = reset

        stream_nodes = self.nodes

        input_visitor = visitor.InputVisitor()
        input_vars = set()
        for node in sorted(stream_nodes, key=lambda x: x.object_id):
            input_vars.update(input_visitor.visit(node))

        output_visitor = visitor.OutputVisitor()
        output_vars = set()
        for node in sorted(stream_nodes, key=lambda x: x.object_id):
            output_vars.update(output_visitor.visit(node))

        # add input ports
        for input_var in sorted(input_vars, key=lambda x: x.object_id):
            input_var._implement_input(m, seq, aswire)

        # schedule
        sched = scheduler.ASAPScheduler()
        sched.schedule(output_vars)

        # balance output stage depth
        max_stage = 0
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            max_stage = stypes._max(max_stage, output_var.end_stage)
        self.max_stage = max_stage

        output_vars = sched.balance_output(output_vars, max_stage)

        # get all vars
        all_visitor = visitor.AllVisitor()
        all_vars = set()
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            all_vars.update(all_visitor.visit(output_var))

        # control (valid and ready)
        if not self.has_control:
            self.add_control(aswire)

        self.implement_control(seq)

        # allocate (implement signals)
        alloc = allocator.Allocator()
        alloc.allocate(m, seq, all_vars, self.valid_list, self.senable)

        # set default module information
        for var in sorted(all_vars, key=lambda x: x.object_id):
            var._set_module(m)
            var._set_strm(self)

            if var.seq is not None:
                seq.update(var.seq)

            var._set_seq(seq)

        # add output ports
        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            output_var._implement_output(m, seq, aswire)

        # save schedule result
        self.last_input = input_vars
        self.last_output = output_vars

        if self.dump:
            self.add_dump(m, seq, input_vars, output_vars, all_vars)

        return m

    def add_dump(self, m, seq, input_vars, output_vars, all_vars):
        pipeline_depth = self.pipeline_depth()
        log_pipeline_depth = max(
            int(math.ceil(math.log(max(pipeline_depth, 10), 10))), 1)

        seq(
            self.dump_step(1)
        )

        for i in range(pipeline_depth + 1):
            seq.If(seq.Prev(self.dump_enable, i))(
                self.dump_step.inc()
            )

        def get_name(obj):
            if hasattr(obj, 'name'):
                return obj.name
            if isinstance(obj, vtypes._Constant):
                return obj.__class__.__name__
            raise TypeError()

        longest_name_len = 0
        for input_var in sorted(input_vars, key=lambda x: x.object_id):
            if not (self.dump_mode == 'all' or
                    self.dump_mode == 'stream' or
                    self.dump_mode == 'input' or
                    self.dump_mode == 'inout' or
                    (self.dump_mode == 'selective' and
                        hasattr(input_var, 'dump') and input_var.dump)):
                continue

            name = get_name(input_var.sig_data)
            length = len(name) + 6
            longest_name_len = max(longest_name_len, length)

        for var in sorted(all_vars, key=lambda x: (-1, x.object_id)
                          if x.end_stage is None else
                          (x.end_stage, x.object_id)):
            if not (self.dump_mode == 'all' or
                    self.dump_mode == 'stream' or
                    (self.dump_mode == 'selective' and
                        hasattr(var, 'dump') and var.dump)):
                continue

            name = get_name(var.sig_data)
            length = len(name) + 6
            longest_name_len = max(longest_name_len, length)

        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            if not (self.dump_mode == 'all' or
                    self.dump_mode == 'stream' or
                    self.dump_mode == 'output' or
                    self.dump_mode == 'inout' or
                    (self.dump_mode == 'selective' and
                        hasattr(output_var, 'dump') and output_var.dump)):
                continue

            name = get_name(output_var.output_sig_data)
            length = len(name) + 6
            longest_name_len = max(longest_name_len, length)

        longest_var_len = 0
        for var in sorted(all_vars, key=lambda x: (-1, x.object_id)
                          if x.start_stage is None else
                          (x.start_stage, x.object_id)):
            bitwidth = vtypes.get_width(var.sig_data)
            if bitwidth is None:
                bitwidth = 1
            if bitwidth <= 0:
                bitwidth = 1

            base = (var.dump_base if hasattr(var, 'dump_base') else
                    self.dump_base)
            total_length = int(math.ceil(bitwidth / math.log(base, 2)))
            #point_length = int(math.ceil(var.point / math.log(base, 2)))
            #point_length = max(point_length, 8)
            #longest_var_len = max(longest_var_len, total_length, point_length)
            longest_var_len = max(longest_var_len, total_length)

        for input_var in sorted(input_vars, key=lambda x: x.object_id):

            base = (input_var.dump_base if hasattr(input_var, 'dump_base') else
                    self.dump_base)
            base_char = ('b' if base == 2 else
                         'o' if base == 8 else
                         'd' if base == 10 and input_var.point <= 0 else
                         # 'f' if base == 10 and input_var.point > 0 else
                         'g' if base == 10 and input_var.point > 0 else
                         'x')
            prefix = ('0b' if base == 2 else
                      '0o' if base == 8 else
                      '  ' if base == 10 else
                      '0x')

            # if base_char == 'f':
            #    point_length = int(math.ceil(input_var.point / math.log(base, 2)))
            #    point_length = max(point_length, 8)
            #    fmt_list = [prefix, '%',
            #                '%d.%d' % (longest_var_len + 1, point_length), base_char]
            # if base_char == 'g':
            #    fmt_list = [prefix, '%', base_char]
            # else:
            #    fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char]
            fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char]

            if input_var not in all_vars:
                fmt_list.append(' (unused)')

            input_var.dump_fmt = ''.join(fmt_list)

        for output_var in sorted(output_vars, key=lambda x: x.object_id):

            base = (output_var.dump_base if hasattr(output_var, 'dump_base') else
                    self.dump_base)
            base_char = ('b' if base == 2 else
                         'o' if base == 8 else
                         'd' if base == 10 and output_var.point <= 0 else
                         # 'f' if base == 10 and output_var.point > 0 else
                         'g' if base == 10 and output_var.point > 0 else
                         'x')
            prefix = ('0b' if base == 2 else
                      '0o' if base == 8 else
                      '  ' if base == 10 else
                      '0x')

            # if base_char == 'f':
            #    point_length = int(math.ceil(output_var.point / math.log(base, 2)))
            #    point_length = max(point_length, 8)
            #    fmt_list = [prefix, '%',
            #                '%d.%d' % (longest_var_len + 1, point_length), base_char]
            # if base_char == 'g':
            #    fmt_list = [prefix, '%', base_char]
            # else:
            #    fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char]
            fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char]

            if output_var not in all_vars:
                fmt_list.append(' (unused)')

            output_var.dump_fmt = ''.join(fmt_list)

        for var in sorted(all_vars, key=lambda x: (-1, x.object_id)
                          if x.start_stage is None else
                          (x.start_stage, x.object_id)):

            base = (var.dump_base if hasattr(var, 'dump_base') else
                    self.dump_base)
            base_char = ('b' if base == 2 else
                         'o' if base == 8 else
                         'd' if base == 10 and var.point <= 0 else
                         # 'f' if base == 10 and var.point > 0 else
                         'g' if base == 10 and var.point > 0 else
                         'x')
            prefix = ('0b' if base == 2 else
                      '0o' if base == 8 else
                      '  ' if base == 10 else
                      '0x')

            # if base_char == 'f':
            #    point_length = int(math.ceil(var.point / math.log(base, 2)))
            #    point_length = max(point_length, 8)
            #    fmt_list = [prefix, '%',
            #                '%d.%d' % (longest_var_len + 1, point_length), base_char]
            # if base_char == 'g':
            #    fmt_list = [prefix, '%', base_char]
            # else:
            #    fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char]
            fmt_list = [prefix, '%', '%d' % (longest_var_len + 1), base_char]

            var.dump_fmt = ''.join(fmt_list)

        enables = []
        for input_var in sorted(input_vars, key=lambda x: x.object_id):
            if not (self.dump_mode == 'all' or
                    self.dump_mode == 'stream' or
                    self.dump_mode == 'input' or
                    self.dump_mode == 'inout' or
                    (self.dump_mode == 'selective' and
                        hasattr(input_var, 'dump') and input_var.dump)):
                continue

            vfmt = input_var.dump_fmt

            name = get_name(input_var.sig_data)
            name_alignment = ' ' * (longest_name_len - len(name) -
                                    len('(in) '))
            fmt = ''.join(['<', self.name, ' step:%d, ',
                           'stage:%', str(
                               log_pipeline_depth), 'd, age:%d> (in) ',
                           name_alignment, name, ' = ', vfmt])

            stage = input_var.end_stage if input_var.end_stage is not None else 0
            enable = seq.Prev(self.dump_enable, stage)
            enables.append(enable)
            age = seq.Prev(self.dump_step, stage) - 1

            if input_var.point > 0:
                sig_data = vtypes.Div(vtypes.SystemTask('itor', input_var.sig_data),
                                      1.0 * (2 ** input_var.point))
            elif input_var.point < 0:
                sig_data = vtypes.Times(input_var.sig_data, 2 ** -input_var.point)
            else:
                sig_data = input_var.sig_data

            seq.If(enable, vtypes.Not(self.dump_mask))(
                vtypes.Display(fmt, self.dump_step, stage, age, sig_data)
            )

        for var in sorted(all_vars, key=lambda x: (-1, x.object_id)
                          if x.end_stage is None else
                          (x.end_stage, x.object_id)):
            if not (self.dump_mode == 'all' or
                    self.dump_mode == 'stream' or
                    (self.dump_mode == 'selective' and
                        hasattr(var, 'dump') and var.dump)):
                continue

            vfmt = var.dump_fmt

            name = get_name(var.sig_data)
            name_alignment = ' ' * (longest_name_len - len(name))
            stage = var.end_stage if var.end_stage is not None else 0

            fmt = ''.join(['<', self.name, ' step:%d, ',
                           'stage:%', str(log_pipeline_depth), 'd, age:%d> ',
                           name_alignment, name, ' = ', vfmt])

            enable = seq.Prev(self.dump_enable, stage)
            enables.append(enable)
            age = seq.Prev(self.dump_step, stage) - 1

            if var.point > 0:
                sig_data = vtypes.Div(vtypes.SystemTask('itor', var.sig_data),
                                      1.0 * (2 ** var.point))
            elif var.point < 0:
                sig_data = vtypes.Times(var.sig_data, 2 ** -var.point)
            else:
                sig_data = var.sig_data

            seq.If(enable, vtypes.Not(self.dump_mask))(
                vtypes.Display(fmt, self.dump_step, stage, age, sig_data)
            )

        for output_var in sorted(output_vars, key=lambda x: x.object_id):
            if not (self.dump_mode == 'all' or
                    self.dump_mode == 'stream' or
                    self.dump_mode == 'output' or
                    self.dump_mode == 'inout' or
                    (self.dump_mode == 'selective' and
                        hasattr(output_var, 'dump') and output_var.dump)):
                continue

            vfmt = output_var.dump_fmt

            name = get_name(output_var.output_sig_data)
            name_alignment = ' ' * (longest_name_len - len(name) -
                                    len('(out) '))
            fmt = ''.join(['<', self.name, ' step:%d, ',
                           'stage:%', str(
                               log_pipeline_depth), 'd, age:%d> (out) ',
                           name_alignment, name, ' = ', vfmt])

            stage = output_var.end_stage if output_var.end_stage is not None else 0
            enable = seq.Prev(self.dump_enable, stage)
            enables.append(enable)
            age = seq.Prev(self.dump_step, stage) - 1

            if output_var.point > 0:
                sig_data = vtypes.Div(vtypes.SystemTask('itor', output_var.output_sig_data),
                                      1.0 * (2 ** output_var.point))
            elif output_var.point < 0:
                sig_data = vtypes.Times(output_var.output_sig_data, 2 ** -output_var.point)
            else:
                sig_data = output_var.output_sig_data

            seq.If(enable, vtypes.Not(self.dump_mask))(
                vtypes.Display(fmt, self.dump_step, stage, age, sig_data)
            )

    # -------------------------------------------------------------------------
    def add_control(self, aswire=True):
        if self.ivalid is not None and isinstance(self.ivalid, str):
            if aswire:
                self.ivalid = self.module.Wire(self.ivalid)
            else:
                self.ivalid = self.module.Input(self.ivalid)

        if self.iready is not None and isinstance(self.iready, str):
            if aswire:
                self.iready = self.module.Wire(self.iready)
            else:
                self.iready = self.module.Output(self.iready)

        if self.ovalid is not None and isinstance(self.ovalid, str):
            if aswire:
                self.ovalid = self.module.Wire(self.ovalid)
            else:
                self.ovalid = self.module.Output(self.ovalid)

        if self.oready is not None and isinstance(self.oready, str):
            if aswire:
                self.oready = self.module.Wire(self.oready)
            else:
                self.oready = self.module.Input(self.oready)

    def implement_control(self, seq):
        self.valid_list = None

        if self.ivalid is None and self.oready is None:
            self.senable = None
            if self.ovalid is not None:
                self.ovalid.assign(1)
            if self.iready is not None:
                self.iready.assign(1)
            return

        if self.ivalid is None:
            self.senable = self.oready
            if self.iready is not None:
                self.iready.assign(self.senable)
            return

        if self.oready is None:
            self.senable = None
        else:
            self.senable = self.oready

        self._make_valid_chain(seq, self.senable)
        if self.iready is not None:
            self.iready.assign(self.senable)

    def _make_valid_chain(self, seq, cond=None):
        self.valid_list = []
        self.valid_list.append(self.ivalid)

        name = self.ivalid.name
        prev = self.ivalid

        for i in range(self.max_stage):
            v = self.module.Reg("_{}_{}".format(name, i + 1), initval=0)
            self.valid_list.append(v)
            seq(v(prev), cond=cond)
            prev = v

        if self.ovalid is not None:
            self.ovalid.assign(prev)

    # -------------------------------------------------------------------------
    def draw_graph(self, filename='out.png', prog='dot', rankdir='LR', approx=False):
        if self.last_output is None:
            self.to_module()

        graph.draw_graph(self.last_output, filename=filename, prog=prog,
                         rankdir=rankdir, approx=approx)

    def enable_draw_graph(self, filename='out.png', prog='dot', rankdir='LR', approx=False):
        self.module.add_hook(self.draw_graph,
                             kwargs={'filename': filename, 'prog': prog,
                                     'rankdir': rankdir, 'approx': approx})

    # -------------------------------------------------------------------------
    def get_input(self):
        if self.last_input is None:
            return OrderedDict()

        ret = OrderedDict()
        for input_var in sorted(self.last_input, key=lambda x: x.object_id):
            key = str(input_var.input_data)
            value = input_var
            ret[key] = value

        return ret

    def get_output(self):
        if self.last_output is None:
            return OrderedDict()

        ret = OrderedDict()
        for output_var in sorted(self.last_output, key=lambda x: x.object_id):
            key = str(output_var.output_data)
            value = output_var
            ret[key] = value

        return ret

    # -------------------------------------------------------------------------
    def pipeline_depth(self):
        return self.max_stage

    # -------------------------------------------------------------------------
    def __getattr__(self, attr):
        try:
            return object.__getattribute__(self, attr)

        except AttributeError as e:
            if attr.startswith('__') or attr not in dir(stypes):
                raise e

            func = getattr(stypes, attr)

            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                v = func(*args, **kwargs)
                if isinstance(v, (tuple, list)):
                    for item in v:
                        self._set_info(item)
                else:
                    self._set_info(v)
                return v

            return wrapper

    def _set_info(self, v):
        if isinstance(v, stypes._Numeric):
            v._set_module(self.module)
            v._set_strm(self)
            v._set_seq(self.seq)

            self.add(v)

    def get_named_numeric(self, name):
        if name not in self.named_numerics:
            raise NameError("Numeric '%s' is not defined." % name)

        return self.named_numerics[name]
Example #32
0
class FSM(vtypes.VeriloggenNode):
    """ Finite State Machine Generator """

    def __init__(self, m, name, clk, rst, width=32, initname='init',
                 nohook=False, as_module=False):
        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.width = width
        self.state_count = 0
        self.state = self.m.Reg(name, width)  # set initval later

        self.mark = collections.OrderedDict()  # key:index
        self._set_mark(0, self.name + '_' + initname)
        self.state.initval = self._get_mark(0)

        self.body = collections.defaultdict(list)
        self.jump = collections.defaultdict(list)

        self.delay_amount = 0
        self.delayed_state = collections.OrderedDict()  # key:delay
        self.delayed_body = collections.defaultdict(
            functools.partial(collections.defaultdict, list))  # key:delay
        self.delayed_cond = collections.OrderedDict()  # key:name
        self.tmp_count = 0

        self.dst_var = collections.OrderedDict()
        self.dst_visitor = SubstDstVisitor()
        self.reset_visitor = ResetVisitor()

        self.seq = Seq(self.m, self.name + '_par', clk, rst, nohook=True)

        self.done = False

        self.last_cond = []
        self.last_kwargs = {}
        self.last_if_statement = None
        self.elif_cond = None
        self.next_kwargs = {}

        self.as_module = as_module

        if not nohook:
            self.m.add_hook(self.implement)

    # -------------------------------------------------------------------------
    def goto(self, dst, cond=None, else_dst=None):
        if cond is None and 'cond' in self.next_kwargs:
            cond = self.next_kwargs['cond']
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()

        src = self.current
        return self._go(src, dst, cond, else_dst)

    def goto_init(self, cond=None):
        if cond is None and 'cond' in self.next_kwargs:
            cond = self.next_kwargs['cond']
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()

        src = self.current
        dst = 0
        return self._go(src, dst, cond)

    def goto_next(self, cond=None):
        if cond is None and 'cond' in self.next_kwargs:
            cond = self.next_kwargs['cond']
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()

        src = self.current
        dst = self.current + 1
        ret = self._go(src, dst, cond=cond)
        self.inc()
        return ret

    def goto_from(self, src, dst, cond=None, else_dst=None):
        if cond is None and 'cond' in self.next_kwargs:
            cond = self.next_kwargs['cond']
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()

        return self._go(src, dst, cond, else_dst)

    def inc(self):
        self._set_index(None)

    # -------------------------------------------------------------------------
    def add(self, *statement, **kwargs):
        """ add new assignments """
        kwargs.update(self.next_kwargs)
        self.last_kwargs = kwargs
        self._clear_next_kwargs()

        # if there is no attributes, Elif object is reused.
        has_args = not (len(kwargs) == 0 or  # has no args
                        (len(kwargs) == 1 and 'cond' in kwargs))  # has only 'cond'

        if self.elif_cond is not None and not has_args:
            next_call = self.last_if_statement.Elif(self.elif_cond)
            next_call(*statement)
            self.last_if_statement = next_call
            self._add_dst_var(statement)
            self._clear_elif_cond()
            return self

        self._clear_last_if_statement()
        return self._add_statement(statement, **kwargs)

    # -------------------------------------------------------------------------
    def add_reset(self, v):
        return self.seq.add_reset(v)

    # -------------------------------------------------------------------------
    def Prev(self, var, delay, initval=0, cond=None, prefix=None):
        return self.seq.Prev(var, delay, initval, cond, prefix)

    # -------------------------------------------------------------------------
    def If(self, *cond):
        self._clear_elif_cond()

        cond = make_condition(*cond)

        if cond is None:
            return self

        if 'cond' not in self.next_kwargs:
            self.next_kwargs['cond'] = cond
        else:
            self.next_kwargs['cond'] = vtypes.Ands(
                self.next_kwargs['cond'], cond)

        self.last_cond = [self.next_kwargs['cond']]

        return self

    def Else(self, *statement, **kwargs):
        self._clear_elif_cond()

        if len(self.last_cond) == 0:
            raise ValueError("No previous condition for Else.")

        old = self.last_cond.pop()
        self.last_cond.append(vtypes.Not(old))

        # if the true-statement has delay attributes,
        # Else statement is separated.
        if 'delay' in self.last_kwargs and self.last_kwargs['delay'] > 0:
            prev_cond = self.last_cond
            ret = self.Then()(*statement)
            self.last_cond = prev_cond
            return ret

        # if there is additional attribute, Else statement is separated.
        has_args = not (len(self.next_kwargs) == 0 or  # has no args
                        (len(self.next_kwargs) == 1 and 'cond' in kwargs))  # has only 'cond'

        if has_args:
            prev_cond = self.last_cond
            ret = self.Then()(*statement)
            self.last_cond = prev_cond
            return ret

        if not isinstance(self.last_if_statement, vtypes.If):
            raise ValueError("Last if-statement is not If")

        self.last_if_statement.Else(*statement)
        self._add_dst_var(statement)

        return self

    def Elif(self, *cond):
        if len(self.last_cond) == 0:
            raise ValueError("No previous condition for Else.")

        cond = make_condition(*cond)

        old = self.last_cond.pop()
        self.last_cond.append(vtypes.Not(old))
        self.last_cond.append(cond)

        # if the true-statement has delay attributes, Else statement is
        # separated.
        if 'delay' in self.last_kwargs and self.last_kwargs['delay'] > 0:
            prev_cond = self.last_cond
            ret = self.Then()
            self.last_cond = prev_cond
            return ret

        if not isinstance(self.last_if_statement, vtypes.If):
            raise ValueError("Last if-statement is not If")

        self.elif_cond = cond

        cond = self._make_cond(self.last_cond)
        self.next_kwargs['cond'] = cond

        return self

    def Delay(self, delay):
        self.next_kwargs['delay'] = delay
        return self

    def Keep(self, keep):
        self.next_kwargs['keep'] = keep
        return self

    def Then(self):
        cond = self._make_cond(self.last_cond)
        self._clear_last_cond()
        self.If(cond)
        return self

    def LazyCond(self, value=True):
        self.next_kwargs['lazy_cond'] = value
        return self

    def EagerVal(self, value=True):
        self.next_kwargs['eager_val'] = value
        return self

    def Clear(self):
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()
        self._clear_elif_cond()
        return self

    # -------------------------------------------------------------------------
    @property
    def current(self):
        return self.state_count

    @property
    def next(self):
        return self.current + 1

    @property
    def current_delay(self):
        if 'delay' in self.next_kwargs:
            return self.next_kwargs['delay']
        return 0

    @property
    def last_delay(self):
        if 'delay' in self.last_kwargs:
            return self.last_kwargs['delay']
        return 0

    @property
    def current_condition(self):
        cond = self.next_kwargs['cond'] if 'cond' in self.next_kwargs else None
        if cond is not None:
            cond = vtypes.AndList(self.state == self.state_count, cond)
        else:
            cond = self.state == self.state_count
        return cond

    @property
    def last_condition(self):
        cond = self._make_cond(self.last_cond)
        if cond is not None:
            cond = vtypes.AndList(self.state == self.state_count, cond)
        else:
            cond = self.state == self.state_count
        return cond

    @property
    def then(self):
        return self.last_condition

    @property
    def here(self):
        return self.state == self.current

    # -------------------------------------------------------------------------
    def implement(self):
        if self.as_module:
            self.make_module()
            return

        self.make_always()

    # -------------------------------------------------------------------------
    def make_always(self, reset=(), body=(), case=True):
        if self.done:
            #raise ValueError('make_always() has been already called.')
            return

        self.done = True

        part_reset = self.make_reset(reset)
        part_body = list(body) + list(self.make_case()
                                      if case else self.make_if())
        self.m.Always(vtypes.Posedge(self.clk))(
            vtypes.If(self.rst)(
                part_reset,
            )(
                part_body,
            ))

    # -------------------------------------------------------------------------
    def make_module(self, reset=(), body=(), case=True):
        if self.done:
            #raise ValueError('make_always() has been already called.')
            return

        self.done = True

        m = Module('sub_%s' % self.name)

        clk = m.Input('CLK')

        if self.rst is not None:
            rst = m.Input('RST')
        else:
            rst = None

        body = list(body) + list(self.make_case()
                                 if case else self.make_if())
        dst_var = self.seq.dst_var
        dst_var.update(self.dst_var)
        dsts = dst_var.values()
        src_visitor = SubstSrcVisitor()

        # collect sources in destination variable definitions
        for dst in dsts:
            if isinstance(dst, (vtypes.Pointer, vtypes.Slice, vtypes.Cat)):
                raise ValueError(
                    'Partial assignment is not supported by as_module mode.')

            if isinstance(dst, vtypes._Variable):
                if dst.width is not None:
                    src_visitor.visit(dst.width)
                if dst.length is not None:
                    src_visitor.visit(dst.length)

        # collect sources in statements
        for statement in body:
            src_visitor.visit(statement)

        srcs = src_visitor.srcs.values()

        # collect sources in source variable definitions
        for src in srcs:
            if isinstance(src, vtypes._Variable):
                if src.width is not None:
                    src_visitor.visit(src.width)
                if src.length is not None:
                    src_visitor.visit(src.length)

        srcs = src_visitor.srcs.values()

        params = collections.OrderedDict()
        ports = collections.OrderedDict()
        src_rename_dict = collections.OrderedDict()

        fsm_orig_labels = [v.name for v in self.mark.values()]
        fsm_labels = collections.OrderedDict()

        # create parameter/localparam definitions
        for src in srcs:
            if isinstance(src, (vtypes.Parameter, vtypes.Localparam)):
                if src.name in fsm_orig_labels:
                    fsm_labels[src.name] = m.Localparam(src.name, src.value)
                    continue

                arg_name = src.name
                v = m.Parameter(arg_name, src.value, src.width, src.signed)
                src_rename_dict[src.name] = v
                params[arg_name] = src

        src_rename_visitor = SrcRenameVisitor(src_rename_dict)

        state_width = src_rename_visitor.visit(self.state.width)
        state_initval = src_rename_visitor.visit(self.state.initval)
        state = m.OutputReg(self.state.name, state_width,
                            initval=state_initval)
        out_state = self.m.TmpWire(state_width,
                                   prefix='_%s_out' % self.state.name)
        self.m.Always()(self.state(out_state, blk=True))
        ports[state.name] = out_state

        src_rename_dict[self.state.name] = state

        for delay, s in sorted(self.delayed_state.items(), key=lambda x: x[0]):
            s_width = src_rename_visitor.visit(s.width)
            s_initval = src_rename_visitor.visit(s.initval)
            d = m.OutputReg(s.name, s_width, initval=s_initval)
            out_d = self.m.TmpWire(s_width, prefix='_%s_out' % d.name)
            self.m.Always()(s(out_d, blk=True))
            ports[s.name] = out_d

        state_names = [self.state.name]
        state_names.extend([s.name for s in self.delayed_state.values()])

        for src in srcs:
            if isinstance(src, (vtypes.Parameter, vtypes.Localparam)):
                continue

            if src.name in state_names:
                continue

            if src.name in list(self.delayed_cond.keys()):
                rep_width = (src_rename_visitor.visit(src.width)
                             if src.width is not None else None)
                v = m.Reg(src.name, rep_width, initval=0)
                src_rename_dict[src.name] = v
                continue

            arg_name = 'i_%s' % src.name

            if src.length is not None:
                width = src.bit_length()
                length = src.length
                pack_width = vtypes.Mul(width, length)
                out_line = self.m.TmpWire(pack_width, prefix='_%s' % self.name)
                i = self.m.TmpGenvar(prefix='i')
                v = out_line[i * width:(i + 1) * width]
                g = self.m.GenerateFor(i(0), i < length, i(i + 1))
                p = g.Assign(v(src[i]))

                rep_width = (src_rename_visitor.visit(src.width)
                             if src.width is not None else None)
                rep_length = src_rename_visitor.visit(src.length)
                pack_width = (rep_length if rep_width is None else
                              vtypes.Mul(rep_length, rep_width))
                in_line = m.Input(arg_name + '_line', pack_width,
                                  signed=src.get_signed())
                in_array = m.Wire(arg_name, rep_width, rep_length,
                                  signed=src.get_signed())
                i = m.TmpGenvar(prefix='i')
                v = in_line[i * rep_width:(i + 1) * rep_width]
                g = m.GenerateFor(i(0), i < rep_length, i(i + 1))
                p = g.Assign(in_array[i](v))

                src_rename_dict[src.name] = in_array
                ports[in_line.name] = out_line

            else:
                rep_width = (src_rename_visitor.visit(src.width)
                             if src.width is not None else None)
                v = m.Input(arg_name, rep_width, signed=src.get_signed())

                src_rename_dict[src.name] = v
                ports[arg_name] = src

        for dst in dsts:
            if dst.name in list(self.delayed_cond.keys()):
                continue

            arg_name = dst.name

            rep_width = (src_rename_visitor.visit(dst.width)
                         if dst.width is not None else None)
            out = m.OutputReg(arg_name, rep_width, signed=dst.get_signed())
            out_wire = self.m.TmpWire(rep_width, signed=dst.get_signed(),
                                      prefix='_%s_%s' % (self.name, arg_name))
            self.m.Always()(dst(out_wire, blk=True))
            ports[arg_name] = out_wire

        body = [src_rename_visitor.visit(statement)
                for statement in body]

        reset = self.make_reset(reset)

        if not reset and not body:
            pass
        elif not reset or rst is None:
            m.Always(vtypes.Posedge(clk))(
                body,
            )
        else:
            m.Always(vtypes.Posedge(clk))(
                vtypes.If(rst)(
                    reset,
                )(
                    body,
                ))

        arg_params = [(name, param) for name, param in params.items()]

        arg_ports = [('CLK', self.clk)]
        if self.rst is not None:
            arg_ports.append(('RST', self.rst))

        arg_ports.extend([(name, port) for name, port in ports.items()])

        sub = Submodule(self.m, m, 'inst_' + m.name, '_%s_' % self.name,
                        arg_params=arg_params, arg_ports=arg_ports)

    # -------------------------------------------------------------------------
    def make_case(self):
        indexes = set(self.body.keys())
        indexes.update(set(self.jump.keys()))

        for index in indexes:
            self._add_mark(index)

        ret = []
        ret.extend(self.seq.make_code())
        ret.extend(self._get_delayed_substs())

        for delay, dct in sorted(self.delayed_body.items(),
                                 key=lambda x: x[0], reverse=True):
            body = tuple([self._get_delayed_when_statement(index, delay)
                          for index in sorted(dct.keys(), key=lambda x:x)])
            case = vtypes.Case(self._get_delayed_state(delay))(*body)
            ret.append(case)

        body = tuple([self._get_when_statement(index)
                      for index in sorted(indexes, key=lambda x:x)])
        case = vtypes.Case(self.state)(*body)

        if len(case.statement) > 0:
            ret.append(case)

        return ret

    def make_if(self):
        indexes = set(self.body.keys())
        indexes.update(set(self.jump.keys()))

        for index in indexes:
            self._add_mark(index)

        ret = []
        ret.extend(self.seq.make_code())
        ret.extend(self._get_delayed_substs())

        for delay, dct in sorted(self.delayed_body.items(),
                                 key=lambda x: x[0], reverse=True):
            ret.append([self._get_delayed_if_statement(index, delay)
                        for index in sorted(dct.keys(), key=lambda x:x)])

        ret.extend([self._get_if_statement(index)
                    for index in sorted(indexes, key=lambda x:x)])
        return ret

    # -------------------------------------------------------------------------
    def make_reset(self, reset):
        ret = collections.OrderedDict()

        for v in reset:
            if not isinstance(v, vtypes.Subst):
                raise TypeError(
                    'make_reset requires Subst, not %s' % str(type(v)))
            key = str(v.left)
            if key not in ret:
                ret[key] = v

        v = self.reset_visitor.visit(self.state)
        key = str(self.state)
        if v is not None and key not in ret:
            ret[key] = v

        for dst in self.delayed_state.values():
            v = self.reset_visitor.visit(dst)
            if v is None:
                continue
            key = str(v.left)
            if key not in ret:
                ret[key] = v

        for dst in self.dst_var.values():
            v = self.reset_visitor.visit(dst)
            if v is None:
                continue
            key = str(v.left)
            if key not in ret:
                ret[key] = v

        for v in self.seq.make_reset():
            if not isinstance(v, vtypes.Subst):
                raise TypeError(
                    'make_reset requires Subst, not %s' % str(type(v)))
            key = str(v.left)
            if key not in ret:
                ret[key] = v

        return list(ret.values())

    # -------------------------------------------------------------------------
    def set_index(self, index):
        return self._set_index(index)

    # -------------------------------------------------------------------------
    def _go(self, src, dst, cond=None, else_dst=None):
        self._add_jump(src, dst, cond, else_dst)
        return self

    def _add_jump(self, src, dst, cond=None, else_dst=None):
        self.jump[src].append((dst, cond, else_dst))

    # -------------------------------------------------------------------------
    def _add_statement(self, statement, index=None, keep=None, delay=None, cond=None,
                       lazy_cond=False, eager_val=False, no_delay_cond=False):

        cond = make_condition(cond)
        index = self._to_index(index) if index is not None else self.current

        if keep is not None:
            for i in range(keep):
                new_delay = i if delay is None else delay + i
                self._add_statement(statement, index=index,
                                    keep=None, delay=new_delay, cond=cond,
                                    lazy_cond=lazy_cond, eager_val=eager_val,
                                    no_delay_cond=no_delay_cond)
            return self

        if delay is not None and delay > 0:
            self._add_delayed_state(delay)

            if eager_val:
                statement = [self._add_delayed_subst(s, index, delay)
                             for s in statement]

            if not no_delay_cond:
                if cond is None:
                    cond = 1

                if not lazy_cond:
                    cond = self._add_delayed_cond(cond, index, delay)

                else:  # lazy condition
                    t = self._add_delayed_cond(1, index, delay)
                    if isinstance(cond, int) and cond == 1:
                        cond = t
                    else:
                        cond = vtypes.Ands(t, cond)

                statement = [vtypes.If(cond)(*statement)]

            self.delayed_body[delay][index].extend(statement)
            self._add_dst_var(statement)

            return self

        if cond is not None:
            statement = [vtypes.If(cond)(*statement)]
            self.last_if_statement = statement[0]

        self.body[index].extend(statement)
        self._add_dst_var(statement)

        return self

    # -------------------------------------------------------------------------
    def _add_dst_var(self, statement):
        for s in statement:
            values = self.dst_visitor.visit(s)
            for v in values:
                k = str(v)
                if k not in self.dst_var:
                    self.dst_var[k] = v

    # -------------------------------------------------------------------------
    def _add_delayed_cond(self, statement, index, delay):
        name_prefix = '_'.join(
            ['', self.name, 'cond', str(index), str(self.tmp_count)])
        self.tmp_count += 1
        prev = statement
        for i in range(delay):
            tmp_name = '_'.join([name_prefix, str(i + 1)])
            tmp = self.m.Reg(tmp_name, initval=0)
            self.delayed_cond[tmp_name] = tmp
            self._add_statement([tmp(prev)], delay=i, no_delay_cond=True)
            prev = tmp
        return prev

    # -------------------------------------------------------------------------
    def _add_delayed_subst(self, subst, index, delay):
        if not isinstance(subst, vtypes.Subst):
            return subst
        left = subst.left
        right = subst.right
        if isinstance(right, (bool, int, float, str,
                              vtypes._Constant, vtypes._ParameterVariable)):
            return subst
        width = left.bit_length()
        signed = vtypes.get_signed(left)
        prev = right

        name_prefix = ('_'.join(['', left.name, str(index), str(self.tmp_count)])
                       if isinstance(left, vtypes._Variable) else
                       '_'.join(['', self.name, 'sbst', str(index), str(self.tmp_count)]))
        self.tmp_count += 1

        for i in range(delay):
            tmp_name = '_'.join([name_prefix, str(i + 1)])
            tmp = self.m.Reg(tmp_name, width, initval=0, signed=signed)
            self._add_statement([tmp(prev)], delay=i, no_delay_cond=True)
            prev = tmp
        return left(prev)

    # -------------------------------------------------------------------------
    def _clear_next_kwargs(self):
        self.next_kwargs = {}

    def _clear_last_if_statement(self):
        self.last_if_statement = None

    def _clear_last_cond(self):
        self.last_cond = []

    def _clear_elif_cond(self):
        self.elif_cond = None

    def _make_cond(self, condlist):
        ret = None
        for cond in condlist:
            if ret is None:
                ret = cond
            else:
                ret = vtypes.Ands(ret, cond)
        return ret

    # -------------------------------------------------------------------------
    def _set_index(self, index=None):
        if index is None:
            self.state_count += 1
            return self.state_count
        self.state_count = index
        return self.state_count

    def _get_mark(self, index=None):
        if index is None:
            index = self.state_count
        if index not in self.mark:
            raise KeyError("No such index in FSM marks: %s" % index)
        return self.mark[index]

    def _set_mark(self, index=None, name=None):
        if index is None:
            index = self.state_count
        if name is None:
            name = self.name + '_' + str(index)
        self.mark[index] = self.m.Localparam(name, index)

    def _get_mark_index(self, s):
        for index, m in self.mark.items():
            if m.name == s.name:
                return index
        raise KeyError("No such mark in FSM marks: %s" % s.name)

    # -------------------------------------------------------------------------
    def _add_mark(self, index):
        index = self._to_index(index)
        if index not in self.mark:
            self._set_mark(index)
        mark = self._get_mark(index)
        return mark

    def _to_index(self, index):
        if not isinstance(index, int):
            index = self._get_mark_index(index)
        return index

    # -------------------------------------------------------------------------
    def _add_delayed_state(self, value):
        if not isinstance(value, int):
            raise TypeError("Delay amount must be int, not '%s'" %
                            str(type(value)))

        if value < 0:
            raise ValueError("Delay amount must be positive number")

        if value == 0:
            return self.state

        if value <= self.delay_amount:
            return self._get_delayed_state(value)

        for i in range(self.delay_amount + 1, value + 1):
            d = self.m.Reg(''.join(['_d', str(i), '_', self.name]), self.width,
                           initval=self._get_mark(0))
            self.delayed_state[i] = d

        self.delay_amount = value
        return d

    def _get_delayed_state(self, value):
        if value == 0:
            return self.state
        if value not in self.delayed_state:
            raise IndexError('No such index %d in delayed state' % value)
        return self.delayed_state[value]

    def _get_delayed_substs(self):
        ret = []
        prev = self.state
        for d in range(1, self.delay_amount + 1):
            ret.append(vtypes.Subst(self.delayed_state[d], prev))
            prev = self.delayed_state[d]
        return ret

    def _init_delayed_state(self):
        ret = []
        for d in range(1, self.delay_amount + 1):
            ret.append(vtypes.Subst(self.delayed_state[d], self._get_mark(0)))
        return ret

    def _to_state_assign(self, dst, cond=None, else_dst=None):
        dst_mark = self._get_mark(dst)
        value = self.state(dst_mark)
        if cond is not None:
            value = vtypes.If(cond)(value)
            if else_dst is not None:
                else_dst_mark = self._get_mark(else_dst)
                value = value.Else(self.state(else_dst_mark))
        return value

    # -------------------------------------------------------------------------
    def _cond_case(self, index):
        if index not in self.mark:
            self._set_mark(index)
        return self._get_mark(index)

    def _cond_if(self, index):
        if index not in self.mark:
            self._set_mark(index)
        return (self.state == self._get_mark(index))

    def _delayed_cond_if(self, index, delay):
        if index not in self.mark:
            self._set_mark(index)
        if delay > 0 and delay not in self.delayed_state:
            self._add_delayed_state(delay)
        return (self._get_delayed_state(delay) == self._get_mark(index))

    def _get_when_statement(self, index):
        body = []
        body.extend(self.body[index])
        for dst, cond, else_dst in self.jump[index]:
            self._add_mark(dst)
            if else_dst is not None:
                self._add_mark(else_dst)
            body.append(self._to_state_assign(dst, cond, else_dst))
        return vtypes.When(self._cond_case(index))(*body)

    def _get_delayed_when_statement(self, index, delay):
        return vtypes.When(self._cond_case(index))(*self.delayed_body[delay][index])

    def _get_if_statement(self, index):
        body = []
        body.extend(self.body[index])
        for dst, cond, else_dst in self.jump[index]:
            self._add_mark(dst)
            if else_dst is not None:
                self._add_mark(else_dst)
            body.append(self._to_state_assign(dst, cond, else_dst))
        return vtypes.If(self._cond_if(index))(*body)

    def _get_delayed_if_statement(self, index, delay):
        return vtypes.If(self._delayed_cond_if(index, delay))(*self.delayed_body[delay][index])

    # -------------------------------------------------------------------------
    def __call__(self, *statement, **kwargs):
        return self.add(*statement, **kwargs)

    def __getitem__(self, index):
        return self.body[index]

    def __len__(self):
        return self.state_count + 1
Example #33
0
class Fifo(object):
    def __init__(self, m, name, clk, rst, datawidth=32, addrwidth=4):
        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.datawidth = datawidth
        self.addrwidth = addrwidth
        self.wif = FifoWriteInterface(self.m, name, datawidth)
        self.rif = FifoReadInterface(self.m, name, datawidth)
        self.definition = mkFifoDefinition(name, datawidth, addrwidth)
        self.inst = self.m.Instance(self.definition,
                                    'inst_' + name,
                                    ports=m.connect_ports(self.definition))

        self.seq = Seq(m, name, clk, rst)
        # self.m.add_hook(self.seq.make_always)

        # entry counter
        self._max_size = (2**self.addrwidth - 1 if isinstance(
            self.addrwidth, int) else vtypes.Int(2)**self.addrwidth - 1)

        self._count = self.m.Reg('count_' + name,
                                 self.addrwidth + 1,
                                 initval=0)

        self.seq.If(
            vtypes.Ands(
                vtypes.Ands(self.wif.enq, vtypes.Not(self.wif.full)),
                vtypes.Ands(self.rif.deq, vtypes.Not(self.rif.empty))))(
                    self._count(self._count)).Elif(
                        vtypes.Ands(self.wif.enq, vtypes.Not(self.wif.full)))(
                            self._count.inc()).Elif(
                                vtypes.Ands(self.rif.deq,
                                            vtypes.Not(self.rif.empty)))(
                                                self._count.dec())

        self._enq_disabled = False
        self._deq_disabled = False

    def disable_enq(self):
        self.seq(self.wif.enq(0))
        self._enq_disabled = True

    def disable_deq(self):
        self.seq(self.rif.deq(0))
        self._deq_disabled = True

    def enq(self, wdata, cond=None, delay=0):
        """ Enque operation """
        if self._enq_disabled:
            raise TypeError('Enq disabled.')

        if cond is not None:
            self.seq.If(cond)

        current_delay = self.seq.current_delay

        not_full = vtypes.Not(self.wif.full)
        ack = vtypes.Ands(not_full, self.wif.enq)
        if current_delay + delay == 0:
            ready = vtypes.Not(self.wif.almost_full)
        else:
            ready = self._count + (current_delay + delay + 1) < self._max_size

        self.seq.Delay(current_delay + delay).EagerVal().If(not_full)(
            self.wif.wdata(wdata))
        self.seq.Then().Delay(current_delay + delay)(self.wif.enq(1))

        # de-assert
        self.seq.Delay(current_delay + delay + 1)(self.wif.enq(0))

        return ack, ready

    def deq(self, cond=None, delay=0):
        """ Deque operation """
        if self._deq_disabled:
            raise TypeError('Deq disabled.')

        if cond is not None:
            self.seq.If(cond)

        not_empty = vtypes.Not(self.rif.empty)

        current_delay = self.seq.current_delay

        self.seq.Delay(current_delay + delay)(self.rif.deq(1))

        rdata = self.rif.rdata
        rvalid = self.m.TmpReg(initval=0)

        self.seq.Then().Delay(current_delay + delay + 1)(rvalid(
            vtypes.Ands(not_empty, self.rif.deq)))

        # de-assert
        self.seq.Delay(current_delay + delay + 1)(self.rif.deq(0))
        self.seq.Delay(current_delay + delay + 2)(rvalid(0))

        return rdata, rvalid

    @property
    def wdata(self):
        return self.wif.wdata

    @property
    def empty(self):
        return self.rif.empty

    @property
    def almost_empty(self):
        return self.rif.almost_empty

    @property
    def rdata(self):
        return self.rif.rdata

    @property
    def full(self):
        return self.wif.full

    @property
    def almost_full(self):
        return self.wif.almost_full

    @property
    def count(self):
        return self._count

    @property
    def space(self):
        if isinstance(self._max_size, int):
            return vtypes.Int(self._max_size) - self.count
        return self._max_size - self.count

    def has_space(self, num=1):
        if num < 1:
            return True
        return (self._count + num < self._max_size)
Example #34
0
class FSM(vtypes.VeriloggenNode):
    """ Finite State Machine Generator """
    def __init__(self,
                 m,
                 name,
                 clk,
                 rst,
                 width=32,
                 initname='init',
                 nohook=False,
                 as_module=False):
        self.m = m
        self.name = name
        self.clk = clk
        self.rst = rst
        self.width = width
        self.state_count = 0
        self.state = self.m.Reg(name, width)  # set initval later

        self.mark = collections.OrderedDict()  # key:index
        self._set_mark(0, self.name + '_' + initname)
        self.state.initval = self._get_mark(0)

        self.body = collections.defaultdict(list)
        self.jump = collections.defaultdict(list)

        self.delay_amount = 0
        self.delayed_state = collections.OrderedDict()  # key:delay
        self.delayed_body = collections.defaultdict(
            functools.partial(collections.defaultdict, list))  # key:delay
        self.delayed_cond = collections.OrderedDict()  # key:name
        self.tmp_count = 0

        self.dst_var = collections.OrderedDict()
        self.dst_visitor = SubstDstVisitor()
        self.reset_visitor = ResetVisitor()

        self.seq = Seq(self.m, self.name + '_par', clk, rst, nohook=True)

        self.done = False

        self.last_cond = []
        self.last_kwargs = {}
        self.last_if_statement = None
        self.elif_cond = None
        self.next_kwargs = {}

        self.as_module = as_module

        if not nohook:
            self.m.add_hook(self.implement)

    #-------------------------------------------------------------------------
    def goto(self, dst, cond=None, else_dst=None):
        if cond is None and 'cond' in self.next_kwargs:
            cond = self.next_kwargs['cond']
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()

        src = self.current
        return self._go(src, dst, cond, else_dst)

    def goto_init(self, cond=None):
        if cond is None and 'cond' in self.next_kwargs:
            cond = self.next_kwargs['cond']
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()

        src = self.current
        dst = 0
        return self._go(src, dst, cond)

    def goto_next(self, cond=None):
        if cond is None and 'cond' in self.next_kwargs:
            cond = self.next_kwargs['cond']
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()

        src = self.current
        dst = self.current + 1
        ret = self._go(src, dst, cond=cond)
        self.inc()
        return ret

    def goto_from(self, src, dst, cond=None, else_dst=None):
        if cond is None and 'cond' in self.next_kwargs:
            cond = self.next_kwargs['cond']
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()

        return self._go(src, dst, cond, else_dst)

    def inc(self):
        self._set_index(None)

    #-------------------------------------------------------------------------
    def add(self, *statement, **kwargs):
        """ add new assignments """
        kwargs.update(self.next_kwargs)
        self.last_kwargs = kwargs
        self._clear_next_kwargs()

        # if there is no attributes, Elif object is reused.
        has_args = not (len(kwargs) == 0 or  # has no args
                        (len(kwargs) == 1 and 'cond' in kwargs)
                        )  # has only 'cond'

        if self.elif_cond is not None and not has_args:
            next_call = self.last_if_statement.Elif(self.elif_cond)
            next_call(*statement)
            self.last_if_statement = next_call
            self._add_dst_var(statement)
            self._clear_elif_cond()
            return self

        self._clear_last_if_statement()
        return self._add_statement(statement, **kwargs)

    #-------------------------------------------------------------------------
    def Prev(self, var, delay, initval=0, cond=None, prefix=None):
        return self.seq.Prev(var, delay, initval, cond, prefix)

    #-------------------------------------------------------------------------
    def If(self, *cond):
        self._clear_elif_cond()

        cond = make_condition(*cond)

        if cond is None:
            return self

        if 'cond' not in self.next_kwargs:
            self.next_kwargs['cond'] = cond
        else:
            self.next_kwargs['cond'] = vtypes.Ands(self.next_kwargs['cond'],
                                                   cond)

        self.last_cond = [self.next_kwargs['cond']]

        return self

    def Else(self, *statement, **kwargs):
        self._clear_elif_cond()

        if len(self.last_cond) == 0:
            raise ValueError("No previous condition for Else.")

        old = self.last_cond.pop()
        self.last_cond.append(vtypes.Not(old))

        # if the true-statement has delay attributes,
        # Else statement is separated.
        if 'delay' in self.last_kwargs and self.last_kwargs['delay'] > 0:
            prev_cond = self.last_cond
            ret = self.Then()(*statement)
            self.last_cond = prev_cond
            return ret

        # if there is additional attribute, Else statement is separated.
        has_args = not (len(self.next_kwargs) == 0 or  # has no args
                        (len(self.next_kwargs) == 1 and 'cond' in kwargs)
                        )  # has only 'cond'

        if has_args:
            prev_cond = self.last_cond
            ret = self.Then()(*statement)
            self.last_cond = prev_cond
            return ret

        if not isinstance(self.last_if_statement, vtypes.If):
            raise ValueError("Last if-statement is not If")

        self.last_if_statement.Else(*statement)
        self._add_dst_var(statement)

        return self

    def Elif(self, *cond):
        if len(self.last_cond) == 0:
            raise ValueError("No previous condition for Else.")

        cond = make_condition(*cond)

        old = self.last_cond.pop()
        self.last_cond.append(vtypes.Not(old))
        self.last_cond.append(cond)

        # if the true-statement has delay attributes, Else statement is
        # separated.
        if 'delay' in self.last_kwargs and self.last_kwargs['delay'] > 0:
            prev_cond = self.last_cond
            ret = self.Then()
            self.last_cond = prev_cond
            return ret

        if not isinstance(self.last_if_statement, vtypes.If):
            raise ValueError("Last if-statement is not If")

        self.elif_cond = cond

        cond = self._make_cond(self.last_cond)
        self.next_kwargs['cond'] = cond

        return self

    def Delay(self, delay):
        self.next_kwargs['delay'] = delay
        return self

    def Keep(self, keep):
        self.next_kwargs['keep'] = keep
        return self

    def Then(self):
        cond = self._make_cond(self.last_cond)
        self._clear_last_cond()
        self.If(cond)
        return self

    def LazyCond(self, value=True):
        self.next_kwargs['lazy_cond'] = value
        return self

    def EagerVal(self, value=True):
        self.next_kwargs['eager_val'] = value
        return self

    def Clear(self):
        self._clear_next_kwargs()
        self._clear_last_if_statement()
        self._clear_last_cond()
        self._clear_elif_cond()
        return self

    #-------------------------------------------------------------------------
    @property
    def current(self):
        return self.state_count

    @property
    def next(self):
        return self.current + 1

    @property
    def current_delay(self):
        if 'delay' in self.next_kwargs:
            return self.next_kwargs['delay']
        return 0

    @property
    def last_delay(self):
        if 'delay' in self.last_kwargs:
            return self.last_kwargs['delay']
        return 0

    @property
    def current_condition(self):
        cond = self.next_kwargs['cond'] if 'cond' in self.next_kwargs else None
        if cond is not None:
            cond = vtypes.AndList(self.state == self.state_count, cond)
        else:
            cond = self.state == self.state_count
        return cond

    @property
    def last_condition(self):
        cond = self._make_cond(self.last_cond)
        if cond is not None:
            cond = vtypes.AndList(self.state == self.state_count, cond)
        else:
            cond = self.state == self.state_count
        return cond

    @property
    def then(self):
        return self.last_condition

    @property
    def here(self):
        return self.state == self.current

    #-------------------------------------------------------------------------
    def implement(self):
        if self.as_module:
            self.make_module()
            return

        self.make_always()

    #-------------------------------------------------------------------------
    def make_always(self, reset=(), body=(), case=True):
        if self.done:
            #raise ValueError('make_always() has been already called.')
            return

        self.done = True

        part_reset = self.make_reset(reset)
        part_body = list(body) + list(
            self.make_case() if case else self.make_if())
        self.m.Always(vtypes.Posedge(self.clk))(vtypes.If(self.rst)(
            part_reset, )(part_body, ))