Exemplo n.º 1
0
def generate_source(casenum, srcgens):
    srcfileids = []

    # handle insert_stmt first
    for srcgen, inputfileid, stmt, span in srcgens:
        for gentype, attrs in srcgen:
            if gentype[0].lower() == 'insert_stmt':
                if inputfileid not in srcfileids:
                    srcfileids.append(inputfileid)
                if len(gentype) > 1:
                    raise ProgramException('More than one gentype: %s' %
                                           gentype)
                transform_source(gentype[0], attrs,
                                 State.inputfile[inputfileid], stmt, span)

    # non insert_stmt type
    for srcgen, inputfileid, stmt, span in srcgens:
        for gentype, attrs in srcgen:
            if not gentype[0].lower() == 'insert_stmt':
                if inputfileid not in srcfileids:
                    srcfileids.append(inputfileid)
                if len(gentype) > 1:
                    raise ProgramException('More than one gentype: %s' %
                                           gentype)
                transform_source(gentype[0], attrs,
                                 State.inputfile[inputfileid], stmt, span)

    for srcfileid in srcfileids:
        #relpath = case_filename(State.inputfile[srcfileid].relpath, casenum)
        State.inputfile[srcfileid].write_to_file(
            Config.path['workdir'] + '/' + State.inputfile[srcfileid].relpath)

    return srcfileids
Exemplo n.º 2
0
def flat_items(items):
    if isinstance(items[0], str):
        return [SrcFile.applymap(t) for t in items]
    elif isinstance(items[0], tuple):
        return [SrcFile.applymap(t[0]) for t, _ in items]
    else:
        raise ProgramException('Unknown type: %s' % items[0].__class__)
Exemplo n.º 3
0
    def get_stmt(self, targets):
        outputs = []

        if isinstance(targets, str):
            if targets:
                if targets.isdigit():
                    stmt = self.stmt_by_label(int(targets))
                    if stmt: outputs.append(stmt)
                elif targets.lower()[0] == 'l':
                    stmt = self.stmt_by_lineno(int(targets[1:]))
                    if stmt: outputs.append(stmt)
                else: raise ProgramException('Syntax error: %s' % targets)
        elif isinstance(targets, int):
            stmt = self.stmt_by_label(targets)
            if stmt: outputs.append(stmt)
        elif isinstance(targets, list) or isinstance(targets, list):
            for target in targets:
                outputs.extend(self.get_stmt(target))
        else:
            raise ProgramException('Unknown type: %s' % targets.__class__)

        return outputs
Exemplo n.º 4
0
    def walk(self,
             casenumseq,
             selectfunc,
             prefunc,
             postfunc,
             elems=None,
             objs=None,
             **kwargs):

        if prefunc:
            prefunc(self)

        if not objs is None: objs.append(self)

        for node in selectfunc(self, casenumseq, **kwargs):
            if node:
                if isinstance(node, SearchSpace):
                    if elems is not None:
                        items = []
                        attrs = {}
                        node.walk(casenumseq,
                                  selectfunc,
                                  prefunc,
                                  postfunc,
                                  items=items,
                                  attrs=attrs,
                                  objs=objs,
                                  **kwargs)
                        if len(items) > 0:
                            elems.append((items, attrs))
                    else:
                        node.walk(casenumseq,
                                  selectfunc,
                                  prefunc,
                                  postfunc,
                                  objs=objs,
                                  **kwargs)
                else:
                    raise ProgramException('Unknown type: %s' % node.__class__)

        if postfunc:
            postfunc(self)
Exemplo n.º 5
0
    def applymap(cls, obj):
        if obj is None: return

        newobj = None
        if isinstance(obj, str):
            newobj = obj[:]
            for key, value in cls.strmap.iteritems():
                newobj = newobj.replace(key, value)
        elif isinstance(obj, dict):
            newobj = {}
            for key, value in obj.iteritems():
                newobj[cls.applymap(key)] = cls.applymap(value)
        elif isinstance(obj, list):
            newobj = []
            for item in obj:
                newobj.append(cls.applymap(item))
        elif isinstance(obj, tuple):
            listobj = list(obj)
            newobj = tuple(cls.applymap(listobj))
        else:
            raise ProgramException('Unknonw type: %s' % obj.__class__)

        return newobj
Exemplo n.º 6
0
def loop_unroll(targets, factor, method):
    for target_stmt in targets:
        if not isinstance(target_stmt, Do):
            Logger.warn('Target statment is not Do type: %s' %
                        target_stmt.__class__)
            continue

        # collect loop control
        target_f2003 = target_stmt.f2003
        if isinstance(target_f2003, Nonlabel_Do_Stmt):
            loop_control = target_f2003.items[1]
            loop_var = loop_control.items[0].string.lower()
            start_idx = loop_control.items[1][0]
            end_idx = loop_control.items[1][1]
            if len(loop_control.items[1]) == 3:
                step = Int_Literal_Constant(str(1))
            else:
                step = loop_control.items[1][2]
        else:
            raise ProgramException('Not supported type: %s' %
                                   f2003obj.__class__)

        # collect loop controls through static analysis
        start_num = target_stmt.get_param(start_idx)
        end_num = target_stmt.get_param(end_idx)
        step_num = target_stmt.get_param(step)
        try:
            loop_indices = range(start_num, end_num + 1, step_num)
        except:
            loop_indices = None

        # TODO: modify analysis if required
        lines = []
        if factor == 'full':
            if loop_indices is not None:
                lines = _unroll(target_stmt.content,
                                loop_var,
                                len(loop_indices),
                                method,
                                start_index=start_num)
            else:
                Logger.warn('Loopcontrol is not collected')

            # save in tree
        elif factor.isdigit():
            factor_num = int(factor)
            if loop_indices is not None and len(loop_indices) == factor_num:
                lines = _unroll(target_stmt.content,
                                loop_var,
                                factor_num,
                                method,
                                start_index=start_num)
            else:
                # replace end and step
                newstep = '%s*%s' % (step.tofortran(), factor)
                newend = '%s-%s' % (end_idx.tofortran(), newstep)
                lines.append(target_stmt.tooc(do_end=newend, do_step=newstep))
                lines.extend(
                    _unroll(target_stmt.content, loop_var, factor_num, method))
                lines.append(target_stmt.content[-1].tooc())

                # replace start
                newstart = loop_var
                lines.append(
                    target_stmt.tooc(do_start=newstart, remove_label=True))
                lines.extend(_unroll(target_stmt.content, loop_var, 1, method))
                lines.append(target_stmt.content[-1].tooc(remove_label=True))
        else:
            raise UserException('Unknown unroll factor: %s' % factor)

        if lines:
            parsed = parse('\n'.join(lines), analyze=False)
            if len(parsed.content) > 0:
                for stmt, depth in walk(parsed, -1):
                    stmt.parse_f2003()
                insert_content(target_stmt, parsed.content)
Exemplo n.º 7
0
    def walk(self,
             casenumseq,
             selectfunc,
             prefunc,
             postfunc,
             items=None,
             attrs=None,
             objs=None,
             **kwargs):
        if prefunc:
            prefunc(self)

        if not objs is None: objs.append(self)

        #node_items, node_attrs = selectfunc(self, casenumseq, **kwargs)
        #for item in node_items:
        for item in self.case[0][0]:
            elems = []
            if isinstance(item, SearchSpace):
                item.walk(casenumseq,
                          selectfunc,
                          prefunc,
                          postfunc,
                          elems=elems,
                          objs=objs,
                          **kwargs)
            elif isinstance(item, str):
                elems.append(item)
            else:
                raise ProgramException('Unknown type: %s' % item.__class__)
            if items is not None:
                items.extend(elems)
        #for key, value in node_attrs.iteritems():
        for key, value in self.case[0][1].iteritems():
            elems = []
            if isinstance(value, SearchSpace):
                value.walk(casenumseq,
                           selectfunc,
                           prefunc,
                           postfunc,
                           elems=elems,
                           objs=objs,
                           **kwargs)
            elif isinstance(value, str):
                elems.append(value)
            else:
                raise ProgramException('Unknown type: %s' % item.__class__)
            if attrs is not None and len(elems) > 0:
                if not attrs.has_key(key):
                    attrs[key] = []
                attrs[key].extend(elems)


#        for node in selectfunc(self, casenumseq, **kwargs):
#            if node:
#                if isinstance(node, tuple):
#                    for item in node:
#                        if isinstance(item, SearchSpace):
#                            item.walk(casenumseq, selectfunc, prefunc, postfunc, **kwargs)
#                elif isinstance(node, dict):
#                    for key, value in node.iteritems():
#                        if isinstance(value, SearchSpace):
#                            value.walk(casenumseq, selectfunc, prefunc, postfunc, **kwargs)
#                else: raise ProgramException('Unknown type: %s'%node.__class__)

        if postfunc:
            postfunc(self)
Exemplo n.º 8
0
    def execute(self):

        if self.casenum == Case.NOT_INTIALIZED:
            raise ProgramException('Case is not initialized')

        output = None

        if self.casenum == Case.REFCASE:
            # execute refcase
            stdout = []
            cmd = State.direct['refcase'][0][0].cases[0][0][0].case[0][0][0]
            refcmd = 'cd %s; ' % Config.path['refdir'] + SrcFile.applymap(cmd)
            for j in range(self.ref_outer_iter):
                for i in range(self.ref_inner_iter):
                    stdout.append(exec_cmd(refcmd))
                    time.sleep(0.1)
                time.sleep(0.3)
            output = '\n'.join(stdout)
        else:
            print 'Executing case %d of %d' % (self.casenum,
                                               State.cases['size'])

            # transform source
            srcgen = [
                value for key, value in self.directs.iteritems()
                if key.lower().startswith('srcgen')
            ]
            srcfiles = generate_source(self.casenum, srcgen)
            for fileid in srcfiles:
                src = '%s/%s' % (Config.path['workdir'],
                                 State.inputfile[fileid].relpath)
                dst = '%s/%s.%d' % (Config.path['outdir'],
                                    State.inputfile[fileid].relpath,
                                    self.casenum)
                shutil.copyfile(src, dst)

            # generate shell script
            script = generate_script(self.casenum, self.directs, srcfiles)
            src = '%s/case_%d.sh' % (Config.path['workdir'], self.casenum)
            dst = '%s/case_%d.sh' % (Config.path['outdir'], self.casenum)
            shutil.copyfile(src, dst)

            # execute shell script
            output = exec_cmd(script)

        #print 'OUTPUT: ', output

        if not output:
            self.result = Case.EXECUTION_FAIL
            return

        # measure
        self.measured = {}
        for var, attrs in self.parent.measure.iteritems():
            self.measured[var] = []
            prefix = attrs['prefix']
            len_prefix = len(prefix)
            match_prefix = findall(prefix, output)
            if match_prefix:
                for start in match_prefix:
                    valuestr = output[start + len_prefix:].lstrip()
                    match_value = re.search(r'[\s\r\n\z]', valuestr)
                    if match_value:
                        self.measured[var].append(
                            valuestr[:match_value.start()])
                    else:
                        self.result = Case.MEASURMENT_FAIL
                        return
            else:
                self.result = Case.MEASURMENT_FAIL
                return
        if any([len(v) == 0 for k, v in self.measured.iteritems()]):
            self.result = Case.MEASURMENT_FAIL
            return

        # verify
        for var, attrs in self.parent.verify.iteritems():
            if not self.measured.has_key(var):
                self.result = Case.NO_MEASUREMENT_FAIL
                return

            method = attrs['method']
            if method == 'match':
                pattern = attrs['pattern']
                if any([value != pattern for value in self.measured[var]]):
                    self.result = Case.VERIFICATION_FAIL
                    return
            elif method == 'diff':
                refval = float(attrs['refval'])
                maxdiff = float(attrs['maxdiff'])
                if any([
                        abs(float(value) - refval) > maxdiff
                        for value in self.measured[var]
                ]):
                    self.result = Case.VERIFICATION_FAIL
                    return
            else:
                raise ProgramException('Unsupported method: %s' % method)

        self.result = Case.VERIFIED
        return
Exemplo n.º 9
0
def transform_source(gentype, attrs, inputfile, stmt, span):
    if gentype.lower() == 'loop_unroll':
        if attrs.has_key('target') and len(attrs['target']) > 0:
            _targets = flat_items(attrs['target'])
            targets = inputfile.get_stmt(_targets)
            if targets:
                factor = flat_items(attrs.get('factor', ['']))
                if len(factor) > 1:
                    raise ProgramException('More than one element: %s' %
                                           factor)
                method = flat_items(attrs.get('method', ['']))
                if len(method) > 1:
                    raise ProgramException('More than one element: %s' %
                                           method)

                loop_unroll(targets, factor[0], method[0])

    elif gentype.lower() == 'loop_merge':
        if attrs.has_key('from') and len(attrs['from'])>0 and \
            attrs.has_key('to') and len(attrs['to'])>0:
            _from = flat_items(attrs['from'])
            _to = flat_items(attrs['to'])
            from_stmts = inputfile.get_stmt(_from)
            to_stmts = inputfile.get_stmt(_to)

            if from_stmts and to_stmts:
                loop_merge(from_stmts, to_stmts)

    elif gentype.lower() == 'loop_split':
        if attrs.has_key('add_stmt') and len(attrs['add_stmt']) > 0:
            add_stmt = flat_items(attrs['add_stmt'])
        else:
            add_stmt = None

        if attrs.has_key('before') and len(attrs['before']) > 0:
            _before = flat_items(attrs['before'])
            before_stmts = inputfile.get_stmt(_before)

            if before_stmts and before_stmts:
                loop_split(before_stmts, add_stmt=add_stmt, before=True)
        elif attrs.has_key('after') and len(attrs['after']) > 0:
            _after = flat_items(attrs['after'])
            after_stmts = inputfile.get_stmt(_after)

            if after_stmts and after_stmts:
                loop_split(after_stmts, add_stmt=add_stmt, before=False)

    elif gentype.lower() == 'loop_interchange':
        if attrs.has_key('outer') and len(attrs['outer'])>0 and \
            attrs.has_key('inner') and len(attrs['inner'])>0:
            _outer = flat_items(attrs['outer'])
            _inner = flat_items(attrs['inner'])
            outer_stmts = inputfile.get_stmt(_outer)
            inner_stmts = inputfile.get_stmt(_inner)

            if outer_stmts and inner_stmts:
                loop_interchange(outer_stmts, inner_stmts)

    elif gentype.lower() == 'name_change':
        if attrs.has_key('target') and len(attrs['target']) > 0:
            _targets = flat_items(attrs['target'])
            targets = inputfile.get_stmt(_targets)
            if targets:
                switch = flat_items(attrs.get('switch', ['']))
                rename = flat_items(attrs.get('rename', ['']))

                name_change(targets, switch, rename)

    elif gentype.lower() == 'openmp':
        if attrs.has_key('sentinel') and len(attrs['sentinel'])>0 and \
            attrs.has_key('directive') and len(attrs['directive'])>0:
            _sentinel = attrs.get('sentinel', None)
            if _sentinel: sentinel = flat_items(_sentinel)
            else: sentinel = None
            _direct = attrs.get('directive', None)
            if _direct: direct = flat_items(_direct)
            else: direct = None
            _clauses = attrs.get('clauses', None)
            if _clauses: clauses = flat_items(_clauses)
            else: clauses = None

            openmp(inputfile, stmt, sentinel, direct, clauses, span)

    elif gentype.lower() == 'insert_stmt':
        if attrs.has_key('label') and len(attrs['label'])>0 and \
            attrs.has_key('stmt') and len(attrs['stmt'])>0:
            _label = attrs.get('label', None)
            if _label: label = flat_items(_label)
            else: label = None
            _stmt = attrs.get('stmt', None)
            if _stmt: stmt_line = flat_items(_stmt)
            else: stmt_line = None

            insert_stmt(inputfile, stmt, label, stmt_line, span)

    elif gentype.lower() == 'remove_stmt':
        if attrs.has_key('target') and len(attrs['target']) > 0:
            _targets = flat_items(attrs['target'])
            targets = inputfile.get_stmt(_targets)
            if targets:
                remove_stmt(inputfile, stmt, targets, span)

    elif gentype.lower() == 'promote':
        if attrs.has_key('name') and len(attrs['name'])>0 and \
            attrs.has_key('dimension') and len(attrs['dimension'])>0 and \
            attrs.has_key('target') and len(attrs['target'])>0:
            _name = attrs.get('name', None)
            if _name: name = flat_items(_name)
            else: name = None
            _dimension = attrs.get('dimension', None)
            if _dimension: dimension = flat_items(_dimension)
            else: dimension = None
            _target = attrs.get('target', None)
            if _target: target = flat_items(_target)
            else: target = None
            _allocate = attrs.get('allocate', None)
            if _allocate: allocate = flat_items(_allocate)
            else: allocate = None

            promote(inputfile, stmt, name, dimension, target, allocate, span)

    elif gentype.lower() == 'directive':
        if attrs.has_key('label') and len(attrs['label'])>0 and \
            attrs.has_key('sentinel') and len(attrs['sentinel'])>0 and \
            attrs.has_key('directive') and len(attrs['directive'])>0:
            _label = attrs.get('label', None)
            if _label: label = flat_items(_label)
            else: label = None
            _sentinel = attrs.get('sentinel', None)
            if _sentinel: sentinel = flat_items(_sentinel)
            else: sentinel = None
            _direct = attrs.get('directive', None)
            if _direct: direct = flat_items(_direct)
            else: direct = None

            directive(inputfile, stmt, label, sentinel, direct, span)

    else:
        raise UserException('Not implemented: %s' % gentype)