Esempio n. 1
0
    def control_sequence(self, fsm):
        act_ram = self.input_rams[0]
        out_ram = self.output_rams[0]

        act_base_offset = self.m.Wire(self._name('act_base_offset'),
                                      self.maxi.addrwidth, signed=True)

        act_offsets = [self.m.Reg(self._name('act_offset_%d' % i),
                                  self.maxi.addrwidth, initval=0, signed=True)
                       for i, _ in enumerate(self.act_shape[:-2])]

        if act_offsets:
            v = act_offsets[0]
            for act_offset in act_offsets[1:]:
                v += act_offset
            act_base_offset.assign(v)
        else:
            act_base_offset.assign(0)

        out_base_offset = self.m.Wire(self._name('out_base_offset'),
                                      self.maxi.addrwidth, signed=True)

        out_offsets = [self.m.Reg(self._name('out_offset_%d' % i),
                                  self.maxi.addrwidth, initval=0, signed=True)
                       for i, _ in enumerate(self.out_shape[:-2])]

        if out_offsets:
            v = out_offsets[0]
            for out_offset in out_offsets[1:]:
                v += out_offset
            out_base_offset.assign(v)
        else:
            out_base_offset.assign(0)

        counts = [self.m.Reg(self._name('count_%d' % i),
                             self.maxi.addrwidth, initval=0)
                  for i, _ in enumerate(self.act_shape[:-2])]

        prev_counts = [self.m.Reg(self._name('prev_count_%d' % i),
                                  self.maxi.addrwidth, initval=0)
                       for i, _ in enumerate(self.act_shape[:-2])]

        stream_act_local = self.m.Reg(self._name('stream_act_local'),
                                      self.maxi.addrwidth, initval=0)
        stream_out_local = self.m.Reg(self._name('stream_out_local'),
                                      self.maxi.addrwidth, initval=0)

        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth, initval=0)
        out_count = self.m.Reg(self._name('out_count'),
                               self.maxi.addrwidth, initval=0)

        act_page = self.m.Reg(self._name('act_page'), initval=0)
        act_page_comp_offset = self.m.Reg(self._name('act_page_comp_offset'),
                                          self.maxi.addrwidth, initval=0)
        act_page_dma_offset = self.m.Reg(self._name('act_page_dma_offset'),
                                         self.maxi.addrwidth, initval=0)

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth, initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth, initval=0)

        act_page_size = act_ram.length // 2
        out_page_size = out_ram.length // 2

        skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0)

        # --------------------
        # initialization phase
        # --------------------
        # ReadAct: offset
        for act_offset, act_offset_begin in zip(act_offsets, self.act_offset_begins):
            fsm(
                act_offset(act_offset_begin)
            )

        # ReadAct: double buffer control
        fsm(
            act_page(0),
            act_page_comp_offset(0),
            act_page_dma_offset(0)
        )

        # WriteOutput: offset
        for out_offset in out_offsets:
            fsm(
                out_offset(0)
            )

        out_offset = out_base_offset

        # WriteOutput: double buffer control
        fsm(
            out_page(0),
            out_page_comp_offset(0),
            out_page_dma_offset(0)
        )

        # counter
        fsm(
            [count(0) for count in counts],
            [prev_count(0) for prev_count in prev_counts]
        )

        # double buffer control
        fsm(
            skip_read_act(0),
            skip_comp(0),
            skip_write_out(1)
        )

        fsm(
            out_count(0)
        )

        state_init = fsm.current

        fsm.goto_next()

        # --------------------
        # ReadAct phase
        # --------------------
        state_read_act = fsm.current

        act_gaddr = self.arg_objaddrs[0] + act_base_offset

        bt.bus_lock(self.maxi, fsm)

        act_laddr = act_page_dma_offset

        begin_state_read = fsm.current
        fsm.goto_next()

        bt.dma_read(self.maxi, fsm, act_ram, act_laddr,
                    act_gaddr, self.act_read_size, port=1)

        end_state_read = fsm.current

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        # Stream Control FSM
        comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst)

        comp_state_init = comp_fsm.current
        comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next()

        fsm.If(comp_fsm.state == comp_state_init).goto_next()

        # local address
        comp_fsm(
            stream_act_local(self.stream_local),
            stream_out_local(0)
        )

        act_page_comp_offset_buf = self.m.Reg(self._name('act_page_comp_offset_buf'),
                                              self.maxi.addrwidth, initval=0)
        out_page_comp_offset_buf = self.m.Reg(self._name('out_page_comp_offset_buf'),
                                              self.maxi.addrwidth, initval=0)

        comp_fsm(
            act_page_comp_offset_buf(act_page_comp_offset),
            out_page_comp_offset_buf(out_page_comp_offset)
        )

        comp_fsm.goto_next()

        # busy check
        self.stream.source_join(comp_fsm)

        # set_source
        name = list(self.stream.sources.keys())[0]
        local = stream_act_local + act_page_comp_offset_buf

        if len(self.out_shape) > 1:
            pat = ((self.stream_size, self.act_strides[-1]),
                   (self.out_shape[-2], self.stream_stride))
        else:
            pat = ((self.stream_size, self.act_strides[-1]),)

        self.stream.set_source_pattern(comp_fsm, name, act_ram,
                                       local, pat)

        comp_fsm.set_index(comp_fsm.current - 1)

        # set_sink
        name = list(self.stream.sinks.keys())[0]
        local = stream_out_local + out_page_comp_offset_buf

        if len(self.out_shape) > 1:
            pat = ((self.stream_size, 1),
                   (self.out_shape[-2], self.stream_size))
        else:
            pat = ((self.stream_size, 1),)

        self.stream.set_sink_pattern(comp_fsm, name, out_ram,
                                     local, pat)

        # stream run (async)
        self.stream.run(comp_fsm)

        comp_fsm.goto_init()

        # sync with WriteOut control
        comp_fsm.seq.If(fsm.state == state_init)(
            comp_count(0)
        )
        comp_fsm.seq.If(self.stream.end_flag)(
            comp_count.inc()
        )

        # --------------------
        # WriteOut phase
        # --------------------
        state_write_out = fsm.current

        # sync with Comp control
        fsm.If(comp_count > out_count).goto_next()

        out_laddr = out_page_dma_offset
        out_gaddr = self.objaddr + out_offset

        bt.bus_lock(self.maxi, fsm)

        bt.dma_write(self.maxi, fsm, out_ram, out_laddr,
                     out_gaddr, self.out_write_size, port=1, use_async=True)

        bt.bus_unlock(self.maxi, fsm)

        fsm(
            out_count.inc()
        )

        fsm.goto_next()

        state_write_out_end = fsm.current
        fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end)

        # --------------------
        # update for next iteration
        # --------------------
        # ReadAct: count
        cond = None
        for size, count in zip(reversed(self.out_shape[:-2]), reversed(counts)):

            fsm.If(cond)(
                count.inc()
            )
            fsm.If(cond, count >= size - 1)(
                count(0)
            )
            if cond is not None:
                cond = vg.Ands(cond, count >= size - 1)
            else:
                cond = count >= size - 1

        # ReadAct: offset
        cond = None
        for size, count, act_offset, act_offset_stride in zip(
                reversed(self.out_shape[:-2]), reversed(counts),
                reversed(act_offsets), reversed(self.act_offset_strides)):

            fsm.If(cond)(
                act_offset.add(act_offset_stride)
            )
            fsm.If(cond, count >= size - 1)(
                act_offset(0)
            )
            if cond is not None:
                cond = vg.Ands(cond, count >= size - 1)
            else:
                cond = count >= size - 1

        # ReadAct and Comp: double buffer
        fsm.If(vg.Not(act_page))(
            act_page_comp_offset(act_page_size),
            act_page_dma_offset(act_page_size),
            act_page(1)
        )
        fsm.If(act_page)(
            act_page_comp_offset(0),
            act_page_dma_offset(0),
            act_page(0)
        )

        # WriteOut: offset
        cond = vg.Not(skip_write_out)
        for size, prev_count, out_offset, out_offset_stride in zip(
                reversed(self.out_shape[:-2]), reversed(prev_counts),
                reversed(out_offsets), reversed(self.out_offset_strides)):

            fsm.If(cond)(
                out_offset.add(out_offset_stride)
            )
            fsm.If(cond, prev_count >= size - 1)(
                out_offset(0)
            )
            cond = vg.Ands(cond, prev_count >= size - 1)

        # WriteOut and Comp: double buffer
        fsm.If(vg.Not(out_page))(
            out_page_comp_offset(out_page_size),
            out_page_dma_offset(0),
            out_page(1)
        )
        fsm.If(out_page)(
            out_page_comp_offset(0),
            out_page_dma_offset(out_page_size),
            out_page(0)
        )

        # ReadAct and WriteOut: prev
        for count, prev_count in zip(counts, prev_counts):
            fsm(
                prev_count(count)
            )

        # ReadAct, Comp, WriteOut: skip
        cond_skip_read_act = None
        cond_skip_comp = None
        for size, count in zip(reversed(self.out_shape[:-2]), reversed(counts)):
            if cond_skip_read_act is not None:
                cond_skip_read_act = vg.Ands(cond_skip_read_act, count >= size - 1)
            else:
                cond_skip_read_act = count >= size - 1

        cond_skip_comp = cond_skip_read_act

        cond_cancel_write_out = None
        for size, prev_count in zip(reversed(self.out_shape[:-2]), reversed(prev_counts)):
            if cond_cancel_write_out is not None:
                cond_cancel_write_out = vg.Ands(cond_cancel_write_out, prev_count == 0)
            else:
                cond_cancel_write_out = prev_count == 0

        cond_done = None
        for size, prev_count in zip(reversed(self.out_shape[:-2]), reversed(prev_counts)):
            if cond_done is not None:
                cond_done = vg.Ands(cond_done, prev_count >= size - 1)
            else:
                cond_done = prev_count >= size - 1

        fsm.If(cond_skip_read_act)(
            skip_read_act(1)
        )
        fsm.If(cond_skip_comp)(
            skip_comp(1)
        )
        fsm.If(skip_write_out,
               cond_cancel_write_out)(
            skip_write_out(0)
        )

        fsm.goto(state_read_act)
        fsm.If(vg.Not(skip_write_out), cond_done).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)
Esempio n. 2
0
    def control_sequence(self, fsm):
        arg = self.args[0]
        ram = self.input_rams[0]

        shape = self.get_aligned_shape()
        arg_shape = arg.get_aligned_shape()

        # burst read, scatter write
        write_order = list(
            reversed([self.transpose_perm.index(i)
                      for i in range(len(shape))]))
        write_pattern = bt.shape_to_pattern(shape, write_order)

        read_offset = self.m.TmpReg(self.maxi.addrwidth, initval=0)
        write_offsets = [
            self.m.TmpReg(self.maxi.addrwidth, initval=0)
            for _ in write_pattern
        ]
        write_all_offset = self.objaddr
        for write_offset in write_offsets:
            write_all_offset += write_offset

        read_counts = [
            self.m.TmpReg(self.maxi.addrwidth, initval=0)
            for _ in write_pattern
        ]

        # initialize
        fsm(read_offset(0),
            [write_offset(0) for write_offset in write_offsets],
            [read_count(0) for read_count in read_counts])
        fsm.goto_next()

        # DMA read
        read_state = fsm.current

        laddr = 0
        gaddr = self.arg_objaddrs[0] + read_offset
        read_size = arg_shape[-1]

        bt.bus_lock(self.maxi, fsm)
        bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, read_size)
        bt.bus_unlock(self.maxi, fsm)

        # read-modify-write
        modify_state = fsm.current

        laddr = read_counts[0]
        gaddr = write_all_offset

        bt.read_modify_write(self.m, fsm, self.maxi, ram, self.output_rams[0],
                             laddr, gaddr)

        prev_done = 1
        for (read_count, maxval, write_offset,
             (out_size, out_stride)) in zip(read_counts, reversed(arg_shape),
                                            write_offsets, write_pattern):
            fsm.If(prev_done)(
                read_count.inc(),
                write_offset.add(
                    optimize(bt.to_byte(out_stride * self.get_ram_width()))))
            fsm.If(prev_done, read_count == maxval - 1)(read_count(0),
                                                        write_offset(0))
            prev_done = vg.Ands(prev_done, (read_count == maxval - 1))

        fsm.If(laddr == read_size - 1)(read_offset.add(
            optimize(bt.to_byte(read_size * arg.get_ram_width()))))
        fsm.If(laddr < read_size - 1).goto(modify_state)
        fsm.If(laddr == read_size - 1).goto(read_state)
        fsm.If(prev_done).goto_next()
Esempio n. 3
0
    def control_sequence(self, fsm):
        sources = self.collect_sources()

        arg_gaddrs = [
            self.m.Reg(self._name('arg_gaddr_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i, _ in enumerate(self.arg_objaddrs)
        ]
        out_gaddr = self.m.Reg(self._name('out_gaddr'),
                               self.maxi.addrwidth,
                               initval=0)
        out_gaddr_offset = self.m.Reg(self._name('out_gaddr_offset'),
                                      self.maxi.addrwidth,
                                      initval=0)
        out_pos_col = self.m.Reg(self._name('out_pos_col'),
                                 self.maxi.addrwidth + 1,
                                 initval=0)
        out_pos_row = self.m.Reg(self._name('out_pos_row'),
                                 self.maxi.addrwidth + 1,
                                 initval=0)
        out_col_count = self.m.Reg(self._name('out_col_count'),
                                   self.maxi.addrwidth + 1,
                                   initval=0)
        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth + 1,
                                initval=0)
        wrap_counts = [
            self.m.Reg(self._name('wrap_count_%d' % i),
                       self.maxi.addrwidth + 1,
                       initval=0) for i, arg in enumerate(sources)
        ]

        arg_pages = [
            self.m.Reg(self._name('arg_page_%d' % i), initval=0)
            for i, _ in enumerate(self.arg_objaddrs)
        ]
        arg_page_comp_offsets = [
            self.m.Reg(self._name('arg_page_comp_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i, _ in enumerate(self.arg_objaddrs)
        ]
        arg_page_dma_offsets = [
            self.m.Reg(self._name('arg_page_dma_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i, _ in enumerate(self.arg_objaddrs)
        ]

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth,
                                          initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth,
                                         initval=0)

        arg_page_size = self.output_rams[0].length // 2
        out_page_size = self.output_rams[0].length // 2

        skip_read = self.m.Reg(self._name('skip_read'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write = self.m.Reg(self._name('skip_write'), initval=0)

        # --------------------
        # initialization phase
        # --------------------
        fsm([arg_gaddr(0) for arg_gaddr in arg_gaddrs])

        fsm(comp_count(0), out_gaddr(0), out_gaddr_offset(0), out_pos_col(0),
            out_pos_row(0), out_col_count(0),
            [wrap_count(0) for wrap_count in wrap_counts])

        fsm([arg_page(0) for arg_page in arg_pages], [
            arg_page_comp_offset(0)
            for arg_page_comp_offset in arg_page_comp_offsets
        ], [
            arg_page_dma_offset(0)
            for arg_page_dma_offset in arg_page_dma_offsets
        ])

        fsm(out_page(0), out_page_comp_offset(0),
            out_page_dma_offset(out_page_size))

        fsm(skip_read(0), skip_comp(0), skip_write(1))

        fsm.goto_next()

        # --------------------
        # Read phase
        # --------------------
        state_read = fsm.current

        # DMA read -> Stream run -> Stream wait -> DMA write
        for (ram, arg_objaddr, arg_gaddr, arg_page_dma_offset, wrap_mode,
             wrap_count, arg) in zip(self.input_rams, self.arg_objaddrs,
                                     arg_gaddrs, arg_page_dma_offsets,
                                     self.wrap_modes, wrap_counts, sources):

            b = fsm.current
            fsm.goto_next()

            # normal
            laddr = arg_page_dma_offset
            gaddr = arg_objaddr + arg_gaddr
            bt.bus_lock(self.maxi, fsm)
            bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, self.dma_size)
            bt.bus_unlock(self.maxi, fsm)
            fsm.goto_next()

            b_stride0 = fsm.current
            fsm.goto_next()

            # stride-0
            bt.bus_lock(self.maxi, fsm)
            bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, 1)
            bt.bus_unlock(self.maxi, fsm)
            fsm.goto_next()

            # for reuse
            e = fsm.current
            fsm.If(wrap_mode == 2, wrap_count > 0).goto_from(b, e)
            fsm.If(wrap_mode == 2, wrap_count == 0).goto_from(b, b_stride0)
            fsm.If(wrap_mode != 2).goto_from(b_stride0, e)

        state_read_end = fsm.current
        fsm.If(skip_read).goto_from(state_read, state_read_end)

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        self.stream.source_join(fsm)

        # set_source, set_constant (dup)
        for (source_name, dup_name, arg_page_comp_offset, ram,
             wrap_mode) in zip(self.stream.sources.keys(),
                               self.stream.constants.keys(),
                               arg_page_comp_offsets, self.input_rams,
                               self.wrap_modes):
            read_laddr = arg_page_comp_offset
            read_size = self.dma_size
            stride = vg.Mux(wrap_mode == 2, 0, 1)
            dup = vg.Mux(wrap_mode == 2, 1, 0)
            self.stream.set_constant(fsm, dup_name, dup)
            fsm.set_index(fsm.current - 1)
            self.stream.set_source(fsm, source_name, ram, read_laddr,
                                   read_size, stride)
            fsm.set_index(fsm.current - 1)

        # set_sink
        write_laddr = out_page_comp_offset
        write_size = self.dma_size

        for name, ram in zip(self.stream.sinks.keys(), self.output_rams):
            self.stream.set_sink(fsm, name, ram, write_laddr, write_size)
            fsm.set_index(fsm.current - 1)

        fsm.goto_next()

        self.stream.run(fsm)

        state_comp_end = fsm.current

        self.stream.join(fsm)

        state_comp_end_join = fsm.current

        fsm.If(skip_comp).goto_from(state_comp, state_comp_end)
        fsm.If(vg.Not(skip_comp)).goto_from(state_comp_end,
                                            state_comp_end_join)

        # --------------------
        # Write phase
        # --------------------
        state_write = fsm.current

        laddr = out_page_dma_offset
        gaddr_base = self.objaddr + out_gaddr

        bt.bus_lock(self.maxi, fsm)

        b = fsm.current

        gaddr = gaddr_base + out_gaddr_offset
        bt.dma_write(self.maxi,
                     fsm,
                     self.output_rams[0],
                     laddr,
                     gaddr,
                     self.dma_size,
                     use_async=True)

        fsm(
            out_pos_col.inc(),
            out_gaddr_offset.add(self.out_col_step),
        )
        fsm.If(out_pos_col == self.max_out_pos_col)(out_pos_col(0),
                                                    out_pos_row.inc(),
                                                    out_gaddr_offset.add(
                                                        self.out_row_step))

        fsm.goto(b)
        fsm.If(out_pos_col == self.max_out_pos_col,
               out_pos_row == self.max_out_pos_row).goto_next()

        bt.bus_unlock(self.maxi, fsm)

        fsm.goto_next()

        state_write_end = fsm.current
        fsm.If(skip_write).goto_from(state_write, state_write_end)

        # --------------------
        # update for next iteration
        # --------------------
        fsm(comp_count.inc())

        fsm(out_gaddr_offset(0), out_pos_col(0), out_pos_row(0))

        fsm.If(vg.Not(skip_write))(
            out_gaddr.add(self.out_col_inc),
            out_col_count.inc(),
        )
        fsm.If(vg.Not(skip_write), out_col_count == self.max_out_col_count)(
            out_gaddr.add(self.out_row_inc),
            out_col_count(0),
        )

        for (arg_gaddr, arg_addr_inc, arg_page, arg_page_comp_offset,
             arg_page_dma_offset, wrap_mode, wrap_size, wrap_count,
             arg) in zip(arg_gaddrs, self.arg_addr_incs, arg_pages,
                         arg_page_comp_offsets, arg_page_dma_offsets,
                         self.wrap_modes, self.wrap_sizes, wrap_counts,
                         sources):

            fsm.If(wrap_mode == 2)(wrap_count(1))

            fsm.If(wrap_mode == 1)(arg_gaddr.add(arg_addr_inc),
                                   wrap_count.inc())
            fsm.If(wrap_mode == 1, wrap_count == wrap_size - 1)(arg_gaddr(0),
                                                                wrap_count(0))

            fsm.If(wrap_mode == 0)(arg_gaddr.add(arg_addr_inc))

            fsm.If(vg.Not(arg_page),
                   wrap_mode != 2)(arg_page_comp_offset(arg_page_size),
                                   arg_page_dma_offset(out_page_size),
                                   arg_page(1))
            fsm.If(arg_page, wrap_mode != 2)(arg_page_comp_offset(0),
                                             arg_page_dma_offset(0),
                                             arg_page(0))

        fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size),
                                 out_page_dma_offset(0), out_page(1))
        fsm.If(out_page)(out_page_comp_offset(0),
                         out_page_dma_offset(out_page_size), out_page(0))

        fsm(skip_write(0))
        fsm.If(comp_count == self.num_comp - 1)(skip_read(1), skip_comp(1))

        fsm.If(comp_count < self.num_comp).goto(state_read)
        fsm.If(comp_count == self.num_comp).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)
Esempio n. 4
0
    def control_sequence(self, fsm):
        ksize_ch = self.ksize[-1]
        ksize_col = self.ksize[-2]
        ksize_row = self.ksize[-3]
        ksize_bat = self.ksize[-4]

        self.stride_bat = 1

        act_rams = self.input_rams
        out_ram = self.output_rams[0]

        act_base_offset = self.m.Wire(self._name('act_base_offset'),
                                      self.maxi.addrwidth,
                                      signed=True)
        act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)
        act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)

        act_base_offset.assign(act_base_offset_row + act_base_offset_bat)

        out_base_offset = self.m.Wire(self._name('out_base_offset'),
                                      self.maxi.addrwidth,
                                      signed=True)
        out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)
        out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'),
                                         self.maxi.addrwidth,
                                         initval=0,
                                         signed=True)

        out_base_offset.assign(out_base_offset_row + out_base_offset_bat)

        dma_flags = [
            self.m.Reg(self._name('dma_flag_%d' % i), initval=0)
            for i in range(ksize_row)
        ]

        col_count = self.m.Reg(self._name('col_count'),
                               self.maxi.addrwidth,
                               initval=0)
        row_count = self.m.Reg(self._name('row_count'),
                               self.maxi.addrwidth,
                               initval=0)
        bat_count = self.m.Reg(self._name('bat_count'),
                               self.maxi.addrwidth,
                               initval=0)
        col_select = self.m.Reg(self._name('col_select'),
                                bt.log_width(ksize_col),
                                initval=0)
        row_select = self.m.Reg(self._name('row_select'),
                                bt.log_width(ksize_row),
                                initval=0)

        prev_row_count = self.m.Reg(self._name('prev_row_count'),
                                    self.maxi.addrwidth,
                                    initval=0)
        prev_bat_count = self.m.Reg(self._name('prev_bat_count'),
                                    self.maxi.addrwidth,
                                    initval=0)
        prev_row_select = self.m.Reg(self._name('prev_row_select'),
                                     bt.log_width(ksize_row),
                                     initval=0)

        stream_act_locals = [
            self.m.Reg(self._name('stream_act_local_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(len(act_rams))
        ]
        stream_out_local = self.m.Reg(self._name('stream_out_local'),
                                      self.maxi.addrwidth,
                                      initval=0)

        # double buffer control
        act_pages = [
            self.m.Reg(self._name('act_page_%d' % i), initval=0)
            for i in range(ksize_row)
        ]
        act_page_comp_offsets = [
            self.m.Reg(self._name('act_page_comp_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]
        act_page_dma_offsets = [
            self.m.Reg(self._name('act_page_dma_offset_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth,
                                          initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth,
                                         initval=0)

        act_page_size = act_rams[0].length // 2
        out_page_size = out_ram.length // 2

        skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0)

        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth,
                                initval=0)
        out_count = self.m.Reg(self._name('out_count'),
                               self.maxi.addrwidth,
                               initval=0)

        # --------------------
        # initialization phase
        # --------------------
        # ReadAct: offset
        fsm(act_base_offset_row(0), act_base_offset_bat(0))

        act_offsets = []
        for v in self.act_offset_values:
            act_offset = act_base_offset + v
            act_offsets.append(act_offset)

        # ReadAct: DMA flag
        for y, dma_flag in enumerate(dma_flags):
            fsm(dma_flag(1))

        dma_pad_masks = []

        for y in range(ksize_row):
            v = vg.Ors((row_count + y < self.pad_row_top),
                       (row_count + y >= self.act_num_row + self.pad_row_top))
            dma_pad_mask = self.m.Wire(self._name('dma_pad_mask_%d' % y))
            dma_pad_mask.assign(v)
            dma_pad_masks.append(dma_pad_mask)

        # ReadAct: double buffer control
        fsm([act_page(0) for act_page in act_pages], [
            act_page_comp_offset(0)
            for act_page_comp_offset in act_page_comp_offsets
        ], [
            act_page_dma_offset(0)
            for act_page_dma_offset in act_page_dma_offsets
        ])

        # WriteOutput: offset
        fsm(out_base_offset_row(0), out_base_offset_bat(0))

        out_offset = out_base_offset

        # WriteOut: double buffer control
        fsm(out_page(0), out_page_comp_offset(0), out_page_dma_offset(0))

        # counter
        fsm(row_count(0), bat_count(0), row_select(0), prev_row_count(0),
            prev_bat_count(0), prev_row_select(0))

        # double buffer control
        fsm(skip_read_act(0), skip_comp(0), skip_write_out(1))

        fsm(out_count(0))

        state_init = fsm.current

        fsm.goto_next()

        # --------------------
        # ReadAct phase
        # --------------------
        state_read_act = fsm.current

        act_gaddrs = []
        for act_offset in act_offsets:
            act_gaddr = self.arg_objaddrs[0] + act_offset
            act_gaddrs.append(act_gaddr)

        act_rams_2d = line_to_2d(act_rams, ksize_col)
        mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row)
        mux_act_gaddrs = []
        for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values):
            mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i),
                                        self.maxi.addrwidth)
            mux_act_gaddr.assign(mux_act_gaddr_value)
            mux_act_gaddrs.append(mux_act_gaddr)

        mux_dma_pad_mask_values = mux_1d(dma_pad_masks, row_select, ksize_row)
        mux_dma_pad_masks = []
        for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values):
            mux_dma_pad_mask = self.m.Wire(
                self._name('mux_dma_pad_mask_%d' % i))
            mux_dma_pad_mask.assign(mux_dma_pad_mask_value)
            mux_dma_pad_masks.append(mux_dma_pad_mask)

        # determined at the previous phase
        mux_dma_flag_values = mux_1d(dma_flags, prev_row_select, ksize_row)
        mux_dma_flags = []
        for i, mux_dma_flag_value in enumerate(mux_dma_flag_values):
            mux_dma_flag = self.m.Wire(self._name('mux_dma_flag_%d' % i))
            mux_dma_flag.assign(mux_dma_flag_value)
            mux_dma_flags.append(mux_dma_flag)

        bt.bus_lock(self.maxi, fsm)

        for (act_rams_row, act_gaddr, act_page_dma_offset, dma_pad_mask,
             dma_flag) in zip(act_rams_2d, mux_act_gaddrs,
                              act_page_dma_offsets, mux_dma_pad_masks,
                              mux_dma_flags):
            act_laddr = act_page_dma_offset

            begin_state_read = fsm.current
            fsm.goto_next()

            if len(act_rams_row) == 1:
                bt.dma_read(self.maxi,
                            fsm,
                            act_rams_row[0],
                            act_laddr,
                            act_gaddr,
                            self.act_read_size,
                            port=1)
            else:
                bt.dma_read_block(self.maxi,
                                  fsm,
                                  act_rams_row,
                                  act_laddr,
                                  act_gaddr,
                                  self.act_read_size,
                                  self.act_read_block,
                                  port=1)

            end_state_read = fsm.current

            fsm.If(vg.Ors(dma_pad_mask,
                          vg.Not(dma_flag))).goto_from(begin_state_read,
                                                       end_state_read)

        bt.bus_unlock(self.maxi, fsm)

        fsm.goto_next()
        state_read_act_end = fsm.current
        fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end)

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        # Stream Control FSM
        comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst)

        comp_state_init = comp_fsm.current
        comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next()

        fsm.If(comp_fsm.state == comp_state_init).goto_next()

        # local address
        stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col)
        for y, stream_act_locals_row in enumerate(stream_act_locals_2d):
            for x, stream_act_local in enumerate(stream_act_locals_row):
                comp_fsm(stream_act_local(0))
                comp_fsm.If(self.stream_act_local_small_flags[x])(
                    stream_act_local(self.stream_act_local_small_offset))
                comp_fsm.If(self.stream_act_local_large_flags[x])(
                    stream_act_local(self.stream_act_local_large_offset))

        comp_fsm(stream_out_local(0))

        # count and sel
        comp_fsm(col_count(0))
        comp_fsm(col_select(self.col_select_initval))

        act_page_comp_offset_bufs = [
            self.m.Reg(self._name('act_page_comp_offset_buf_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i in range(ksize_row)
        ]
        out_page_comp_offset_buf = self.m.Reg(
            self._name('out_page_comp_offset_buf'),
            self.maxi.addrwidth,
            initval=0)
        row_count_buf = self.m.Reg(self._name('row_count_buf'),
                                   self.maxi.addrwidth,
                                   initval=0)
        row_select_buf = self.m.Reg(self._name('row_select_buf'),
                                    bt.log_width(ksize_row),
                                    initval=0)
        comp_fsm([
            act_page_comp_offset_buf(act_page_comp_offset)
            for act_page_comp_offset_buf, act_page_comp_offset in zip(
                act_page_comp_offset_bufs, act_page_comp_offsets)
        ], out_page_comp_offset_buf(out_page_comp_offset),
                 row_count_buf(row_count), row_select_buf(row_select))

        comp_fsm.goto_next()

        # repeat comp
        comp_state_rep = comp_fsm.current

        # pad_mask
        stream_pad_masks = []

        for y in range(ksize_row):
            for x in range(ksize_col):
                stream_col_count = col_count + x
                stream_row_count = row_count_buf + y
                v = vg.Ors(
                    (stream_col_count < self.pad_col_left),
                    (stream_col_count >= self.act_num_col + self.pad_col_left),
                    (stream_row_count < self.pad_row_top),
                    (stream_row_count >= self.act_num_row + self.pad_row_top))
                stream_pad_mask = self.m.Wire(
                    self._name('stream_pad_mask_%d_%d' % (y, x)))
                stream_pad_mask.assign(v)
                stream_pad_masks.append(stream_pad_mask)

        stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col)
        stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d, col_select,
                                        row_select_buf, ksize_col, ksize_row)
        stream_pad_masks = [
            flatten for inner in stream_pad_mask_2d_mux for flatten in inner
        ]

        stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'),
                                          len(stream_pad_masks),
                                          initval=0)
        comp_fsm(stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks))))
        comp_fsm.goto_next()

        # busy check
        self.stream.source_join(comp_fsm)

        stream_masks = stream_pad_masks_reg

        # set_constant
        name = list(self.stream.constants.keys())[0]
        self.stream.set_constant(comp_fsm, name, stream_masks)
        comp_fsm.set_index(comp_fsm.current - 1)

        # set_source
        act_page_comp_offset_bufs_dup = []
        for act_page_comp_offset_buf in act_page_comp_offset_bufs:
            act_page_comp_offset_bufs_dup.extend([act_page_comp_offset_buf] *
                                                 ksize_col)

        for name, ram, stream_act_local, act_page_comp_offset_buf in zip(
                self.stream.sources.keys(), act_rams, stream_act_locals,
                act_page_comp_offset_bufs_dup):
            local = stream_act_local + act_page_comp_offset_buf
            self.stream.set_source(comp_fsm, name, ram, local,
                                   self.stream_size)
            comp_fsm.set_index(comp_fsm.current - 1)

        # set_sink
        name = list(self.stream.sinks.keys())[0]
        local = stream_out_local + out_page_comp_offset_buf
        self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size)

        # stream run (async)
        self.stream.run(comp_fsm)

        # stream_act_local
        stream_act_locals_2d = line_to_2d(stream_act_locals, ksize_col)

        i = 0
        for y, stream_act_locals_row in enumerate(stream_act_locals_2d):
            for x, stream_act_local in enumerate(stream_act_locals_row):
                patterns = []
                for col in range(ksize_col):
                    val = self.inc_act_laddr_conds[i]
                    i += 1
                    pat = (col_select == col, val)
                    patterns.append(pat)

                patterns.append((None, 0))
                v = vg.PatternMux(*patterns)

                comp_fsm.If(vg.Not(v))(stream_act_local.add(
                    self.inc_act_laddr_small))
                comp_fsm.If(v)(stream_act_local.add(self.inc_act_laddr_large))

                comp_fsm.If(col_count >= self.max_col_count)(
                    stream_act_local(0))
                comp_fsm.If(col_count >= self.max_col_count,
                            self.stream_act_local_small_flags[x])(
                                stream_act_local(
                                    self.stream_act_local_small_offset))
                comp_fsm.If(col_count >= self.max_col_count,
                            self.stream_act_local_large_flags[x])(
                                stream_act_local(
                                    self.stream_act_local_large_offset))

        # stream_out_local
        comp_fsm(stream_out_local.add(self.inc_out_laddr))
        comp_fsm.If(col_count >= self.max_col_count)(stream_out_local(0))

        # counter
        comp_fsm(col_count.add(self.stride_col))
        comp_fsm.If(col_count >= self.max_col_count)(col_count(0))

        comp_fsm(col_select.add(self.stride_col_mod_ksize))
        comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)(
            col_select.sub(self.ksize_col_minus_stride_col_mod))

        comp_fsm.If(col_count >= self.max_col_count)(col_select(
            self.col_select_initval), )

        # repeat
        comp_fsm.goto(comp_state_rep)
        comp_fsm.If(col_count >= self.max_col_count).goto_init()

        # sync with WriteOut control
        comp_fsm.seq.If(fsm.state == state_init)(comp_count(0))
        comp_fsm.seq.If(self.stream.end_flag)(comp_count.add(
            self.inc_out_laddr))

        # --------------------
        # WriteOut phase
        # --------------------
        state_write_out = fsm.current

        # sync with Comp control
        fsm.If(comp_count >= out_count + self.out_write_size).goto_next()

        out_laddr = out_page_dma_offset
        out_gaddr = self.objaddr + out_offset

        bt.bus_lock(self.maxi, fsm)

        bt.dma_write(self.maxi,
                     fsm,
                     out_ram,
                     out_laddr,
                     out_gaddr,
                     self.out_write_size,
                     port=1,
                     use_async=True)

        bt.bus_unlock(self.maxi, fsm)

        fsm(out_count.add(self.out_write_size))

        fsm.goto_next()

        state_write_out_end = fsm.current
        fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end)

        # --------------------
        # update for next iteration
        # --------------------
        # ReadAct: offset
        fsm(act_base_offset_row.add(self.act_row_step))
        fsm.If(row_count >= self.max_row_count)(act_base_offset_row(0),
                                                act_base_offset_bat.add(
                                                    self.act_bat_step))
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(act_base_offset_bat(0))

        # ReadAct: DMA flag
        next_dma_flags = []
        for dma_flag, dma_flag_cond in zip(dma_flags, self.dma_flag_conds):
            fsm(dma_flag(dma_flag_cond))

            fsm.If(row_count >= self.max_row_count)(dma_flag(1))

            next_dma_flags.append(
                vg.Mux(row_count >= self.max_row_count, 1, dma_flag_cond))

        # ReadAct: counter
        fsm(row_count.add(self.stride_row))
        fsm.If(row_count >= self.max_row_count)(row_count(0),
                                                bat_count.add(self.stride_bat))
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(bat_count(0))

        fsm.If(self.stride_row < ksize_row)(row_select.add(self.stride_row),
                                            prev_row_select(row_select))
        fsm.If(self.stride_row < ksize_row,
               row_select + self.stride_row >= ksize_row)(
                   row_select(row_select -
                              (vg.Int(ksize_row) - self.stride_row)),
                   prev_row_select(row_select))
        fsm.If(vg.Not(self.stride_row < ksize_row))(row_select(0),
                                                    prev_row_select(0))

        fsm.If(row_count >= self.max_row_count)(row_select(0),
                                                prev_row_select(0))

        # ReadAct and Comp: double buffer
        mux_next_dma_flag_values = mux_1d(next_dma_flags, row_select,
                                          ksize_row)
        mux_next_dma_flags = []
        for i, mux_next_dma_flag_value in enumerate(mux_next_dma_flag_values):
            mux_next_dma_flag = self.m.Wire(
                self._name('mux_next_dma_flag_%d' % i))
            mux_next_dma_flag.assign(mux_next_dma_flag_value)
            mux_next_dma_flags.append(mux_next_dma_flag)

        for (act_page, act_page_comp_offset, act_page_dma_offset,
             mux_next_dma_flag) in zip(act_pages, act_page_comp_offsets,
                                       act_page_dma_offsets,
                                       mux_next_dma_flags):

            fsm.If(vg.Not(act_page),
                   mux_next_dma_flag)(act_page_comp_offset(act_page_size),
                                      act_page_dma_offset(act_page_size),
                                      act_page(1))
            fsm.If(act_page, mux_next_dma_flag)(act_page_comp_offset(0),
                                                act_page_dma_offset(0),
                                                act_page(0))

        # WriteOut: counter
        fsm.If(vg.Not(skip_write_out))(out_base_offset_row.add(
            self.out_row_step))
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count)(
            out_base_offset_row(0), out_base_offset_bat.add(self.out_bat_step))
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count)(out_base_offset_bat(0))

        # WriteOut and Comp: double buffer
        fsm.If(vg.Not(out_page))(out_page_comp_offset(out_page_size),
                                 out_page_dma_offset(0), out_page(1))
        fsm.If(out_page)(out_page_comp_offset(0),
                         out_page_dma_offset(out_page_size), out_page(0))

        # ReadAct and WriteOut: prev
        fsm(prev_row_count(row_count), prev_bat_count(bat_count))

        # ReadAct, Comp, WriteOut: skip
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(skip_read_act(1))

        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(skip_comp(1))

        fsm.If(skip_write_out, prev_row_count == 0,
               prev_bat_count == 0)(skip_write_out(0))

        fsm.goto(state_read_act)
        fsm.If(vg.Not(skip_write_out), prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)
Esempio n. 5
0
    def control_sequence(self, fsm):
        ksize_ch = self.ksize[-1]
        ksize_col = self.ksize[-2]
        ksize_row = self.ksize[-3]
        ksize_bat = self.ksize[-4]

        self.stride_bat = 1

        act_ram = self.input_rams[0]
        out_ram = self.output_rams[0]

        act_base_offset = self.m.Wire(self._name('act_base_offset'),
                                      self.maxi.addrwidth, signed=True)
        act_base_offset_row = self.m.Reg(self._name('act_base_offset_row'),
                                         self.maxi.addrwidth, initval=0, signed=True)
        act_base_offset_bat = self.m.Reg(self._name('act_base_offset_bat'),
                                         self.maxi.addrwidth, initval=0, signed=True)

        act_base_offset.assign(act_base_offset_row
                               + act_base_offset_bat)

        out_base_offset = self.m.Wire(self._name('out_base_offset'),
                                      self.maxi.addrwidth, signed=True)
        out_base_offset_row = self.m.Reg(self._name('out_base_offset_row'),
                                         self.maxi.addrwidth, initval=0, signed=True)
        out_base_offset_bat = self.m.Reg(self._name('out_base_offset_bat'),
                                         self.maxi.addrwidth, initval=0, signed=True)

        out_base_offset.assign(out_base_offset_row
                               + out_base_offset_bat)

        col_count = self.m.Reg(self._name('col_count'),
                               self.maxi.addrwidth, initval=0)
        row_count = self.m.Reg(self._name('row_count'),
                               self.maxi.addrwidth, initval=0)
        bat_count = self.m.Reg(self._name('bat_count'),
                               self.maxi.addrwidth, initval=0)

        if not self.no_reuse:
            col_select = self.m.Reg(self._name('col_select'),
                                    bt.log_width(ksize_col),
                                    initval=0)
            row_select = self.m.Reg(self._name('row_select'),
                                    bt.log_width(ksize_row),
                                    initval=0)

        prev_row_count = self.m.Reg(self._name('prev_row_count'),
                                    self.maxi.addrwidth, initval=0)
        prev_bat_count = self.m.Reg(self._name('prev_bat_count'),
                                    self.maxi.addrwidth, initval=0)

        if not self.no_reuse:
            prev_row_select = self.m.Reg(self._name('prev_row_select'),
                                         bt.log_width(ksize_row),
                                         initval=0)

        stream_act_local = self.m.Reg(self._name('stream_act_local'),
                                      self.maxi.addrwidth, initval=0)
        stream_out_local = self.m.Reg(self._name('stream_out_local'),
                                      self.maxi.addrwidth, initval=0)

        # double buffer control
        act_page = self.m.Reg(self._name('act_page'), initval=0)
        act_page_comp_offset = self.m.Reg(self._name('act_page_comp_offset'),
                                          self.maxi.addrwidth, initval=0)
        act_page_dma_offset = self.m.Reg(self._name('act_page_dma_offset'),
                                         self.maxi.addrwidth, initval=0)

        out_page = self.m.Reg(self._name('out_page'), initval=0)
        out_page_comp_offset = self.m.Reg(self._name('out_page_comp_offset'),
                                          self.maxi.addrwidth, initval=0)
        out_page_dma_offset = self.m.Reg(self._name('out_page_dma_offset'),
                                         self.maxi.addrwidth, initval=0)

        act_page_size = act_ram.length // 2
        out_page_size = out_ram.length // 2

        skip_read_act = self.m.Reg(self._name('skip_read_act'), initval=0)
        skip_comp = self.m.Reg(self._name('skip_comp'), initval=0)
        skip_write_out = self.m.Reg(self._name('skip_write_out'), initval=0)

        comp_count = self.m.Reg(self._name('comp_count'),
                                self.maxi.addrwidth, initval=0)
        out_count = self.m.Reg(self._name('out_count'),
                               self.maxi.addrwidth, initval=0)

        # --------------------
        # initialization phase
        # --------------------
        # ReadAct: offset
        fsm(
            act_base_offset_row(0),
            act_base_offset_bat(0)
        )

        act_offsets = []
        for v in self.act_offset_values:
            act_offset = act_base_offset + v
            act_offsets.append(act_offset)

        # ReadAct: DMA flag
        dma_pad_masks = []

        for y in range(ksize_row):
            v = vg.Ors((row_count + y < self.pad_row_top),
                       (row_count + y >= self.act_num_row + self.pad_row_top))
            dma_pad_mask = self.m.Wire(
                self._name('dma_pad_mask_%d' % y))
            dma_pad_mask.assign(v)
            dma_pad_masks.append(dma_pad_mask)

        # ReadAct: double buffer control
        fsm(
            act_page(0),
            act_page_comp_offset(0),
            act_page_dma_offset(0)
        )

        # WriteOutput: offset
        fsm(
            out_base_offset_row(0),
            out_base_offset_bat(0)
        )

        out_offset = out_base_offset

        # WriteOut: double buffer control
        fsm(
            out_page(0),
            out_page_comp_offset(0),
            out_page_dma_offset(0)
        )

        # counter
        fsm(
            row_count(0),
            bat_count(0),
            prev_row_count(0),
            prev_bat_count(0)
        )

        if not self.no_reuse:
            fsm(
                row_select(0),
                prev_row_select(0)
            )

        # double buffer control
        fsm(
            skip_read_act(0),
            skip_comp(0),
            skip_write_out(1)
        )

        fsm(
            out_count(0)
        )

        state_init = fsm.current

        fsm.goto_next()

        # --------------------
        # ReadAct phase
        # --------------------
        state_read_act = fsm.current

        act_gaddrs = []
        for act_offset in act_offsets:
            act_gaddr = self.arg_objaddrs[0] + act_offset
            act_gaddrs.append(act_gaddr)

        if not self.no_reuse:
            mux_act_gaddr_values = mux_1d(act_gaddrs, row_select, ksize_row)
            mux_act_gaddrs = []
            for i, mux_act_gaddr_value in enumerate(mux_act_gaddr_values):
                mux_act_gaddr = self.m.Wire(self._name('mux_act_gaddr_%d' % i),
                                            self.maxi.addrwidth)
                mux_act_gaddr.assign(mux_act_gaddr_value)
                mux_act_gaddrs.append(mux_act_gaddr)

            mux_dma_pad_mask_values = mux_1d(
                dma_pad_masks, row_select, ksize_row)
            mux_dma_pad_masks = []
            for i, mux_dma_pad_mask_value in enumerate(mux_dma_pad_mask_values):
                mux_dma_pad_mask = self.m.Wire(
                    self._name('mux_dma_pad_mask_%d' % i))
                mux_dma_pad_mask.assign(mux_dma_pad_mask_value)
                mux_dma_pad_masks.append(mux_dma_pad_mask)

        else:
            mux_act_gaddrs = act_gaddrs
            mux_dma_pad_masks = dma_pad_masks

        bt.bus_lock(self.maxi, fsm)

        act_laddr = act_page_dma_offset

        for (act_gaddr, dma_pad_mask) in zip(
                mux_act_gaddrs, mux_dma_pad_masks):
            begin_state_read = fsm.current
            fsm.goto_next()

            bt.dma_read(self.maxi, fsm, act_ram, act_laddr,
                        act_gaddr, self.act_read_size, port=1)

            end_state_read = fsm.current

            fsm.If(dma_pad_mask).goto_from(
                begin_state_read, end_state_read)

            act_laddr += self.act_read_size

        bt.bus_unlock(self.maxi, fsm)

        fsm.goto_next()
        state_read_act_end = fsm.current
        fsm.If(skip_read_act).goto_from(state_read_act, state_read_act_end)

        # --------------------
        # Comp phase
        # --------------------
        state_comp = fsm.current

        # Stream Control FSM
        comp_fsm = vg.FSM(self.m, self._name('comp_fsm'), self.clk, self.rst)

        comp_state_init = comp_fsm.current
        comp_fsm.If(fsm.state == state_comp, vg.Not(skip_comp)).goto_next()

        fsm.If(comp_fsm.state == comp_state_init).goto_next()

        # local address
        comp_fsm(
            stream_act_local(self.local_pad_offset)
        )

        comp_fsm(
            stream_out_local(0)
        )

        # count and sel
        comp_fsm(
            col_count(0)
        )

        if not self.no_reuse:
            comp_fsm(
                col_select(self.col_select_initval)
            )

        act_page_comp_offset_buf = self.m.Reg(self._name('act_page_comp_offset_buf'),
                                              self.maxi.addrwidth, initval=0)
        out_page_comp_offset_buf = self.m.Reg(self._name('out_page_comp_offset_buf'),
                                              self.maxi.addrwidth, initval=0)
        row_count_buf = self.m.Reg(self._name('row_count_buf'),
                                   self.maxi.addrwidth, initval=0)

        if not self.no_reuse:
            row_select_buf = self.m.Reg(self._name('row_select_buf'),
                                        bt.log_width(ksize_row),
                                        initval=0)
        comp_fsm(
            act_page_comp_offset_buf(act_page_comp_offset),
            out_page_comp_offset_buf(out_page_comp_offset),
            row_count_buf(row_count)
        )

        if not self.no_reuse:
            comp_fsm(
                row_select_buf(row_select)
            )

        comp_fsm.goto_next()

        # repeat comp
        comp_state_rep = comp_fsm.current

        # pad_mask
        stream_pad_masks = []

        for y in range(ksize_row):
            for x in range(ksize_col):
                stream_col_count = col_count + x
                stream_row_count = row_count_buf + y
                v = vg.Ors((stream_col_count < self.pad_col_left),
                           (stream_col_count >= self.act_num_col + self.pad_col_left),
                           (stream_row_count < self.pad_row_top),
                           (stream_row_count >= self.act_num_row + self.pad_row_top))
                stream_pad_mask = self.m.Wire(
                    self._name('stream_pad_mask_%d_%d' % (y, x)))
                stream_pad_mask.assign(v)
                stream_pad_masks.append(stream_pad_mask)

        if not self.no_reuse:
            stream_pad_mask_2d = line_to_2d(stream_pad_masks, ksize_col)
            stream_pad_mask_2d_mux = mux_2d(stream_pad_mask_2d,
                                            col_select, row_select_buf,
                                            ksize_col, ksize_row)
            stream_pad_masks = [flatten for inner in stream_pad_mask_2d_mux
                                for flatten in inner]

        stream_pad_masks_reg = self.m.Reg(self._name('stream_pad_masks'),
                                          len(stream_pad_masks), initval=0)
        comp_fsm(
            stream_pad_masks_reg(vg.Cat(*reversed(stream_pad_masks)))
        )
        comp_fsm.goto_next()

        # busy check
        self.stream.source_join(comp_fsm)

        stream_masks = stream_pad_masks_reg

        # set_constant
        name = list(self.stream.constants.keys())[0]
        self.stream.set_constant(comp_fsm, name, ksize_col * ksize_row)
        comp_fsm.set_index(comp_fsm.current - 1)

        name = list(self.stream.constants.keys())[1]
        self.stream.set_constant(comp_fsm, name, stream_masks)
        comp_fsm.set_index(comp_fsm.current - 1)

        # set_source
        name = list(self.stream.sources.keys())[0]
        local = stream_act_local + act_page_comp_offset_buf
        pat = ((ksize_col, self.act_read_block),
               (ksize_row, self.act_read_size),
               (self.stream_size, 1))
        self.stream.set_source_pattern(comp_fsm, name, act_ram,
                                       local, pat)
        comp_fsm.set_index(comp_fsm.current - 1)

        # set_sink
        name = list(self.stream.sinks.keys())[0]
        local = stream_out_local + out_page_comp_offset_buf
        self.stream.set_sink(comp_fsm, name, out_ram, local, self.stream_size)

        # stream run (async)
        self.stream.run(comp_fsm)

        # stream_act_local
        comp_fsm(
            stream_act_local.add(self.inc_act_laddr)
        )
        comp_fsm.If(col_count >= self.max_col_count)(
            stream_act_local(self.local_pad_offset)
        )

        # stream_out_local
        comp_fsm(
            stream_out_local.add(self.inc_out_laddr)
        )
        comp_fsm.If(col_count >= self.max_col_count)(
            stream_out_local(0)
        )

        # counter
        comp_fsm(
            col_count.add(self.stride_col)
        )
        comp_fsm.If(col_count >= self.max_col_count)(
            col_count(0)
        )

        if not self.no_reuse:
            comp_fsm(
                col_select.add(self.stride_col_mod_ksize)
            )
            comp_fsm.If(col_select + self.stride_col_mod_ksize >= ksize_col)(
                col_select.sub(self.ksize_col_minus_stride_col_mod)
            )

            comp_fsm.If(col_count >= self.max_col_count)(
                col_select(self.col_select_initval)
            )

        # repeat
        comp_fsm.goto(comp_state_rep)
        comp_fsm.If(col_count >= self.max_col_count).goto_init()

        # sync with WriteOut control
        comp_fsm.seq.If(fsm.state == state_init)(
            comp_count(0)
        )
        comp_fsm.seq.If(self.stream.end_flag)(
            comp_count.add(self.inc_out_laddr)
        )

        # --------------------
        # WriteOut phase
        # --------------------
        state_write_out = fsm.current

        # sync with Comp control
        fsm.If(comp_count >= out_count + self.out_write_size).goto_next()

        out_laddr = out_page_dma_offset
        out_gaddr = self.objaddr + out_offset

        bt.bus_lock(self.maxi, fsm)

        bt.dma_write(self.maxi, fsm, out_ram, out_laddr,
                     out_gaddr, self.out_write_size, port=1, use_async=True)

        bt.bus_unlock(self.maxi, fsm)

        fsm(
            out_count.add(self.out_write_size)
        )

        fsm.goto_next()

        state_write_out_end = fsm.current
        fsm.If(skip_write_out).goto_from(state_write_out, state_write_out_end)

        # --------------------
        # update for next iteration
        # --------------------
        # ReadAct: offset
        fsm(
            act_base_offset_row.add(self.act_row_step)
        )
        fsm.If(row_count >= self.max_row_count)(
            act_base_offset_row(0),
            act_base_offset_bat.add(self.act_bat_step)
        )
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(
            act_base_offset_bat(0)
        )

        # ReadAct: counter
        fsm(
            row_count.add(self.stride_row)
        )
        fsm.If(row_count >= self.max_row_count)(
            row_count(0),
            bat_count.add(self.stride_bat)
        )
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(
            bat_count(0)
        )

        if not self.no_reuse:
            fsm.If(self.stride_row < ksize_row)(
                row_select.add(self.stride_row),
                prev_row_select(row_select)
            )
            fsm.If(self.stride_row < ksize_row,
                   row_select + self.stride_row >= ksize_row)(
                row_select(row_select - (vg.Int(ksize_row) - self.stride_row)),
                prev_row_select(row_select)
            )
            fsm.If(vg.Not(self.stride_row < ksize_row))(
                row_select(0),
                prev_row_select(0)
            )

            fsm.If(row_count >= self.max_row_count)(
                row_select(0),
                prev_row_select(0)
            )

        # ReadAct and Comp: double buffer
        fsm.If(vg.Not(act_page))(
            act_page_comp_offset(act_page_size),
            act_page_dma_offset(act_page_size),
            act_page(1)
        )
        fsm.If(act_page)(
            act_page_comp_offset(0),
            act_page_dma_offset(0),
            act_page(0)
        )

        # WriteOut: counter
        fsm.If(vg.Not(skip_write_out))(
            out_base_offset_row.add(self.out_row_step)
        )
        fsm.If(vg.Not(skip_write_out),
               prev_row_count >= self.max_row_count)(
            out_base_offset_row(0),
            out_base_offset_bat.add(self.out_bat_step)
        )
        fsm.If(vg.Not(skip_write_out),
               prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count)(
            out_base_offset_bat(0)
        )

        # WriteOut and Comp: double buffer
        fsm.If(vg.Not(out_page))(
            out_page_comp_offset(out_page_size),
            out_page_dma_offset(0),
            out_page(1)
        )
        fsm.If(out_page)(
            out_page_comp_offset(0),
            out_page_dma_offset(out_page_size),
            out_page(0)
        )

        # ReadAct and WriteOut: prev
        fsm(
            prev_row_count(row_count),
            prev_bat_count(bat_count)
        )

        # ReadAct, Comp, WriteOut: skip
        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(
            skip_read_act(1)
        )

        fsm.If(row_count >= self.max_row_count,
               bat_count >= self.max_bat_count)(
            skip_comp(1)
        )

        fsm.If(skip_write_out,
               prev_row_count == 0,
               prev_bat_count == 0)(
            skip_write_out(0)
        )

        fsm.goto(state_read_act)
        fsm.If(vg.Not(skip_write_out),
               prev_row_count >= self.max_row_count,
               prev_bat_count >= self.max_bat_count).goto_next()

        # wait for last DMA write
        bt.dma_wait_write(self.maxi, fsm)
Esempio n. 6
0
    def control_sequence(self, fsm):
        arg_gaddrs = [
            self.m.Reg(self._name('arg_gaddr_%d' % i),
                       self.maxi.addrwidth,
                       initval=0) for i, _ in enumerate(self.arg_objaddrs)
        ]
        out_gaddr = self.m.Reg(self._name('out_gaddr'),
                               self.maxi.addrwidth,
                               initval=0)

        arg_laddr = self.m.Reg(self._name('arg_laddr'),
                               self.maxi.addrwidth,
                               initval=0)
        copy_laddr = self.m.Reg(self._name('copy_laddr'),
                                self.maxi.addrwidth,
                                initval=0)
        copy_size = self.m.Reg(self._name('copy_size'),
                               self.maxi.addrwidth,
                               initval=0)
        sum_read_sizes = self.m.Wire(
            self._name('sum_read_sizes'),
            max([i.width for i in self.arg_read_sizes]) +
            int(math.ceil(math.log2(len(self.arg_read_sizes)))))
        sum_read_sizes.assign(vg.Add(*self.arg_read_sizes))

        out_addr_inc_unbuffered = self.m.Reg(
            self._name('out_addr_inc_unbuffered'),
            self.maxi.addrwidth,
            initval=0)

        arg_select = self.m.Reg(self._name('arg_select'),
                                int(
                                    max(math.ceil(math.log(len(self.args), 2)),
                                        1)),
                                initval=0)
        prev_arg_select = self.m.Reg(
            self._name('prev_arg_select'),
            int(max(math.ceil(math.log(len(self.args), 2)), 1)),
            initval=0)
        arg_chunk_count = self.m.Reg(self._name('arg_chunk_count'),
                                     self.maxi.addrwidth + 1,
                                     initval=0)
        out_count = self.m.Reg(self._name('out_count'),
                               self.maxi.addrwidth + 1,
                               initval=0)

        # --------------------
        # initialization phase
        # --------------------
        fsm([arg_gaddr(0) for arg_gaddr in arg_gaddrs], out_gaddr(0),
            arg_laddr(0), copy_laddr(0), copy_size(0),
            out_addr_inc_unbuffered(0), arg_select(0), prev_arg_select(0),
            arg_chunk_count(0), out_count(0))

        fsm.goto_next()

        # --------------------
        # Read phase
        # --------------------
        state_read = fsm.current

        fsm.inc()

        state_read_begin_list = []
        state_read_end_list = []

        for (arg, arg_objaddr, arg_gaddr, arg_addr_inc, arg_read_size,
             arg_chunk_size) in zip(self.args, self.arg_objaddrs, arg_gaddrs,
                                    self.arg_addr_incs, self.arg_read_sizes,
                                    self.arg_chunk_sizes):

            b = fsm.current
            state_read_begin_list.append(b)

            # normal
            laddr = arg_laddr
            gaddr = arg_objaddr + arg_gaddr

            bt.bus_lock(self.maxi, fsm)
            bt.dma_read(self.maxi, fsm, self.input_rams[0], laddr, gaddr,
                        arg_read_size)
            bt.bus_unlock(self.maxi, fsm)

            fsm(arg_gaddr.add(arg_addr_inc), arg_laddr.add(arg_read_size),
                arg_chunk_count.inc(), copy_size(arg_read_size),
                out_addr_inc_unbuffered(arg_addr_inc))
            fsm.If(arg_chunk_count == arg_chunk_size - 1)(arg_chunk_count(0),
                                                          arg_select.inc())
            fsm.If(arg_chunk_count == arg_chunk_size - 1,
                   arg_select == len(self.args) - 1)(arg_select(0))
            fsm(prev_arg_select(arg_select))

            e = fsm.current
            state_read_end_list.append(e)

            fsm.inc()

        state_read_end = fsm.current

        for i, b in enumerate(state_read_begin_list):
            fsm.If(arg_select == i).goto_from(state_read, b)

        for i, e in enumerate(state_read_end_list):
            fsm.goto_from(e, state_read_end)

        # --------------------
        # Copy phase
        # --------------------
        state_copy = fsm.current

        name = list(self.stream.sources.keys())[0]
        self.stream.set_source(fsm, name, self.input_rams[0], 0, copy_size)
        fsm.set_index(fsm.current - 1)

        name = list(self.stream.constants.keys())[0]
        self.stream.set_constant(fsm, name, prev_arg_select)
        fsm.set_index(fsm.current - 1)

        name = list(self.stream.sinks.keys())[0]
        self.stream.set_sink(fsm, name, self.output_rams[0], copy_laddr,
                             copy_size)
        self.stream.run(fsm)
        self.stream.join(fsm)

        fsm(arg_laddr(0), copy_laddr.add(copy_size))
        fsm.goto_next()

        fsm.If(copy_laddr < sum_read_sizes).goto(state_read)
        fsm.If(copy_laddr >= sum_read_sizes).goto_next()

        state_copy_end = fsm.current
        fsm.If(vg.Not(self.buffered)).goto_from(state_copy, state_copy_end)

        # --------------------
        # Write phase
        # --------------------
        state_write = fsm.current
        fsm.inc()

        # Case with Copy
        state_write_buffered = fsm.current

        laddr = 0
        gaddr = self.objaddr + out_gaddr
        bt.bus_lock(self.maxi, fsm)
        bt.dma_write(self.maxi, fsm, self.output_rams[0], laddr, gaddr,
                     self.out_write_size)
        bt.bus_unlock(self.maxi, fsm)

        fsm(copy_laddr(0), out_gaddr.add(self.out_addr_inc), out_count.inc())

        state_write_end_buffered = fsm.current
        fsm.inc()

        # Case without Copy
        state_write_unbuffered = fsm.current

        laddr = 0
        gaddr = self.objaddr + out_gaddr
        bt.bus_lock(self.maxi, fsm)
        bt.dma_write(self.maxi, fsm, self.input_rams[0], laddr, gaddr,
                     copy_size)
        bt.bus_unlock(self.maxi, fsm)

        fsm(arg_laddr(0), out_gaddr.add(out_addr_inc_unbuffered),
            out_count.inc())

        state_write_end_unbuffered = fsm.current
        fsm.inc()

        state_write_end = fsm.current

        fsm.If(self.buffered).goto_from(state_write, state_write_buffered)
        fsm.If(vg.Not(self.buffered)).goto_from(state_write,
                                                state_write_unbuffered)

        fsm.goto_from(state_write_end_buffered, state_write_end)
        fsm.goto_from(state_write_end_unbuffered, state_write_end)

        # --------------------
        # update for next iteration
        # --------------------
        fsm.If(out_count < self.num_steps).goto(state_read)
        fsm.If(out_count == self.num_steps).goto_next()