예제 #1
0
파일: core.py 프로젝트: Ruola/haoda
    def _interfaces(self):
        # find dram reads
        reads_in_lets = tuple(_.expr for _ in self.lets)
        reads_in_exprs = tuple(self.exprs.values())
        dram_reads = collections.OrderedDict()
        for dram_ref in visitor.get_dram_refs(reads_in_lets + reads_in_exprs):
            for bank in dram_ref.dram:
                dram_reads[(dram_ref.var, bank)] = (dram_ref, bank)
        dram_reads = tuple(dram_reads.values())

        # find dram writes
        writes_in_lets = tuple(_.name for _ in self.lets
                               if not isinstance(_.name, str))
        dram_writes = collections.OrderedDict()
        for dram_ref in visitor.get_dram_refs(writes_in_lets):
            for bank in dram_ref.dram:
                dram_writes[(dram_ref.var, bank)] = (dram_ref, bank)
        dram_writes = tuple(dram_writes.values())

        output_fifos = tuple(_.c_expr for _ in self.exprs)
        input_fifos = tuple(_.c_expr for _ in visitor.get_read_fifo_set(self))

        return {
            'dram_writes': dram_writes,
            'output_fifos': output_fifos,
            'input_fifos': input_fifos,
            'dram_reads': dram_reads
        }
예제 #2
0
파일: core.py 프로젝트: Ruola/haoda
    def _interfaces(self):
        # find dram reads
        reads_in_lets = tuple(_.expr for _ in self.lets)
        reads_in_exprs = tuple(self.exprs)
        dram_reads = collections.OrderedDict()
        for dram_ref in visitor.get_dram_refs(reads_in_lets + reads_in_exprs):
            for bank in dram_ref.dram:
                dram_reads[(dram_ref.var, bank)] = (dram_ref, bank)
        dram_reads = tuple(dram_reads.values())

        # find dram writes
        writes_in_lets = tuple(_.name for _ in self.lets
                               if not isinstance(_.name, str))
        dram_writes = collections.OrderedDict()
        for dram_ref in visitor.get_dram_refs(writes_in_lets):
            for bank in dram_ref.dram:
                dram_writes[(dram_ref.var, bank)] = (dram_ref, bank)
        dram_writes = tuple(dram_writes.values())

        output_fifos = tuple('{}{}'.format(FIFORef.ST_PREFIX, idx)
                             for idx, expr in enumerate(self.exprs))
        input_fifos = tuple(_.ld_name for _ in self.loads)

        return {
            'dram_writes': dram_writes,
            'output_fifos': output_fifos,
            'input_fifos': input_fifos,
            'dram_reads': dram_reads
        }
예제 #3
0
def _process_accesses(
    module_trait: ir.ModuleTrait,
    burst_width: int,
    interface: str,
):
  # input/output channels
  if interface in {'m_axi', 'axis'}:
    fifo_loads = [
        f'/* input*/ hls::stream<{DATA_TYPE_FMT[interface].format(x)}>&'
        f' {x.ld_name}' for x in module_trait.loads
    ]
    fifo_stores = [
        f'/*output*/ hls::stream<{DATA_TYPE_FMT[interface].format(expr)}>&'
        f' {ir.FIFORef.ST_PREFIX}{idx}'
        for idx, expr in enumerate(module_trait.exprs)
    ]

  # format strings for input/output channels for packing/unpacking modules
  if interface == 'm_axi':
    fifo_load_fmt = ("f'/* input*/ hls::stream<Data<ap_uint<{burst_width}>>>&"
                     " {x.dram_fifo_name(bank)}'")
    fifo_store_fmt = ("f'/*output*/ hls::stream<Data<ap_uint<{burst_width}>>>&"
                      " {x.dram_fifo_name(bank)}'")
  elif interface == 'axis':
    fifo_load_fmt = (
        "f'/* input*/ hls::stream<ap_axiu<{burst_width}, 0, 0, 0>>&"
        " {x.dram_fifo_name(bank)}'")
    fifo_store_fmt = (
        "f'/*output*/ hls::stream<ap_axiu<{burst_width}, 0, 0, 0>>&"
        " {x.dram_fifo_name(bank)}'")

  # dict mapping variable name to
  #   dict mapping bank tuple to
  #     dict mapping offset to ir.DRAMRef
  dram_read_map: Dict[str, Dict[Tuple[int, ...], Dict[int, ir.DRAMRef]]]
  dram_read_map = collections.defaultdict(dict)
  dram_write_map: Dict[str, Dict[Tuple[int, ...], Dict[int, ir.DRAMRef]]]
  dram_write_map = collections.defaultdict(dict)

  num_bank_map: Dict[str, int] = {}
  all_dram_reads: List[ir.DRAMRef] = []
  dram_reads: List[ir.DRAMRef] = []
  coalescing_factor = 0
  ii = 1

  exprs = [_.expr for _ in module_trait.lets]
  exprs.extend(module_trait.exprs)
  dram_read_refs: Tuple[ir.DRAMRef, ...] = visitor.get_dram_refs(exprs)
  dram_write_refs: Tuple[ir.DRAMRef, ...] = visitor.get_dram_refs(
      _.name for _ in module_trait.lets if not isinstance(_.name, str))

  # temporary dict mapping variable name to
  #   dict mapping bank tuple to
  #     list of ir.DRAMRef
  dram_map: Dict[str, Dict[Tuple[int, ...], List[ir.DRAMRef]]]
  dram_map = collections.defaultdict(lambda: collections.defaultdict(list))

  if dram_read_refs:  # this is an unpacking module
    assert not dram_write_refs, 'cannot read and write DRAM in the same module'
    for dram_ref in dram_read_refs:
      dram_map[dram_ref.var][dram_ref.dram].append(dram_ref)
    _logger.debug('dram read map: %s', dram_map)
    for var in dram_map:
      for dram in dram_map[var]:
        # number of elements per cycle
        batch_size = len(dram_map[var][dram])
        dram_read_map[var][dram] = {_.offset: _ for _ in dram_map[var][dram]}
        offset_dict = dram_read_map[var][dram]
        num_banks = len(dram)
        if var in num_bank_map:
          assert num_bank_map[var] == num_banks, 'inconsistent num banks'
        else:
          num_bank_map[var] = num_banks
        _logger.debug('dram reads: %s', offset_dict)
        assert tuple(sorted(offset_dict.keys())) == tuple(range(batch_size)), \
               'unexpected DRAM accesses pattern %s' % offset_dict
        batch_width = sum(
            _.haoda_type.width_in_bits for _ in offset_dict.values())
        if burst_width * num_banks >= batch_width:
          assert burst_width * num_banks % batch_width == 0, \
              'cannot process such a burst'
          # a single burst consumed in multiple cycles
          coalescing_factor = burst_width * num_banks // batch_width
          ii = coalescing_factor
        else:
          assert batch_width * num_banks % burst_width == 0, \
              'cannot process such a burst'
          # multiple bursts consumed in a single cycle
          # reassemble_factor = batch_width // (burst_width * num_banks)
          raise util.InternalError('cannot process such a burst yet')
      dram_reads = [next(iter(_.values())) for _ in dram_read_map[var].values()]
      all_dram_reads += dram_reads
      fifo_loads.extend(
          eval(fifo_load_fmt, dict(burst_width=burst_width), locals())
          for x in dram_reads
          for bank in x.dram)
  elif dram_write_refs:  # this is a packing module
    for dram_ref in dram_write_refs:
      dram_map[dram_ref.var][dram_ref.dram].append(dram_ref)
    _logger.debug('dram write map: %s', dram_map)
    for var in dram_map:
      for dram in dram_map[var]:
        # number of elements per cycle
        batch_size = len(dram_map[var][dram])
        dram_write_map[var][dram] = {_.offset: _ for _ in dram_map[var][dram]}
        offset_dict = dram_write_map[var][dram]
        num_banks = len(dram)
        if var in num_bank_map:
          assert num_bank_map[var] == num_banks, 'inconsistent num banks'
        else:
          num_bank_map[var] = num_banks
        _logger.debug('dram writes: %s', offset_dict)
        assert tuple(sorted(offset_dict.keys())) == tuple(range(batch_size)), \
               'unexpected DRAM accesses pattern %s' % offset_dict
        batch_width = sum(
            _.haoda_type.width_in_bits for _ in offset_dict.values())
        if burst_width * num_banks >= batch_width:
          assert burst_width * num_banks % batch_width == 0, \
              'cannot process such a burst'
          # a single burst consumed in multiple cycles
          coalescing_factor = burst_width * num_banks // batch_width
          ii = coalescing_factor
        else:
          assert batch_width * num_banks % burst_width == 0, \
              'cannot process such a burst'
          # multiple bursts consumed in a single cycle
          # reassemble_factor = batch_width // (burst_width * num_banks)
          raise util.InternalError('cannot process such a burst yet')
      dram_writes = [
          next(iter(_.values())) for _ in dram_write_map[var].values()
      ]
      fifo_stores.extend(
          eval(fifo_store_fmt, dict(burst_width=burst_width), locals())
          for x in dram_writes
          for bank in x.dram)

  return (
      fifo_loads,
      fifo_stores,
      dram_read_map,
      dram_write_map,
      num_bank_map,
      all_dram_reads,
      coalescing_factor,
      ii,
  )
예제 #4
0
def _print_module_definition(printer, module_trait, module_trait_id, **kwargs):
    println = printer.println
    do_scope = printer.do_scope
    un_scope = printer.un_scope
    func_name = util.get_func_name(module_trait_id)
    func_lower_name = util.get_module_name(module_trait_id)
    ii = 1

    def get_delays(obj, delays):
        if isinstance(obj, ir.DelayedRef):
            delays.append(obj)
        return obj

    delays = []
    for let in module_trait.lets:
        let.visit(get_delays, delays)
    for expr in module_trait.exprs:
        expr.visit(get_delays, delays)
    _logger.debug('delays: %s', delays)

    fifo_loads = tuple(
        '/* input*/ hls::stream<Data<{} > >* {}'.format(_.c_type, _.ld_name)
        for _ in module_trait.loads)
    fifo_stores = tuple('/*output*/ hls::stream<Data<{} > >* {}{}'.format(
        expr.c_type, ir.FIFORef.ST_PREFIX, idx)
                        for idx, expr in enumerate(module_trait.exprs))

    # look for DRAM access
    reads_in_lets = tuple(_.expr for _ in module_trait.lets)
    writes_in_lets = tuple(_.name for _ in module_trait.lets
                           if not isinstance(_.name, str))
    reads_in_exprs = module_trait.exprs
    dram_reads = visitor.get_dram_refs(reads_in_lets + reads_in_exprs)
    dram_writes = visitor.get_dram_refs(writes_in_lets)
    dram_read_map = collections.OrderedDict()
    dram_write_map = collections.OrderedDict()
    all_dram_reads = ()
    num_bank_map = {}
    if dram_reads:  # this is an unpacking module
        assert not dram_writes, 'cannot read and write DRAM in the same module'
        for dram_read in dram_reads:
            dram_read_map.setdefault(dram_read.var,
                                     collections.OrderedDict()).setdefault(
                                         dram_read.dram, []).append(dram_read)
        _logger.debug('dram read map: %s', dram_read_map)
        burst_width = kwargs.pop('burst_width')
        for var in dram_read_map:
            for dram in dram_read_map[var]:
                # number of elements per cycle
                batch_size = len(dram_read_map[var][dram])
                dram_read_map[var][dram] = collections.OrderedDict(
                    (_.offset, _) for _ in dram_read_map[var][dram])
                dram_reads = dram_read_map[var][dram]
                num_banks = len(next(iter(dram_reads.values())).dram)
                if var in num_bank_map:
                    assert num_bank_map[
                        var] == num_banks, 'inconsistent num banks'
                else:
                    num_bank_map[var] = num_banks
                _logger.debug('dram reads: %s', dram_reads)
                assert tuple(sorted(dram_reads.keys())) == tuple(range(batch_size)), \
                       'unexpected DRAM accesses pattern %s' % dram_reads
                batch_width = sum(
                    util.get_width_in_bits(_.haoda_type)
                    for _ in dram_reads.values())
                del dram_reads
                if burst_width * num_banks >= batch_width:
                    assert burst_width * num_banks % batch_width == 0, \
                        'cannot process such a burst'
                    # a single burst consumed in multiple cycles
                    coalescing_factor = burst_width * num_banks // batch_width
                    ii = coalescing_factor
                else:
                    assert batch_width * num_banks % burst_width == 0, \
                        'cannot process such a burst'
                    # multiple bursts consumed in a single cycle
                    # reassemble_factor = batch_width // (burst_width * num_banks)
                    raise util.InternalError('cannot process such a burst yet')
            dram_reads = tuple(
                next(iter(_.values())) for _ in dram_read_map[var].values())
            all_dram_reads += dram_reads
            fifo_loads += tuple(
                '/* input*/ hls::stream<Data<ap_uint<{burst_width} > > >* '
                '{bank_name}'.format(burst_width=burst_width,
                                     bank_name=_.dram_fifo_name(bank))
                for _ in dram_reads for bank in _.dram)
    elif dram_writes:  # this is a packing module
        for dram_write in dram_writes:
            dram_write_map.setdefault(dram_write.var,
                                      collections.OrderedDict()).setdefault(
                                          dram_write.dram,
                                          []).append(dram_write)
        _logger.debug('dram write map: %s', dram_write_map)
        burst_width = kwargs.pop('burst_width')
        for var in dram_write_map:
            for dram in dram_write_map[var]:
                # number of elements per cycle
                batch_size = len(dram_write_map[var][dram])
                dram_write_map[var][dram] = collections.OrderedDict(
                    (_.offset, _) for _ in dram_write_map[var][dram])
                dram_writes = dram_write_map[var][dram]
                num_banks = len(next(iter(dram_writes.values())).dram)
                if var in num_bank_map:
                    assert num_bank_map[
                        var] == num_banks, 'inconsistent num banks'
                else:
                    num_bank_map[var] = num_banks
                _logger.debug('dram writes: %s', dram_writes)
                assert tuple(sorted(dram_writes.keys())) == tuple(range(batch_size)), \
                       'unexpected DRAM accesses pattern %s' % dram_writes
                batch_width = sum(
                    util.get_width_in_bits(_.haoda_type)
                    for _ in dram_writes.values())
                del dram_writes
                if burst_width * num_banks >= batch_width:
                    assert burst_width * num_banks % batch_width == 0, \
                        'cannot process such a burst'
                    # a single burst consumed in multiple cycles
                    coalescing_factor = burst_width * num_banks // batch_width
                    ii = coalescing_factor
                else:
                    assert batch_width * num_banks % burst_width == 0, \
                        'cannot process such a burst'
                    # multiple bursts consumed in a single cycle
                    # reassemble_factor = batch_width // (burst_width * num_banks)
                    raise util.InternalError('cannot process such a burst yet')
            dram_writes = tuple(
                next(iter(_.values())) for _ in dram_write_map[var].values())
            fifo_stores += tuple(
                '/*output*/ hls::stream<Data<ap_uint<{burst_width} > > >* '
                '{bank_name}'.format(burst_width=burst_width,
                                     bank_name=_.dram_fifo_name(bank))
                for _ in dram_writes for bank in _.dram)

    # print function
    printer.print_func('void {func_name}'.format(**locals()),
                       fifo_stores + fifo_loads,
                       align=0)
    do_scope(func_name)

    for dram_ref, bank in module_trait.dram_writes:
        println(
            '#pragma HLS data_pack variable = {}'.format(
                dram_ref.dram_fifo_name(bank)), 0)
    for arg in module_trait.output_fifos:
        println('#pragma HLS data_pack variable = %s' % arg, 0)
    for arg in module_trait.input_fifos:
        println('#pragma HLS data_pack variable = %s' % arg, 0)
    for dram_ref, bank in module_trait.dram_reads:
        println(
            '#pragma HLS data_pack variable = {}'.format(
                dram_ref.dram_fifo_name(bank)), 0)

    # print inter-iteration declarations
    for delay in delays:
        println(delay.c_buf_decl)
        println(delay.c_ptr_decl)

    # print loop
    println('{}_epoch:'.format(func_lower_name), indent=0)
    println('for (bool enable = true; enable;)')
    do_scope('for {}_epoch'.format(func_lower_name))
    println('#pragma HLS pipeline II=%d' % ii, 0)
    for delay in delays:
        println(
            '#pragma HLS dependence variable=%s inter false' % delay.buf_name,
            0)

    # print emptyness tests
    println(
        'if (%s)' %
        (' && '.join('!{fifo}->empty()'.format(fifo=fifo)
                     for fifo in tuple(_.ld_name
                                       for _ in module_trait.loads) + tuple(
                                           _.dram_fifo_name(bank)
                                           for _ in all_dram_reads
                                           for bank in _.dram))))
    do_scope('if not empty')

    # print intra-iteration declarations
    for fifo_in in module_trait.loads:
        println('{fifo_in.c_type} {fifo_in.ref_name};'.format(**locals()))
    for var in dram_read_map:
        for dram in (next(iter(_.values()))
                     for _ in dram_read_map[var].values()):
            for bank in dram.dram:
                println('ap_uint<{}> {};'.format(burst_width,
                                                 dram.dram_buf_name(bank)))
    for var in dram_write_map:
        for dram in (next(iter(_.values()))
                     for _ in dram_write_map[var].values()):
            for bank in dram.dram:
                println('ap_uint<{}> {};'.format(burst_width,
                                                 dram.dram_buf_name(bank)))

    # print enable conditions
    if not dram_write_map:
        for fifo_in in module_trait.loads:
            println('const bool {fifo_in.ref_name}_enable = '
                    'ReadData(&{fifo_in.ref_name}, {fifo_in.ld_name});'.format(
                        **locals()))
    for dram in all_dram_reads:
        for bank in dram.dram:
            println('const bool {dram_buf_name}_enable = '
                    'ReadData(&{dram_buf_name}, {dram_fifo_name});'.format(
                        dram_buf_name=dram.dram_buf_name(bank),
                        dram_fifo_name=dram.dram_fifo_name(bank)))
    if not dram_write_map:
        println('const bool enabled = %s;' % (' && '.join(
            tuple('{_.ref_name}_enable'.format(_=_)
                  for _ in module_trait.loads) +
            tuple('{}_enable'.format(_.dram_buf_name(bank))
                  for _ in all_dram_reads for bank in _.dram))))
        println('enable = enabled;')

    # print delays (if any)
    for delay in delays:
        println('const {} {};'.format(delay.c_type, delay.c_buf_load))

    # print lets
    def mutate_dram_ref_for_writes(obj, kwargs):
        if isinstance(obj, ir.DRAMRef):
            coalescing_idx = kwargs.pop('coalescing_idx')
            unroll_factor = kwargs.pop('unroll_factor')
            type_width = util.get_width_in_bits(obj.haoda_type)
            elem_idx = coalescing_idx * unroll_factor + obj.offset
            num_banks = num_bank_map[obj.var]
            bank = obj.dram[elem_idx % num_banks]
            lsb = (elem_idx // num_banks) * type_width
            msb = lsb + type_width - 1
            return ir.Var(name='{}({msb}, {lsb})'.format(
                obj.dram_buf_name(bank), msb=msb, lsb=lsb),
                          idx=())
        return obj

    # mutate dram ref for writes
    if dram_write_map:
        for coalescing_idx in range(coalescing_factor):
            for fifo_in in module_trait.loads:
                if coalescing_idx == coalescing_factor - 1:
                    prefix = 'const bool {fifo_in.ref_name}_enable = '.format(
                        fifo_in=fifo_in)
                else:
                    prefix = ''
                println('{prefix}ReadData(&{fifo_in.ref_name},'
                        ' {fifo_in.ld_name});'.format(fifo_in=fifo_in,
                                                      prefix=prefix))
            if coalescing_idx == coalescing_factor - 1:
                println('const bool enabled = %s;' % (' && '.join(
                    tuple('{_.ref_name}_enable'.format(_=_)
                          for _ in module_trait.loads) +
                    tuple('{}_enable'.format(_.dram_buf_name(bank))
                          for _ in dram_reads for bank in _.dram))))
                println('enable = enabled;')
            for idx, let in enumerate(module_trait.lets):
                let = let.visit(
                    mutate_dram_ref_for_writes, {
                        'coalescing_idx':
                        coalescing_idx,
                        'unroll_factor':
                        len(dram_write_map[let.name.var][let.name.dram])
                    })
                println('{} = Reinterpret<ap_uint<{width} > >({});'.format(
                    let.name,
                    let.expr.c_expr,
                    width=util.get_width_in_bits(let.expr.haoda_type)))
        for var in dram_write_map:
            for dram in (next(iter(_.values()))
                         for _ in dram_write_map[var].values()):
                for bank in dram.dram:
                    println('WriteData({}, {}, enabled);'.format(
                        dram.dram_fifo_name(bank), dram.dram_buf_name(bank)))
    else:
        for let in module_trait.lets:
            println(let.c_expr)

    def mutate_dram_ref_for_reads(obj, kwargs):
        if isinstance(obj, ir.DRAMRef):
            coalescing_idx = kwargs.pop('coalescing_idx')
            unroll_factor = kwargs.pop('unroll_factor')
            type_width = util.get_width_in_bits(obj.haoda_type)
            elem_idx = coalescing_idx * unroll_factor + obj.offset
            num_banks = num_bank_map[obj.var]
            bank = expr.dram[elem_idx % num_banks]
            lsb = (elem_idx // num_banks) * type_width
            msb = lsb + type_width - 1
            return ir.Var(
                name='Reinterpret<{c_type}>(static_cast<ap_uint<{width} > >('
                '{dram_buf_name}({msb}, {lsb})))'.format(
                    c_type=obj.c_type,
                    dram_buf_name=obj.dram_buf_name(bank),
                    msb=msb,
                    lsb=lsb,
                    width=msb - lsb + 1),
                idx=())
        return obj

    # mutate dram ref for reads
    if dram_read_map:
        for coalescing_idx in range(coalescing_factor):
            for idx, expr in enumerate(module_trait.exprs):
                println('WriteData({}{}, {}, {});'.format(
                    ir.FIFORef.ST_PREFIX, idx,
                    expr.visit(
                        mutate_dram_ref_for_reads, {
                            'coalescing_idx':
                            coalescing_idx,
                            'unroll_factor':
                            len(dram_read_map[expr.var][expr.dram])
                        }).c_expr, 'true'
                    if coalescing_idx < coalescing_factor - 1 else 'enabled'))
    else:
        for idx, expr in enumerate(module_trait.exprs):
            println('WriteData({}{}, {}({}), enabled);'.format(
                ir.FIFORef.ST_PREFIX, idx, expr.c_type, expr.c_expr))

    for delay in delays:
        println(delay.c_buf_store)
        println('{} = {};'.format(delay.ptr, delay.c_next_ptr_expr))

    un_scope()
    un_scope()
    un_scope()
    _logger.debug('printing: %s', module_trait)