Beispiel #1
0
def get_result_type(operand1, operand2, operator):
    for t in ('double', 'float') + sum(
        (('int%d_t' % w, 'uint%d_t' % w) for w in (64, 32, 16, 8)), tuple()):
        if t in (operand1, operand2):
            return t
    raise util.SemanticError('cannot parse type: %s %s %s' %
                             (operand1, operator, operand2))
Beispiel #2
0
 def _get_expr_for(self, node):
     if isinstance(node, grammar.Output):
         return node.expr
     if isinstance(node, grammar.Local):
         return node.expr
     raise util.SemanticError('cannot get expression for %s' %
                              str(type(node)))
Beispiel #3
0
 def __init__(self, **kwargs):
     super().__init__(**kwargs)
     for node in self.input_stmts:
         if hasattr(self, 'tile_size'):
             # pylint: disable=access-member-before-definition
             if self.tile_size != node.tile_size:
                 msg = (
                     'tile size %s doesn\'t match previous one %s' %
                     # pylint: disable=access-member-before-definition
                     (node.tile_size, self.tile_size))
                 raise util.SemanticError(msg)
         elif node.tile_size[:-1]:
             self.tile_size = node.tile_size
             self.dim = len(self.tile_size)
     # deal with 1D case
     if not hasattr(self, 'tile_size'):
         # pylint: disable=undefined-loop-variable
         self.tile_size = node.tile_size
         self.dim = len(self.tile_size)
Beispiel #4
0
    def __init__(self, **kwargs):
        self.iterate = kwargs.pop('iterate')
        if self.iterate < 1:
            raise util.SemanticError('cannot iterate %d times' % self.iterate)
        self.border = kwargs.pop('border')
        self.preserve_border = self.border == 'preserve'
        self.cluster = kwargs.pop('cluster')
        # platform determined
        self.burst_width = kwargs.pop('burst_width')
        # application determined
        self.app_name = kwargs.pop('app_name')
        # parameters that can be explored
        self.tile_size = tuple(kwargs.pop('tile_size'))
        self.unroll_factor = kwargs.pop('unroll_factor')
        self.replication_factor = kwargs.pop('replication_factor')
        # stage-independent
        self.dim = kwargs.pop('dim')
        self.param_stmts = kwargs.pop('param_stmts')
        # stage-specific
        self.input_stmts = kwargs.pop('input_stmts')
        self.local_stmts = kwargs.pop('local_stmts')
        self.output_stmts = kwargs.pop('output_stmts')
        self.optimizations = {}
        if 'optimizations' in kwargs:
            self.optimizations = kwargs.pop('optimizations')

        if 'dram_in' in kwargs:
            dram_in = kwargs.pop('dram_in')
            if dram_in is not None:
                if ':' in dram_in:
                    input_stmt_map = {_.name: _ for _ in self.input_stmts}
                    for dram_map in dram_in.split('^'):
                        var_name, bank_list = dram_map.split(':')
                        if var_name not in input_stmt_map:
                            raise util.SemanticError(
                                'no input named `{}`'.format(var_name))
                        input_stmt_map[var_name].dram = tuple(
                            map(int, bank_list.split('.')))
                else:
                    for input_stmt in self.input_stmts:
                        input_stmt.dram = tuple(map(int, dram_in.split('.')))

        if 'dram_out' in kwargs:
            dram_out = kwargs.pop('dram_out')
            if dram_out is not None:
                if ':' in dram_out:
                    output_stmt_map = {_.name: _ for _ in self.output_stmts}
                    for dram_map in dram_out.split(','):
                        var_name, bank_list = dram_map.split(':')
                        if var_name not in output_stmt_map:
                            raise util.SemanticError(
                                'no output named `{}`'.format(var_name))
                        output_stmt_map[var_name].dram = tuple(
                            map(int, bank_list.split('.')))
                else:
                    for output_stmt in self.output_stmts:
                        output_stmt.dram = tuple(map(int, dram_out.split('.')))

        if self.iterate > 1:
            if len(self.input_stmts) != len(self.output_stmts):
                raise util.SemanticError(
                    'number of input tensors must be the same as output if iterate > 1 '
                    'times, currently there are %d input(s) but %d output(s)' %
                    (len(self.input_stmts), len(self.output_stmts)))
            if self.input_types != self.output_types:
                raise util.SemanticError(
                    'input must have the same type(s) as output if iterate > 1 '
                    'times, current input has type %s but output has type %s' %
                    (util.lst2str(
                        self.input_types), util.lst2str(self.output_types)))
            _logger.debug(
                'pipeline %d iterations of [%s] -> [%s]' %
                (self.iterate, ', '.join('%s: %s' %
                                         (stmt.haoda_type, stmt.name)
                                         for stmt in self.input_stmts),
                 ', '.join('%s: %s' % (stmt.haoda_type, stmt.name)
                           for stmt in self.output_stmts)))

        for stmt in itertools.chain(self.local_stmts, self.output_stmts):
            _logger.debug('simplify %s', stmt.name)
            # LocalStmt and OutputStmt must remember the stencil object
            # for type propagation
            stmt.stencil = self
            stmt.expr = arithmetic.simplify(stmt.expr)
            stmt.let = arithmetic.simplify(stmt.let)

        self._cr_counter = 0
        cr.computation_reuse(self)
        if 'inline' in self.optimizations:
            inline.inline(self)
        inline.rebalance(self)

        for stmt in itertools.chain(self.local_stmts, self.output_stmts):
            stmt.propagate_type()

        # soda frontend successfully parsed
        _logger.debug(
            'producer tensors: [%s]',
            ', '.join(tensor.name for tensor in self.producer_tensors))
        _logger.debug(
            'consumer tensors: [%s]',
            ', '.join(tensor.name for tensor in self.consumer_tensors))

        # TODO: build Ref table and Var table
        # generate reuse buffers and get haoda nodes
        # pylint: disable=pointless-statement
        self.dataflow_super_source
        _logger.debug('dataflow: %s', self.dataflow_super_source)

        _logger.debug('module table: %s', dict(self.module_table))
        _logger.debug('module traits: %s', self.module_traits)