Example #1
0
def loop_interchange(outer_stmts, inner_stmts):

    for outer_stmt in outer_stmts:

        if not isinstance(outer_stmt, Do):
            Logger.warn('Outer statment is not Do type: %s'%outer_stmt.__class__)
            continue

        for inner_stmt in inner_stmts:
            if not isinstance(inner_stmt, Do):
                Logger.warn('Inner statment is not Do type: %s'%inner_stmt.__class__)
                continue

            lines = []
            for stmt, depth in walk(outer_stmt, -1):
                if stmt is outer_stmt:
                    lines.append(inner_stmt.tooc())
                elif stmt is inner_stmt:
                    lines.append(outer_stmt.tooc())
                elif stmt is inner_stmt.content[-1]:
                    lines.append(outer_stmt.content[-1].tooc())
                elif stmt is outer_stmt.content[-1]:
                    lines.append(inner_stmt.content[-1].tooc())
                else:
                    lines.append(stmt.tooc())

            if lines:
                parsed = parse('\n'.join(lines), analyze=False)
                if len(parsed.content)>0:
                    parsed.content[0].parent = outer_stmt.parent
                    for stmt, depth in walk(parsed, -1): stmt.parse_f2003() 
                    insert_content(outer_stmt, parsed.content)
Example #2
0
def promote(inputfile, target_stmt, names, dimensions, targets, allocate, span):

    for target in targets:
        for promote_pair in target.split(','):
            label, idxname = promote_pair.split(':')
            promote_stmt = inputfile.get_stmt(label)
            if promote_stmt:
                lines = []
                in_exepart = False
                for stmt, depth in walk(promote_stmt[0], -1):
                    if isinstance(stmt, TypeDeclarationStatement):
                        org_attrspec = stmt.attrspec
                        if any( [ name in stmt.entity_decls for name in names ] ):
                            entity_decls = []
                            name_decls = [] 
                            attrspec = stmt.attrspec
                            for entity in stmt.entity_decls:
                                if entity in names:
                                    name_decls.append(entity)
                                else:
                                    entity_decls.append(entity)
                            if len(stmt.entity_decls)>0:
                                stmt.entity_decls = entity_decls 
                                lines.append(stmt.tooc())
                            if allocate:
                                if 'allocatable' not in stmt.attrspec:
                                    stmt.attrspec.append('allocatable')
                                stmt.entity_decls = [ name_decl+dimensions[0] for name_decl in name_decls ]
                            else:
                                stmt.entity_decls = [ name_decl+allocate[0] for name_decl in name_decls ]
                            if len(stmt.entity_decls)>0:
                                lines.append(stmt.tooc())
                                stmt.entity_decls = entity_decls
                        else:
                            if len(stmt.entity_decls)>0:
                                lines.append(stmt.tooc())
                    elif not in_exepart and stmt.__class__ in execution_part:
                        renames = []
                        for name in names:
                            for dim in dimensions:
                                if allocate:
                                    lines.append('allocate(%s)'%(name+allocate[0]))
                                renames.append([name, name+idxname])
                        lines.append(stmt.tooc(name_rename=renames))
                        in_exepart = True
                    elif in_exepart:
                        renames = []
                        for name in names:
                            for dim in dimensions:
                                renames.append([name, name+idxname])
                        lines.append(stmt.tooc(name_rename=renames))
                    else:
                        lines.append(stmt.tooc())

                try:
                    parsed = parse('\n'.join(lines), analyze=False, ignore_comments=False)
                    if len(parsed.content)>0:
                        for stmt, depth in walk(parsed, -1): stmt.parse_f2003()
                        insert_content(promote_stmt[0], parsed.content, remove_olditem=True)                       
                except: pass
Example #3
0
def remove_stmt(inputfile, target_stmt, targets, span):
    for target in targets:
        if target:
            parsed = parse('!'+str(target), analyze=False, ignore_comments=False)
            if len(parsed.content)>0:
                for stmt, depth in walk(parsed, 1):
                    stmt.parse_f2003()
                insert_content(target, parsed.content, remove_olditem=True)
Example #4
0
def name_change(targets, switch, rename):

    for target_stmt in targets:
        list_switch = [ (pair.split(':')[0].strip(),  pair.split(':')[1].strip()) for pair in switch  if pair]
        list_rename = [ (pair.split(':')[0].strip(),  pair.split(':')[1].strip()) for pair in rename  if pair]
        lines = target_stmt.tooc(name_switch=list_switch, name_rename=list_rename)

        if lines:
            parsed = parse(lines, analyze=False)
            if len(parsed.content)>0:
                parsed.content[0].parent = target_stmt.parent
                for stmt, depth in walk(parsed, -1): stmt.parse_f2003()
                insert_content(target_stmt, parsed.content, remove_olditem=True)
Example #5
0
def loop_split(stmts, add_stmt, before=True):

    for stmt in stmts:
        parent = stmt.parent

        if not isinstance(parent, Do):
            Logger.warn('Parent of statment is not Do type: %s'%parent.__class__)
            continue

        doblk1 = []
        doblk2 = []

        #if add_stmt: doblk1.append(add_stmt[0])
        doblk1.append(parent.tooc())

        if add_stmt: doblk2.append(add_stmt[0])
        doblk2.append(parent.tooc(remove_label=True))

        enddo_stmt = parent.content[-1]

        doblk = doblk1
        remove_label = False
        for childstmt, depth in walk(parent, -1):
            if childstmt not in [ parent, enddo_stmt]:
                if not before:
                    doblk.append(childstmt.tooc(remove_label=remove_label))
                if childstmt==stmt:
                    doblk = doblk2
                    remove_label = True
                if before:
                    doblk.append(childstmt.tooc(remove_label=remove_label))
            
        doblk1.append(enddo_stmt.tooc())
        doblk2.append(enddo_stmt.tooc(remove_label=True))

        if doblk1:
            parsed = parse('\n'.join(doblk1), analyze=False, ignore_comments=False)
            if len(parsed.content)>0:
                parsed.content[0].parent = parent.parent
                for stmt, depth in walk(parsed, -1): stmt.parse_f2003()
                insert_content(parent, parsed.content, remove_olditem=False)

        if doblk2:
            parsed = parse('\n'.join(doblk2), analyze=False, ignore_comments=False)
            if len(parsed.content)>0:
                parsed.content[0].parent = parent.parent
                for stmt, depth in walk(parsed, -1): stmt.parse_f2003()
                insert_content(parent, parsed.content, remove_olditem=True)
Example #6
0
def insert_stmt(inputfile, target_stmt, label, stmt_line, span):
    new_target_stmt = None
    for stmt, depth in walk(inputfile.tree, -1):
        if stmt.item.span==span:
            new_target_stmt = stmt
            break

    if stmt_line:
        parsed = parse(stmt_line[0], analyze=False, ignore_comments=False)
        if len(parsed.content)>0:
            for stmt, depth in walk(parsed, 1):
                if isinstance(stmt, Comment):
                    stmt.label = int(label[0])
                else:
                    stmt.item.label = int(label[0])
                stmt.parse_f2003()
            insert_content(new_target_stmt, parsed.content, remove_olditem=True)
Example #7
0
def name_change(targets, switch, rename):

    for target_stmt in targets:
        list_switch = [(pair.split(':')[0].strip(), pair.split(':')[1].strip())
                       for pair in switch if pair]
        list_rename = [(pair.split(':')[0].strip(), pair.split(':')[1].strip())
                       for pair in rename if pair]
        lines = target_stmt.tooc(name_switch=list_switch,
                                 name_rename=list_rename)

        if lines:
            parsed = parse(lines, analyze=False)
            if len(parsed.content) > 0:
                parsed.content[0].parent = target_stmt.parent
                for stmt, depth in walk(parsed, -1):
                    stmt.parse_f2003()
                insert_content(target_stmt,
                               parsed.content,
                               remove_olditem=True)
Example #8
0
def openmp(inputfile, target_stmt, sentinel, directive, clauses, span):

    line = ''
    new_target_stmt = None
    for stmt, depth in walk(inputfile.tree, -1):
        if stmt.item.span==span:
            new_target_stmt = stmt
            if clauses:
                mapped_clauses = SrcFile.applymap(clauses[0])                
            else:
                mapped_clauses = ''
            line = '%s %s %s'%(SrcFile.applymap(sentinel[0]), SrcFile.applymap(directive[0]), mapped_clauses)
            break

    if line:
        parsed = parse(line, analyze=False, ignore_comments=False)
        if len(parsed.content)>0:
            #parsed.content[0].parent = target_stmt.parent
            #import pdb; pdb.set_trace()
            for stmt, depth in walk(parsed, -1): stmt.parse_f2003()
            insert_content(new_target_stmt, parsed.content, remove_olditem=True)
Example #9
0
def directive(inputfile, target_stmt, label, sentinel, directive, span):

    line = ''
    new_target_stmt = None
    for stmt, depth in walk(inputfile.tree, -1):
        if stmt.item.span==span:
            new_target_stmt = stmt
            line = '!%s$ %s'%(SrcFile.applymap(sentinel[0]), SrcFile.applymap(directive[0]))
            break

    if line:
        parsed = parse(line, analyze=False, ignore_comments=False)
        if len(parsed.content)>0:
            for stmt, depth in walk(parsed, 1):
                if isinstance(stmt, Comment):
                    stmt.label = int(label[0])
                else:
                    stmt.item.label = int(label[0])
                stmt.parse_f2003()

            insert_content(new_target_stmt, parsed.content, remove_olditem=True)
Example #10
0
def directive(inputfile, target_stmt, label, sentinel, directive, span):

    line = ''
    new_target_stmt = None
    for stmt, depth in walk(inputfile.tree, -1):
        if stmt.item.span == span:
            new_target_stmt = stmt
            line = '!%s$ %s' % (SrcFile.applymap(
                sentinel[0]), SrcFile.applymap(directive[0]))
            break

    if line:
        parsed = parse(line, analyze=False, ignore_comments=False)
        if len(parsed.content) > 0:
            for stmt, depth in walk(parsed, 1):
                if isinstance(stmt, Comment):
                    stmt.label = int(label[0])
                else:
                    stmt.item.label = int(label[0])
                stmt.parse_f2003()

            insert_content(new_target_stmt,
                           parsed.content,
                           remove_olditem=True)
Example #11
0
def loop_unroll(targets, factor, method):
    for target_stmt in targets:
        if not isinstance(target_stmt, Do):
            Logger.warn('Target statment is not Do type: %s' %
                        target_stmt.__class__)
            continue

        # collect loop control
        target_f2003 = target_stmt.f2003
        if isinstance(target_f2003, Nonlabel_Do_Stmt):
            loop_control = target_f2003.items[1]
            loop_var = loop_control.items[0].string.lower()
            start_idx = loop_control.items[1][0]
            end_idx = loop_control.items[1][1]
            if len(loop_control.items[1]) == 3:
                step = Int_Literal_Constant(str(1))
            else:
                step = loop_control.items[1][2]
        else:
            raise ProgramException('Not supported type: %s' %
                                   f2003obj.__class__)

        # collect loop controls through static analysis
        start_num = target_stmt.get_param(start_idx)
        end_num = target_stmt.get_param(end_idx)
        step_num = target_stmt.get_param(step)
        try:
            loop_indices = range(start_num, end_num + 1, step_num)
        except:
            loop_indices = None

        # TODO: modify analysis if required
        lines = []
        if factor == 'full':
            if loop_indices is not None:
                lines = _unroll(target_stmt.content,
                                loop_var,
                                len(loop_indices),
                                method,
                                start_index=start_num)
            else:
                Logger.warn('Loopcontrol is not collected')

            # save in tree
        elif factor.isdigit():
            factor_num = int(factor)
            if loop_indices is not None and len(loop_indices) == factor_num:
                lines = _unroll(target_stmt.content,
                                loop_var,
                                factor_num,
                                method,
                                start_index=start_num)
            else:
                # replace end and step
                newstep = '%s*%s' % (step.tofortran(), factor)
                newend = '%s-%s' % (end_idx.tofortran(), newstep)
                lines.append(target_stmt.tooc(do_end=newend, do_step=newstep))
                lines.extend(
                    _unroll(target_stmt.content, loop_var, factor_num, method))
                lines.append(target_stmt.content[-1].tooc())

                # replace start
                newstart = loop_var
                lines.append(
                    target_stmt.tooc(do_start=newstart, remove_label=True))
                lines.extend(_unroll(target_stmt.content, loop_var, 1, method))
                lines.append(target_stmt.content[-1].tooc(remove_label=True))
        else:
            raise UserException('Unknown unroll factor: %s' % factor)

        if lines:
            parsed = parse('\n'.join(lines), analyze=False)
            if len(parsed.content) > 0:
                for stmt, depth in walk(parsed, -1):
                    stmt.parse_f2003()
                insert_content(target_stmt, parsed.content)
Example #12
0
def loop_unroll(targets, factor, method):
    for target_stmt in targets:
        if not isinstance(target_stmt, Do):
            Logger.warn("Target statment is not Do type: %s" % target_stmt.__class__)
            continue

        # collect loop control
        target_f2003 = target_stmt.f2003
        if isinstance(target_f2003, Nonlabel_Do_Stmt):
            loop_control = target_f2003.items[1]
            loop_var = loop_control.items[0].string.lower()
            start_idx = loop_control.items[1][0]
            end_idx = loop_control.items[1][1]
            if len(loop_control.items[1]) == 3:
                step = Int_Literal_Constant(str(1))
            else:
                step = loop_control.items[1][2]
        else:
            raise ProgramException("Not supported type: %s" % f2003obj.__class__)

        # collect loop controls through static analysis
        start_num = target_stmt.get_param(start_idx)
        end_num = target_stmt.get_param(end_idx)
        step_num = target_stmt.get_param(step)
        try:
            loop_indices = range(start_num, end_num + 1, step_num)
        except:
            loop_indices = None

        # TODO: modify analysis if required
        lines = []
        if factor == "full":
            if loop_indices is not None:
                lines = _unroll(target_stmt.content, loop_var, len(loop_indices), method, start_index=start_num)
            else:
                Logger.warn("Loopcontrol is not collected")

            # save in tree
        elif factor.isdigit():
            factor_num = int(factor)
            if loop_indices is not None and len(loop_indices) == factor_num:
                lines = _unroll(target_stmt.content, loop_var, factor_num, method, start_index=start_num)
            else:
                # replace end and step
                newstep = "%s*%s" % (step.tofortran(), factor)
                newend = "%s-%s" % (end_idx.tofortran(), newstep)
                lines.append(target_stmt.tooc(do_end=newend, do_step=newstep))
                lines.extend(_unroll(target_stmt.content, loop_var, factor_num, method))
                lines.append(target_stmt.content[-1].tooc())

                # replace start
                newstart = loop_var
                lines.append(target_stmt.tooc(do_start=newstart, remove_label=True))
                lines.extend(_unroll(target_stmt.content, loop_var, 1, method))
                lines.append(target_stmt.content[-1].tooc(remove_label=True))
        else:
            raise UserException("Unknown unroll factor: %s" % factor)

        if lines:
            parsed = parse("\n".join(lines), analyze=False)
            if len(parsed.content) > 0:
                for stmt, depth in walk(parsed, -1):
                    stmt.parse_f2003()
                insert_content(target_stmt, parsed.content)