def __init__(self, stmt, tile_size): self.haoda_type = stmt.haoda_type self._tile_size = tile_size if isinstance(stmt, grammar.LocalStmtOrOutputStmt): self.st_ref = copy.copy(stmt.ref) self.st_ref.parent = self self.lets = stmt.let self.expr = stmt.expr elif isinstance(stmt, grammar.InputStmt): self._name = stmt.name self.st_ref = None self.lets = [] self.expr = None else: raise util.InternalError('cannot initialize a Tensor from %s' % type(stmt)) _logger.debug('tensor initialized from stmt `%s`', stmt) # pylint: disable=protected-access _logger.debug(' at tx position %d', stmt._tx_position) # these fields are to be set externally self.parents = collections.OrderedDict() self.children = collections.OrderedDict() self.ld_refs = collections.OrderedDict()
def print_kernel_xml( name: str, inputs: Iterable[Tuple[str, str, int, str]], outputs: Iterable[Tuple[str, str, int, str]], kernel_xml: TextIO, interface: str = 'm_axi', ) -> None: """Generate kernel.xml file. Args: name: Name of the kernel. inputs: Iterable of (port_name, bundle_name, width, _) of input ports. outputs: Iterable of (port_name, bundle_name, width, _) of output ports. kernel_xml: File object to write to. interface: Interface type, supported values are 'm_axi' and 'axis'. """ args: List[backend.Arg] = [] if interface == 'm_axi': for ports in outputs, inputs: for port_name, bundle_name, width, _ in ports: args.append( backend.Arg( cat=backend.Cat.MMAP, name=port_name, port=bundle_name, ctype=f'ap_uint<{width}>*', width=width, )) args.append( backend.Arg( cat=backend.Cat.SCALAR, name='coalesced_data_num', port='', ctype='uint64_t', width=64, )) elif interface == 'axis': for cat, ports in ((backend.Cat.ISTREAM, inputs), (backend.Cat.OSTREAM, outputs)): for port_name, _, width, _ in ports: ctype = f'stream<ap_axiu<{width}, 0, 0, 0>>&' args.append( backend.Arg( cat=cat, name=port_name, port='', ctype=ctype, width=width, )) else: raise util.InternalError(f'unexpected interface `{interface}`') backend.print_kernel_xml(name=name, args=args, kernel_xml=kernel_xml)
def verify_mode_depths(self) -> None: latency_table = {} lp_problem = pulp.LpProblem('verify_fifo_depths', pulp.LpMinimize) for node in self.tpo_valid_node_gen(): if self in node.parents: latency_table[node] = 0 else: latency_table[node] = pulp.LpVariable( name=f'latency_{node.name}', lowBound=0, cat='Integer', ) lp_problem.extend( parent.get_latency(node) + latency_table[parent] <= latency_table[node] for parent in node.parents) lp_problem.extend( parent.get_latency(node) + latency_table[parent] + parent.fifo(node).depth >= latency_table[node] for parent in node.parents) lp_status = lp_problem.solve() if lp_status == pulp.LpStatusOptimal: _logger.debug('II=1 check: PASS') elif lp_status == pulp.LpStatusInfeasible: _logger.warn('II=1 check: FAIL') else: lp_status_str = pulp.LpStatus[lp_status] _logger.error('ILP error: %s\n%s', lp_status_str, lp_problem) raise util.InternalError('unexpected ILP status: %s' % lp_status_str) for node in self.tpo_valid_node_gen(): if self in node.parents: min_capacity = 0 else: min_capacity = min( parent.get_latency(node) + int(pulp.value(latency_table[parent])) + parent.fifo(node).depth for parent in node.parents) debug_enabled = _logger.isEnabledFor(logging.DEBUG) check_failed = int(pulp.value(latency_table[node])) > min_capacity if debug_enabled or check_failed: (_logger.debug if debug_enabled else _logger.warn)( 'II=1 check %s: %s: latency %d %s min capacity %d', '✖' if check_failed else '✔', repr(node), int(pulp.value(latency_table[node])), '>' if check_failed else '<=', min_capacity, )
def name_in_iter(name, iteration): if name in self.input_names: if iteration > 0: return name + '_iter%d' % iteration return name if name in self.output_names: if iteration < self.iterate - 1: return (self.input_names[self.output_names.index(name)] + '_iter%d' % (iteration + 1)) return name if name in self.local_names: if iteration > 0: return name + '_iter%d' % iteration return name if name in self.param_names: return name raise util.InternalError('unknown name: %s' % name)
def _get_haoda_type(self): for attr in self.ATTRS: val = getattr(self, attr) if val is not None: if hasattr(val, 'haoda_type'): return ir.Type(val.haoda_type) if attr == 'num': if 'u' in val.lower(): if 'll' in val.lower(): return ir.Type('uint64') return ir.Type('uint32') if 'll' in val.lower(): return ir.Type('int64') if 'fl' in val.lower(): return ir.Type('double') if 'f' in val.lower() or 'e' in val.lower(): return ir.Type('float') if '.' in val: return ir.Type('double') return ir.Type('int32') return None raise util.InternalError('undefined Operand')
def _process_accesses( module_trait: ir.ModuleTrait, burst_width: int, interface: str, ): # input/output channels if interface in {'m_axi', 'axis'}: fifo_loads = [ f'/* input*/ hls::stream<{DATA_TYPE_FMT[interface].format(x)}>&' f' {x.ld_name}' for x in module_trait.loads ] fifo_stores = [ f'/*output*/ hls::stream<{DATA_TYPE_FMT[interface].format(expr)}>&' f' {ir.FIFORef.ST_PREFIX}{idx}' for idx, expr in enumerate(module_trait.exprs) ] # format strings for input/output channels for packing/unpacking modules if interface == 'm_axi': fifo_load_fmt = ("f'/* input*/ hls::stream<Data<ap_uint<{burst_width}>>>&" " {x.dram_fifo_name(bank)}'") fifo_store_fmt = ("f'/*output*/ hls::stream<Data<ap_uint<{burst_width}>>>&" " {x.dram_fifo_name(bank)}'") elif interface == 'axis': fifo_load_fmt = ( "f'/* input*/ hls::stream<ap_axiu<{burst_width}, 0, 0, 0>>&" " {x.dram_fifo_name(bank)}'") fifo_store_fmt = ( "f'/*output*/ hls::stream<ap_axiu<{burst_width}, 0, 0, 0>>&" " {x.dram_fifo_name(bank)}'") # dict mapping variable name to # dict mapping bank tuple to # dict mapping offset to ir.DRAMRef dram_read_map: Dict[str, Dict[Tuple[int, ...], Dict[int, ir.DRAMRef]]] dram_read_map = collections.defaultdict(dict) dram_write_map: Dict[str, Dict[Tuple[int, ...], Dict[int, ir.DRAMRef]]] dram_write_map = collections.defaultdict(dict) num_bank_map: Dict[str, int] = {} all_dram_reads: List[ir.DRAMRef] = [] dram_reads: List[ir.DRAMRef] = [] coalescing_factor = 0 ii = 1 exprs = [_.expr for _ in module_trait.lets] exprs.extend(module_trait.exprs) dram_read_refs: Tuple[ir.DRAMRef, ...] = visitor.get_dram_refs(exprs) dram_write_refs: Tuple[ir.DRAMRef, ...] = visitor.get_dram_refs( _.name for _ in module_trait.lets if not isinstance(_.name, str)) # temporary dict mapping variable name to # dict mapping bank tuple to # list of ir.DRAMRef dram_map: Dict[str, Dict[Tuple[int, ...], List[ir.DRAMRef]]] dram_map = collections.defaultdict(lambda: collections.defaultdict(list)) if dram_read_refs: # this is an unpacking module assert not dram_write_refs, 'cannot read and write DRAM in the same module' for dram_ref in dram_read_refs: dram_map[dram_ref.var][dram_ref.dram].append(dram_ref) _logger.debug('dram read map: %s', dram_map) for var in dram_map: for dram in dram_map[var]: # number of elements per cycle batch_size = len(dram_map[var][dram]) dram_read_map[var][dram] = {_.offset: _ for _ in dram_map[var][dram]} offset_dict = dram_read_map[var][dram] num_banks = len(dram) if var in num_bank_map: assert num_bank_map[var] == num_banks, 'inconsistent num banks' else: num_bank_map[var] = num_banks _logger.debug('dram reads: %s', offset_dict) assert tuple(sorted(offset_dict.keys())) == tuple(range(batch_size)), \ 'unexpected DRAM accesses pattern %s' % offset_dict batch_width = sum( _.haoda_type.width_in_bits for _ in offset_dict.values()) if burst_width * num_banks >= batch_width: assert burst_width * num_banks % batch_width == 0, \ 'cannot process such a burst' # a single burst consumed in multiple cycles coalescing_factor = burst_width * num_banks // batch_width ii = coalescing_factor else: assert batch_width * num_banks % burst_width == 0, \ 'cannot process such a burst' # multiple bursts consumed in a single cycle # reassemble_factor = batch_width // (burst_width * num_banks) raise util.InternalError('cannot process such a burst yet') dram_reads = [next(iter(_.values())) for _ in dram_read_map[var].values()] all_dram_reads += dram_reads fifo_loads.extend( eval(fifo_load_fmt, dict(burst_width=burst_width), locals()) for x in dram_reads for bank in x.dram) elif dram_write_refs: # this is a packing module for dram_ref in dram_write_refs: dram_map[dram_ref.var][dram_ref.dram].append(dram_ref) _logger.debug('dram write map: %s', dram_map) for var in dram_map: for dram in dram_map[var]: # number of elements per cycle batch_size = len(dram_map[var][dram]) dram_write_map[var][dram] = {_.offset: _ for _ in dram_map[var][dram]} offset_dict = dram_write_map[var][dram] num_banks = len(dram) if var in num_bank_map: assert num_bank_map[var] == num_banks, 'inconsistent num banks' else: num_bank_map[var] = num_banks _logger.debug('dram writes: %s', offset_dict) assert tuple(sorted(offset_dict.keys())) == tuple(range(batch_size)), \ 'unexpected DRAM accesses pattern %s' % offset_dict batch_width = sum( _.haoda_type.width_in_bits for _ in offset_dict.values()) if burst_width * num_banks >= batch_width: assert burst_width * num_banks % batch_width == 0, \ 'cannot process such a burst' # a single burst consumed in multiple cycles coalescing_factor = burst_width * num_banks // batch_width ii = coalescing_factor else: assert batch_width * num_banks % burst_width == 0, \ 'cannot process such a burst' # multiple bursts consumed in a single cycle # reassemble_factor = batch_width // (burst_width * num_banks) raise util.InternalError('cannot process such a burst yet') dram_writes = [ next(iter(_.values())) for _ in dram_write_map[var].values() ] fifo_stores.extend( eval(fifo_store_fmt, dict(burst_width=burst_width), locals()) for x in dram_writes for bank in x.dram) return ( fifo_loads, fifo_stores, dram_read_map, dram_write_map, num_bank_map, all_dram_reads, coalescing_factor, ii, )
def _print_module_definition(printer, module_trait, module_trait_id, **kwargs): println = printer.println do_scope = printer.do_scope un_scope = printer.un_scope func_name = util.get_func_name(module_trait_id) func_lower_name = util.get_module_name(module_trait_id) ii = 1 def get_delays(obj, delays): if isinstance(obj, ir.DelayedRef): delays.append(obj) return obj delays = [] for let in module_trait.lets: let.visit(get_delays, delays) for expr in module_trait.exprs: expr.visit(get_delays, delays) _logger.debug('delays: %s', delays) fifo_loads = tuple( '/* input*/ hls::stream<Data<{} > >* {}'.format(_.c_type, _.ld_name) for _ in module_trait.loads) fifo_stores = tuple('/*output*/ hls::stream<Data<{} > >* {}{}'.format( expr.c_type, ir.FIFORef.ST_PREFIX, idx) for idx, expr in enumerate(module_trait.exprs)) # look for DRAM access reads_in_lets = tuple(_.expr for _ in module_trait.lets) writes_in_lets = tuple(_.name for _ in module_trait.lets if not isinstance(_.name, str)) reads_in_exprs = module_trait.exprs dram_reads = visitor.get_dram_refs(reads_in_lets + reads_in_exprs) dram_writes = visitor.get_dram_refs(writes_in_lets) dram_read_map = collections.OrderedDict() dram_write_map = collections.OrderedDict() all_dram_reads = () num_bank_map = {} if dram_reads: # this is an unpacking module assert not dram_writes, 'cannot read and write DRAM in the same module' for dram_read in dram_reads: dram_read_map.setdefault(dram_read.var, collections.OrderedDict()).setdefault( dram_read.dram, []).append(dram_read) _logger.debug('dram read map: %s', dram_read_map) burst_width = kwargs.pop('burst_width') for var in dram_read_map: for dram in dram_read_map[var]: # number of elements per cycle batch_size = len(dram_read_map[var][dram]) dram_read_map[var][dram] = collections.OrderedDict( (_.offset, _) for _ in dram_read_map[var][dram]) dram_reads = dram_read_map[var][dram] num_banks = len(next(iter(dram_reads.values())).dram) if var in num_bank_map: assert num_bank_map[ var] == num_banks, 'inconsistent num banks' else: num_bank_map[var] = num_banks _logger.debug('dram reads: %s', dram_reads) assert tuple(sorted(dram_reads.keys())) == tuple(range(batch_size)), \ 'unexpected DRAM accesses pattern %s' % dram_reads batch_width = sum( util.get_width_in_bits(_.haoda_type) for _ in dram_reads.values()) del dram_reads if burst_width * num_banks >= batch_width: assert burst_width * num_banks % batch_width == 0, \ 'cannot process such a burst' # a single burst consumed in multiple cycles coalescing_factor = burst_width * num_banks // batch_width ii = coalescing_factor else: assert batch_width * num_banks % burst_width == 0, \ 'cannot process such a burst' # multiple bursts consumed in a single cycle # reassemble_factor = batch_width // (burst_width * num_banks) raise util.InternalError('cannot process such a burst yet') dram_reads = tuple( next(iter(_.values())) for _ in dram_read_map[var].values()) all_dram_reads += dram_reads fifo_loads += tuple( '/* input*/ hls::stream<Data<ap_uint<{burst_width} > > >* ' '{bank_name}'.format(burst_width=burst_width, bank_name=_.dram_fifo_name(bank)) for _ in dram_reads for bank in _.dram) elif dram_writes: # this is a packing module for dram_write in dram_writes: dram_write_map.setdefault(dram_write.var, collections.OrderedDict()).setdefault( dram_write.dram, []).append(dram_write) _logger.debug('dram write map: %s', dram_write_map) burst_width = kwargs.pop('burst_width') for var in dram_write_map: for dram in dram_write_map[var]: # number of elements per cycle batch_size = len(dram_write_map[var][dram]) dram_write_map[var][dram] = collections.OrderedDict( (_.offset, _) for _ in dram_write_map[var][dram]) dram_writes = dram_write_map[var][dram] num_banks = len(next(iter(dram_writes.values())).dram) if var in num_bank_map: assert num_bank_map[ var] == num_banks, 'inconsistent num banks' else: num_bank_map[var] = num_banks _logger.debug('dram writes: %s', dram_writes) assert tuple(sorted(dram_writes.keys())) == tuple(range(batch_size)), \ 'unexpected DRAM accesses pattern %s' % dram_writes batch_width = sum( util.get_width_in_bits(_.haoda_type) for _ in dram_writes.values()) del dram_writes if burst_width * num_banks >= batch_width: assert burst_width * num_banks % batch_width == 0, \ 'cannot process such a burst' # a single burst consumed in multiple cycles coalescing_factor = burst_width * num_banks // batch_width ii = coalescing_factor else: assert batch_width * num_banks % burst_width == 0, \ 'cannot process such a burst' # multiple bursts consumed in a single cycle # reassemble_factor = batch_width // (burst_width * num_banks) raise util.InternalError('cannot process such a burst yet') dram_writes = tuple( next(iter(_.values())) for _ in dram_write_map[var].values()) fifo_stores += tuple( '/*output*/ hls::stream<Data<ap_uint<{burst_width} > > >* ' '{bank_name}'.format(burst_width=burst_width, bank_name=_.dram_fifo_name(bank)) for _ in dram_writes for bank in _.dram) # print function printer.print_func('void {func_name}'.format(**locals()), fifo_stores + fifo_loads, align=0) do_scope(func_name) for dram_ref, bank in module_trait.dram_writes: println( '#pragma HLS data_pack variable = {}'.format( dram_ref.dram_fifo_name(bank)), 0) for arg in module_trait.output_fifos: println('#pragma HLS data_pack variable = %s' % arg, 0) for arg in module_trait.input_fifos: println('#pragma HLS data_pack variable = %s' % arg, 0) for dram_ref, bank in module_trait.dram_reads: println( '#pragma HLS data_pack variable = {}'.format( dram_ref.dram_fifo_name(bank)), 0) # print inter-iteration declarations for delay in delays: println(delay.c_buf_decl) println(delay.c_ptr_decl) # print loop println('{}_epoch:'.format(func_lower_name), indent=0) println('for (bool enable = true; enable;)') do_scope('for {}_epoch'.format(func_lower_name)) println('#pragma HLS pipeline II=%d' % ii, 0) for delay in delays: println( '#pragma HLS dependence variable=%s inter false' % delay.buf_name, 0) # print emptyness tests println( 'if (%s)' % (' && '.join('!{fifo}->empty()'.format(fifo=fifo) for fifo in tuple(_.ld_name for _ in module_trait.loads) + tuple( _.dram_fifo_name(bank) for _ in all_dram_reads for bank in _.dram)))) do_scope('if not empty') # print intra-iteration declarations for fifo_in in module_trait.loads: println('{fifo_in.c_type} {fifo_in.ref_name};'.format(**locals())) for var in dram_read_map: for dram in (next(iter(_.values())) for _ in dram_read_map[var].values()): for bank in dram.dram: println('ap_uint<{}> {};'.format(burst_width, dram.dram_buf_name(bank))) for var in dram_write_map: for dram in (next(iter(_.values())) for _ in dram_write_map[var].values()): for bank in dram.dram: println('ap_uint<{}> {};'.format(burst_width, dram.dram_buf_name(bank))) # print enable conditions if not dram_write_map: for fifo_in in module_trait.loads: println('const bool {fifo_in.ref_name}_enable = ' 'ReadData(&{fifo_in.ref_name}, {fifo_in.ld_name});'.format( **locals())) for dram in all_dram_reads: for bank in dram.dram: println('const bool {dram_buf_name}_enable = ' 'ReadData(&{dram_buf_name}, {dram_fifo_name});'.format( dram_buf_name=dram.dram_buf_name(bank), dram_fifo_name=dram.dram_fifo_name(bank))) if not dram_write_map: println('const bool enabled = %s;' % (' && '.join( tuple('{_.ref_name}_enable'.format(_=_) for _ in module_trait.loads) + tuple('{}_enable'.format(_.dram_buf_name(bank)) for _ in all_dram_reads for bank in _.dram)))) println('enable = enabled;') # print delays (if any) for delay in delays: println('const {} {};'.format(delay.c_type, delay.c_buf_load)) # print lets def mutate_dram_ref_for_writes(obj, kwargs): if isinstance(obj, ir.DRAMRef): coalescing_idx = kwargs.pop('coalescing_idx') unroll_factor = kwargs.pop('unroll_factor') type_width = util.get_width_in_bits(obj.haoda_type) elem_idx = coalescing_idx * unroll_factor + obj.offset num_banks = num_bank_map[obj.var] bank = obj.dram[elem_idx % num_banks] lsb = (elem_idx // num_banks) * type_width msb = lsb + type_width - 1 return ir.Var(name='{}({msb}, {lsb})'.format( obj.dram_buf_name(bank), msb=msb, lsb=lsb), idx=()) return obj # mutate dram ref for writes if dram_write_map: for coalescing_idx in range(coalescing_factor): for fifo_in in module_trait.loads: if coalescing_idx == coalescing_factor - 1: prefix = 'const bool {fifo_in.ref_name}_enable = '.format( fifo_in=fifo_in) else: prefix = '' println('{prefix}ReadData(&{fifo_in.ref_name},' ' {fifo_in.ld_name});'.format(fifo_in=fifo_in, prefix=prefix)) if coalescing_idx == coalescing_factor - 1: println('const bool enabled = %s;' % (' && '.join( tuple('{_.ref_name}_enable'.format(_=_) for _ in module_trait.loads) + tuple('{}_enable'.format(_.dram_buf_name(bank)) for _ in dram_reads for bank in _.dram)))) println('enable = enabled;') for idx, let in enumerate(module_trait.lets): let = let.visit( mutate_dram_ref_for_writes, { 'coalescing_idx': coalescing_idx, 'unroll_factor': len(dram_write_map[let.name.var][let.name.dram]) }) println('{} = Reinterpret<ap_uint<{width} > >({});'.format( let.name, let.expr.c_expr, width=util.get_width_in_bits(let.expr.haoda_type))) for var in dram_write_map: for dram in (next(iter(_.values())) for _ in dram_write_map[var].values()): for bank in dram.dram: println('WriteData({}, {}, enabled);'.format( dram.dram_fifo_name(bank), dram.dram_buf_name(bank))) else: for let in module_trait.lets: println(let.c_expr) def mutate_dram_ref_for_reads(obj, kwargs): if isinstance(obj, ir.DRAMRef): coalescing_idx = kwargs.pop('coalescing_idx') unroll_factor = kwargs.pop('unroll_factor') type_width = util.get_width_in_bits(obj.haoda_type) elem_idx = coalescing_idx * unroll_factor + obj.offset num_banks = num_bank_map[obj.var] bank = expr.dram[elem_idx % num_banks] lsb = (elem_idx // num_banks) * type_width msb = lsb + type_width - 1 return ir.Var( name='Reinterpret<{c_type}>(static_cast<ap_uint<{width} > >(' '{dram_buf_name}({msb}, {lsb})))'.format( c_type=obj.c_type, dram_buf_name=obj.dram_buf_name(bank), msb=msb, lsb=lsb, width=msb - lsb + 1), idx=()) return obj # mutate dram ref for reads if dram_read_map: for coalescing_idx in range(coalescing_factor): for idx, expr in enumerate(module_trait.exprs): println('WriteData({}{}, {}, {});'.format( ir.FIFORef.ST_PREFIX, idx, expr.visit( mutate_dram_ref_for_reads, { 'coalescing_idx': coalescing_idx, 'unroll_factor': len(dram_read_map[expr.var][expr.dram]) }).c_expr, 'true' if coalescing_idx < coalescing_factor - 1 else 'enabled')) else: for idx, expr in enumerate(module_trait.exprs): println('WriteData({}{}, {}({}), enabled);'.format( ir.FIFORef.ST_PREFIX, idx, expr.c_type, expr.c_expr)) for delay in delays: println(delay.c_buf_store) println('{} = {};'.format(delay.ptr, delay.c_next_ptr_expr)) un_scope() un_scope() un_scope() _logger.debug('printing: %s', module_trait)
def tensors(self): """Constructs high-level DAG and creates the tensors. Returns: An collections.OrderedDict mapping a tensor's name to the tensor. """ # TODO: check for name conflicts tensor_map = collections.OrderedDict() for stmt in self.input_stmts: tensor = soda.tensor.Tensor(stmt, self.tile_size) tensor_map[stmt.name] = tensor def name_in_iter(name, iteration): if name in self.input_names: if iteration > 0: return name + '_iter%d' % iteration return name if name in self.output_names: if iteration < self.iterate - 1: return (self.input_names[self.output_names.index(name)] + '_iter%d' % (iteration + 1)) return name if name in self.local_names: if iteration > 0: return name + '_iter%d' % iteration return name if name in self.param_names: return name raise util.InternalError('unknown name: %s' % name) for iteration in range(self.iterate): _logger.debug('iterate %s', iteration) _logger.debug('map: %s', self.symbol_table) def mutate_name_callback(obj, mutated): if isinstance(obj, ir.Ref): obj.haoda_type = self.symbol_table[obj.name] # pylint: disable=cell-var-from-loop obj.name = name_in_iter(obj.name, iteration) return obj tensors = [] for stmt in itertools.chain(self.local_stmts, self.output_stmts): tensor = soda.tensor.Tensor(stmt.visit(mutate_name_callback), self.tile_size) tensor_map[tensor.name] = tensor tensors.append(tensor) for tensor in tensors: _logger.debug('%s', tensor) for tensor in tensors: tensor.propagate_type() loads = visitor.get_load_dict(tensor) for parent_name, ld_refs in loads.items(): ld_refs = sorted(ld_refs, key=lambda ref: soda.util.serialize( ref.idx, self.tile_size)) parent_tensor = tensor_map[parent_name] parent_tensor.children[tensor.name] = tensor tensor.parents[parent_name] = parent_tensor tensor.ld_refs[parent_name] = ld_refs # solve ILP for optimal reuse buffer lp_problem = pulp.LpProblem("optimal_reuse_buffer", pulp.LpMinimize) lp_vars = {self.input_names[0]: 0} # set 1 and only 1 reference point lp_helper_vars = {} # type: Dict[str, pulp.LpVariable] objectives = [] constraints = [] for tensor in tensor_map.values(): lp_var = pulp.LpVariable('produced_offset_' + tensor.name, cat='Integer') lp_helper_var = pulp.LpVariable('consumed_offset_' + tensor.name, cat='Integer') lp_vars.setdefault(tensor.name, lp_var) lp_helper_vars[tensor.name] = lp_helper_var # tensor need to be kept for this long objectives.append(lp_helper_var - lp_vars[tensor.name]) # tensor cannot be consumed until it is produced constraints.append(lp_helper_var >= lp_vars[tensor.name]) lp_problem += sum(objectives) lp_problem.extend(constraints) for st_tensor in tensor_map.values(): for ld_tensor_name, offsets in st_tensor.ld_offsets.items(): oldest_access = min(offsets) newest_access = max(offsets) _logger.debug('%s @ %s accesses %s @ [%s, %s]', st_tensor.name, st_tensor.st_offset, ld_tensor_name, oldest_access, newest_access) # newest ld_tensor access must have been produced # when st_tensor is produced lp_problem += lp_vars[ld_tensor_name] <= lp_vars[ st_tensor.name] + (st_tensor.st_offset - newest_access) # oldest ld_tensor access must have been not consumed # when st_tensor is produced lp_problem += lp_helper_vars[ld_tensor_name] >= lp_vars[ st_tensor.name] + (st_tensor.st_offset - oldest_access) lp_status = lp_problem.solve() lp_status_str = pulp.LpStatus[lp_status] total_distance = int(pulp.value(lp_problem.objective)) _logger.debug('ILP status: %s %s', lp_status_str, total_distance) _logger.info('total reuse distance: %d', total_distance) if lp_status != pulp.LpStatusOptimal: _logger.error('ILP error: %s\n%s', lp_status_str, lp_problem) raise util.InternalError('unexpected ILP status: %s' % lp_status_str) # some inputs may need to be delayed relative to others base = min(int(pulp.value(lp_vars[x])) for x in self.input_names) # set produce offsets for tensor in tensor_map.values(): produce_offset = int(pulp.value(lp_vars[tensor.name])) - base consume_offset = int(pulp.value( lp_helper_vars[tensor.name])) - base tensor.produce_offset = produce_offset tensor.consume_offset = consume_offset tensor.max_access = 0 # pixels before current produce _logger.debug('%s should be produced @ %d and kept until %d', tensor.name, produce_offset, consume_offset) # calculate overall acceses for ld_tensor in tensor_map.values(): for st_tensor in ld_tensor.children.values(): oldest_access = st_tensor.st_offset - min( st_tensor.ld_offsets[ld_tensor.name] ) + st_tensor.produce_offset - ld_tensor.produce_offset newest_access = st_tensor.st_offset - max( st_tensor.ld_offsets[ld_tensor.name] ) + st_tensor.produce_offset - ld_tensor.produce_offset _logger.debug( ' producing %s @ %s accesses [%s, %s] pixels before %s ' 'produced @ %s', st_tensor.name, st_tensor.produce_offset, newest_access, oldest_access, ld_tensor.name, ld_tensor.produce_offset) ld_tensor.max_access = max(ld_tensor.max_access, oldest_access) for tensor in tensor_map.values(): _logger.debug('%s should be kept for %s pixels', tensor.name, tensor.max_access) # high-level DAG construction finished for tensor in tensor_map.values(): if tensor.name in self.input_names: _logger.debug('<input tensor>: %s', tensor) elif tensor.name in self.output_names: _logger.debug('<output tensor>: %s', tensor) else: _logger.debug('<local tensor>: %s', tensor) return tensor_map
def update_module_depths( self, depths: Dict[int, int], ) -> None: """Update module pipeline depths and FIFO depths. The FIFO depths are determined by solving an ILP problem: + Optimization objective: minimize the sum (weighted by FIFO width) of all FIFO depths. + Constraints: the whole DAG can be fully pipelined without artificial stalls. For every non-overlapping path between a pair of nodes, the latency of each token is the maximum minimum latency among all paths. To enable full pipelining, this latency must not exceed the maximum latency of any path. The minimum latency of each path is the sum of the FIFO write latency in each module and the number of edges (FIFOs), since the minimum latency of a FIFO is 1. The maximum latency of each path is the sum of the FIFO write latency in each module and the total depth of FIFOs. Args: depths (Dict[int, int]): Dict mapping module ids to pipeline depths. """ # update module pipeline depths for src_node, dst_node in self.bfs_valid_edge_gen(): module_id = self.module_table[src_node][1] depth = depths.get(module_id) if depth is not None: fifo = src_node.fifo(dst_node) if fifo.write_lat != depth: _logger.debug('%s write latency changed %s -> %d', fifo, fifo.write_lat, depth) fifo.write_lat = depth # set up ILP problem, variables, and objective lp_problem = pulp.LpProblem('optimal_fifo_depths', pulp.LpMinimize) lp_vars = {} for src_node, dst_node in self.bfs_valid_edge_gen(): lp_vars[(src_node, dst_node)] = pulp.LpVariable( name=f'depth_{src_node.fifo(dst_node).c_expr}', lowBound=0, cat='Integer', ) lp_problem += sum( x.fifo(y).haoda_type.width_in_bits * v for (x, y), v in lp_vars.items()) # add ILP constraints latency_table = { x: pulp.LpVariable(name=f'latency_{x.name}', lowBound=0, cat='Integer') for x in self.tpo_valid_node_gen() } for node in self.tpo_valid_node_gen(): if self in node.parents: latency_table[node] = 0 else: lp_problem.extend( parent.get_latency(node) + latency_table[parent] <= latency_table[node] for parent in node.parents) lp_problem.extend( parent.get_latency(node) + latency_table[parent] + lp_vars[(parent, node)] >= latency_table[node] for parent in node.parents) # solve ILP lp_status = lp_problem.solve() if lp_status != pulp.LpStatusOptimal: lp_status_str = pulp.LpStatus[lp_status] _logger.error('ILP error: %s\n%s', lp_status_str, lp_problem) raise util.InternalError('unexpected ILP status: %s' % lp_status_str) # update FIFO depths for (src_node, dst_node), lp_var in lp_vars.items(): depth = int(pulp.value(lp_var)) fifo = src_node.fifo(dst_node) if fifo.depth != depth: _logger.debug('%s * depth %d -> %d', fifo, fifo.depth, depth) fifo.depth = depth self.verify_mode_depths()
def create_dataflow_graph(stencil): chronological_tensors = stencil.chronological_tensors super_source = SuperSourceNode( fwd_nodes={}, cpt_nodes={}, super_sink=SuperSinkNode(), ) load_nodes = { stmt.name: tuple(LoadNode(var=stmt.name, bank=bank) for bank in stmt.dram) for stmt in stencil.input_stmts } store_nodes = { stmt.name: tuple(StoreNode(var=stmt.name, bank=bank) for bank in stmt.dram) for stmt in stencil.output_stmts } for mem_node in itertools.chain(*load_nodes.values()): super_source.add_child(mem_node) for mem_node in itertools.chain(*store_nodes.values()): mem_node.add_child(super_source.super_sink) def color_id(node): if isinstance(node, LoadNode): return f'\033[33mload {node.var}[bank{node.bank}]\033[0m' if isinstance(node, StoreNode): return f'\033[36mstore {node.var}[bank{node.bank}]\033[0m' if isinstance(node, ForwardNode): return f'\033[32mforward {node.tensor.name} @{node.offset}\033[0m' if isinstance(node, ComputeNode): return f'\033[31mcompute {node.tensor.name} #{node.pe_id}\033[0m' return 'unknown node' def color_attr(node): result = [] for k, v in node.__dict__.items(): if (node.__class__, k) in ((SuperSourceNode, 'parents'), (SuperSinkNode, 'children')): continue if k in ('parents', 'children'): result.append('%s: [%s]' % (k, ', '.join(map(color_id, v)))) else: result.append('%s: %s' % (k, repr(v))) return '{%s}' % ', '.join(result) def color_print(node): return '%s: %s' % (color_id(node), color_attr(node)) print_node = color_id if stencil.replication_factor > 1: replicated_next_fifo = stencil.get_replicated_next_fifo() replicated_all_points = stencil.get_replicated_all_points() replicated_reuse_buffers = stencil.get_replicated_reuse_buffers() def add_fwd_nodes(src_name): dsts = replicated_all_points[src_name] reuse_buffer = replicated_reuse_buffers[src_name][1:] nodes_to_add = [] for dst_point_dicts in dsts.values(): for offset in dst_point_dicts: if (src_name, offset) in super_source.fwd_nodes: continue fwd_node = ForwardNode( tensor=stencil.tensors[src_name], offset=offset, depth=stencil.get_replicated_reuse_buffer_length( src_name, offset)) _logger.debug('create %s', print_node(fwd_node)) init_offsets = [ start for start, end in reuse_buffer if start == end ] if offset in init_offsets: if src_name in [stencil.input.name]: load_node_count = len(load_nodes[src_name]) load_nodes[src_name][ load_node_count - 1 - offset % load_node_count].add_child(fwd_node) else: (super_source.cpt_nodes[(src_name, 0)].add_child(fwd_node)) super_source.fwd_nodes[(src_name, offset)] = fwd_node if offset in replicated_next_fifo[src_name]: nodes_to_add.append( (fwd_node, (src_name, replicated_next_fifo[src_name][offset]))) for src_node, key in nodes_to_add: src_node.add_child(super_source.fwd_nodes[key]) add_fwd_nodes(stencil.input.name) for stage in stencil.get_stages_chronologically(): cpt_node = ComputeNode(stage=stage, pe_id=0) _logger.debug('create %s', print_node(cpt_node)) super_source.cpt_nodes[(stage.name, 0)] = cpt_node for input_name, input_window in stage.window.items(): for i in range(len(input_window)): offset = next(offset for offset, points in ( replicated_all_points[input_name][stage.name].items()) if points == i) fwd_node = super_source.fwd_nodes[(input_name, offset)] _logger.debug(' access %s', print_node(fwd_node)) fwd_node.add_child(cpt_node) if stage.is_output(): super_source.cpt_nodes[stage.name, 0].add_child(store_nodes[stage.name][0]) else: add_fwd_nodes(stage.name) else: next_fifo = stencil.next_fifo all_points = stencil.all_points reuse_buffers = stencil.reuse_buffers def add_fwd_nodes(src_name): dsts = all_points[src_name] reuse_buffer = reuse_buffers[src_name][1:] nodes_to_add = [] for dst_point_dicts in dsts.values(): for offset in dst_point_dicts: if (src_name, offset) in super_source.fwd_nodes: continue fwd_node = ForwardNode(tensor=stencil.tensors[src_name], offset=offset) #depth=stencil.get_reuse_buffer_length(src_name, offset)) _logger.debug('create %s', print_node(fwd_node)) # init_offsets is the start of each reuse chain init_offsets = [ next(end for start, end in reuse_buffer if start == unroll_idx) for unroll_idx in reversed(range(stencil.unroll_factor)) ] _logger.debug('reuse buffer: %s', reuse_buffer) _logger.debug('init offsets: %s', init_offsets) if offset in init_offsets: if src_name in stencil.input_names: # fwd from external input load_node_count = len(load_nodes[src_name]) load_nodes[src_name][ load_node_count - 1 - offset % load_node_count].add_child(fwd_node) else: # fwd from output of last stage # tensor name and offset are used to find the cpt node cpt_offset = next( unroll_idx for unroll_idx in range(stencil.unroll_factor) if init_offsets[unroll_idx] == offset) cpt_node = super_source.cpt_nodes[(src_name, cpt_offset)] cpt_node.add_child(fwd_node) super_source.fwd_nodes[(src_name, offset)] = fwd_node if offset in next_fifo[src_name]: nodes_to_add.append( (fwd_node, (src_name, next_fifo[src_name][offset]))) for src_node, key in nodes_to_add: # fwd from another fwd node src_node.add_child(super_source.fwd_nodes[key]) for input_name in stencil.input_names: add_fwd_nodes(input_name) for tensor in chronological_tensors: if tensor.is_input(): continue for unroll_index in range(stencil.unroll_factor): pe_id = stencil.unroll_factor - 1 - unroll_index cpt_node = ComputeNode(tensor=tensor, pe_id=pe_id) _logger.debug('create %s', print_node(cpt_node)) super_source.cpt_nodes[(tensor.name, pe_id)] = cpt_node for input_name, input_window in tensor.ld_indices.items(): for i in range(len(input_window)): offset = next( offset for offset, points in all_points[input_name][ tensor.name].items() if pe_id in points and points[pe_id] == i) fwd_node = super_source.fwd_nodes[(input_name, offset)] _logger.debug(' access %s', print_node(fwd_node)) fwd_node.add_child(cpt_node) if tensor.is_output(): for pe_id in range(stencil.unroll_factor): super_source.cpt_nodes[tensor.name, pe_id].add_child( store_nodes[tensor.name][pe_id % len( store_nodes[tensor.name])]) else: add_fwd_nodes(tensor.name) # pylint: disable=too-many-nested-blocks for src_node in super_source.tpo_valid_node_gen(): for dst_node in filter(is_valid_node, src_node.children): # 5 possible edge types: # 1. load => fwd # 2. fwd => fwd # 3. fwd => cpt # 4. cpt => fwd # 5. cpt => store if isinstance(src_node, LoadNode): write_lat = 0 elif isinstance(src_node, ForwardNode): write_lat = 2 elif isinstance(src_node, ComputeNode): write_lat = src_node.tensor.st_ref.lat else: raise util.InternalError('unexpected source node: %s' % repr(src_node)) fifo = ir.FIFO(src_node, dst_node, depth=0, write_lat=write_lat) lets: List[ir.Let] = [] if isinstance(src_node, LoadNode): expr = ir.DRAMRef( haoda_type=dst_node.tensor.haoda_type, dram=(src_node.bank, ), var=dst_node.tensor.name, offset=(stencil.unroll_factor - 1 - dst_node.offset) // len(stencil.stmt_table[dst_node.tensor.name].dram), ) elif isinstance(src_node, ForwardNode): if isinstance(dst_node, ComputeNode): dst = src_node.tensor.children[dst_node.tensor.name] src_name = src_node.tensor.name unroll_idx = dst_node.pe_id point = all_points[src_name][dst.name][ src_node.offset][unroll_idx] idx = list(dst.ld_indices[src_name].values())[point].idx _logger.debug( '%s%s referenced by <%s> @ unroll_idx=%d is %s', src_name, util.idx2str(idx), dst.name, unroll_idx, print_node(src_node)) dst_node.fifo_map[src_name][idx] = fifo delay = stencil.reuse_buffer_lengths[src_node.tensor.name]\ [src_node.offset] offset = src_node.offset - delay for parent in src_node.parents: # fwd node has only 1 parent for fifo_r in parent.fifos: if fifo_r.edge == (parent, src_node): break if delay > 0: # TODO: build an index somewhere for let in src_node.lets: # pylint: disable=undefined-loop-variable if isinstance( let.expr, ir.DelayedRef) and let.expr.ref == fifo_r: var_name = let.name var_type = let.haoda_type break else: var_name = 'let_%d' % len(src_node.lets) # pylint: disable=undefined-loop-variable var_type = fifo_r.haoda_type lets.append( ir.Let(haoda_type=var_type, name=var_name, expr=ir.DelayedRef(delay=delay, ref=fifo_r))) expr = ir.Var(name=var_name, idx=[]) expr.haoda_type = var_type else: expr = fifo_r # pylint: disable=undefined-loop-variable elif isinstance(src_node, ComputeNode): def replace_refs_callback(obj, args): if isinstance(obj, ir.Ref): _logger.debug( 'replace %s with %s', obj, # pylint: disable=cell-var-from-loop,undefined-loop-variable src_node.fifo_map[obj.name][obj.idx]) # pylint: disable=cell-var-from-loop,undefined-loop-variable return src_node.fifo_map[obj.name][obj.idx] return obj _logger.debug('lets: %s', src_node.tensor.lets) lets = [ _.visit(replace_refs_callback) for _ in src_node.tensor.lets ] _logger.debug('replaced lets: %s', lets) _logger.debug('expr: %s', src_node.tensor.expr) expr = src_node.tensor.expr.visit(replace_refs_callback) _logger.debug('replaced expr: %s', expr) if isinstance(dst_node, StoreNode): dram_ref = ir.DRAMRef( haoda_type=src_node.tensor.haoda_type, dram=(dst_node.bank, ), var=src_node.tensor.name, offset=(src_node.pe_id) // len(stencil.stmt_table[src_node.tensor.name].dram), ) dst_node.lets.append( ir.Let(haoda_type=None, name=dram_ref, expr=fifo)) else: raise util.InternalError('unexpected node of type %s' % type(src_node)) src_node.exprs[fifo] = expr src_node.lets.extend(_ for _ in lets if _ not in src_node.lets) _logger.debug( 'fifo [%d]: %s%s => %s', fifo.depth, color_id(src_node), '' if fifo.write_lat is None else ' ~%d' % fifo.write_lat, color_id(dst_node)) super_source.update_module_depths({}) return super_source
def visitor(node: ir.Node, args=None) -> ir.Node: if isinstance(node, ir.BinaryOp): # Flatten singleton BinaryOp if len(node.operand) == 1: return flatten(node.operand[0]) # Flatten BinaryOp with reduction operators new_operator, new_operand = [], [] for child_operator, child_operand in zip((None, *node.operator), node.operand): if child_operator is not None: new_operator.append(child_operator) # The first operator can always be flattened if two operations has the # same type. if child_operator in (None, '||', '&&', *'|&+*') and \ type(child_operand) is type(node): new_operator.extend(child_operand.operator) new_operand.extend(child_operand.operand) else: new_operand.append(child_operand) # At least 1 operand is flattened. if len(new_operand) > len(node.operand): return flatten( type(node)(operator=new_operator, operand=new_operand)) # Flatten compound Operand if isinstance(node, ir.Operand): for attr in node.ATTRS: val = getattr(node, attr) if val is not None: if isinstance(val, ir.Node): return flatten(val) break else: raise util.InternalError('undefined Operand') # Flatten identity unary operators if isinstance(node, ir.Unary): minus_count = node.operator.count('-') if minus_count % 2 == 0: plus_count = node.operator.count('+') if plus_count + minus_count == len(node.operator): return flatten(node.operand) not_count = node.operator.count('!') if not_count % 2 == 0 and not_count == len(node.operator): return flatten(node.operand) # Flatten reduction functions if isinstance(node, ir.Call): operator = getattr(node, 'name') if operator in ir.REDUCTION_FUNCS: operands = [] for operand in getattr(node, 'arg'): if (isinstance(operand, ir.Call) and getattr(operand, 'name') == operator): operands.extend(getattr(operand, 'arg')) else: operands.append(operand) if len(operands) > len(getattr(node, 'arg')): return flatten(ir.Call(name=operator, arg=operands)) return node