Example #1
0
    def get_control_param_values(self):
        act = self.args[0]

        act_shape = act.get_aligned_shape()
        act_num_ch = act_shape[-1]

        out_shape = self.get_aligned_shape()
        out_num_ch = out_shape[-1]

        act_offset_base = bt.to_byte(act_num_ch * act.get_ram_width())

        act_offset_begins = []
        act_offset_strides = []
        for i, (begin, stride) in enumerate(
                zip(reversed(self.begins[:-2]), reversed(self.strides[:-2]))):
            mul = functools.reduce(lambda x, y: x * y, act_shape[-i - 2:-1], 1)
            act_offset_begin = act_offset_base * mul * begin
            act_offset_begins.append(act_offset_begin)
            act_offset_stride = act_offset_base * mul * stride
            act_offset_strides.append(act_offset_stride)

        act_offset_begins.reverse()
        act_offset_strides.reverse()

        act_read_size = ((act_num_ch // self.par) *
                         (act_shape[-2] if len(act_shape) > 1 else 1))

        out_offset_base = bt.to_byte(out_num_ch * self.get_ram_width())

        out_offset_strides = []
        for i in range(len(out_shape) - 2):
            mul = functools.reduce(lambda x, y: x * y, out_shape[-i - 2:-1], 1)
            out_offset_stride = out_offset_base * mul
            out_offset_strides.append(out_offset_stride)

        out_offset_strides.reverse()

        out_write_size = ((out_num_ch // self.par) *
                          (out_shape[-2] if len(out_shape) > 1 else 1))

        stream_size = out_num_ch // self.par
        if len(self.strides) > 1:
            stream_stride = self.strides[-2] * (act_num_ch // self.par)
            stream_local = self.begins[-2] * (act_num_ch //
                                              self.par) + self.begins[-1]
        else:
            stream_stride = 0
            stream_local = self.begins[-1]

        return OrderedDict([('act_shape', act_shape), ('out_shape', out_shape),
                            ('act_begins', self.begins),
                            ('act_strides', self.strides),
                            ('act_offset_begins', act_offset_begins),
                            ('act_offset_strides', act_offset_strides),
                            ('act_read_size', act_read_size),
                            ('out_offset_strides', out_offset_strides),
                            ('out_write_size', out_write_size),
                            ('stream_size', stream_size),
                            ('stream_stride', stream_stride),
                            ('stream_local', stream_local)])
Example #2
0
    def get_control_param_values(self):
        buffered = False

        for arg in self.args:
            if self.dtype != arg.dtype:
                buffered = True

        if self.axis == bt.get_rank(self.shape) - 1:
            for arg in self.args:
                if arg.shape[-1] != arg.get_aligned_shape()[-1]:
                    buffered = True

        # for __str__
        self.buffered_value = buffered

        aligned_shape = self.get_aligned_shape()
        aligned_length = self.get_aligned_length()

        arg_read_sizes = [arg.shape[-1] for arg in self.args]
        arg_addr_incs = [
            bt.to_byte(
                bt.align_word(arg.shape[-1], arg.get_word_alignment()) *
                arg.get_ram_width()) for arg in self.args
        ]

        arg_chunk_sizes = [
            functools.reduce(lambda x, y: x * y, arg.shape[self.axis:-1], 1)
            for arg in self.args
        ]

        out_write_size = aligned_shape[-1]
        out_addr_inc = bt.to_byte(
            bt.align_word(self.shape[-1], self.get_word_alignment()) *
            self.get_ram_width())

        num_steps = int(math.ceil(aligned_length / out_write_size))
        if not buffered:
            num_steps *= len(self.args)

        return OrderedDict([('buffered', buffered),
                            ('arg_read_sizes', arg_read_sizes),
                            ('arg_addr_incs', arg_addr_incs),
                            ('arg_chunk_sizes', arg_chunk_sizes),
                            ('out_write_size', out_write_size),
                            ('out_addr_inc', out_addr_inc),
                            ('num_steps', num_steps)])
Example #3
0
    def control_sequence(self, fsm):
        arg = self.args[0]
        ram = self.input_rams[0]

        shape = self.get_aligned_shape()
        arg_shape = arg.get_aligned_shape()

        # burst read, scatter write
        write_order = list(
            reversed([self.transpose_perm.index(i)
                      for i in range(len(shape))]))
        write_pattern = bt.shape_to_pattern(shape, write_order)

        read_offset = self.m.TmpReg(self.maxi.addrwidth, initval=0)
        write_offsets = [
            self.m.TmpReg(self.maxi.addrwidth, initval=0)
            for _ in write_pattern
        ]
        write_all_offset = self.objaddr
        for write_offset in write_offsets:
            write_all_offset += write_offset

        read_counts = [
            self.m.TmpReg(self.maxi.addrwidth, initval=0)
            for _ in write_pattern
        ]

        # initialize
        fsm(read_offset(0),
            [write_offset(0) for write_offset in write_offsets],
            [read_count(0) for read_count in read_counts])
        fsm.goto_next()

        # DMA read
        read_state = fsm.current

        laddr = 0
        gaddr = self.arg_objaddrs[0] + read_offset
        read_size = arg_shape[-1]

        bt.bus_lock(self.maxi, fsm)
        bt.dma_read(self.maxi, fsm, ram, laddr, gaddr, read_size)
        bt.bus_unlock(self.maxi, fsm)

        # read-modify-write
        modify_state = fsm.current

        laddr = read_counts[0]
        gaddr = write_all_offset

        bt.read_modify_write(self.m, fsm, self.maxi, ram, self.output_rams[0],
                             laddr, gaddr)

        prev_done = 1
        for (read_count, maxval, write_offset,
             (out_size, out_stride)) in zip(read_counts, reversed(arg_shape),
                                            write_offsets, write_pattern):
            fsm.If(prev_done)(
                read_count.inc(),
                write_offset.add(
                    optimize(bt.to_byte(out_stride * self.get_ram_width()))))
            fsm.If(prev_done, read_count == maxval - 1)(read_count(0),
                                                        write_offset(0))
            prev_done = vg.Ands(prev_done, (read_count == maxval - 1))

        fsm.If(laddr == read_size - 1)(read_offset.add(
            optimize(bt.to_byte(read_size * arg.get_ram_width()))))
        fsm.If(laddr < read_size - 1).goto(modify_state)
        fsm.If(laddr == read_size - 1).goto(read_state)
        fsm.If(prev_done).goto_next()
Example #4
0
    def get_control_param_values(self):
        orig_shape = self.args[0].shape

        num_words = self.get_word_alignment()

        aligned_shape = []
        for s in orig_shape[:-1]:
            aligned_shape.append(s)

        res = num_words - orig_shape[-1] % num_words

        if res == num_words:
            res = 0

        aligned_shape.append(orig_shape[-1] + res)

        aligned_length = bt.shape_to_length(aligned_shape)

        total_size = int(math.ceil(aligned_length / self.par))
        dma_size = int(math.ceil(aligned_shape[-1] / self.par))
        num_comp = int(math.ceil(total_size / dma_size))

        base = bt.to_byte(
            bt.align_word(orig_shape[-1], self.get_word_alignment()) *
            self.get_ram_width())

        factor_col = self.factors[2]
        factor_row = self.factors[1]

        out_col_step = base
        out_row_step = base * (self.shape[-2] - (factor_col - 1))
        max_out_pos_col = factor_col - 1
        max_out_pos_row = factor_row - 1

        out_col_inc = base * factor_col
        out_row_inc = out_col_inc + base * self.shape[-2] * (factor_row - 1)
        max_out_col_count = orig_shape[-2] - 1

        sources = self.collect_sources()

        arg_addr_incs = []
        wrap_modes = []
        wrap_sizes = []
        for arg in sources:
            arg_addr_inc = bt.to_byte(
                bt.align_word(arg.shape[-1], arg.get_word_alignment()) *
                arg.get_ram_width())
            if tuple(arg.shape) == tuple(orig_shape):
                wrap_mode = 0
                wrap_size = 0
            elif len(arg.shape) == 1 and arg.shape[-1] == 1:
                # stride-0
                wrap_mode = 2
                wrap_size = bt.get_wrap_size(orig_shape, arg.shape)
            else:
                # repeat
                wrap_mode = 1
                wrap_size = bt.get_wrap_size(orig_shape, arg.shape)
            arg_addr_incs.append(arg_addr_inc)
            wrap_modes.append(wrap_mode)
            wrap_sizes.append(wrap_size)

        return OrderedDict([('dma_size', dma_size), ('num_comp', num_comp),
                            ('out_col_step', out_col_step),
                            ('out_row_step', out_row_step),
                            ('max_out_pos_col', max_out_pos_col),
                            ('max_out_pos_row', max_out_pos_row),
                            ('out_col_inc', out_col_inc),
                            ('out_row_inc', out_row_inc),
                            ('max_out_col_count', max_out_col_count),
                            ('arg_addr_incs', arg_addr_incs),
                            ('wrap_modes', wrap_modes),
                            ('wrap_sizes', wrap_sizes)])
Example #5
0
    def get_control_param_values(self):
        act = self.args[0]

        ksize_ch = self.ksize[-1]
        ksize_col = self.ksize[-2]
        ksize_row = self.ksize[-3]
        ksize_bat = self.ksize[-4]

        act_shape = act.get_aligned_shape()
        act_num_ch = act_shape[-1]
        act_num_col = act_shape[-2]
        act_num_row = act_shape[-3]
        act_num_bat = act_shape[-4]

        # stride_ch = self.strides[-1]  # always 1
        stride_col = self.strides[-2]  # width
        stride_row = self.strides[-3]  # height
        stride_bat = self.strides[-4]  # always 1

        out_shape = self.get_aligned_shape()
        out_num_ch = out_shape[-1]
        out_num_col = out_shape[-2]
        out_num_row = out_shape[-3]
        out_num_bat = out_shape[-4]

        if isinstance(self.padding, str) and self.padding == 'SAME':
            pad_col, pad_col_left, pad_col_right = util.pad_size_split(
                act_num_col, ksize_col, stride_col)
            pad_row, pad_row_top, pad_row_bottom = util.pad_size_split(
                act_num_row, ksize_row, stride_row)
        elif isinstance(self.padding, int):
            pad_col = self.padding * 2
            pad_col_left = self.padding
            pad_col_right = self.padding
            pad_row = self.padding * 2
            pad_row_top = self.padding
            pad_row_bottom = self.padding
        elif isinstance(self.padding, (tuple, list)):
            pad_col = self.padding[2] + self.padding[3]
            pad_col_left = self.padding[2]
            pad_col_right = self.padding[3]
            pad_row = self.padding[0] + self.padding[1]
            pad_row_top = self.padding[0]
            pad_row_bottom = self.padding[1]
        else:
            pad_col = 0
            pad_col_left = 0
            pad_col_right = 0
            pad_row = 0
            pad_row_top = 0
            pad_row_bottom = 0

        # for __str__
        self.pad_col_left_value = pad_col_left
        self.pad_col_right_value = pad_col_right
        self.pad_row_top_value = pad_row_top
        self.pad_row_bottom_value = pad_row_bottom

        max_col_count = act_num_col + pad_col + 1 - ksize_col - stride_col
        if max_col_count < 0:
            max_col_count = 0

        max_row_count = act_num_row + pad_row + 1 - ksize_row - stride_row
        if max_row_count < 0:
            max_row_count = 0

        max_bat_count = act_num_bat - stride_bat
        if max_bat_count < 0:
            max_bat_count = 0

        dma_flag_conds = []
        for row_select in range(ksize_row):
            v = False
            for i in range(stride_row):
                v = v or (row_select == (i % ksize_row))

            dma_flag_conds.append(v)

        aligned_act_num_ch = bt.align_word(act_num_ch,
                                           act.get_word_alignment())

        act_step = bt.to_byte(aligned_act_num_ch * act.get_ram_width())

        act_offset_values = []
        for y in range(ksize_row):
            v = act_num_col * (y - pad_row_top) * act_step
            act_offset_values.append(v)

        act_row_step = act_step * act_num_col * stride_row
        act_bat_step = act_step * act_num_col * act_num_row

        act_read_size = (int(math.ceil(aligned_act_num_ch / self.par)) *
                         act_num_col)
        act_read_block = int(math.ceil(aligned_act_num_ch / self.par))

        out_step = bt.to_byte(
            bt.align_word(out_num_ch, self.get_word_alignment()) *
            self.get_ram_width())

        out_row_step = out_step * out_num_col
        out_bat_step = out_step * out_num_col * out_num_row

        out_write_size = (int(math.ceil(out_num_ch / self.par)) * out_num_col)

        stream_size = int(math.ceil(act_num_ch / self.par))

        if pad_col_left == 0:
            col_select_initval = 0
        else:
            col_select_initval = (ksize_col - pad_col_left) % ksize_col

        stride_col_mod_ksize = stride_col % ksize_col
        ksize_col_minus_stride_col_mod = ksize_col - stride_col_mod_ksize

        inc_act_laddr_conds = []
        for y in range(ksize_row):
            for x in range(ksize_col):
                for col_select in range(ksize_col):
                    v = False
                    for i in range(stride_col_mod_ksize):
                        v = v or (col_select
                                  == ((x + ksize_col - i) % ksize_col))

                    inc_act_laddr_conds.append(v)

        inc_act_laddr_small = (int(math.floor(stride_col / ksize_col)) *
                               act_read_block)
        inc_act_laddr_large = (int(math.ceil(stride_col / ksize_col)) *
                               act_read_block)
        inc_out_laddr = int(math.ceil(out_num_ch / self.par))

        stream_act_local_small_offset = (
            -1 * int(math.floor(pad_col_left / ksize_col)) * act_read_block)
        stream_act_local_large_offset = (
            -1 * int(math.ceil(pad_col_left / ksize_col)) * act_read_block)

        stream_act_local_small_flags = []
        stream_act_local_large_flags = []
        for x in range(ksize_col):
            s = (ksize_col - x) <= pad_col_left
            l = (ksize_col - x) <= (pad_col_left % ksize_col)
            stream_act_local_small_flags.append(s)
            stream_act_local_large_flags.append(s and l)

        return OrderedDict([
            ('act_num_col', act_num_col), ('act_num_row', act_num_row),
            ('stride_col', stride_col), ('stride_row', stride_row),
            ('out_num_col', out_num_col), ('out_num_row', out_num_row),
            ('pad_col_left', pad_col_left), ('pad_row_top', pad_row_top),
            ('max_col_count', max_col_count), ('max_row_count', max_row_count),
            ('max_bat_count', max_bat_count),
            ('dma_flag_conds', dma_flag_conds),
            ('act_offset_values', act_offset_values),
            ('act_row_step', act_row_step), ('act_bat_step', act_bat_step),
            ('act_read_size', act_read_size),
            ('act_read_block', act_read_block), ('out_row_step', out_row_step),
            ('out_bat_step', out_bat_step), ('out_write_size', out_write_size),
            ('stream_size', stream_size),
            ('col_select_initval', col_select_initval),
            ('stride_col_mod_ksize', stride_col_mod_ksize),
            ('ksize_col_minus_stride_col_mod', ksize_col_minus_stride_col_mod),
            ('inc_act_laddr_conds', inc_act_laddr_conds),
            ('inc_act_laddr_small', inc_act_laddr_small),
            ('inc_act_laddr_large', inc_act_laddr_large),
            ('inc_out_laddr', inc_out_laddr),
            ('stream_act_local_small_offset', stream_act_local_small_offset),
            ('stream_act_local_large_offset', stream_act_local_large_offset),
            ('stream_act_local_small_flags', stream_act_local_small_flags),
            ('stream_act_local_large_flags', stream_act_local_large_flags)
        ])