示例#1
0
 def __init__(self, n, num):
     """for example: n=5, indices=[0,2] represents 10100"""
     self.n = n
     self.num = num
     self.indices = num_and_n_to_indices(num, n)
     self.binary_tuple = tuple(str(bin(self.num))[2:])
     self.popcount = popcount(num)
示例#2
0
    def __init__(self, N):
        self.N = N
        self.NUM_TO_POINT = {}
        self.LEVEL_TO_POINTS = {i: [] for i in xrange(N + 1)}
        for i in xrange(2**N):
            self.NUM_TO_POINT[i] = Point(N, i)
            curr_level = popcount(i)
            self.LEVEL_TO_POINTS[curr_level].append(self.NUM_TO_POINT[i])

        for l in self.LEVEL_TO_POINTS.values():
            l.sort(key=lambda p: p.num)
示例#3
0
def precomp():
    global WMAT, HWTABLE, BITTABLE
    if CLUSTERSIZE <= 8:
        WMAT = hadamard(2**CLUSTERSIZE)
    else:
        WMAT = None

    HWTABLE = np.zeros(shape=2**CLUSTERSIZE, dtype=int)
    for i in range(2**CLUSTERSIZE):
        HWTABLE[i] = popcount(i)

    BITTABLE = np.zeros(shape=(2**CLUSTERSIZE, CLUSTERSIZE), dtype=np.bool)
    for i in range(2**CLUSTERSIZE):
        BITTABLE[i, :] = np.array(intToBits(i, CLUSTERSIZE))
    def initMessages(self):
        totcard = sum([x.cardinality for x in self.edges])
        assert totcard == self.wordsize

        if len(
                self.edges
        ) == 1:  #if leaf --> just broadcast entire distribution on edge. Done in init, not needed later
            self.isLeaf = True

            edge = self.edges[0]
            assert (self.wordsize == edge.cardinality)

            msgout = np.zeros(shape=self.numvals,
                              dtype=settings.NUMPY_DATATYPE)
            # for i in range(self.numvals):
            #     msgout[i] = self.dist[popcount(i)]
            hws = settings.HWTABLE[:2**self.wordsize]
            for i in range(self.wordsize + 1):
                idx = (hws == i)
                msgout[idx] = self.dist[i]

            msgout = msgout / np.sum(msgout)

            edge.m2n = msgout

        else:
            super().initMessages()
            self.nodeTable = np.zeros(shape=totcard, dtype=int)

            nodeidx = 0
            for (edgeidx, edge) in enumerate(self.edges):
                vals = 2**edge.cardinality
                factorexpand = np.zeros(shape=vals, dtype=int)
                for i in range(vals):
                    factorexpand[i] = popcount(i)

                edge.factorexpand = factorexpand
                for i in range(edge.cardinality):
                    self.nodeTable[nodeidx] = edgeidx
                    nodeidx += 1
示例#5
0
def sample_monotone_function(ctx):
    # This raises a natural question - how to randomly sample a monotone function?
    # for now, we will start with the zero function, randomize point, and flip them with some probability
    points_to_values = {p: 1 for p in ctx.NUM_TO_POINT.values()}
    n = ctx.N
    for num_flip in xrange(max(2**int(n / 2),
                               100)):  # for now, magic numbers...
        num = random.randint(0, 2**n - 1)
        point = ctx.NUM_TO_POINT[num]
        level = popcount(num)
        dist_from_mid_level = 1 + abs(level - float(n) / 2)
        # TODO - maybe flip with respect to level (extreme levels don't flip
        to_flip = random.random() < 0.5 * (1 / dist_from_mid_level)**2
        if to_flip:
            if points_to_values[point] == 1:
                all_to_flip = get_all_upper_neighbours(point, ctx)
                for p_to_flip in all_to_flip:
                    points_to_values[p_to_flip] = -1
            else:
                all_to_flip = get_all_lower_neighbours(point, ctx)
                for p_to_flip in all_to_flip:
                    points_to_values[p_to_flip] = 1

    return BooleanFunction(ctx, points_to_values)
示例#6
0
def unload(apb_base,
           processor_map,
           input_shape,
           out_array,
           out_offset,
           in_array,
           flatten=False):
    """
    Unload HWC memory from AI84 and return it in `out_array`.
    The generated C code is specific to the network configuration passed in in `processor_map`,
    `input_shape`, and `chan`. Additionally, the generated addresses are offset by `apb_base` and
    `out_offset`. The C code function takes a pointer to a memory array, and the dimensions of
    the array do not matter (flattened or not flattened).
    The additional simulation code takes the `flatten` parameter and an `in_array`.
    If `flatten` is `True`, then the out_array is flattened.
    """
    def get_val(offs):
        """
        Returns value stored at offset `offs` in the memory array.
        """
        if offs >= (MEM_SIZE << 2) or offs < 0:
            raise RuntimeError(
                f'Offset {offs:04x} is invalid for the memory array.')
        if offs & 3:
            raise RuntimeError(
                f'Offset {offs:04x} should be a 32-bit address.')
        if in_array[offs >> 2] == MEM_INVALID:
            raise RuntimeError(
                f'Trying to read from uninitialized memory at location {offs:04x}.'
            )
        return in_array[offs >> 2]

    print('\n// Custom unload for this network:\n'
          f'// Input shape: {input_shape}\n'
          'void unload(uint8_t *out_buf)\n'
          '{\n  uint32_t val, *addr, offs;\n')

    coffs = ffs(processor_map) & ~(tc.dev.P_SHARED - 1)
    next_layer_map = processor_map >> coffs
    read_addr = None
    write_addr = None
    c = 0
    while c < input_shape[0]:
        for doffs in range(input_shape[1] * input_shape[2]):
            row, col = divmod(doffs, input_shape[2])
            this_map = next_layer_map
            this_c = c

            # Get four bytes from memory array
            proc = (coffs % tc.dev.MAX_PROC) & ~(tc.dev.P_SHARED - 1)
            # FIXME: seq = ...
            offs = out_offset + \
                (((proc % tc.dev.P_NUMPRO) * tc.dev.INSTANCE_SIZE |
                  (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) +
                 doffs) * 4

            val = get_val(offs)

            if offs != read_addr:
                print(
                    f'  addr = (uint32_t *) 0x{apb_base + tc.dev.C_SRAM_BASE + offs:08x};'
                )
            print('  val = *addr++;')
            read_addr = offs + 4

            # Singulate bytes, ignoring unused processors
            for shift in range(4):
                addr = this_c * input_shape[1] * input_shape[
                    2] + row * input_shape[1] + col
                if shift == 0:
                    if addr != write_addr:
                        print(f'  offs = 0x{addr:04x};')
                    else:
                        print('  offs++;')
                    write_addr = addr + 1
                if this_map & 1:
                    if not flatten:
                        out_array[this_c][row][col] = val & 0xff
                    else:
                        out_array[addr] = val & 0xff
                    print('  out_buf[offs', end='')
                    if shift > 0:
                        print(f'+0x{0x10 * shift:02x}', end='')
                    print('] = ', end='')
                    if shift == 0:
                        print('val', end='')
                    else:
                        print(f'(val >> {shift * 8})', end='')
                    print(' & 0xff;')
                    this_c += 1
                this_map >>= 1
                val >>= 8

        coffs += 4
        c += popcount(next_layer_map & 0x0f)
        next_layer_map >>= 4

    print('}')
示例#7
0
def load(  # pylint: disable=too-many-branches,too-many-statements
    verbose,
    embedded_code,
    device,
    apb,
    start_layer,
    layers,
    operator,
    kernel,
    kernel_size,
    quantization,
    processor_map,
    output_processor_map,
    input_chan,
    output_chan,
    out_expand,
    out_expand_thresh,
    in_expand,
    in_expand_thresh,
    flatten=False,
    mexpress=False,
    verify=False,
    riscv_flash=False,
    quad=False,
    debug=False,
    blocklevel=False,
    legacy_kernels=False,
    calcx4=False,
):
    """
    Stack `kernel` values and write them to C code (for `embedded_code` if `True` or
    RTL simulation). The output is written to the `apb` object.
    Input is configured with `kernel_size`, `quantization`, `layers`, `processor_map`,
    `output_processor_map`, `input_chan`, `output_chan`, `out_expand` and `out_expand_thresh`.
    When `mexpress` is `True`, the function uses the memcpy()-friendly hardware functionality to
    reduce the number of transfers. When `verify` is also true (mexpress mode only), kernels are
    read back and compared.
    This function returns the kernel offsets and the kernel lengths for all layers.
    """
    # Kernels: Stack kernels; write only the kernels needed
    proc_kern_max = [0] * tc.dev.MAX_PROC
    kern_offs = [0] * layers
    kern_len = [0] * layers
    kernel_map = np.full((tc.dev.MAX_PROC, tc.dev.MASK_WIDTH_LARGE),
                         _INVALID_VALUE,
                         dtype=np.int64)
    kernels_used = np.zeros((tc.dev.MAX_PROC, tc.dev.MASK_WIDTH_LARGE),
                            dtype=np.int64)
    kernel_data = np.zeros((tc.dev.MAX_PROC, tc.dev.MASK_WIDTH_LARGE, 9),
                           dtype=np.int8)
    # There are four 32-bit words per 9-byte kernel.
    # The value map is initialized with zeros so we can later ignore unused entries and use
    # memcpy() on initialized and uninitialized data.
    kernel_values = np.zeros(
        (tc.dev.MAX_PROC, tc.dev.MASK_WIDTH_LARGE * _WORDS_PER_KERNEL),
        dtype=np.int64)
    if debug:
        print('\nLoading Kernels...')

    if calcx4 and not tc.dev.SUPPORT_CALCX4:
        eprint('--calcx4 is not supported on this device.')
        sys.exit(1)
    assert not (
        (embedded_code or mexpress) and calcx4)  # FIXME Add support later

    for ll in range(start_layer, layers):
        if operator[ll] not in [op.CONV1D, op.CONV2D, op.CONVTRANSPOSE2D]:
            kern_len[ll] = 0
            kern_offs[ll] = 0
            continue

        if flatten[ll]:
            kernel_reshaped = kernel[ll].reshape(
                output_chan[ll] * input_chan[ll],
                -1,
                kernel_size[ll][0],
                kernel_size[ll][1],
            )
        else:
            kernel_reshaped = kernel[ll]

        first_proc = ffs(processor_map[ll])
        last_proc = fls(processor_map[ll])
        ch = 0
        m = 0
        for p in range(first_proc, last_proc + 1):
            if (processor_map[ll] >> p) & 1 == 0:
                # Unused processor
                continue
            # Get highest offset for all used processors
            kern_offs[ll] = max(proc_kern_max[p], kern_offs[ll])

        ksize = kernel_size[ll][0] * kernel_size[ll][1]
        qfactor = 8 // quantization[ll]
        # Determine the number of kernels that need to be programmed. Since each instance
        # spans 4 processors, kernels for all instances that have a single processor enabled
        # need to be written, i.e. round down the first. The last does not need to be rounded
        # up because hardware takes care of it.
        next_layer_map = output_processor_map[ll]
        # When using kernels smaller than 8 bit, round up to the next 8-bit boundary
        # Gaps are accounted for like any other kernel.
        kern_len[ll] = 1 + quantization[ll] * \
            (fls(next_layer_map) - ffs(next_layer_map)) // 8
        # This extends the kernels to the right on AI85 for input and output expansion
        if output_chan[ll] > tc.dev.MAX_PROC:
            kern_len[ll] = (kern_len[ll] + tc.dev.P_SHARED -
                            1) & ~(tc.dev.P_SHARED - 1)
        kern_len[ll] *= out_expand[ll] * in_expand[ll]
        if not legacy_kernels and flatten[ll]:
            kern_len[ll] *= kernel_reshaped.shape[1]
            kern_len[ll] -= (out_expand[ll] * popcount(next_layer_map) - output_chan[ll]) \
                * kernel_reshaped.shape[1] * 8 // (ksize * quantization[ll])
        if device != 84:
            # Pack kernels when using 1D convolutions, or 1x1 kernels
            kern_len[ll] = (kern_len[ll] * ksize + 8) // 9
        if ll == 0 and quad:
            kern_len[0] = (kern_len[0] + 3) // 4

        # We don't have to use dummy columns if there's space available on the left
        kern_offs[ll] = \
            max(0, kern_offs[ll] - (((ffs(next_layer_map) % tc.dev.P_SHARED)
                                     + qfactor - 1) // qfactor))
        # The kernel offset needs to start at a multiple of 4.
        kern_offs[ll] = (kern_offs[ll] + tc.dev.P_SHARED -
                         1) & ~(tc.dev.P_SHARED - 1)
        if kern_offs[ll] + kern_len[ll] > tc.dev.mask_width(p):
            eprint(
                f'\nKernel memory exceeded at layer {ll}; offset: {kern_offs[ll]}, '
                f'needed: {kern_len[ll]}.'
                '\n\nKernel map so far:')
            print_map(layers, kernel_map, print_fn=eprint_noprefix)
            sys.exit(1)

        proc_mask = 2**qfactor - 1
        # Start at the first used instance
        this_map_init = next_layer_map >> ffs(next_layer_map)
        start_col = ffs(
            next_layer_map) % tc.dev.P_SHARED  # First target column

        for p in range(first_proc, last_proc + 1):
            if (processor_map[ll] >> p) & 1 == 0:
                # Unused source processor
                continue
            col_target = start_col
            for expand in range(out_expand[ll]):
                this_map = this_map_init
                if ll == 0 and quad:
                    col = expand * (out_expand_thresh[ll] + 3) // 4
                    stop_col = col + (out_expand_thresh[ll] + 3) // 4
                else:
                    col = expand * out_expand_thresh[ll]
                    stop_col = col + out_expand_thresh[ll]
                while col < stop_col:
                    # Skip over unused bits in the target processor map
                    # (unused means 1 bit for 8-bit weights, 2 for 4-bit weights, etc.)
                    if this_map != 0:
                        while this_map & proc_mask == 0:
                            assert this_map != 0
                            col_target += 1  # Completely skip
                            this_map >>= qfactor  # and slide forward
                    this_mask = this_map & proc_mask
                    this_map >>= qfactor

                    if ll == 0 and quad:
                        src_offs = ch + (m - p // 16) * input_chan[ll]
                    else:
                        src_offs = ch + m * input_chan[ll]
                    if ll > 0 or not quad or (m % 4 == p // 16):
                        for ie in range(in_expand[ll]):
                            mask = this_mask

                            def add_kernel_data(ll, p, col_target, b):
                                col = kern_offs[ll] + col_target
                                if col >= tc.dev.mask_width(p):
                                    eprint(
                                        f'\nKernel memory exceeded in layer {ll}.'
                                        '\n\nKernel map so far:')
                                    print_map(layers,
                                              kernel_map,
                                              print_fn=eprint_noprefix)
                                    sys.exit(1)

                                if kernels_used[p][
                                        col] == 0:  # Update kernel map
                                    assert kernel_map[p][col] == _INVALID_VALUE
                                    kernel_map[p][col] = ll

                                assert kernels_used[p][col] <= 8
                                kernel_data[p][col][
                                    8 - kernels_used[p][col]] = b & 0xff
                                kernels_used[p][col] += 1

                                if kernels_used[p][col] == 9:  # Flush
                                    col_target += 1  # Write 1

                                return col_target

                            n = 0
                            if src_offs < len(kernel_reshaped):
                                if not flatten[ll]:
                                    k = np.zeros_like(
                                        kernel_reshaped[src_offs].flatten())
                                    for i in range(qfactor):
                                        if m < output_chan[ll]:
                                            # Cycle through phases
                                            idx = n + ie * qfactor
                                            koffs = src_offs + (idx % in_expand[ll]) \
                                                * in_expand_thresh[ll] \
                                                + (idx // in_expand[ll]) \
                                                * input_chan[ll]
                                            if koffs < len(kernel_reshaped):
                                                this_kern = kernel_reshaped[koffs].flatten() \
                                                    & (2**quantization[ll]-1)
                                                k |= this_kern << (
                                                    i * quantization[ll])
                                            n += 1
                                        mask >>= 1
                                else:
                                    kl = (len(kernel_reshaped[src_offs]) +
                                          qfactor - 1) // qfactor
                                    k = np.zeros(kl, dtype=np.int64)
                                    if m < output_chan[ll]:
                                        # Cycle through phases
                                        idx = n + ie * qfactor
                                        koffs = src_offs + (idx % in_expand[ll]) \
                                            * in_expand_thresh[ll] \
                                            + (idx // in_expand[ll]) \
                                            * input_chan[ll]
                                        if koffs < len(kernel_reshaped):
                                            this_kern = kernel_reshaped[
                                                koffs].flatten()
                                            if len(this_kern) % qfactor != 0:
                                                this_kern = np.append(
                                                    this_kern,
                                                    np.zeros(qfactor -
                                                             len(this_kern) %
                                                             qfactor,
                                                             dtype=np.int64))
                                            for i in range(qfactor):
                                                k |= ((this_kern[i::qfactor]
                                                       & (2**quantization[ll]-1))) \
                                                    << (i * quantization[ll])
                                        n += 1
                                        mask >>= 1
                                if debug:
                                    with np.printoptions(
                                            formatter={
                                                'int': '{0:02x}'.format
                                            }):
                                        print(
                                            f'Layer {ll} processor {p} channel '
                                            f'{ch + ie * in_expand_thresh[ll]} m[{m}..{m+n-1}] '
                                            f'of {output_chan[ll]}: {k}')

                                if flatten[ll]:
                                    for _, e in enumerate(k):
                                        col_target = add_kernel_data(
                                            ll, p, col_target, e)
                                else:
                                    for i in range(ksize):
                                        col_target = add_kernel_data(
                                            ll, p, col_target,
                                            k[ksize - i - 1])

                            else:  # When expanding, need to pad with zero kernels if needed
                                for _ in range(ksize // qfactor):
                                    col_target = add_kernel_data(
                                        ll, p, col_target, 0)

                        # Consume kernels
                        if not flatten[ll]:
                            col += qfactor
                            m += qfactor
                        else:
                            col += 1
                            m += 1
                    else:
                        m += qfactor

            if kern_offs[ll] + col_target < tc.dev.mask_width(p) \
               and kernels_used[p][kern_offs[ll] + col_target] > 0:  # Partials
                col_target += 1
            while col_target - start_col < kern_len[ll]:
                col_target = add_kernel_data(ll, p, col_target, 0)
            if flatten[ll]:
                kern_len[ll] = col_target
            else:
                assert kern_len[ll] == col_target - start_col
            proc_kern_max[p] = kern_offs[ll] + kern_len[ll]
            ch += 1
            m = 0

    if verbose:
        print('\nKernel map:')
        print_map(layers, kernel_map)

    if verify or not (embedded_code or mexpress):
        if verify:
            apb.output('int verify_kernels(void)\n{\n')
        # Write in-line
        for p in range(tc.dev.MAX_PROC):
            for col in range(0, tc.dev.mask_width(p)):
                ll = kernel_map[p][col]
                if ll != _INVALID_VALUE:
                    k = kernel_data[p][col]
                    apb.write_kern(ll,
                                   p,
                                   col,
                                   k,
                                   verify_only=verify,
                                   calcx4=calcx4)
        if verify:
            apb.output('  return 1;\n}\n\n')
    if embedded_code or mexpress:
        # Write kernels, combining layers and processors where possible to reduce the number
        # of constants and calls to memcpy.
        apb.output('// Kernels:\n')

        if not mexpress:
            for p in range(tc.dev.MAX_PROC):
                for col in range(0, tc.dev.mask_width(p)):
                    ll = kernel_map[p][col]
                    if ll != _INVALID_VALUE:
                        k = kernel_data[p][col]
                        offs = _WORDS_PER_KERNEL * col
                        kernel_values[p][offs] = k[0] & 0xff
                        kernel_values[p][offs + 1] = (k[1] & 0xff) << 24 \
                            | (k[2] & 0xff) << 16 | (k[3] & 0xff) << 8 | k[4] & 0xff
                        kernel_values[p][offs + 2] = (k[5] & 0xff) << 24 \
                            | (k[6] & 0xff) << 16 | (k[7] & 0xff) << 8 | k[8] & 0xff

            # First, define the weights (will move to header file)
            # Combining memcopy() requires stacked memories
            max_col = [-1] * tc.dev.MAX_PROC
            min_col = [tc.dev.MASK_WIDTH_LARGE if not legacy_kernels else 0
                       ] * tc.dev.MAX_PROC
            for p in range(0, tc.dev.MAX_PROC):
                for col in range(0, tc.dev.mask_width(p)):
                    ll = kernel_map[p][col]
                    if ll != _INVALID_VALUE:
                        max_col[p] = col
                        min_col[p] = min(min_col[p], col)
            p = 0
            while p < tc.dev.MAX_PROC:
                if max_col[p] >= 0:
                    start = p
                    while (max_col[p] == tc.dev.MASK_OFFS
                           and p + 1 < tc.dev.MAX_PROC and max_col[p + 1] >= 0
                           and min_col[p + 1] == 0
                           and (start & ~(tc.dev.P_NUMPRO - 1))
                           == (p + 1 & ~(tc.dev.P_NUMPRO - 1))):
                        p += 1
                    # Combine multiple channels into one define
                    k = None
                    for i in range(start, p + 1):
                        if k is None:
                            k = kernel_values[i][min_col[i] *
                                                 _WORDS_PER_KERNEL:
                                                 (max_col[i] + 1) *
                                                 _WORDS_PER_KERNEL]
                        else:
                            k = np.concatenate(
                                (k, kernel_values[i]
                                 [min_col[i] *
                                  _WORDS_PER_KERNEL:(max_col[i] + 1) *
                                  _WORDS_PER_KERNEL]))

                    apb.output_define(k, f'KERNELS_{start}', '0x%08x', 8)
                p += 1

            # Second, initialize static const variables as source for memcpy
            p = 0
            while p < tc.dev.MAX_PROC:
                if max_col[p] >= 0:
                    span = max_col[p] + 1 - min_col[p]
                    start = p
                    while (max_col[p] == tc.dev.MASK_OFFS
                           and p + 1 < tc.dev.MAX_PROC and max_col[p + 1] >= 0
                           and min_col[p + 1] == 0
                           and (start & ~(tc.dev.P_NUMPRO - 1))
                           == (p + 1 & ~(tc.dev.P_NUMPRO - 1))):
                        p += 1
                        span += max_col[p] + 1 - min_col[p]
                    if riscv_flash:
                        apb.output(rv.RISCV_FLASH)
                    apb.output(
                        f'static const uint32_t kernels_{start}[] = KERNELS_{start};\n'
                    )
                p += 1
            apb.output('\n')

            # Generate code to load the weights using memcpy
            apb.output(
                'void memcpy_96to128(uint32_t *dst, const uint32_t *src, int n)\n{\n'
            )
            apb.output('  while (n-- > 0) {\n'
                       '    *dst++ = *src++;\n'
                       '    *dst++ = *src++;\n'
                       '    *dst++ = *src++;\n'
                       '    *dst++ = 0;  // Execute write\n'
                       '  }\n}\n\n')
        else:
            # When using the express loader, gather all consecutive kernels for each processor
            # and pack them.
            zero_kernel = np.array([0] * 9, dtype=np.uint8)
            k = None

            for p in range(tc.dev.MAX_PROC):
                # Find min/max from kernel_map
                max_col = -1
                min_col = tc.dev.mask_width(p) if not legacy_kernels else 0
                for col in range(0, tc.dev.mask_width(p)):
                    ll = kernel_map[p][col]
                    if ll != _INVALID_VALUE:
                        max_col = col
                        min_col = min(min_col, col)
                if max_col >= 0:
                    for col in range(min_col, max_col + 1):
                        ll = kernel_map[p][col]
                        if ll != _INVALID_VALUE:
                            new_k = (kernel_data[p][col] & 0xff).astype(
                                np.uint8)
                        else:
                            new_k = zero_kernel
                        if k is None:
                            k = new_k
                        else:
                            k = np.concatenate((k, new_k))

                    # Round up to multiple of 4
                    if len(k) % 4 != 0:
                        k = np.concatenate((k, zero_kernel[:4 - len(k) % 4]))
                    # '>u4' swaps endianness to what the hardware needs, `view` packs into 32-bit
                    if not blocklevel:
                        apb.output_define(k.view(dtype='>u4'), f'KERNELS_{p}',
                                          '0x%08x', 8)
                    else:
                        addr = tc.dev.C_GROUP_OFFS * (p // tc.dev.P_NUMPRO) \
                            + tc.dev.C_MRAM_BASE + (p % tc.dev.P_NUMPRO) * tc.dev.MASK_OFFS * 16
                        apb.write(addr + min_col * 4 | 0x01, 0x01)
                        kb = k.view(dtype=">u4")
                        for _, e in enumerate(kb):
                            apb.write(addr, e)
                            addr += 4

                    if riscv_flash:
                        apb.output(rv.RISCV_FLASH)
                    apb.output(
                        f'static const uint32_t kernels_{p}[] = KERNELS_{p};\n'
                    )
                    k = None
            apb.output('\n')

        if not blocklevel:
            apb.output('void load_kernels(void)\n{\n')
            max_col = [-1] * tc.dev.MAX_PROC
            min_col = [tc.dev.MASK_WIDTH_LARGE if not legacy_kernels else 0
                       ] * tc.dev.MAX_PROC
            for p in range(0, tc.dev.MAX_PROC):
                for col in range(0, tc.dev.mask_width(p)):
                    ll = kernel_map[p][col]
                    if ll != _INVALID_VALUE:
                        max_col[p] = col
                        min_col[p] = min(min_col[p], col)
            p = 0
            while p < tc.dev.MAX_PROC:
                if max_col[p] >= 0:
                    span = max_col[p] + 1 - min_col[p]
                    start = p
                    addr = apb.apb_base + tc.dev.C_GROUP_OFFS * (p // tc.dev.P_NUMPRO) \
                        + tc.dev.C_MRAM_BASE + (p % tc.dev.P_NUMPRO) * tc.dev.MASK_OFFS * 16
                    while (max_col[p] == tc.dev.MASK_OFFS
                           and p + 1 < tc.dev.MAX_PROC and max_col[p + 1] >= 0
                           and min_col[p + 1] == 0
                           and (start & ~(tc.dev.P_NUMPRO - 1))
                           == (p + 1 & ~(tc.dev.P_NUMPRO - 1))):
                        p += 1
                        span += max_col[p] + 1 - min_col[p]
                    assert addr % 16 == 0
                    if not mexpress:
                        apb.output('  memcpy_96to128((uint32_t *)'
                                   f' 0x{addr + min_col[start] * 16:08x},'
                                   f' kernels_{start}, {span});\n')
                    else:
                        apb.output(
                            '  *((volatile uint8_t *)'
                            f' 0x{addr + min_col[start] * 4 | 0x01:08x}) = 0x01; '
                            '// Set address\n')
                        apb.output(
                            f'  memcpy32((uint32_t *) 0x{addr:08x}, '
                            f'kernels_{start}, {(span * 9 + 3) // 4});\n')
                p += 1

            apb.output('}\n\n')

    return kern_offs, kern_len
示例#8
0
def unload(
    memfile,
    apb_base,
    processor_map,
    input_shape,
    out_offset,
    out_expand,
    out_expand_thresh,
    output_width=8,
    pool=None,
    pool_stride=None,
    device=84,
    mlator=False,
    blocklevel=False,
):
    """
    Unload HWC memory from AI84, writing C code to the `memfile` handle.
    The generated C code is specific to the network configuration passed in in `processor_map`,
    and `input_shape`. Additionally, the generated addresses are offset by `apb_base` and
    `out_offset`. The C code function takes a pointer to a memory array, and the depth of
    the array does not matter (flattened or not flattened) as long as the size is correct.
    When `mlator` is set, use the hardware mechanism to rearrange 4-channel data into single
    channels.
    """
    assert not blocklevel or not mlator

    memfile.write('// Custom unload for this network:\n'
                  f'// {output_width}-bit data, shape: {input_shape}\n'
                  f'void cnn_unload(uint{output_width}_t *out_buf)\n'
                  '{\n'
                  '  volatile uint32_t *addr;\n')
    if output_width != 32:
        if input_shape[1] * input_shape[2] == 1:
            memfile.write('  uint32_t val;\n')
        else:
            memfile.write('  uint32_t val, offs;\n')
    if mlator:
        memfile.write('  uint32_t *out_buf32 = (uint32_t *) out_buf;\n\n')
    else:
        memfile.write('\n')

    coffs_start = ffs(processor_map) & ~(tc.dev.P_SHARED - 1)
    coffs = coffs_start
    poffs = coffs_start
    next_layer_map_init = processor_map >> coffs
    next_layer_map = next_layer_map_init

    # Output expansion for channels and/or wide output
    out_size = output_width // 8
    width = out_expand * out_size

    read_addr = None
    write_addr = None
    mlat_addr = None
    c = 0
    while c < input_shape[0]:
        if c % out_expand_thresh == 0:
            poffs = coffs_start
            next_layer_map = next_layer_map_init

        expand = c // out_expand_thresh  # Channels 64+ handled by processors 0+
        proc = poffs & ~(tc.dev.P_SHARED - 1)

        if not mlator or out_size > 1:
            for doffs in range(input_shape[1] * input_shape[2]):
                row, col = divmod(doffs, input_shape[2])
                this_map = next_layer_map
                this_c = c

                # Get four bytes from memory array
                offs = out_offset + \
                    (((proc % tc.dev.P_NUMPRO) * tc.dev.INSTANCE_SIZE |
                      (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) +
                     doffs * width + expand * out_size) * 4

                if device == 84 and pool and pool[0] == 4 and pool_stride[
                        0] == 4:
                    offs += (doffs // 4) * 8 + 8

                if offs != read_addr:
                    memfile.write(
                        '  addr = (volatile uint32_t *) '
                        f'0x{apb_base + tc.dev.C_SRAM_BASE + offs:08x};\n')
                if out_size != 4:
                    memfile.write('  val = *addr++;\n')
                    read_addr = offs + 4
                else:
                    read_addr = offs

                # Singulate bytes, ignoring unused processors
                for shift in range(4):
                    addr = this_c * input_shape[1] * input_shape[
                        2] + row * input_shape[1] + col
                    if (shift == 0 or out_size > 1) \
                       and out_size != 4 and input_shape[1] * input_shape[2] != 1:
                        if addr != write_addr:
                            memfile.write(f'  offs = 0x{addr:04x};\n')
                        else:
                            memfile.write('  offs++;\n')
                        write_addr = addr + 1
                    if this_map & 1:
                        if out_size != 4:
                            if input_shape[1] * input_shape[2] != 1:
                                memfile.write('  out_buf[offs')
                                if shift > 0:
                                    memfile.write(f'+0x{0x10 * shift:02x}')
                                memfile.write('] = ')
                            else:
                                memfile.write('  *out_buf++ = ')
                            if shift == 0:
                                memfile.write('val')
                            else:
                                memfile.write(f'(val >> {shift * 8})')
                            if out_size == 1:
                                memfile.write(' & 0xff;\n')
                            else:
                                memfile.write(';\n')
                        else:  # out_size == 4
                            memfile.write('  *out_buf++ = *addr++;\n')
                            write_addr = addr + 4
                            read_addr += 4

                        this_c += 1
                    this_map >>= 1
        else:  # mlator
            assert out_size == 1
            this_map = next_layer_map
            addr = apb_base + tc.dev.C_CNN_BASE + (
                proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS
            mlat = addr + tc.dev.REG_MLAT * 4
            if mlat_addr != mlat:
                mlat_addr = mlat
                ctrl = addr + tc.dev.REG_CTL * 4
                memfile.write(
                    f'  ctrl = (volatile uint32_t *) 0x{ctrl:08x};\n')
                memfile.write(
                    f'  mlat = (volatile uint32_t *) 0x{mlat:08x};\n')

            this_c = c
            for shift in range(4):
                if this_map & 1:
                    memfile.write(f'  // Channel {this_c}\n')

                    for doffs in range(0, input_shape[1] * input_shape[2], 4):
                        row, col = divmod(doffs, input_shape[2])

                        # Get four bytes from memory
                        source = out_offset + \
                            (((proc % tc.dev.P_NUMPRO) * tc.dev.INSTANCE_SIZE |
                              (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) +
                             (doffs >> 2) * width + expand * out_size) * 4
                        target = this_c * input_shape[1] * input_shape[2] \
                            + row * input_shape[1] + col
                        assert target & 3 == 0

                        if target != write_addr:
                            memfile.write(f'  offs = 0x{target >> 2:04x};\n')
                        if source != read_addr:
                            if doffs != 0:
                                memfile.write(
                                    f'  *ctrl = 0x{tc.dev.READY_SEL << 1 | 1 << 3:08x}; '
                                    '// Disable mlator\n')
                            # Set wptr to start address
                            val = addr + tc.dev.C_CNN*4 \
                                + tc.dev.LREG_WPTR_BASE*4 * tc.dev.MAX_LAYERS
                            memfile.write(
                                f'  *((volatile uint32_t *) 0x{val:08x}) = '
                                f'0x{doffs:08x}; // Set SRAM address\n')
                            # Set wptr_inc to set increment value (default: 1)
                            val = addr + tc.dev.C_CNN*4 \
                                + tc.dev.LREG_LCTL2*4 * tc.dev.MAX_LAYERS
                            memfile.write(
                                f'  *((volatile uint32_t *) 0x{val:08x}) = '
                                f'0x{expand:08x}; // Set pointer increment\n')
                            # Set mlatorld enable bit to load write ptr; select byte 0..3
                            val = tc.dev.READY_SEL << 1 | 1 << 16 | shift << 17 | 1 << 3
                            memfile.write(f'  *ctrl = 0x{val:08x}; '
                                          f'// Enable mlator, byte {shift}\n')
                            # memfile.write('  val = *mlat; // Prime\n')
                            memfile.write(
                                '  asm volatile ("" : "=m" (*mlat) : "r" (*mlat));'
                                ' // Prime\n')

                        # FIXME: Do not write more than `num_bytes = min(4, input_shape[2] - col)`
                        memfile.write('  out_buf32[offs++] = *mlat;'
                                      f' // {this_c},{row},{col}-{col+3}\n')
                        read_addr = source + 4
                        write_addr = target + 4

                    # Disable mlator
                    memfile.write(
                        f'  *ctrl = 0x{tc.dev.READY_SEL << 1 | 1 << 3:08x}; '
                        '// Disable mlator\n')
                this_c += 1

                this_map >>= 1

        coffs += 4
        poffs += 4
        c += popcount(next_layer_map & 0x0f)
        next_layer_map >>= 4

    memfile.write('}\n\n')