Пример #1
0
    def __init__(self, stmt, tile_size):
        self.haoda_type = stmt.haoda_type
        self._tile_size = tile_size
        if isinstance(stmt, grammar.LocalStmtOrOutputStmt):
            self.st_ref = copy.copy(stmt.ref)
            self.st_ref.parent = self
            self.lets = stmt.let
            self.expr = stmt.expr
        elif isinstance(stmt, grammar.InputStmt):
            self._name = stmt.name
            self.st_ref = None
            self.lets = []
            self.expr = None
        else:
            raise util.InternalError('cannot initialize a Tensor from %s' %
                                     type(stmt))
        _logger.debug('tensor initialized from stmt `%s`', stmt)
        # pylint: disable=protected-access
        _logger.debug('                   at tx position %d',
                      stmt._tx_position)

        # these fields are to be set externally
        self.parents = collections.OrderedDict()
        self.children = collections.OrderedDict()
        self.ld_refs = collections.OrderedDict()
Пример #2
0
def print_kernel_xml(
    name: str,
    inputs: Iterable[Tuple[str, str, int, str]],
    outputs: Iterable[Tuple[str, str, int, str]],
    kernel_xml: TextIO,
    interface: str = 'm_axi',
) -> None:
    """Generate kernel.xml file.

  Args:
    name: Name of the kernel.
    inputs: Iterable of (port_name, bundle_name, width, _) of input ports.
    outputs: Iterable of (port_name, bundle_name, width, _) of output ports.
    kernel_xml: File object to write to.
    interface: Interface type, supported values are 'm_axi' and 'axis'.
  """
    args: List[backend.Arg] = []
    if interface == 'm_axi':
        for ports in outputs, inputs:
            for port_name, bundle_name, width, _ in ports:
                args.append(
                    backend.Arg(
                        cat=backend.Cat.MMAP,
                        name=port_name,
                        port=bundle_name,
                        ctype=f'ap_uint<{width}>*',
                        width=width,
                    ))
        args.append(
            backend.Arg(
                cat=backend.Cat.SCALAR,
                name='coalesced_data_num',
                port='',
                ctype='uint64_t',
                width=64,
            ))
    elif interface == 'axis':
        for cat, ports in ((backend.Cat.ISTREAM, inputs), (backend.Cat.OSTREAM,
                                                           outputs)):
            for port_name, _, width, _ in ports:
                ctype = f'stream<ap_axiu<{width}, 0, 0, 0>>&'
                args.append(
                    backend.Arg(
                        cat=cat,
                        name=port_name,
                        port='',
                        ctype=ctype,
                        width=width,
                    ))
    else:
        raise util.InternalError(f'unexpected interface `{interface}`')

    backend.print_kernel_xml(name=name, args=args, kernel_xml=kernel_xml)
Пример #3
0
    def verify_mode_depths(self) -> None:
        latency_table = {}
        lp_problem = pulp.LpProblem('verify_fifo_depths', pulp.LpMinimize)
        for node in self.tpo_valid_node_gen():
            if self in node.parents:
                latency_table[node] = 0
            else:
                latency_table[node] = pulp.LpVariable(
                    name=f'latency_{node.name}',
                    lowBound=0,
                    cat='Integer',
                )
                lp_problem.extend(
                    parent.get_latency(node) +
                    latency_table[parent] <= latency_table[node]
                    for parent in node.parents)
                lp_problem.extend(
                    parent.get_latency(node) + latency_table[parent] +
                    parent.fifo(node).depth >= latency_table[node]
                    for parent in node.parents)

        lp_status = lp_problem.solve()
        if lp_status == pulp.LpStatusOptimal:
            _logger.debug('II=1 check: PASS')
        elif lp_status == pulp.LpStatusInfeasible:
            _logger.warn('II=1 check: FAIL')
        else:
            lp_status_str = pulp.LpStatus[lp_status]
            _logger.error('ILP error: %s\n%s', lp_status_str, lp_problem)
            raise util.InternalError('unexpected ILP status: %s' %
                                     lp_status_str)

        for node in self.tpo_valid_node_gen():
            if self in node.parents:
                min_capacity = 0
            else:
                min_capacity = min(
                    parent.get_latency(node) +
                    int(pulp.value(latency_table[parent])) +
                    parent.fifo(node).depth for parent in node.parents)

            debug_enabled = _logger.isEnabledFor(logging.DEBUG)
            check_failed = int(pulp.value(latency_table[node])) > min_capacity
            if debug_enabled or check_failed:
                (_logger.debug if debug_enabled else _logger.warn)(
                    'II=1 check %s: %s: latency %d %s min capacity %d',
                    '✖' if check_failed else '✔',
                    repr(node),
                    int(pulp.value(latency_table[node])),
                    '>' if check_failed else '<=',
                    min_capacity,
                )
Пример #4
0
 def name_in_iter(name, iteration):
     if name in self.input_names:
         if iteration > 0:
             return name + '_iter%d' % iteration
         return name
     if name in self.output_names:
         if iteration < self.iterate - 1:
             return (self.input_names[self.output_names.index(name)] +
                     '_iter%d' % (iteration + 1))
         return name
     if name in self.local_names:
         if iteration > 0:
             return name + '_iter%d' % iteration
         return name
     if name in self.param_names:
         return name
     raise util.InternalError('unknown name: %s' % name)
Пример #5
0
 def _get_haoda_type(self):
     for attr in self.ATTRS:
         val = getattr(self, attr)
         if val is not None:
             if hasattr(val, 'haoda_type'):
                 return ir.Type(val.haoda_type)
             if attr == 'num':
                 if 'u' in val.lower():
                     if 'll' in val.lower():
                         return ir.Type('uint64')
                     return ir.Type('uint32')
                 if 'll' in val.lower():
                     return ir.Type('int64')
                 if 'fl' in val.lower():
                     return ir.Type('double')
                 if 'f' in val.lower() or 'e' in val.lower():
                     return ir.Type('float')
                 if '.' in val:
                     return ir.Type('double')
                 return ir.Type('int32')
             return None
     raise util.InternalError('undefined Operand')
Пример #6
0
def _process_accesses(
    module_trait: ir.ModuleTrait,
    burst_width: int,
    interface: str,
):
  # input/output channels
  if interface in {'m_axi', 'axis'}:
    fifo_loads = [
        f'/* input*/ hls::stream<{DATA_TYPE_FMT[interface].format(x)}>&'
        f' {x.ld_name}' for x in module_trait.loads
    ]
    fifo_stores = [
        f'/*output*/ hls::stream<{DATA_TYPE_FMT[interface].format(expr)}>&'
        f' {ir.FIFORef.ST_PREFIX}{idx}'
        for idx, expr in enumerate(module_trait.exprs)
    ]

  # format strings for input/output channels for packing/unpacking modules
  if interface == 'm_axi':
    fifo_load_fmt = ("f'/* input*/ hls::stream<Data<ap_uint<{burst_width}>>>&"
                     " {x.dram_fifo_name(bank)}'")
    fifo_store_fmt = ("f'/*output*/ hls::stream<Data<ap_uint<{burst_width}>>>&"
                      " {x.dram_fifo_name(bank)}'")
  elif interface == 'axis':
    fifo_load_fmt = (
        "f'/* input*/ hls::stream<ap_axiu<{burst_width}, 0, 0, 0>>&"
        " {x.dram_fifo_name(bank)}'")
    fifo_store_fmt = (
        "f'/*output*/ hls::stream<ap_axiu<{burst_width}, 0, 0, 0>>&"
        " {x.dram_fifo_name(bank)}'")

  # dict mapping variable name to
  #   dict mapping bank tuple to
  #     dict mapping offset to ir.DRAMRef
  dram_read_map: Dict[str, Dict[Tuple[int, ...], Dict[int, ir.DRAMRef]]]
  dram_read_map = collections.defaultdict(dict)
  dram_write_map: Dict[str, Dict[Tuple[int, ...], Dict[int, ir.DRAMRef]]]
  dram_write_map = collections.defaultdict(dict)

  num_bank_map: Dict[str, int] = {}
  all_dram_reads: List[ir.DRAMRef] = []
  dram_reads: List[ir.DRAMRef] = []
  coalescing_factor = 0
  ii = 1

  exprs = [_.expr for _ in module_trait.lets]
  exprs.extend(module_trait.exprs)
  dram_read_refs: Tuple[ir.DRAMRef, ...] = visitor.get_dram_refs(exprs)
  dram_write_refs: Tuple[ir.DRAMRef, ...] = visitor.get_dram_refs(
      _.name for _ in module_trait.lets if not isinstance(_.name, str))

  # temporary dict mapping variable name to
  #   dict mapping bank tuple to
  #     list of ir.DRAMRef
  dram_map: Dict[str, Dict[Tuple[int, ...], List[ir.DRAMRef]]]
  dram_map = collections.defaultdict(lambda: collections.defaultdict(list))

  if dram_read_refs:  # this is an unpacking module
    assert not dram_write_refs, 'cannot read and write DRAM in the same module'
    for dram_ref in dram_read_refs:
      dram_map[dram_ref.var][dram_ref.dram].append(dram_ref)
    _logger.debug('dram read map: %s', dram_map)
    for var in dram_map:
      for dram in dram_map[var]:
        # number of elements per cycle
        batch_size = len(dram_map[var][dram])
        dram_read_map[var][dram] = {_.offset: _ for _ in dram_map[var][dram]}
        offset_dict = dram_read_map[var][dram]
        num_banks = len(dram)
        if var in num_bank_map:
          assert num_bank_map[var] == num_banks, 'inconsistent num banks'
        else:
          num_bank_map[var] = num_banks
        _logger.debug('dram reads: %s', offset_dict)
        assert tuple(sorted(offset_dict.keys())) == tuple(range(batch_size)), \
               'unexpected DRAM accesses pattern %s' % offset_dict
        batch_width = sum(
            _.haoda_type.width_in_bits for _ in offset_dict.values())
        if burst_width * num_banks >= batch_width:
          assert burst_width * num_banks % batch_width == 0, \
              'cannot process such a burst'
          # a single burst consumed in multiple cycles
          coalescing_factor = burst_width * num_banks // batch_width
          ii = coalescing_factor
        else:
          assert batch_width * num_banks % burst_width == 0, \
              'cannot process such a burst'
          # multiple bursts consumed in a single cycle
          # reassemble_factor = batch_width // (burst_width * num_banks)
          raise util.InternalError('cannot process such a burst yet')
      dram_reads = [next(iter(_.values())) for _ in dram_read_map[var].values()]
      all_dram_reads += dram_reads
      fifo_loads.extend(
          eval(fifo_load_fmt, dict(burst_width=burst_width), locals())
          for x in dram_reads
          for bank in x.dram)
  elif dram_write_refs:  # this is a packing module
    for dram_ref in dram_write_refs:
      dram_map[dram_ref.var][dram_ref.dram].append(dram_ref)
    _logger.debug('dram write map: %s', dram_map)
    for var in dram_map:
      for dram in dram_map[var]:
        # number of elements per cycle
        batch_size = len(dram_map[var][dram])
        dram_write_map[var][dram] = {_.offset: _ for _ in dram_map[var][dram]}
        offset_dict = dram_write_map[var][dram]
        num_banks = len(dram)
        if var in num_bank_map:
          assert num_bank_map[var] == num_banks, 'inconsistent num banks'
        else:
          num_bank_map[var] = num_banks
        _logger.debug('dram writes: %s', offset_dict)
        assert tuple(sorted(offset_dict.keys())) == tuple(range(batch_size)), \
               'unexpected DRAM accesses pattern %s' % offset_dict
        batch_width = sum(
            _.haoda_type.width_in_bits for _ in offset_dict.values())
        if burst_width * num_banks >= batch_width:
          assert burst_width * num_banks % batch_width == 0, \
              'cannot process such a burst'
          # a single burst consumed in multiple cycles
          coalescing_factor = burst_width * num_banks // batch_width
          ii = coalescing_factor
        else:
          assert batch_width * num_banks % burst_width == 0, \
              'cannot process such a burst'
          # multiple bursts consumed in a single cycle
          # reassemble_factor = batch_width // (burst_width * num_banks)
          raise util.InternalError('cannot process such a burst yet')
      dram_writes = [
          next(iter(_.values())) for _ in dram_write_map[var].values()
      ]
      fifo_stores.extend(
          eval(fifo_store_fmt, dict(burst_width=burst_width), locals())
          for x in dram_writes
          for bank in x.dram)

  return (
      fifo_loads,
      fifo_stores,
      dram_read_map,
      dram_write_map,
      num_bank_map,
      all_dram_reads,
      coalescing_factor,
      ii,
  )
Пример #7
0
def _print_module_definition(printer, module_trait, module_trait_id, **kwargs):
    println = printer.println
    do_scope = printer.do_scope
    un_scope = printer.un_scope
    func_name = util.get_func_name(module_trait_id)
    func_lower_name = util.get_module_name(module_trait_id)
    ii = 1

    def get_delays(obj, delays):
        if isinstance(obj, ir.DelayedRef):
            delays.append(obj)
        return obj

    delays = []
    for let in module_trait.lets:
        let.visit(get_delays, delays)
    for expr in module_trait.exprs:
        expr.visit(get_delays, delays)
    _logger.debug('delays: %s', delays)

    fifo_loads = tuple(
        '/* input*/ hls::stream<Data<{} > >* {}'.format(_.c_type, _.ld_name)
        for _ in module_trait.loads)
    fifo_stores = tuple('/*output*/ hls::stream<Data<{} > >* {}{}'.format(
        expr.c_type, ir.FIFORef.ST_PREFIX, idx)
                        for idx, expr in enumerate(module_trait.exprs))

    # look for DRAM access
    reads_in_lets = tuple(_.expr for _ in module_trait.lets)
    writes_in_lets = tuple(_.name for _ in module_trait.lets
                           if not isinstance(_.name, str))
    reads_in_exprs = module_trait.exprs
    dram_reads = visitor.get_dram_refs(reads_in_lets + reads_in_exprs)
    dram_writes = visitor.get_dram_refs(writes_in_lets)
    dram_read_map = collections.OrderedDict()
    dram_write_map = collections.OrderedDict()
    all_dram_reads = ()
    num_bank_map = {}
    if dram_reads:  # this is an unpacking module
        assert not dram_writes, 'cannot read and write DRAM in the same module'
        for dram_read in dram_reads:
            dram_read_map.setdefault(dram_read.var,
                                     collections.OrderedDict()).setdefault(
                                         dram_read.dram, []).append(dram_read)
        _logger.debug('dram read map: %s', dram_read_map)
        burst_width = kwargs.pop('burst_width')
        for var in dram_read_map:
            for dram in dram_read_map[var]:
                # number of elements per cycle
                batch_size = len(dram_read_map[var][dram])
                dram_read_map[var][dram] = collections.OrderedDict(
                    (_.offset, _) for _ in dram_read_map[var][dram])
                dram_reads = dram_read_map[var][dram]
                num_banks = len(next(iter(dram_reads.values())).dram)
                if var in num_bank_map:
                    assert num_bank_map[
                        var] == num_banks, 'inconsistent num banks'
                else:
                    num_bank_map[var] = num_banks
                _logger.debug('dram reads: %s', dram_reads)
                assert tuple(sorted(dram_reads.keys())) == tuple(range(batch_size)), \
                       'unexpected DRAM accesses pattern %s' % dram_reads
                batch_width = sum(
                    util.get_width_in_bits(_.haoda_type)
                    for _ in dram_reads.values())
                del dram_reads
                if burst_width * num_banks >= batch_width:
                    assert burst_width * num_banks % batch_width == 0, \
                        'cannot process such a burst'
                    # a single burst consumed in multiple cycles
                    coalescing_factor = burst_width * num_banks // batch_width
                    ii = coalescing_factor
                else:
                    assert batch_width * num_banks % burst_width == 0, \
                        'cannot process such a burst'
                    # multiple bursts consumed in a single cycle
                    # reassemble_factor = batch_width // (burst_width * num_banks)
                    raise util.InternalError('cannot process such a burst yet')
            dram_reads = tuple(
                next(iter(_.values())) for _ in dram_read_map[var].values())
            all_dram_reads += dram_reads
            fifo_loads += tuple(
                '/* input*/ hls::stream<Data<ap_uint<{burst_width} > > >* '
                '{bank_name}'.format(burst_width=burst_width,
                                     bank_name=_.dram_fifo_name(bank))
                for _ in dram_reads for bank in _.dram)
    elif dram_writes:  # this is a packing module
        for dram_write in dram_writes:
            dram_write_map.setdefault(dram_write.var,
                                      collections.OrderedDict()).setdefault(
                                          dram_write.dram,
                                          []).append(dram_write)
        _logger.debug('dram write map: %s', dram_write_map)
        burst_width = kwargs.pop('burst_width')
        for var in dram_write_map:
            for dram in dram_write_map[var]:
                # number of elements per cycle
                batch_size = len(dram_write_map[var][dram])
                dram_write_map[var][dram] = collections.OrderedDict(
                    (_.offset, _) for _ in dram_write_map[var][dram])
                dram_writes = dram_write_map[var][dram]
                num_banks = len(next(iter(dram_writes.values())).dram)
                if var in num_bank_map:
                    assert num_bank_map[
                        var] == num_banks, 'inconsistent num banks'
                else:
                    num_bank_map[var] = num_banks
                _logger.debug('dram writes: %s', dram_writes)
                assert tuple(sorted(dram_writes.keys())) == tuple(range(batch_size)), \
                       'unexpected DRAM accesses pattern %s' % dram_writes
                batch_width = sum(
                    util.get_width_in_bits(_.haoda_type)
                    for _ in dram_writes.values())
                del dram_writes
                if burst_width * num_banks >= batch_width:
                    assert burst_width * num_banks % batch_width == 0, \
                        'cannot process such a burst'
                    # a single burst consumed in multiple cycles
                    coalescing_factor = burst_width * num_banks // batch_width
                    ii = coalescing_factor
                else:
                    assert batch_width * num_banks % burst_width == 0, \
                        'cannot process such a burst'
                    # multiple bursts consumed in a single cycle
                    # reassemble_factor = batch_width // (burst_width * num_banks)
                    raise util.InternalError('cannot process such a burst yet')
            dram_writes = tuple(
                next(iter(_.values())) for _ in dram_write_map[var].values())
            fifo_stores += tuple(
                '/*output*/ hls::stream<Data<ap_uint<{burst_width} > > >* '
                '{bank_name}'.format(burst_width=burst_width,
                                     bank_name=_.dram_fifo_name(bank))
                for _ in dram_writes for bank in _.dram)

    # print function
    printer.print_func('void {func_name}'.format(**locals()),
                       fifo_stores + fifo_loads,
                       align=0)
    do_scope(func_name)

    for dram_ref, bank in module_trait.dram_writes:
        println(
            '#pragma HLS data_pack variable = {}'.format(
                dram_ref.dram_fifo_name(bank)), 0)
    for arg in module_trait.output_fifos:
        println('#pragma HLS data_pack variable = %s' % arg, 0)
    for arg in module_trait.input_fifos:
        println('#pragma HLS data_pack variable = %s' % arg, 0)
    for dram_ref, bank in module_trait.dram_reads:
        println(
            '#pragma HLS data_pack variable = {}'.format(
                dram_ref.dram_fifo_name(bank)), 0)

    # print inter-iteration declarations
    for delay in delays:
        println(delay.c_buf_decl)
        println(delay.c_ptr_decl)

    # print loop
    println('{}_epoch:'.format(func_lower_name), indent=0)
    println('for (bool enable = true; enable;)')
    do_scope('for {}_epoch'.format(func_lower_name))
    println('#pragma HLS pipeline II=%d' % ii, 0)
    for delay in delays:
        println(
            '#pragma HLS dependence variable=%s inter false' % delay.buf_name,
            0)

    # print emptyness tests
    println(
        'if (%s)' %
        (' && '.join('!{fifo}->empty()'.format(fifo=fifo)
                     for fifo in tuple(_.ld_name
                                       for _ in module_trait.loads) + tuple(
                                           _.dram_fifo_name(bank)
                                           for _ in all_dram_reads
                                           for bank in _.dram))))
    do_scope('if not empty')

    # print intra-iteration declarations
    for fifo_in in module_trait.loads:
        println('{fifo_in.c_type} {fifo_in.ref_name};'.format(**locals()))
    for var in dram_read_map:
        for dram in (next(iter(_.values()))
                     for _ in dram_read_map[var].values()):
            for bank in dram.dram:
                println('ap_uint<{}> {};'.format(burst_width,
                                                 dram.dram_buf_name(bank)))
    for var in dram_write_map:
        for dram in (next(iter(_.values()))
                     for _ in dram_write_map[var].values()):
            for bank in dram.dram:
                println('ap_uint<{}> {};'.format(burst_width,
                                                 dram.dram_buf_name(bank)))

    # print enable conditions
    if not dram_write_map:
        for fifo_in in module_trait.loads:
            println('const bool {fifo_in.ref_name}_enable = '
                    'ReadData(&{fifo_in.ref_name}, {fifo_in.ld_name});'.format(
                        **locals()))
    for dram in all_dram_reads:
        for bank in dram.dram:
            println('const bool {dram_buf_name}_enable = '
                    'ReadData(&{dram_buf_name}, {dram_fifo_name});'.format(
                        dram_buf_name=dram.dram_buf_name(bank),
                        dram_fifo_name=dram.dram_fifo_name(bank)))
    if not dram_write_map:
        println('const bool enabled = %s;' % (' && '.join(
            tuple('{_.ref_name}_enable'.format(_=_)
                  for _ in module_trait.loads) +
            tuple('{}_enable'.format(_.dram_buf_name(bank))
                  for _ in all_dram_reads for bank in _.dram))))
        println('enable = enabled;')

    # print delays (if any)
    for delay in delays:
        println('const {} {};'.format(delay.c_type, delay.c_buf_load))

    # print lets
    def mutate_dram_ref_for_writes(obj, kwargs):
        if isinstance(obj, ir.DRAMRef):
            coalescing_idx = kwargs.pop('coalescing_idx')
            unroll_factor = kwargs.pop('unroll_factor')
            type_width = util.get_width_in_bits(obj.haoda_type)
            elem_idx = coalescing_idx * unroll_factor + obj.offset
            num_banks = num_bank_map[obj.var]
            bank = obj.dram[elem_idx % num_banks]
            lsb = (elem_idx // num_banks) * type_width
            msb = lsb + type_width - 1
            return ir.Var(name='{}({msb}, {lsb})'.format(
                obj.dram_buf_name(bank), msb=msb, lsb=lsb),
                          idx=())
        return obj

    # mutate dram ref for writes
    if dram_write_map:
        for coalescing_idx in range(coalescing_factor):
            for fifo_in in module_trait.loads:
                if coalescing_idx == coalescing_factor - 1:
                    prefix = 'const bool {fifo_in.ref_name}_enable = '.format(
                        fifo_in=fifo_in)
                else:
                    prefix = ''
                println('{prefix}ReadData(&{fifo_in.ref_name},'
                        ' {fifo_in.ld_name});'.format(fifo_in=fifo_in,
                                                      prefix=prefix))
            if coalescing_idx == coalescing_factor - 1:
                println('const bool enabled = %s;' % (' && '.join(
                    tuple('{_.ref_name}_enable'.format(_=_)
                          for _ in module_trait.loads) +
                    tuple('{}_enable'.format(_.dram_buf_name(bank))
                          for _ in dram_reads for bank in _.dram))))
                println('enable = enabled;')
            for idx, let in enumerate(module_trait.lets):
                let = let.visit(
                    mutate_dram_ref_for_writes, {
                        'coalescing_idx':
                        coalescing_idx,
                        'unroll_factor':
                        len(dram_write_map[let.name.var][let.name.dram])
                    })
                println('{} = Reinterpret<ap_uint<{width} > >({});'.format(
                    let.name,
                    let.expr.c_expr,
                    width=util.get_width_in_bits(let.expr.haoda_type)))
        for var in dram_write_map:
            for dram in (next(iter(_.values()))
                         for _ in dram_write_map[var].values()):
                for bank in dram.dram:
                    println('WriteData({}, {}, enabled);'.format(
                        dram.dram_fifo_name(bank), dram.dram_buf_name(bank)))
    else:
        for let in module_trait.lets:
            println(let.c_expr)

    def mutate_dram_ref_for_reads(obj, kwargs):
        if isinstance(obj, ir.DRAMRef):
            coalescing_idx = kwargs.pop('coalescing_idx')
            unroll_factor = kwargs.pop('unroll_factor')
            type_width = util.get_width_in_bits(obj.haoda_type)
            elem_idx = coalescing_idx * unroll_factor + obj.offset
            num_banks = num_bank_map[obj.var]
            bank = expr.dram[elem_idx % num_banks]
            lsb = (elem_idx // num_banks) * type_width
            msb = lsb + type_width - 1
            return ir.Var(
                name='Reinterpret<{c_type}>(static_cast<ap_uint<{width} > >('
                '{dram_buf_name}({msb}, {lsb})))'.format(
                    c_type=obj.c_type,
                    dram_buf_name=obj.dram_buf_name(bank),
                    msb=msb,
                    lsb=lsb,
                    width=msb - lsb + 1),
                idx=())
        return obj

    # mutate dram ref for reads
    if dram_read_map:
        for coalescing_idx in range(coalescing_factor):
            for idx, expr in enumerate(module_trait.exprs):
                println('WriteData({}{}, {}, {});'.format(
                    ir.FIFORef.ST_PREFIX, idx,
                    expr.visit(
                        mutate_dram_ref_for_reads, {
                            'coalescing_idx':
                            coalescing_idx,
                            'unroll_factor':
                            len(dram_read_map[expr.var][expr.dram])
                        }).c_expr, 'true'
                    if coalescing_idx < coalescing_factor - 1 else 'enabled'))
    else:
        for idx, expr in enumerate(module_trait.exprs):
            println('WriteData({}{}, {}({}), enabled);'.format(
                ir.FIFORef.ST_PREFIX, idx, expr.c_type, expr.c_expr))

    for delay in delays:
        println(delay.c_buf_store)
        println('{} = {};'.format(delay.ptr, delay.c_next_ptr_expr))

    un_scope()
    un_scope()
    un_scope()
    _logger.debug('printing: %s', module_trait)
Пример #8
0
    def tensors(self):
        """Constructs high-level DAG and creates the tensors.

    Returns:
      An collections.OrderedDict mapping a tensor's name to the tensor.
    """
        # TODO: check for name conflicts
        tensor_map = collections.OrderedDict()
        for stmt in self.input_stmts:
            tensor = soda.tensor.Tensor(stmt, self.tile_size)
            tensor_map[stmt.name] = tensor

        def name_in_iter(name, iteration):
            if name in self.input_names:
                if iteration > 0:
                    return name + '_iter%d' % iteration
                return name
            if name in self.output_names:
                if iteration < self.iterate - 1:
                    return (self.input_names[self.output_names.index(name)] +
                            '_iter%d' % (iteration + 1))
                return name
            if name in self.local_names:
                if iteration > 0:
                    return name + '_iter%d' % iteration
                return name
            if name in self.param_names:
                return name
            raise util.InternalError('unknown name: %s' % name)

        for iteration in range(self.iterate):
            _logger.debug('iterate %s', iteration)
            _logger.debug('map: %s', self.symbol_table)

            def mutate_name_callback(obj, mutated):
                if isinstance(obj, ir.Ref):
                    obj.haoda_type = self.symbol_table[obj.name]
                    # pylint: disable=cell-var-from-loop
                    obj.name = name_in_iter(obj.name, iteration)
                return obj

            tensors = []
            for stmt in itertools.chain(self.local_stmts, self.output_stmts):
                tensor = soda.tensor.Tensor(stmt.visit(mutate_name_callback),
                                            self.tile_size)
                tensor_map[tensor.name] = tensor
                tensors.append(tensor)

            for tensor in tensors:
                _logger.debug('%s', tensor)

            for tensor in tensors:
                tensor.propagate_type()
                loads = visitor.get_load_dict(tensor)
                for parent_name, ld_refs in loads.items():
                    ld_refs = sorted(ld_refs,
                                     key=lambda ref: soda.util.serialize(
                                         ref.idx, self.tile_size))
                    parent_tensor = tensor_map[parent_name]
                    parent_tensor.children[tensor.name] = tensor
                    tensor.parents[parent_name] = parent_tensor
                    tensor.ld_refs[parent_name] = ld_refs

        # solve ILP for optimal reuse buffer
        lp_problem = pulp.LpProblem("optimal_reuse_buffer", pulp.LpMinimize)
        lp_vars = {self.input_names[0]: 0}  # set 1 and only 1 reference point
        lp_helper_vars = {}  # type: Dict[str, pulp.LpVariable]
        objectives = []
        constraints = []
        for tensor in tensor_map.values():
            lp_var = pulp.LpVariable('produced_offset_' + tensor.name,
                                     cat='Integer')
            lp_helper_var = pulp.LpVariable('consumed_offset_' + tensor.name,
                                            cat='Integer')
            lp_vars.setdefault(tensor.name, lp_var)
            lp_helper_vars[tensor.name] = lp_helper_var
            # tensor need to be kept for this long
            objectives.append(lp_helper_var - lp_vars[tensor.name])
            # tensor cannot be consumed until it is produced
            constraints.append(lp_helper_var >= lp_vars[tensor.name])
        lp_problem += sum(objectives)
        lp_problem.extend(constraints)
        for st_tensor in tensor_map.values():
            for ld_tensor_name, offsets in st_tensor.ld_offsets.items():
                oldest_access = min(offsets)
                newest_access = max(offsets)
                _logger.debug('%s @ %s accesses %s @ [%s, %s]', st_tensor.name,
                              st_tensor.st_offset, ld_tensor_name,
                              oldest_access, newest_access)
                # newest ld_tensor access must have been produced
                # when st_tensor is produced
                lp_problem += lp_vars[ld_tensor_name] <= lp_vars[
                    st_tensor.name] + (st_tensor.st_offset - newest_access)
                # oldest ld_tensor access must have been not consumed
                # when st_tensor is produced
                lp_problem += lp_helper_vars[ld_tensor_name] >= lp_vars[
                    st_tensor.name] + (st_tensor.st_offset - oldest_access)

        lp_status = lp_problem.solve()
        lp_status_str = pulp.LpStatus[lp_status]
        total_distance = int(pulp.value(lp_problem.objective))
        _logger.debug('ILP status: %s %s', lp_status_str, total_distance)
        _logger.info('total reuse distance: %d', total_distance)

        if lp_status != pulp.LpStatusOptimal:
            _logger.error('ILP error: %s\n%s', lp_status_str, lp_problem)
            raise util.InternalError('unexpected ILP status: %s' %
                                     lp_status_str)

        # some inputs may need to be delayed relative to others
        base = min(int(pulp.value(lp_vars[x])) for x in self.input_names)

        # set produce offsets
        for tensor in tensor_map.values():
            produce_offset = int(pulp.value(lp_vars[tensor.name])) - base
            consume_offset = int(pulp.value(
                lp_helper_vars[tensor.name])) - base
            tensor.produce_offset = produce_offset
            tensor.consume_offset = consume_offset
            tensor.max_access = 0  # pixels before current produce
            _logger.debug('%s should be produced @ %d and kept until %d',
                          tensor.name, produce_offset, consume_offset)

        # calculate overall acceses
        for ld_tensor in tensor_map.values():
            for st_tensor in ld_tensor.children.values():
                oldest_access = st_tensor.st_offset - min(
                    st_tensor.ld_offsets[ld_tensor.name]
                ) + st_tensor.produce_offset - ld_tensor.produce_offset
                newest_access = st_tensor.st_offset - max(
                    st_tensor.ld_offsets[ld_tensor.name]
                ) + st_tensor.produce_offset - ld_tensor.produce_offset
                _logger.debug(
                    '  producing %s @ %s accesses [%s, %s] pixels before %s '
                    'produced @ %s', st_tensor.name, st_tensor.produce_offset,
                    newest_access, oldest_access, ld_tensor.name,
                    ld_tensor.produce_offset)
                ld_tensor.max_access = max(ld_tensor.max_access, oldest_access)

        for tensor in tensor_map.values():
            _logger.debug('%s should be kept for %s pixels', tensor.name,
                          tensor.max_access)

        # high-level DAG construction finished
        for tensor in tensor_map.values():
            if tensor.name in self.input_names:
                _logger.debug('<input tensor>: %s', tensor)
            elif tensor.name in self.output_names:
                _logger.debug('<output tensor>: %s', tensor)
            else:
                _logger.debug('<local tensor>: %s', tensor)
        return tensor_map
Пример #9
0
    def update_module_depths(
        self,
        depths: Dict[int, int],
    ) -> None:
        """Update module pipeline depths and FIFO depths.

    The FIFO depths are determined by solving an ILP problem:

    + Optimization objective: minimize the sum (weighted by FIFO width) of all
    FIFO depths.
    + Constraints: the whole DAG can be fully pipelined without artificial
    stalls.

    For every non-overlapping path between a pair of nodes,
    the latency of each token is the maximum minimum latency among all paths.
    To enable full pipelining,
    this latency must not exceed the maximum latency of any path.
    The minimum latency of each path is the sum of the FIFO write latency in
    each module and the number of edges (FIFOs),
    since the minimum latency of a FIFO is 1.
    The maximum latency of each path is the sum of the FIFO write latency in
    each module and the total depth of FIFOs.

    Args:
        depths (Dict[int, int]): Dict mapping module ids to pipeline depths.
    """
        # update module pipeline depths
        for src_node, dst_node in self.bfs_valid_edge_gen():
            module_id = self.module_table[src_node][1]
            depth = depths.get(module_id)
            if depth is not None:
                fifo = src_node.fifo(dst_node)
                if fifo.write_lat != depth:
                    _logger.debug('%s write latency changed %s -> %d', fifo,
                                  fifo.write_lat, depth)
                    fifo.write_lat = depth

        # set up ILP problem, variables, and objective
        lp_problem = pulp.LpProblem('optimal_fifo_depths', pulp.LpMinimize)
        lp_vars = {}
        for src_node, dst_node in self.bfs_valid_edge_gen():
            lp_vars[(src_node, dst_node)] = pulp.LpVariable(
                name=f'depth_{src_node.fifo(dst_node).c_expr}',
                lowBound=0,
                cat='Integer',
            )
        lp_problem += sum(
            x.fifo(y).haoda_type.width_in_bits * v
            for (x, y), v in lp_vars.items())

        # add ILP constraints
        latency_table = {
            x: pulp.LpVariable(name=f'latency_{x.name}',
                               lowBound=0,
                               cat='Integer')
            for x in self.tpo_valid_node_gen()
        }
        for node in self.tpo_valid_node_gen():
            if self in node.parents:
                latency_table[node] = 0
            else:
                lp_problem.extend(
                    parent.get_latency(node) +
                    latency_table[parent] <= latency_table[node]
                    for parent in node.parents)
                lp_problem.extend(
                    parent.get_latency(node) + latency_table[parent] +
                    lp_vars[(parent, node)] >= latency_table[node]
                    for parent in node.parents)

        # solve ILP
        lp_status = lp_problem.solve()
        if lp_status != pulp.LpStatusOptimal:
            lp_status_str = pulp.LpStatus[lp_status]
            _logger.error('ILP error: %s\n%s', lp_status_str, lp_problem)
            raise util.InternalError('unexpected ILP status: %s' %
                                     lp_status_str)

        # update FIFO depths
        for (src_node, dst_node), lp_var in lp_vars.items():
            depth = int(pulp.value(lp_var))
            fifo = src_node.fifo(dst_node)
            if fifo.depth != depth:
                _logger.debug('%s * depth %d -> %d', fifo, fifo.depth, depth)
                fifo.depth = depth

        self.verify_mode_depths()
Пример #10
0
def create_dataflow_graph(stencil):
    chronological_tensors = stencil.chronological_tensors
    super_source = SuperSourceNode(
        fwd_nodes={},
        cpt_nodes={},
        super_sink=SuperSinkNode(),
    )

    load_nodes = {
        stmt.name:
        tuple(LoadNode(var=stmt.name, bank=bank) for bank in stmt.dram)
        for stmt in stencil.input_stmts
    }
    store_nodes = {
        stmt.name:
        tuple(StoreNode(var=stmt.name, bank=bank) for bank in stmt.dram)
        for stmt in stencil.output_stmts
    }

    for mem_node in itertools.chain(*load_nodes.values()):
        super_source.add_child(mem_node)
    for mem_node in itertools.chain(*store_nodes.values()):
        mem_node.add_child(super_source.super_sink)

    def color_id(node):
        if isinstance(node, LoadNode):
            return f'\033[33mload {node.var}[bank{node.bank}]\033[0m'
        if isinstance(node, StoreNode):
            return f'\033[36mstore {node.var}[bank{node.bank}]\033[0m'
        if isinstance(node, ForwardNode):
            return f'\033[32mforward {node.tensor.name} @{node.offset}\033[0m'
        if isinstance(node, ComputeNode):
            return f'\033[31mcompute {node.tensor.name} #{node.pe_id}\033[0m'
        return 'unknown node'

    def color_attr(node):
        result = []
        for k, v in node.__dict__.items():
            if (node.__class__, k) in ((SuperSourceNode, 'parents'),
                                       (SuperSinkNode, 'children')):
                continue
            if k in ('parents', 'children'):
                result.append('%s: [%s]' % (k, ', '.join(map(color_id, v))))
            else:
                result.append('%s: %s' % (k, repr(v)))
        return '{%s}' % ', '.join(result)

    def color_print(node):
        return '%s: %s' % (color_id(node), color_attr(node))

    print_node = color_id

    if stencil.replication_factor > 1:
        replicated_next_fifo = stencil.get_replicated_next_fifo()
        replicated_all_points = stencil.get_replicated_all_points()
        replicated_reuse_buffers = stencil.get_replicated_reuse_buffers()

        def add_fwd_nodes(src_name):
            dsts = replicated_all_points[src_name]
            reuse_buffer = replicated_reuse_buffers[src_name][1:]
            nodes_to_add = []
            for dst_point_dicts in dsts.values():
                for offset in dst_point_dicts:
                    if (src_name, offset) in super_source.fwd_nodes:
                        continue
                    fwd_node = ForwardNode(
                        tensor=stencil.tensors[src_name],
                        offset=offset,
                        depth=stencil.get_replicated_reuse_buffer_length(
                            src_name, offset))
                    _logger.debug('create %s', print_node(fwd_node))
                    init_offsets = [
                        start for start, end in reuse_buffer if start == end
                    ]
                    if offset in init_offsets:
                        if src_name in [stencil.input.name]:
                            load_node_count = len(load_nodes[src_name])
                            load_nodes[src_name][
                                load_node_count - 1 -
                                offset % load_node_count].add_child(fwd_node)
                        else:
                            (super_source.cpt_nodes[(src_name,
                                                     0)].add_child(fwd_node))
                    super_source.fwd_nodes[(src_name, offset)] = fwd_node
                    if offset in replicated_next_fifo[src_name]:
                        nodes_to_add.append(
                            (fwd_node,
                             (src_name,
                              replicated_next_fifo[src_name][offset])))
            for src_node, key in nodes_to_add:
                src_node.add_child(super_source.fwd_nodes[key])

        add_fwd_nodes(stencil.input.name)

        for stage in stencil.get_stages_chronologically():
            cpt_node = ComputeNode(stage=stage, pe_id=0)
            _logger.debug('create %s', print_node(cpt_node))
            super_source.cpt_nodes[(stage.name, 0)] = cpt_node
            for input_name, input_window in stage.window.items():
                for i in range(len(input_window)):
                    offset = next(offset for offset, points in (
                        replicated_all_points[input_name][stage.name].items())
                                  if points == i)
                    fwd_node = super_source.fwd_nodes[(input_name, offset)]
                    _logger.debug('  access %s', print_node(fwd_node))
                    fwd_node.add_child(cpt_node)
            if stage.is_output():
                super_source.cpt_nodes[stage.name,
                                       0].add_child(store_nodes[stage.name][0])
            else:
                add_fwd_nodes(stage.name)

    else:
        next_fifo = stencil.next_fifo
        all_points = stencil.all_points
        reuse_buffers = stencil.reuse_buffers

        def add_fwd_nodes(src_name):
            dsts = all_points[src_name]
            reuse_buffer = reuse_buffers[src_name][1:]
            nodes_to_add = []
            for dst_point_dicts in dsts.values():
                for offset in dst_point_dicts:
                    if (src_name, offset) in super_source.fwd_nodes:
                        continue
                    fwd_node = ForwardNode(tensor=stencil.tensors[src_name],
                                           offset=offset)
                    #depth=stencil.get_reuse_buffer_length(src_name, offset))
                    _logger.debug('create %s', print_node(fwd_node))
                    # init_offsets is the start of each reuse chain
                    init_offsets = [
                        next(end for start, end in reuse_buffer
                             if start == unroll_idx) for unroll_idx in
                        reversed(range(stencil.unroll_factor))
                    ]
                    _logger.debug('reuse buffer: %s', reuse_buffer)
                    _logger.debug('init offsets: %s', init_offsets)
                    if offset in init_offsets:
                        if src_name in stencil.input_names:
                            # fwd from external input
                            load_node_count = len(load_nodes[src_name])
                            load_nodes[src_name][
                                load_node_count - 1 -
                                offset % load_node_count].add_child(fwd_node)
                        else:
                            # fwd from output of last stage
                            # tensor name and offset are used to find the cpt node
                            cpt_offset = next(
                                unroll_idx
                                for unroll_idx in range(stencil.unroll_factor)
                                if init_offsets[unroll_idx] == offset)
                            cpt_node = super_source.cpt_nodes[(src_name,
                                                               cpt_offset)]
                            cpt_node.add_child(fwd_node)
                    super_source.fwd_nodes[(src_name, offset)] = fwd_node
                    if offset in next_fifo[src_name]:
                        nodes_to_add.append(
                            (fwd_node, (src_name,
                                        next_fifo[src_name][offset])))
            for src_node, key in nodes_to_add:
                # fwd from another fwd node
                src_node.add_child(super_source.fwd_nodes[key])

        for input_name in stencil.input_names:
            add_fwd_nodes(input_name)

        for tensor in chronological_tensors:
            if tensor.is_input():
                continue
            for unroll_index in range(stencil.unroll_factor):
                pe_id = stencil.unroll_factor - 1 - unroll_index
                cpt_node = ComputeNode(tensor=tensor, pe_id=pe_id)
                _logger.debug('create %s', print_node(cpt_node))
                super_source.cpt_nodes[(tensor.name, pe_id)] = cpt_node
                for input_name, input_window in tensor.ld_indices.items():
                    for i in range(len(input_window)):
                        offset = next(
                            offset
                            for offset, points in all_points[input_name][
                                tensor.name].items()
                            if pe_id in points and points[pe_id] == i)
                        fwd_node = super_source.fwd_nodes[(input_name, offset)]
                        _logger.debug('  access %s', print_node(fwd_node))
                        fwd_node.add_child(cpt_node)
            if tensor.is_output():
                for pe_id in range(stencil.unroll_factor):
                    super_source.cpt_nodes[tensor.name, pe_id].add_child(
                        store_nodes[tensor.name][pe_id % len(
                            store_nodes[tensor.name])])
            else:
                add_fwd_nodes(tensor.name)

    # pylint: disable=too-many-nested-blocks
    for src_node in super_source.tpo_valid_node_gen():
        for dst_node in filter(is_valid_node, src_node.children):
            # 5 possible edge types:
            # 1. load => fwd
            # 2. fwd => fwd
            # 3. fwd => cpt
            # 4. cpt => fwd
            # 5. cpt => store
            if isinstance(src_node, LoadNode):
                write_lat = 0
            elif isinstance(src_node, ForwardNode):
                write_lat = 2
            elif isinstance(src_node, ComputeNode):
                write_lat = src_node.tensor.st_ref.lat
            else:
                raise util.InternalError('unexpected source node: %s' %
                                         repr(src_node))

            fifo = ir.FIFO(src_node, dst_node, depth=0, write_lat=write_lat)
            lets: List[ir.Let] = []
            if isinstance(src_node, LoadNode):
                expr = ir.DRAMRef(
                    haoda_type=dst_node.tensor.haoda_type,
                    dram=(src_node.bank, ),
                    var=dst_node.tensor.name,
                    offset=(stencil.unroll_factor - 1 - dst_node.offset) //
                    len(stencil.stmt_table[dst_node.tensor.name].dram),
                )
            elif isinstance(src_node, ForwardNode):
                if isinstance(dst_node, ComputeNode):
                    dst = src_node.tensor.children[dst_node.tensor.name]
                    src_name = src_node.tensor.name
                    unroll_idx = dst_node.pe_id
                    point = all_points[src_name][dst.name][
                        src_node.offset][unroll_idx]
                    idx = list(dst.ld_indices[src_name].values())[point].idx
                    _logger.debug(
                        '%s%s referenced by <%s> @ unroll_idx=%d is %s',
                        src_name, util.idx2str(idx), dst.name, unroll_idx,
                        print_node(src_node))
                    dst_node.fifo_map[src_name][idx] = fifo
                delay = stencil.reuse_buffer_lengths[src_node.tensor.name]\
                                                    [src_node.offset]
                offset = src_node.offset - delay
                for parent in src_node.parents:  # fwd node has only 1 parent
                    for fifo_r in parent.fifos:
                        if fifo_r.edge == (parent, src_node):
                            break
                if delay > 0:
                    # TODO: build an index somewhere
                    for let in src_node.lets:
                        # pylint: disable=undefined-loop-variable
                        if isinstance(
                                let.expr,
                                ir.DelayedRef) and let.expr.ref == fifo_r:
                            var_name = let.name
                            var_type = let.haoda_type
                            break
                    else:
                        var_name = 'let_%d' % len(src_node.lets)
                        # pylint: disable=undefined-loop-variable
                        var_type = fifo_r.haoda_type
                        lets.append(
                            ir.Let(haoda_type=var_type,
                                   name=var_name,
                                   expr=ir.DelayedRef(delay=delay,
                                                      ref=fifo_r)))
                    expr = ir.Var(name=var_name, idx=[])
                    expr.haoda_type = var_type
                else:
                    expr = fifo_r  # pylint: disable=undefined-loop-variable
            elif isinstance(src_node, ComputeNode):

                def replace_refs_callback(obj, args):
                    if isinstance(obj, ir.Ref):
                        _logger.debug(
                            'replace %s with %s',
                            obj,
                            # pylint: disable=cell-var-from-loop,undefined-loop-variable
                            src_node.fifo_map[obj.name][obj.idx])
                        # pylint: disable=cell-var-from-loop,undefined-loop-variable
                        return src_node.fifo_map[obj.name][obj.idx]
                    return obj

                _logger.debug('lets: %s', src_node.tensor.lets)
                lets = [
                    _.visit(replace_refs_callback)
                    for _ in src_node.tensor.lets
                ]
                _logger.debug('replaced lets: %s', lets)
                _logger.debug('expr: %s', src_node.tensor.expr)
                expr = src_node.tensor.expr.visit(replace_refs_callback)
                _logger.debug('replaced expr: %s', expr)
                if isinstance(dst_node, StoreNode):
                    dram_ref = ir.DRAMRef(
                        haoda_type=src_node.tensor.haoda_type,
                        dram=(dst_node.bank, ),
                        var=src_node.tensor.name,
                        offset=(src_node.pe_id) //
                        len(stencil.stmt_table[src_node.tensor.name].dram),
                    )
                    dst_node.lets.append(
                        ir.Let(haoda_type=None, name=dram_ref, expr=fifo))
            else:
                raise util.InternalError('unexpected node of type %s' %
                                         type(src_node))

            src_node.exprs[fifo] = expr
            src_node.lets.extend(_ for _ in lets if _ not in src_node.lets)
            _logger.debug(
                'fifo [%d]: %s%s => %s', fifo.depth, color_id(src_node),
                '' if fifo.write_lat is None else ' ~%d' % fifo.write_lat,
                color_id(dst_node))

    super_source.update_module_depths({})

    return super_source
Пример #11
0
    def visitor(node: ir.Node, args=None) -> ir.Node:
        if isinstance(node, ir.BinaryOp):

            # Flatten singleton BinaryOp
            if len(node.operand) == 1:
                return flatten(node.operand[0])

            # Flatten BinaryOp with reduction operators
            new_operator, new_operand = [], []
            for child_operator, child_operand in zip((None, *node.operator),
                                                     node.operand):
                if child_operator is not None:
                    new_operator.append(child_operator)
                # The first operator can always be flattened if two operations has the
                # same type.
                if child_operator in (None, '||', '&&', *'|&+*') and \
                    type(child_operand) is type(node):
                    new_operator.extend(child_operand.operator)
                    new_operand.extend(child_operand.operand)
                else:
                    new_operand.append(child_operand)
            # At least 1 operand is flattened.
            if len(new_operand) > len(node.operand):
                return flatten(
                    type(node)(operator=new_operator, operand=new_operand))

        # Flatten compound Operand
        if isinstance(node, ir.Operand):
            for attr in node.ATTRS:
                val = getattr(node, attr)
                if val is not None:
                    if isinstance(val, ir.Node):
                        return flatten(val)
                    break
            else:
                raise util.InternalError('undefined Operand')

        # Flatten identity unary operators
        if isinstance(node, ir.Unary):
            minus_count = node.operator.count('-')
            if minus_count % 2 == 0:
                plus_count = node.operator.count('+')
                if plus_count + minus_count == len(node.operator):
                    return flatten(node.operand)
            not_count = node.operator.count('!')
            if not_count % 2 == 0 and not_count == len(node.operator):
                return flatten(node.operand)

        # Flatten reduction functions
        if isinstance(node, ir.Call):
            operator = getattr(node, 'name')
            if operator in ir.REDUCTION_FUNCS:
                operands = []
                for operand in getattr(node, 'arg'):
                    if (isinstance(operand, ir.Call)
                            and getattr(operand, 'name') == operator):
                        operands.extend(getattr(operand, 'arg'))
                    else:
                        operands.append(operand)
                if len(operands) > len(getattr(node, 'arg')):
                    return flatten(ir.Call(name=operator, arg=operands))

        return node