def fsm(self, name='fsm', clock_domain=None): if clock_domain is None: clock_domain = self.clock_domain if clock_domain is None: raise ValueError('This Module has no clock domain.') fsm = veriloggen.FSM(self, name, clock_domain.clock, clock_domain.reset) self._fsms.append(fsm) # call make_always method when the module is converted into verilog self.add_hook(fsm.make_always) return fsm
def control_sequence(self, fsm): act_ram = self.input_rams[0] out_ram = self.output_rams[0] act_base_offset = self.m.Wire(self._name('act_base_offset'), self.maxi.addrwidth, signed=True) act_offsets = [self.m.Reg(self._name('act_offset_%d' % i), self.maxi.addrwidth, initval=0, signed=True) for i, _ in enumerate(self.act_shape[:-2])] if act_offsets: v = act_offsets[0] for act_offset in act_offsets[1:]: v += act_offset act_base_offset.assign(v) else: act_base_offset.assign(0) out_base_offset = self.m.Wire(self._name('out_base_offset'), self.maxi.addrwidth, signed=True) out_offsets = [self.m.Reg(self._name('out_offset_%d' % i), self.maxi.addrwidth, initval=0, signed=True) for i, _ in enumerate(self.out_shape[:-2])] if out_offsets: v = out_offsets[0] for out_offset in out_offsets[1:]: v += out_offset out_base_offset.assign(v) else: out_base_offset.assign(0) counts = [self.m.Reg(self._name('count_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.act_shape[:-2])] prev_counts = [self.m.Reg(self._name('prev_count_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.act_shape[:-2])] stream_act_local = self.m.Reg(self._name('stream_act_local'), self.maxi.addrwidth, initval=0) stream_out_local = self.m.Reg(self._name('stream_out_local'), self.maxi.addrwidth, initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth, initval=0) out_count = self.m.Reg(self._name('out_count'), self.maxi.addrwidth, initval=0) act_page = self.m.Reg(self._name('act_page'), initval=0) act_page_comp_offset = self.m.Reg(self._name('act_page_comp_offset'), self.maxi.addrwidth, initval=0) act_page_dma_offset = self.m.Reg(self._name('act_page_dma_offset'), self.maxi.addrwidth, initval=0) out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) act_page_size = act_ram.length // 2 out_page_size = out_ram.length // 2 skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0) # -------------------- # initialization phase # -------------------- # ReadAct: offset for act_offset, act_offset_begin in zip(act_offsets, self.act_offset_begins): fsm( act_offset(act_offset_begin) ) # ReadAct: double buffer control fsm( act_page(0), act_page_comp_offset(0), act_page_dma_offset(0) ) # WriteOutput: offset for out_offset in out_offsets: fsm( out_offset(0) ) out_offset = out_base_offset # WriteOutput: double buffer control fsm( out_page(0), out_page_comp_offset(0), out_page_dma_offset(0) ) # counter fsm( [count(0) for count in counts], [prev_count(0) for prev_count in prev_counts] ) # double buffer control fsm( skip_read_act(0), skip_comp(0), skip_write_out(1) ) fsm( out_count(0) ) state_init = fsm.current fsm.goto_next() # -------------------- # ReadAct phase # -------------------- state_read_act = fsm.current act_gaddr = self.arg_objaddrs[0] + act_base_offset bt.bus_lock(self.maxi, fsm) act_laddr = act_page_dma_offset begin_state_read = fsm.current fsm.goto_next() bt.dma_read(self.maxi, fsm, act_ram, act_laddr, act_gaddr, self.act_read_size, port=1) end_state_read = fsm.current # -------------------- # Comp phase # -------------------- state_comp = fsm.current # Stream Control FSM comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst) comp_state_init = comp_fsm.current comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next() fsm.If(comp_fsm.state == comp_state_init).goto_next() # local address comp_fsm( stream_act_local(self.stream_local), stream_out_local(0) ) act_page_comp_offset_buf = self.m.Reg(self._name('act_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) out_page_comp_offset_buf = self.m.Reg(self._name('out_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) comp_fsm( act_page_comp_offset_buf(act_page_comp_offset), out_page_comp_offset_buf(out_page_comp_offset) ) comp_fsm.goto_next() # busy check self.stream.source_join(comp_fsm) # set_source name = list(self.stream.sources.keys())[0] local = stream_act_local + act_page_comp_offset_buf if len(self.out_shape) > 1: pat = ((self.stream_size, self.act_strides[-1]), (self.out_shape[-2], self.stream_stride)) else: pat = ((self.stream_size, self.act_strides[-1]),) self.stream.set_source_pattern(comp_fsm, name, act_ram, local, pat) comp_fsm.set_index(comp_fsm.current - 1) # set_sink name = list(self.stream.sinks.keys())[0] local = stream_out_local + out_page_comp_offset_buf if len(self.out_shape) > 1: pat = ((self.stream_size, 1), (self.out_shape[-2], self.stream_size)) else: pat = ((self.stream_size, 1),) self.stream.set_sink_pattern(comp_fsm, name, out_ram, local, pat) # stream run (async) self.stream.run(comp_fsm) comp_fsm.goto_init() # sync with WriteOut control comp_fsm.seq.If(fsm.state == state_init)( comp_count(0) ) comp_fsm.seq.If(self.stream.end_flag)( comp_count.inc() ) # -------------------- # WriteOut phase # -------------------- state_write_out = fsm.current # sync with Comp control fsm.If(comp_count > out_count).goto_next() out_laddr = out_page_dma_offset out_gaddr = self.objaddr + out_offset bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, out_ram, out_laddr, out_gaddr, self.out_write_size, port=1, use_async=True) bt.bus_unlock(self.maxi, fsm) fsm( out_count.inc() ) fsm.goto_next() state_write_out_end = fsm.current fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end) # -------------------- # update for next iteration # -------------------- # ReadAct: count cond = None for size, count in zip(reversed(self.out_shape[:-2]), reversed(counts)): fsm.If(cond)( count.inc() ) fsm.If(cond, count >= size - 1)( count(0) ) if cond is not None: cond = vg.Ands(cond, count >= size - 1) else: cond = count >= size - 1 # ReadAct: offset cond = None for size, count, act_offset, act_offset_stride in zip( reversed(self.out_shape[:-2]), reversed(counts), reversed(act_offsets), reversed(self.act_offset_strides)): fsm.If(cond)( act_offset.add(act_offset_stride) ) fsm.If(cond, count >= size - 1)( act_offset(0) ) if cond is not None: cond = vg.Ands(cond, count >= size - 1) else: cond = count >= size - 1 # ReadAct and Comp: double buffer fsm.If(vg.Not(act_page))( act_page_comp_offset(act_page_size), act_page_dma_offset(act_page_size), act_page(1) ) fsm.If(act_page)( act_page_comp_offset(0), act_page_dma_offset(0), act_page(0) ) # WriteOut: offset cond = vg.Not(skip_write_out) for size, prev_count, out_offset, out_offset_stride in zip( reversed(self.out_shape[:-2]), reversed(prev_counts), reversed(out_offsets), reversed(self.out_offset_strides)): fsm.If(cond)( out_offset.add(out_offset_stride) ) fsm.If(cond, prev_count >= size - 1)( out_offset(0) ) cond = vg.Ands(cond, prev_count >= size - 1) # WriteOut and Comp: double buffer fsm.If(vg.Not(out_page))( out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1) ) fsm.If(out_page)( out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0) ) # ReadAct and WriteOut: prev for count, prev_count in zip(counts, prev_counts): fsm( prev_count(count) ) # ReadAct, Comp, WriteOut: skip cond_skip_read_act = None cond_skip_comp = None for size, count in zip(reversed(self.out_shape[:-2]), reversed(counts)): if cond_skip_read_act is not None: cond_skip_read_act = vg.Ands(cond_skip_read_act, count >= size - 1) else: cond_skip_read_act = count >= size - 1 cond_skip_comp = cond_skip_read_act cond_cancel_write_out = None for size, prev_count in zip(reversed(self.out_shape[:-2]), reversed(prev_counts)): if cond_cancel_write_out is not None: cond_cancel_write_out = vg.Ands(cond_cancel_write_out, prev_count == 0) else: cond_cancel_write_out = prev_count == 0 cond_done = None for size, prev_count in zip(reversed(self.out_shape[:-2]), reversed(prev_counts)): if cond_done is not None: cond_done = vg.Ands(cond_done, prev_count >= size - 1) else: cond_done = prev_count >= size - 1 fsm.If(cond_skip_read_act)( skip_read_act(1) ) fsm.If(cond_skip_comp)( skip_comp(1) ) fsm.If(skip_write_out, cond_cancel_write_out)( skip_write_out(0) ) fsm.goto(state_read_act) fsm.If(vg.Not(skip_write_out), cond_done).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)
def control_sequence(self, fsm): ksize_ch = self.ksize[-1] ksize_col = self.ksize[-2] ksize_row = self.ksize[-3] ksize_bat = self.ksize[-4] self.stride_bat = 1 act_rams = self.input_rams out_ram = self.output_rams[0] act_base_offset = self.m.Wire(self._name('act_base_offset'), self.maxi.addrwidth, signed=True) act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset.assign(act_base_offset_row + act_base_offset_bat) out_base_offset = self.m.Wire(self._name('out_base_offset'), self.maxi.addrwidth, signed=True) out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset.assign(out_base_offset_row + out_base_offset_bat) dma_flags = [ self.m.Reg(self._name('dma_flag_%d' % i), initval=0) for i in range(ksize_row) ] col_count = self.m.Reg(self._name('col_count'), self.maxi.addrwidth, initval=0) row_count = self.m.Reg(self._name('row_count'), self.maxi.addrwidth, initval=0) bat_count = self.m.Reg(self._name('bat_count'), self.maxi.addrwidth, initval=0) col_select = self.m.Reg(self._name('col_select'), bt.log_width(ksize_col), initval=0) row_select = self.m.Reg(self._name('row_select'), bt.log_width(ksize_row), initval=0) prev_row_count = self.m.Reg(self._name('prev_row_count'), self.maxi.addrwidth, initval=0) prev_bat_count = self.m.Reg(self._name('prev_bat_count'), self.maxi.addrwidth, initval=0) prev_row_select = self.m.Reg(self._name('prev_row_select'), bt.log_width(ksize_row), initval=0) stream_act_locals = [ self.m.Reg(self._name('stream_act_local_%d' % i), self.maxi.addrwidth, initval=0) for i in range(len(act_rams)) ] stream_out_local = self.m.Reg(self._name('stream_out_local'), self.maxi.addrwidth, initval=0) # double buffer control act_pages = [ self.m.Reg(self._name('act_page_%d' % i), initval=0) for i in range(ksize_row) ] act_page_comp_offsets = [ self.m.Reg(self._name('act_page_comp_offset_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] act_page_dma_offsets = [ self.m.Reg(self._name('act_page_dma_offset_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) act_page_size = act_rams[0].length // 2 out_page_size = out_ram.length // 2 skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth, initval=0) out_count = self.m.Reg(self._name('out_count'), self.maxi.addrwidth, initval=0) # -------------------- # initialization phase # -------------------- # ReadAct: offset fsm(act_base_offset_row(0), act_base_offset_bat(0)) act_offsets = [] for v in self.act_offset_values: act_offset = act_base_offset + v act_offsets.append(act_offset) # ReadAct: DMA flag for y, dma_flag in enumerate(dma_flags): fsm(dma_flag(1)) dma_pad_masks = [] for y in range(ksize_row): v = vg.Ors((row_count + y < self.pad_row_top), (row_count + y >= self.act_num_row + self.pad_row_top)) dma_pad_mask = self.m.Wire(self._name('dma_pad_mask_%d' % y)) dma_pad_mask.assign(v) dma_pad_masks.append(dma_pad_mask) # ReadAct: double buffer control fsm([act_page(0) for act_page in act_pages], [ act_page_comp_offset(0) for act_page_comp_offset in act_page_comp_offsets ], [ act_page_dma_offset(0) for act_page_dma_offset in act_page_dma_offsets ]) # WriteOutput: offset fsm(out_base_offset_row(0), out_base_offset_bat(0)) out_offset = out_base_offset # WriteOut: double buffer control fsm(out_page(0), out_page_comp_offset(0), out_page_dma_offset(0)) # counter fsm(row_count(0), bat_count(0), row_select(0), prev_row_count(0), prev_bat_count(0), prev_row_select(0)) # double buffer control fsm(skip_read_act(0), skip_comp(0), skip_write_out(1)) fsm(out_count(0)) state_init = fsm.current fsm.goto_next() # -------------------- # ReadAct phase # -------------------- state_read_act = fsm.current act_gaddrs = [] for act_offset in act_offsets: act_gaddr = self.arg_objaddrs[0] + act_offset act_gaddrs.append(act_gaddr) act_rams_2d = line_to_2d(act_rams, ksize_col) mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row) mux_act_gaddrs = [] for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values): mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i), self.maxi.addrwidth) mux_act_gaddr.assign(mux_act_gaddr_value) mux_act_gaddrs.append(mux_act_gaddr) mux_dma_pad_mask_values = mux_1d(dma_pad_masks, row_select, ksize_row) mux_dma_pad_masks = [] for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values): mux_dma_pad_mask = self.m.Wire( self._name('mux_dma_pad_mask_%d' % i)) mux_dma_pad_mask.assign(mux_dma_pad_mask_value) mux_dma_pad_masks.append(mux_dma_pad_mask) # determined at the previous phase mux_dma_flag_values = mux_1d(dma_flags, prev_row_select, ksize_row) mux_dma_flags = [] for i, mux_dma_flag_value in enumerate(mux_dma_flag_values): mux_dma_flag = self.m.Wire(self._name('mux_dma_flag_%d' % i)) mux_dma_flag.assign(mux_dma_flag_value) mux_dma_flags.append(mux_dma_flag) bt.bus_lock(self.maxi, fsm) for (act_rams_row, act_gaddr, act_page_dma_offset, dma_pad_mask, dma_flag) in zip(act_rams_2d, mux_act_gaddrs, act_page_dma_offsets, mux_dma_pad_masks, mux_dma_flags): act_laddr = act_page_dma_offset begin_state_read = fsm.current fsm.goto_next() if len(act_rams_row) == 1: bt.dma_read(self.maxi, fsm, act_rams_row[0], act_laddr, act_gaddr, self.act_read_size, port=1) else: bt.dma_read_block(self.maxi, fsm, act_rams_row, act_laddr, act_gaddr, self.act_read_size, self.act_read_block, port=1) end_state_read = fsm.current fsm.If(vg.Ors(dma_pad_mask, vg.Not(dma_flag))).goto_from(begin_state_read, end_state_read) bt.bus_unlock(self.maxi, fsm) fsm.goto_next() state_read_act_end = fsm.current fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end) # -------------------- # Comp phase # -------------------- state_comp = fsm.current # Stream Control FSM comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst) comp_state_init = comp_fsm.current comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next() fsm.If(comp_fsm.state == comp_state_init).goto_next() # local address stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col) for y, stream_act_locals_row in enumerate(stream_act_locals_2d): for x, stream_act_local in enumerate(stream_act_locals_row): comp_fsm(stream_act_local(0)) comp_fsm.If(self.stream_act_local_small_flags[x])( stream_act_local(self.stream_act_local_small_offset)) comp_fsm.If(self.stream_act_local_large_flags[x])( stream_act_local(self.stream_act_local_large_offset)) comp_fsm(stream_out_local(0)) # count and sel comp_fsm(col_count(0)) comp_fsm(col_select(self.col_select_initval)) act_page_comp_offset_bufs = [ self.m.Reg(self._name('act_page_comp_offset_buf_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] out_page_comp_offset_buf = self.m.Reg( self._name('out_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) row_count_buf = self.m.Reg(self._name('row_count_buf'), self.maxi.addrwidth, initval=0) row_select_buf = self.m.Reg(self._name('row_select_buf'), bt.log_width(ksize_row), initval=0) comp_fsm([ act_page_comp_offset_buf(act_page_comp_offset) for act_page_comp_offset_buf, act_page_comp_offset in zip( act_page_comp_offset_bufs, act_page_comp_offsets) ], out_page_comp_offset_buf(out_page_comp_offset), row_count_buf(row_count), row_select_buf(row_select)) comp_fsm.goto_next() # repeat comp comp_state_rep = comp_fsm.current # pad_mask stream_pad_masks = [] for y in range(ksize_row): for x in range(ksize_col): stream_col_count = col_count + x stream_row_count = row_count_buf + y v = vg.Ors( (stream_col_count < self.pad_col_left), (stream_col_count >= self.act_num_col + self.pad_col_left), (stream_row_count < self.pad_row_top), (stream_row_count >= self.act_num_row + self.pad_row_top)) stream_pad_mask = self.m.Wire( self._name('stream_pad_mask_%d_%d' % (y, x))) stream_pad_mask.assign(v) stream_pad_masks.append(stream_pad_mask) stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col) stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d, col_select, row_select_buf, ksize_col, ksize_row) stream_pad_masks = [ flatten for inner in stream_pad_mask_2d_mux for flatten in inner ] stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'), len(stream_pad_masks), initval=0) comp_fsm(stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks)))) comp_fsm.goto_next() # busy check self.stream.source_join(comp_fsm) stream_masks = stream_pad_masks_reg # set_constant name = list(self.stream.constants.keys())[0] self.stream.set_constant(comp_fsm, name, stream_masks) comp_fsm.set_index(comp_fsm.current - 1) # set_source act_page_comp_offset_bufs_dup = [] for act_page_comp_offset_buf in act_page_comp_offset_bufs: act_page_comp_offset_bufs_dup.extend([act_page_comp_offset_buf] * ksize_col) for name, ram, stream_act_local, act_page_comp_offset_buf in zip( self.stream.sources.keys(), act_rams, stream_act_locals, act_page_comp_offset_bufs_dup): local = stream_act_local + act_page_comp_offset_buf self.stream.set_source(comp_fsm, name, ram, local, self.stream_size) comp_fsm.set_index(comp_fsm.current - 1) # set_sink name = list(self.stream.sinks.keys())[0] local = stream_out_local + out_page_comp_offset_buf self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size) # stream run (async) self.stream.run(comp_fsm) # stream_act_local stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col) i = 0 for y, stream_act_locals_row in enumerate(stream_act_locals_2d): for x, stream_act_local in enumerate(stream_act_locals_row): patterns = [] for col in range(ksize_col): val = self.inc_act_laddr_conds[i] i += 1 pat = (col_select == col, val) patterns.append(pat) patterns.append((None, 0)) v = vg.PatternMux(*patterns) comp_fsm.If(vg.Not(v))(stream_act_local.add( self.inc_act_laddr_small)) comp_fsm.If(v)(stream_act_local.add(self.inc_act_laddr_large)) comp_fsm.If(col_count >= self.max_col_count)( stream_act_local(0)) comp_fsm.If(col_count >= self.max_col_count, self.stream_act_local_small_flags[x])( stream_act_local( self.stream_act_local_small_offset)) comp_fsm.If(col_count >= self.max_col_count, self.stream_act_local_large_flags[x])( stream_act_local( self.stream_act_local_large_offset)) # stream_out_local comp_fsm(stream_out_local.add(self.inc_out_laddr)) comp_fsm.If(col_count >= self.max_col_count)(stream_out_local(0)) # counter comp_fsm(col_count.add(self.stride_col)) comp_fsm.If(col_count >= self.max_col_count)(col_count(0)) comp_fsm(col_select.add(self.stride_col_mod_ksize)) comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)( col_select.sub(self.ksize_col_minus_stride_col_mod)) comp_fsm.If(col_count >= self.max_col_count)(col_select( self.col_select_initval), ) # repeat comp_fsm.goto(comp_state_rep) comp_fsm.If(col_count >= self.max_col_count).goto_init() # sync with WriteOut control comp_fsm.seq.If(fsm.state == state_init)(comp_count(0)) comp_fsm.seq.If(self.stream.end_flag)(comp_count.add( self.inc_out_laddr)) # -------------------- # WriteOut phase # -------------------- state_write_out = fsm.current # sync with Comp control fsm.If(comp_count >= out_count + self.out_write_size).goto_next() out_laddr = out_page_dma_offset out_gaddr = self.objaddr + out_offset bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, out_ram, out_laddr, out_gaddr, self.out_write_size, port=1, use_async=True) bt.bus_unlock(self.maxi, fsm) fsm(out_count.add(self.out_write_size)) fsm.goto_next() state_write_out_end = fsm.current fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end) # -------------------- # update for next iteration # -------------------- # ReadAct: offset fsm(act_base_offset_row.add(self.act_row_step)) fsm.If(row_count >= self.max_row_count)(act_base_offset_row(0), act_base_offset_bat.add( self.act_bat_step)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(act_base_offset_bat(0)) # ReadAct: DMA flag next_dma_flags = [] for dma_flag, dma_flag_cond in zip(dma_flags, self.dma_flag_conds): fsm(dma_flag(dma_flag_cond)) fsm.If(row_count >= self.max_row_count)(dma_flag(1)) next_dma_flags.append( vg.Mux(row_count >= self.max_row_count, 1, dma_flag_cond)) # ReadAct: counter fsm(row_count.add(self.stride_row)) fsm.If(row_count >= self.max_row_count)(row_count(0), bat_count.add(self.stride_bat)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(bat_count(0)) fsm.If(self.stride_row < ksize_row)(row_select.add(self.stride_row), prev_row_select(row_select)) fsm.If(self.stride_row < ksize_row, row_select + self.stride_row >= ksize_row)( row_select(row_select - (vg.Int(ksize_row) - self.stride_row)), prev_row_select(row_select)) fsm.If(vg.Not(self.stride_row < ksize_row))(row_select(0), prev_row_select(0)) fsm.If(row_count >= self.max_row_count)(row_select(0), prev_row_select(0)) # ReadAct and Comp: double buffer mux_next_dma_flag_values = mux_1d(next_dma_flags, row_select, ksize_row) mux_next_dma_flags = [] for i, mux_next_dma_flag_value in enumerate(mux_next_dma_flag_values): mux_next_dma_flag = self.m.Wire( self._name('mux_next_dma_flag_%d' % i)) mux_next_dma_flag.assign(mux_next_dma_flag_value) mux_next_dma_flags.append(mux_next_dma_flag) for (act_page, act_page_comp_offset, act_page_dma_offset, mux_next_dma_flag) in zip(act_pages, act_page_comp_offsets, act_page_dma_offsets, mux_next_dma_flags): fsm.If(vg.Not(act_page), mux_next_dma_flag)(act_page_comp_offset(act_page_size), act_page_dma_offset(act_page_size), act_page(1)) fsm.If(act_page, mux_next_dma_flag)(act_page_comp_offset(0), act_page_dma_offset(0), act_page(0)) # WriteOut: counter fsm.If(vg.Not(skip_write_out))(out_base_offset_row.add( self.out_row_step)) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count)( out_base_offset_row(0), out_base_offset_bat.add(self.out_bat_step)) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count)(out_base_offset_bat(0)) # WriteOut and Comp: double buffer fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1)) fsm.If(out_page)(out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0)) # ReadAct and WriteOut: prev fsm(prev_row_count(row_count), prev_bat_count(bat_count)) # ReadAct, Comp, WriteOut: skip fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(skip_read_act(1)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(skip_comp(1)) fsm.If(skip_write_out, prev_row_count == 0, prev_bat_count == 0)(skip_write_out(0)) fsm.goto(state_read_act) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)
def control_sequence(self, fsm): ksize_ch = self.ksize[-1] ksize_col = self.ksize[-2] ksize_row = self.ksize[-3] ksize_bat = self.ksize[-4] self.stride_bat = 1 act_ram = self.input_rams[0] out_ram = self.output_rams[0] act_base_offset = self.m.Wire(self._name('act_base_offset'), self.maxi.addrwidth, signed=True) act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset.assign(act_base_offset_row + act_base_offset_bat) out_base_offset = self.m.Wire(self._name('out_base_offset'), self.maxi.addrwidth, signed=True) out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset.assign(out_base_offset_row + out_base_offset_bat) col_count = self.m.Reg(self._name('col_count'), self.maxi.addrwidth, initval=0) row_count = self.m.Reg(self._name('row_count'), self.maxi.addrwidth, initval=0) bat_count = self.m.Reg(self._name('bat_count'), self.maxi.addrwidth, initval=0) if not self.no_reuse: col_select = self.m.Reg(self._name('col_select'), bt.log_width(ksize_col), initval=0) row_select = self.m.Reg(self._name('row_select'), bt.log_width(ksize_row), initval=0) prev_row_count = self.m.Reg(self._name('prev_row_count'), self.maxi.addrwidth, initval=0) prev_bat_count = self.m.Reg(self._name('prev_bat_count'), self.maxi.addrwidth, initval=0) if not self.no_reuse: prev_row_select = self.m.Reg(self._name('prev_row_select'), bt.log_width(ksize_row), initval=0) stream_act_local = self.m.Reg(self._name('stream_act_local'), self.maxi.addrwidth, initval=0) stream_out_local = self.m.Reg(self._name('stream_out_local'), self.maxi.addrwidth, initval=0) # double buffer control act_page = self.m.Reg(self._name('act_page'), initval=0) act_page_comp_offset = self.m.Reg(self._name('act_page_comp_offset'), self.maxi.addrwidth, initval=0) act_page_dma_offset = self.m.Reg(self._name('act_page_dma_offset'), self.maxi.addrwidth, initval=0) out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) act_page_size = act_ram.length // 2 out_page_size = out_ram.length // 2 skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth, initval=0) out_count = self.m.Reg(self._name('out_count'), self.maxi.addrwidth, initval=0) # -------------------- # initialization phase # -------------------- # ReadAct: offset fsm( act_base_offset_row(0), act_base_offset_bat(0) ) act_offsets = [] for v in self.act_offset_values: act_offset = act_base_offset + v act_offsets.append(act_offset) # ReadAct: DMA flag dma_pad_masks = [] for y in range(ksize_row): v = vg.Ors((row_count + y < self.pad_row_top), (row_count + y >= self.act_num_row + self.pad_row_top)) dma_pad_mask = self.m.Wire( self._name('dma_pad_mask_%d' % y)) dma_pad_mask.assign(v) dma_pad_masks.append(dma_pad_mask) # ReadAct: double buffer control fsm( act_page(0), act_page_comp_offset(0), act_page_dma_offset(0) ) # WriteOutput: offset fsm( out_base_offset_row(0), out_base_offset_bat(0) ) out_offset = out_base_offset # WriteOut: double buffer control fsm( out_page(0), out_page_comp_offset(0), out_page_dma_offset(0) ) # counter fsm( row_count(0), bat_count(0), prev_row_count(0), prev_bat_count(0) ) if not self.no_reuse: fsm( row_select(0), prev_row_select(0) ) # double buffer control fsm( skip_read_act(0), skip_comp(0), skip_write_out(1) ) fsm( out_count(0) ) state_init = fsm.current fsm.goto_next() # -------------------- # ReadAct phase # -------------------- state_read_act = fsm.current act_gaddrs = [] for act_offset in act_offsets: act_gaddr = self.arg_objaddrs[0] + act_offset act_gaddrs.append(act_gaddr) if not self.no_reuse: mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row) mux_act_gaddrs = [] for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values): mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i), self.maxi.addrwidth) mux_act_gaddr.assign(mux_act_gaddr_value) mux_act_gaddrs.append(mux_act_gaddr) mux_dma_pad_mask_values = mux_1d( dma_pad_masks, row_select, ksize_row) mux_dma_pad_masks = [] for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values): mux_dma_pad_mask = self.m.Wire( self._name('mux_dma_pad_mask_%d' % i)) mux_dma_pad_mask.assign(mux_dma_pad_mask_value) mux_dma_pad_masks.append(mux_dma_pad_mask) else: mux_act_gaddrs = act_gaddrs mux_dma_pad_masks = dma_pad_masks bt.bus_lock(self.maxi, fsm) act_laddr = act_page_dma_offset for (act_gaddr, dma_pad_mask) in zip( mux_act_gaddrs, mux_dma_pad_masks): begin_state_read = fsm.current fsm.goto_next() bt.dma_read(self.maxi, fsm, act_ram, act_laddr, act_gaddr, self.act_read_size, port=1) end_state_read = fsm.current fsm.If(dma_pad_mask).goto_from( begin_state_read, end_state_read) act_laddr += self.act_read_size bt.bus_unlock(self.maxi, fsm) fsm.goto_next() state_read_act_end = fsm.current fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end) # -------------------- # Comp phase # -------------------- state_comp = fsm.current # Stream Control FSM comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst) comp_state_init = comp_fsm.current comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next() fsm.If(comp_fsm.state == comp_state_init).goto_next() # local address comp_fsm( stream_act_local(self.local_pad_offset) ) comp_fsm( stream_out_local(0) ) # count and sel comp_fsm( col_count(0) ) if not self.no_reuse: comp_fsm( col_select(self.col_select_initval) ) act_page_comp_offset_buf = self.m.Reg(self._name('act_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) out_page_comp_offset_buf = self.m.Reg(self._name('out_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) row_count_buf = self.m.Reg(self._name('row_count_buf'), self.maxi.addrwidth, initval=0) if not self.no_reuse: row_select_buf = self.m.Reg(self._name('row_select_buf'), bt.log_width(ksize_row), initval=0) comp_fsm( act_page_comp_offset_buf(act_page_comp_offset), out_page_comp_offset_buf(out_page_comp_offset), row_count_buf(row_count) ) if not self.no_reuse: comp_fsm( row_select_buf(row_select) ) comp_fsm.goto_next() # repeat comp comp_state_rep = comp_fsm.current # pad_mask stream_pad_masks = [] for y in range(ksize_row): for x in range(ksize_col): stream_col_count = col_count + x stream_row_count = row_count_buf + y v = vg.Ors((stream_col_count < self.pad_col_left), (stream_col_count >= self.act_num_col + self.pad_col_left), (stream_row_count < self.pad_row_top), (stream_row_count >= self.act_num_row + self.pad_row_top)) stream_pad_mask = self.m.Wire( self._name('stream_pad_mask_%d_%d' % (y, x))) stream_pad_mask.assign(v) stream_pad_masks.append(stream_pad_mask) if not self.no_reuse: stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col) stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d, col_select, row_select_buf, ksize_col, ksize_row) stream_pad_masks = [flatten for inner in stream_pad_mask_2d_mux for flatten in inner] stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'), len(stream_pad_masks), initval=0) comp_fsm( stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks))) ) comp_fsm.goto_next() # busy check self.stream.source_join(comp_fsm) stream_masks = stream_pad_masks_reg # set_constant name = list(self.stream.constants.keys())[0] self.stream.set_constant(comp_fsm, name, ksize_col * ksize_row) comp_fsm.set_index(comp_fsm.current - 1) name = list(self.stream.constants.keys())[1] self.stream.set_constant(comp_fsm, name, stream_masks) comp_fsm.set_index(comp_fsm.current - 1) # set_source name = list(self.stream.sources.keys())[0] local = stream_act_local + act_page_comp_offset_buf pat = ((ksize_col, self.act_read_block), (ksize_row, self.act_read_size), (self.stream_size, 1)) self.stream.set_source_pattern(comp_fsm, name, act_ram, local, pat) comp_fsm.set_index(comp_fsm.current - 1) # set_sink name = list(self.stream.sinks.keys())[0] local = stream_out_local + out_page_comp_offset_buf self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size) # stream run (async) self.stream.run(comp_fsm) # stream_act_local comp_fsm( stream_act_local.add(self.inc_act_laddr) ) comp_fsm.If(col_count >= self.max_col_count)( stream_act_local(self.local_pad_offset) ) # stream_out_local comp_fsm( stream_out_local.add(self.inc_out_laddr) ) comp_fsm.If(col_count >= self.max_col_count)( stream_out_local(0) ) # counter comp_fsm( col_count.add(self.stride_col) ) comp_fsm.If(col_count >= self.max_col_count)( col_count(0) ) if not self.no_reuse: comp_fsm( col_select.add(self.stride_col_mod_ksize) ) comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)( col_select.sub(self.ksize_col_minus_stride_col_mod) ) comp_fsm.If(col_count >= self.max_col_count)( col_select(self.col_select_initval) ) # repeat comp_fsm.goto(comp_state_rep) comp_fsm.If(col_count >= self.max_col_count).goto_init() # sync with WriteOut control comp_fsm.seq.If(fsm.state == state_init)( comp_count(0) ) comp_fsm.seq.If(self.stream.end_flag)( comp_count.add(self.inc_out_laddr) ) # -------------------- # WriteOut phase # -------------------- state_write_out = fsm.current # sync with Comp control fsm.If(comp_count >= out_count + self.out_write_size).goto_next() out_laddr = out_page_dma_offset out_gaddr = self.objaddr + out_offset bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, out_ram, out_laddr, out_gaddr, self.out_write_size, port=1, use_async=True) bt.bus_unlock(self.maxi, fsm) fsm( out_count.add(self.out_write_size) ) fsm.goto_next() state_write_out_end = fsm.current fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end) # -------------------- # update for next iteration # -------------------- # ReadAct: offset fsm( act_base_offset_row.add(self.act_row_step) ) fsm.If(row_count >= self.max_row_count)( act_base_offset_row(0), act_base_offset_bat.add(self.act_bat_step) ) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)( act_base_offset_bat(0) ) # ReadAct: counter fsm( row_count.add(self.stride_row) ) fsm.If(row_count >= self.max_row_count)( row_count(0), bat_count.add(self.stride_bat) ) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)( bat_count(0) ) if not self.no_reuse: fsm.If(self.stride_row < ksize_row)( row_select.add(self.stride_row), prev_row_select(row_select) ) fsm.If(self.stride_row < ksize_row, row_select + self.stride_row >= ksize_row)( row_select(row_select - (vg.Int(ksize_row) - self.stride_row)), prev_row_select(row_select) ) fsm.If(vg.Not(self.stride_row < ksize_row))( row_select(0), prev_row_select(0) ) fsm.If(row_count >= self.max_row_count)( row_select(0), prev_row_select(0) ) # ReadAct and Comp: double buffer fsm.If(vg.Not(act_page))( act_page_comp_offset(act_page_size), act_page_dma_offset(act_page_size), act_page(1) ) fsm.If(act_page)( act_page_comp_offset(0), act_page_dma_offset(0), act_page(0) ) # WriteOut: counter fsm.If(vg.Not(skip_write_out))( out_base_offset_row.add(self.out_row_step) ) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count)( out_base_offset_row(0), out_base_offset_bat.add(self.out_bat_step) ) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count)( out_base_offset_bat(0) ) # WriteOut and Comp: double buffer fsm.If(vg.Not(out_page))( out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1) ) fsm.If(out_page)( out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0) ) # ReadAct and WriteOut: prev fsm( prev_row_count(row_count), prev_bat_count(bat_count) ) # ReadAct, Comp, WriteOut: skip fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)( skip_read_act(1) ) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)( skip_comp(1) ) fsm.If(skip_write_out, prev_row_count == 0, prev_bat_count == 0)( skip_write_out(0) ) fsm.goto(state_read_act) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)