Exemplo n.º 1
0
 def sync(tensor, offset):
     if tensor is None:
         return offset
     _logger.debug('index of tensor <%s>: %s', tensor.name,
                   tensor.st_idx)
     stage_offset = soda_util.serialize(
         tensor.st_idx, self.tile_size)
     _logger.debug('offset of tensor <%s>: %d', tensor.name,
                   stage_offset)
     loads = visitor.get_load_dict(tensor)
     for name in loads:
         loads[name] = tuple(ref.idx for ref in loads[name])
     _logger.debug(
         'loads: %s', ', '.join(
             '%s@%s' %
             (name,
              util.lst2str(map(util.idx2str, indices)))
             for name, indices in loads.items()))
     for n in loads:
         loads[n] = soda_util.serialize_iter(
             loads[n], self.tile_size)
     for l in loads.values():
         l[0], l[-1] = (stage_offset - max(l),
                        stage_offset - min(l))
         del l[1:-1]
         if len(l) == 1:
             l.append(l[-1])
     _logger.debug(
         'load offset range in tensor %s: %s', tensor.name,
         '{%s}' % (', '.join('%s: [%d:%d]' % (n, *v)
                             for n, v in loads.items())))
     for parent in tensor.parents.values():
         tensor_distance = next(
             reversed(tensor.ld_offsets[parent.name]))
         _logger.debug('tensor distance: %s',
                       tensor_distance)
         _logger.debug(
             'want to access tensor <%s> at offset [%d, %d] '
             'to generate tensor <%s> at offset %d',
             parent.name, offset + loads[parent.name][0],
             offset + loads[parent.name][-1], tensor.name,
             offset)
         tensor_offset = (parent.st_delay +
                          tensor_distance - stage_offset)
         if offset < tensor_offset:
             _logger.debug(
                 'but tensor <%s> won\'t be available until offset %d',
                 parent.name, tensor_offset)
             offset = tensor_offset
             _logger.debug(
                 'need to access tensor <%s> at offset [%d, %d] '
                 'to generate tensor <%s> at offset %d',
                 parent.name,
                 offset + loads[parent.name][0],
                 offset + loads[parent.name][-1],
                 tensor.name, offset)
     return offset
Exemplo n.º 2
0
    def __init__(self, **kwargs):
        self.iterate = kwargs.pop('iterate')
        if self.iterate < 1:
            raise util.SemanticError('cannot iterate %d times' % self.iterate)
        self.border = kwargs.pop('border')
        self.preserve_border = self.border == 'preserve'
        self.cluster = kwargs.pop('cluster')
        # platform determined
        self.burst_width = kwargs.pop('burst_width')
        # application determined
        self.app_name = kwargs.pop('app_name')
        # parameters that can be explored
        self.tile_size = tuple(kwargs.pop('tile_size'))
        self.unroll_factor = kwargs.pop('unroll_factor')
        self.replication_factor = kwargs.pop('replication_factor')
        # stage-independent
        self.dim = kwargs.pop('dim')
        self.param_stmts = kwargs.pop('param_stmts')
        # stage-specific
        self.input_stmts = kwargs.pop('input_stmts')
        self.local_stmts = kwargs.pop('local_stmts')
        self.output_stmts = kwargs.pop('output_stmts')
        self.optimizations = {}
        if 'optimizations' in kwargs:
            self.optimizations = kwargs.pop('optimizations')

        if 'dram_in' in kwargs:
            dram_in = kwargs.pop('dram_in')
            if dram_in is not None:
                if ':' in dram_in:
                    input_stmt_map = {_.name: _ for _ in self.input_stmts}
                    for dram_map in dram_in.split('^'):
                        var_name, bank_list = dram_map.split(':')
                        if var_name not in input_stmt_map:
                            raise util.SemanticError(
                                'no input named `{}`'.format(var_name))
                        input_stmt_map[var_name].dram = tuple(
                            map(int, bank_list.split('.')))
                else:
                    for input_stmt in self.input_stmts:
                        input_stmt.dram = tuple(map(int, dram_in.split('.')))

        if 'dram_out' in kwargs:
            dram_out = kwargs.pop('dram_out')
            if dram_out is not None:
                if ':' in dram_out:
                    output_stmt_map = {_.name: _ for _ in self.output_stmts}
                    for dram_map in dram_out.split(','):
                        var_name, bank_list = dram_map.split(':')
                        if var_name not in output_stmt_map:
                            raise util.SemanticError(
                                'no output named `{}`'.format(var_name))
                        output_stmt_map[var_name].dram = tuple(
                            map(int, bank_list.split('.')))
                else:
                    for output_stmt in self.output_stmts:
                        output_stmt.dram = tuple(map(int, dram_out.split('.')))

        if self.iterate > 1:
            if len(self.input_stmts) != len(self.output_stmts):
                raise util.SemanticError(
                    'number of input tensors must be the same as output if iterate > 1 '
                    'times, currently there are %d input(s) but %d output(s)' %
                    (len(self.input_stmts), len(self.output_stmts)))
            if self.input_types != self.output_types:
                raise util.SemanticError(
                    'input must have the same type(s) as output if iterate > 1 '
                    'times, current input has type %s but output has type %s' %
                    (util.lst2str(
                        self.input_types), util.lst2str(self.output_types)))
            _logger.debug(
                'pipeline %d iterations of [%s] -> [%s]' %
                (self.iterate, ', '.join('%s: %s' %
                                         (stmt.haoda_type, stmt.name)
                                         for stmt in self.input_stmts),
                 ', '.join('%s: %s' % (stmt.haoda_type, stmt.name)
                           for stmt in self.output_stmts)))

        for stmt in itertools.chain(self.local_stmts, self.output_stmts):
            _logger.debug('simplify %s', stmt.name)
            # LocalStmt and OutputStmt must remember the stencil object
            # for type propagation
            stmt.stencil = self
            stmt.expr = arithmetic.simplify(stmt.expr)
            stmt.let = arithmetic.simplify(stmt.let)

        self._cr_counter = 0
        cr.computation_reuse(self)
        if 'inline' in self.optimizations:
            inline.inline(self)
        inline.rebalance(self)

        for stmt in itertools.chain(self.local_stmts, self.output_stmts):
            stmt.propagate_type()

        # soda frontend successfully parsed
        _logger.debug(
            'producer tensors: [%s]',
            ', '.join(tensor.name for tensor in self.producer_tensors))
        _logger.debug(
            'consumer tensors: [%s]',
            ', '.join(tensor.name for tensor in self.consumer_tensors))

        # TODO: build Ref table and Var table
        # generate reuse buffers and get haoda nodes
        # pylint: disable=pointless-statement
        self.dataflow_super_source
        _logger.debug('dataflow: %s', self.dataflow_super_source)

        _logger.debug('module table: %s', dict(self.module_table))
        _logger.debug('module traits: %s', self.module_traits)
Exemplo n.º 3
0
Arquivo: core.py Projeto: Ruola/haoda
 def __str__(self):
     return 'dram<bank {} {}@{}>'.format(util.lst2str(self.dram), self.var,
                                         self.offset)
Exemplo n.º 4
0
    def chronological_tensors(self):
        """Computes the offsets of tensors.

    Returns:
      A list of Tensor, in chronological order.
    """
        _logger.info('calculate tensor offsets')
        processing_queue = collections.deque(list(self.input_names))
        processed_tensors = set(self.input_names)
        chronological_tensors = list(map(self.tensors.get, self.input_names))
        for tensor in chronological_tensors:
            _logger.debug('tensor <%s> is at offset %d' %
                          (tensor.name, tensor.st_offset))
        _logger.debug('processing queue: %s', processing_queue)
        _logger.debug('processed_tensors: %s', processed_tensors)
        while processing_queue:
            tensor = self.tensors[processing_queue.popleft()]
            _logger.debug('inspecting tensor %s\'s children' % tensor.name)
            for child in tensor.children.values():
                if ({x.name
                     for x in child.parents.values()} <= processed_tensors
                        and child.name not in processed_tensors):
                    # good, all inputs are processed
                    # can determine offset of current tensor
                    _logger.debug(
                        'input%s for tensor <%s> (i.e. %s) %s processed',
                        '' if len(child.parents) == 1 else 's', child.name,
                        ', '.join([x.name for x in child.parents.values()]),
                        'is' if len(child.parents) == 1 else 'are')
                    stage_offset = soda_util.serialize(child.st_idx,
                                                       self.tile_size)

                    # synchronization check
                    def sync(tensor, offset):
                        if tensor is None:
                            return offset
                        _logger.debug('index of tensor <%s>: %s', tensor.name,
                                      tensor.st_idx)
                        stage_offset = soda_util.serialize(
                            tensor.st_idx, self.tile_size)
                        _logger.debug('offset of tensor <%s>: %d', tensor.name,
                                      stage_offset)
                        loads = visitor.get_load_dict(tensor)
                        for name in loads:
                            loads[name] = tuple(ref.idx for ref in loads[name])
                        _logger.debug(
                            'loads: %s', ', '.join(
                                '%s@%s' %
                                (name,
                                 util.lst2str(map(util.idx2str, indices)))
                                for name, indices in loads.items()))
                        for n in loads:
                            loads[n] = soda_util.serialize_iter(
                                loads[n], self.tile_size)
                        for l in loads.values():
                            l[0], l[-1] = (stage_offset - max(l),
                                           stage_offset - min(l))
                            del l[1:-1]
                            if len(l) == 1:
                                l.append(l[-1])
                        _logger.debug(
                            'load offset range in tensor %s: %s', tensor.name,
                            '{%s}' % (', '.join('%s: [%d:%d]' % (n, *v)
                                                for n, v in loads.items())))
                        for parent in tensor.parents.values():
                            tensor_distance = next(
                                reversed(tensor.ld_offsets[parent.name]))
                            _logger.debug('tensor distance: %s',
                                          tensor_distance)
                            _logger.debug(
                                'want to access tensor <%s> at offset [%d, %d] '
                                'to generate tensor <%s> at offset %d',
                                parent.name, offset + loads[parent.name][0],
                                offset + loads[parent.name][-1], tensor.name,
                                offset)
                            tensor_offset = (parent.st_delay +
                                             tensor_distance - stage_offset)
                            if offset < tensor_offset:
                                _logger.debug(
                                    'but tensor <%s> won\'t be available until offset %d',
                                    parent.name, tensor_offset)
                                offset = tensor_offset
                                _logger.debug(
                                    'need to access tensor <%s> at offset [%d, %d] '
                                    'to generate tensor <%s> at offset %d',
                                    parent.name,
                                    offset + loads[parent.name][0],
                                    offset + loads[parent.name][-1],
                                    tensor.name, offset)
                        return offset

                    _logger.debug(
                        'intend to generate tensor <%s> at offset %d',
                        child.name, child.st_delay)
                    synced_offset = sync(child, child.st_delay)
                    _logger.debug('synced offset: %s', synced_offset)
                    child.st_delay = synced_offset
                    _logger.debug(
                        'decide to generate tensor <%s> at offset %d',
                        child.name, child.st_delay)

                    # add delay
                    for sibling in child.parents.values():
                        delay = child.st_delay - (sibling.st_delay + list(
                            child.ld_offsets[sibling.name].keys())[-1] -
                                                  stage_offset)
                        if delay > 0:
                            _logger.debug(
                                'tensor %s arrives at tensor <%s> at offset %d < %d; '
                                'add %d delay', sibling.name, child.name,
                                sibling.st_delay + next(
                                    reversed(child.ld_offsets[sibling.name])) -
                                stage_offset, child.st_delay, delay)
                        else:
                            _logger.debug(
                                'tensor %s arrives at tensor <%s> at offset %d = %d; good',
                                sibling.name, child.name,
                                sibling.st_delay + next(
                                    reversed(child.ld_offsets[sibling.name])) -
                                stage_offset, child.st_delay)
                        child.ld_delays[sibling.name] = max(delay, 0)
                        _logger.debug('set delay of |%s <- %s| to %d' %
                                      (child.name, sibling.name,
                                       child.ld_delays[sibling.name]))

                    processing_queue.append(child.name)
                    processed_tensors.add(child.name)
                    chronological_tensors.append(child)
                else:
                    for parent in tensor.parents.values():
                        if parent.name not in processed_tensors:
                            _logger.debug(
                                'tensor %s requires tensor <%s> as an input',
                                tensor.name, parent.name)
                            _logger.debug(
                                'but tensor <%s> isn\'t processed yet',
                                parent.name)
                            _logger.debug('add %s to scheduling queue',
                                          parent.name)
                            processing_queue.append(parent.name)

        _logger.debug('tensors in insertion order: [%s]',
                      ', '.join(map(str, self.tensors)))
        _logger.debug('tensors in chronological order: [%s]',
                      ', '.join(t.name for t in chronological_tensors))

        for tensor in self.tensors.values():
            for name, indices in tensor.ld_indices.items():
                _logger.debug(
                    'stage index: %s@%s <- %s@%s', tensor.name,
                    util.idx2str(tensor.st_idx), name,
                    util.lst2str(util.idx2str(idx) for idx in indices))
        for tensor in self.tensors.values():
            if tensor.is_input():
                continue
            _logger.debug('stage expr: %s = %s', tensor.st_ref, tensor.expr)
        for tensor in self.tensors.values():
            for name, offsets in tensor.ld_offsets.items():
                _logger.debug(
                    'stage offset: %s@%d <- %s@%s', tensor.name,
                    soda_util.serialize(tensor.st_idx, self.tile_size), name,
                    util.lst2str(offsets))
        for tensor in self.tensors.values():
            for name, delay in tensor.ld_delays.items():
                _logger.debug('stage delay: %s <- %s delayed %d' %
                              (tensor.name, name, delay))

        return chronological_tensors
Exemplo n.º 5
0
def inline2(stencil):
    """Inline statements that are referenced by only one other statement.
  """
    if not stencil.local_stmts:
        return stencil

    refs = collections.OrderedDict(
    )  # type: Dict[str, Dict[ir.LocalOrOutputStmt, List[ir.Ref]]]
    for stmt in itertools.chain(stencil.local_stmts, stencil.output_stmts):
        for var_name, ref_list in visitor.get_load_dict(stmt).items():
            if var_name in stencil.input_names or var_name == stmt.name:
                continue
            refs.setdefault(var_name, collections.OrderedDict()).setdefault(
                stmt, []).extend(ref_list)

    refs = {
        name: next(iter(ref_dict.items()))
        for name, ref_dict in refs.items() if len(ref_dict) == 1 and len(
            visitor.get_load_set(
                {stmt.name: stmt.expr
                 for stmt in stencil.local_stmts}[name])) == 1
    }
    for name, (stmt, ref_list) in refs.items():
        _logger.info(
            'name: %s stmt: %s ref_list: %s', name, stmt.name,
            util.lst2str(
                visitor.get_load_set(
                    {stmt.name: stmt.expr
                     for stmt in stencil.local_stmts}[name])))
    if not refs:
        return stencil

    # sort loads to avoid referencing wrong stmt
    local_stmt_table = {
        stmt.name: idx
        for idx, stmt in enumerate(stencil.local_stmts)
    }
    ref_queue = collections.deque(list(refs.items()))
    sorted_refs = []  # type: List[Tuple[ir.Ref, ir.LocalOrOutputStmt]]
    while ref_queue:
        var_name, (load_stmt, ref_list) = ref_queue.popleft()
        store_stmt = stencil.local_stmts[local_stmt_table[ref_list[0].name]]
        accessed_vars = {ref.name for ref in visitor.get_load_set(store_stmt)}
        queued_vars = {var_name for var_name, _ in ref_queue}
        _logger.debug('stmt to be removed: %s', store_stmt)
        _logger.debug('accessed vars: %s', util.lst2str(accessed_vars))
        _logger.debug('queued vars %s', util.lst2str(queued_vars))
        if accessed_vars & queued_vars:
            ref_queue.append((var_name, (load_stmt, ref_list)))
        else:
            sorted_refs.append((var_name, (load_stmt, ref_list)))

    for var_name, (load_stmt, ref_list) in sorted_refs:
        idx, store_stmt = {
            stmt.name: (idx, stmt)
            for idx, stmt in enumerate(stencil.local_stmts)
        }[var_name]
        ref_table = {}
        for ref in ref_list:
            offset = tuple(a - b for a, b in zip(store_stmt.ref.idx, ref.idx))
            ref = mutator.shift(store_stmt.ref, offset)
            lets = tuple(mutator.shift(let, offset) for let in store_stmt.let)
            expr = mutator.shift(store_stmt.expr, offset)
            _logger.info(
                '`%s` is referenced only once by stmt %s, replace with `%s`',
                ref, load_stmt.name, expr)
            ref_table[ref] = expr
        replace_load = lambda obj, args: args.get(obj, obj)
        # TODO: resolve let variable name conflicts
        load_stmt.let = lets + tuple(
            let.visit(replace_load, ref_table) for let in load_stmt.let)
        load_stmt.expr = load_stmt.expr.visit(replace_load, ref_table)
        del stencil.local_stmts[idx]

    # invalidate cached_property
    stencil.__dict__.pop('symbol_table', None)
    stencil.__dict__.pop('local_names', None)
    stencil.__dict__.pop('local_types', None)

    for stmt in itertools.chain(stencil.local_stmts, stencil.output_stmts):
        _logger.debug('simplify  : %s', stmt)
        stmt.expr = arithmetic.simplify(
            arithmetic.base.reverse_distribute(stmt.expr))
        stmt.let = arithmetic.simplify(
            tuple(map(arithmetic.base.reverse_distribute, stmt.let)))
        _logger.debug('simplified:  %s', stmt)
    return inline2(stencil)
Exemplo n.º 6
0
def inline(stencil):
    """Inline statements that are only referenced once.
  """
    if not stencil.local_stmts:
        return stencil

    refs = {}  # type: Dict[str, Set[Tuple[ir.Ref, ir.LocalOrOutputStmt]]]
    for stmt in itertools.chain(stencil.local_stmts, stencil.output_stmts):
        for var_name, ref_list in visitor.get_load_dict(stmt).items():
            if var_name in stencil.input_names or var_name == stmt.name:
                continue
            refs.setdefault(var_name,
                            set()).update(zip(ref_list,
                                              itertools.repeat(stmt)))

    refs = {
        name: next(iter(ref_set))
        for name, ref_set in refs.items() if len(ref_set) == 1
    }
    if not refs:
        return stencil

    # sort loads to avoid referencing wrong stmt
    local_stmt_table = {
        stmt.name: idx
        for idx, stmt in enumerate(stencil.local_stmts)
    }
    ref_queue = collections.deque(list(refs.items()))
    sorted_refs = []  # type: List[Tuple[ir.Ref, ir.LocalOrOutputStmt]]
    while ref_queue:
        var_name, (ref, load_stmt) = ref_queue.popleft()
        store_stmt = stencil.local_stmts[local_stmt_table[ref.name]]
        accessed_vars = {ref.name for ref in visitor.get_load_set(store_stmt)}
        queued_vars = {var_name for var_name, _ in ref_queue}
        _logger.debug('stmt to be removed: %s', store_stmt)
        _logger.debug('accessed vars: %s', util.lst2str(accessed_vars))
        _logger.debug('queued vars %s', util.lst2str(queued_vars))
        if accessed_vars & queued_vars:
            ref_queue.append((var_name, (ref, load_stmt)))
        else:
            sorted_refs.append((var_name, (ref, load_stmt)))

    for var_name, (ref, load_stmt) in sorted_refs:
        idx, store_stmt = {
            stmt.name: (idx, stmt)
            for idx, stmt in enumerate(stencil.local_stmts)
        }[var_name]
        offset = tuple(a - b for a, b in zip(store_stmt.ref.idx, ref.idx))
        ref = mutator.shift(store_stmt.ref, offset)
        lets = tuple(mutator.shift(let, offset) for let in store_stmt.let)
        expr = mutator.shift(store_stmt.expr, offset)
        _logger.info('`%s` is referenced only once, replace with `%s`', ref,
                     expr)
        replace_load = lambda obj, args: args[1] if obj == args[0] else obj
        # TODO: resolve let variable name conflicts
        load_stmt.let = lets + tuple(
            let.visit(replace_load, (ref, expr)) for let in load_stmt.let)
        load_stmt.expr = load_stmt.expr.visit(replace_load, (ref, expr))
        del stencil.local_stmts[idx]

    # invalidate cached_property
    stencil.__dict__.pop('symbol_table', None)
    stencil.__dict__.pop('local_names', None)
    stencil.__dict__.pop('local_types', None)

    for stmt in itertools.chain(stencil.local_stmts, stencil.output_stmts):
        _logger.debug('simplify  : %s', stmt)
        stmt.expr = arithmetic.simplify(stmt.expr)
        stmt.let = arithmetic.simplify(stmt.let)
        _logger.debug('simplified:  %s', stmt)
    return inline(stencil)