def _modify_always_statement_body(self, m, st, regs): write_body = [] for reg in regs.values(): wport = self.get_write_port(reg) write_body.append(reg(wport)) write = vg.If(self.ctrl_write)(*write_body) main = write.Else(vg.If(vg.Not(self.ctrl_read))(st)) return main
def control_sequence(self, fsm): act_ram = self.input_rams[0] out_ram = self.output_rams[0] act_base_offset = self.m.Wire(self._name('act_base_offset'), self.maxi.addrwidth, signed=True) act_offsets = [self.m.Reg(self._name('act_offset_%d' % i), self.maxi.addrwidth, initval=0, signed=True) for i, _ in enumerate(self.act_shape[:-2])] if act_offsets: v = act_offsets[0] for act_offset in act_offsets[1:]: v += act_offset act_base_offset.assign(v) else: act_base_offset.assign(0) out_base_offset = self.m.Wire(self._name('out_base_offset'), self.maxi.addrwidth, signed=True) out_offsets = [self.m.Reg(self._name('out_offset_%d' % i), self.maxi.addrwidth, initval=0, signed=True) for i, _ in enumerate(self.out_shape[:-2])] if out_offsets: v = out_offsets[0] for out_offset in out_offsets[1:]: v += out_offset out_base_offset.assign(v) else: out_base_offset.assign(0) counts = [self.m.Reg(self._name('count_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.act_shape[:-2])] prev_counts = [self.m.Reg(self._name('prev_count_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.act_shape[:-2])] stream_act_local = self.m.Reg(self._name('stream_act_local'), self.maxi.addrwidth, initval=0) stream_out_local = self.m.Reg(self._name('stream_out_local'), self.maxi.addrwidth, initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth, initval=0) out_count = self.m.Reg(self._name('out_count'), self.maxi.addrwidth, initval=0) act_page = self.m.Reg(self._name('act_page'), initval=0) act_page_comp_offset = self.m.Reg(self._name('act_page_comp_offset'), self.maxi.addrwidth, initval=0) act_page_dma_offset = self.m.Reg(self._name('act_page_dma_offset'), self.maxi.addrwidth, initval=0) out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) act_page_size = act_ram.length // 2 out_page_size = out_ram.length // 2 skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0) # -------------------- # initialization phase # -------------------- # ReadAct: offset for act_offset, act_offset_begin in zip(act_offsets, self.act_offset_begins): fsm( act_offset(act_offset_begin) ) # ReadAct: double buffer control fsm( act_page(0), act_page_comp_offset(0), act_page_dma_offset(0) ) # WriteOutput: offset for out_offset in out_offsets: fsm( out_offset(0) ) out_offset = out_base_offset # WriteOutput: double buffer control fsm( out_page(0), out_page_comp_offset(0), out_page_dma_offset(0) ) # counter fsm( [count(0) for count in counts], [prev_count(0) for prev_count in prev_counts] ) # double buffer control fsm( skip_read_act(0), skip_comp(0), skip_write_out(1) ) fsm( out_count(0) ) state_init = fsm.current fsm.goto_next() # -------------------- # ReadAct phase # -------------------- state_read_act = fsm.current act_gaddr = self.arg_objaddrs[0] + act_base_offset bt.bus_lock(self.maxi, fsm) act_laddr = act_page_dma_offset begin_state_read = fsm.current fsm.goto_next() bt.dma_read(self.maxi, fsm, act_ram, act_laddr, act_gaddr, self.act_read_size, port=1) end_state_read = fsm.current # -------------------- # Comp phase # -------------------- state_comp = fsm.current # Stream Control FSM comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst) comp_state_init = comp_fsm.current comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next() fsm.If(comp_fsm.state == comp_state_init).goto_next() # local address comp_fsm( stream_act_local(self.stream_local), stream_out_local(0) ) act_page_comp_offset_buf = self.m.Reg(self._name('act_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) out_page_comp_offset_buf = self.m.Reg(self._name('out_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) comp_fsm( act_page_comp_offset_buf(act_page_comp_offset), out_page_comp_offset_buf(out_page_comp_offset) ) comp_fsm.goto_next() # busy check self.stream.source_join(comp_fsm) # set_source name = list(self.stream.sources.keys())[0] local = stream_act_local + act_page_comp_offset_buf if len(self.out_shape) > 1: pat = ((self.stream_size, self.act_strides[-1]), (self.out_shape[-2], self.stream_stride)) else: pat = ((self.stream_size, self.act_strides[-1]),) self.stream.set_source_pattern(comp_fsm, name, act_ram, local, pat) comp_fsm.set_index(comp_fsm.current - 1) # set_sink name = list(self.stream.sinks.keys())[0] local = stream_out_local + out_page_comp_offset_buf if len(self.out_shape) > 1: pat = ((self.stream_size, 1), (self.out_shape[-2], self.stream_size)) else: pat = ((self.stream_size, 1),) self.stream.set_sink_pattern(comp_fsm, name, out_ram, local, pat) # stream run (async) self.stream.run(comp_fsm) comp_fsm.goto_init() # sync with WriteOut control comp_fsm.seq.If(fsm.state == state_init)( comp_count(0) ) comp_fsm.seq.If(self.stream.end_flag)( comp_count.inc() ) # -------------------- # WriteOut phase # -------------------- state_write_out = fsm.current # sync with Comp control fsm.If(comp_count > out_count).goto_next() out_laddr = out_page_dma_offset out_gaddr = self.objaddr + out_offset bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, out_ram, out_laddr, out_gaddr, self.out_write_size, port=1, use_async=True) bt.bus_unlock(self.maxi, fsm) fsm( out_count.inc() ) fsm.goto_next() state_write_out_end = fsm.current fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end) # -------------------- # update for next iteration # -------------------- # ReadAct: count cond = None for size, count in zip(reversed(self.out_shape[:-2]), reversed(counts)): fsm.If(cond)( count.inc() ) fsm.If(cond, count >= size - 1)( count(0) ) if cond is not None: cond = vg.Ands(cond, count >= size - 1) else: cond = count >= size - 1 # ReadAct: offset cond = None for size, count, act_offset, act_offset_stride in zip( reversed(self.out_shape[:-2]), reversed(counts), reversed(act_offsets), reversed(self.act_offset_strides)): fsm.If(cond)( act_offset.add(act_offset_stride) ) fsm.If(cond, count >= size - 1)( act_offset(0) ) if cond is not None: cond = vg.Ands(cond, count >= size - 1) else: cond = count >= size - 1 # ReadAct and Comp: double buffer fsm.If(vg.Not(act_page))( act_page_comp_offset(act_page_size), act_page_dma_offset(act_page_size), act_page(1) ) fsm.If(act_page)( act_page_comp_offset(0), act_page_dma_offset(0), act_page(0) ) # WriteOut: offset cond = vg.Not(skip_write_out) for size, prev_count, out_offset, out_offset_stride in zip( reversed(self.out_shape[:-2]), reversed(prev_counts), reversed(out_offsets), reversed(self.out_offset_strides)): fsm.If(cond)( out_offset.add(out_offset_stride) ) fsm.If(cond, prev_count >= size - 1)( out_offset(0) ) cond = vg.Ands(cond, prev_count >= size - 1) # WriteOut and Comp: double buffer fsm.If(vg.Not(out_page))( out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1) ) fsm.If(out_page)( out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0) ) # ReadAct and WriteOut: prev for count, prev_count in zip(counts, prev_counts): fsm( prev_count(count) ) # ReadAct, Comp, WriteOut: skip cond_skip_read_act = None cond_skip_comp = None for size, count in zip(reversed(self.out_shape[:-2]), reversed(counts)): if cond_skip_read_act is not None: cond_skip_read_act = vg.Ands(cond_skip_read_act, count >= size - 1) else: cond_skip_read_act = count >= size - 1 cond_skip_comp = cond_skip_read_act cond_cancel_write_out = None for size, prev_count in zip(reversed(self.out_shape[:-2]), reversed(prev_counts)): if cond_cancel_write_out is not None: cond_cancel_write_out = vg.Ands(cond_cancel_write_out, prev_count == 0) else: cond_cancel_write_out = prev_count == 0 cond_done = None for size, prev_count in zip(reversed(self.out_shape[:-2]), reversed(prev_counts)): if cond_done is not None: cond_done = vg.Ands(cond_done, prev_count >= size - 1) else: cond_done = prev_count >= size - 1 fsm.If(cond_skip_read_act)( skip_read_act(1) ) fsm.If(cond_skip_comp)( skip_comp(1) ) fsm.If(skip_write_out, cond_cancel_write_out)( skip_write_out(0) ) fsm.goto(state_read_act) fsm.If(vg.Not(skip_write_out), cond_done).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)
def control_sequence(self, fsm): sources = self.collect_sources() arg_gaddrs = [ self.m.Reg(self._name('arg_gaddr_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.arg_objaddrs) ] out_gaddr = self.m.Reg(self._name('out_gaddr'), self.maxi.addrwidth, initval=0) out_gaddr_offset = self.m.Reg(self._name('out_gaddr_offset'), self.maxi.addrwidth, initval=0) out_pos_col = self.m.Reg(self._name('out_pos_col'), self.maxi.addrwidth + 1, initval=0) out_pos_row = self.m.Reg(self._name('out_pos_row'), self.maxi.addrwidth + 1, initval=0) out_col_count = self.m.Reg(self._name('out_col_count'), self.maxi.addrwidth + 1, initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth + 1, initval=0) wrap_counts = [ self.m.Reg(self._name('wrap_count_%d' % i), self.maxi.addrwidth + 1, initval=0) for i, arg in enumerate(sources) ] arg_pages = [ self.m.Reg(self._name('arg_page_%d' % i), initval=0) for i, _ in enumerate(self.arg_objaddrs) ] arg_page_comp_offsets = [ self.m.Reg(self._name('arg_page_comp_offset_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.arg_objaddrs) ] arg_page_dma_offsets = [ self.m.Reg(self._name('arg_page_dma_offset_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.arg_objaddrs) ] out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) arg_page_size = self.output_rams[0].length // 2 out_page_size = self.output_rams[0].length // 2 skip_read = self.m.Reg(self._name('skip_read'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write = self.m.Reg(self._name('skip_write'), initval=0) # -------------------- # initialization phase # -------------------- fsm([arg_gaddr(0) for arg_gaddr in arg_gaddrs]) fsm(comp_count(0), out_gaddr(0), out_gaddr_offset(0), out_pos_col(0), out_pos_row(0), out_col_count(0), [wrap_count(0) for wrap_count in wrap_counts]) fsm([arg_page(0) for arg_page in arg_pages], [ arg_page_comp_offset(0) for arg_page_comp_offset in arg_page_comp_offsets ], [ arg_page_dma_offset(0) for arg_page_dma_offset in arg_page_dma_offsets ]) fsm(out_page(0), out_page_comp_offset(0), out_page_dma_offset(out_page_size)) fsm(skip_read(0), skip_comp(0), skip_write(1)) fsm.goto_next() # -------------------- # Read phase # -------------------- state_read = fsm.current # DMA read -> Stream run -> Stream wait -> DMA write for (ram, arg_objaddr, arg_gaddr, arg_page_dma_offset, wrap_mode, wrap_count, arg) in zip(self.input_rams, self.arg_objaddrs, arg_gaddrs, arg_page_dma_offsets, self.wrap_modes, wrap_counts, sources): b = fsm.current fsm.goto_next() # normal laddr = arg_page_dma_offset gaddr = arg_objaddr + arg_gaddr bt.bus_lock(self.maxi, fsm) bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, self.dma_size) bt.bus_unlock(self.maxi, fsm) fsm.goto_next() b_stride0 = fsm.current fsm.goto_next() # stride-0 bt.bus_lock(self.maxi, fsm) bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, 1) bt.bus_unlock(self.maxi, fsm) fsm.goto_next() # for reuse e = fsm.current fsm.If(wrap_mode == 2, wrap_count > 0).goto_from(b, e) fsm.If(wrap_mode == 2, wrap_count == 0).goto_from(b, b_stride0) fsm.If(wrap_mode != 2).goto_from(b_stride0, e) state_read_end = fsm.current fsm.If(skip_read).goto_from(state_read, state_read_end) # -------------------- # Comp phase # -------------------- state_comp = fsm.current self.stream.source_join(fsm) # set_source, set_constant (dup) for (source_name, dup_name, arg_page_comp_offset, ram, wrap_mode) in zip(self.stream.sources.keys(), self.stream.constants.keys(), arg_page_comp_offsets, self.input_rams, self.wrap_modes): read_laddr = arg_page_comp_offset read_size = self.dma_size stride = vg.Mux(wrap_mode == 2, 0, 1) dup = vg.Mux(wrap_mode == 2, 1, 0) self.stream.set_constant(fsm, dup_name, dup) fsm.set_index(fsm.current - 1) self.stream.set_source(fsm, source_name, ram, read_laddr, read_size, stride) fsm.set_index(fsm.current - 1) # set_sink write_laddr = out_page_comp_offset write_size = self.dma_size for name, ram in zip(self.stream.sinks.keys(), self.output_rams): self.stream.set_sink(fsm, name, ram, write_laddr, write_size) fsm.set_index(fsm.current - 1) fsm.goto_next() self.stream.run(fsm) state_comp_end = fsm.current self.stream.join(fsm) state_comp_end_join = fsm.current fsm.If(skip_comp).goto_from(state_comp, state_comp_end) fsm.If(vg.Not(skip_comp)).goto_from(state_comp_end, state_comp_end_join) # -------------------- # Write phase # -------------------- state_write = fsm.current laddr = out_page_dma_offset gaddr_base = self.objaddr + out_gaddr bt.bus_lock(self.maxi, fsm) b = fsm.current gaddr = gaddr_base + out_gaddr_offset bt.dma_write(self.maxi, fsm, self.output_rams[0], laddr, gaddr, self.dma_size, use_async=True) fsm( out_pos_col.inc(), out_gaddr_offset.add(self.out_col_step), ) fsm.If(out_pos_col == self.max_out_pos_col)(out_pos_col(0), out_pos_row.inc(), out_gaddr_offset.add( self.out_row_step)) fsm.goto(b) fsm.If(out_pos_col == self.max_out_pos_col, out_pos_row == self.max_out_pos_row).goto_next() bt.bus_unlock(self.maxi, fsm) fsm.goto_next() state_write_end = fsm.current fsm.If(skip_write).goto_from(state_write, state_write_end) # -------------------- # update for next iteration # -------------------- fsm(comp_count.inc()) fsm(out_gaddr_offset(0), out_pos_col(0), out_pos_row(0)) fsm.If(vg.Not(skip_write))( out_gaddr.add(self.out_col_inc), out_col_count.inc(), ) fsm.If(vg.Not(skip_write), out_col_count == self.max_out_col_count)( out_gaddr.add(self.out_row_inc), out_col_count(0), ) for (arg_gaddr, arg_addr_inc, arg_page, arg_page_comp_offset, arg_page_dma_offset, wrap_mode, wrap_size, wrap_count, arg) in zip(arg_gaddrs, self.arg_addr_incs, arg_pages, arg_page_comp_offsets, arg_page_dma_offsets, self.wrap_modes, self.wrap_sizes, wrap_counts, sources): fsm.If(wrap_mode == 2)(wrap_count(1)) fsm.If(wrap_mode == 1)(arg_gaddr.add(arg_addr_inc), wrap_count.inc()) fsm.If(wrap_mode == 1, wrap_count == wrap_size - 1)(arg_gaddr(0), wrap_count(0)) fsm.If(wrap_mode == 0)(arg_gaddr.add(arg_addr_inc)) fsm.If(vg.Not(arg_page), wrap_mode != 2)(arg_page_comp_offset(arg_page_size), arg_page_dma_offset(out_page_size), arg_page(1)) fsm.If(arg_page, wrap_mode != 2)(arg_page_comp_offset(0), arg_page_dma_offset(0), arg_page(0)) fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1)) fsm.If(out_page)(out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0)) fsm(skip_write(0)) fsm.If(comp_count == self.num_comp - 1)(skip_read(1), skip_comp(1)) fsm.If(comp_count < self.num_comp).goto(state_read) fsm.If(comp_count == self.num_comp).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)
def control_sequence(self, fsm): ksize_ch = self.ksize[-1] ksize_col = self.ksize[-2] ksize_row = self.ksize[-3] ksize_bat = self.ksize[-4] self.stride_bat = 1 act_rams = self.input_rams out_ram = self.output_rams[0] act_base_offset = self.m.Wire(self._name('act_base_offset'), self.maxi.addrwidth, signed=True) act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset.assign(act_base_offset_row + act_base_offset_bat) out_base_offset = self.m.Wire(self._name('out_base_offset'), self.maxi.addrwidth, signed=True) out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset.assign(out_base_offset_row + out_base_offset_bat) dma_flags = [ self.m.Reg(self._name('dma_flag_%d' % i), initval=0) for i in range(ksize_row) ] col_count = self.m.Reg(self._name('col_count'), self.maxi.addrwidth, initval=0) row_count = self.m.Reg(self._name('row_count'), self.maxi.addrwidth, initval=0) bat_count = self.m.Reg(self._name('bat_count'), self.maxi.addrwidth, initval=0) col_select = self.m.Reg(self._name('col_select'), bt.log_width(ksize_col), initval=0) row_select = self.m.Reg(self._name('row_select'), bt.log_width(ksize_row), initval=0) prev_row_count = self.m.Reg(self._name('prev_row_count'), self.maxi.addrwidth, initval=0) prev_bat_count = self.m.Reg(self._name('prev_bat_count'), self.maxi.addrwidth, initval=0) prev_row_select = self.m.Reg(self._name('prev_row_select'), bt.log_width(ksize_row), initval=0) stream_act_locals = [ self.m.Reg(self._name('stream_act_local_%d' % i), self.maxi.addrwidth, initval=0) for i in range(len(act_rams)) ] stream_out_local = self.m.Reg(self._name('stream_out_local'), self.maxi.addrwidth, initval=0) # double buffer control act_pages = [ self.m.Reg(self._name('act_page_%d' % i), initval=0) for i in range(ksize_row) ] act_page_comp_offsets = [ self.m.Reg(self._name('act_page_comp_offset_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] act_page_dma_offsets = [ self.m.Reg(self._name('act_page_dma_offset_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) act_page_size = act_rams[0].length // 2 out_page_size = out_ram.length // 2 skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth, initval=0) out_count = self.m.Reg(self._name('out_count'), self.maxi.addrwidth, initval=0) # -------------------- # initialization phase # -------------------- # ReadAct: offset fsm(act_base_offset_row(0), act_base_offset_bat(0)) act_offsets = [] for v in self.act_offset_values: act_offset = act_base_offset + v act_offsets.append(act_offset) # ReadAct: DMA flag for y, dma_flag in enumerate(dma_flags): fsm(dma_flag(1)) dma_pad_masks = [] for y in range(ksize_row): v = vg.Ors((row_count + y < self.pad_row_top), (row_count + y >= self.act_num_row + self.pad_row_top)) dma_pad_mask = self.m.Wire(self._name('dma_pad_mask_%d' % y)) dma_pad_mask.assign(v) dma_pad_masks.append(dma_pad_mask) # ReadAct: double buffer control fsm([act_page(0) for act_page in act_pages], [ act_page_comp_offset(0) for act_page_comp_offset in act_page_comp_offsets ], [ act_page_dma_offset(0) for act_page_dma_offset in act_page_dma_offsets ]) # WriteOutput: offset fsm(out_base_offset_row(0), out_base_offset_bat(0)) out_offset = out_base_offset # WriteOut: double buffer control fsm(out_page(0), out_page_comp_offset(0), out_page_dma_offset(0)) # counter fsm(row_count(0), bat_count(0), row_select(0), prev_row_count(0), prev_bat_count(0), prev_row_select(0)) # double buffer control fsm(skip_read_act(0), skip_comp(0), skip_write_out(1)) fsm(out_count(0)) state_init = fsm.current fsm.goto_next() # -------------------- # ReadAct phase # -------------------- state_read_act = fsm.current act_gaddrs = [] for act_offset in act_offsets: act_gaddr = self.arg_objaddrs[0] + act_offset act_gaddrs.append(act_gaddr) act_rams_2d = line_to_2d(act_rams, ksize_col) mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row) mux_act_gaddrs = [] for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values): mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i), self.maxi.addrwidth) mux_act_gaddr.assign(mux_act_gaddr_value) mux_act_gaddrs.append(mux_act_gaddr) mux_dma_pad_mask_values = mux_1d(dma_pad_masks, row_select, ksize_row) mux_dma_pad_masks = [] for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values): mux_dma_pad_mask = self.m.Wire( self._name('mux_dma_pad_mask_%d' % i)) mux_dma_pad_mask.assign(mux_dma_pad_mask_value) mux_dma_pad_masks.append(mux_dma_pad_mask) # determined at the previous phase mux_dma_flag_values = mux_1d(dma_flags, prev_row_select, ksize_row) mux_dma_flags = [] for i, mux_dma_flag_value in enumerate(mux_dma_flag_values): mux_dma_flag = self.m.Wire(self._name('mux_dma_flag_%d' % i)) mux_dma_flag.assign(mux_dma_flag_value) mux_dma_flags.append(mux_dma_flag) bt.bus_lock(self.maxi, fsm) for (act_rams_row, act_gaddr, act_page_dma_offset, dma_pad_mask, dma_flag) in zip(act_rams_2d, mux_act_gaddrs, act_page_dma_offsets, mux_dma_pad_masks, mux_dma_flags): act_laddr = act_page_dma_offset begin_state_read = fsm.current fsm.goto_next() if len(act_rams_row) == 1: bt.dma_read(self.maxi, fsm, act_rams_row[0], act_laddr, act_gaddr, self.act_read_size, port=1) else: bt.dma_read_block(self.maxi, fsm, act_rams_row, act_laddr, act_gaddr, self.act_read_size, self.act_read_block, port=1) end_state_read = fsm.current fsm.If(vg.Ors(dma_pad_mask, vg.Not(dma_flag))).goto_from(begin_state_read, end_state_read) bt.bus_unlock(self.maxi, fsm) fsm.goto_next() state_read_act_end = fsm.current fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end) # -------------------- # Comp phase # -------------------- state_comp = fsm.current # Stream Control FSM comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst) comp_state_init = comp_fsm.current comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next() fsm.If(comp_fsm.state == comp_state_init).goto_next() # local address stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col) for y, stream_act_locals_row in enumerate(stream_act_locals_2d): for x, stream_act_local in enumerate(stream_act_locals_row): comp_fsm(stream_act_local(0)) comp_fsm.If(self.stream_act_local_small_flags[x])( stream_act_local(self.stream_act_local_small_offset)) comp_fsm.If(self.stream_act_local_large_flags[x])( stream_act_local(self.stream_act_local_large_offset)) comp_fsm(stream_out_local(0)) # count and sel comp_fsm(col_count(0)) comp_fsm(col_select(self.col_select_initval)) act_page_comp_offset_bufs = [ self.m.Reg(self._name('act_page_comp_offset_buf_%d' % i), self.maxi.addrwidth, initval=0) for i in range(ksize_row) ] out_page_comp_offset_buf = self.m.Reg( self._name('out_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) row_count_buf = self.m.Reg(self._name('row_count_buf'), self.maxi.addrwidth, initval=0) row_select_buf = self.m.Reg(self._name('row_select_buf'), bt.log_width(ksize_row), initval=0) comp_fsm([ act_page_comp_offset_buf(act_page_comp_offset) for act_page_comp_offset_buf, act_page_comp_offset in zip( act_page_comp_offset_bufs, act_page_comp_offsets) ], out_page_comp_offset_buf(out_page_comp_offset), row_count_buf(row_count), row_select_buf(row_select)) comp_fsm.goto_next() # repeat comp comp_state_rep = comp_fsm.current # pad_mask stream_pad_masks = [] for y in range(ksize_row): for x in range(ksize_col): stream_col_count = col_count + x stream_row_count = row_count_buf + y v = vg.Ors( (stream_col_count < self.pad_col_left), (stream_col_count >= self.act_num_col + self.pad_col_left), (stream_row_count < self.pad_row_top), (stream_row_count >= self.act_num_row + self.pad_row_top)) stream_pad_mask = self.m.Wire( self._name('stream_pad_mask_%d_%d' % (y, x))) stream_pad_mask.assign(v) stream_pad_masks.append(stream_pad_mask) stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col) stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d, col_select, row_select_buf, ksize_col, ksize_row) stream_pad_masks = [ flatten for inner in stream_pad_mask_2d_mux for flatten in inner ] stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'), len(stream_pad_masks), initval=0) comp_fsm(stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks)))) comp_fsm.goto_next() # busy check self.stream.source_join(comp_fsm) stream_masks = stream_pad_masks_reg # set_constant name = list(self.stream.constants.keys())[0] self.stream.set_constant(comp_fsm, name, stream_masks) comp_fsm.set_index(comp_fsm.current - 1) # set_source act_page_comp_offset_bufs_dup = [] for act_page_comp_offset_buf in act_page_comp_offset_bufs: act_page_comp_offset_bufs_dup.extend([act_page_comp_offset_buf] * ksize_col) for name, ram, stream_act_local, act_page_comp_offset_buf in zip( self.stream.sources.keys(), act_rams, stream_act_locals, act_page_comp_offset_bufs_dup): local = stream_act_local + act_page_comp_offset_buf self.stream.set_source(comp_fsm, name, ram, local, self.stream_size) comp_fsm.set_index(comp_fsm.current - 1) # set_sink name = list(self.stream.sinks.keys())[0] local = stream_out_local + out_page_comp_offset_buf self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size) # stream run (async) self.stream.run(comp_fsm) # stream_act_local stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col) i = 0 for y, stream_act_locals_row in enumerate(stream_act_locals_2d): for x, stream_act_local in enumerate(stream_act_locals_row): patterns = [] for col in range(ksize_col): val = self.inc_act_laddr_conds[i] i += 1 pat = (col_select == col, val) patterns.append(pat) patterns.append((None, 0)) v = vg.PatternMux(*patterns) comp_fsm.If(vg.Not(v))(stream_act_local.add( self.inc_act_laddr_small)) comp_fsm.If(v)(stream_act_local.add(self.inc_act_laddr_large)) comp_fsm.If(col_count >= self.max_col_count)( stream_act_local(0)) comp_fsm.If(col_count >= self.max_col_count, self.stream_act_local_small_flags[x])( stream_act_local( self.stream_act_local_small_offset)) comp_fsm.If(col_count >= self.max_col_count, self.stream_act_local_large_flags[x])( stream_act_local( self.stream_act_local_large_offset)) # stream_out_local comp_fsm(stream_out_local.add(self.inc_out_laddr)) comp_fsm.If(col_count >= self.max_col_count)(stream_out_local(0)) # counter comp_fsm(col_count.add(self.stride_col)) comp_fsm.If(col_count >= self.max_col_count)(col_count(0)) comp_fsm(col_select.add(self.stride_col_mod_ksize)) comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)( col_select.sub(self.ksize_col_minus_stride_col_mod)) comp_fsm.If(col_count >= self.max_col_count)(col_select( self.col_select_initval), ) # repeat comp_fsm.goto(comp_state_rep) comp_fsm.If(col_count >= self.max_col_count).goto_init() # sync with WriteOut control comp_fsm.seq.If(fsm.state == state_init)(comp_count(0)) comp_fsm.seq.If(self.stream.end_flag)(comp_count.add( self.inc_out_laddr)) # -------------------- # WriteOut phase # -------------------- state_write_out = fsm.current # sync with Comp control fsm.If(comp_count >= out_count + self.out_write_size).goto_next() out_laddr = out_page_dma_offset out_gaddr = self.objaddr + out_offset bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, out_ram, out_laddr, out_gaddr, self.out_write_size, port=1, use_async=True) bt.bus_unlock(self.maxi, fsm) fsm(out_count.add(self.out_write_size)) fsm.goto_next() state_write_out_end = fsm.current fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end) # -------------------- # update for next iteration # -------------------- # ReadAct: offset fsm(act_base_offset_row.add(self.act_row_step)) fsm.If(row_count >= self.max_row_count)(act_base_offset_row(0), act_base_offset_bat.add( self.act_bat_step)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(act_base_offset_bat(0)) # ReadAct: DMA flag next_dma_flags = [] for dma_flag, dma_flag_cond in zip(dma_flags, self.dma_flag_conds): fsm(dma_flag(dma_flag_cond)) fsm.If(row_count >= self.max_row_count)(dma_flag(1)) next_dma_flags.append( vg.Mux(row_count >= self.max_row_count, 1, dma_flag_cond)) # ReadAct: counter fsm(row_count.add(self.stride_row)) fsm.If(row_count >= self.max_row_count)(row_count(0), bat_count.add(self.stride_bat)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(bat_count(0)) fsm.If(self.stride_row < ksize_row)(row_select.add(self.stride_row), prev_row_select(row_select)) fsm.If(self.stride_row < ksize_row, row_select + self.stride_row >= ksize_row)( row_select(row_select - (vg.Int(ksize_row) - self.stride_row)), prev_row_select(row_select)) fsm.If(vg.Not(self.stride_row < ksize_row))(row_select(0), prev_row_select(0)) fsm.If(row_count >= self.max_row_count)(row_select(0), prev_row_select(0)) # ReadAct and Comp: double buffer mux_next_dma_flag_values = mux_1d(next_dma_flags, row_select, ksize_row) mux_next_dma_flags = [] for i, mux_next_dma_flag_value in enumerate(mux_next_dma_flag_values): mux_next_dma_flag = self.m.Wire( self._name('mux_next_dma_flag_%d' % i)) mux_next_dma_flag.assign(mux_next_dma_flag_value) mux_next_dma_flags.append(mux_next_dma_flag) for (act_page, act_page_comp_offset, act_page_dma_offset, mux_next_dma_flag) in zip(act_pages, act_page_comp_offsets, act_page_dma_offsets, mux_next_dma_flags): fsm.If(vg.Not(act_page), mux_next_dma_flag)(act_page_comp_offset(act_page_size), act_page_dma_offset(act_page_size), act_page(1)) fsm.If(act_page, mux_next_dma_flag)(act_page_comp_offset(0), act_page_dma_offset(0), act_page(0)) # WriteOut: counter fsm.If(vg.Not(skip_write_out))(out_base_offset_row.add( self.out_row_step)) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count)( out_base_offset_row(0), out_base_offset_bat.add(self.out_bat_step)) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count)(out_base_offset_bat(0)) # WriteOut and Comp: double buffer fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1)) fsm.If(out_page)(out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0)) # ReadAct and WriteOut: prev fsm(prev_row_count(row_count), prev_bat_count(bat_count)) # ReadAct, Comp, WriteOut: skip fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(skip_read_act(1)) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)(skip_comp(1)) fsm.If(skip_write_out, prev_row_count == 0, prev_bat_count == 0)(skip_write_out(0)) fsm.goto(state_read_act) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)
def control_sequence(self, fsm): ksize_ch = self.ksize[-1] ksize_col = self.ksize[-2] ksize_row = self.ksize[-3] ksize_bat = self.ksize[-4] self.stride_bat = 1 act_ram = self.input_rams[0] out_ram = self.output_rams[0] act_base_offset = self.m.Wire(self._name('act_base_offset'), self.maxi.addrwidth, signed=True) act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) act_base_offset.assign(act_base_offset_row + act_base_offset_bat) out_base_offset = self.m.Wire(self._name('out_base_offset'), self.maxi.addrwidth, signed=True) out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'), self.maxi.addrwidth, initval=0, signed=True) out_base_offset.assign(out_base_offset_row + out_base_offset_bat) col_count = self.m.Reg(self._name('col_count'), self.maxi.addrwidth, initval=0) row_count = self.m.Reg(self._name('row_count'), self.maxi.addrwidth, initval=0) bat_count = self.m.Reg(self._name('bat_count'), self.maxi.addrwidth, initval=0) if not self.no_reuse: col_select = self.m.Reg(self._name('col_select'), bt.log_width(ksize_col), initval=0) row_select = self.m.Reg(self._name('row_select'), bt.log_width(ksize_row), initval=0) prev_row_count = self.m.Reg(self._name('prev_row_count'), self.maxi.addrwidth, initval=0) prev_bat_count = self.m.Reg(self._name('prev_bat_count'), self.maxi.addrwidth, initval=0) if not self.no_reuse: prev_row_select = self.m.Reg(self._name('prev_row_select'), bt.log_width(ksize_row), initval=0) stream_act_local = self.m.Reg(self._name('stream_act_local'), self.maxi.addrwidth, initval=0) stream_out_local = self.m.Reg(self._name('stream_out_local'), self.maxi.addrwidth, initval=0) # double buffer control act_page = self.m.Reg(self._name('act_page'), initval=0) act_page_comp_offset = self.m.Reg(self._name('act_page_comp_offset'), self.maxi.addrwidth, initval=0) act_page_dma_offset = self.m.Reg(self._name('act_page_dma_offset'), self.maxi.addrwidth, initval=0) out_page = self.m.Reg(self._name('out_page'), initval=0) out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'), self.maxi.addrwidth, initval=0) out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'), self.maxi.addrwidth, initval=0) act_page_size = act_ram.length // 2 out_page_size = out_ram.length // 2 skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0) skip_comp = self.m.Reg(self._name('skip_comp'), initval=0) skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0) comp_count = self.m.Reg(self._name('comp_count'), self.maxi.addrwidth, initval=0) out_count = self.m.Reg(self._name('out_count'), self.maxi.addrwidth, initval=0) # -------------------- # initialization phase # -------------------- # ReadAct: offset fsm( act_base_offset_row(0), act_base_offset_bat(0) ) act_offsets = [] for v in self.act_offset_values: act_offset = act_base_offset + v act_offsets.append(act_offset) # ReadAct: DMA flag dma_pad_masks = [] for y in range(ksize_row): v = vg.Ors((row_count + y < self.pad_row_top), (row_count + y >= self.act_num_row + self.pad_row_top)) dma_pad_mask = self.m.Wire( self._name('dma_pad_mask_%d' % y)) dma_pad_mask.assign(v) dma_pad_masks.append(dma_pad_mask) # ReadAct: double buffer control fsm( act_page(0), act_page_comp_offset(0), act_page_dma_offset(0) ) # WriteOutput: offset fsm( out_base_offset_row(0), out_base_offset_bat(0) ) out_offset = out_base_offset # WriteOut: double buffer control fsm( out_page(0), out_page_comp_offset(0), out_page_dma_offset(0) ) # counter fsm( row_count(0), bat_count(0), prev_row_count(0), prev_bat_count(0) ) if not self.no_reuse: fsm( row_select(0), prev_row_select(0) ) # double buffer control fsm( skip_read_act(0), skip_comp(0), skip_write_out(1) ) fsm( out_count(0) ) state_init = fsm.current fsm.goto_next() # -------------------- # ReadAct phase # -------------------- state_read_act = fsm.current act_gaddrs = [] for act_offset in act_offsets: act_gaddr = self.arg_objaddrs[0] + act_offset act_gaddrs.append(act_gaddr) if not self.no_reuse: mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row) mux_act_gaddrs = [] for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values): mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i), self.maxi.addrwidth) mux_act_gaddr.assign(mux_act_gaddr_value) mux_act_gaddrs.append(mux_act_gaddr) mux_dma_pad_mask_values = mux_1d( dma_pad_masks, row_select, ksize_row) mux_dma_pad_masks = [] for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values): mux_dma_pad_mask = self.m.Wire( self._name('mux_dma_pad_mask_%d' % i)) mux_dma_pad_mask.assign(mux_dma_pad_mask_value) mux_dma_pad_masks.append(mux_dma_pad_mask) else: mux_act_gaddrs = act_gaddrs mux_dma_pad_masks = dma_pad_masks bt.bus_lock(self.maxi, fsm) act_laddr = act_page_dma_offset for (act_gaddr, dma_pad_mask) in zip( mux_act_gaddrs, mux_dma_pad_masks): begin_state_read = fsm.current fsm.goto_next() bt.dma_read(self.maxi, fsm, act_ram, act_laddr, act_gaddr, self.act_read_size, port=1) end_state_read = fsm.current fsm.If(dma_pad_mask).goto_from( begin_state_read, end_state_read) act_laddr += self.act_read_size bt.bus_unlock(self.maxi, fsm) fsm.goto_next() state_read_act_end = fsm.current fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end) # -------------------- # Comp phase # -------------------- state_comp = fsm.current # Stream Control FSM comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst) comp_state_init = comp_fsm.current comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next() fsm.If(comp_fsm.state == comp_state_init).goto_next() # local address comp_fsm( stream_act_local(self.local_pad_offset) ) comp_fsm( stream_out_local(0) ) # count and sel comp_fsm( col_count(0) ) if not self.no_reuse: comp_fsm( col_select(self.col_select_initval) ) act_page_comp_offset_buf = self.m.Reg(self._name('act_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) out_page_comp_offset_buf = self.m.Reg(self._name('out_page_comp_offset_buf'), self.maxi.addrwidth, initval=0) row_count_buf = self.m.Reg(self._name('row_count_buf'), self.maxi.addrwidth, initval=0) if not self.no_reuse: row_select_buf = self.m.Reg(self._name('row_select_buf'), bt.log_width(ksize_row), initval=0) comp_fsm( act_page_comp_offset_buf(act_page_comp_offset), out_page_comp_offset_buf(out_page_comp_offset), row_count_buf(row_count) ) if not self.no_reuse: comp_fsm( row_select_buf(row_select) ) comp_fsm.goto_next() # repeat comp comp_state_rep = comp_fsm.current # pad_mask stream_pad_masks = [] for y in range(ksize_row): for x in range(ksize_col): stream_col_count = col_count + x stream_row_count = row_count_buf + y v = vg.Ors((stream_col_count < self.pad_col_left), (stream_col_count >= self.act_num_col + self.pad_col_left), (stream_row_count < self.pad_row_top), (stream_row_count >= self.act_num_row + self.pad_row_top)) stream_pad_mask = self.m.Wire( self._name('stream_pad_mask_%d_%d' % (y, x))) stream_pad_mask.assign(v) stream_pad_masks.append(stream_pad_mask) if not self.no_reuse: stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col) stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d, col_select, row_select_buf, ksize_col, ksize_row) stream_pad_masks = [flatten for inner in stream_pad_mask_2d_mux for flatten in inner] stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'), len(stream_pad_masks), initval=0) comp_fsm( stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks))) ) comp_fsm.goto_next() # busy check self.stream.source_join(comp_fsm) stream_masks = stream_pad_masks_reg # set_constant name = list(self.stream.constants.keys())[0] self.stream.set_constant(comp_fsm, name, ksize_col * ksize_row) comp_fsm.set_index(comp_fsm.current - 1) name = list(self.stream.constants.keys())[1] self.stream.set_constant(comp_fsm, name, stream_masks) comp_fsm.set_index(comp_fsm.current - 1) # set_source name = list(self.stream.sources.keys())[0] local = stream_act_local + act_page_comp_offset_buf pat = ((ksize_col, self.act_read_block), (ksize_row, self.act_read_size), (self.stream_size, 1)) self.stream.set_source_pattern(comp_fsm, name, act_ram, local, pat) comp_fsm.set_index(comp_fsm.current - 1) # set_sink name = list(self.stream.sinks.keys())[0] local = stream_out_local + out_page_comp_offset_buf self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size) # stream run (async) self.stream.run(comp_fsm) # stream_act_local comp_fsm( stream_act_local.add(self.inc_act_laddr) ) comp_fsm.If(col_count >= self.max_col_count)( stream_act_local(self.local_pad_offset) ) # stream_out_local comp_fsm( stream_out_local.add(self.inc_out_laddr) ) comp_fsm.If(col_count >= self.max_col_count)( stream_out_local(0) ) # counter comp_fsm( col_count.add(self.stride_col) ) comp_fsm.If(col_count >= self.max_col_count)( col_count(0) ) if not self.no_reuse: comp_fsm( col_select.add(self.stride_col_mod_ksize) ) comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)( col_select.sub(self.ksize_col_minus_stride_col_mod) ) comp_fsm.If(col_count >= self.max_col_count)( col_select(self.col_select_initval) ) # repeat comp_fsm.goto(comp_state_rep) comp_fsm.If(col_count >= self.max_col_count).goto_init() # sync with WriteOut control comp_fsm.seq.If(fsm.state == state_init)( comp_count(0) ) comp_fsm.seq.If(self.stream.end_flag)( comp_count.add(self.inc_out_laddr) ) # -------------------- # WriteOut phase # -------------------- state_write_out = fsm.current # sync with Comp control fsm.If(comp_count >= out_count + self.out_write_size).goto_next() out_laddr = out_page_dma_offset out_gaddr = self.objaddr + out_offset bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, out_ram, out_laddr, out_gaddr, self.out_write_size, port=1, use_async=True) bt.bus_unlock(self.maxi, fsm) fsm( out_count.add(self.out_write_size) ) fsm.goto_next() state_write_out_end = fsm.current fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end) # -------------------- # update for next iteration # -------------------- # ReadAct: offset fsm( act_base_offset_row.add(self.act_row_step) ) fsm.If(row_count >= self.max_row_count)( act_base_offset_row(0), act_base_offset_bat.add(self.act_bat_step) ) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)( act_base_offset_bat(0) ) # ReadAct: counter fsm( row_count.add(self.stride_row) ) fsm.If(row_count >= self.max_row_count)( row_count(0), bat_count.add(self.stride_bat) ) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)( bat_count(0) ) if not self.no_reuse: fsm.If(self.stride_row < ksize_row)( row_select.add(self.stride_row), prev_row_select(row_select) ) fsm.If(self.stride_row < ksize_row, row_select + self.stride_row >= ksize_row)( row_select(row_select - (vg.Int(ksize_row) - self.stride_row)), prev_row_select(row_select) ) fsm.If(vg.Not(self.stride_row < ksize_row))( row_select(0), prev_row_select(0) ) fsm.If(row_count >= self.max_row_count)( row_select(0), prev_row_select(0) ) # ReadAct and Comp: double buffer fsm.If(vg.Not(act_page))( act_page_comp_offset(act_page_size), act_page_dma_offset(act_page_size), act_page(1) ) fsm.If(act_page)( act_page_comp_offset(0), act_page_dma_offset(0), act_page(0) ) # WriteOut: counter fsm.If(vg.Not(skip_write_out))( out_base_offset_row.add(self.out_row_step) ) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count)( out_base_offset_row(0), out_base_offset_bat.add(self.out_bat_step) ) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count)( out_base_offset_bat(0) ) # WriteOut and Comp: double buffer fsm.If(vg.Not(out_page))( out_page_comp_offset(out_page_size), out_page_dma_offset(0), out_page(1) ) fsm.If(out_page)( out_page_comp_offset(0), out_page_dma_offset(out_page_size), out_page(0) ) # ReadAct and WriteOut: prev fsm( prev_row_count(row_count), prev_bat_count(bat_count) ) # ReadAct, Comp, WriteOut: skip fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)( skip_read_act(1) ) fsm.If(row_count >= self.max_row_count, bat_count >= self.max_bat_count)( skip_comp(1) ) fsm.If(skip_write_out, prev_row_count == 0, prev_bat_count == 0)( skip_write_out(0) ) fsm.goto(state_read_act) fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count, prev_bat_count >= self.max_bat_count).goto_next() # wait for last DMA write bt.dma_wait_write(self.maxi, fsm)
def control_sequence(self, fsm): arg_gaddrs = [ self.m.Reg(self._name('arg_gaddr_%d' % i), self.maxi.addrwidth, initval=0) for i, _ in enumerate(self.arg_objaddrs) ] out_gaddr = self.m.Reg(self._name('out_gaddr'), self.maxi.addrwidth, initval=0) arg_laddr = self.m.Reg(self._name('arg_laddr'), self.maxi.addrwidth, initval=0) copy_laddr = self.m.Reg(self._name('copy_laddr'), self.maxi.addrwidth, initval=0) copy_size = self.m.Reg(self._name('copy_size'), self.maxi.addrwidth, initval=0) sum_read_sizes = self.m.Wire( self._name('sum_read_sizes'), max([i.width for i in self.arg_read_sizes]) + int(math.ceil(math.log2(len(self.arg_read_sizes))))) sum_read_sizes.assign(vg.Add(*self.arg_read_sizes)) out_addr_inc_unbuffered = self.m.Reg( self._name('out_addr_inc_unbuffered'), self.maxi.addrwidth, initval=0) arg_select = self.m.Reg(self._name('arg_select'), int( max(math.ceil(math.log(len(self.args), 2)), 1)), initval=0) prev_arg_select = self.m.Reg( self._name('prev_arg_select'), int(max(math.ceil(math.log(len(self.args), 2)), 1)), initval=0) arg_chunk_count = self.m.Reg(self._name('arg_chunk_count'), self.maxi.addrwidth + 1, initval=0) out_count = self.m.Reg(self._name('out_count'), self.maxi.addrwidth + 1, initval=0) # -------------------- # initialization phase # -------------------- fsm([arg_gaddr(0) for arg_gaddr in arg_gaddrs], out_gaddr(0), arg_laddr(0), copy_laddr(0), copy_size(0), out_addr_inc_unbuffered(0), arg_select(0), prev_arg_select(0), arg_chunk_count(0), out_count(0)) fsm.goto_next() # -------------------- # Read phase # -------------------- state_read = fsm.current fsm.inc() state_read_begin_list = [] state_read_end_list = [] for (arg, arg_objaddr, arg_gaddr, arg_addr_inc, arg_read_size, arg_chunk_size) in zip(self.args, self.arg_objaddrs, arg_gaddrs, self.arg_addr_incs, self.arg_read_sizes, self.arg_chunk_sizes): b = fsm.current state_read_begin_list.append(b) # normal laddr = arg_laddr gaddr = arg_objaddr + arg_gaddr bt.bus_lock(self.maxi, fsm) bt.dma_read(self.maxi, fsm, self.input_rams[0], laddr, gaddr, arg_read_size) bt.bus_unlock(self.maxi, fsm) fsm(arg_gaddr.add(arg_addr_inc), arg_laddr.add(arg_read_size), arg_chunk_count.inc(), copy_size(arg_read_size), out_addr_inc_unbuffered(arg_addr_inc)) fsm.If(arg_chunk_count == arg_chunk_size - 1)(arg_chunk_count(0), arg_select.inc()) fsm.If(arg_chunk_count == arg_chunk_size - 1, arg_select == len(self.args) - 1)(arg_select(0)) fsm(prev_arg_select(arg_select)) e = fsm.current state_read_end_list.append(e) fsm.inc() state_read_end = fsm.current for i, b in enumerate(state_read_begin_list): fsm.If(arg_select == i).goto_from(state_read, b) for i, e in enumerate(state_read_end_list): fsm.goto_from(e, state_read_end) # -------------------- # Copy phase # -------------------- state_copy = fsm.current name = list(self.stream.sources.keys())[0] self.stream.set_source(fsm, name, self.input_rams[0], 0, copy_size) fsm.set_index(fsm.current - 1) name = list(self.stream.constants.keys())[0] self.stream.set_constant(fsm, name, prev_arg_select) fsm.set_index(fsm.current - 1) name = list(self.stream.sinks.keys())[0] self.stream.set_sink(fsm, name, self.output_rams[0], copy_laddr, copy_size) self.stream.run(fsm) self.stream.join(fsm) fsm(arg_laddr(0), copy_laddr.add(copy_size)) fsm.goto_next() fsm.If(copy_laddr < sum_read_sizes).goto(state_read) fsm.If(copy_laddr >= sum_read_sizes).goto_next() state_copy_end = fsm.current fsm.If(vg.Not(self.buffered)).goto_from(state_copy, state_copy_end) # -------------------- # Write phase # -------------------- state_write = fsm.current fsm.inc() # Case with Copy state_write_buffered = fsm.current laddr = 0 gaddr = self.objaddr + out_gaddr bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, self.output_rams[0], laddr, gaddr, self.out_write_size) bt.bus_unlock(self.maxi, fsm) fsm(copy_laddr(0), out_gaddr.add(self.out_addr_inc), out_count.inc()) state_write_end_buffered = fsm.current fsm.inc() # Case without Copy state_write_unbuffered = fsm.current laddr = 0 gaddr = self.objaddr + out_gaddr bt.bus_lock(self.maxi, fsm) bt.dma_write(self.maxi, fsm, self.input_rams[0], laddr, gaddr, copy_size) bt.bus_unlock(self.maxi, fsm) fsm(arg_laddr(0), out_gaddr.add(out_addr_inc_unbuffered), out_count.inc()) state_write_end_unbuffered = fsm.current fsm.inc() state_write_end = fsm.current fsm.If(self.buffered).goto_from(state_write, state_write_buffered) fsm.If(vg.Not(self.buffered)).goto_from(state_write, state_write_unbuffered) fsm.goto_from(state_write_end_buffered, state_write_end) fsm.goto_from(state_write_end_unbuffered, state_write_end) # -------------------- # update for next iteration # -------------------- fsm.If(out_count < self.num_steps).goto(state_read) fsm.If(out_count == self.num_steps).goto_next()