Beispiel #1
0
    def evens(self, n):
        '''evens returns a list of the length of consecutive cascading even
        divisions. For input n = 7 we have:

            7 -> 22     []
            22 -> 11    [1]
            11 -> 34
            34 -> 17    [1,1]
            17 -> 52
            52 -> 26    
            26 -> 13    [1,1,2]
            26 -> 40   
            40 -> 20
            20 -> 10
            10 -> 5     [1,1,2,3]
            5  -> 16
            16 -> 8
            8  -> 4
            4  -> 2
            2  -> 1     [1,1,2,3,4] 

        and we return [1,1,2,3,4]
        '''
        assert self.parity == 2  # This ONLY makes sense with parity of 2
        result = []
        if n % 2 == 1:
            n = self(n)
        while n > 1:
            trailing_zeros = ffs(n)
            result.append(trailing_zeros)
            n >>= trailing_zeros
            if n > 1:
                n = self.mult * n + self.add
        return result
Beispiel #2
0
def print_ffs_after_increase(n):
    assert n > 0
    i = 1  # Count which iteration we're on
    print("1: {}".format(n))
    while n > 1:
        i += 1
        m = C(n)
        if m > n:  # We increased w/ m = 3n + 1
            print("{}: {} {}".format(i, m, ffs(m)))
        else:
            print("{}: {}".format(i, m))
        n = m
Beispiel #3
0
def unload(apb_base,
           processor_map,
           input_shape,
           out_array,
           out_offset,
           in_array,
           flatten=False):
    """
    Unload HWC memory from AI84 and return it in `out_array`.
    The generated C code is specific to the network configuration passed in in `processor_map`,
    `input_shape`, and `chan`. Additionally, the generated addresses are offset by `apb_base` and
    `out_offset`. The C code function takes a pointer to a memory array, and the dimensions of
    the array do not matter (flattened or not flattened).
    The additional simulation code takes the `flatten` parameter and an `in_array`.
    If `flatten` is `True`, then the out_array is flattened.
    """
    def get_val(offs):
        """
        Returns value stored at offset `offs` in the memory array.
        """
        if offs >= (MEM_SIZE << 2) or offs < 0:
            raise RuntimeError(
                f'Offset {offs:04x} is invalid for the memory array.')
        if offs & 3:
            raise RuntimeError(
                f'Offset {offs:04x} should be a 32-bit address.')
        if in_array[offs >> 2] == MEM_INVALID:
            raise RuntimeError(
                f'Trying to read from uninitialized memory at location {offs:04x}.'
            )
        return in_array[offs >> 2]

    print('\n// Custom unload for this network:\n'
          f'// Input shape: {input_shape}\n'
          'void unload(uint8_t *out_buf)\n'
          '{\n  uint32_t val, *addr, offs;\n')

    coffs = ffs(processor_map) & ~(tc.dev.P_SHARED - 1)
    next_layer_map = processor_map >> coffs
    read_addr = None
    write_addr = None
    c = 0
    while c < input_shape[0]:
        for doffs in range(input_shape[1] * input_shape[2]):
            row, col = divmod(doffs, input_shape[2])
            this_map = next_layer_map
            this_c = c

            # Get four bytes from memory array
            proc = (coffs % tc.dev.MAX_PROC) & ~(tc.dev.P_SHARED - 1)
            # FIXME: seq = ...
            offs = out_offset + \
                (((proc % tc.dev.P_NUMPRO) * tc.dev.INSTANCE_SIZE |
                  (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) +
                 doffs) * 4

            val = get_val(offs)

            if offs != read_addr:
                print(
                    f'  addr = (uint32_t *) 0x{apb_base + tc.dev.C_SRAM_BASE + offs:08x};'
                )
            print('  val = *addr++;')
            read_addr = offs + 4

            # Singulate bytes, ignoring unused processors
            for shift in range(4):
                addr = this_c * input_shape[1] * input_shape[
                    2] + row * input_shape[1] + col
                if shift == 0:
                    if addr != write_addr:
                        print(f'  offs = 0x{addr:04x};')
                    else:
                        print('  offs++;')
                    write_addr = addr + 1
                if this_map & 1:
                    if not flatten:
                        out_array[this_c][row][col] = val & 0xff
                    else:
                        out_array[addr] = val & 0xff
                    print('  out_buf[offs', end='')
                    if shift > 0:
                        print(f'+0x{0x10 * shift:02x}', end='')
                    print('] = ', end='')
                    if shift == 0:
                        print('val', end='')
                    else:
                        print(f'(val >> {shift * 8})', end='')
                    print(' & 0xff;')
                    this_c += 1
                this_map >>= 1
                val >>= 8

        coffs += 4
        c += popcount(next_layer_map & 0x0f)
        next_layer_map >>= 4

    print('}')
Beispiel #4
0
def load(  # pylint: disable=too-many-branches,too-many-statements
    verbose,
    embedded_code,
    device,
    apb,
    start_layer,
    layers,
    operator,
    kernel,
    kernel_size,
    quantization,
    processor_map,
    output_processor_map,
    input_chan,
    output_chan,
    out_expand,
    out_expand_thresh,
    in_expand,
    in_expand_thresh,
    flatten=False,
    mexpress=False,
    verify=False,
    riscv_flash=False,
    quad=False,
    debug=False,
    blocklevel=False,
    legacy_kernels=False,
    calcx4=False,
):
    """
    Stack `kernel` values and write them to C code (for `embedded_code` if `True` or
    RTL simulation). The output is written to the `apb` object.
    Input is configured with `kernel_size`, `quantization`, `layers`, `processor_map`,
    `output_processor_map`, `input_chan`, `output_chan`, `out_expand` and `out_expand_thresh`.
    When `mexpress` is `True`, the function uses the memcpy()-friendly hardware functionality to
    reduce the number of transfers. When `verify` is also true (mexpress mode only), kernels are
    read back and compared.
    This function returns the kernel offsets and the kernel lengths for all layers.
    """
    # Kernels: Stack kernels; write only the kernels needed
    proc_kern_max = [0] * tc.dev.MAX_PROC
    kern_offs = [0] * layers
    kern_len = [0] * layers
    kernel_map = np.full((tc.dev.MAX_PROC, tc.dev.MASK_WIDTH_LARGE),
                         _INVALID_VALUE,
                         dtype=np.int64)
    kernels_used = np.zeros((tc.dev.MAX_PROC, tc.dev.MASK_WIDTH_LARGE),
                            dtype=np.int64)
    kernel_data = np.zeros((tc.dev.MAX_PROC, tc.dev.MASK_WIDTH_LARGE, 9),
                           dtype=np.int8)
    # There are four 32-bit words per 9-byte kernel.
    # The value map is initialized with zeros so we can later ignore unused entries and use
    # memcpy() on initialized and uninitialized data.
    kernel_values = np.zeros(
        (tc.dev.MAX_PROC, tc.dev.MASK_WIDTH_LARGE * _WORDS_PER_KERNEL),
        dtype=np.int64)
    if debug:
        print('\nLoading Kernels...')

    if calcx4 and not tc.dev.SUPPORT_CALCX4:
        eprint('--calcx4 is not supported on this device.')
        sys.exit(1)
    assert not (
        (embedded_code or mexpress) and calcx4)  # FIXME Add support later

    for ll in range(start_layer, layers):
        if operator[ll] not in [op.CONV1D, op.CONV2D, op.CONVTRANSPOSE2D]:
            kern_len[ll] = 0
            kern_offs[ll] = 0
            continue

        if flatten[ll]:
            kernel_reshaped = kernel[ll].reshape(
                output_chan[ll] * input_chan[ll],
                -1,
                kernel_size[ll][0],
                kernel_size[ll][1],
            )
        else:
            kernel_reshaped = kernel[ll]

        first_proc = ffs(processor_map[ll])
        last_proc = fls(processor_map[ll])
        ch = 0
        m = 0
        for p in range(first_proc, last_proc + 1):
            if (processor_map[ll] >> p) & 1 == 0:
                # Unused processor
                continue
            # Get highest offset for all used processors
            kern_offs[ll] = max(proc_kern_max[p], kern_offs[ll])

        ksize = kernel_size[ll][0] * kernel_size[ll][1]
        qfactor = 8 // quantization[ll]
        # Determine the number of kernels that need to be programmed. Since each instance
        # spans 4 processors, kernels for all instances that have a single processor enabled
        # need to be written, i.e. round down the first. The last does not need to be rounded
        # up because hardware takes care of it.
        next_layer_map = output_processor_map[ll]
        # When using kernels smaller than 8 bit, round up to the next 8-bit boundary
        # Gaps are accounted for like any other kernel.
        kern_len[ll] = 1 + quantization[ll] * \
            (fls(next_layer_map) - ffs(next_layer_map)) // 8
        # This extends the kernels to the right on AI85 for input and output expansion
        if output_chan[ll] > tc.dev.MAX_PROC:
            kern_len[ll] = (kern_len[ll] + tc.dev.P_SHARED -
                            1) & ~(tc.dev.P_SHARED - 1)
        kern_len[ll] *= out_expand[ll] * in_expand[ll]
        if not legacy_kernels and flatten[ll]:
            kern_len[ll] *= kernel_reshaped.shape[1]
            kern_len[ll] -= (out_expand[ll] * popcount(next_layer_map) - output_chan[ll]) \
                * kernel_reshaped.shape[1] * 8 // (ksize * quantization[ll])
        if device != 84:
            # Pack kernels when using 1D convolutions, or 1x1 kernels
            kern_len[ll] = (kern_len[ll] * ksize + 8) // 9
        if ll == 0 and quad:
            kern_len[0] = (kern_len[0] + 3) // 4

        # We don't have to use dummy columns if there's space available on the left
        kern_offs[ll] = \
            max(0, kern_offs[ll] - (((ffs(next_layer_map) % tc.dev.P_SHARED)
                                     + qfactor - 1) // qfactor))
        # The kernel offset needs to start at a multiple of 4.
        kern_offs[ll] = (kern_offs[ll] + tc.dev.P_SHARED -
                         1) & ~(tc.dev.P_SHARED - 1)
        if kern_offs[ll] + kern_len[ll] > tc.dev.mask_width(p):
            eprint(
                f'\nKernel memory exceeded at layer {ll}; offset: {kern_offs[ll]}, '
                f'needed: {kern_len[ll]}.'
                '\n\nKernel map so far:')
            print_map(layers, kernel_map, print_fn=eprint_noprefix)
            sys.exit(1)

        proc_mask = 2**qfactor - 1
        # Start at the first used instance
        this_map_init = next_layer_map >> ffs(next_layer_map)
        start_col = ffs(
            next_layer_map) % tc.dev.P_SHARED  # First target column

        for p in range(first_proc, last_proc + 1):
            if (processor_map[ll] >> p) & 1 == 0:
                # Unused source processor
                continue
            col_target = start_col
            for expand in range(out_expand[ll]):
                this_map = this_map_init
                if ll == 0 and quad:
                    col = expand * (out_expand_thresh[ll] + 3) // 4
                    stop_col = col + (out_expand_thresh[ll] + 3) // 4
                else:
                    col = expand * out_expand_thresh[ll]
                    stop_col = col + out_expand_thresh[ll]
                while col < stop_col:
                    # Skip over unused bits in the target processor map
                    # (unused means 1 bit for 8-bit weights, 2 for 4-bit weights, etc.)
                    if this_map != 0:
                        while this_map & proc_mask == 0:
                            assert this_map != 0
                            col_target += 1  # Completely skip
                            this_map >>= qfactor  # and slide forward
                    this_mask = this_map & proc_mask
                    this_map >>= qfactor

                    if ll == 0 and quad:
                        src_offs = ch + (m - p // 16) * input_chan[ll]
                    else:
                        src_offs = ch + m * input_chan[ll]
                    if ll > 0 or not quad or (m % 4 == p // 16):
                        for ie in range(in_expand[ll]):
                            mask = this_mask

                            def add_kernel_data(ll, p, col_target, b):
                                col = kern_offs[ll] + col_target
                                if col >= tc.dev.mask_width(p):
                                    eprint(
                                        f'\nKernel memory exceeded in layer {ll}.'
                                        '\n\nKernel map so far:')
                                    print_map(layers,
                                              kernel_map,
                                              print_fn=eprint_noprefix)
                                    sys.exit(1)

                                if kernels_used[p][
                                        col] == 0:  # Update kernel map
                                    assert kernel_map[p][col] == _INVALID_VALUE
                                    kernel_map[p][col] = ll

                                assert kernels_used[p][col] <= 8
                                kernel_data[p][col][
                                    8 - kernels_used[p][col]] = b & 0xff
                                kernels_used[p][col] += 1

                                if kernels_used[p][col] == 9:  # Flush
                                    col_target += 1  # Write 1

                                return col_target

                            n = 0
                            if src_offs < len(kernel_reshaped):
                                if not flatten[ll]:
                                    k = np.zeros_like(
                                        kernel_reshaped[src_offs].flatten())
                                    for i in range(qfactor):
                                        if m < output_chan[ll]:
                                            # Cycle through phases
                                            idx = n + ie * qfactor
                                            koffs = src_offs + (idx % in_expand[ll]) \
                                                * in_expand_thresh[ll] \
                                                + (idx // in_expand[ll]) \
                                                * input_chan[ll]
                                            if koffs < len(kernel_reshaped):
                                                this_kern = kernel_reshaped[koffs].flatten() \
                                                    & (2**quantization[ll]-1)
                                                k |= this_kern << (
                                                    i * quantization[ll])
                                            n += 1
                                        mask >>= 1
                                else:
                                    kl = (len(kernel_reshaped[src_offs]) +
                                          qfactor - 1) // qfactor
                                    k = np.zeros(kl, dtype=np.int64)
                                    if m < output_chan[ll]:
                                        # Cycle through phases
                                        idx = n + ie * qfactor
                                        koffs = src_offs + (idx % in_expand[ll]) \
                                            * in_expand_thresh[ll] \
                                            + (idx // in_expand[ll]) \
                                            * input_chan[ll]
                                        if koffs < len(kernel_reshaped):
                                            this_kern = kernel_reshaped[
                                                koffs].flatten()
                                            if len(this_kern) % qfactor != 0:
                                                this_kern = np.append(
                                                    this_kern,
                                                    np.zeros(qfactor -
                                                             len(this_kern) %
                                                             qfactor,
                                                             dtype=np.int64))
                                            for i in range(qfactor):
                                                k |= ((this_kern[i::qfactor]
                                                       & (2**quantization[ll]-1))) \
                                                    << (i * quantization[ll])
                                        n += 1
                                        mask >>= 1
                                if debug:
                                    with np.printoptions(
                                            formatter={
                                                'int': '{0:02x}'.format
                                            }):
                                        print(
                                            f'Layer {ll} processor {p} channel '
                                            f'{ch + ie * in_expand_thresh[ll]} m[{m}..{m+n-1}] '
                                            f'of {output_chan[ll]}: {k}')

                                if flatten[ll]:
                                    for _, e in enumerate(k):
                                        col_target = add_kernel_data(
                                            ll, p, col_target, e)
                                else:
                                    for i in range(ksize):
                                        col_target = add_kernel_data(
                                            ll, p, col_target,
                                            k[ksize - i - 1])

                            else:  # When expanding, need to pad with zero kernels if needed
                                for _ in range(ksize // qfactor):
                                    col_target = add_kernel_data(
                                        ll, p, col_target, 0)

                        # Consume kernels
                        if not flatten[ll]:
                            col += qfactor
                            m += qfactor
                        else:
                            col += 1
                            m += 1
                    else:
                        m += qfactor

            if kern_offs[ll] + col_target < tc.dev.mask_width(p) \
               and kernels_used[p][kern_offs[ll] + col_target] > 0:  # Partials
                col_target += 1
            while col_target - start_col < kern_len[ll]:
                col_target = add_kernel_data(ll, p, col_target, 0)
            if flatten[ll]:
                kern_len[ll] = col_target
            else:
                assert kern_len[ll] == col_target - start_col
            proc_kern_max[p] = kern_offs[ll] + kern_len[ll]
            ch += 1
            m = 0

    if verbose:
        print('\nKernel map:')
        print_map(layers, kernel_map)

    if verify or not (embedded_code or mexpress):
        if verify:
            apb.output('int verify_kernels(void)\n{\n')
        # Write in-line
        for p in range(tc.dev.MAX_PROC):
            for col in range(0, tc.dev.mask_width(p)):
                ll = kernel_map[p][col]
                if ll != _INVALID_VALUE:
                    k = kernel_data[p][col]
                    apb.write_kern(ll,
                                   p,
                                   col,
                                   k,
                                   verify_only=verify,
                                   calcx4=calcx4)
        if verify:
            apb.output('  return 1;\n}\n\n')
    if embedded_code or mexpress:
        # Write kernels, combining layers and processors where possible to reduce the number
        # of constants and calls to memcpy.
        apb.output('// Kernels:\n')

        if not mexpress:
            for p in range(tc.dev.MAX_PROC):
                for col in range(0, tc.dev.mask_width(p)):
                    ll = kernel_map[p][col]
                    if ll != _INVALID_VALUE:
                        k = kernel_data[p][col]
                        offs = _WORDS_PER_KERNEL * col
                        kernel_values[p][offs] = k[0] & 0xff
                        kernel_values[p][offs + 1] = (k[1] & 0xff) << 24 \
                            | (k[2] & 0xff) << 16 | (k[3] & 0xff) << 8 | k[4] & 0xff
                        kernel_values[p][offs + 2] = (k[5] & 0xff) << 24 \
                            | (k[6] & 0xff) << 16 | (k[7] & 0xff) << 8 | k[8] & 0xff

            # First, define the weights (will move to header file)
            # Combining memcopy() requires stacked memories
            max_col = [-1] * tc.dev.MAX_PROC
            min_col = [tc.dev.MASK_WIDTH_LARGE if not legacy_kernels else 0
                       ] * tc.dev.MAX_PROC
            for p in range(0, tc.dev.MAX_PROC):
                for col in range(0, tc.dev.mask_width(p)):
                    ll = kernel_map[p][col]
                    if ll != _INVALID_VALUE:
                        max_col[p] = col
                        min_col[p] = min(min_col[p], col)
            p = 0
            while p < tc.dev.MAX_PROC:
                if max_col[p] >= 0:
                    start = p
                    while (max_col[p] == tc.dev.MASK_OFFS
                           and p + 1 < tc.dev.MAX_PROC and max_col[p + 1] >= 0
                           and min_col[p + 1] == 0
                           and (start & ~(tc.dev.P_NUMPRO - 1))
                           == (p + 1 & ~(tc.dev.P_NUMPRO - 1))):
                        p += 1
                    # Combine multiple channels into one define
                    k = None
                    for i in range(start, p + 1):
                        if k is None:
                            k = kernel_values[i][min_col[i] *
                                                 _WORDS_PER_KERNEL:
                                                 (max_col[i] + 1) *
                                                 _WORDS_PER_KERNEL]
                        else:
                            k = np.concatenate(
                                (k, kernel_values[i]
                                 [min_col[i] *
                                  _WORDS_PER_KERNEL:(max_col[i] + 1) *
                                  _WORDS_PER_KERNEL]))

                    apb.output_define(k, f'KERNELS_{start}', '0x%08x', 8)
                p += 1

            # Second, initialize static const variables as source for memcpy
            p = 0
            while p < tc.dev.MAX_PROC:
                if max_col[p] >= 0:
                    span = max_col[p] + 1 - min_col[p]
                    start = p
                    while (max_col[p] == tc.dev.MASK_OFFS
                           and p + 1 < tc.dev.MAX_PROC and max_col[p + 1] >= 0
                           and min_col[p + 1] == 0
                           and (start & ~(tc.dev.P_NUMPRO - 1))
                           == (p + 1 & ~(tc.dev.P_NUMPRO - 1))):
                        p += 1
                        span += max_col[p] + 1 - min_col[p]
                    if riscv_flash:
                        apb.output(rv.RISCV_FLASH)
                    apb.output(
                        f'static const uint32_t kernels_{start}[] = KERNELS_{start};\n'
                    )
                p += 1
            apb.output('\n')

            # Generate code to load the weights using memcpy
            apb.output(
                'void memcpy_96to128(uint32_t *dst, const uint32_t *src, int n)\n{\n'
            )
            apb.output('  while (n-- > 0) {\n'
                       '    *dst++ = *src++;\n'
                       '    *dst++ = *src++;\n'
                       '    *dst++ = *src++;\n'
                       '    *dst++ = 0;  // Execute write\n'
                       '  }\n}\n\n')
        else:
            # When using the express loader, gather all consecutive kernels for each processor
            # and pack them.
            zero_kernel = np.array([0] * 9, dtype=np.uint8)
            k = None

            for p in range(tc.dev.MAX_PROC):
                # Find min/max from kernel_map
                max_col = -1
                min_col = tc.dev.mask_width(p) if not legacy_kernels else 0
                for col in range(0, tc.dev.mask_width(p)):
                    ll = kernel_map[p][col]
                    if ll != _INVALID_VALUE:
                        max_col = col
                        min_col = min(min_col, col)
                if max_col >= 0:
                    for col in range(min_col, max_col + 1):
                        ll = kernel_map[p][col]
                        if ll != _INVALID_VALUE:
                            new_k = (kernel_data[p][col] & 0xff).astype(
                                np.uint8)
                        else:
                            new_k = zero_kernel
                        if k is None:
                            k = new_k
                        else:
                            k = np.concatenate((k, new_k))

                    # Round up to multiple of 4
                    if len(k) % 4 != 0:
                        k = np.concatenate((k, zero_kernel[:4 - len(k) % 4]))
                    # '>u4' swaps endianness to what the hardware needs, `view` packs into 32-bit
                    if not blocklevel:
                        apb.output_define(k.view(dtype='>u4'), f'KERNELS_{p}',
                                          '0x%08x', 8)
                    else:
                        addr = tc.dev.C_GROUP_OFFS * (p // tc.dev.P_NUMPRO) \
                            + tc.dev.C_MRAM_BASE + (p % tc.dev.P_NUMPRO) * tc.dev.MASK_OFFS * 16
                        apb.write(addr + min_col * 4 | 0x01, 0x01)
                        kb = k.view(dtype=">u4")
                        for _, e in enumerate(kb):
                            apb.write(addr, e)
                            addr += 4

                    if riscv_flash:
                        apb.output(rv.RISCV_FLASH)
                    apb.output(
                        f'static const uint32_t kernels_{p}[] = KERNELS_{p};\n'
                    )
                    k = None
            apb.output('\n')

        if not blocklevel:
            apb.output('void load_kernels(void)\n{\n')
            max_col = [-1] * tc.dev.MAX_PROC
            min_col = [tc.dev.MASK_WIDTH_LARGE if not legacy_kernels else 0
                       ] * tc.dev.MAX_PROC
            for p in range(0, tc.dev.MAX_PROC):
                for col in range(0, tc.dev.mask_width(p)):
                    ll = kernel_map[p][col]
                    if ll != _INVALID_VALUE:
                        max_col[p] = col
                        min_col[p] = min(min_col[p], col)
            p = 0
            while p < tc.dev.MAX_PROC:
                if max_col[p] >= 0:
                    span = max_col[p] + 1 - min_col[p]
                    start = p
                    addr = apb.apb_base + tc.dev.C_GROUP_OFFS * (p // tc.dev.P_NUMPRO) \
                        + tc.dev.C_MRAM_BASE + (p % tc.dev.P_NUMPRO) * tc.dev.MASK_OFFS * 16
                    while (max_col[p] == tc.dev.MASK_OFFS
                           and p + 1 < tc.dev.MAX_PROC and max_col[p + 1] >= 0
                           and min_col[p + 1] == 0
                           and (start & ~(tc.dev.P_NUMPRO - 1))
                           == (p + 1 & ~(tc.dev.P_NUMPRO - 1))):
                        p += 1
                        span += max_col[p] + 1 - min_col[p]
                    assert addr % 16 == 0
                    if not mexpress:
                        apb.output('  memcpy_96to128((uint32_t *)'
                                   f' 0x{addr + min_col[start] * 16:08x},'
                                   f' kernels_{start}, {span});\n')
                    else:
                        apb.output(
                            '  *((volatile uint8_t *)'
                            f' 0x{addr + min_col[start] * 4 | 0x01:08x}) = 0x01; '
                            '// Set address\n')
                        apb.output(
                            f'  memcpy32((uint32_t *) 0x{addr:08x}, '
                            f'kernels_{start}, {(span * 9 + 3) // 4});\n')
                p += 1

            apb.output('}\n\n')

    return kern_offs, kern_len
Beispiel #5
0
def verify(
    verify_fn,
    ll,
    in_map,
    out_map,
    out_buf,
    processor_map,
    input_shape,
    out_offset,
    out_expand,
    out_expand_thresh,
    output_width=8,
    pool=None,
    pool_stride=None,
    overwrite_ok=False,
    no_error_stop=False,
    device=84,
    mlator=False,
    apb_base=0,
    stream=None,
    max_count=None,
    write_gap=0,
):
    """
    Verify HWC memory from AI8X, writing C or mem code using the `verify_fn` function.
    The generated code is specific to the network configuration passed in in `processor_map`,
    and `input_shape`. Additionally, the generated addresses are offset by
    `out_offset`. The function takes a pointer to a memory array, and the depth of
    the array does not matter (flattened or not flattened) as long as the size is correct.
    `in_map` and `out_map` are used to optionally prevent overwriting data
    (controlled by `overwrite_ok` and `no_error_stop`).
    When `mlator` is set, use the hardware mechanism to rearrange 4-channel data into single
    channels.
    """
    count = 0

    def check_overwrite(
        p,
        target_offs,
        in_map,
        out_map,
        c,
        row,
        col,
    ):
        # If using single layer, make sure we're not overwriting the input
        if (not overwrite_ok) and in_map[target_offs >> 2] is not None:
            old_ll, old_c, old_row, old_col, _ = in_map[target_offs >> 2]
            eprint(
                f'Processor {p}: '
                f'Layer {ll} output for CHW={c},{row},{col} is overwriting '
                f'input at offset 0x{target_offs:08x} that was created by '
                f'layer {old_ll}, CHW={old_c},{old_row},{old_col}.',
                error=not no_error_stop)
            if not no_error_stop:
                sys.exit(1)
        # Check we're not overflowing the data memory
        if (not overwrite_ok) and out_map is not None and out_map[
                target_offs >> 2] is not None:
            old_ll, old_c, old_row, old_col, old_val = out_map[target_offs
                                                               >> 2]
            eprint(
                f'Processor {p}: '
                f'Layer {ll} output for CHW={c},{row},{col} is overwriting '
                f'offset 0x{target_offs:08x}. Previous write by '
                f'layer {old_ll},CHW={old_c},{old_row},{old_col} with value 0x{old_val:08x}.',
                error=not no_error_stop)
            if not no_error_stop:
                sys.exit(1)

    # Start at the instance of the first active output processor/channel
    coffs_start = ffs(processor_map) & ~(tc.dev.P_SHARED - 1)
    next_layer_map = processor_map >> coffs_start
    # Output expansion for channels and/or wide output
    out_size = output_width // 8
    width = out_expand * out_size

    if not mlator or out_size > 1:
        if mlator:
            eprint('ignoring --mlator for 32-bit output', error=False)

        for doffs in range(input_shape[1] * input_shape[2]):
            row, col = divmod(doffs, input_shape[2])
            this_map = next_layer_map
            coffs = coffs_start
            poffs = coffs_start
            c = 0
            while c < input_shape[0]:
                if c % out_expand_thresh == 0:
                    poffs = coffs_start
                    this_map = next_layer_map  # Wrap around for AI85 channel expansion

                this_c = c
                expand = c // out_expand_thresh  # Channels 64+ handled by processors 0+
                # Physical offset into instance and group
                proc = poffs & ~(tc.dev.P_SHARED - 1)

                # Get four bytes or words either from output or zeros and construct HWC word
                no_data = True
                if out_size == 1:
                    val = 0
                    for _ in range(4):
                        val >>= 8
                        if this_map & 1:
                            no_data = False
                            if c < input_shape[0]:
                                val |= (out_buf[c][row][col] & 0xff) << 24
                            c += 1
                        this_map >>= 1
                else:
                    val = [0] * 4
                    for i in range(4):
                        if this_map & 1:
                            no_data = False
                            if c < input_shape[0]:
                                val[i] = out_buf[c][row][col] & 0xffffffff
                            c += 1
                        this_map >>= 1

                # Get the offset of the first output byte/word of 4
                offs = tc.dev.C_SRAM_BASE + out_offset - (write_gap << 2) + \
                    (((proc % tc.dev.P_NUMPRO) * tc.dev.INSTANCE_SIZE |
                      (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) +
                     (doffs * (write_gap + 1)) * width + expand * out_size) * 4

                # Special adjustment for AI84 quirk
                if device == 84 and pool and pool[0] == 4 and pool_stride[
                        0] == 4:
                    offs += (doffs // 4) * 8 + 8

                if not no_data:
                    num_bytes = min(c - this_c, input_shape[0] - this_c)
                    if out_size == 1:
                        check_overwrite(
                            proc,
                            offs,
                            in_map,
                            out_map,
                            this_c,
                            row,
                            col,
                        )
                        if out_map is not None:
                            out_map[offs >> 2] = (ll, this_c, row, col, val)
                        if max_count is None or count < max_count:
                            verify_fn(
                                offs,
                                val,
                                rv=False,
                                comment=
                                f' // {row},{col},{this_c}-{this_c+num_bytes-1}',
                                num_bytes=num_bytes,
                                first_proc=ffs(next_layer_map >> proc) % 4,
                            )
                    else:
                        for i in range(min(num_bytes, out_size)):
                            check_overwrite(
                                proc,
                                offs,
                                in_map,
                                out_map,
                                this_c,
                                row,
                                col,
                            )
                            if out_map is not None:
                                out_map[offs >> 2] = (ll, this_c, row, col,
                                                      val[i])
                            if max_count is None or count < max_count:
                                verify_fn(
                                    offs,
                                    val[i],
                                    rv=False,
                                    comment=f' // {row},{col},{this_c+i}',
                                )
                            offs += out_size
                    count += 1
                    if count == max_count:
                        stream.write('  // Truncated further checks...\n')

                coffs += 4
                poffs += 4
    else:  # mlator == True
        assert out_size == 1
        c = 0
        poffs = coffs_start
        this_map = next_layer_map
        read_addr = None

        while c < input_shape[0]:
            if c % out_expand_thresh == 0:
                poffs = coffs_start  # Wrap around for AI85 channel expansion
                this_map = next_layer_map

            expand = c // out_expand_thresh  # Channels 64+ handled by processors 0+
            # Physical offset into instance and group
            proc = poffs & ~(tc.dev.P_SHARED - 1)

            addr = tc.dev.C_CNN_BASE + (proc //
                                        tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS
            mlat = addr + tc.dev.REG_MLAT * 4
            ctrl = addr + tc.dev.REG_CTL * 4

            for shift in range(4):
                if this_map & 1:
                    for doffs in range(0, input_shape[1] * input_shape[2], 4):
                        row, col = divmod(doffs, input_shape[2])

                        # Get four bytes or words either from output or zeros and
                        # construct HWC word
                        val = 0
                        for i in range(4):
                            val >>= 8
                            if col + i < input_shape[2]:
                                val |= (out_buf[c][row][col + i] & 0xff) << 24

                        # Get the offset of the first output byte/word of 4
                        source = out_offset + \
                            (((proc % tc.dev.P_NUMPRO) * tc.dev.INSTANCE_SIZE |
                              (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) +
                             (doffs >> 2) * width) * 4

                        if source != read_addr:
                            if doffs != 0:
                                stream.write(
                                    f'  *((volatile uint32_t *) '
                                    f'0x{apb_base + ctrl:08x}) = '
                                    f'0x{tc.dev.READY_SEL << 1 | 1 << 3:08x}; '
                                    '// Disable mlator\n')
                            # Set wptr to start address
                            w = apb_base + addr + tc.dev.C_CNN*4 \
                                + tc.dev.LREG_WPTR_BASE*4 * tc.dev.MAX_LAYERS
                            stream.write(
                                f'  *((volatile uint32_t *) 0x{w:08x}) = '
                                f'0x{source >> 2:08x}; // Set SRAM address\n')
                            # Set wptr_inc to set increment value (default: 1)
                            w = apb_base + addr + tc.dev.C_CNN*4 \
                                + tc.dev.LREG_LCTL2*4 * tc.dev.MAX_LAYERS
                            stream.write(
                                f'  *((volatile uint32_t *) 0x{w:08x}) = '
                                f'0x{expand:08x}; // Set pointer increment\n')
                            # Set mlatorld enable bit to load write ptr; select byte 0..3
                            w = tc.dev.READY_SEL << 1 | 1 << 16 | shift << 17 | 1 << 3
                            stream.write(
                                f'  *((volatile uint32_t *) 0x{apb_base + ctrl:08x}) ='
                                f' 0x{w:08x}; '
                                f'// Enable mlator, byte {shift}\n')
                            stream.write(
                                '  asm volatile ("" : "=m" (*((volatile uint32_t *) '
                                f'0x{apb_base + mlat:08x})) : "r" '
                                f'(*((volatile uint32_t *) 0x{apb_base + mlat:08x})));'
                                ' // Prime\n')

                        num_bytes = min(4, input_shape[2] - col)
                        check_overwrite(
                            proc,
                            tc.dev.C_SRAM_BASE + source,
                            in_map,
                            out_map,
                            c,
                            row,
                            col,
                        )
                        if out_map is not None:
                            out_map[source >> 2] = (ll, c, row, col, val)
                        verify_fn(
                            mlat,
                            val,
                            rv=False,
                            comment=f' // {row},{col}-{col+num_bytes-1},{c}',
                            num_bytes=num_bytes,
                        )

                        read_addr = source + 4
                    # Disable mlator
                    stream.write(f'  *((volatile uint32_t *) '
                                 f'0x{apb_base + ctrl:08x}) = '
                                 f'0x{tc.dev.READY_SEL << 1 | 1 << 3:08x}; '
                                 '// Disable mlator\n')

                this_map >>= 1
                c += 1

            poffs += 4
Beispiel #6
0
def unload(
    memfile,
    apb_base,
    processor_map,
    input_shape,
    out_offset,
    out_expand,
    out_expand_thresh,
    output_width=8,
    pool=None,
    pool_stride=None,
    device=84,
    mlator=False,
    blocklevel=False,
):
    """
    Unload HWC memory from AI84, writing C code to the `memfile` handle.
    The generated C code is specific to the network configuration passed in in `processor_map`,
    and `input_shape`. Additionally, the generated addresses are offset by `apb_base` and
    `out_offset`. The C code function takes a pointer to a memory array, and the depth of
    the array does not matter (flattened or not flattened) as long as the size is correct.
    When `mlator` is set, use the hardware mechanism to rearrange 4-channel data into single
    channels.
    """
    assert not blocklevel or not mlator

    memfile.write('// Custom unload for this network:\n'
                  f'// {output_width}-bit data, shape: {input_shape}\n'
                  f'void cnn_unload(uint{output_width}_t *out_buf)\n'
                  '{\n'
                  '  volatile uint32_t *addr;\n')
    if output_width != 32:
        if input_shape[1] * input_shape[2] == 1:
            memfile.write('  uint32_t val;\n')
        else:
            memfile.write('  uint32_t val, offs;\n')
    if mlator:
        memfile.write('  uint32_t *out_buf32 = (uint32_t *) out_buf;\n\n')
    else:
        memfile.write('\n')

    coffs_start = ffs(processor_map) & ~(tc.dev.P_SHARED - 1)
    coffs = coffs_start
    poffs = coffs_start
    next_layer_map_init = processor_map >> coffs
    next_layer_map = next_layer_map_init

    # Output expansion for channels and/or wide output
    out_size = output_width // 8
    width = out_expand * out_size

    read_addr = None
    write_addr = None
    mlat_addr = None
    c = 0
    while c < input_shape[0]:
        if c % out_expand_thresh == 0:
            poffs = coffs_start
            next_layer_map = next_layer_map_init

        expand = c // out_expand_thresh  # Channels 64+ handled by processors 0+
        proc = poffs & ~(tc.dev.P_SHARED - 1)

        if not mlator or out_size > 1:
            for doffs in range(input_shape[1] * input_shape[2]):
                row, col = divmod(doffs, input_shape[2])
                this_map = next_layer_map
                this_c = c

                # Get four bytes from memory array
                offs = out_offset + \
                    (((proc % tc.dev.P_NUMPRO) * tc.dev.INSTANCE_SIZE |
                      (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) +
                     doffs * width + expand * out_size) * 4

                if device == 84 and pool and pool[0] == 4 and pool_stride[
                        0] == 4:
                    offs += (doffs // 4) * 8 + 8

                if offs != read_addr:
                    memfile.write(
                        '  addr = (volatile uint32_t *) '
                        f'0x{apb_base + tc.dev.C_SRAM_BASE + offs:08x};\n')
                if out_size != 4:
                    memfile.write('  val = *addr++;\n')
                    read_addr = offs + 4
                else:
                    read_addr = offs

                # Singulate bytes, ignoring unused processors
                for shift in range(4):
                    addr = this_c * input_shape[1] * input_shape[
                        2] + row * input_shape[1] + col
                    if (shift == 0 or out_size > 1) \
                       and out_size != 4 and input_shape[1] * input_shape[2] != 1:
                        if addr != write_addr:
                            memfile.write(f'  offs = 0x{addr:04x};\n')
                        else:
                            memfile.write('  offs++;\n')
                        write_addr = addr + 1
                    if this_map & 1:
                        if out_size != 4:
                            if input_shape[1] * input_shape[2] != 1:
                                memfile.write('  out_buf[offs')
                                if shift > 0:
                                    memfile.write(f'+0x{0x10 * shift:02x}')
                                memfile.write('] = ')
                            else:
                                memfile.write('  *out_buf++ = ')
                            if shift == 0:
                                memfile.write('val')
                            else:
                                memfile.write(f'(val >> {shift * 8})')
                            if out_size == 1:
                                memfile.write(' & 0xff;\n')
                            else:
                                memfile.write(';\n')
                        else:  # out_size == 4
                            memfile.write('  *out_buf++ = *addr++;\n')
                            write_addr = addr + 4
                            read_addr += 4

                        this_c += 1
                    this_map >>= 1
        else:  # mlator
            assert out_size == 1
            this_map = next_layer_map
            addr = apb_base + tc.dev.C_CNN_BASE + (
                proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS
            mlat = addr + tc.dev.REG_MLAT * 4
            if mlat_addr != mlat:
                mlat_addr = mlat
                ctrl = addr + tc.dev.REG_CTL * 4
                memfile.write(
                    f'  ctrl = (volatile uint32_t *) 0x{ctrl:08x};\n')
                memfile.write(
                    f'  mlat = (volatile uint32_t *) 0x{mlat:08x};\n')

            this_c = c
            for shift in range(4):
                if this_map & 1:
                    memfile.write(f'  // Channel {this_c}\n')

                    for doffs in range(0, input_shape[1] * input_shape[2], 4):
                        row, col = divmod(doffs, input_shape[2])

                        # Get four bytes from memory
                        source = out_offset + \
                            (((proc % tc.dev.P_NUMPRO) * tc.dev.INSTANCE_SIZE |
                              (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) +
                             (doffs >> 2) * width + expand * out_size) * 4
                        target = this_c * input_shape[1] * input_shape[2] \
                            + row * input_shape[1] + col
                        assert target & 3 == 0

                        if target != write_addr:
                            memfile.write(f'  offs = 0x{target >> 2:04x};\n')
                        if source != read_addr:
                            if doffs != 0:
                                memfile.write(
                                    f'  *ctrl = 0x{tc.dev.READY_SEL << 1 | 1 << 3:08x}; '
                                    '// Disable mlator\n')
                            # Set wptr to start address
                            val = addr + tc.dev.C_CNN*4 \
                                + tc.dev.LREG_WPTR_BASE*4 * tc.dev.MAX_LAYERS
                            memfile.write(
                                f'  *((volatile uint32_t *) 0x{val:08x}) = '
                                f'0x{doffs:08x}; // Set SRAM address\n')
                            # Set wptr_inc to set increment value (default: 1)
                            val = addr + tc.dev.C_CNN*4 \
                                + tc.dev.LREG_LCTL2*4 * tc.dev.MAX_LAYERS
                            memfile.write(
                                f'  *((volatile uint32_t *) 0x{val:08x}) = '
                                f'0x{expand:08x}; // Set pointer increment\n')
                            # Set mlatorld enable bit to load write ptr; select byte 0..3
                            val = tc.dev.READY_SEL << 1 | 1 << 16 | shift << 17 | 1 << 3
                            memfile.write(f'  *ctrl = 0x{val:08x}; '
                                          f'// Enable mlator, byte {shift}\n')
                            # memfile.write('  val = *mlat; // Prime\n')
                            memfile.write(
                                '  asm volatile ("" : "=m" (*mlat) : "r" (*mlat));'
                                ' // Prime\n')

                        # FIXME: Do not write more than `num_bytes = min(4, input_shape[2] - col)`
                        memfile.write('  out_buf32[offs++] = *mlat;'
                                      f' // {this_c},{row},{col}-{col+3}\n')
                        read_addr = source + 4
                        write_addr = target + 4

                    # Disable mlator
                    memfile.write(
                        f'  *ctrl = 0x{tc.dev.READY_SEL << 1 | 1 << 3:08x}; '
                        '// Disable mlator\n')
                this_c += 1

                this_map >>= 1

        coffs += 4
        poffs += 4
        c += popcount(next_layer_map & 0x0f)
        next_layer_map >>= 4

    memfile.write('}\n\n')