Ejemplo n.º 1
0
def mux_2d(mat, col_select, row_select, col_size, row_size, width=1):
    ret_list = []
    for line in mat:
        for j in range(col_size):
            ret = vg.Int(0, width=width)
            for i in reversed(range(len(line))):
                ret = vg.Mux(
                    col_select == i,
                    # line[(i + j) % col_size], ret)
                    line[(j + col_size - i) % col_size],
                    ret)
            ret_list.append(ret)

    mat = transpose_2d(line_to_2d(ret_list, col_size))

    ret_list = []
    for line in mat:
        for j in range(row_size):
            ret = vg.Int(0, width=width)
            for i in reversed(range(len(line))):
                ret = vg.Mux(
                    row_select == i,
                    # line[(i + j) % row_size], ret)
                    line[(j + row_size - i) % row_size],
                    ret)
            ret_list.append(ret)

    return transpose_2d(line_to_2d(ret_list, row_size))
Ejemplo n.º 2
0
Archivo: pool.py Proyecto: shp776/nngen
def mux_1d(line, select, size, width=1):
    ret_list = []
    for j in range(size):
        ret = vg.Int(0, width=width)
        for i in reversed(range(len(line))):
            ret = vg.Mux(select == i, line[(j + size - i) % size], ret)
        ret_list.append(ret)

    return ret_list
Ejemplo n.º 3
0
    def control_sequence(self, fsm):
        sources = self.collect_sources()

        arg_gaddrs = [
            self.m.Reg(self._name('arg_gaddr_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i, _ in enumerate(self.arg_objaddrs)
        ]
        out_gaddr = self.m.Reg(self._name('out_gaddr'),
                               self.maxi.addrwidth,
                               initval=0)
        out_gaddr_offset = self.m.Reg(self._name('out_gaddr_offset'),
                                      self.maxi.addrwidth,
                                      initval=0)
        out_pos_col = self.m.Reg(self._name('out_pos_col'),
                                 self.maxi.addrwidth + 1,
                                 initval=0)
        out_pos_row = self.m.Reg(self._name('out_pos_row'),
                                 self.maxi.addrwidth + 1,
                                 initval=0)
        out_col_count = self.m.Reg(self._name('out_col_count'),
                                   self.maxi.addrwidth + 1,
                                   initval=0)
        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth + 1,
                                initval=0)
        wrap_counts = [
            self.m.Reg(self._name('wrap_count_%d' % i),
                       self.maxi.addrwidth + 1,
                       initval=0) for i, arg in enumerate(sources)
        ]

        arg_pages = [
            self.m.Reg(self._name('arg_page_%d' % i), initval=0)
            for i, _ in enumerate(self.arg_objaddrs)
        ]
        arg_page_comp_offsets = [
            self.m.Reg(self._name('arg_page_comp_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i, _ in enumerate(self.arg_objaddrs)
        ]
        arg_page_dma_offsets = [
            self.m.Reg(self._name('arg_page_dma_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i, _ in enumerate(self.arg_objaddrs)
        ]

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth,
                                          initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth,
                                         initval=0)

        arg_page_size = self.output_rams[0].length // 2
        out_page_size = self.output_rams[0].length // 2

        skip_read = self.m.Reg(self._name('skip_read'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write = self.m.Reg(self._name('skip_write'), initval=0)

        # --------------------
        # initialization phase
        # --------------------
        fsm([arg_gaddr(0) for arg_gaddr in arg_gaddrs])

        fsm(comp_count(0), out_gaddr(0), out_gaddr_offset(0), out_pos_col(0),
            out_pos_row(0), out_col_count(0),
            [wrap_count(0) for wrap_count in wrap_counts])

        fsm([arg_page(0) for arg_page in arg_pages], [
            arg_page_comp_offset(0)
            for arg_page_comp_offset in arg_page_comp_offsets
        ], [
            arg_page_dma_offset(0)
            for arg_page_dma_offset in arg_page_dma_offsets
        ])

        fsm(out_page(0), out_page_comp_offset(0),
            out_page_dma_offset(out_page_size))

        fsm(skip_read(0), skip_comp(0), skip_write(1))

        fsm.goto_next()

        # --------------------
        # Read phase
        # --------------------
        state_read = fsm.current

        # DMA read -> Stream run -> Stream wait -> DMA write
        for (ram, arg_objaddr, arg_gaddr, arg_page_dma_offset, wrap_mode,
             wrap_count, arg) in zip(self.input_rams, self.arg_objaddrs,
                                     arg_gaddrs, arg_page_dma_offsets,
                                     self.wrap_modes, wrap_counts, sources):

            b = fsm.current
            fsm.goto_next()

            # normal
            laddr = arg_page_dma_offset
            gaddr = arg_objaddr + arg_gaddr
            bt.bus_lock(self.maxi, fsm)
            bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, self.dma_size)
            bt.bus_unlock(self.maxi, fsm)
            fsm.goto_next()

            b_stride0 = fsm.current
            fsm.goto_next()

            # stride-0
            bt.bus_lock(self.maxi, fsm)
            bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, 1)
            bt.bus_unlock(self.maxi, fsm)
            fsm.goto_next()

            # for reuse
            e = fsm.current
            fsm.If(wrap_mode == 2, wrap_count > 0).goto_from(b, e)
            fsm.If(wrap_mode == 2, wrap_count == 0).goto_from(b, b_stride0)
            fsm.If(wrap_mode != 2).goto_from(b_stride0, e)

        state_read_end = fsm.current
        fsm.If(skip_read).goto_from(state_read, state_read_end)

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        self.stream.source_join(fsm)

        # set_source, set_constant (dup)
        for (source_name, dup_name, arg_page_comp_offset, ram,
             wrap_mode) in zip(self.stream.sources.keys(),
                               self.stream.constants.keys(),
                               arg_page_comp_offsets, self.input_rams,
                               self.wrap_modes):
            read_laddr = arg_page_comp_offset
            read_size = self.dma_size
            stride = vg.Mux(wrap_mode == 2, 0, 1)
            dup = vg.Mux(wrap_mode == 2, 1, 0)
            self.stream.set_constant(fsm, dup_name, dup)
            fsm.set_index(fsm.current - 1)
            self.stream.set_source(fsm, source_name, ram, read_laddr,
                                   read_size, stride)
            fsm.set_index(fsm.current - 1)

        # set_sink
        write_laddr = out_page_comp_offset
        write_size = self.dma_size

        for name, ram in zip(self.stream.sinks.keys(), self.output_rams):
            self.stream.set_sink(fsm, name, ram, write_laddr, write_size)
            fsm.set_index(fsm.current - 1)

        fsm.goto_next()

        self.stream.run(fsm)

        state_comp_end = fsm.current

        self.stream.join(fsm)

        state_comp_end_join = fsm.current

        fsm.If(skip_comp).goto_from(state_comp, state_comp_end)
        fsm.If(vg.Not(skip_comp)).goto_from(state_comp_end,
                                            state_comp_end_join)

        # --------------------
        # Write phase
        # --------------------
        state_write = fsm.current

        laddr = out_page_dma_offset
        gaddr_base = self.objaddr + out_gaddr

        bt.bus_lock(self.maxi, fsm)

        b = fsm.current

        gaddr = gaddr_base + out_gaddr_offset
        bt.dma_write(self.maxi,
                     fsm,
                     self.output_rams[0],
                     laddr,
                     gaddr,
                     self.dma_size,
                     use_async=True)

        fsm(
            out_pos_col.inc(),
            out_gaddr_offset.add(self.out_col_step),
        )
        fsm.If(out_pos_col == self.max_out_pos_col)(out_pos_col(0),
                                                    out_pos_row.inc(),
                                                    out_gaddr_offset.add(
                                                        self.out_row_step))

        fsm.goto(b)
        fsm.If(out_pos_col == self.max_out_pos_col,
               out_pos_row == self.max_out_pos_row).goto_next()

        bt.bus_unlock(self.maxi, fsm)

        fsm.goto_next()

        state_write_end = fsm.current
        fsm.If(skip_write).goto_from(state_write, state_write_end)

        # --------------------
        # update for next iteration
        # --------------------
        fsm(comp_count.inc())

        fsm(out_gaddr_offset(0), out_pos_col(0), out_pos_row(0))

        fsm.If(vg.Not(skip_write))(
            out_gaddr.add(self.out_col_inc),
            out_col_count.inc(),
        )
        fsm.If(vg.Not(skip_write), out_col_count == self.max_out_col_count)(
            out_gaddr.add(self.out_row_inc),
            out_col_count(0),
        )

        for (arg_gaddr, arg_addr_inc, arg_page, arg_page_comp_offset,
             arg_page_dma_offset, wrap_mode, wrap_size, wrap_count,
             arg) in zip(arg_gaddrs, self.arg_addr_incs, arg_pages,
                         arg_page_comp_offsets, arg_page_dma_offsets,
                         self.wrap_modes, self.wrap_sizes, wrap_counts,
                         sources):

            fsm.If(wrap_mode == 2)(wrap_count(1))

            fsm.If(wrap_mode == 1)(arg_gaddr.add(arg_addr_inc),
                                   wrap_count.inc())
            fsm.If(wrap_mode == 1, wrap_count == wrap_size - 1)(arg_gaddr(0),
                                                                wrap_count(0))

            fsm.If(wrap_mode == 0)(arg_gaddr.add(arg_addr_inc))

            fsm.If(vg.Not(arg_page),
                   wrap_mode != 2)(arg_page_comp_offset(arg_page_size),
                                   arg_page_dma_offset(out_page_size),
                                   arg_page(1))
            fsm.If(arg_page, wrap_mode != 2)(arg_page_comp_offset(0),
                                             arg_page_dma_offset(0),
                                             arg_page(0))

        fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size),
                                 out_page_dma_offset(0), out_page(1))
        fsm.If(out_page)(out_page_comp_offset(0),
                         out_page_dma_offset(out_page_size), out_page(0))

        fsm(skip_write(0))
        fsm.If(comp_count == self.num_comp - 1)(skip_read(1), skip_comp(1))

        fsm.If(comp_count < self.num_comp).goto(state_read)
        fsm.If(comp_count == self.num_comp).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)
Ejemplo n.º 4
0
    def control_sequence(self, fsm):
        ksize_ch = self.ksize[-1]
        ksize_col = self.ksize[-2]
        ksize_row = self.ksize[-3]
        ksize_bat = self.ksize[-4]

        self.stride_bat = 1

        act_rams = self.input_rams
        out_ram = self.output_rams[0]

        act_base_offset = self.m.Wire(self._name('act_base_offset'),
                                      self.maxi.addrwidth,
                                      signed=True)
        act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)
        act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)

        act_base_offset.assign(act_base_offset_row + act_base_offset_bat)

        out_base_offset = self.m.Wire(self._name('out_base_offset'),
                                      self.maxi.addrwidth,
                                      signed=True)
        out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)
        out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)

        out_base_offset.assign(out_base_offset_row + out_base_offset_bat)

        dma_flags = [
            self.m.Reg(self._name('dma_flag_%d' % i), initval=0)
            for i in range(ksize_row)
        ]

        col_count = self.m.Reg(self._name('col_count'),
                               self.maxi.addrwidth,
                               initval=0)
        row_count = self.m.Reg(self._name('row_count'),
                               self.maxi.addrwidth,
                               initval=0)
        bat_count = self.m.Reg(self._name('bat_count'),
                               self.maxi.addrwidth,
                               initval=0)
        col_select = self.m.Reg(self._name('col_select'),
                                bt.log_width(ksize_col),
                                initval=0)
        row_select = self.m.Reg(self._name('row_select'),
                                bt.log_width(ksize_row),
                                initval=0)

        prev_row_count = self.m.Reg(self._name('prev_row_count'),
                                    self.maxi.addrwidth,
                                    initval=0)
        prev_bat_count = self.m.Reg(self._name('prev_bat_count'),
                                    self.maxi.addrwidth,
                                    initval=0)
        prev_row_select = self.m.Reg(self._name('prev_row_select'),
                                     bt.log_width(ksize_row),
                                     initval=0)

        stream_act_locals = [
            self.m.Reg(self._name('stream_act_local_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(len(act_rams))
        ]
        stream_out_local = self.m.Reg(self._name('stream_out_local'),
                                      self.maxi.addrwidth,
                                      initval=0)

        # double buffer control
        act_pages = [
            self.m.Reg(self._name('act_page_%d' % i), initval=0)
            for i in range(ksize_row)
        ]
        act_page_comp_offsets = [
            self.m.Reg(self._name('act_page_comp_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]
        act_page_dma_offsets = [
            self.m.Reg(self._name('act_page_dma_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth,
                                          initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth,
                                         initval=0)

        act_page_size = act_rams[0].length // 2
        out_page_size = out_ram.length // 2

        skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0)

        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth,
                                initval=0)
        out_count = self.m.Reg(self._name('out_count'),
                               self.maxi.addrwidth,
                               initval=0)

        # --------------------
        # initialization phase
        # --------------------
        # ReadAct: offset
        fsm(act_base_offset_row(0), act_base_offset_bat(0))

        act_offsets = []
        for v in self.act_offset_values:
            act_offset = act_base_offset + v
            act_offsets.append(act_offset)

        # ReadAct: DMA flag
        for y, dma_flag in enumerate(dma_flags):
            fsm(dma_flag(1))

        dma_pad_masks = []

        for y in range(ksize_row):
            v = vg.Ors((row_count + y < self.pad_row_top),
                       (row_count + y >= self.act_num_row + self.pad_row_top))
            dma_pad_mask = self.m.Wire(self._name('dma_pad_mask_%d' % y))
            dma_pad_mask.assign(v)
            dma_pad_masks.append(dma_pad_mask)

        # ReadAct: double buffer control
        fsm([act_page(0) for act_page in act_pages], [
            act_page_comp_offset(0)
            for act_page_comp_offset in act_page_comp_offsets
        ], [
            act_page_dma_offset(0)
            for act_page_dma_offset in act_page_dma_offsets
        ])

        # WriteOutput: offset
        fsm(out_base_offset_row(0), out_base_offset_bat(0))

        out_offset = out_base_offset

        # WriteOut: double buffer control
        fsm(out_page(0), out_page_comp_offset(0), out_page_dma_offset(0))

        # counter
        fsm(row_count(0), bat_count(0), row_select(0), prev_row_count(0),
            prev_bat_count(0), prev_row_select(0))

        # double buffer control
        fsm(skip_read_act(0), skip_comp(0), skip_write_out(1))

        fsm(out_count(0))

        state_init = fsm.current

        fsm.goto_next()

        # --------------------
        # ReadAct phase
        # --------------------
        state_read_act = fsm.current

        act_gaddrs = []
        for act_offset in act_offsets:
            act_gaddr = self.arg_objaddrs[0] + act_offset
            act_gaddrs.append(act_gaddr)

        act_rams_2d = line_to_2d(act_rams, ksize_col)
        mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row)
        mux_act_gaddrs = []
        for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values):
            mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i),
                                        self.maxi.addrwidth)
            mux_act_gaddr.assign(mux_act_gaddr_value)
            mux_act_gaddrs.append(mux_act_gaddr)

        mux_dma_pad_mask_values = mux_1d(dma_pad_masks, row_select, ksize_row)
        mux_dma_pad_masks = []
        for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values):
            mux_dma_pad_mask = self.m.Wire(
                self._name('mux_dma_pad_mask_%d' % i))
            mux_dma_pad_mask.assign(mux_dma_pad_mask_value)
            mux_dma_pad_masks.append(mux_dma_pad_mask)

        # determined at the previous phase
        mux_dma_flag_values = mux_1d(dma_flags, prev_row_select, ksize_row)
        mux_dma_flags = []
        for i, mux_dma_flag_value in enumerate(mux_dma_flag_values):
            mux_dma_flag = self.m.Wire(self._name('mux_dma_flag_%d' % i))
            mux_dma_flag.assign(mux_dma_flag_value)
            mux_dma_flags.append(mux_dma_flag)

        bt.bus_lock(self.maxi, fsm)

        for (act_rams_row, act_gaddr, act_page_dma_offset, dma_pad_mask,
             dma_flag) in zip(act_rams_2d, mux_act_gaddrs,
                              act_page_dma_offsets, mux_dma_pad_masks,
                              mux_dma_flags):
            act_laddr = act_page_dma_offset

            begin_state_read = fsm.current
            fsm.goto_next()

            if len(act_rams_row) == 1:
                bt.dma_read(self.maxi,
                            fsm,
                            act_rams_row[0],
                            act_laddr,
                            act_gaddr,
                            self.act_read_size,
                            port=1)
            else:
                bt.dma_read_block(self.maxi,
                                  fsm,
                                  act_rams_row,
                                  act_laddr,
                                  act_gaddr,
                                  self.act_read_size,
                                  self.act_read_block,
                                  port=1)

            end_state_read = fsm.current

            fsm.If(vg.Ors(dma_pad_mask,
                          vg.Not(dma_flag))).goto_from(begin_state_read,
                                                       end_state_read)

        bt.bus_unlock(self.maxi, fsm)

        fsm.goto_next()
        state_read_act_end = fsm.current
        fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end)

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        # Stream Control FSM
        comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst)

        comp_state_init = comp_fsm.current
        comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next()

        fsm.If(comp_fsm.state == comp_state_init).goto_next()

        # local address
        stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col)
        for y, stream_act_locals_row in enumerate(stream_act_locals_2d):
            for x, stream_act_local in enumerate(stream_act_locals_row):
                comp_fsm(stream_act_local(0))
                comp_fsm.If(self.stream_act_local_small_flags[x])(
                    stream_act_local(self.stream_act_local_small_offset))
                comp_fsm.If(self.stream_act_local_large_flags[x])(
                    stream_act_local(self.stream_act_local_large_offset))

        comp_fsm(stream_out_local(0))

        # count and sel
        comp_fsm(col_count(0))
        comp_fsm(col_select(self.col_select_initval))

        act_page_comp_offset_bufs = [
            self.m.Reg(self._name('act_page_comp_offset_buf_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]
        out_page_comp_offset_buf = self.m.Reg(
            self._name('out_page_comp_offset_buf'),
            self.maxi.addrwidth,
            initval=0)
        row_count_buf = self.m.Reg(self._name('row_count_buf'),
                                   self.maxi.addrwidth,
                                   initval=0)
        row_select_buf = self.m.Reg(self._name('row_select_buf'),
                                    bt.log_width(ksize_row),
                                    initval=0)
        comp_fsm([
            act_page_comp_offset_buf(act_page_comp_offset)
            for act_page_comp_offset_buf, act_page_comp_offset in zip(
                act_page_comp_offset_bufs, act_page_comp_offsets)
        ], out_page_comp_offset_buf(out_page_comp_offset),
                 row_count_buf(row_count), row_select_buf(row_select))

        comp_fsm.goto_next()

        # repeat comp
        comp_state_rep = comp_fsm.current

        # pad_mask
        stream_pad_masks = []

        for y in range(ksize_row):
            for x in range(ksize_col):
                stream_col_count = col_count + x
                stream_row_count = row_count_buf + y
                v = vg.Ors(
                    (stream_col_count < self.pad_col_left),
                    (stream_col_count >= self.act_num_col + self.pad_col_left),
                    (stream_row_count < self.pad_row_top),
                    (stream_row_count >= self.act_num_row + self.pad_row_top))
                stream_pad_mask = self.m.Wire(
                    self._name('stream_pad_mask_%d_%d' % (y, x)))
                stream_pad_mask.assign(v)
                stream_pad_masks.append(stream_pad_mask)

        stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col)
        stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d, col_select,
                                        row_select_buf, ksize_col, ksize_row)
        stream_pad_masks = [
            flatten for inner in stream_pad_mask_2d_mux for flatten in inner
        ]

        stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'),
                                          len(stream_pad_masks),
                                          initval=0)
        comp_fsm(stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks))))
        comp_fsm.goto_next()

        # busy check
        self.stream.source_join(comp_fsm)

        stream_masks = stream_pad_masks_reg

        # set_constant
        name = list(self.stream.constants.keys())[0]
        self.stream.set_constant(comp_fsm, name, stream_masks)
        comp_fsm.set_index(comp_fsm.current - 1)

        # set_source
        act_page_comp_offset_bufs_dup = []
        for act_page_comp_offset_buf in act_page_comp_offset_bufs:
            act_page_comp_offset_bufs_dup.extend([act_page_comp_offset_buf] *
                                                 ksize_col)

        for name, ram, stream_act_local, act_page_comp_offset_buf in zip(
                self.stream.sources.keys(), act_rams, stream_act_locals,
                act_page_comp_offset_bufs_dup):
            local = stream_act_local + act_page_comp_offset_buf
            self.stream.set_source(comp_fsm, name, ram, local,
                                   self.stream_size)
            comp_fsm.set_index(comp_fsm.current - 1)

        # set_sink
        name = list(self.stream.sinks.keys())[0]
        local = stream_out_local + out_page_comp_offset_buf
        self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size)

        # stream run (async)
        self.stream.run(comp_fsm)

        # stream_act_local
        stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col)

        i = 0
        for y, stream_act_locals_row in enumerate(stream_act_locals_2d):
            for x, stream_act_local in enumerate(stream_act_locals_row):
                patterns = []
                for col in range(ksize_col):
                    val = self.inc_act_laddr_conds[i]
                    i += 1
                    pat = (col_select == col, val)
                    patterns.append(pat)

                patterns.append((None, 0))
                v = vg.PatternMux(*patterns)

                comp_fsm.If(vg.Not(v))(stream_act_local.add(
                    self.inc_act_laddr_small))
                comp_fsm.If(v)(stream_act_local.add(self.inc_act_laddr_large))

                comp_fsm.If(col_count >= self.max_col_count)(
                    stream_act_local(0))
                comp_fsm.If(col_count >= self.max_col_count,
                            self.stream_act_local_small_flags[x])(
                                stream_act_local(
                                    self.stream_act_local_small_offset))
                comp_fsm.If(col_count >= self.max_col_count,
                            self.stream_act_local_large_flags[x])(
                                stream_act_local(
                                    self.stream_act_local_large_offset))

        # stream_out_local
        comp_fsm(stream_out_local.add(self.inc_out_laddr))
        comp_fsm.If(col_count >= self.max_col_count)(stream_out_local(0))

        # counter
        comp_fsm(col_count.add(self.stride_col))
        comp_fsm.If(col_count >= self.max_col_count)(col_count(0))

        comp_fsm(col_select.add(self.stride_col_mod_ksize))
        comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)(
            col_select.sub(self.ksize_col_minus_stride_col_mod))

        comp_fsm.If(col_count >= self.max_col_count)(col_select(
            self.col_select_initval), )

        # repeat
        comp_fsm.goto(comp_state_rep)
        comp_fsm.If(col_count >= self.max_col_count).goto_init()

        # sync with WriteOut control
        comp_fsm.seq.If(fsm.state == state_init)(comp_count(0))
        comp_fsm.seq.If(self.stream.end_flag)(comp_count.add(
            self.inc_out_laddr))

        # --------------------
        # WriteOut phase
        # --------------------
        state_write_out = fsm.current

        # sync with Comp control
        fsm.If(comp_count >= out_count + self.out_write_size).goto_next()

        out_laddr = out_page_dma_offset
        out_gaddr = self.objaddr + out_offset

        bt.bus_lock(self.maxi, fsm)

        bt.dma_write(self.maxi,
                     fsm,
                     out_ram,
                     out_laddr,
                     out_gaddr,
                     self.out_write_size,
                     port=1,
                     use_async=True)

        bt.bus_unlock(self.maxi, fsm)

        fsm(out_count.add(self.out_write_size))

        fsm.goto_next()

        state_write_out_end = fsm.current
        fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end)

        # --------------------
        # update for next iteration
        # --------------------
        # ReadAct: offset
        fsm(act_base_offset_row.add(self.act_row_step))
        fsm.If(row_count >= self.max_row_count)(act_base_offset_row(0),
                                                act_base_offset_bat.add(
                                                    self.act_bat_step))
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(act_base_offset_bat(0))

        # ReadAct: DMA flag
        next_dma_flags = []
        for dma_flag, dma_flag_cond in zip(dma_flags, self.dma_flag_conds):
            fsm(dma_flag(dma_flag_cond))

            fsm.If(row_count >= self.max_row_count)(dma_flag(1))

            next_dma_flags.append(
                vg.Mux(row_count >= self.max_row_count, 1, dma_flag_cond))

        # ReadAct: counter
        fsm(row_count.add(self.stride_row))
        fsm.If(row_count >= self.max_row_count)(row_count(0),
                                                bat_count.add(self.stride_bat))
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(bat_count(0))

        fsm.If(self.stride_row < ksize_row)(row_select.add(self.stride_row),
                                            prev_row_select(row_select))
        fsm.If(self.stride_row < ksize_row,
               row_select + self.stride_row >= ksize_row)(
                   row_select(row_select -
                              (vg.Int(ksize_row) - self.stride_row)),
                   prev_row_select(row_select))
        fsm.If(vg.Not(self.stride_row < ksize_row))(row_select(0),
                                                    prev_row_select(0))

        fsm.If(row_count >= self.max_row_count)(row_select(0),
                                                prev_row_select(0))

        # ReadAct and Comp: double buffer
        mux_next_dma_flag_values = mux_1d(next_dma_flags, row_select,
                                          ksize_row)
        mux_next_dma_flags = []
        for i, mux_next_dma_flag_value in enumerate(mux_next_dma_flag_values):
            mux_next_dma_flag = self.m.Wire(
                self._name('mux_next_dma_flag_%d' % i))
            mux_next_dma_flag.assign(mux_next_dma_flag_value)
            mux_next_dma_flags.append(mux_next_dma_flag)

        for (act_page, act_page_comp_offset, act_page_dma_offset,
             mux_next_dma_flag) in zip(act_pages, act_page_comp_offsets,
                                       act_page_dma_offsets,
                                       mux_next_dma_flags):

            fsm.If(vg.Not(act_page),
                   mux_next_dma_flag)(act_page_comp_offset(act_page_size),
                                      act_page_dma_offset(act_page_size),
                                      act_page(1))
            fsm.If(act_page, mux_next_dma_flag)(act_page_comp_offset(0),
                                                act_page_dma_offset(0),
                                                act_page(0))

        # WriteOut: counter
        fsm.If(vg.Not(skip_write_out))(out_base_offset_row.add(
            self.out_row_step))
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count)(
            out_base_offset_row(0), out_base_offset_bat.add(self.out_bat_step))
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count)(out_base_offset_bat(0))

        # WriteOut and Comp: double buffer
        fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size),
                                 out_page_dma_offset(0), out_page(1))
        fsm.If(out_page)(out_page_comp_offset(0),
                         out_page_dma_offset(out_page_size), out_page(0))

        # ReadAct and WriteOut: prev
        fsm(prev_row_count(row_count), prev_bat_count(bat_count))

        # ReadAct, Comp, WriteOut: skip
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(skip_read_act(1))

        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(skip_comp(1))

        fsm.If(skip_write_out, prev_row_count == 0,
               prev_bat_count == 0)(skip_write_out(0))

        fsm.goto(state_read_act)
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)
Ejemplo n.º 5
0
 def func(a, b):
     return vg.Mux(a > b, a, b)