Beispiel #1
0
    def fsm(self, name='fsm', clock_domain=None):
        if clock_domain is None:
            clock_domain = self.clock_domain

        if clock_domain is None:
            raise ValueError('This Module has no clock domain.')

        fsm = veriloggen.FSM(self, name, clock_domain.clock,
                             clock_domain.reset)
        self._fsms.append(fsm)
        # call make_always method when the module is converted into verilog
        self.add_hook(fsm.make_always)

        return fsm
Beispiel #2
0
    def control_sequence(self, fsm):
        act_ram = self.input_rams[0]
        out_ram = self.output_rams[0]

        act_base_offset = self.m.Wire(self._name('act_base_offset'),
                                      self.maxi.addrwidth, signed=True)

        act_offsets = [self.m.Reg(self._name('act_offset_%d' % i),
                                  self.maxi.addrwidth, initval=0, signed=True)
                       for i, _ in enumerate(self.act_shape[:-2])]

        if act_offsets:
            v = act_offsets[0]
            for act_offset in act_offsets[1:]:
                v += act_offset
            act_base_offset.assign(v)
        else:
            act_base_offset.assign(0)

        out_base_offset = self.m.Wire(self._name('out_base_offset'),
                                      self.maxi.addrwidth, signed=True)

        out_offsets = [self.m.Reg(self._name('out_offset_%d' % i),
                                  self.maxi.addrwidth, initval=0, signed=True)
                       for i, _ in enumerate(self.out_shape[:-2])]

        if out_offsets:
            v = out_offsets[0]
            for out_offset in out_offsets[1:]:
                v += out_offset
            out_base_offset.assign(v)
        else:
            out_base_offset.assign(0)

        counts = [self.m.Reg(self._name('count_%d' % i),
                             self.maxi.addrwidth, initval=0)
                  for i, _ in enumerate(self.act_shape[:-2])]

        prev_counts = [self.m.Reg(self._name('prev_count_%d' % i),
                                  self.maxi.addrwidth, initval=0)
                       for i, _ in enumerate(self.act_shape[:-2])]

        stream_act_local = self.m.Reg(self._name('stream_act_local'),
                                      self.maxi.addrwidth, initval=0)
        stream_out_local = self.m.Reg(self._name('stream_out_local'),
                                      self.maxi.addrwidth, initval=0)

        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth, initval=0)
        out_count = self.m.Reg(self._name('out_count'),
                               self.maxi.addrwidth, initval=0)

        act_page = self.m.Reg(self._name('act_page'), initval=0)
        act_page_comp_offset = self.m.Reg(self._name('act_page_comp_offset'),
                                          self.maxi.addrwidth, initval=0)
        act_page_dma_offset = self.m.Reg(self._name('act_page_dma_offset'),
                                         self.maxi.addrwidth, initval=0)

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth, initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth, initval=0)

        act_page_size = act_ram.length // 2
        out_page_size = out_ram.length // 2

        skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0)

        # --------------------
        # initialization phase
        # --------------------
        # ReadAct: offset
        for act_offset, act_offset_begin in zip(act_offsets, self.act_offset_begins):
            fsm(
                act_offset(act_offset_begin)
            )

        # ReadAct: double buffer control
        fsm(
            act_page(0),
            act_page_comp_offset(0),
            act_page_dma_offset(0)
        )

        # WriteOutput: offset
        for out_offset in out_offsets:
            fsm(
                out_offset(0)
            )

        out_offset = out_base_offset

        # WriteOutput: double buffer control
        fsm(
            out_page(0),
            out_page_comp_offset(0),
            out_page_dma_offset(0)
        )

        # counter
        fsm(
            [count(0) for count in counts],
            [prev_count(0) for prev_count in prev_counts]
        )

        # double buffer control
        fsm(
            skip_read_act(0),
            skip_comp(0),
            skip_write_out(1)
        )

        fsm(
            out_count(0)
        )

        state_init = fsm.current

        fsm.goto_next()

        # --------------------
        # ReadAct phase
        # --------------------
        state_read_act = fsm.current

        act_gaddr = self.arg_objaddrs[0] + act_base_offset

        bt.bus_lock(self.maxi, fsm)

        act_laddr = act_page_dma_offset

        begin_state_read = fsm.current
        fsm.goto_next()

        bt.dma_read(self.maxi, fsm, act_ram, act_laddr,
                    act_gaddr, self.act_read_size, port=1)

        end_state_read = fsm.current

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        # Stream Control FSM
        comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst)

        comp_state_init = comp_fsm.current
        comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next()

        fsm.If(comp_fsm.state == comp_state_init).goto_next()

        # local address
        comp_fsm(
            stream_act_local(self.stream_local),
            stream_out_local(0)
        )

        act_page_comp_offset_buf = self.m.Reg(self._name('act_page_comp_offset_buf'),
                                              self.maxi.addrwidth, initval=0)
        out_page_comp_offset_buf = self.m.Reg(self._name('out_page_comp_offset_buf'),
                                              self.maxi.addrwidth, initval=0)

        comp_fsm(
            act_page_comp_offset_buf(act_page_comp_offset),
            out_page_comp_offset_buf(out_page_comp_offset)
        )

        comp_fsm.goto_next()

        # busy check
        self.stream.source_join(comp_fsm)

        # set_source
        name = list(self.stream.sources.keys())[0]
        local = stream_act_local + act_page_comp_offset_buf

        if len(self.out_shape) > 1:
            pat = ((self.stream_size, self.act_strides[-1]),
                   (self.out_shape[-2], self.stream_stride))
        else:
            pat = ((self.stream_size, self.act_strides[-1]),)

        self.stream.set_source_pattern(comp_fsm, name, act_ram,
                                       local, pat)

        comp_fsm.set_index(comp_fsm.current - 1)

        # set_sink
        name = list(self.stream.sinks.keys())[0]
        local = stream_out_local + out_page_comp_offset_buf

        if len(self.out_shape) > 1:
            pat = ((self.stream_size, 1),
                   (self.out_shape[-2], self.stream_size))
        else:
            pat = ((self.stream_size, 1),)

        self.stream.set_sink_pattern(comp_fsm, name, out_ram,
                                     local, pat)

        # stream run (async)
        self.stream.run(comp_fsm)

        comp_fsm.goto_init()

        # sync with WriteOut control
        comp_fsm.seq.If(fsm.state == state_init)(
            comp_count(0)
        )
        comp_fsm.seq.If(self.stream.end_flag)(
            comp_count.inc()
        )

        # --------------------
        # WriteOut phase
        # --------------------
        state_write_out = fsm.current

        # sync with Comp control
        fsm.If(comp_count > out_count).goto_next()

        out_laddr = out_page_dma_offset
        out_gaddr = self.objaddr + out_offset

        bt.bus_lock(self.maxi, fsm)

        bt.dma_write(self.maxi, fsm, out_ram, out_laddr,
                     out_gaddr, self.out_write_size, port=1, use_async=True)

        bt.bus_unlock(self.maxi, fsm)

        fsm(
            out_count.inc()
        )

        fsm.goto_next()

        state_write_out_end = fsm.current
        fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end)

        # --------------------
        # update for next iteration
        # --------------------
        # ReadAct: count
        cond = None
        for size, count in zip(reversed(self.out_shape[:-2]), reversed(counts)):

            fsm.If(cond)(
                count.inc()
            )
            fsm.If(cond, count >= size - 1)(
                count(0)
            )
            if cond is not None:
                cond = vg.Ands(cond, count >= size - 1)
            else:
                cond = count >= size - 1

        # ReadAct: offset
        cond = None
        for size, count, act_offset, act_offset_stride in zip(
                reversed(self.out_shape[:-2]), reversed(counts),
                reversed(act_offsets), reversed(self.act_offset_strides)):

            fsm.If(cond)(
                act_offset.add(act_offset_stride)
            )
            fsm.If(cond, count >= size - 1)(
                act_offset(0)
            )
            if cond is not None:
                cond = vg.Ands(cond, count >= size - 1)
            else:
                cond = count >= size - 1

        # ReadAct and Comp: double buffer
        fsm.If(vg.Not(act_page))(
            act_page_comp_offset(act_page_size),
            act_page_dma_offset(act_page_size),
            act_page(1)
        )
        fsm.If(act_page)(
            act_page_comp_offset(0),
            act_page_dma_offset(0),
            act_page(0)
        )

        # WriteOut: offset
        cond = vg.Not(skip_write_out)
        for size, prev_count, out_offset, out_offset_stride in zip(
                reversed(self.out_shape[:-2]), reversed(prev_counts),
                reversed(out_offsets), reversed(self.out_offset_strides)):

            fsm.If(cond)(
                out_offset.add(out_offset_stride)
            )
            fsm.If(cond, prev_count >= size - 1)(
                out_offset(0)
            )
            cond = vg.Ands(cond, prev_count >= size - 1)

        # WriteOut and Comp: double buffer
        fsm.If(vg.Not(out_page))(
            out_page_comp_offset(out_page_size),
            out_page_dma_offset(0),
            out_page(1)
        )
        fsm.If(out_page)(
            out_page_comp_offset(0),
            out_page_dma_offset(out_page_size),
            out_page(0)
        )

        # ReadAct and WriteOut: prev
        for count, prev_count in zip(counts, prev_counts):
            fsm(
                prev_count(count)
            )

        # ReadAct, Comp, WriteOut: skip
        cond_skip_read_act = None
        cond_skip_comp = None
        for size, count in zip(reversed(self.out_shape[:-2]), reversed(counts)):
            if cond_skip_read_act is not None:
                cond_skip_read_act = vg.Ands(cond_skip_read_act, count >= size - 1)
            else:
                cond_skip_read_act = count >= size - 1

        cond_skip_comp = cond_skip_read_act

        cond_cancel_write_out = None
        for size, prev_count in zip(reversed(self.out_shape[:-2]), reversed(prev_counts)):
            if cond_cancel_write_out is not None:
                cond_cancel_write_out = vg.Ands(cond_cancel_write_out, prev_count == 0)
            else:
                cond_cancel_write_out = prev_count == 0

        cond_done = None
        for size, prev_count in zip(reversed(self.out_shape[:-2]), reversed(prev_counts)):
            if cond_done is not None:
                cond_done = vg.Ands(cond_done, prev_count >= size - 1)
            else:
                cond_done = prev_count >= size - 1

        fsm.If(cond_skip_read_act)(
            skip_read_act(1)
        )
        fsm.If(cond_skip_comp)(
            skip_comp(1)
        )
        fsm.If(skip_write_out,
               cond_cancel_write_out)(
            skip_write_out(0)
        )

        fsm.goto(state_read_act)
        fsm.If(vg.Not(skip_write_out), cond_done).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)
Beispiel #3
0
    def control_sequence(self, fsm):
        ksize_ch = self.ksize[-1]
        ksize_col = self.ksize[-2]
        ksize_row = self.ksize[-3]
        ksize_bat = self.ksize[-4]

        self.stride_bat = 1

        act_rams = self.input_rams
        out_ram = self.output_rams[0]

        act_base_offset = self.m.Wire(self._name('act_base_offset'),
                                      self.maxi.addrwidth,
                                      signed=True)
        act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)
        act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)

        act_base_offset.assign(act_base_offset_row + act_base_offset_bat)

        out_base_offset = self.m.Wire(self._name('out_base_offset'),
                                      self.maxi.addrwidth,
                                      signed=True)
        out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)
        out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)

        out_base_offset.assign(out_base_offset_row + out_base_offset_bat)

        dma_flags = [
            self.m.Reg(self._name('dma_flag_%d' % i), initval=0)
            for i in range(ksize_row)
        ]

        col_count = self.m.Reg(self._name('col_count'),
                               self.maxi.addrwidth,
                               initval=0)
        row_count = self.m.Reg(self._name('row_count'),
                               self.maxi.addrwidth,
                               initval=0)
        bat_count = self.m.Reg(self._name('bat_count'),
                               self.maxi.addrwidth,
                               initval=0)
        col_select = self.m.Reg(self._name('col_select'),
                                bt.log_width(ksize_col),
                                initval=0)
        row_select = self.m.Reg(self._name('row_select'),
                                bt.log_width(ksize_row),
                                initval=0)

        prev_row_count = self.m.Reg(self._name('prev_row_count'),
                                    self.maxi.addrwidth,
                                    initval=0)
        prev_bat_count = self.m.Reg(self._name('prev_bat_count'),
                                    self.maxi.addrwidth,
                                    initval=0)
        prev_row_select = self.m.Reg(self._name('prev_row_select'),
                                     bt.log_width(ksize_row),
                                     initval=0)

        stream_act_locals = [
            self.m.Reg(self._name('stream_act_local_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(len(act_rams))
        ]
        stream_out_local = self.m.Reg(self._name('stream_out_local'),
                                      self.maxi.addrwidth,
                                      initval=0)

        # double buffer control
        act_pages = [
            self.m.Reg(self._name('act_page_%d' % i), initval=0)
            for i in range(ksize_row)
        ]
        act_page_comp_offsets = [
            self.m.Reg(self._name('act_page_comp_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]
        act_page_dma_offsets = [
            self.m.Reg(self._name('act_page_dma_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth,
                                          initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth,
                                         initval=0)

        act_page_size = act_rams[0].length // 2
        out_page_size = out_ram.length // 2

        skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0)

        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth,
                                initval=0)
        out_count = self.m.Reg(self._name('out_count'),
                               self.maxi.addrwidth,
                               initval=0)

        # --------------------
        # initialization phase
        # --------------------
        # ReadAct: offset
        fsm(act_base_offset_row(0), act_base_offset_bat(0))

        act_offsets = []
        for v in self.act_offset_values:
            act_offset = act_base_offset + v
            act_offsets.append(act_offset)

        # ReadAct: DMA flag
        for y, dma_flag in enumerate(dma_flags):
            fsm(dma_flag(1))

        dma_pad_masks = []

        for y in range(ksize_row):
            v = vg.Ors((row_count + y < self.pad_row_top),
                       (row_count + y >= self.act_num_row + self.pad_row_top))
            dma_pad_mask = self.m.Wire(self._name('dma_pad_mask_%d' % y))
            dma_pad_mask.assign(v)
            dma_pad_masks.append(dma_pad_mask)

        # ReadAct: double buffer control
        fsm([act_page(0) for act_page in act_pages], [
            act_page_comp_offset(0)
            for act_page_comp_offset in act_page_comp_offsets
        ], [
            act_page_dma_offset(0)
            for act_page_dma_offset in act_page_dma_offsets
        ])

        # WriteOutput: offset
        fsm(out_base_offset_row(0), out_base_offset_bat(0))

        out_offset = out_base_offset

        # WriteOut: double buffer control
        fsm(out_page(0), out_page_comp_offset(0), out_page_dma_offset(0))

        # counter
        fsm(row_count(0), bat_count(0), row_select(0), prev_row_count(0),
            prev_bat_count(0), prev_row_select(0))

        # double buffer control
        fsm(skip_read_act(0), skip_comp(0), skip_write_out(1))

        fsm(out_count(0))

        state_init = fsm.current

        fsm.goto_next()

        # --------------------
        # ReadAct phase
        # --------------------
        state_read_act = fsm.current

        act_gaddrs = []
        for act_offset in act_offsets:
            act_gaddr = self.arg_objaddrs[0] + act_offset
            act_gaddrs.append(act_gaddr)

        act_rams_2d = line_to_2d(act_rams, ksize_col)
        mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row)
        mux_act_gaddrs = []
        for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values):
            mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i),
                                        self.maxi.addrwidth)
            mux_act_gaddr.assign(mux_act_gaddr_value)
            mux_act_gaddrs.append(mux_act_gaddr)

        mux_dma_pad_mask_values = mux_1d(dma_pad_masks, row_select, ksize_row)
        mux_dma_pad_masks = []
        for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values):
            mux_dma_pad_mask = self.m.Wire(
                self._name('mux_dma_pad_mask_%d' % i))
            mux_dma_pad_mask.assign(mux_dma_pad_mask_value)
            mux_dma_pad_masks.append(mux_dma_pad_mask)

        # determined at the previous phase
        mux_dma_flag_values = mux_1d(dma_flags, prev_row_select, ksize_row)
        mux_dma_flags = []
        for i, mux_dma_flag_value in enumerate(mux_dma_flag_values):
            mux_dma_flag = self.m.Wire(self._name('mux_dma_flag_%d' % i))
            mux_dma_flag.assign(mux_dma_flag_value)
            mux_dma_flags.append(mux_dma_flag)

        bt.bus_lock(self.maxi, fsm)

        for (act_rams_row, act_gaddr, act_page_dma_offset, dma_pad_mask,
             dma_flag) in zip(act_rams_2d, mux_act_gaddrs,
                              act_page_dma_offsets, mux_dma_pad_masks,
                              mux_dma_flags):
            act_laddr = act_page_dma_offset

            begin_state_read = fsm.current
            fsm.goto_next()

            if len(act_rams_row) == 1:
                bt.dma_read(self.maxi,
                            fsm,
                            act_rams_row[0],
                            act_laddr,
                            act_gaddr,
                            self.act_read_size,
                            port=1)
            else:
                bt.dma_read_block(self.maxi,
                                  fsm,
                                  act_rams_row,
                                  act_laddr,
                                  act_gaddr,
                                  self.act_read_size,
                                  self.act_read_block,
                                  port=1)

            end_state_read = fsm.current

            fsm.If(vg.Ors(dma_pad_mask,
                          vg.Not(dma_flag))).goto_from(begin_state_read,
                                                       end_state_read)

        bt.bus_unlock(self.maxi, fsm)

        fsm.goto_next()
        state_read_act_end = fsm.current
        fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end)

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        # Stream Control FSM
        comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst)

        comp_state_init = comp_fsm.current
        comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next()

        fsm.If(comp_fsm.state == comp_state_init).goto_next()

        # local address
        stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col)
        for y, stream_act_locals_row in enumerate(stream_act_locals_2d):
            for x, stream_act_local in enumerate(stream_act_locals_row):
                comp_fsm(stream_act_local(0))
                comp_fsm.If(self.stream_act_local_small_flags[x])(
                    stream_act_local(self.stream_act_local_small_offset))
                comp_fsm.If(self.stream_act_local_large_flags[x])(
                    stream_act_local(self.stream_act_local_large_offset))

        comp_fsm(stream_out_local(0))

        # count and sel
        comp_fsm(col_count(0))
        comp_fsm(col_select(self.col_select_initval))

        act_page_comp_offset_bufs = [
            self.m.Reg(self._name('act_page_comp_offset_buf_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]
        out_page_comp_offset_buf = self.m.Reg(
            self._name('out_page_comp_offset_buf'),
            self.maxi.addrwidth,
            initval=0)
        row_count_buf = self.m.Reg(self._name('row_count_buf'),
                                   self.maxi.addrwidth,
                                   initval=0)
        row_select_buf = self.m.Reg(self._name('row_select_buf'),
                                    bt.log_width(ksize_row),
                                    initval=0)
        comp_fsm([
            act_page_comp_offset_buf(act_page_comp_offset)
            for act_page_comp_offset_buf, act_page_comp_offset in zip(
                act_page_comp_offset_bufs, act_page_comp_offsets)
        ], out_page_comp_offset_buf(out_page_comp_offset),
                 row_count_buf(row_count), row_select_buf(row_select))

        comp_fsm.goto_next()

        # repeat comp
        comp_state_rep = comp_fsm.current

        # pad_mask
        stream_pad_masks = []

        for y in range(ksize_row):
            for x in range(ksize_col):
                stream_col_count = col_count + x
                stream_row_count = row_count_buf + y
                v = vg.Ors(
                    (stream_col_count < self.pad_col_left),
                    (stream_col_count >= self.act_num_col + self.pad_col_left),
                    (stream_row_count < self.pad_row_top),
                    (stream_row_count >= self.act_num_row + self.pad_row_top))
                stream_pad_mask = self.m.Wire(
                    self._name('stream_pad_mask_%d_%d' % (y, x)))
                stream_pad_mask.assign(v)
                stream_pad_masks.append(stream_pad_mask)

        stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col)
        stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d, col_select,
                                        row_select_buf, ksize_col, ksize_row)
        stream_pad_masks = [
            flatten for inner in stream_pad_mask_2d_mux for flatten in inner
        ]

        stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'),
                                          len(stream_pad_masks),
                                          initval=0)
        comp_fsm(stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks))))
        comp_fsm.goto_next()

        # busy check
        self.stream.source_join(comp_fsm)

        stream_masks = stream_pad_masks_reg

        # set_constant
        name = list(self.stream.constants.keys())[0]
        self.stream.set_constant(comp_fsm, name, stream_masks)
        comp_fsm.set_index(comp_fsm.current - 1)

        # set_source
        act_page_comp_offset_bufs_dup = []
        for act_page_comp_offset_buf in act_page_comp_offset_bufs:
            act_page_comp_offset_bufs_dup.extend([act_page_comp_offset_buf] *
                                                 ksize_col)

        for name, ram, stream_act_local, act_page_comp_offset_buf in zip(
                self.stream.sources.keys(), act_rams, stream_act_locals,
                act_page_comp_offset_bufs_dup):
            local = stream_act_local + act_page_comp_offset_buf
            self.stream.set_source(comp_fsm, name, ram, local,
                                   self.stream_size)
            comp_fsm.set_index(comp_fsm.current - 1)

        # set_sink
        name = list(self.stream.sinks.keys())[0]
        local = stream_out_local + out_page_comp_offset_buf
        self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size)

        # stream run (async)
        self.stream.run(comp_fsm)

        # stream_act_local
        stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col)

        i = 0
        for y, stream_act_locals_row in enumerate(stream_act_locals_2d):
            for x, stream_act_local in enumerate(stream_act_locals_row):
                patterns = []
                for col in range(ksize_col):
                    val = self.inc_act_laddr_conds[i]
                    i += 1
                    pat = (col_select == col, val)
                    patterns.append(pat)

                patterns.append((None, 0))
                v = vg.PatternMux(*patterns)

                comp_fsm.If(vg.Not(v))(stream_act_local.add(
                    self.inc_act_laddr_small))
                comp_fsm.If(v)(stream_act_local.add(self.inc_act_laddr_large))

                comp_fsm.If(col_count >= self.max_col_count)(
                    stream_act_local(0))
                comp_fsm.If(col_count >= self.max_col_count,
                            self.stream_act_local_small_flags[x])(
                                stream_act_local(
                                    self.stream_act_local_small_offset))
                comp_fsm.If(col_count >= self.max_col_count,
                            self.stream_act_local_large_flags[x])(
                                stream_act_local(
                                    self.stream_act_local_large_offset))

        # stream_out_local
        comp_fsm(stream_out_local.add(self.inc_out_laddr))
        comp_fsm.If(col_count >= self.max_col_count)(stream_out_local(0))

        # counter
        comp_fsm(col_count.add(self.stride_col))
        comp_fsm.If(col_count >= self.max_col_count)(col_count(0))

        comp_fsm(col_select.add(self.stride_col_mod_ksize))
        comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)(
            col_select.sub(self.ksize_col_minus_stride_col_mod))

        comp_fsm.If(col_count >= self.max_col_count)(col_select(
            self.col_select_initval), )

        # repeat
        comp_fsm.goto(comp_state_rep)
        comp_fsm.If(col_count >= self.max_col_count).goto_init()

        # sync with WriteOut control
        comp_fsm.seq.If(fsm.state == state_init)(comp_count(0))
        comp_fsm.seq.If(self.stream.end_flag)(comp_count.add(
            self.inc_out_laddr))

        # --------------------
        # WriteOut phase
        # --------------------
        state_write_out = fsm.current

        # sync with Comp control
        fsm.If(comp_count >= out_count + self.out_write_size).goto_next()

        out_laddr = out_page_dma_offset
        out_gaddr = self.objaddr + out_offset

        bt.bus_lock(self.maxi, fsm)

        bt.dma_write(self.maxi,
                     fsm,
                     out_ram,
                     out_laddr,
                     out_gaddr,
                     self.out_write_size,
                     port=1,
                     use_async=True)

        bt.bus_unlock(self.maxi, fsm)

        fsm(out_count.add(self.out_write_size))

        fsm.goto_next()

        state_write_out_end = fsm.current
        fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end)

        # --------------------
        # update for next iteration
        # --------------------
        # ReadAct: offset
        fsm(act_base_offset_row.add(self.act_row_step))
        fsm.If(row_count >= self.max_row_count)(act_base_offset_row(0),
                                                act_base_offset_bat.add(
                                                    self.act_bat_step))
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(act_base_offset_bat(0))

        # ReadAct: DMA flag
        next_dma_flags = []
        for dma_flag, dma_flag_cond in zip(dma_flags, self.dma_flag_conds):
            fsm(dma_flag(dma_flag_cond))

            fsm.If(row_count >= self.max_row_count)(dma_flag(1))

            next_dma_flags.append(
                vg.Mux(row_count >= self.max_row_count, 1, dma_flag_cond))

        # ReadAct: counter
        fsm(row_count.add(self.stride_row))
        fsm.If(row_count >= self.max_row_count)(row_count(0),
                                                bat_count.add(self.stride_bat))
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(bat_count(0))

        fsm.If(self.stride_row < ksize_row)(row_select.add(self.stride_row),
                                            prev_row_select(row_select))
        fsm.If(self.stride_row < ksize_row,
               row_select + self.stride_row >= ksize_row)(
                   row_select(row_select -
                              (vg.Int(ksize_row) - self.stride_row)),
                   prev_row_select(row_select))
        fsm.If(vg.Not(self.stride_row < ksize_row))(row_select(0),
                                                    prev_row_select(0))

        fsm.If(row_count >= self.max_row_count)(row_select(0),
                                                prev_row_select(0))

        # ReadAct and Comp: double buffer
        mux_next_dma_flag_values = mux_1d(next_dma_flags, row_select,
                                          ksize_row)
        mux_next_dma_flags = []
        for i, mux_next_dma_flag_value in enumerate(mux_next_dma_flag_values):
            mux_next_dma_flag = self.m.Wire(
                self._name('mux_next_dma_flag_%d' % i))
            mux_next_dma_flag.assign(mux_next_dma_flag_value)
            mux_next_dma_flags.append(mux_next_dma_flag)

        for (act_page, act_page_comp_offset, act_page_dma_offset,
             mux_next_dma_flag) in zip(act_pages, act_page_comp_offsets,
                                       act_page_dma_offsets,
                                       mux_next_dma_flags):

            fsm.If(vg.Not(act_page),
                   mux_next_dma_flag)(act_page_comp_offset(act_page_size),
                                      act_page_dma_offset(act_page_size),
                                      act_page(1))
            fsm.If(act_page, mux_next_dma_flag)(act_page_comp_offset(0),
                                                act_page_dma_offset(0),
                                                act_page(0))

        # WriteOut: counter
        fsm.If(vg.Not(skip_write_out))(out_base_offset_row.add(
            self.out_row_step))
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count)(
            out_base_offset_row(0), out_base_offset_bat.add(self.out_bat_step))
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count)(out_base_offset_bat(0))

        # WriteOut and Comp: double buffer
        fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size),
                                 out_page_dma_offset(0), out_page(1))
        fsm.If(out_page)(out_page_comp_offset(0),
                         out_page_dma_offset(out_page_size), out_page(0))

        # ReadAct and WriteOut: prev
        fsm(prev_row_count(row_count), prev_bat_count(bat_count))

        # ReadAct, Comp, WriteOut: skip
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(skip_read_act(1))

        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(skip_comp(1))

        fsm.If(skip_write_out, prev_row_count == 0,
               prev_bat_count == 0)(skip_write_out(0))

        fsm.goto(state_read_act)
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)
Beispiel #4
0
    def control_sequence(self, fsm):
        ksize_ch = self.ksize[-1]
        ksize_col = self.ksize[-2]
        ksize_row = self.ksize[-3]
        ksize_bat = self.ksize[-4]

        self.stride_bat = 1

        act_ram = self.input_rams[0]
        out_ram = self.output_rams[0]

        act_base_offset = self.m.Wire(self._name('act_base_offset'),
                                      self.maxi.addrwidth, signed=True)
        act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'),
                                         self.maxi.addrwidth, initval=0, signed=True)
        act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'),
                                         self.maxi.addrwidth, initval=0, signed=True)

        act_base_offset.assign(act_base_offset_row
                               + act_base_offset_bat)

        out_base_offset = self.m.Wire(self._name('out_base_offset'),
                                      self.maxi.addrwidth, signed=True)
        out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'),
                                         self.maxi.addrwidth, initval=0, signed=True)
        out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'),
                                         self.maxi.addrwidth, initval=0, signed=True)

        out_base_offset.assign(out_base_offset_row
                               + out_base_offset_bat)

        col_count = self.m.Reg(self._name('col_count'),
                               self.maxi.addrwidth, initval=0)
        row_count = self.m.Reg(self._name('row_count'),
                               self.maxi.addrwidth, initval=0)
        bat_count = self.m.Reg(self._name('bat_count'),
                               self.maxi.addrwidth, initval=0)

        if not self.no_reuse:
            col_select = self.m.Reg(self._name('col_select'),
                                    bt.log_width(ksize_col),
                                    initval=0)
            row_select = self.m.Reg(self._name('row_select'),
                                    bt.log_width(ksize_row),
                                    initval=0)

        prev_row_count = self.m.Reg(self._name('prev_row_count'),
                                    self.maxi.addrwidth, initval=0)
        prev_bat_count = self.m.Reg(self._name('prev_bat_count'),
                                    self.maxi.addrwidth, initval=0)

        if not self.no_reuse:
            prev_row_select = self.m.Reg(self._name('prev_row_select'),
                                         bt.log_width(ksize_row),
                                         initval=0)

        stream_act_local = self.m.Reg(self._name('stream_act_local'),
                                      self.maxi.addrwidth, initval=0)
        stream_out_local = self.m.Reg(self._name('stream_out_local'),
                                      self.maxi.addrwidth, initval=0)

        # double buffer control
        act_page = self.m.Reg(self._name('act_page'), initval=0)
        act_page_comp_offset = self.m.Reg(self._name('act_page_comp_offset'),
                                          self.maxi.addrwidth, initval=0)
        act_page_dma_offset = self.m.Reg(self._name('act_page_dma_offset'),
                                         self.maxi.addrwidth, initval=0)

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth, initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth, initval=0)

        act_page_size = act_ram.length // 2
        out_page_size = out_ram.length // 2

        skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0)

        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth, initval=0)
        out_count = self.m.Reg(self._name('out_count'),
                               self.maxi.addrwidth, initval=0)

        # --------------------
        # initialization phase
        # --------------------
        # ReadAct: offset
        fsm(
            act_base_offset_row(0),
            act_base_offset_bat(0)
        )

        act_offsets = []
        for v in self.act_offset_values:
            act_offset = act_base_offset + v
            act_offsets.append(act_offset)

        # ReadAct: DMA flag
        dma_pad_masks = []

        for y in range(ksize_row):
            v = vg.Ors((row_count + y < self.pad_row_top),
                       (row_count + y >= self.act_num_row + self.pad_row_top))
            dma_pad_mask = self.m.Wire(
                self._name('dma_pad_mask_%d' % y))
            dma_pad_mask.assign(v)
            dma_pad_masks.append(dma_pad_mask)

        # ReadAct: double buffer control
        fsm(
            act_page(0),
            act_page_comp_offset(0),
            act_page_dma_offset(0)
        )

        # WriteOutput: offset
        fsm(
            out_base_offset_row(0),
            out_base_offset_bat(0)
        )

        out_offset = out_base_offset

        # WriteOut: double buffer control
        fsm(
            out_page(0),
            out_page_comp_offset(0),
            out_page_dma_offset(0)
        )

        # counter
        fsm(
            row_count(0),
            bat_count(0),
            prev_row_count(0),
            prev_bat_count(0)
        )

        if not self.no_reuse:
            fsm(
                row_select(0),
                prev_row_select(0)
            )

        # double buffer control
        fsm(
            skip_read_act(0),
            skip_comp(0),
            skip_write_out(1)
        )

        fsm(
            out_count(0)
        )

        state_init = fsm.current

        fsm.goto_next()

        # --------------------
        # ReadAct phase
        # --------------------
        state_read_act = fsm.current

        act_gaddrs = []
        for act_offset in act_offsets:
            act_gaddr = self.arg_objaddrs[0] + act_offset
            act_gaddrs.append(act_gaddr)

        if not self.no_reuse:
            mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row)
            mux_act_gaddrs = []
            for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values):
                mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i),
                                            self.maxi.addrwidth)
                mux_act_gaddr.assign(mux_act_gaddr_value)
                mux_act_gaddrs.append(mux_act_gaddr)

            mux_dma_pad_mask_values = mux_1d(
                dma_pad_masks, row_select, ksize_row)
            mux_dma_pad_masks = []
            for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values):
                mux_dma_pad_mask = self.m.Wire(
                    self._name('mux_dma_pad_mask_%d' % i))
                mux_dma_pad_mask.assign(mux_dma_pad_mask_value)
                mux_dma_pad_masks.append(mux_dma_pad_mask)

        else:
            mux_act_gaddrs = act_gaddrs
            mux_dma_pad_masks = dma_pad_masks

        bt.bus_lock(self.maxi, fsm)

        act_laddr = act_page_dma_offset

        for (act_gaddr, dma_pad_mask) in zip(
                mux_act_gaddrs, mux_dma_pad_masks):
            begin_state_read = fsm.current
            fsm.goto_next()

            bt.dma_read(self.maxi, fsm, act_ram, act_laddr,
                        act_gaddr, self.act_read_size, port=1)

            end_state_read = fsm.current

            fsm.If(dma_pad_mask).goto_from(
                begin_state_read, end_state_read)

            act_laddr += self.act_read_size

        bt.bus_unlock(self.maxi, fsm)

        fsm.goto_next()
        state_read_act_end = fsm.current
        fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end)

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        # Stream Control FSM
        comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst)

        comp_state_init = comp_fsm.current
        comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next()

        fsm.If(comp_fsm.state == comp_state_init).goto_next()

        # local address
        comp_fsm(
            stream_act_local(self.local_pad_offset)
        )

        comp_fsm(
            stream_out_local(0)
        )

        # count and sel
        comp_fsm(
            col_count(0)
        )

        if not self.no_reuse:
            comp_fsm(
                col_select(self.col_select_initval)
            )

        act_page_comp_offset_buf = self.m.Reg(self._name('act_page_comp_offset_buf'),
                                              self.maxi.addrwidth, initval=0)
        out_page_comp_offset_buf = self.m.Reg(self._name('out_page_comp_offset_buf'),
                                              self.maxi.addrwidth, initval=0)
        row_count_buf = self.m.Reg(self._name('row_count_buf'),
                                   self.maxi.addrwidth, initval=0)

        if not self.no_reuse:
            row_select_buf = self.m.Reg(self._name('row_select_buf'),
                                        bt.log_width(ksize_row),
                                        initval=0)
        comp_fsm(
            act_page_comp_offset_buf(act_page_comp_offset),
            out_page_comp_offset_buf(out_page_comp_offset),
            row_count_buf(row_count)
        )

        if not self.no_reuse:
            comp_fsm(
                row_select_buf(row_select)
            )

        comp_fsm.goto_next()

        # repeat comp
        comp_state_rep = comp_fsm.current

        # pad_mask
        stream_pad_masks = []

        for y in range(ksize_row):
            for x in range(ksize_col):
                stream_col_count = col_count + x
                stream_row_count = row_count_buf + y
                v = vg.Ors((stream_col_count < self.pad_col_left),
                           (stream_col_count >= self.act_num_col + self.pad_col_left),
                           (stream_row_count < self.pad_row_top),
                           (stream_row_count >= self.act_num_row + self.pad_row_top))
                stream_pad_mask = self.m.Wire(
                    self._name('stream_pad_mask_%d_%d' % (y, x)))
                stream_pad_mask.assign(v)
                stream_pad_masks.append(stream_pad_mask)

        if not self.no_reuse:
            stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col)
            stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d,
                                            col_select, row_select_buf,
                                            ksize_col, ksize_row)
            stream_pad_masks = [flatten for inner in stream_pad_mask_2d_mux
                                for flatten in inner]

        stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'),
                                          len(stream_pad_masks), initval=0)
        comp_fsm(
            stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks)))
        )
        comp_fsm.goto_next()

        # busy check
        self.stream.source_join(comp_fsm)

        stream_masks = stream_pad_masks_reg

        # set_constant
        name = list(self.stream.constants.keys())[0]
        self.stream.set_constant(comp_fsm, name, ksize_col * ksize_row)
        comp_fsm.set_index(comp_fsm.current - 1)

        name = list(self.stream.constants.keys())[1]
        self.stream.set_constant(comp_fsm, name, stream_masks)
        comp_fsm.set_index(comp_fsm.current - 1)

        # set_source
        name = list(self.stream.sources.keys())[0]
        local = stream_act_local + act_page_comp_offset_buf
        pat = ((ksize_col, self.act_read_block),
               (ksize_row, self.act_read_size),
               (self.stream_size, 1))
        self.stream.set_source_pattern(comp_fsm, name, act_ram,
                                       local, pat)
        comp_fsm.set_index(comp_fsm.current - 1)

        # set_sink
        name = list(self.stream.sinks.keys())[0]
        local = stream_out_local + out_page_comp_offset_buf
        self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size)

        # stream run (async)
        self.stream.run(comp_fsm)

        # stream_act_local
        comp_fsm(
            stream_act_local.add(self.inc_act_laddr)
        )
        comp_fsm.If(col_count >= self.max_col_count)(
            stream_act_local(self.local_pad_offset)
        )

        # stream_out_local
        comp_fsm(
            stream_out_local.add(self.inc_out_laddr)
        )
        comp_fsm.If(col_count >= self.max_col_count)(
            stream_out_local(0)
        )

        # counter
        comp_fsm(
            col_count.add(self.stride_col)
        )
        comp_fsm.If(col_count >= self.max_col_count)(
            col_count(0)
        )

        if not self.no_reuse:
            comp_fsm(
                col_select.add(self.stride_col_mod_ksize)
            )
            comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)(
                col_select.sub(self.ksize_col_minus_stride_col_mod)
            )

            comp_fsm.If(col_count >= self.max_col_count)(
                col_select(self.col_select_initval)
            )

        # repeat
        comp_fsm.goto(comp_state_rep)
        comp_fsm.If(col_count >= self.max_col_count).goto_init()

        # sync with WriteOut control
        comp_fsm.seq.If(fsm.state == state_init)(
            comp_count(0)
        )
        comp_fsm.seq.If(self.stream.end_flag)(
            comp_count.add(self.inc_out_laddr)
        )

        # --------------------
        # WriteOut phase
        # --------------------
        state_write_out = fsm.current

        # sync with Comp control
        fsm.If(comp_count >= out_count + self.out_write_size).goto_next()

        out_laddr = out_page_dma_offset
        out_gaddr = self.objaddr + out_offset

        bt.bus_lock(self.maxi, fsm)

        bt.dma_write(self.maxi, fsm, out_ram, out_laddr,
                     out_gaddr, self.out_write_size, port=1, use_async=True)

        bt.bus_unlock(self.maxi, fsm)

        fsm(
            out_count.add(self.out_write_size)
        )

        fsm.goto_next()

        state_write_out_end = fsm.current
        fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end)

        # --------------------
        # update for next iteration
        # --------------------
        # ReadAct: offset
        fsm(
            act_base_offset_row.add(self.act_row_step)
        )
        fsm.If(row_count >= self.max_row_count)(
            act_base_offset_row(0),
            act_base_offset_bat.add(self.act_bat_step)
        )
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(
            act_base_offset_bat(0)
        )

        # ReadAct: counter
        fsm(
            row_count.add(self.stride_row)
        )
        fsm.If(row_count >= self.max_row_count)(
            row_count(0),
            bat_count.add(self.stride_bat)
        )
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(
            bat_count(0)
        )

        if not self.no_reuse:
            fsm.If(self.stride_row < ksize_row)(
                row_select.add(self.stride_row),
                prev_row_select(row_select)
            )
            fsm.If(self.stride_row < ksize_row,
                   row_select + self.stride_row >= ksize_row)(
                row_select(row_select - (vg.Int(ksize_row) - self.stride_row)),
                prev_row_select(row_select)
            )
            fsm.If(vg.Not(self.stride_row < ksize_row))(
                row_select(0),
                prev_row_select(0)
            )

            fsm.If(row_count >= self.max_row_count)(
                row_select(0),
                prev_row_select(0)
            )

        # ReadAct and Comp: double buffer
        fsm.If(vg.Not(act_page))(
            act_page_comp_offset(act_page_size),
            act_page_dma_offset(act_page_size),
            act_page(1)
        )
        fsm.If(act_page)(
            act_page_comp_offset(0),
            act_page_dma_offset(0),
            act_page(0)
        )

        # WriteOut: counter
        fsm.If(vg.Not(skip_write_out))(
            out_base_offset_row.add(self.out_row_step)
        )
        fsm.If(vg.Not(skip_write_out),
               prev_row_count >= self.max_row_count)(
            out_base_offset_row(0),
            out_base_offset_bat.add(self.out_bat_step)
        )
        fsm.If(vg.Not(skip_write_out),
               prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count)(
            out_base_offset_bat(0)
        )

        # WriteOut and Comp: double buffer
        fsm.If(vg.Not(out_page))(
            out_page_comp_offset(out_page_size),
            out_page_dma_offset(0),
            out_page(1)
        )
        fsm.If(out_page)(
            out_page_comp_offset(0),
            out_page_dma_offset(out_page_size),
            out_page(0)
        )

        # ReadAct and WriteOut: prev
        fsm(
            prev_row_count(row_count),
            prev_bat_count(bat_count)
        )

        # ReadAct, Comp, WriteOut: skip
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(
            skip_read_act(1)
        )

        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(
            skip_comp(1)
        )

        fsm.If(skip_write_out,
               prev_row_count == 0,
               prev_bat_count == 0)(
            skip_write_out(0)
        )

        fsm.goto(state_read_act)
        fsm.If(vg.Not(skip_write_out),
               prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)