def mux_2d(mat, col_select, row_select, col_size, row_size, width=1): ret_list = [] for line in mat: for j in range(col_size): ret = vg.Int(0, width=width) for i in reversed(range(len(line))): ret = vg.Mux( col_select == i, # line[(i + j) % col_size], ret) line[(j + col_size - i) % col_size], ret) ret_list.append(ret) mat = transpose_2d(line_to_2d(ret_list, col_size)) ret_list = [] for line in mat: for j in range(row_size): ret = vg.Int(0, width=width) for i in reversed(range(len(line))): ret = vg.Mux( row_select == i, # line[(i + j) % row_size], ret) line[(j + row_size - i) % row_size], ret) ret_list.append(ret) return transpose_2d(line_to_2d(ret_list, row_size))
def mux_1d(line, select, size, width=1): ret_list = [] for j in range(size): ret = vg.Int(0, width=width) for i in reversed(range(len(line))): ret = vg.Mux(select == i, line[(j + size - i) % size], ret) ret_list.append(ret) return ret_list
def control_sequence(self, fsm): sources = self.collect_sources() arg_gaddrs = [ self.m.Reg(self._name('arg_gaddr_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.arg_objaddrs) ] out_gaddr = self.m.Reg(self._name('out_gaddr'), self.maxi.addrwidth, initval=0) out_gaddr_offset = self.m.Reg(self._name('out_gaddr_offset'), self.maxi.addrwidth, initval=0) out_pos_col = self.m.Reg(self._name('out_pos_col'), self.maxi.addrwidth + 1, initval=0) out_pos_row = self.m.Reg(self._name('out_pos_row'), self.maxi.addrwidth + 1, initval=0) out_col_count = self.m.Reg(self._name('out_col_count'), self.maxi.addrwidth + 1, initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth + 1, initval=0) wrap_counts = [ self.m.Reg(self._name('wrap_count_%d' % i), self.maxi.addrwidth + 1, initval=0) for i, arg in enumerate(sources) ] arg_pages = [ self.m.Reg(self._name('arg_page_%d' % i), initval=0) for i, _ in enumerate(self.arg_objaddrs) ] arg_page_comp_offsets = [ self.m.Reg(self._name('arg_page_comp_offset_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.arg_objaddrs) ] arg_page_dma_offsets = [ self.m.Reg(self._name('arg_page_dma_offset_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.arg_objaddrs) ] out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) arg_page_size = self.output_rams[0].length // 2 out_page_size = self.output_rams[0].length // 2 skip_read = self.m.Reg(self._name('skip_read'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write = self.m.Reg(self._name('skip_write'), initval=0) # -------------------- # initialization phase # -------------------- fsm([arg_gaddr(0) for arg_gaddr in arg_gaddrs]) fsm(comp_count(0), out_gaddr(0), out_gaddr_offset(0), out_pos_col(0), out_pos_row(0), out_col_count(0), [wrap_count(0) for wrap_count in wrap_counts]) fsm([arg_page(0) for arg_page in arg_pages], [ arg_page_comp_offset(0) for arg_page_comp_offset in arg_page_comp_offsets ], [ arg_page_dma_offset(0) for arg_page_dma_offset in arg_page_dma_offsets ]) fsm(out_page(0), out_page_comp_offset(0), out_page_dma_offset(out_page_size)) fsm(skip_read(0), skip_comp(0), skip_write(1)) fsm.goto_next() # -------------------- # Read phase # -------------------- state_read = fsm.current # DMA read -> Stream run -> Stream wait -> DMA write for (ram, arg_objaddr, arg_gaddr, arg_page_dma_offset, wrap_mode, wrap_count, arg) in zip(self.input_rams, self.arg_objaddrs, arg_gaddrs, arg_page_dma_offsets, self.wrap_modes, wrap_counts, sources): b = fsm.current fsm.goto_next() # normal laddr = arg_page_dma_offset gaddr = arg_objaddr + arg_gaddr bt.bus_lock(self.maxi, fsm) bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, self.dma_size) bt.bus_unlock(self.maxi, fsm) fsm.goto_next() b_stride0 = fsm.current fsm.goto_next() # stride-0 bt.bus_lock(self.maxi, fsm) bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, 1) bt.bus_unlock(self.maxi, fsm) fsm.goto_next() # for reuse e = fsm.current fsm.If(wrap_mode == 2, wrap_count > 0).goto_from(b, e) fsm.If(wrap_mode == 2, wrap_count == 0).goto_from(b, b_stride0) fsm.If(wrap_mode != 2).goto_from(b_stride0, e) state_read_end = fsm.current fsm.If(skip_read).goto_from(state_read, state_read_end) # -------------------- # Comp phase # -------------------- state_comp = fsm.current self.stream.source_join(fsm) # set_source, set_constant (dup) for (source_name, dup_name, arg_page_comp_offset, ram, wrap_mode) in zip(self.stream.sources.keys(), self.stream.constants.keys(), arg_page_comp_offsets, self.input_rams, self.wrap_modes): read_laddr = arg_page_comp_offset read_size = self.dma_size stride = vg.Mux(wrap_mode == 2, 0, 1) dup = vg.Mux(wrap_mode == 2, 1, 0) self.stream.set_constant(fsm, dup_name, dup) fsm.set_index(fsm.current - 1) self.stream.set_source(fsm, source_name, ram, read_laddr, read_size, stride) fsm.set_index(fsm.current - 1) # set_sink write_laddr = out_page_comp_offset write_size = self.dma_size for name, ram in zip(self.stream.sinks.keys(), self.output_rams): self.stream.set_sink(fsm, name, ram, write_laddr, write_size) fsm.set_index(fsm.current - 1) fsm.goto_next() self.stream.run(fsm) state_comp_end = fsm.current self.stream.join(fsm) state_comp_end_join = fsm.current fsm.If(skip_comp).goto_from(state_comp, state_comp_end) fsm.If(vg.Not(skip_comp)).goto_from(state_comp_end, state_comp_end_join) # -------------------- # Write phase # -------------------- state_write = fsm.current laddr = out_page_dma_offset gaddr_base = self.objaddr + out_gaddr bt.bus_lock(self.maxi, fsm) b = fsm.current gaddr = gaddr_base + out_gaddr_offset bt.dma_write(self.maxi, fsm, self.output_rams[0], laddr, gaddr, self.dma_size, use_async=True) fsm( out_pos_col.inc(), out_gaddr_offset.add(self.out_col_step), ) fsm.If(out_pos_col == self.max_out_pos_col)(out_pos_col(0), out_pos_row.inc(), out_gaddr_offset.add( self.out_row_step)) fsm.goto(b) fsm.If(out_pos_col == self.max_out_pos_col, out_pos_row == self.max_out_pos_row).goto_next() bt.bus_unlock(self.maxi, fsm) fsm.goto_next() state_write_end = fsm.current fsm.If(skip_write).goto_from(state_write, state_write_end) # -------------------- # update for next iteration # -------------------- fsm(comp_count.inc()) fsm(out_gaddr_offset(0), out_pos_col(0), out_pos_row(0)) fsm.If(vg.Not(skip_write))( out_gaddr.add(self.out_col_inc), out_col_count.inc(), ) fsm.If(vg.Not(skip_write), out_col_count == self.max_out_col_count)( out_gaddr.add(self.out_row_inc), out_col_count(0), ) for (arg_gaddr, arg_addr_inc, arg_page, arg_page_comp_offset, arg_page_dma_offset, wrap_mode, wrap_size, wrap_count, arg) in zip(arg_gaddrs, self.arg_addr_incs, arg_pages, arg_page_comp_offsets, arg_page_dma_offsets, self.wrap_modes, self.wrap_sizes, wrap_counts, sources): fsm.If(wrap_mode == 2)(wrap_count(1)) fsm.If(wrap_mode == 1)(arg_gaddr.add(arg_addr_inc), wrap_count.inc()) fsm.If(wrap_mode == 1, wrap_count == wrap_size - 1)(arg_gaddr(0), wrap_count(0)) fsm.If(wrap_mode == 0)(arg_gaddr.add(arg_addr_inc)) fsm.If(vg.Not(arg_page), wrap_mode != 2)(arg_page_comp_offset(arg_page_size), arg_page_dma_offset(out_page_size), arg_page(1)) fsm.If(arg_page, wrap_mode != 2)(arg_page_comp_offset(0), arg_page_dma_offset(0), arg_page(0)) fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1)) fsm.If(out_page)(out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0)) fsm(skip_write(0)) fsm.If(comp_count == self.num_comp - 1)(skip_read(1), skip_comp(1)) fsm.If(comp_count < self.num_comp).goto(state_read) fsm.If(comp_count == self.num_comp).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)
def control_sequence(self, fsm): ksize_ch = self.ksize[-1] ksize_col = self.ksize[-2] ksize_row = self.ksize[-3] ksize_bat = self.ksize[-4] self.stride_bat = 1 act_rams = self.input_rams out_ram = self.output_rams[0] act_base_offset = self.m.Wire(self._name('act_base_offset'), self.maxi.addrwidth, signed=True) act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset.assign(act_base_offset_row + act_base_offset_bat) out_base_offset = self.m.Wire(self._name('out_base_offset'), self.maxi.addrwidth, signed=True) out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset.assign(out_base_offset_row + out_base_offset_bat) dma_flags = [ self.m.Reg(self._name('dma_flag_%d' % i), initval=0) for i in range(ksize_row) ] col_count = self.m.Reg(self._name('col_count'), self.maxi.addrwidth, initval=0) row_count = self.m.Reg(self._name('row_count'), self.maxi.addrwidth, initval=0) bat_count = self.m.Reg(self._name('bat_count'), self.maxi.addrwidth, initval=0) col_select = self.m.Reg(self._name('col_select'), bt.log_width(ksize_col), initval=0) row_select = self.m.Reg(self._name('row_select'), bt.log_width(ksize_row), initval=0) prev_row_count = self.m.Reg(self._name('prev_row_count'), self.maxi.addrwidth, initval=0) prev_bat_count = self.m.Reg(self._name('prev_bat_count'), self.maxi.addrwidth, initval=0) prev_row_select = self.m.Reg(self._name('prev_row_select'), bt.log_width(ksize_row), initval=0) stream_act_locals = [ self.m.Reg(self._name('stream_act_local_%d' % i), self.maxi.addrwidth, initval=0) for i in range(len(act_rams)) ] stream_out_local = self.m.Reg(self._name('stream_out_local'), self.maxi.addrwidth, initval=0) # double buffer control act_pages = [ self.m.Reg(self._name('act_page_%d' % i), initval=0) for i in range(ksize_row) ] act_page_comp_offsets = [ self.m.Reg(self._name('act_page_comp_offset_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] act_page_dma_offsets = [ self.m.Reg(self._name('act_page_dma_offset_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) act_page_size = act_rams[0].length // 2 out_page_size = out_ram.length // 2 skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth, initval=0) out_count = self.m.Reg(self._name('out_count'), self.maxi.addrwidth, initval=0) # -------------------- # initialization phase # -------------------- # ReadAct: offset fsm(act_base_offset_row(0), act_base_offset_bat(0)) act_offsets = [] for v in self.act_offset_values: act_offset = act_base_offset + v act_offsets.append(act_offset) # ReadAct: DMA flag for y, dma_flag in enumerate(dma_flags): fsm(dma_flag(1)) dma_pad_masks = [] for y in range(ksize_row): v = vg.Ors((row_count + y < self.pad_row_top), (row_count + y >= self.act_num_row + self.pad_row_top)) dma_pad_mask = self.m.Wire(self._name('dma_pad_mask_%d' % y)) dma_pad_mask.assign(v) dma_pad_masks.append(dma_pad_mask) # ReadAct: double buffer control fsm([act_page(0) for act_page in act_pages], [ act_page_comp_offset(0) for act_page_comp_offset in act_page_comp_offsets ], [ act_page_dma_offset(0) for act_page_dma_offset in act_page_dma_offsets ]) # WriteOutput: offset fsm(out_base_offset_row(0), out_base_offset_bat(0)) out_offset = out_base_offset # WriteOut: double buffer control fsm(out_page(0), out_page_comp_offset(0), out_page_dma_offset(0)) # counter fsm(row_count(0), bat_count(0), row_select(0), prev_row_count(0), prev_bat_count(0), prev_row_select(0)) # double buffer control fsm(skip_read_act(0), skip_comp(0), skip_write_out(1)) fsm(out_count(0)) state_init = fsm.current fsm.goto_next() # -------------------- # ReadAct phase # -------------------- state_read_act = fsm.current act_gaddrs = [] for act_offset in act_offsets: act_gaddr = self.arg_objaddrs[0] + act_offset act_gaddrs.append(act_gaddr) act_rams_2d = line_to_2d(act_rams, ksize_col) mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row) mux_act_gaddrs = [] for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values): mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i), self.maxi.addrwidth) mux_act_gaddr.assign(mux_act_gaddr_value) mux_act_gaddrs.append(mux_act_gaddr) mux_dma_pad_mask_values = mux_1d(dma_pad_masks, row_select, ksize_row) mux_dma_pad_masks = [] for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values): mux_dma_pad_mask = self.m.Wire( self._name('mux_dma_pad_mask_%d' % i)) mux_dma_pad_mask.assign(mux_dma_pad_mask_value) mux_dma_pad_masks.append(mux_dma_pad_mask) # determined at the previous phase mux_dma_flag_values = mux_1d(dma_flags, prev_row_select, ksize_row) mux_dma_flags = [] for i, mux_dma_flag_value in enumerate(mux_dma_flag_values): mux_dma_flag = self.m.Wire(self._name('mux_dma_flag_%d' % i)) mux_dma_flag.assign(mux_dma_flag_value) mux_dma_flags.append(mux_dma_flag) bt.bus_lock(self.maxi, fsm) for (act_rams_row, act_gaddr, act_page_dma_offset, dma_pad_mask, dma_flag) in zip(act_rams_2d, mux_act_gaddrs, act_page_dma_offsets, mux_dma_pad_masks, mux_dma_flags): act_laddr = act_page_dma_offset begin_state_read = fsm.current fsm.goto_next() if len(act_rams_row) == 1: bt.dma_read(self.maxi, fsm, act_rams_row[0], act_laddr, act_gaddr, self.act_read_size, port=1) else: bt.dma_read_block(self.maxi, fsm, act_rams_row, act_laddr, act_gaddr, self.act_read_size, self.act_read_block, port=1) end_state_read = fsm.current fsm.If(vg.Ors(dma_pad_mask, vg.Not(dma_flag))).goto_from(begin_state_read, end_state_read) bt.bus_unlock(self.maxi, fsm) fsm.goto_next() state_read_act_end = fsm.current fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end) # -------------------- # Comp phase # -------------------- state_comp = fsm.current # Stream Control FSM comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst) comp_state_init = comp_fsm.current comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next() fsm.If(comp_fsm.state == comp_state_init).goto_next() # local address stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col) for y, stream_act_locals_row in enumerate(stream_act_locals_2d): for x, stream_act_local in enumerate(stream_act_locals_row): comp_fsm(stream_act_local(0)) comp_fsm.If(self.stream_act_local_small_flags[x])( stream_act_local(self.stream_act_local_small_offset)) comp_fsm.If(self.stream_act_local_large_flags[x])( stream_act_local(self.stream_act_local_large_offset)) comp_fsm(stream_out_local(0)) # count and sel comp_fsm(col_count(0)) comp_fsm(col_select(self.col_select_initval)) act_page_comp_offset_bufs = [ self.m.Reg(self._name('act_page_comp_offset_buf_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] out_page_comp_offset_buf = self.m.Reg( self._name('out_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) row_count_buf = self.m.Reg(self._name('row_count_buf'), self.maxi.addrwidth, initval=0) row_select_buf = self.m.Reg(self._name('row_select_buf'), bt.log_width(ksize_row), initval=0) comp_fsm([ act_page_comp_offset_buf(act_page_comp_offset) for act_page_comp_offset_buf, act_page_comp_offset in zip( act_page_comp_offset_bufs, act_page_comp_offsets) ], out_page_comp_offset_buf(out_page_comp_offset), row_count_buf(row_count), row_select_buf(row_select)) comp_fsm.goto_next() # repeat comp comp_state_rep = comp_fsm.current # pad_mask stream_pad_masks = [] for y in range(ksize_row): for x in range(ksize_col): stream_col_count = col_count + x stream_row_count = row_count_buf + y v = vg.Ors( (stream_col_count < self.pad_col_left), (stream_col_count >= self.act_num_col + self.pad_col_left), (stream_row_count < self.pad_row_top), (stream_row_count >= self.act_num_row + self.pad_row_top)) stream_pad_mask = self.m.Wire( self._name('stream_pad_mask_%d_%d' % (y, x))) stream_pad_mask.assign(v) stream_pad_masks.append(stream_pad_mask) stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col) stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d, col_select, row_select_buf, ksize_col, ksize_row) stream_pad_masks = [ flatten for inner in stream_pad_mask_2d_mux for flatten in inner ] stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'), len(stream_pad_masks), initval=0) comp_fsm(stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks)))) comp_fsm.goto_next() # busy check self.stream.source_join(comp_fsm) stream_masks = stream_pad_masks_reg # set_constant name = list(self.stream.constants.keys())[0] self.stream.set_constant(comp_fsm, name, stream_masks) comp_fsm.set_index(comp_fsm.current - 1) # set_source act_page_comp_offset_bufs_dup = [] for act_page_comp_offset_buf in act_page_comp_offset_bufs: act_page_comp_offset_bufs_dup.extend([act_page_comp_offset_buf] * ksize_col) for name, ram, stream_act_local, act_page_comp_offset_buf in zip( self.stream.sources.keys(), act_rams, stream_act_locals, act_page_comp_offset_bufs_dup): local = stream_act_local + act_page_comp_offset_buf self.stream.set_source(comp_fsm, name, ram, local, self.stream_size) comp_fsm.set_index(comp_fsm.current - 1) # set_sink name = list(self.stream.sinks.keys())[0] local = stream_out_local + out_page_comp_offset_buf self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size) # stream run (async) self.stream.run(comp_fsm) # stream_act_local stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col) i = 0 for y, stream_act_locals_row in enumerate(stream_act_locals_2d): for x, stream_act_local in enumerate(stream_act_locals_row): patterns = [] for col in range(ksize_col): val = self.inc_act_laddr_conds[i] i += 1 pat = (col_select == col, val) patterns.append(pat) patterns.append((None, 0)) v = vg.PatternMux(*patterns) comp_fsm.If(vg.Not(v))(stream_act_local.add( self.inc_act_laddr_small)) comp_fsm.If(v)(stream_act_local.add(self.inc_act_laddr_large)) comp_fsm.If(col_count >= self.max_col_count)( stream_act_local(0)) comp_fsm.If(col_count >= self.max_col_count, self.stream_act_local_small_flags[x])( stream_act_local( self.stream_act_local_small_offset)) comp_fsm.If(col_count >= self.max_col_count, self.stream_act_local_large_flags[x])( stream_act_local( self.stream_act_local_large_offset)) # stream_out_local comp_fsm(stream_out_local.add(self.inc_out_laddr)) comp_fsm.If(col_count >= self.max_col_count)(stream_out_local(0)) # counter comp_fsm(col_count.add(self.stride_col)) comp_fsm.If(col_count >= self.max_col_count)(col_count(0)) comp_fsm(col_select.add(self.stride_col_mod_ksize)) comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)( col_select.sub(self.ksize_col_minus_stride_col_mod)) comp_fsm.If(col_count >= self.max_col_count)(col_select( self.col_select_initval), ) # repeat comp_fsm.goto(comp_state_rep) comp_fsm.If(col_count >= self.max_col_count).goto_init() # sync with WriteOut control comp_fsm.seq.If(fsm.state == state_init)(comp_count(0)) comp_fsm.seq.If(self.stream.end_flag)(comp_count.add( self.inc_out_laddr)) # -------------------- # WriteOut phase # -------------------- state_write_out = fsm.current # sync with Comp control fsm.If(comp_count >= out_count + self.out_write_size).goto_next() out_laddr = out_page_dma_offset out_gaddr = self.objaddr + out_offset bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, out_ram, out_laddr, out_gaddr, self.out_write_size, port=1, use_async=True) bt.bus_unlock(self.maxi, fsm) fsm(out_count.add(self.out_write_size)) fsm.goto_next() state_write_out_end = fsm.current fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end) # -------------------- # update for next iteration # -------------------- # ReadAct: offset fsm(act_base_offset_row.add(self.act_row_step)) fsm.If(row_count >= self.max_row_count)(act_base_offset_row(0), act_base_offset_bat.add( self.act_bat_step)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(act_base_offset_bat(0)) # ReadAct: DMA flag next_dma_flags = [] for dma_flag, dma_flag_cond in zip(dma_flags, self.dma_flag_conds): fsm(dma_flag(dma_flag_cond)) fsm.If(row_count >= self.max_row_count)(dma_flag(1)) next_dma_flags.append( vg.Mux(row_count >= self.max_row_count, 1, dma_flag_cond)) # ReadAct: counter fsm(row_count.add(self.stride_row)) fsm.If(row_count >= self.max_row_count)(row_count(0), bat_count.add(self.stride_bat)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(bat_count(0)) fsm.If(self.stride_row < ksize_row)(row_select.add(self.stride_row), prev_row_select(row_select)) fsm.If(self.stride_row < ksize_row, row_select + self.stride_row >= ksize_row)( row_select(row_select - (vg.Int(ksize_row) - self.stride_row)), prev_row_select(row_select)) fsm.If(vg.Not(self.stride_row < ksize_row))(row_select(0), prev_row_select(0)) fsm.If(row_count >= self.max_row_count)(row_select(0), prev_row_select(0)) # ReadAct and Comp: double buffer mux_next_dma_flag_values = mux_1d(next_dma_flags, row_select, ksize_row) mux_next_dma_flags = [] for i, mux_next_dma_flag_value in enumerate(mux_next_dma_flag_values): mux_next_dma_flag = self.m.Wire( self._name('mux_next_dma_flag_%d' % i)) mux_next_dma_flag.assign(mux_next_dma_flag_value) mux_next_dma_flags.append(mux_next_dma_flag) for (act_page, act_page_comp_offset, act_page_dma_offset, mux_next_dma_flag) in zip(act_pages, act_page_comp_offsets, act_page_dma_offsets, mux_next_dma_flags): fsm.If(vg.Not(act_page), mux_next_dma_flag)(act_page_comp_offset(act_page_size), act_page_dma_offset(act_page_size), act_page(1)) fsm.If(act_page, mux_next_dma_flag)(act_page_comp_offset(0), act_page_dma_offset(0), act_page(0)) # WriteOut: counter fsm.If(vg.Not(skip_write_out))(out_base_offset_row.add( self.out_row_step)) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count)( out_base_offset_row(0), out_base_offset_bat.add(self.out_bat_step)) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count)(out_base_offset_bat(0)) # WriteOut and Comp: double buffer fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1)) fsm.If(out_page)(out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0)) # ReadAct and WriteOut: prev fsm(prev_row_count(row_count), prev_bat_count(bat_count)) # ReadAct, Comp, WriteOut: skip fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(skip_read_act(1)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(skip_comp(1)) fsm.If(skip_write_out, prev_row_count == 0, prev_bat_count == 0)(skip_write_out(0)) fsm.goto(state_read_act) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)
def func(a, b): return vg.Mux(a > b, a, b)