Exemplo n.º 1
0
 def single_loop_shape(self, args):
     loop_var = args[0].var
     default_greater_eq = Access(f'{loop_var}_greater_eq')
     default_less_eq = Access(f'{loop_var}_less_eq')
     default_step = Literal(int, 1)
     return LoopShape(args[0], default_greater_eq, default_less_eq,
                      default_step)
Exemplo n.º 2
0
 def loop_shape_parts(self, args):
     merged = None
     for loop_shape_builder in args:
         if merged is None:
             merged = loop_shape_builder
         else:
             merged.merge(loop_shape_builder)
     assert (merged is not None)
     assert (merged.loop_var is not None)
     loop_var = merged.loop_var.var
     default_greater_eq = Access(f'{loop_var}_greater_eq')
     default_less_eq = Access(f'{loop_var}_less_eq')
     default_step = Literal(int, 1)
     return merged.build(default_greater_eq, default_less_eq, default_step)
Exemplo n.º 3
0
    def transform(self, pattern):
        pattern_with_ids = assign_node_ids(pattern)
        loops = get_loops(pattern_with_ids)
        loop_vars = []
        for loop in loops:
            loop_id = loop.attributes['node_id']
            for shape in loop.loop_shapes:
                assert(type(shape.loop_var) == Access)
                loop_vars.append((loop_id, shape.loop_var.var))

        # sort by depth
        def depth_rec(node, id_var_pair, current_depth):
            loop_id, loop_var = id_var_pair
            if type(node) == Program:
                for stmt in node.body:
                    return depth_rec(stmt, id_var_pair, current_depth+1)
            elif type(node) == AbstractLoop:
                if node.attributes['node_id'] == loop_id:
                    for i, shape in enumerate(node.loop_shapes):
                        if shape.loop_var.var == loop_var:
                            return (current_depth, i)
                    raise RuntimeError(f'Loop var {loop_var} not found in {node.pprint()}')
                for stmt in node.body:
                    return depth_rec(stmt, id_var_pair, current_depth+1)
            else:
                return (current_depth + 1, 0)

        def depth(id_var_pair):
            return depth_rec(pattern_with_ids, id_var_pair, 0)

        sorted_loop_vars = sorted(loop_vars, key=depth, reverse=True)

        # TODO: assign unique node ids
        while True:
            cloned = pattern_with_ids.clone()

            for loop_id, loop_var in sorted_loop_vars:
                factor = random.randint(1, self.max_factor)
                if factor == 1:
                    continue
                loops = get_loops(cloned)
                loop = None
                for l in loops:
                    if l.attributes['node_id'] == loop_id:
                        loop = l
                assert(loop is not None)
                loop_shapes_before = []
                loop_shapes_after = []
                loop_var_index = None
                unroll_shape = None
                remainder_shape = None

                is_unrollable = True

                for i, shape in enumerate(loop.loop_shapes):
                    if shape.loop_var.var == loop_var:
                        loop_var_index = i
                        original_step = shape.step.clone()

                        # Build the unroll shape
                        # only support literals for simplicity
                        logger.info('trying')
                        if (type(shape.greater_eq) != Literal or shape.greater_eq.ty != int or
                            type(shape.less_eq) != Literal or shape.less_eq.ty != int or
                            type(shape.step) != Literal or shape.step.ty != int):
                            is_unrollable = False
                            break

                        logger.info('passed')

                        unroll_greater_eq = shape.greater_eq.val
                        unroll_step = shape.step.val * factor
                        unroll_n_iterations = (shape.less_eq.val - shape.greater_eq.val + shape.step.val) // (shape.step.val * factor)
                        unroll_less_eq = unroll_greater_eq + ((unroll_n_iterations - 1) * unroll_step)
                        unroll_shape = LoopShape(shape.loop_var.clone(),
                                                 Literal(int, unroll_greater_eq),
                                                 Literal(int, unroll_less_eq),
                                                 Literal(int, unroll_step))

                        # Build the remainder shape
                        remainder_greater_eq = unroll_less_eq + unroll_step
                        remainder_less_eq = shape.less_eq.val
                        remainder_step = shape.step.val
                        remainder_shape = LoopShape(shape.loop_var.clone(),
                                                    Literal(int, remainder_greater_eq),
                                                    Literal(int, remainder_less_eq),
                                                    Literal(int, remainder_step))
                        break
                    else:
                        loop_shapes_before.append(shape)

                if not is_unrollable:
                    print(f'{loop_var} is not unrollable')
                    continue
                assert(loop_var_index is not None)
                assert(unroll_shape is not None)
                assert(remainder_shape is not None)

                for shape in loop.loop_shapes[loop_var_index+1:]:
                    loop_shapes_after.append(shape)

                unrolled_body = []
                for f in range(0, factor):
                    unrolled_innermost_body = []
                    step = loop.loop_shapes[loop_var_index].step
                    assert(type(step) == Literal)
                    assert(step.ty == int)
                    replacer = UnrollReplacer(loop_var, f * step.val)
                    for stmt in loop.body:
                        unrolled_stmt = stmt.clone()
                        unrolled_stmt.replace(replacer)
                        unrolled_innermost_body.append(unrolled_stmt)
                    if len(loop_shapes_after) == 0:
                        unrolled_body += unrolled_innermost_body
                    else:
                        shapes = [shape.clone() for shape in loop_shapes_after]
                        unrolled_body.append(AbstractLoop(shapes,
                                                          unrolled_innermost_body))

                remainder_innermost_body = [stmt.clone() for stmt in loop.body]
                if len(loop_shapes_after) == 0:
                    remainder_body = remainder_innermost_body
                else:
                    shapes = [shape.clone() for shape in loop_shapes_after]
                    remainder_body = [AbstractLoop(shapes, remainder_innermost_body)]

                unrolled_loop = AbstractLoop([unroll_shape], unrolled_body)
                remainder_loop = AbstractLoop([remainder_shape], remainder_body)

                # The unroll sequence is the unrolled loop followed by the remainder loop
                if len(loop_shapes_before) == 0:
                    unroll_sequence = [unrolled_loop, remainder_loop]
                else:
                    # The surrounding loop needs to preserve the loop_id
                    # since the surrounding loops may be unrolled as well
                    unroll_sequence = [AbstractLoop(loop_shapes_before,
                                                    [unrolled_loop, remainder_loop],
                                                    loop_id)]

                # Replace the original loop with the unroll sequence
                index = loop.surrounding_loop.find_stmt(loop)
                loop.surrounding_loop.remove_stmt(loop)
                loop.surrounding_loop.insert_stmts(index, unroll_sequence)

            yield cloned
Exemplo n.º 4
0
 def replace(self, node):
     return Op('+', [node.clone(), Literal(int, self.offset)])
Exemplo n.º 5
0
 def int_literal(self, args):
     return Literal(int, int(args[0]))
Exemplo n.º 6
0
 def float_literal(self, args):
     return Literal(float, float(args[0]))